diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/neqo-transport/src | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/neqo-transport/src')
56 files changed, 34799 insertions, 0 deletions
diff --git a/third_party/rust/neqo-transport/src/ackrate.rs b/third_party/rust/neqo-transport/src/ackrate.rs new file mode 100644 index 0000000000..cf68f9021f --- /dev/null +++ b/third_party/rust/neqo-transport/src/ackrate.rs @@ -0,0 +1,213 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Management of the peer's ack rate. +#![deny(clippy::pedantic)] + +use std::{cmp::max, convert::TryFrom, time::Duration}; + +use neqo_common::qtrace; + +use crate::{ + connection::params::ACK_RATIO_SCALE, frame::FRAME_TYPE_ACK_FREQUENCY, packet::PacketBuilder, + recovery::RecoveryToken, stats::FrameStats, +}; + +#[derive(Debug, Clone)] +pub struct AckRate { + /// The maximum number of packets that can be received without sending an ACK. + packets: usize, + /// The maximum delay before sending an ACK. + delay: Duration, +} + +impl AckRate { + pub fn new(minimum: Duration, ratio: u8, cwnd: usize, mtu: usize, rtt: Duration) -> Self { + const PACKET_RATIO: usize = ACK_RATIO_SCALE as usize; + // At worst, ask for an ACK for every other packet. + const MIN_PACKETS: usize = 2; + // At worst, require an ACK every 256 packets. + const MAX_PACKETS: usize = 256; + const RTT_RATIO: u32 = ACK_RATIO_SCALE as u32; + const MAX_DELAY: Duration = Duration::from_millis(50); + + let packets = cwnd * PACKET_RATIO / mtu / usize::from(ratio); + let packets = packets.clamp(MIN_PACKETS, MAX_PACKETS) - 1; + let delay = rtt * RTT_RATIO / u32::from(ratio); + let delay = delay.clamp(minimum, MAX_DELAY); + qtrace!("AckRate inputs: {}/{}/{}, {:?}", cwnd, mtu, ratio, rtt); + Self { packets, delay } + } + + pub fn write_frame(&self, builder: &mut PacketBuilder, seqno: u64) -> bool { + builder.write_varint_frame(&[ + FRAME_TYPE_ACK_FREQUENCY, + seqno, + u64::try_from(self.packets + 1).unwrap(), + u64::try_from(self.delay.as_micros()).unwrap(), + 0, + ]) + } + + /// Determine whether to send an update frame. + pub fn needs_update(&self, target: &Self) -> bool { + if self.packets != target.packets { + return true; + } + // Allow more flexibility for delays, as those can change + // by small amounts fairly easily. + let delta = target.delay / 4; + target.delay + delta < self.delay || target.delay > self.delay + delta + } +} + +#[derive(Debug, Clone)] +pub struct FlexibleAckRate { + current: AckRate, + target: AckRate, + next_frame_seqno: u64, + frame_outstanding: bool, + min_ack_delay: Duration, + ratio: u8, +} + +impl FlexibleAckRate { + fn new( + max_ack_delay: Duration, + min_ack_delay: Duration, + ratio: u8, + cwnd: usize, + mtu: usize, + rtt: Duration, + ) -> Self { + qtrace!( + "FlexibleAckRate: {:?} {:?} {}", + max_ack_delay, + min_ack_delay, + ratio + ); + let ratio = max(ACK_RATIO_SCALE, ratio); // clamp it + Self { + current: AckRate { + packets: 1, + delay: max_ack_delay, + }, + target: AckRate::new(min_ack_delay, ratio, cwnd, mtu, rtt), + next_frame_seqno: 0, + frame_outstanding: false, + min_ack_delay, + ratio, + } + } + + fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + if !self.frame_outstanding + && self.current.needs_update(&self.target) + && self.target.write_frame(builder, self.next_frame_seqno) + { + qtrace!("FlexibleAckRate: write frame {:?}", self.target); + self.frame_outstanding = true; + self.next_frame_seqno += 1; + tokens.push(RecoveryToken::AckFrequency(self.target.clone())); + stats.ack_frequency += 1; + } + } + + fn frame_acked(&mut self, acked: &AckRate) { + self.frame_outstanding = false; + self.current = acked.clone(); + } + + fn frame_lost(&mut self, _lost: &AckRate) { + self.frame_outstanding = false; + } + + fn update(&mut self, cwnd: usize, mtu: usize, rtt: Duration) { + self.target = AckRate::new(self.min_ack_delay, self.ratio, cwnd, mtu, rtt); + qtrace!("FlexibleAckRate: {:?} -> {:?}", self.current, self.target); + } + + fn peer_ack_delay(&self) -> Duration { + max(self.current.delay, self.target.delay) + } +} + +#[derive(Debug, Clone)] +pub enum PeerAckDelay { + Fixed(Duration), + Flexible(FlexibleAckRate), +} + +impl PeerAckDelay { + pub fn fixed(max_ack_delay: Duration) -> Self { + Self::Fixed(max_ack_delay) + } + + pub fn flexible( + max_ack_delay: Duration, + min_ack_delay: Duration, + ratio: u8, + cwnd: usize, + mtu: usize, + rtt: Duration, + ) -> Self { + Self::Flexible(FlexibleAckRate::new( + max_ack_delay, + min_ack_delay, + ratio, + cwnd, + mtu, + rtt, + )) + } + + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + if let Self::Flexible(rate) = self { + rate.write_frames(builder, tokens, stats); + } + } + + pub fn frame_acked(&mut self, r: &AckRate) { + if let Self::Flexible(rate) = self { + rate.frame_acked(r); + } + } + + pub fn frame_lost(&mut self, r: &AckRate) { + if let Self::Flexible(rate) = self { + rate.frame_lost(r); + } + } + + pub fn max(&self) -> Duration { + match self { + Self::Flexible(rate) => rate.peer_ack_delay(), + Self::Fixed(delay) => *delay, + } + } + + pub fn update(&mut self, cwnd: usize, mtu: usize, rtt: Duration) { + if let Self::Flexible(rate) = self { + rate.update(cwnd, mtu, rtt); + } + } +} + +impl Default for PeerAckDelay { + fn default() -> Self { + Self::fixed(Duration::from_millis(25)) + } +} diff --git a/third_party/rust/neqo-transport/src/addr_valid.rs b/third_party/rust/neqo-transport/src/addr_valid.rs new file mode 100644 index 0000000000..b5ed2d07d1 --- /dev/null +++ b/third_party/rust/neqo-transport/src/addr_valid.rs @@ -0,0 +1,508 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// This file implements functions necessary for address validation. + +use std::{ + convert::TryFrom, + net::{IpAddr, SocketAddr}, + time::{Duration, Instant}, +}; + +use neqo_common::{qinfo, qtrace, Decoder, Encoder, Role}; +use neqo_crypto::{ + constants::{TLS_AES_128_GCM_SHA256, TLS_VERSION_1_3}, + selfencrypt::SelfEncrypt, +}; +use smallvec::SmallVec; + +use crate::{ + cid::ConnectionId, packet::PacketBuilder, recovery::RecoveryToken, stats::FrameStats, Res, +}; + +/// A prefix we add to Retry tokens to distinguish them from NEW_TOKEN tokens. +const TOKEN_IDENTIFIER_RETRY: &[u8] = &[0x52, 0x65, 0x74, 0x72, 0x79]; +/// A prefix on NEW_TOKEN tokens, that is maximally Hamming distant from NEW_TOKEN. +/// Together, these need to have a low probability of collision, even if there is +/// corruption of individual bits in transit. +const TOKEN_IDENTIFIER_NEW_TOKEN: &[u8] = &[0xad, 0x9a, 0x8b, 0x8d, 0x86]; + +/// The maximum number of tokens we'll save from NEW_TOKEN frames. +/// This should be the same as the value of MAX_TICKETS in neqo-crypto. +const MAX_NEW_TOKEN: usize = 4; +/// The number of tokens we'll track for the purposes of looking for duplicates. +/// This is based on how many might be received over a period where could be +/// retransmissions. It should be at least `MAX_NEW_TOKEN`. +const MAX_SAVED_TOKENS: usize = 8; + +/// `ValidateAddress` determines what sort of address validation is performed. +/// In short, this determines when a Retry packet is sent. +#[derive(Debug, PartialEq, Eq)] +pub enum ValidateAddress { + /// Require address validation never. + Never, + /// Require address validation unless a NEW_TOKEN token is provided. + NoToken, + /// Require address validation even if a NEW_TOKEN token is provided. + Always, +} + +pub enum AddressValidationResult { + Pass, + ValidRetry(ConnectionId), + Validate, + Invalid, +} + +pub struct AddressValidation { + /// What sort of validation is performed. + validation: ValidateAddress, + /// A self-encryption object used for protecting Retry tokens. + self_encrypt: SelfEncrypt, + /// When this object was created. + start_time: Instant, +} + +impl AddressValidation { + pub fn new(now: Instant, validation: ValidateAddress) -> Res<Self> { + Ok(Self { + validation, + self_encrypt: SelfEncrypt::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256)?, + start_time: now, + }) + } + + fn encode_aad(peer_address: SocketAddr, retry: bool) -> Encoder { + // Let's be "clever" by putting the peer's address in the AAD. + // We don't need to encode these into the token as they should be + // available when we need to check the token. + let mut aad = Encoder::default(); + if retry { + aad.encode(TOKEN_IDENTIFIER_RETRY); + } else { + aad.encode(TOKEN_IDENTIFIER_NEW_TOKEN); + } + match peer_address.ip() { + IpAddr::V4(a) => { + aad.encode_byte(4); + aad.encode(&a.octets()); + } + IpAddr::V6(a) => { + aad.encode_byte(6); + aad.encode(&a.octets()); + } + } + if retry { + aad.encode_uint(2, peer_address.port()); + } + aad + } + + pub fn generate_token( + &self, + dcid: Option<&ConnectionId>, + peer_address: SocketAddr, + now: Instant, + ) -> Res<Vec<u8>> { + const EXPIRATION_RETRY: Duration = Duration::from_secs(5); + const EXPIRATION_NEW_TOKEN: Duration = Duration::from_secs(60 * 60 * 24); + + // TODO(mt) rotate keys on a fixed schedule. + let retry = dcid.is_some(); + let mut data = Encoder::default(); + let end = now + + if retry { + EXPIRATION_RETRY + } else { + EXPIRATION_NEW_TOKEN + }; + let end_millis = u32::try_from(end.duration_since(self.start_time).as_millis())?; + data.encode_uint(4, end_millis); + if let Some(dcid) = dcid { + data.encode(dcid); + } + + // Include the token identifier ("Retry"/~) in the AAD, then keep it for plaintext. + let mut buf = Self::encode_aad(peer_address, retry); + let encrypted = self.self_encrypt.seal(buf.as_ref(), data.as_ref())?; + buf.truncate(TOKEN_IDENTIFIER_RETRY.len()); + buf.encode(&encrypted); + Ok(buf.into()) + } + + /// This generates a token for use with Retry. + pub fn generate_retry_token( + &self, + dcid: &ConnectionId, + peer_address: SocketAddr, + now: Instant, + ) -> Res<Vec<u8>> { + self.generate_token(Some(dcid), peer_address, now) + } + + /// This generates a token for use with NEW_TOKEN. + pub fn generate_new_token(&self, peer_address: SocketAddr, now: Instant) -> Res<Vec<u8>> { + self.generate_token(None, peer_address, now) + } + + pub fn set_validation(&mut self, validation: ValidateAddress) { + qtrace!("AddressValidation {:p}: set to {:?}", self, validation); + self.validation = validation; + } + + /// Decrypts `token` and returns the connection ID it contains. + /// Returns a tuple with a boolean indicating whether this thinks + /// that the token was a Retry token, and a connection ID, that is + /// None if the token wasn't successfully decrypted. + fn decrypt_token( + &self, + token: &[u8], + peer_address: SocketAddr, + retry: bool, + now: Instant, + ) -> Option<ConnectionId> { + let peer_addr = Self::encode_aad(peer_address, retry); + let data = self.self_encrypt.open(peer_addr.as_ref(), token).ok()?; + let mut dec = Decoder::new(&data); + match dec.decode_uint(4) { + Some(d) => { + let end = self.start_time + Duration::from_millis(d); + if end < now { + qtrace!("Expired token: {:?} vs. {:?}", end, now); + return None; + } + } + _ => return None, + } + Some(ConnectionId::from(dec.decode_remainder())) + } + + /// Calculate the Hamming difference between our identifier and the target. + /// Less than one difference per byte indicates that it is likely not a Retry. + /// This generous interpretation allows for a lot of damage in transit. + /// Note that if this check fails, then the token will be treated like it came + /// from NEW_TOKEN instead. If there truly is corruption of packets that causes + /// validation failure, it will be a failure that we try to recover from. + fn is_likely_retry(token: &[u8]) -> bool { + let mut difference = 0; + for i in 0..TOKEN_IDENTIFIER_RETRY.len() { + difference += (token[i] ^ TOKEN_IDENTIFIER_RETRY[i]).count_ones(); + } + usize::try_from(difference).unwrap() < TOKEN_IDENTIFIER_RETRY.len() + } + + pub fn validate( + &self, + token: &[u8], + peer_address: SocketAddr, + now: Instant, + ) -> AddressValidationResult { + qtrace!( + "AddressValidation {:p}: validate {:?}", + self, + self.validation + ); + + if token.is_empty() { + if self.validation == ValidateAddress::Never { + qinfo!("AddressValidation: no token; accepting"); + return AddressValidationResult::Pass; + } else { + qinfo!("AddressValidation: no token; validating"); + return AddressValidationResult::Validate; + } + } + if token.len() <= TOKEN_IDENTIFIER_RETRY.len() { + // Treat bad tokens strictly. + qinfo!("AddressValidation: too short token"); + return AddressValidationResult::Invalid; + } + let retry = Self::is_likely_retry(token); + let enc = &token[TOKEN_IDENTIFIER_RETRY.len()..]; + // Note that this allows the token identifier part to be corrupted. + // That's OK here as we don't depend on that being authenticated. + if let Some(cid) = self.decrypt_token(enc, peer_address, retry, now) { + if retry { + // This is from Retry, so we should have an ODCID >= 8. + if cid.len() >= 8 { + qinfo!("AddressValidation: valid Retry token for {}", cid); + AddressValidationResult::ValidRetry(cid) + } else { + panic!("AddressValidation: Retry token with small CID {}", cid); + } + } else if cid.is_empty() { + // An empty connection ID means NEW_TOKEN. + if self.validation == ValidateAddress::Always { + qinfo!("AddressValidation: valid NEW_TOKEN token; validating again"); + AddressValidationResult::Validate + } else { + qinfo!("AddressValidation: valid NEW_TOKEN token; accepting"); + AddressValidationResult::Pass + } + } else { + panic!("AddressValidation: NEW_TOKEN token with CID {}", cid); + } + } else { + // From here on, we have a token that we couldn't decrypt. + // We've either lost the keys or we've received junk. + if retry { + // If this looked like a Retry, treat it as being bad. + qinfo!("AddressValidation: invalid Retry token; rejecting"); + AddressValidationResult::Invalid + } else if self.validation == ValidateAddress::Never { + // We don't require validation, so OK. + qinfo!("AddressValidation: invalid NEW_TOKEN token; accepting"); + AddressValidationResult::Pass + } else { + // This might be an invalid NEW_TOKEN token, or a valid one + // for which we have since lost the keys. Check again. + qinfo!("AddressValidation: invalid NEW_TOKEN token; validating again"); + AddressValidationResult::Validate + } + } + } +} + +// Note: these lint override can be removed in later versions where the lints +// either don't trip a false positive or don't apply. rustc 1.46 is fine. +#[allow(dead_code, clippy::large_enum_variant)] +pub enum NewTokenState { + Client { + /// Tokens that haven't been taken yet. + pending: SmallVec<[Vec<u8>; MAX_NEW_TOKEN]>, + /// Tokens that have been taken, saved so that we can discard duplicates. + old: SmallVec<[Vec<u8>; MAX_SAVED_TOKENS]>, + }, + Server(NewTokenSender), +} + +impl NewTokenState { + pub fn new(role: Role) -> Self { + match role { + Role::Client => Self::Client { + pending: SmallVec::<[_; MAX_NEW_TOKEN]>::new(), + old: SmallVec::<[_; MAX_SAVED_TOKENS]>::new(), + }, + Role::Server => Self::Server(NewTokenSender::default()), + } + } + + /// Is there a token available? + pub fn has_token(&self) -> bool { + match self { + Self::Client { ref pending, .. } => !pending.is_empty(), + Self::Server(..) => false, + } + } + + /// If this is a client, take a token if there is one. + /// If this is a server, panic. + pub fn take_token(&mut self) -> Option<&[u8]> { + if let Self::Client { + ref mut pending, + ref mut old, + } = self + { + if let Some(t) = pending.pop() { + if old.len() >= MAX_SAVED_TOKENS { + old.remove(0); + } + old.push(t); + Some(&old[old.len() - 1]) + } else { + None + } + } else { + unreachable!(); + } + } + + /// If this is a client, save a token. + /// If this is a server, panic. + pub fn save_token(&mut self, token: Vec<u8>) { + if let Self::Client { + ref mut pending, + ref old, + } = self + { + for t in old.iter().rev().chain(pending.iter().rev()) { + if t == &token { + qinfo!("NewTokenState discarding duplicate NEW_TOKEN"); + return; + } + } + + if pending.len() >= MAX_NEW_TOKEN { + pending.remove(0); + } + pending.push(token); + } else { + unreachable!(); + } + } + + /// If this is a server, maybe send a frame. + /// If this is a client, do nothing. + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) -> Res<()> { + if let Self::Server(ref mut sender) = self { + sender.write_frames(builder, tokens, stats)?; + } + Ok(()) + } + + /// If this a server, buffer a NEW_TOKEN for sending. + /// If this is a client, panic. + pub fn send_new_token(&mut self, token: Vec<u8>) { + if let Self::Server(ref mut sender) = self { + sender.send_new_token(token); + } else { + unreachable!(); + } + } + + /// If this a server, process a lost signal for a NEW_TOKEN frame. + /// If this is a client, panic. + pub fn lost(&mut self, seqno: usize) { + if let Self::Server(ref mut sender) = self { + sender.lost(seqno); + } else { + unreachable!(); + } + } + + /// If this a server, process remove the acknowledged NEW_TOKEN frame. + /// If this is a client, panic. + pub fn acked(&mut self, seqno: usize) { + if let Self::Server(ref mut sender) = self { + sender.acked(seqno); + } else { + unreachable!(); + } + } +} + +struct NewTokenFrameStatus { + seqno: usize, + token: Vec<u8>, + needs_sending: bool, +} + +impl NewTokenFrameStatus { + fn len(&self) -> usize { + 1 + Encoder::vvec_len(self.token.len()) + } +} + +#[derive(Default)] +pub struct NewTokenSender { + /// The unacknowledged NEW_TOKEN frames we are yet to send. + tokens: Vec<NewTokenFrameStatus>, + /// A sequence number that is used to track individual tokens + /// by reference (so that recovery tokens can be simple). + next_seqno: usize, +} + +impl NewTokenSender { + /// Add a token to be sent. + pub fn send_new_token(&mut self, token: Vec<u8>) { + self.tokens.push(NewTokenFrameStatus { + seqno: self.next_seqno, + token, + needs_sending: true, + }); + self.next_seqno += 1; + } + + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) -> Res<()> { + for t in self.tokens.iter_mut() { + if t.needs_sending && t.len() <= builder.remaining() { + t.needs_sending = false; + + builder.encode_varint(crate::frame::FRAME_TYPE_NEW_TOKEN); + builder.encode_vvec(&t.token); + + tokens.push(RecoveryToken::NewToken(t.seqno)); + stats.new_token += 1; + } + } + Ok(()) + } + + pub fn lost(&mut self, seqno: usize) { + for t in self.tokens.iter_mut() { + if t.seqno == seqno { + t.needs_sending = true; + break; + } + } + } + + pub fn acked(&mut self, seqno: usize) { + self.tokens.retain(|i| i.seqno != seqno); + } +} + +#[cfg(test)] +mod tests { + use neqo_common::Role; + + use super::NewTokenState; + + const ONE: &[u8] = &[1, 2, 3]; + const TWO: &[u8] = &[4, 5]; + + #[test] + fn duplicate_saved() { + let mut tokens = NewTokenState::new(Role::Client); + tokens.save_token(ONE.to_vec()); + tokens.save_token(TWO.to_vec()); + tokens.save_token(ONE.to_vec()); + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably TWO + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably ONE + assert!(!tokens.has_token()); + assert!(tokens.take_token().is_none()); + } + + #[test] + fn duplicate_after_take() { + let mut tokens = NewTokenState::new(Role::Client); + tokens.save_token(ONE.to_vec()); + tokens.save_token(TWO.to_vec()); + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably TWO + tokens.save_token(ONE.to_vec()); + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably ONE + assert!(!tokens.has_token()); + assert!(tokens.take_token().is_none()); + } + + #[test] + fn duplicate_after_empty() { + let mut tokens = NewTokenState::new(Role::Client); + tokens.save_token(ONE.to_vec()); + tokens.save_token(TWO.to_vec()); + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably TWO + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably ONE + tokens.save_token(ONE.to_vec()); + assert!(!tokens.has_token()); + assert!(tokens.take_token().is_none()); + } +} diff --git a/third_party/rust/neqo-transport/src/cc/classic_cc.rs b/third_party/rust/neqo-transport/src/cc/classic_cc.rs new file mode 100644 index 0000000000..6f4a01d795 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/classic_cc.rs @@ -0,0 +1,1186 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] + +use std::{ + cmp::{max, min}, + fmt::{self, Debug, Display}, + time::{Duration, Instant}, +}; + +use super::CongestionControl; +use crate::{ + cc::MAX_DATAGRAM_SIZE, + packet::PacketNumber, + qlog::{self, QlogMetric}, + rtt::RttEstimate, + sender::PACING_BURST_SIZE, + tracking::SentPacket, +}; +#[rustfmt::skip] // to keep `::` and thus prevent conflict with `crate::qlog` +use ::qlog::events::{quic::CongestionStateUpdated, EventData}; +use neqo_common::{const_max, const_min, qdebug, qinfo, qlog::NeqoQlog, qtrace}; + +pub const CWND_INITIAL_PKTS: usize = 10; +pub const CWND_INITIAL: usize = const_min( + CWND_INITIAL_PKTS * MAX_DATAGRAM_SIZE, + const_max(2 * MAX_DATAGRAM_SIZE, 14720), +); +pub const CWND_MIN: usize = MAX_DATAGRAM_SIZE * 2; +const PERSISTENT_CONG_THRESH: u32 = 3; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum State { + /// In either slow start or congestion avoidance, not recovery. + SlowStart, + /// In congestion avoidance. + CongestionAvoidance, + /// In a recovery period, but no packets have been sent yet. This is a + /// transient state because we want to exempt the first packet sent after + /// entering recovery from the congestion window. + RecoveryStart, + /// In a recovery period, with the first packet sent at this time. + Recovery, + /// Start of persistent congestion, which is transient, like `RecoveryStart`. + PersistentCongestion, +} + +impl State { + pub fn in_recovery(self) -> bool { + matches!(self, Self::RecoveryStart | Self::Recovery) + } + + pub fn in_slow_start(self) -> bool { + self == Self::SlowStart + } + + /// These states are transient, we tell qlog on entry, but not on exit. + pub fn transient(self) -> bool { + matches!(self, Self::RecoveryStart | Self::PersistentCongestion) + } + + /// Update a transient state to the true state. + pub fn update(&mut self) { + *self = match self { + Self::PersistentCongestion => Self::SlowStart, + Self::RecoveryStart => Self::Recovery, + _ => unreachable!(), + }; + } + + pub fn to_qlog(self) -> &'static str { + match self { + Self::SlowStart | Self::PersistentCongestion => "slow_start", + Self::CongestionAvoidance => "congestion_avoidance", + Self::Recovery | Self::RecoveryStart => "recovery", + } + } +} + +pub trait WindowAdjustment: Display + Debug { + /// This is called when an ack is received. + /// The function calculates the amount of acked bytes congestion controller needs + /// to collect before increasing its cwnd by `MAX_DATAGRAM_SIZE`. + fn bytes_for_cwnd_increase( + &mut self, + curr_cwnd: usize, + new_acked_bytes: usize, + min_rtt: Duration, + now: Instant, + ) -> usize; + /// This function is called when a congestion event has beed detected and it + /// returns new (decreased) values of `curr_cwnd` and `acked_bytes`. + /// This value can be very small; the calling code is responsible for ensuring that the + /// congestion window doesn't drop below the minimum of `CWND_MIN`. + fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize); + /// Cubic needs this signal to reset its epoch. + fn on_app_limited(&mut self); + #[cfg(test)] + fn last_max_cwnd(&self) -> f64; + #[cfg(test)] + fn set_last_max_cwnd(&mut self, last_max_cwnd: f64); +} + +#[derive(Debug)] +pub struct ClassicCongestionControl<T> { + cc_algorithm: T, + state: State, + congestion_window: usize, // = kInitialWindow + bytes_in_flight: usize, + acked_bytes: usize, + ssthresh: usize, + recovery_start: Option<PacketNumber>, + /// `first_app_limited` indicates the packet number after which the application might be + /// underutilizing the congestion window. When underutilizing the congestion window due to not + /// sending out enough data, we SHOULD NOT increase the congestion window.[1] Packets sent + /// before this point are deemed to fully utilize the congestion window and count towards + /// increasing the congestion window. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc9002#section-7.8 + first_app_limited: PacketNumber, + + qlog: NeqoQlog, +} + +impl<T: WindowAdjustment> Display for ClassicCongestionControl<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} CongCtrl {}/{} ssthresh {}", + self.cc_algorithm, self.bytes_in_flight, self.congestion_window, self.ssthresh, + )?; + Ok(()) + } +} + +impl<T: WindowAdjustment> CongestionControl for ClassicCongestionControl<T> { + fn set_qlog(&mut self, qlog: NeqoQlog) { + self.qlog = qlog; + } + + #[must_use] + fn cwnd(&self) -> usize { + self.congestion_window + } + + #[must_use] + fn bytes_in_flight(&self) -> usize { + self.bytes_in_flight + } + + #[must_use] + fn cwnd_avail(&self) -> usize { + // BIF can be higher than cwnd due to PTO packets, which are sent even + // if avail is 0, but still count towards BIF. + self.congestion_window.saturating_sub(self.bytes_in_flight) + } + + // Multi-packet version of OnPacketAckedCC + fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], rtt_est: &RttEstimate, now: Instant) { + let mut is_app_limited = true; + let mut new_acked = 0; + for pkt in acked_pkts { + qinfo!( + "packet_acked this={:p}, pn={}, ps={}, ignored={}, lost={}, rtt_est={:?}", + self, + pkt.pn, + pkt.size, + i32::from(!pkt.cc_outstanding()), + i32::from(pkt.lost()), + rtt_est, + ); + if !pkt.cc_outstanding() { + continue; + } + if pkt.pn < self.first_app_limited { + is_app_limited = false; + } + assert!(self.bytes_in_flight >= pkt.size); + self.bytes_in_flight -= pkt.size; + + if !self.after_recovery_start(pkt) { + // Do not increase congestion window for packets sent before + // recovery last started. + continue; + } + + if self.state.in_recovery() { + self.set_state(State::CongestionAvoidance); + qlog::metrics_updated(&mut self.qlog, &[QlogMetric::InRecovery(false)]); + } + + new_acked += pkt.size; + } + + if is_app_limited { + self.cc_algorithm.on_app_limited(); + qinfo!("on_packets_acked this={:p}, limited=1, bytes_in_flight={}, cwnd={}, state={:?}, new_acked={}", self, self.bytes_in_flight, self.congestion_window, self.state, new_acked); + return; + } + + // Slow start, up to the slow start threshold. + if self.congestion_window < self.ssthresh { + self.acked_bytes += new_acked; + let increase = min(self.ssthresh - self.congestion_window, self.acked_bytes); + self.congestion_window += increase; + self.acked_bytes -= increase; + qinfo!([self], "slow start += {}", increase); + if self.congestion_window == self.ssthresh { + // This doesn't look like it is necessary, but it can happen + // after persistent congestion. + self.set_state(State::CongestionAvoidance); + } + } + // Congestion avoidance, above the slow start threshold. + if self.congestion_window >= self.ssthresh { + // The following function return the amount acked bytes a controller needs + // to collect to be allowed to increase its cwnd by MAX_DATAGRAM_SIZE. + let bytes_for_increase = self.cc_algorithm.bytes_for_cwnd_increase( + self.congestion_window, + new_acked, + rtt_est.minimum(), + now, + ); + debug_assert!(bytes_for_increase > 0); + // If enough credit has been accumulated already, apply them gradually. + // If we have sudden increase in allowed rate we actually increase cwnd gently. + if self.acked_bytes >= bytes_for_increase { + self.acked_bytes = 0; + self.congestion_window += MAX_DATAGRAM_SIZE; + } + self.acked_bytes += new_acked; + if self.acked_bytes >= bytes_for_increase { + self.acked_bytes -= bytes_for_increase; + self.congestion_window += MAX_DATAGRAM_SIZE; // or is this the current MTU? + } + // The number of bytes we require can go down over time with Cubic. + // That might result in an excessive rate of increase, so limit the number of unused + // acknowledged bytes after increasing the congestion window twice. + self.acked_bytes = min(bytes_for_increase, self.acked_bytes); + } + qlog::metrics_updated( + &mut self.qlog, + &[ + QlogMetric::CongestionWindow(self.congestion_window), + QlogMetric::BytesInFlight(self.bytes_in_flight), + ], + ); + qinfo!([self], "on_packets_acked this={:p}, limited=0, bytes_in_flight={}, cwnd={}, state={:?}, new_acked={}", self, self.bytes_in_flight, self.congestion_window, self.state, new_acked); + } + + /// Update congestion controller state based on lost packets. + fn on_packets_lost( + &mut self, + first_rtt_sample_time: Option<Instant>, + prev_largest_acked_sent: Option<Instant>, + pto: Duration, + lost_packets: &[SentPacket], + ) -> bool { + if lost_packets.is_empty() { + return false; + } + + for pkt in lost_packets.iter().filter(|pkt| pkt.cc_in_flight()) { + qinfo!( + "packet_lost this={:p}, pn={}, ps={}", + self, + pkt.pn, + pkt.size + ); + assert!(self.bytes_in_flight >= pkt.size); + self.bytes_in_flight -= pkt.size; + } + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::BytesInFlight(self.bytes_in_flight)], + ); + + let congestion = self.on_congestion_event(lost_packets.last().unwrap()); + let persistent_congestion = self.detect_persistent_congestion( + first_rtt_sample_time, + prev_largest_acked_sent, + pto, + lost_packets, + ); + qinfo!( + "on_packets_lost this={:p}, bytes_in_flight={}, cwnd={}, state={:?}", + self, + self.bytes_in_flight, + self.congestion_window, + self.state + ); + congestion || persistent_congestion + } + + fn discard(&mut self, pkt: &SentPacket) { + if pkt.cc_outstanding() { + assert!(self.bytes_in_flight >= pkt.size); + self.bytes_in_flight -= pkt.size; + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::BytesInFlight(self.bytes_in_flight)], + ); + qtrace!([self], "Ignore pkt with size {}", pkt.size); + } + } + + fn discard_in_flight(&mut self) { + self.bytes_in_flight = 0; + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::BytesInFlight(self.bytes_in_flight)], + ); + } + + fn on_packet_sent(&mut self, pkt: &SentPacket) { + // Record the recovery time and exit any transient state. + if self.state.transient() { + self.recovery_start = Some(pkt.pn); + self.state.update(); + } + + if !pkt.cc_in_flight() { + return; + } + if !self.app_limited() { + // Given the current non-app-limited condition, we're fully utilizing the congestion + // window. Assume that all in-flight packets up to this one are NOT app-limited. + // However, subsequent packets might be app-limited. Set `first_app_limited` to the + // next packet number. + self.first_app_limited = pkt.pn + 1; + } + + self.bytes_in_flight += pkt.size; + qinfo!( + "packet_sent this={:p}, pn={}, ps={}", + self, + pkt.pn, + pkt.size + ); + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::BytesInFlight(self.bytes_in_flight)], + ); + } + + /// Whether a packet can be sent immediately as a result of entering recovery. + fn recovery_packet(&self) -> bool { + self.state == State::RecoveryStart + } +} + +impl<T: WindowAdjustment> ClassicCongestionControl<T> { + pub fn new(cc_algorithm: T) -> Self { + Self { + cc_algorithm, + state: State::SlowStart, + congestion_window: CWND_INITIAL, + bytes_in_flight: 0, + acked_bytes: 0, + ssthresh: usize::MAX, + recovery_start: None, + qlog: NeqoQlog::disabled(), + first_app_limited: 0, + } + } + + #[cfg(test)] + #[must_use] + pub fn ssthresh(&self) -> usize { + self.ssthresh + } + + #[cfg(test)] + pub fn set_ssthresh(&mut self, v: usize) { + self.ssthresh = v; + } + + #[cfg(test)] + pub fn last_max_cwnd(&self) -> f64 { + self.cc_algorithm.last_max_cwnd() + } + + #[cfg(test)] + pub fn set_last_max_cwnd(&mut self, last_max_cwnd: f64) { + self.cc_algorithm.set_last_max_cwnd(last_max_cwnd); + } + + #[cfg(test)] + pub fn acked_bytes(&self) -> usize { + self.acked_bytes + } + + fn set_state(&mut self, state: State) { + if self.state != state { + qdebug!([self], "state -> {:?}", state); + let old_state = self.state; + self.qlog.add_event_data(|| { + // No need to tell qlog about exit from transient states. + if old_state.transient() { + None + } else { + let ev_data = EventData::CongestionStateUpdated(CongestionStateUpdated { + old: Some(old_state.to_qlog().to_owned()), + new: state.to_qlog().to_owned(), + trigger: None, + }); + Some(ev_data) + } + }); + self.state = state; + } + } + + fn detect_persistent_congestion( + &mut self, + first_rtt_sample_time: Option<Instant>, + prev_largest_acked_sent: Option<Instant>, + pto: Duration, + lost_packets: &[SentPacket], + ) -> bool { + if first_rtt_sample_time.is_none() { + return false; + } + + let pc_period = pto * PERSISTENT_CONG_THRESH; + + let mut last_pn = 1 << 62; // Impossibly large, but not enough to overflow. + let mut start = None; + + // Look for the first lost packet after the previous largest acknowledged. + // Ignore packets that weren't ack-eliciting for the start of this range. + // Also, make sure to ignore any packets sent before we got an RTT estimate + // as we might not have sent PTO packets soon enough after those. + let cutoff = max(first_rtt_sample_time, prev_largest_acked_sent); + for p in lost_packets + .iter() + .skip_while(|p| Some(p.time_sent) < cutoff) + { + if p.pn != last_pn + 1 { + // Not a contiguous range of lost packets, start over. + start = None; + } + last_pn = p.pn; + if !p.cc_in_flight() { + // Not interesting, keep looking. + continue; + } + if let Some(t) = start { + let elapsed = p + .time_sent + .checked_duration_since(t) + .expect("time is monotonic"); + if elapsed > pc_period { + qinfo!([self], "persistent congestion"); + self.congestion_window = CWND_MIN; + self.acked_bytes = 0; + self.set_state(State::PersistentCongestion); + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::CongestionWindow(self.congestion_window)], + ); + return true; + } + } else { + start = Some(p.time_sent); + } + } + false + } + + #[must_use] + fn after_recovery_start(&mut self, packet: &SentPacket) -> bool { + // At the start of the recovery period, the state is transient and + // all packets will have been sent before recovery. When sending out + // the first packet we transition to the non-transient `Recovery` + // state and update the variable `self.recovery_start`. Before the + // first recovery, all packets were sent after the recovery event, + // allowing to reduce the cwnd on congestion events. + !self.state.transient() && self.recovery_start.map_or(true, |pn| packet.pn >= pn) + } + + /// Handle a congestion event. + /// Returns true if this was a true congestion event. + fn on_congestion_event(&mut self, last_packet: &SentPacket) -> bool { + // Start a new congestion event if lost packet was sent after the start + // of the previous congestion recovery period. + if !self.after_recovery_start(last_packet) { + return false; + } + + let (cwnd, acked_bytes) = self + .cc_algorithm + .reduce_cwnd(self.congestion_window, self.acked_bytes); + self.congestion_window = max(cwnd, CWND_MIN); + self.acked_bytes = acked_bytes; + self.ssthresh = self.congestion_window; + qinfo!( + [self], + "Cong event -> recovery; cwnd {}, ssthresh {}", + self.congestion_window, + self.ssthresh + ); + qlog::metrics_updated( + &mut self.qlog, + &[ + QlogMetric::CongestionWindow(self.congestion_window), + QlogMetric::SsThresh(self.ssthresh), + QlogMetric::InRecovery(true), + ], + ); + self.set_state(State::RecoveryStart); + true + } + + #[allow(clippy::unused_self)] + fn app_limited(&self) -> bool { + if self.bytes_in_flight >= self.congestion_window { + false + } else if self.state.in_slow_start() { + // Allow for potential doubling of the congestion window during slow start. + // That is, the application might not have been able to send enough to respond + // to increases to the congestion window. + self.bytes_in_flight < self.congestion_window / 2 + } else { + // We're not limited if the in-flight data is within a single burst of the + // congestion window. + (self.bytes_in_flight + MAX_DATAGRAM_SIZE * PACING_BURST_SIZE) < self.congestion_window + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + convert::TryFrom, + time::{Duration, Instant}, + }; + + use neqo_common::qinfo; + use test_fixture::now; + + use super::{ + ClassicCongestionControl, WindowAdjustment, CWND_INITIAL, CWND_MIN, PERSISTENT_CONG_THRESH, + }; + use crate::{ + cc::{ + classic_cc::State, + cubic::{Cubic, CUBIC_BETA_USIZE_DIVIDEND, CUBIC_BETA_USIZE_DIVISOR}, + new_reno::NewReno, + CongestionControl, CongestionControlAlgorithm, CWND_INITIAL_PKTS, MAX_DATAGRAM_SIZE, + }, + packet::{PacketNumber, PacketType}, + rtt::RttEstimate, + tracking::SentPacket, + }; + + const PTO: Duration = Duration::from_millis(100); + const RTT: Duration = Duration::from_millis(98); + const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(98)); + const ZERO: Duration = Duration::from_secs(0); + const EPSILON: Duration = Duration::from_nanos(1); + const GAP: Duration = Duration::from_secs(1); + /// The largest time between packets without causing persistent congestion. + const SUB_PC: Duration = Duration::from_millis(100 * PERSISTENT_CONG_THRESH as u64); + /// The minimum time between packets to cause persistent congestion. + /// Uses an odd expression because `Duration` arithmetic isn't `const`. + const PC: Duration = Duration::from_nanos(100_000_000 * (PERSISTENT_CONG_THRESH as u64) + 1); + + fn cwnd_is_default(cc: &ClassicCongestionControl<NewReno>) { + assert_eq!(cc.cwnd(), CWND_INITIAL); + assert_eq!(cc.ssthresh(), usize::MAX); + } + + fn cwnd_is_halved(cc: &ClassicCongestionControl<NewReno>) { + assert_eq!(cc.cwnd(), CWND_INITIAL / 2); + assert_eq!(cc.ssthresh(), CWND_INITIAL / 2); + } + + fn lost(pn: PacketNumber, ack_eliciting: bool, t: Duration) -> SentPacket { + SentPacket::new( + PacketType::Short, + pn, + now() + t, + ack_eliciting, + Vec::new(), + 100, + ) + } + + fn congestion_control(cc: CongestionControlAlgorithm) -> Box<dyn CongestionControl> { + match cc { + CongestionControlAlgorithm::NewReno => { + Box::new(ClassicCongestionControl::new(NewReno::default())) + } + CongestionControlAlgorithm::Cubic => { + Box::new(ClassicCongestionControl::new(Cubic::default())) + } + } + } + + fn persistent_congestion_by_algorithm( + cc_alg: CongestionControlAlgorithm, + reduced_cwnd: usize, + lost_packets: &[SentPacket], + persistent_expected: bool, + ) { + let mut cc = congestion_control(cc_alg); + for p in lost_packets { + cc.on_packet_sent(p); + } + + cc.on_packets_lost(Some(now()), None, PTO, lost_packets); + + let persistent = if cc.cwnd() == reduced_cwnd { + false + } else if cc.cwnd() == CWND_MIN { + true + } else { + panic!("unexpected cwnd"); + }; + assert_eq!(persistent, persistent_expected); + } + + fn persistent_congestion(lost_packets: &[SentPacket], persistent_expected: bool) { + persistent_congestion_by_algorithm( + CongestionControlAlgorithm::NewReno, + CWND_INITIAL / 2, + lost_packets, + persistent_expected, + ); + persistent_congestion_by_algorithm( + CongestionControlAlgorithm::Cubic, + CWND_INITIAL * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR, + lost_packets, + persistent_expected, + ); + } + + /// A span of exactly the PC threshold only reduces the window on loss. + #[test] + fn persistent_congestion_none() { + persistent_congestion(&[lost(1, true, ZERO), lost(2, true, SUB_PC)], false); + } + + /// A span of just more than the PC threshold causes persistent congestion. + #[test] + fn persistent_congestion_simple() { + persistent_congestion(&[lost(1, true, ZERO), lost(2, true, PC)], true); + } + + /// Both packets need to be ack-eliciting. + #[test] + fn persistent_congestion_non_ack_eliciting() { + persistent_congestion(&[lost(1, false, ZERO), lost(2, true, PC)], false); + persistent_congestion(&[lost(1, true, ZERO), lost(2, false, PC)], false); + } + + /// Packets in the middle, of any type, are OK. + #[test] + fn persistent_congestion_middle() { + persistent_congestion( + &[lost(1, true, ZERO), lost(2, false, RTT), lost(3, true, PC)], + true, + ); + persistent_congestion( + &[lost(1, true, ZERO), lost(2, true, RTT), lost(3, true, PC)], + true, + ); + } + + /// Leading non-ack-eliciting packets are skipped. + #[test] + fn persistent_congestion_leading_non_ack_eliciting() { + persistent_congestion( + &[lost(1, false, ZERO), lost(2, true, RTT), lost(3, true, PC)], + false, + ); + persistent_congestion( + &[ + lost(1, false, ZERO), + lost(2, true, RTT), + lost(3, true, RTT + PC), + ], + true, + ); + } + + /// Trailing non-ack-eliciting packets aren't relevant. + #[test] + fn persistent_congestion_trailing_non_ack_eliciting() { + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PC), + lost(3, false, PC + EPSILON), + ], + true, + ); + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, SUB_PC), + lost(3, false, PC), + ], + false, + ); + } + + /// Gaps in the middle, of any type, restart the count. + #[test] + fn persistent_congestion_gap_reset() { + persistent_congestion(&[lost(1, true, ZERO), lost(3, true, PC)], false); + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, RTT), + lost(4, true, GAP), + lost(5, true, GAP + PTO * PERSISTENT_CONG_THRESH), + ], + false, + ); + } + + /// A span either side of a gap will cause persistent congestion. + #[test] + fn persistent_congestion_gap_or() { + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PC), + lost(4, true, GAP), + lost(5, true, GAP + PTO), + ], + true, + ); + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, true, GAP), + lost(5, true, GAP + PC), + ], + true, + ); + } + + /// A gap only restarts after an ack-eliciting packet. + #[test] + fn persistent_congestion_gap_non_ack_eliciting() { + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, false, GAP), + lost(5, true, GAP + PC), + ], + false, + ); + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, false, GAP), + lost(5, true, GAP + RTT), + lost(6, true, GAP + RTT + SUB_PC), + ], + false, + ); + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, false, GAP), + lost(5, true, GAP + RTT), + lost(6, true, GAP + RTT + PC), + ], + true, + ); + } + + /// Get a time, in multiples of `PTO`, relative to `now()`. + fn by_pto(t: u32) -> Instant { + now() + (PTO * t) + } + + /// Make packets that will be made lost. + /// `times` is the time of sending, in multiples of `PTO`, relative to `now()`. + fn make_lost(times: &[u32]) -> Vec<SentPacket> { + times + .iter() + .enumerate() + .map(|(i, &t)| { + SentPacket::new( + PacketType::Short, + u64::try_from(i).unwrap(), + by_pto(t), + true, + Vec::new(), + 1000, + ) + }) + .collect::<Vec<_>>() + } + + /// Call `detect_persistent_congestion` using times relative to now and the fixed PTO time. + /// `last_ack` and `rtt_time` are times in multiples of `PTO`, relative to `now()`, + /// for the time of the largest acknowledged and the first RTT sample, respectively. + fn persistent_congestion_by_pto<T: WindowAdjustment>( + mut cc: ClassicCongestionControl<T>, + last_ack: u32, + rtt_time: u32, + lost: &[SentPacket], + ) -> bool { + assert_eq!(cc.cwnd(), CWND_INITIAL); + + let last_ack = Some(by_pto(last_ack)); + let rtt_time = Some(by_pto(rtt_time)); + + // Persistent congestion is never declared if the RTT time is `None`. + cc.detect_persistent_congestion(None, None, PTO, lost); + assert_eq!(cc.cwnd(), CWND_INITIAL); + cc.detect_persistent_congestion(None, last_ack, PTO, lost); + assert_eq!(cc.cwnd(), CWND_INITIAL); + + cc.detect_persistent_congestion(rtt_time, last_ack, PTO, lost); + cc.cwnd() == CWND_MIN + } + + /// No persistent congestion can be had if there are no lost packets. + #[test] + fn persistent_congestion_no_lost() { + let lost = make_lost(&[]); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 0, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 0, + &lost + )); + } + + /// No persistent congestion can be had if there is only one lost packet. + #[test] + fn persistent_congestion_one_lost() { + let lost = make_lost(&[1]); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 0, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 0, + &lost + )); + } + + /// Persistent congestion can't happen based on old packets. + #[test] + fn persistent_congestion_past() { + // Packets sent prior to either the last acknowledged or the first RTT + // sample are not considered. So 0 is ignored. + let lost = make_lost(&[0, PERSISTENT_CONG_THRESH + 1, PERSISTENT_CONG_THRESH + 2]); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 1, + 1, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 1, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 1, + 0, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 1, + 1, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 1, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 1, + 0, + &lost + )); + } + + /// Persistent congestion doesn't start unless the packet is ack-eliciting. + #[test] + fn persistent_congestion_ack_eliciting() { + let mut lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); + lost[0] = SentPacket::new( + lost[0].pt, + lost[0].pn, + lost[0].time_sent, + false, + Vec::new(), + lost[0].size, + ); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 0, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 0, + &lost + )); + } + + /// Detect persistent congestion. Note that the first lost packet needs to have a time + /// greater than the previously acknowledged packet AND the first RTT sample. And the + /// difference in times needs to be greater than the persistent congestion threshold. + #[test] + fn persistent_congestion_min() { + let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); + assert!(persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 0, + &lost + )); + assert!(persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 0, + &lost + )); + } + + /// Make sure that not having a previous largest acknowledged also results + /// in detecting persistent congestion. (This is not expected to happen, but + /// the code permits it). + #[test] + fn persistent_congestion_no_prev_ack_newreno() { + let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); + let mut cc = ClassicCongestionControl::new(NewReno::default()); + cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost); + assert_eq!(cc.cwnd(), CWND_MIN); + } + + #[test] + fn persistent_congestion_no_prev_ack_cubic() { + let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); + let mut cc = ClassicCongestionControl::new(Cubic::default()); + cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost); + assert_eq!(cc.cwnd(), CWND_MIN); + } + + /// The code asserts on ordering errors. + #[test] + #[should_panic(expected = "time is monotonic")] + fn persistent_congestion_unsorted_newreno() { + let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 0, + &lost + )); + } + + /// The code asserts on ordering errors. + #[test] + #[should_panic(expected = "time is monotonic")] + fn persistent_congestion_unsorted_cubic() { + let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 0, + &lost + )); + } + + #[test] + fn app_limited_slow_start() { + const BELOW_APP_LIMIT_PKTS: usize = 5; + const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1; + let mut cc = ClassicCongestionControl::new(NewReno::default()); + let cwnd = cc.congestion_window; + let mut now = now(); + let mut next_pn = 0; + + // simulate packet bursts below app_limit + for packet_burst_size in 1..=BELOW_APP_LIMIT_PKTS { + // always stay below app_limit during sent. + let mut pkts = Vec::new(); + for _ in 0..packet_burst_size { + let p = SentPacket::new( + PacketType::Short, + next_pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + next_pn += 1; + cc.on_packet_sent(&p); + pkts.push(p); + } + assert_eq!(cc.bytes_in_flight(), packet_burst_size * MAX_DATAGRAM_SIZE); + now += RTT; + cc.on_packets_acked(&pkts, &RTT_ESTIMATE, now); + assert_eq!(cc.bytes_in_flight(), 0); + assert_eq!(cc.acked_bytes, 0); + assert_eq!(cwnd, cc.congestion_window); // CWND doesn't grow because we're app limited + } + + // Fully utilize the congestion window by sending enough packets to + // have `bytes_in_flight` above the `app_limited` threshold. + let mut pkts = Vec::new(); + for _ in 0..ABOVE_APP_LIMIT_PKTS { + let p = SentPacket::new( + PacketType::Short, + next_pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + next_pn += 1; + cc.on_packet_sent(&p); + pkts.push(p); + } + assert_eq!( + cc.bytes_in_flight(), + ABOVE_APP_LIMIT_PKTS * MAX_DATAGRAM_SIZE + ); + now += RTT; + // Check if congestion window gets increased for all packets currently in flight + for (i, pkt) in pkts.into_iter().enumerate() { + cc.on_packets_acked(&[pkt], &RTT_ESTIMATE, now); + + assert_eq!( + cc.bytes_in_flight(), + (ABOVE_APP_LIMIT_PKTS - i - 1) * MAX_DATAGRAM_SIZE + ); + // increase acked_bytes with each packet + qinfo!("{} {}", cc.congestion_window, cwnd + i * MAX_DATAGRAM_SIZE); + assert_eq!(cc.congestion_window, cwnd + (i + 1) * MAX_DATAGRAM_SIZE); + assert_eq!(cc.acked_bytes, 0); + } + } + + #[test] + fn app_limited_congestion_avoidance() { + const CWND_PKTS_CA: usize = CWND_INITIAL_PKTS / 2; + const BELOW_APP_LIMIT_PKTS: usize = CWND_PKTS_CA - 2; + const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1; + + let mut cc = ClassicCongestionControl::new(NewReno::default()); + let mut now = now(); + + // Change state to congestion avoidance by introducing loss. + + let p_lost = SentPacket::new( + PacketType::Short, + 1, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packet_sent(&p_lost); + cwnd_is_default(&cc); + now += PTO; + cc.on_packets_lost(Some(now), None, PTO, &[p_lost]); + cwnd_is_halved(&cc); + let p_not_lost = SentPacket::new( + PacketType::Short, + 2, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packet_sent(&p_not_lost); + now += RTT; + cc.on_packets_acked(&[p_not_lost], &RTT_ESTIMATE, now); + cwnd_is_halved(&cc); + // cc is app limited therefore cwnd in not increased. + assert_eq!(cc.acked_bytes, 0); + + // Now we are in the congestion avoidance state. + assert_eq!(cc.state, State::CongestionAvoidance); + // simulate packet bursts below app_limit + let mut next_pn = 3; + for packet_burst_size in 1..=BELOW_APP_LIMIT_PKTS { + // always stay below app_limit during sent. + let mut pkts = Vec::new(); + for _ in 0..packet_burst_size { + let p = SentPacket::new( + PacketType::Short, + next_pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + next_pn += 1; + cc.on_packet_sent(&p); + pkts.push(p); + } + assert_eq!(cc.bytes_in_flight(), packet_burst_size * MAX_DATAGRAM_SIZE); + now += RTT; + for (i, pkt) in pkts.into_iter().enumerate() { + cc.on_packets_acked(&[pkt], &RTT_ESTIMATE, now); + + assert_eq!( + cc.bytes_in_flight(), + (packet_burst_size - i - 1) * MAX_DATAGRAM_SIZE + ); + cwnd_is_halved(&cc); // CWND doesn't grow because we're app limited + assert_eq!(cc.acked_bytes, 0); + } + } + + // Fully utilize the congestion window by sending enough packets to + // have `bytes_in_flight` above the `app_limited` threshold. + let mut pkts = Vec::new(); + for _ in 0..ABOVE_APP_LIMIT_PKTS { + let p = SentPacket::new( + PacketType::Short, + next_pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + next_pn += 1; + cc.on_packet_sent(&p); + pkts.push(p); + } + assert_eq!( + cc.bytes_in_flight(), + ABOVE_APP_LIMIT_PKTS * MAX_DATAGRAM_SIZE + ); + now += RTT; + let mut last_acked_bytes = 0; + // Check if congestion window gets increased for all packets currently in flight + for (i, pkt) in pkts.into_iter().enumerate() { + cc.on_packets_acked(&[pkt], &RTT_ESTIMATE, now); + + assert_eq!( + cc.bytes_in_flight(), + (ABOVE_APP_LIMIT_PKTS - i - 1) * MAX_DATAGRAM_SIZE + ); + // The cwnd doesn't increase, but the acked_bytes do, which will eventually lead to an + // increase, once the number of bytes reaches the necessary level + cwnd_is_halved(&cc); + // increase acked_bytes with each packet + assert_ne!(cc.acked_bytes, last_acked_bytes); + last_acked_bytes = cc.acked_bytes; + } + } +} diff --git a/third_party/rust/neqo-transport/src/cc/cubic.rs b/third_party/rust/neqo-transport/src/cc/cubic.rs new file mode 100644 index 0000000000..c04a29b443 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/cubic.rs @@ -0,0 +1,215 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![deny(clippy::pedantic)] + +use std::{ + convert::TryFrom, + fmt::{self, Display}, + time::{Duration, Instant}, +}; + +use neqo_common::qtrace; + +use crate::cc::{classic_cc::WindowAdjustment, MAX_DATAGRAM_SIZE_F64}; + +// CUBIC congestion control + +// C is a constant fixed to determine the aggressiveness of window +// increase in high BDP networks. +pub const CUBIC_C: f64 = 0.4; +pub const CUBIC_ALPHA: f64 = 3.0 * (1.0 - 0.7) / (1.0 + 0.7); + +// CUBIC_BETA = 0.7; +pub const CUBIC_BETA_USIZE_DIVIDEND: usize = 7; +pub const CUBIC_BETA_USIZE_DIVISOR: usize = 10; + +/// The fast convergence ratio further reduces the congestion window when a congestion event +/// occurs before reaching the previous `W_max`. +pub const CUBIC_FAST_CONVERGENCE: f64 = 0.85; // (1.0 + CUBIC_BETA) / 2.0; + +/// The minimum number of multiples of the datagram size that need +/// to be received to cause an increase in the congestion window. +/// When there is no loss, Cubic can return to exponential increase, but +/// this value reduces the magnitude of the resulting growth by a constant factor. +/// A value of 1.0 would mean a return to the rate used in slow start. +const EXPONENTIAL_GROWTH_REDUCTION: f64 = 2.0; + +/// Convert an integer congestion window value into a floating point value. +/// This has the effect of reducing larger values to `1<<53`. +/// If you have a congestion window that large, something is probably wrong. +fn convert_to_f64(v: usize) -> f64 { + let mut f_64 = f64::from(u32::try_from(v >> 21).unwrap_or(u32::MAX)); + f_64 *= 2_097_152.0; // f_64 <<= 21 + f_64 += f64::from(u32::try_from(v & 0x1f_ffff).unwrap()); + f_64 +} + +#[derive(Debug)] +pub struct Cubic { + last_max_cwnd: f64, + estimated_tcp_cwnd: f64, + k: f64, + w_max: f64, + ca_epoch_start: Option<Instant>, + tcp_acked_bytes: f64, +} + +impl Default for Cubic { + fn default() -> Self { + Self { + last_max_cwnd: 0.0, + estimated_tcp_cwnd: 0.0, + k: 0.0, + w_max: 0.0, + ca_epoch_start: None, + tcp_acked_bytes: 0.0, + } + } +} + +impl Display for Cubic { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Cubic [last_max_cwnd: {}, k: {}, w_max: {}, ca_epoch_start: {:?}]", + self.last_max_cwnd, self.k, self.w_max, self.ca_epoch_start + )?; + Ok(()) + } +} + +#[allow(clippy::doc_markdown)] +impl Cubic { + /// Original equations is: + /// K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2 RFC8312) + /// W_max is number of segments of the maximum segment size (MSS). + /// + /// K is actually the time that W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) would + /// take to increase to W_max. We use bytes not MSS units, therefore this + /// equation will be: W_cubic(t) = C*MSS*(t-K)^3 + W_max. + /// + /// From that equation we can calculate K as: + /// K = cubic_root((W_max - W_cubic) / C / MSS); + fn calc_k(&self, curr_cwnd: f64) -> f64 { + ((self.w_max - curr_cwnd) / CUBIC_C / MAX_DATAGRAM_SIZE_F64).cbrt() + } + + /// W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) + /// t is relative to the start of the congestion avoidance phase and it is in seconds. + fn w_cubic(&self, t: f64) -> f64 { + CUBIC_C * (t - self.k).powi(3) * MAX_DATAGRAM_SIZE_F64 + self.w_max + } + + fn start_epoch(&mut self, curr_cwnd_f64: f64, new_acked_f64: f64, now: Instant) { + self.ca_epoch_start = Some(now); + // reset tcp_acked_bytes and estimated_tcp_cwnd; + self.tcp_acked_bytes = new_acked_f64; + self.estimated_tcp_cwnd = curr_cwnd_f64; + if self.last_max_cwnd <= curr_cwnd_f64 { + self.w_max = curr_cwnd_f64; + self.k = 0.0; + } else { + self.w_max = self.last_max_cwnd; + self.k = self.calc_k(curr_cwnd_f64); + } + qtrace!([self], "New epoch"); + } +} + +impl WindowAdjustment for Cubic { + // This is because of the cast in the last line from f64 to usize. + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + fn bytes_for_cwnd_increase( + &mut self, + curr_cwnd: usize, + new_acked_bytes: usize, + min_rtt: Duration, + now: Instant, + ) -> usize { + let curr_cwnd_f64 = convert_to_f64(curr_cwnd); + let new_acked_f64 = convert_to_f64(new_acked_bytes); + if self.ca_epoch_start.is_none() { + // This is a start of a new congestion avoidance phase. + self.start_epoch(curr_cwnd_f64, new_acked_f64, now); + } else { + self.tcp_acked_bytes += new_acked_f64; + } + + let time_ca = self + .ca_epoch_start + .map_or(min_rtt, |t| { + if now + min_rtt < t { + // This only happens when processing old packets + // that were saved and replayed with old timestamps. + min_rtt + } else { + now + min_rtt - t + } + }) + .as_secs_f64(); + let target_cubic = self.w_cubic(time_ca); + + let tcp_cnt = self.estimated_tcp_cwnd / CUBIC_ALPHA; + while self.tcp_acked_bytes > tcp_cnt { + self.tcp_acked_bytes -= tcp_cnt; + self.estimated_tcp_cwnd += MAX_DATAGRAM_SIZE_F64; + } + + let target_cwnd = target_cubic.max(self.estimated_tcp_cwnd); + + // Calculate the number of bytes that would need to be acknowledged for an increase + // of `MAX_DATAGRAM_SIZE` to match the increase of `target - cwnd / cwnd` as defined + // in the specification (Sections 4.4 and 4.5). + // The amount of data required therefore reduces asymptotically as the target increases. + // If the target is not significantly higher than the congestion window, require a very + // large amount of acknowledged data (effectively block increases). + let mut acked_to_increase = + MAX_DATAGRAM_SIZE_F64 * curr_cwnd_f64 / (target_cwnd - curr_cwnd_f64).max(1.0); + + // Limit increase to max 1 MSS per EXPONENTIAL_GROWTH_REDUCTION ack packets. + // This effectively limits target_cwnd to (1 + 1 / EXPONENTIAL_GROWTH_REDUCTION) cwnd. + acked_to_increase = + acked_to_increase.max(EXPONENTIAL_GROWTH_REDUCTION * MAX_DATAGRAM_SIZE_F64); + acked_to_increase as usize + } + + fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize) { + let curr_cwnd_f64 = convert_to_f64(curr_cwnd); + // Fast Convergence + // If congestion event occurs before the maximum congestion window before the last + // congestion event, we reduce the the maximum congestion window and thereby W_max. + // check cwnd + MAX_DATAGRAM_SIZE instead of cwnd because with cwnd in bytes, cwnd may be + // slightly off. + self.last_max_cwnd = if curr_cwnd_f64 + MAX_DATAGRAM_SIZE_F64 < self.last_max_cwnd { + curr_cwnd_f64 * CUBIC_FAST_CONVERGENCE + } else { + curr_cwnd_f64 + }; + self.ca_epoch_start = None; + ( + curr_cwnd * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR, + acked_bytes * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR, + ) + } + + fn on_app_limited(&mut self) { + // Reset ca_epoch_start. Let it start again when the congestion controller + // exits the app-limited period. + self.ca_epoch_start = None; + } + + #[cfg(test)] + fn last_max_cwnd(&self) -> f64 { + self.last_max_cwnd + } + + #[cfg(test)] + fn set_last_max_cwnd(&mut self, last_max_cwnd: f64) { + self.last_max_cwnd = last_max_cwnd; + } +} diff --git a/third_party/rust/neqo-transport/src/cc/mod.rs b/third_party/rust/neqo-transport/src/cc/mod.rs new file mode 100644 index 0000000000..a1a43bd157 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/mod.rs @@ -0,0 +1,87 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] + +use std::{ + fmt::{Debug, Display}, + str::FromStr, + time::{Duration, Instant}, +}; + +use neqo_common::qlog::NeqoQlog; + +use crate::{path::PATH_MTU_V6, rtt::RttEstimate, tracking::SentPacket, Error}; + +mod classic_cc; +mod cubic; +mod new_reno; + +pub use classic_cc::ClassicCongestionControl; +#[cfg(test)] +pub use classic_cc::{CWND_INITIAL, CWND_INITIAL_PKTS, CWND_MIN}; +pub use cubic::Cubic; +pub use new_reno::NewReno; + +pub const MAX_DATAGRAM_SIZE: usize = PATH_MTU_V6; +#[allow(clippy::cast_precision_loss)] +pub const MAX_DATAGRAM_SIZE_F64: f64 = MAX_DATAGRAM_SIZE as f64; + +pub trait CongestionControl: Display + Debug { + fn set_qlog(&mut self, qlog: NeqoQlog); + + #[must_use] + fn cwnd(&self) -> usize; + + #[must_use] + fn bytes_in_flight(&self) -> usize; + + #[must_use] + fn cwnd_avail(&self) -> usize; + + fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], rtt_est: &RttEstimate, now: Instant); + + /// Returns true if the congestion window was reduced. + fn on_packets_lost( + &mut self, + first_rtt_sample_time: Option<Instant>, + prev_largest_acked_sent: Option<Instant>, + pto: Duration, + lost_packets: &[SentPacket], + ) -> bool; + + #[must_use] + fn recovery_packet(&self) -> bool; + + fn discard(&mut self, pkt: &SentPacket); + + fn on_packet_sent(&mut self, pkt: &SentPacket); + + fn discard_in_flight(&mut self); +} + +#[derive(Debug, Copy, Clone)] +pub enum CongestionControlAlgorithm { + NewReno, + Cubic, +} + +// A `FromStr` implementation so that this can be used in command-line interfaces. +impl FromStr for CongestionControlAlgorithm { + type Err = Error; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s.trim().to_ascii_lowercase().as_str() { + "newreno" | "reno" => Ok(Self::NewReno), + "cubic" => Ok(Self::Cubic), + _ => Err(Error::InvalidInput), + } + } +} + +#[cfg(test)] +mod tests; diff --git a/third_party/rust/neqo-transport/src/cc/new_reno.rs b/third_party/rust/neqo-transport/src/cc/new_reno.rs new file mode 100644 index 0000000000..e51b3d6cc0 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/new_reno.rs @@ -0,0 +1,51 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] + +use std::{ + fmt::{self, Display}, + time::{Duration, Instant}, +}; + +use crate::cc::classic_cc::WindowAdjustment; + +#[derive(Debug, Default)] +pub struct NewReno {} + +impl Display for NewReno { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "NewReno")?; + Ok(()) + } +} + +impl WindowAdjustment for NewReno { + fn bytes_for_cwnd_increase( + &mut self, + curr_cwnd: usize, + _new_acked_bytes: usize, + _min_rtt: Duration, + _now: Instant, + ) -> usize { + curr_cwnd + } + + fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize) { + (curr_cwnd / 2, acked_bytes / 2) + } + + fn on_app_limited(&mut self) {} + + #[cfg(test)] + fn last_max_cwnd(&self) -> f64 { + 0.0 + } + + #[cfg(test)] + fn set_last_max_cwnd(&mut self, _last_max_cwnd: f64) {} +} diff --git a/third_party/rust/neqo-transport/src/cc/tests/cubic.rs b/third_party/rust/neqo-transport/src/cc/tests/cubic.rs new file mode 100644 index 0000000000..0c82e47817 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/tests/cubic.rs @@ -0,0 +1,333 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(clippy::cast_possible_truncation)] +#![allow(clippy::cast_sign_loss)] + +use std::{ + convert::TryFrom, + ops::Sub, + time::{Duration, Instant}, +}; + +use test_fixture::now; + +use crate::{ + cc::{ + classic_cc::{ClassicCongestionControl, CWND_INITIAL}, + cubic::{ + Cubic, CUBIC_ALPHA, CUBIC_BETA_USIZE_DIVIDEND, CUBIC_BETA_USIZE_DIVISOR, CUBIC_C, + CUBIC_FAST_CONVERGENCE, + }, + CongestionControl, MAX_DATAGRAM_SIZE, MAX_DATAGRAM_SIZE_F64, + }, + packet::PacketType, + rtt::RttEstimate, + tracking::SentPacket, +}; + +const RTT: Duration = Duration::from_millis(100); +const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(100)); +const CWND_INITIAL_F64: f64 = 10.0 * MAX_DATAGRAM_SIZE_F64; +const CWND_INITIAL_10_F64: f64 = 10.0 * CWND_INITIAL_F64; +const CWND_INITIAL_10: usize = 10 * CWND_INITIAL; +const CWND_AFTER_LOSS: usize = CWND_INITIAL * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR; +const CWND_AFTER_LOSS_SLOW_START: usize = + (CWND_INITIAL + MAX_DATAGRAM_SIZE) * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR; + +fn fill_cwnd(cc: &mut ClassicCongestionControl<Cubic>, mut next_pn: u64, now: Instant) -> u64 { + while cc.bytes_in_flight() < cc.cwnd() { + let sent = SentPacket::new( + PacketType::Short, + next_pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packet_sent(&sent); + next_pn += 1; + } + next_pn +} + +fn ack_packet(cc: &mut ClassicCongestionControl<Cubic>, pn: u64, now: Instant) { + let acked = SentPacket::new( + PacketType::Short, + pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packets_acked(&[acked], &RTT_ESTIMATE, now); +} + +fn packet_lost(cc: &mut ClassicCongestionControl<Cubic>, pn: u64) { + const PTO: Duration = Duration::from_millis(120); + let p_lost = SentPacket::new( + PacketType::Short, + pn, // pn + now(), // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packets_lost(None, None, PTO, &[p_lost]); +} + +fn expected_tcp_acks(cwnd_rtt_start: usize) -> u64 { + (f64::from(i32::try_from(cwnd_rtt_start).unwrap()) / MAX_DATAGRAM_SIZE_F64 / CUBIC_ALPHA) + .round() as u64 +} + +#[test] +fn tcp_phase() { + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + + // change to congestion avoidance state. + cubic.set_ssthresh(1); + + let mut now = now(); + let start_time = now; + // helper variables to remember the next packet number to be sent/acked. + let mut next_pn_send = 0; + let mut next_pn_ack = 0; + + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + + // This will start with TCP phase. + // in this phase cwnd is increase by CUBIC_ALPHA every RTT. We can look at it as + // increase of MAX_DATAGRAM_SIZE every 1 / CUBIC_ALPHA RTTs. + // The phase will end when cwnd calculated with cubic equation is equal to TCP estimate: + // CUBIC_C * (n * RTT / CUBIC_ALPHA)^3 * MAX_DATAGRAM_SIZE = n * MAX_DATAGRAM_SIZE + // from this n = sqrt(CUBIC_ALPHA^3/ (CUBIC_C * RTT^3)). + let num_tcp_increases = (CUBIC_ALPHA.powi(3) / (CUBIC_C * RTT.as_secs_f64().powi(3))) + .sqrt() + .floor() as u64; + + for _ in 0..num_tcp_increases { + let cwnd_rtt_start = cubic.cwnd(); + // Expected acks during a period of RTT / CUBIC_ALPHA. + let acks = expected_tcp_acks(cwnd_rtt_start); + // The time between acks if they are ideally paced over a RTT. + let time_increase = RTT / u32::try_from(cwnd_rtt_start / MAX_DATAGRAM_SIZE).unwrap(); + + for _ in 0..acks { + now += time_increase; + ack_packet(&mut cubic, next_pn_ack, now); + next_pn_ack += 1; + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + } + + assert_eq!(cubic.cwnd() - cwnd_rtt_start, MAX_DATAGRAM_SIZE); + } + + // The next increase will be according to the cubic equation. + + let cwnd_rtt_start = cubic.cwnd(); + // cwnd_rtt_start has change, therefore calculate new time_increase (the time + // between acks if they are ideally paced over a RTT). + let time_increase = RTT / u32::try_from(cwnd_rtt_start / MAX_DATAGRAM_SIZE).unwrap(); + let mut num_acks = 0; // count the number of acks. until cwnd is increased by MAX_DATAGRAM_SIZE. + + while cwnd_rtt_start == cubic.cwnd() { + num_acks += 1; + now += time_increase; + ack_packet(&mut cubic, next_pn_ack, now); + next_pn_ack += 1; + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + } + + // Make sure that the increase is not according to TCP equation, i.e., that it took + // less than RTT / CUBIC_ALPHA. + let expected_ack_tcp_increase = expected_tcp_acks(cwnd_rtt_start); + assert!(num_acks < expected_ack_tcp_increase); + + // This first increase after a TCP phase may be shorter than what it would take by a regular + // cubic phase, because of the proper byte counting and the credit it already had before + // entering this phase. Therefore We will perform another round and compare it to expected + // increase using the cubic equation. + + let cwnd_rtt_start_after_tcp = cubic.cwnd(); + let elapsed_time = now - start_time; + + // calculate new time_increase. + let time_increase = RTT / u32::try_from(cwnd_rtt_start_after_tcp / MAX_DATAGRAM_SIZE).unwrap(); + let mut num_acks2 = 0; // count the number of acks. until cwnd is increased by MAX_DATAGRAM_SIZE. + + while cwnd_rtt_start_after_tcp == cubic.cwnd() { + num_acks2 += 1; + now += time_increase; + ack_packet(&mut cubic, next_pn_ack, now); + next_pn_ack += 1; + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + } + + let expected_ack_tcp_increase2 = expected_tcp_acks(cwnd_rtt_start_after_tcp); + assert!(num_acks2 < expected_ack_tcp_increase2); + + // The time needed to increase cwnd by MAX_DATAGRAM_SIZE using the cubic equation will be + // calculates from: W_cubic(elapsed_time + t_to_increase) - W_cubis(elapsed_time) = + // MAX_DATAGRAM_SIZE => CUBIC_C * (elapsed_time + t_to_increase)^3 * MAX_DATAGRAM_SIZE + + // CWND_INITIAL - CUBIC_C * elapsed_time^3 * MAX_DATAGRAM_SIZE + CWND_INITIAL = + // MAX_DATAGRAM_SIZE => t_to_increase = cbrt((1 + CUBIC_C * elapsed_time^3) / CUBIC_C) - + // elapsed_time (t_to_increase is in seconds) + // number of ack needed is t_to_increase / time_increase. + let expected_ack_cubic_increase = + ((((1.0 + CUBIC_C * (elapsed_time).as_secs_f64().powi(3)) / CUBIC_C).cbrt() + - elapsed_time.as_secs_f64()) + / time_increase.as_secs_f64()) + .ceil() as u64; + // num_acks is very close to the calculated value. The exact value is hard to calculate + // because the proportional increase(i.e. curr_cwnd_f64 / (target - curr_cwnd_f64) * + // MAX_DATAGRAM_SIZE_F64) and the byte counting. + assert_eq!(num_acks2, expected_ack_cubic_increase + 2); +} + +#[test] +fn cubic_phase() { + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + // Set last_max_cwnd to a higher number make sure that cc is the cubic phase (cwnd is calculated + // by the cubic equation). + cubic.set_last_max_cwnd(CWND_INITIAL_10_F64); + // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. + cubic.set_ssthresh(1); + let mut now = now(); + let mut next_pn_send = 0; + let mut next_pn_ack = 0; + + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + + let k = ((CWND_INITIAL_10_F64 - CWND_INITIAL_F64) / CUBIC_C / MAX_DATAGRAM_SIZE_F64).cbrt(); + let epoch_start = now; + + // The number of RTT until W_max is reached. + let num_rtts_w_max = (k / RTT.as_secs_f64()).round() as u64; + for _ in 0..num_rtts_w_max { + let cwnd_rtt_start = cubic.cwnd(); + // Expected acks + let acks = cwnd_rtt_start / MAX_DATAGRAM_SIZE; + let time_increase = RTT / u32::try_from(acks).unwrap(); + for _ in 0..acks { + now += time_increase; + ack_packet(&mut cubic, next_pn_ack, now); + next_pn_ack += 1; + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + } + + let expected = + (CUBIC_C * ((now - epoch_start).as_secs_f64() - k).powi(3) * MAX_DATAGRAM_SIZE_F64 + + CWND_INITIAL_10_F64) + .round() as usize; + + assert_within(cubic.cwnd(), expected, MAX_DATAGRAM_SIZE); + } + assert_eq!(cubic.cwnd(), CWND_INITIAL_10); +} + +fn assert_within<T: Sub<Output = T> + PartialOrd + Copy>(value: T, expected: T, margin: T) { + if value >= expected { + assert!(value - expected < margin); + } else { + assert!(expected - value < margin); + } +} + +#[test] +fn congestion_event_slow_start() { + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + + _ = fill_cwnd(&mut cubic, 0, now()); + ack_packet(&mut cubic, 0, now()); + + assert_within(cubic.last_max_cwnd(), 0.0, f64::EPSILON); + + // cwnd is increased by 1 in slow start phase, after an ack. + assert_eq!(cubic.cwnd(), CWND_INITIAL + MAX_DATAGRAM_SIZE); + + // Trigger a congestion_event in slow start phase + packet_lost(&mut cubic, 1); + + // last_max_cwnd is equal to cwnd before decrease. + assert_within( + cubic.last_max_cwnd(), + CWND_INITIAL_F64 + MAX_DATAGRAM_SIZE_F64, + f64::EPSILON, + ); + assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS_SLOW_START); +} + +#[test] +fn congestion_event_congestion_avoidance() { + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + + // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. + cubic.set_ssthresh(1); + + // Set last_max_cwnd to something smaller than cwnd so that the fast convergence is not + // triggered. + cubic.set_last_max_cwnd(3.0 * MAX_DATAGRAM_SIZE_F64); + + _ = fill_cwnd(&mut cubic, 0, now()); + ack_packet(&mut cubic, 0, now()); + + assert_eq!(cubic.cwnd(), CWND_INITIAL); + + // Trigger a congestion_event in slow start phase + packet_lost(&mut cubic, 1); + + assert_within(cubic.last_max_cwnd(), CWND_INITIAL_F64, f64::EPSILON); + assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS); +} + +#[test] +fn congestion_event_congestion_avoidance_2() { + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + + // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. + cubic.set_ssthresh(1); + + // Set last_max_cwnd to something higher than cwnd so that the fast convergence is triggered. + cubic.set_last_max_cwnd(CWND_INITIAL_10_F64); + + _ = fill_cwnd(&mut cubic, 0, now()); + ack_packet(&mut cubic, 0, now()); + + assert_within(cubic.last_max_cwnd(), CWND_INITIAL_10_F64, f64::EPSILON); + assert_eq!(cubic.cwnd(), CWND_INITIAL); + + // Trigger a congestion_event. + packet_lost(&mut cubic, 1); + + assert_within( + cubic.last_max_cwnd(), + CWND_INITIAL_F64 * CUBIC_FAST_CONVERGENCE, + f64::EPSILON, + ); + assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS); +} + +#[test] +fn congestion_event_congestion_avoidance_test_no_overflow() { + const PTO: Duration = Duration::from_millis(120); + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + + // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. + cubic.set_ssthresh(1); + + // Set last_max_cwnd to something higher than cwnd so that the fast convergence is triggered. + cubic.set_last_max_cwnd(CWND_INITIAL_10_F64); + + _ = fill_cwnd(&mut cubic, 0, now()); + ack_packet(&mut cubic, 1, now()); + + assert_within(cubic.last_max_cwnd(), CWND_INITIAL_10_F64, f64::EPSILON); + assert_eq!(cubic.cwnd(), CWND_INITIAL); + + // Now ack packet that was send earlier. + ack_packet(&mut cubic, 0, now().checked_sub(PTO).unwrap()); +} diff --git a/third_party/rust/neqo-transport/src/cc/tests/mod.rs b/third_party/rust/neqo-transport/src/cc/tests/mod.rs new file mode 100644 index 0000000000..238a7ad012 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/tests/mod.rs @@ -0,0 +1,7 @@ +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +mod cubic; +mod new_reno; diff --git a/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs b/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs new file mode 100644 index 0000000000..a73844a755 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs @@ -0,0 +1,219 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] + +use std::time::Duration; + +use test_fixture::now; + +use crate::{ + cc::{ + new_reno::NewReno, ClassicCongestionControl, CongestionControl, CWND_INITIAL, + MAX_DATAGRAM_SIZE, + }, + packet::PacketType, + rtt::RttEstimate, + tracking::SentPacket, +}; + +const PTO: Duration = Duration::from_millis(100); +const RTT: Duration = Duration::from_millis(98); +const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(98)); + +fn cwnd_is_default(cc: &ClassicCongestionControl<NewReno>) { + assert_eq!(cc.cwnd(), CWND_INITIAL); + assert_eq!(cc.ssthresh(), usize::MAX); +} + +fn cwnd_is_halved(cc: &ClassicCongestionControl<NewReno>) { + assert_eq!(cc.cwnd(), CWND_INITIAL / 2); + assert_eq!(cc.ssthresh(), CWND_INITIAL / 2); +} + +#[test] +fn issue_876() { + let mut cc = ClassicCongestionControl::new(NewReno::default()); + let time_now = now(); + let time_before = time_now.checked_sub(Duration::from_millis(100)).unwrap(); + let time_after = time_now + Duration::from_millis(150); + + let sent_packets = &[ + SentPacket::new( + PacketType::Short, + 1, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE - 1, // size + ), + SentPacket::new( + PacketType::Short, + 2, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE - 2, // size + ), + SentPacket::new( + PacketType::Short, + 3, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 4, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 5, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 6, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 7, // pn + time_after, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE - 3, // size + ), + ]; + + // Send some more packets so that the cc is not app-limited. + for p in &sent_packets[..6] { + cc.on_packet_sent(p); + } + assert_eq!(cc.acked_bytes(), 0); + cwnd_is_default(&cc); + assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 3); + + cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[0..1]); + + // We are now in recovery + assert!(cc.recovery_packet()); + assert_eq!(cc.acked_bytes(), 0); + cwnd_is_halved(&cc); + assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2); + + // Send a packet after recovery starts + cc.on_packet_sent(&sent_packets[6]); + assert!(!cc.recovery_packet()); + cwnd_is_halved(&cc); + assert_eq!(cc.acked_bytes(), 0); + assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 5); + + // and ack it. cwnd increases slightly + cc.on_packets_acked(&sent_packets[6..], &RTT_ESTIMATE, time_now); + assert_eq!(cc.acked_bytes(), sent_packets[6].size); + cwnd_is_halved(&cc); + assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2); + + // Packet from before is lost. Should not hurt cwnd. + cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[1..2]); + assert!(!cc.recovery_packet()); + assert_eq!(cc.acked_bytes(), sent_packets[6].size); + cwnd_is_halved(&cc); + assert_eq!(cc.bytes_in_flight(), 4 * MAX_DATAGRAM_SIZE); +} + +#[test] +// https://github.com/mozilla/neqo/pull/1465 +fn issue_1465() { + let mut cc = ClassicCongestionControl::new(NewReno::default()); + let mut pn = 0; + let mut now = now(); + let mut next_packet = |now| { + let p = SentPacket::new( + PacketType::Short, + pn, // pn + now, // time_sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + pn += 1; + p + }; + let mut send_next = |cc: &mut ClassicCongestionControl<NewReno>, now| { + let p = next_packet(now); + cc.on_packet_sent(&p); + p + }; + + let p1 = send_next(&mut cc, now); + let p2 = send_next(&mut cc, now); + let p3 = send_next(&mut cc, now); + + assert_eq!(cc.acked_bytes(), 0); + cwnd_is_default(&cc); + assert_eq!(cc.bytes_in_flight(), 3 * MAX_DATAGRAM_SIZE); + + // advance one rtt to detect lost packet there this simplifies the timers, because + // on_packet_loss would only be called after RTO, but that is not relevant to the problem + now += RTT; + cc.on_packets_lost(Some(now), None, PTO, &[p1]); + + // We are now in recovery + assert!(cc.recovery_packet()); + assert_eq!(cc.acked_bytes(), 0); + cwnd_is_halved(&cc); + assert_eq!(cc.bytes_in_flight(), 2 * MAX_DATAGRAM_SIZE); + + // Don't reduce the cwnd again on second packet loss + cc.on_packets_lost(Some(now), None, PTO, &[p3]); + assert_eq!(cc.acked_bytes(), 0); + cwnd_is_halved(&cc); // still the same as after first packet loss + assert_eq!(cc.bytes_in_flight(), MAX_DATAGRAM_SIZE); + + // the acked packets before on_packet_sent were the cause of + // https://github.com/mozilla/neqo/pull/1465 + cc.on_packets_acked(&[p2], &RTT_ESTIMATE, now); + + assert_eq!(cc.bytes_in_flight(), 0); + + // send out recovery packet and get it acked to get out of recovery state + let p4 = send_next(&mut cc, now); + cc.on_packet_sent(&p4); + now += RTT; + cc.on_packets_acked(&[p4], &RTT_ESTIMATE, now); + + // do the same as in the first rtt but now the bug appears + let p5 = send_next(&mut cc, now); + let p6 = send_next(&mut cc, now); + now += RTT; + + let cur_cwnd = cc.cwnd(); + cc.on_packets_lost(Some(now), None, PTO, &[p5]); + + // go back into recovery + assert!(cc.recovery_packet()); + assert_eq!(cc.cwnd(), cur_cwnd / 2); + assert_eq!(cc.acked_bytes(), 0); + assert_eq!(cc.bytes_in_flight(), 2 * MAX_DATAGRAM_SIZE); + + // this shouldn't introduce further cwnd reduction, but it did before https://github.com/mozilla/neqo/pull/1465 + cc.on_packets_lost(Some(now), None, PTO, &[p6]); + assert_eq!(cc.cwnd(), cur_cwnd / 2); +} diff --git a/third_party/rust/neqo-transport/src/cid.rs b/third_party/rust/neqo-transport/src/cid.rs new file mode 100644 index 0000000000..be202daf25 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cid.rs @@ -0,0 +1,609 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Representation and management of connection IDs. + +use std::{ + borrow::Borrow, + cell::{Ref, RefCell}, + cmp::{max, min}, + convert::{AsRef, TryFrom}, + ops::Deref, + rc::Rc, +}; + +use neqo_common::{hex, hex_with_len, qinfo, Decoder, Encoder}; +use neqo_crypto::random; +use smallvec::SmallVec; + +use crate::{ + frame::FRAME_TYPE_NEW_CONNECTION_ID, packet::PacketBuilder, recovery::RecoveryToken, + stats::FrameStats, Error, Res, +}; + +pub const MAX_CONNECTION_ID_LEN: usize = 20; +pub const LOCAL_ACTIVE_CID_LIMIT: usize = 8; +pub const CONNECTION_ID_SEQNO_INITIAL: u64 = 0; +pub const CONNECTION_ID_SEQNO_PREFERRED: u64 = 1; +/// A special value. See `ConnectionIdManager::add_odcid`. +const CONNECTION_ID_SEQNO_ODCID: u64 = u64::MAX; +/// A special value. See `ConnectionIdEntry::empty_remote`. +const CONNECTION_ID_SEQNO_EMPTY: u64 = u64::MAX - 1; + +#[derive(Clone, Default, Eq, Hash, PartialEq)] +pub struct ConnectionId { + pub(crate) cid: SmallVec<[u8; MAX_CONNECTION_ID_LEN]>, +} + +impl ConnectionId { + pub fn generate(len: usize) -> Self { + assert!(matches!(len, 0..=MAX_CONNECTION_ID_LEN)); + Self::from(random(len)) + } + + // Apply a wee bit of greasing here in picking a length between 8 and 20 bytes long. + pub fn generate_initial() -> Self { + let v = random(1); + // Bias selection toward picking 8 (>50% of the time). + let len: usize = max(8, 5 + (v[0] & (v[0] >> 4))).into(); + Self::generate(len) + } + + pub fn as_cid_ref(&self) -> ConnectionIdRef { + ConnectionIdRef::from(&self.cid[..]) + } +} + +impl AsRef<[u8]> for ConnectionId { + fn as_ref(&self) -> &[u8] { + self.borrow() + } +} + +impl Borrow<[u8]> for ConnectionId { + fn borrow(&self) -> &[u8] { + &self.cid + } +} + +impl From<SmallVec<[u8; MAX_CONNECTION_ID_LEN]>> for ConnectionId { + fn from(cid: SmallVec<[u8; MAX_CONNECTION_ID_LEN]>) -> Self { + Self { cid } + } +} + +impl From<Vec<u8>> for ConnectionId { + fn from(cid: Vec<u8>) -> Self { + Self::from(SmallVec::from(cid)) + } +} + +impl<T: AsRef<[u8]> + ?Sized> From<&T> for ConnectionId { + fn from(buf: &T) -> Self { + Self::from(SmallVec::from(buf.as_ref())) + } +} + +impl<'a> From<ConnectionIdRef<'a>> for ConnectionId { + fn from(cidref: ConnectionIdRef<'a>) -> Self { + Self::from(SmallVec::from(cidref.cid)) + } +} + +impl std::ops::Deref for ConnectionId { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.cid + } +} + +impl ::std::fmt::Debug for ConnectionId { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "CID {}", hex_with_len(&self.cid)) + } +} + +impl ::std::fmt::Display for ConnectionId { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}", hex(&self.cid)) + } +} + +impl<'a> PartialEq<ConnectionIdRef<'a>> for ConnectionId { + fn eq(&self, other: &ConnectionIdRef<'a>) -> bool { + &self.cid[..] == other.cid + } +} + +#[derive(Hash, Eq, PartialEq, Clone, Copy)] +pub struct ConnectionIdRef<'a> { + cid: &'a [u8], +} + +impl<'a> ::std::fmt::Debug for ConnectionIdRef<'a> { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "CID {}", hex_with_len(self.cid)) + } +} + +impl<'a> ::std::fmt::Display for ConnectionIdRef<'a> { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}", hex(self.cid)) + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> From<&'a T> for ConnectionIdRef<'a> { + fn from(cid: &'a T) -> Self { + Self { cid: cid.as_ref() } + } +} + +impl<'a> std::ops::Deref for ConnectionIdRef<'a> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.cid + } +} + +impl<'a> PartialEq<ConnectionId> for ConnectionIdRef<'a> { + fn eq(&self, other: &ConnectionId) -> bool { + self.cid == &other.cid[..] + } +} + +pub trait ConnectionIdDecoder { + /// Decodes a connection ID from the provided decoder. + fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>>; +} + +pub trait ConnectionIdGenerator: ConnectionIdDecoder { + /// Generates a connection ID. This can return `None` if the generator + /// is exhausted. + fn generate_cid(&mut self) -> Option<ConnectionId>; + /// Indicates whether the connection IDs are zero-length. + /// If this returns true, `generate_cid` must always produce an empty value + /// and never `None`. + /// If this returns false, `generate_cid` must never produce an empty value, + /// though it can return `None`. + /// + /// You should not need to implement this: if you want zero-length connection IDs, + /// use `EmptyConnectionIdGenerator` instead. + fn generates_empty_cids(&self) -> bool { + false + } + fn as_decoder(&self) -> &dyn ConnectionIdDecoder; +} + +/// An `EmptyConnectionIdGenerator` generates empty connection IDs. +#[derive(Default)] +pub struct EmptyConnectionIdGenerator {} + +impl ConnectionIdDecoder for EmptyConnectionIdGenerator { + fn decode_cid<'a>(&self, _: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> { + Some(ConnectionIdRef::from(&[])) + } +} + +impl ConnectionIdGenerator for EmptyConnectionIdGenerator { + fn generate_cid(&mut self) -> Option<ConnectionId> { + Some(ConnectionId::from(&[])) + } + fn as_decoder(&self) -> &dyn ConnectionIdDecoder { + self + } + fn generates_empty_cids(&self) -> bool { + true + } +} + +/// An RandomConnectionIdGenerator produces connection IDs of +/// a fixed length and random content. No effort is made to +/// prevent collisions. +pub struct RandomConnectionIdGenerator { + len: usize, +} + +impl RandomConnectionIdGenerator { + pub fn new(len: usize) -> Self { + Self { len } + } +} + +impl ConnectionIdDecoder for RandomConnectionIdGenerator { + fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> { + dec.decode(self.len).map(ConnectionIdRef::from) + } +} + +impl ConnectionIdGenerator for RandomConnectionIdGenerator { + fn generate_cid(&mut self) -> Option<ConnectionId> { + Some(ConnectionId::from(&random(self.len))) + } + + fn as_decoder(&self) -> &dyn ConnectionIdDecoder { + self + } + + fn generates_empty_cids(&self) -> bool { + self.len == 0 + } +} + +/// A single connection ID, as saved from NEW_CONNECTION_ID. +/// This is templated so that the connection ID entries from a peer can be +/// saved with a stateless reset token. Local entries don't need that. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct ConnectionIdEntry<SRT: Clone + PartialEq> { + /// The sequence number. + seqno: u64, + /// The connection ID. + cid: ConnectionId, + /// The corresponding stateless reset token. + srt: SRT, +} + +impl ConnectionIdEntry<[u8; 16]> { + /// Create a random stateless reset token so that it is hard to guess the correct + /// value and reset the connection. + fn random_srt() -> [u8; 16] { + <[u8; 16]>::try_from(&random(16)[..]).unwrap() + } + + /// Create the first entry, which won't have a stateless reset token. + pub fn initial_remote(cid: ConnectionId) -> Self { + Self::new(CONNECTION_ID_SEQNO_INITIAL, cid, Self::random_srt()) + } + + /// Create an empty for when the peer chooses empty connection IDs. + /// This uses a special sequence number just because it can. + pub fn empty_remote() -> Self { + Self::new( + CONNECTION_ID_SEQNO_EMPTY, + ConnectionId::from(&[]), + Self::random_srt(), + ) + } + + fn token_equal(a: &[u8; 16], b: &[u8; 16]) -> bool { + // rustc might decide to optimize this and make this non-constant-time + // with respect to `t`, but it doesn't appear to currently. + let mut c = 0; + for (&a, &b) in a.iter().zip(b) { + c |= a ^ b; + } + c == 0 + } + + /// Determine whether this is a valid stateless reset. + pub fn is_stateless_reset(&self, token: &[u8; 16]) -> bool { + // A sequence number of 2^62 or more has no corresponding stateless reset token. + (self.seqno < (1 << 62)) && Self::token_equal(&self.srt, token) + } + + /// Return true if the two contain any equal parts. + fn any_part_equal(&self, other: &Self) -> bool { + self.seqno == other.seqno || self.cid == other.cid || self.srt == other.srt + } + + /// The sequence number of this entry. + pub fn sequence_number(&self) -> u64 { + self.seqno + } +} + +impl ConnectionIdEntry<()> { + /// Create an initial entry. + pub fn initial_local(cid: ConnectionId) -> Self { + Self::new(0, cid, ()) + } +} + +impl<SRT: Clone + PartialEq> ConnectionIdEntry<SRT> { + pub fn new(seqno: u64, cid: ConnectionId, srt: SRT) -> Self { + Self { seqno, cid, srt } + } + + /// Update the stateless reset token. This panics if the sequence number is non-zero. + pub fn set_stateless_reset_token(&mut self, srt: SRT) { + assert_eq!(self.seqno, CONNECTION_ID_SEQNO_INITIAL); + self.srt = srt; + } + + /// Replace the connection ID. This panics if the sequence number is non-zero. + pub fn update_cid(&mut self, cid: ConnectionId) { + assert_eq!(self.seqno, CONNECTION_ID_SEQNO_INITIAL); + self.cid = cid; + } + + pub fn connection_id(&self) -> &ConnectionId { + &self.cid + } + + pub fn reset_token(&self) -> &SRT { + &self.srt + } +} + +pub type RemoteConnectionIdEntry = ConnectionIdEntry<[u8; 16]>; + +/// A collection of connection IDs that are indexed by a sequence number. +/// Used to store connection IDs that are provided by a peer. +#[derive(Debug, Default)] +pub struct ConnectionIdStore<SRT: Clone + PartialEq> { + cids: SmallVec<[ConnectionIdEntry<SRT>; 8]>, +} + +impl<SRT: Clone + PartialEq> ConnectionIdStore<SRT> { + pub fn retire(&mut self, seqno: u64) { + self.cids.retain(|c| c.seqno != seqno); + } + + pub fn contains(&self, cid: ConnectionIdRef) -> bool { + self.cids.iter().any(|c| c.cid == cid) + } + + pub fn next(&mut self) -> Option<ConnectionIdEntry<SRT>> { + if self.cids.is_empty() { + None + } else { + Some(self.cids.remove(0)) + } + } + + pub fn len(&self) -> usize { + self.cids.len() + } +} + +impl ConnectionIdStore<[u8; 16]> { + pub fn add_remote(&mut self, entry: ConnectionIdEntry<[u8; 16]>) -> Res<()> { + // It's OK if this perfectly matches an existing entry. + if self.cids.iter().any(|c| c == &entry) { + return Ok(()); + } + // It's not OK if any individual piece matches though. + if self.cids.iter().any(|c| c.any_part_equal(&entry)) { + qinfo!("ConnectionIdStore found reused part in NEW_CONNECTION_ID"); + return Err(Error::ProtocolViolation); + } + + // Insert in order so that we use them in order where possible. + if let Err(idx) = self.cids.binary_search_by_key(&entry.seqno, |e| e.seqno) { + self.cids.insert(idx, entry); + Ok(()) + } else { + Err(Error::ProtocolViolation) + } + } + + // Retire connection IDs and return the sequence numbers of those that were retired. + pub fn retire_prior_to(&mut self, retire_prior: u64) -> Vec<u64> { + let mut retired = Vec::new(); + self.cids.retain(|e| { + if e.seqno < retire_prior { + retired.push(e.seqno); + false + } else { + true + } + }); + retired + } +} + +impl ConnectionIdStore<()> { + fn add_local(&mut self, entry: ConnectionIdEntry<()>) { + self.cids.push(entry); + } +} + +pub struct ConnectionIdDecoderRef<'a> { + generator: Ref<'a, dyn ConnectionIdGenerator>, +} + +// Ideally this would be an implementation of `Deref`, but it doesn't +// seem to be possible to convince the compiler to build anything useful. +impl<'a: 'b, 'b> ConnectionIdDecoderRef<'a> { + pub fn as_ref(&'a self) -> &'b dyn ConnectionIdDecoder { + self.generator.as_decoder() + } +} + +/// A connection ID manager looks after the generation of connection IDs, +/// the set of connection IDs that are valid for the connection, and the +/// generation of `NEW_CONNECTION_ID` frames. +pub struct ConnectionIdManager { + /// The `ConnectionIdGenerator` instance that is used to create connection IDs. + generator: Rc<RefCell<dyn ConnectionIdGenerator>>, + /// The connection IDs that we will accept. + /// This includes any we advertise in `NEW_CONNECTION_ID` that haven't been bound to a path + /// yet. During the handshake at the server, it also includes the randomized DCID pick by + /// the client. + connection_ids: ConnectionIdStore<()>, + /// The maximum number of connection IDs this will accept. This is at least 2 and won't + /// be more than `LOCAL_ACTIVE_CID_LIMIT`. + limit: usize, + /// The next sequence number that will be used for sending `NEW_CONNECTION_ID` frames. + next_seqno: u64, + /// Outstanding, but lost NEW_CONNECTION_ID frames will be stored here. + lost_new_connection_id: Vec<ConnectionIdEntry<[u8; 16]>>, +} + +impl ConnectionIdManager { + pub fn new(generator: Rc<RefCell<dyn ConnectionIdGenerator>>, initial: ConnectionId) -> Self { + let mut connection_ids = ConnectionIdStore::default(); + connection_ids.add_local(ConnectionIdEntry::initial_local(initial)); + Self { + generator, + connection_ids, + // A note about initializing the limit to 2. + // For a server, the number of connection IDs that are tracked at the point that + // it is first possible to send `NEW_CONNECTION_ID` is 2. One is the client-generated + // destination connection (stored with a sequence number of `HANDSHAKE_SEQNO`); the + // other being the handshake value (seqno 0). As a result, `NEW_CONNECTION_ID` + // won't be sent until until after the handshake completes, because this initial + // value remains until the connection completes and transport parameters are handled. + limit: 2, + next_seqno: 1, + lost_new_connection_id: Vec::new(), + } + } + + pub fn generator(&self) -> Rc<RefCell<dyn ConnectionIdGenerator>> { + Rc::clone(&self.generator) + } + + pub fn decoder(&self) -> ConnectionIdDecoderRef { + ConnectionIdDecoderRef { + generator: self.generator.deref().borrow(), + } + } + + /// Generate a connection ID and stateless reset token for a preferred address. + pub fn preferred_address_cid(&mut self) -> Res<(ConnectionId, [u8; 16])> { + if self.generator.deref().borrow().generates_empty_cids() { + return Err(Error::ConnectionIdsExhausted); + } + if let Some(cid) = self.generator.borrow_mut().generate_cid() { + assert_ne!(cid.len(), 0); + debug_assert_eq!(self.next_seqno, CONNECTION_ID_SEQNO_PREFERRED); + self.connection_ids + .add_local(ConnectionIdEntry::new(self.next_seqno, cid.clone(), ())); + self.next_seqno += 1; + + let srt = <[u8; 16]>::try_from(&random(16)[..]).unwrap(); + Ok((cid, srt)) + } else { + Err(Error::ConnectionIdsExhausted) + } + } + + pub fn is_valid(&self, cid: ConnectionIdRef) -> bool { + self.connection_ids.contains(cid) + } + + pub fn retire(&mut self, seqno: u64) { + // TODO(mt) - consider keeping connection IDs around for a short while. + + self.connection_ids.retire(seqno); + self.lost_new_connection_id.retain(|cid| cid.seqno != seqno); + } + + /// During the handshake, a server needs to regard the client's choice of destination + /// connection ID as valid. This function saves it in the store in a special place. + /// Note that this is only done *after* an Initial packet from the client is + /// successfully processed. + pub fn add_odcid(&mut self, cid: ConnectionId) { + let entry = ConnectionIdEntry::new(CONNECTION_ID_SEQNO_ODCID, cid, ()); + self.connection_ids.add_local(entry); + } + + /// Stop treating the original destination connection ID as valid. + pub fn remove_odcid(&mut self) { + self.connection_ids.retire(CONNECTION_ID_SEQNO_ODCID); + } + + pub fn set_limit(&mut self, limit: u64) { + debug_assert!(limit >= 2); + self.limit = min( + LOCAL_ACTIVE_CID_LIMIT, + usize::try_from(limit).unwrap_or(LOCAL_ACTIVE_CID_LIMIT), + ); + } + + fn write_entry( + &mut self, + entry: &ConnectionIdEntry<[u8; 16]>, + builder: &mut PacketBuilder, + stats: &mut FrameStats, + ) -> Res<bool> { + let len = 1 + Encoder::varint_len(entry.seqno) + 1 + 1 + entry.cid.len() + 16; + if builder.remaining() < len { + return Ok(false); + } + + builder.encode_varint(FRAME_TYPE_NEW_CONNECTION_ID); + builder.encode_varint(entry.seqno); + builder.encode_varint(0u64); + builder.encode_vec(1, &entry.cid); + builder.encode(&entry.srt); + stats.new_connection_id += 1; + Ok(true) + } + + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) -> Res<()> { + if self.generator.deref().borrow().generates_empty_cids() { + debug_assert_eq!(self.generator.borrow_mut().generate_cid().unwrap().len(), 0); + return Ok(()); + } + + while let Some(entry) = self.lost_new_connection_id.pop() { + if self.write_entry(&entry, builder, stats)? { + tokens.push(RecoveryToken::NewConnectionId(entry)); + } else { + // This shouldn't happen often. + self.lost_new_connection_id.push(entry); + break; + } + } + + // Keep writing while we have fewer than the limit of active connection IDs + // and while there is room for more. This uses the longest connection ID + // length to simplify (assuming Retire Prior To is just 1 byte). + while self.connection_ids.len() < self.limit && builder.remaining() >= 47 { + let maybe_cid = self.generator.borrow_mut().generate_cid(); + if let Some(cid) = maybe_cid { + assert_ne!(cid.len(), 0); + // TODO: generate the stateless reset tokens from the connection ID and a key. + let srt = <[u8; 16]>::try_from(&random(16)[..]).unwrap(); + + let seqno = self.next_seqno; + self.next_seqno += 1; + self.connection_ids + .add_local(ConnectionIdEntry::new(seqno, cid.clone(), ())); + + let entry = ConnectionIdEntry::new(seqno, cid, srt); + debug_assert!(self.write_entry(&entry, builder, stats)?); + tokens.push(RecoveryToken::NewConnectionId(entry)); + } + } + Ok(()) + } + + pub fn lost(&mut self, entry: &ConnectionIdEntry<[u8; 16]>) { + self.lost_new_connection_id.push(entry.clone()); + } + + pub fn acked(&mut self, entry: &ConnectionIdEntry<[u8; 16]>) { + self.lost_new_connection_id + .retain(|e| e.seqno != entry.seqno); + } +} + +#[cfg(test)] +mod tests { + use test_fixture::fixture_init; + + use super::*; + + #[test] + fn generate_initial_cid() { + fixture_init(); + for _ in 0..100 { + let cid = ConnectionId::generate_initial(); + if !matches!(cid.len(), 8..=MAX_CONNECTION_ID_LEN) { + panic!("connection ID {:?}", cid); + } + } + } +} diff --git a/third_party/rust/neqo-transport/src/connection/dump.rs b/third_party/rust/neqo-transport/src/connection/dump.rs new file mode 100644 index 0000000000..77d51c605c --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/dump.rs @@ -0,0 +1,46 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Enable just this file for logging to just see packets. +// e.g. "RUST_LOG=neqo_transport::dump neqo-client ..." + +use std::fmt::Write; + +use neqo_common::{qdebug, Decoder}; + +use crate::{ + connection::Connection, + frame::Frame, + packet::{PacketNumber, PacketType}, + path::PathRef, +}; + +#[allow(clippy::module_name_repetitions)] +pub fn dump_packet( + conn: &Connection, + path: &PathRef, + dir: &str, + pt: PacketType, + pn: PacketNumber, + payload: &[u8], +) { + if !log::log_enabled!(log::Level::Debug) { + return; + } + + let mut s = String::from(""); + let mut d = Decoder::from(payload); + while d.remaining() > 0 { + let Ok(f) = Frame::decode(&mut d) else { + s.push_str(" [broken]..."); + break; + }; + if let Some(x) = f.dump() { + write!(&mut s, "\n {} {}", dir, &x).unwrap(); + } + } + qdebug!([conn], "pn={} type={:?} {}{}", pn, pt, path.borrow(), s); +} diff --git a/third_party/rust/neqo-transport/src/connection/idle.rs b/third_party/rust/neqo-transport/src/connection/idle.rs new file mode 100644 index 0000000000..e33f3defb3 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/idle.rs @@ -0,0 +1,120 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cmp::{max, min}, + time::{Duration, Instant}, +}; + +use neqo_common::qtrace; + +use crate::recovery::RecoveryToken; + +#[derive(Debug, Clone)] +/// There's a little bit of different behavior for resetting idle timeout. See +/// -transport 10.2 ("Idle Timeout"). +enum IdleTimeoutState { + Init, + PacketReceived(Instant), + AckElicitingPacketSent(Instant), +} + +#[derive(Debug, Clone)] +/// There's a little bit of different behavior for resetting idle timeout. See +/// -transport 10.2 ("Idle Timeout"). +pub struct IdleTimeout { + timeout: Duration, + state: IdleTimeoutState, + keep_alive_outstanding: bool, +} + +impl IdleTimeout { + pub fn new(timeout: Duration) -> Self { + Self { + timeout, + state: IdleTimeoutState::Init, + keep_alive_outstanding: false, + } + } +} + +impl IdleTimeout { + pub fn set_peer_timeout(&mut self, peer_timeout: Duration) { + self.timeout = min(self.timeout, peer_timeout); + } + + pub fn expiry(&self, now: Instant, pto: Duration, keep_alive: bool) -> Instant { + let start = match self.state { + IdleTimeoutState::Init => now, + IdleTimeoutState::PacketReceived(t) | IdleTimeoutState::AckElicitingPacketSent(t) => t, + }; + let delay = if keep_alive && !self.keep_alive_outstanding { + // For a keep-alive timer, wait for half the timeout interval, but be sure + // not to wait too little or we will send many unnecessary probes. + max(self.timeout / 2, pto) + } else { + max(self.timeout, pto * 3) + }; + qtrace!( + "IdleTimeout::expiry@{now:?} pto={pto:?}, ka={keep_alive} => {t:?}", + t = start + delay + ); + start + delay + } + + pub fn on_packet_sent(&mut self, now: Instant) { + // Only reset idle timeout if we've received a packet since the last + // time we reset the timeout here. + match self.state { + IdleTimeoutState::AckElicitingPacketSent(_) => {} + IdleTimeoutState::Init | IdleTimeoutState::PacketReceived(_) => { + self.state = IdleTimeoutState::AckElicitingPacketSent(now); + } + } + } + + pub fn on_packet_received(&mut self, now: Instant) { + // Only update if this doesn't rewind the idle timeout. + // We sometimes process packets after caching them, which uses + // the time the packet was received. That could be in the past. + let update = match self.state { + IdleTimeoutState::Init => true, + IdleTimeoutState::AckElicitingPacketSent(t) | IdleTimeoutState::PacketReceived(t) => { + t <= now + } + }; + if update { + self.state = IdleTimeoutState::PacketReceived(now); + } + } + + pub fn expired(&self, now: Instant, pto: Duration) -> bool { + now >= self.expiry(now, pto, false) + } + + pub fn send_keep_alive( + &mut self, + now: Instant, + pto: Duration, + tokens: &mut Vec<RecoveryToken>, + ) -> bool { + if !self.keep_alive_outstanding && now >= self.expiry(now, pto, true) { + self.keep_alive_outstanding = true; + tokens.push(RecoveryToken::KeepAlive); + true + } else { + false + } + } + + pub fn lost_keep_alive(&mut self) { + self.keep_alive_outstanding = false; + } + + pub fn ack_keep_alive(&mut self) { + self.keep_alive_outstanding = false; + } +} diff --git a/third_party/rust/neqo-transport/src/connection/mod.rs b/third_party/rust/neqo-transport/src/connection/mod.rs new file mode 100644 index 0000000000..2de388418a --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/mod.rs @@ -0,0 +1,3241 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// The class implementing a QUIC connection. + +use std::{ + cell::RefCell, + cmp::{max, min}, + convert::TryFrom, + fmt::{self, Debug}, + mem, + net::{IpAddr, SocketAddr}, + ops::RangeInclusive, + rc::{Rc, Weak}, + time::{Duration, Instant}, +}; + +use neqo_common::{ + event::Provider as EventProvider, hex, hex_snip_middle, hrtime, qdebug, qerror, qinfo, + qlog::NeqoQlog, qtrace, qwarn, Datagram, Decoder, Encoder, Role, +}; +use neqo_crypto::{ + agent::CertificateInfo, random, Agent, AntiReplay, AuthenticationStatus, Cipher, Client, Group, + HandshakeState, PrivateKey, PublicKey, ResumptionToken, SecretAgentInfo, SecretAgentPreInfo, + Server, ZeroRttChecker, +}; +use smallvec::SmallVec; + +use crate::{ + addr_valid::{AddressValidation, NewTokenState}, + cid::{ + ConnectionId, ConnectionIdEntry, ConnectionIdGenerator, ConnectionIdManager, + ConnectionIdRef, ConnectionIdStore, LOCAL_ACTIVE_CID_LIMIT, + }, + crypto::{Crypto, CryptoDxState, CryptoSpace}, + events::{ConnectionEvent, ConnectionEvents, OutgoingDatagramOutcome}, + frame::{ + CloseError, Frame, FrameType, FRAME_TYPE_CONNECTION_CLOSE_APPLICATION, + FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT, + }, + packet::{DecryptedPacket, PacketBuilder, PacketNumber, PacketType, PublicPacket}, + path::{Path, PathRef, Paths}, + qlog, + quic_datagrams::{DatagramTracking, QuicDatagrams}, + recovery::{LossRecovery, RecoveryToken, SendProfile}, + recv_stream::RecvStreamStats, + rtt::GRANULARITY, + stats::{Stats, StatsCell}, + stream_id::StreamType, + streams::{SendOrder, Streams}, + tparams::{ + self, TransportParameter, TransportParameterId, TransportParameters, + TransportParametersHandler, + }, + tracking::{AckTracker, PacketNumberSpace, SentPacket}, + version::{Version, WireVersion}, + AppError, ConnectionError, Error, Res, StreamId, +}; +mod dump; +mod idle; +pub mod params; +mod saved; +mod state; +#[cfg(test)] +pub mod test_internal; +use dump::dump_packet; +use idle::IdleTimeout; +pub use params::ConnectionParameters; +use params::PreferredAddressConfig; +#[cfg(test)] +pub use params::ACK_RATIO_SCALE; +use saved::SavedDatagrams; +use state::StateSignaling; +pub use state::{ClosingFrame, State}; + +pub use crate::send_stream::{RetransmissionPriority, SendStreamStats, TransmissionPriority}; + +#[derive(Debug, Default)] +struct Packet(Vec<u8>); + +/// The number of Initial packets that the client will send in response +/// to receiving an undecryptable packet during the early part of the +/// handshake. This is a hack, but a useful one. +const EXTRA_INITIALS: usize = 4; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ZeroRttState { + Init, + Sending, + AcceptedClient, + AcceptedServer, + Rejected, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +/// Type returned from process() and `process_output()`. Users are required to +/// call these repeatedly until `Callback` or `None` is returned. +pub enum Output { + /// Connection requires no action. + None, + /// Connection requires the datagram be sent. + Datagram(Datagram), + /// Connection requires `process_input()` be called when the `Duration` + /// elapses. + Callback(Duration), +} + +impl Output { + /// Convert into an `Option<Datagram>`. + #[must_use] + pub fn dgram(self) -> Option<Datagram> { + match self { + Self::Datagram(dg) => Some(dg), + _ => None, + } + } + + /// Get a reference to the Datagram, if any. + pub fn as_dgram_ref(&self) -> Option<&Datagram> { + match self { + Self::Datagram(dg) => Some(dg), + _ => None, + } + } + + /// Ask how long the caller should wait before calling back. + #[must_use] + pub fn callback(&self) -> Duration { + match self { + Self::Callback(t) => *t, + _ => Duration::new(0, 0), + } + } +} + +/// Used by inner functions like Connection::output. +enum SendOption { + /// Yes, please send this datagram. + Yes(Datagram), + /// Don't send. If this was blocked on the pacer (the arg is true). + No(bool), +} + +impl Default for SendOption { + fn default() -> Self { + Self::No(false) + } +} + +/// Used by `Connection::preprocess` to determine what to do +/// with an packet before attempting to remove protection. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PreprocessResult { + /// End processing and return successfully. + End, + /// Stop processing this datagram and move on to the next. + Next, + /// Continue and process this packet. + Continue, +} + +/// `AddressValidationInfo` holds information relevant to either +/// responding to address validation (`NewToken`, `Retry`) or generating +/// tokens for address validation (`Server`). +enum AddressValidationInfo { + None, + // We are a client and have information from `NEW_TOKEN`. + NewToken(Vec<u8>), + // We are a client and have received a `Retry` packet. + Retry { + token: Vec<u8>, + retry_source_cid: ConnectionId, + }, + // We are a server and can generate tokens. + Server(Weak<RefCell<AddressValidation>>), +} + +impl AddressValidationInfo { + pub fn token(&self) -> &[u8] { + match self { + Self::NewToken(token) | Self::Retry { token, .. } => token, + _ => &[], + } + } + + pub fn generate_new_token( + &mut self, + peer_address: SocketAddr, + now: Instant, + ) -> Option<Vec<u8>> { + match self { + Self::Server(ref w) => { + if let Some(validation) = w.upgrade() { + validation + .borrow() + .generate_new_token(peer_address, now) + .ok() + } else { + None + } + } + Self::None => None, + _ => unreachable!("called a server function on a client"), + } + } +} + +/// A QUIC Connection +/// +/// First, create a new connection using `new_client()` or `new_server()`. +/// +/// For the life of the connection, handle activity in the following manner: +/// 1. Perform operations using the `stream_*()` methods. +/// 1. Call `process_input()` when a datagram is received or the timer +/// expires. Obtain information on connection state changes by checking +/// `events()`. +/// 1. Having completed handling current activity, repeatedly call +/// `process_output()` for packets to send, until it returns `Output::Callback` +/// or `Output::None`. +/// +/// After the connection is closed (either by calling `close()` or by the +/// remote) continue processing until `state()` returns `Closed`. +pub struct Connection { + role: Role, + version: Version, + state: State, + tps: Rc<RefCell<TransportParametersHandler>>, + /// What we are doing with 0-RTT. + zero_rtt_state: ZeroRttState, + /// All of the network paths that we are aware of. + paths: Paths, + /// This object will generate connection IDs for the connection. + cid_manager: ConnectionIdManager, + address_validation: AddressValidationInfo, + /// The connection IDs that were provided by the peer. + connection_ids: ConnectionIdStore<[u8; 16]>, + + /// The source connection ID that this endpoint uses for the handshake. + /// Since we need to communicate this to our peer in tparams, setting this + /// value is part of constructing the struct. + local_initial_source_cid: ConnectionId, + /// The source connection ID from the first packet from the other end. + /// This is checked against the peer's transport parameters. + remote_initial_source_cid: Option<ConnectionId>, + /// The destination connection ID from the first packet from the client. + /// This is checked by the client against the server's transport parameters. + original_destination_cid: Option<ConnectionId>, + + /// We sometimes save a datagram against the possibility that keys will later + /// become available. This avoids reporting packets as dropped during the handshake + /// when they are either just reordered or we haven't been able to install keys yet. + /// In particular, this occurs when asynchronous certificate validation happens. + saved_datagrams: SavedDatagrams, + /// Some packets were received, but not tracked. + received_untracked: bool, + + /// This is responsible for the QuicDatagrams' handling: + /// <https://datatracker.ietf.org/doc/html/draft-ietf-quic-datagram> + quic_datagrams: QuicDatagrams, + + pub(crate) crypto: Crypto, + pub(crate) acks: AckTracker, + idle_timeout: IdleTimeout, + streams: Streams, + state_signaling: StateSignaling, + loss_recovery: LossRecovery, + events: ConnectionEvents, + new_token: NewTokenState, + stats: StatsCell, + qlog: NeqoQlog, + /// A session ticket was received without NEW_TOKEN, + /// this is when that turns into an event without NEW_TOKEN. + release_resumption_token_timer: Option<Instant>, + conn_params: ConnectionParameters, + hrtime: hrtime::Handle, + + /// For testing purposes it is sometimes necessary to inject frames that wouldn't + /// otherwise be sent, just to see how a connection handles them. Inserting them + /// into packets proper mean that the frames follow the entire processing path. + #[cfg(test)] + pub test_frame_writer: Option<Box<dyn test_internal::FrameWriter>>, +} + +impl Debug for Connection { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{:?} Connection: {:?} {:?}", + self.role, + self.state, + self.paths.primary_fallible() + ) + } +} + +impl Connection { + /// A long default for timer resolution, so that we don't tax the + /// system too hard when we don't need to. + const LOOSE_TIMER_RESOLUTION: Duration = Duration::from_millis(50); + + /// Create a new QUIC connection with Client role. + pub fn new_client( + server_name: impl Into<String>, + protocols: &[impl AsRef<str>], + cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>, + local_addr: SocketAddr, + remote_addr: SocketAddr, + conn_params: ConnectionParameters, + now: Instant, + ) -> Res<Self> { + let dcid = ConnectionId::generate_initial(); + let mut c = Self::new( + Role::Client, + Agent::from(Client::new(server_name.into(), conn_params.is_greasing())?), + cid_generator, + protocols, + conn_params, + )?; + c.crypto.states.init( + c.conn_params.get_versions().compatible(), + Role::Client, + &dcid, + ); + c.original_destination_cid = Some(dcid); + let path = Path::temporary( + local_addr, + remote_addr, + c.conn_params.get_cc_algorithm(), + c.conn_params.pacing_enabled(), + NeqoQlog::default(), + now, + ); + c.setup_handshake_path(&Rc::new(RefCell::new(path)), now); + Ok(c) + } + + /// Create a new QUIC connection with Server role. + pub fn new_server( + certs: &[impl AsRef<str>], + protocols: &[impl AsRef<str>], + cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>, + conn_params: ConnectionParameters, + ) -> Res<Self> { + Self::new( + Role::Server, + Agent::from(Server::new(certs)?), + cid_generator, + protocols, + conn_params, + ) + } + + fn new<P: AsRef<str>>( + role: Role, + agent: Agent, + cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>, + protocols: &[P], + conn_params: ConnectionParameters, + ) -> Res<Self> { + // Setup the local connection ID. + let local_initial_source_cid = cid_generator + .borrow_mut() + .generate_cid() + .ok_or(Error::ConnectionIdsExhausted)?; + let mut cid_manager = + ConnectionIdManager::new(cid_generator, local_initial_source_cid.clone()); + let mut tps = conn_params.create_transport_parameter(role, &mut cid_manager)?; + tps.local.set_bytes( + tparams::INITIAL_SOURCE_CONNECTION_ID, + local_initial_source_cid.to_vec(), + ); + + let tphandler = Rc::new(RefCell::new(tps)); + let crypto = Crypto::new( + conn_params.get_versions().initial(), + agent, + protocols.iter().map(P::as_ref).map(String::from).collect(), + Rc::clone(&tphandler), + conn_params.is_fuzzing(), + )?; + + let stats = StatsCell::default(); + let events = ConnectionEvents::default(); + let quic_datagrams = QuicDatagrams::new( + conn_params.get_datagram_size(), + conn_params.get_outgoing_datagram_queue(), + conn_params.get_incoming_datagram_queue(), + events.clone(), + ); + + let c = Self { + role, + version: conn_params.get_versions().initial(), + state: State::Init, + paths: Paths::default(), + cid_manager, + tps: tphandler.clone(), + zero_rtt_state: ZeroRttState::Init, + address_validation: AddressValidationInfo::None, + local_initial_source_cid, + remote_initial_source_cid: None, + original_destination_cid: None, + saved_datagrams: SavedDatagrams::default(), + received_untracked: false, + crypto, + acks: AckTracker::default(), + idle_timeout: IdleTimeout::new(conn_params.get_idle_timeout()), + streams: Streams::new(tphandler, role, events.clone()), + connection_ids: ConnectionIdStore::default(), + state_signaling: StateSignaling::Idle, + loss_recovery: LossRecovery::new(stats.clone(), conn_params.get_fast_pto()), + events, + new_token: NewTokenState::new(role), + stats, + qlog: NeqoQlog::disabled(), + release_resumption_token_timer: None, + conn_params, + hrtime: hrtime::Time::get(Self::LOOSE_TIMER_RESOLUTION), + quic_datagrams, + #[cfg(test)] + test_frame_writer: None, + }; + c.stats.borrow_mut().init(format!("{c}")); + Ok(c) + } + + pub fn server_enable_0rtt( + &mut self, + anti_replay: &AntiReplay, + zero_rtt_checker: impl ZeroRttChecker + 'static, + ) -> Res<()> { + self.crypto + .server_enable_0rtt(self.tps.clone(), anti_replay, zero_rtt_checker) + } + + pub fn server_enable_ech( + &mut self, + config: u8, + public_name: &str, + sk: &PrivateKey, + pk: &PublicKey, + ) -> Res<()> { + self.crypto.server_enable_ech(config, public_name, sk, pk) + } + + /// Get the active ECH configuration, which is empty if ECH is disabled. + pub fn ech_config(&self) -> &[u8] { + self.crypto.ech_config() + } + + pub fn client_enable_ech(&mut self, ech_config_list: impl AsRef<[u8]>) -> Res<()> { + self.crypto.client_enable_ech(ech_config_list) + } + + /// Set or clear the qlog for this connection. + pub fn set_qlog(&mut self, qlog: NeqoQlog) { + self.loss_recovery.set_qlog(qlog.clone()); + self.paths.set_qlog(qlog.clone()); + self.qlog = qlog; + } + + /// Get the qlog (if any) for this connection. + pub fn qlog_mut(&mut self) -> &mut NeqoQlog { + &mut self.qlog + } + + /// Get the original destination connection id for this connection. This + /// will always be present for Role::Client but not if Role::Server is in + /// State::Init. + pub fn odcid(&self) -> Option<&ConnectionId> { + self.original_destination_cid.as_ref() + } + + /// Set a local transport parameter, possibly overriding a default value. + /// This only sets transport parameters without dealing with other aspects of + /// setting the value. + /// + /// # Panics + /// + /// This panics if the transport parameter is known to this crate. + pub fn set_local_tparam(&self, tp: TransportParameterId, value: TransportParameter) -> Res<()> { + #[cfg(not(test))] + { + assert!(!tparams::INTERNAL_TRANSPORT_PARAMETERS.contains(&tp)); + } + if *self.state() == State::Init { + self.tps.borrow_mut().local.set(tp, value); + Ok(()) + } else { + qerror!("Current state: {:?}", self.state()); + qerror!("Cannot set local tparam when not in an initial connection state."); + Err(Error::ConnectionState) + } + } + + /// `odcid` is their original choice for our CID, which we get from the Retry token. + /// `remote_cid` is the value from the Source Connection ID field of an incoming packet: what + /// the peer wants us to use now. `retry_cid` is what we asked them to use when we sent the + /// Retry. + pub(crate) fn set_retry_cids( + &mut self, + odcid: ConnectionId, + remote_cid: ConnectionId, + retry_cid: ConnectionId, + ) { + debug_assert_eq!(self.role, Role::Server); + qtrace!( + [self], + "Retry CIDs: odcid={} remote={} retry={}", + odcid, + remote_cid, + retry_cid + ); + // We advertise "our" choices in transport parameters. + let local_tps = &mut self.tps.borrow_mut().local; + local_tps.set_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID, odcid.to_vec()); + local_tps.set_bytes(tparams::RETRY_SOURCE_CONNECTION_ID, retry_cid.to_vec()); + + // ...and save their choices for later validation. + self.remote_initial_source_cid = Some(remote_cid); + } + + fn retry_sent(&self) -> bool { + self.tps + .borrow() + .local + .get_bytes(tparams::RETRY_SOURCE_CONNECTION_ID) + .is_some() + } + + /// Set ALPN preferences. Strings that appear earlier in the list are given + /// higher preference. + pub fn set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()> { + self.crypto.tls.set_alpn(protocols)?; + Ok(()) + } + + /// Enable a set of ciphers. + pub fn set_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()> { + if self.state != State::Init { + qerror!([self], "Cannot enable ciphers in state {:?}", self.state); + return Err(Error::ConnectionState); + } + self.crypto.tls.set_ciphers(ciphers)?; + Ok(()) + } + + /// Enable a set of key exchange groups. + pub fn set_groups(&mut self, groups: &[Group]) -> Res<()> { + if self.state != State::Init { + qerror!([self], "Cannot enable groups in state {:?}", self.state); + return Err(Error::ConnectionState); + } + self.crypto.tls.set_groups(groups)?; + Ok(()) + } + + /// Set the number of additional key shares to send in the client hello. + pub fn send_additional_key_shares(&mut self, count: usize) -> Res<()> { + if self.state != State::Init { + qerror!([self], "Cannot enable groups in state {:?}", self.state); + return Err(Error::ConnectionState); + } + self.crypto.tls.send_additional_key_shares(count)?; + Ok(()) + } + + fn make_resumption_token(&mut self) -> ResumptionToken { + debug_assert_eq!(self.role, Role::Client); + debug_assert!(self.crypto.has_resumption_token()); + let rtt = self.paths.primary().borrow().rtt().estimate(); + self.crypto + .create_resumption_token( + self.new_token.take_token(), + self.tps + .borrow() + .remote + .as_ref() + .expect("should have transport parameters"), + self.version, + u64::try_from(rtt.as_millis()).unwrap_or(0), + ) + .unwrap() + } + + /// Get the simplest PTO calculation for all those cases where we need + /// a value of this approximate order. Don't use this for loss recovery, + /// only use it where a more precise value is not important. + fn pto(&self) -> Duration { + self.paths + .primary() + .borrow() + .rtt() + .pto(PacketNumberSpace::ApplicationData) + } + + fn create_resumption_token(&mut self, now: Instant) { + if self.role == Role::Server || self.state < State::Connected { + return; + } + + qtrace!( + [self], + "Maybe create resumption token: {} {}", + self.crypto.has_resumption_token(), + self.new_token.has_token() + ); + + while self.crypto.has_resumption_token() && self.new_token.has_token() { + let token = self.make_resumption_token(); + self.events.client_resumption_token(token); + } + + // If we have a resumption ticket check or set a timer. + if self.crypto.has_resumption_token() { + let arm = if let Some(expiration_time) = self.release_resumption_token_timer { + if expiration_time <= now { + let token = self.make_resumption_token(); + self.events.client_resumption_token(token); + self.release_resumption_token_timer = None; + + // This means that we release one session ticket every 3 PTOs + // if no NEW_TOKEN frame is received. + self.crypto.has_resumption_token() + } else { + false + } + } else { + true + }; + + if arm { + self.release_resumption_token_timer = Some(now + 3 * self.pto()); + } + } + } + + /// The correct way to obtain a resumption token is to wait for the + /// `ConnectionEvent::ResumptionToken` event. To emit the event we are waiting for a + /// resumption token and a `NEW_TOKEN` frame to arrive. Some servers don't send `NEW_TOKEN` + /// frames and in this case, we wait for 3xPTO before emitting an event. This is especially a + /// problem for short-lived connections, where the connection is closed before any events are + /// released. This function retrieves the token, without waiting for a `NEW_TOKEN` frame to + /// arrive. + /// + /// # Panics + /// + /// If this is called on a server. + pub fn take_resumption_token(&mut self, now: Instant) -> Option<ResumptionToken> { + assert_eq!(self.role, Role::Client); + + if self.crypto.has_resumption_token() { + let token = self.make_resumption_token(); + if self.crypto.has_resumption_token() { + self.release_resumption_token_timer = Some(now + 3 * self.pto()); + } + Some(token) + } else { + None + } + } + + /// Enable resumption, using a token previously provided. + /// This can only be called once and only on the client. + /// After calling the function, it should be possible to attempt 0-RTT + /// if the token supports that. + pub fn enable_resumption(&mut self, now: Instant, token: impl AsRef<[u8]>) -> Res<()> { + if self.state != State::Init { + qerror!([self], "set token in state {:?}", self.state); + return Err(Error::ConnectionState); + } + if self.role == Role::Server { + return Err(Error::ConnectionState); + } + + qinfo!( + [self], + "resumption token {}", + hex_snip_middle(token.as_ref()) + ); + let mut dec = Decoder::from(token.as_ref()); + + let version = + Version::try_from(dec.decode_uint(4).ok_or(Error::InvalidResumptionToken)? as u32)?; + qtrace!([self], " version {:?}", version); + if !self.conn_params.get_versions().all().contains(&version) { + return Err(Error::DisabledVersion); + } + + let rtt = Duration::from_millis(dec.decode_varint().ok_or(Error::InvalidResumptionToken)?); + qtrace!([self], " RTT {:?}", rtt); + + let tp_slice = dec.decode_vvec().ok_or(Error::InvalidResumptionToken)?; + qtrace!([self], " transport parameters {}", hex(tp_slice)); + let mut dec_tp = Decoder::from(tp_slice); + let tp = + TransportParameters::decode(&mut dec_tp).map_err(|_| Error::InvalidResumptionToken)?; + + let init_token = dec.decode_vvec().ok_or(Error::InvalidResumptionToken)?; + qtrace!([self], " Initial token {}", hex(init_token)); + + let tok = dec.decode_remainder(); + qtrace!([self], " TLS token {}", hex(tok)); + + match self.crypto.tls { + Agent::Client(ref mut c) => { + let res = c.enable_resumption(tok); + if let Err(e) = res { + self.absorb_error::<Error>(now, Err(Error::from(e))); + return Ok(()); + } + } + Agent::Server(_) => return Err(Error::WrongRole), + } + + self.version = version; + self.conn_params.get_versions_mut().set_initial(version); + self.tps.borrow_mut().set_version(version); + self.tps.borrow_mut().remote_0rtt = Some(tp); + if !init_token.is_empty() { + self.address_validation = AddressValidationInfo::NewToken(init_token.to_vec()); + } + self.paths.primary().borrow_mut().rtt_mut().set_initial(rtt); + self.set_initial_limits(); + // Start up TLS, which has the effect of setting up all the necessary + // state for 0-RTT. This only stages the CRYPTO frames. + let res = self.client_start(now); + self.absorb_error(now, res); + Ok(()) + } + + pub(crate) fn set_validation(&mut self, validation: Rc<RefCell<AddressValidation>>) { + qtrace!([self], "Enabling NEW_TOKEN"); + assert_eq!(self.role, Role::Server); + self.address_validation = AddressValidationInfo::Server(Rc::downgrade(&validation)); + } + + /// Send a TLS session ticket AND a NEW_TOKEN frame (if possible). + pub fn send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<()> { + if self.role == Role::Client { + return Err(Error::WrongRole); + } + + let tps = &self.tps; + if let Agent::Server(ref mut s) = self.crypto.tls { + let mut enc = Encoder::default(); + enc.encode_vvec_with(|enc_inner| { + tps.borrow().local.encode(enc_inner); + }); + enc.encode(extra); + let records = s.send_ticket(now, enc.as_ref())?; + qinfo!([self], "send session ticket {}", hex(&enc)); + self.crypto.buffer_records(records)?; + } else { + unreachable!(); + } + + // If we are able, also send a NEW_TOKEN frame. + // This should be recording all remote addresses that are valid, + // but there are just 0 or 1 in the current implementation. + if let Some(path) = self.paths.primary_fallible() { + if let Some(token) = self + .address_validation + .generate_new_token(path.borrow().remote_address(), now) + { + self.new_token.send_new_token(token); + } + Ok(()) + } else { + Err(Error::NotConnected) + } + } + + pub fn tls_info(&self) -> Option<&SecretAgentInfo> { + self.crypto.tls.info() + } + + pub fn tls_preinfo(&self) -> Res<SecretAgentPreInfo> { + Ok(self.crypto.tls.preinfo()?) + } + + /// Get the peer's certificate chain and other info. + pub fn peer_certificate(&self) -> Option<CertificateInfo> { + self.crypto.tls.peer_certificate() + } + + /// Call by application when the peer cert has been verified. + /// + /// This panics if there is no active peer. It's OK to call this + /// when authentication isn't needed, that will likely only cause + /// the connection to fail. However, if no packets have been + /// exchanged, it's not OK. + pub fn authenticated(&mut self, status: AuthenticationStatus, now: Instant) { + qinfo!([self], "Authenticated {:?}", status); + self.crypto.tls.authenticated(status); + let res = self.handshake(now, self.version, PacketNumberSpace::Handshake, None); + self.absorb_error(now, res); + self.process_saved(now); + } + + /// Get the role of the connection. + pub fn role(&self) -> Role { + self.role + } + + /// Get the state of the connection. + pub fn state(&self) -> &State { + &self.state + } + + /// The QUIC version in use. + pub fn version(&self) -> Version { + self.version + } + + /// Get the 0-RTT state of the connection. + pub fn zero_rtt_state(&self) -> ZeroRttState { + self.zero_rtt_state + } + + /// Get a snapshot of collected statistics. + pub fn stats(&self) -> Stats { + let mut v = self.stats.borrow().clone(); + if let Some(p) = self.paths.primary_fallible() { + let p = p.borrow(); + v.rtt = p.rtt().estimate(); + v.rttvar = p.rtt().rttvar(); + } + v + } + + // This function wraps a call to another function and sets the connection state + // properly if that call fails. + fn capture_error<T>( + &mut self, + path: Option<PathRef>, + now: Instant, + frame_type: FrameType, + res: Res<T>, + ) -> Res<T> { + if let Err(v) = &res { + #[cfg(debug_assertions)] + let msg = format!("{v:?}"); + #[cfg(not(debug_assertions))] + let msg = ""; + let error = ConnectionError::Transport(v.clone()); + match &self.state { + State::Closing { error: err, .. } + | State::Draining { error: err, .. } + | State::Closed(err) => { + qwarn!([self], "Closing again after error {:?}", err); + } + State::Init => { + // We have not even sent anything just close the connection without sending any + // error. This may happen when client_start fails. + self.set_state(State::Closed(error)); + } + State::WaitInitial => { + // We don't have any state yet, so don't bother with + // the closing state, just send one CONNECTION_CLOSE. + if let Some(path) = path.or_else(|| self.paths.primary_fallible()) { + self.state_signaling + .close(path, error.clone(), frame_type, msg); + } + self.set_state(State::Closed(error)); + } + _ => { + if let Some(path) = path.or_else(|| self.paths.primary_fallible()) { + self.state_signaling + .close(path, error.clone(), frame_type, msg); + if matches!(v, Error::KeysExhausted) { + self.set_state(State::Closed(error)); + } else { + self.set_state(State::Closing { + error, + timeout: self.get_closing_period_time(now), + }); + } + } else { + self.set_state(State::Closed(error)); + } + } + } + } + res + } + + /// For use with process_input(). Errors there can be ignored, but this + /// needs to ensure that the state is updated. + fn absorb_error<T>(&mut self, now: Instant, res: Res<T>) -> Option<T> { + self.capture_error(None, now, 0, res).ok() + } + + fn process_timer(&mut self, now: Instant) { + match &self.state { + // Only the client runs timers while waiting for Initial packets. + State::WaitInitial => debug_assert_eq!(self.role, Role::Client), + // If Closing or Draining, check if it is time to move to Closed. + State::Closing { error, timeout } | State::Draining { error, timeout } => { + if *timeout <= now { + let st = State::Closed(error.clone()); + self.set_state(st); + qinfo!("Closing timer expired"); + return; + } + } + State::Closed(_) => { + qdebug!("Timer fired while closed"); + return; + } + _ => (), + } + + let pto = self.pto(); + if self.idle_timeout.expired(now, pto) { + qinfo!([self], "idle timeout expired"); + self.set_state(State::Closed(ConnectionError::Transport( + Error::IdleTimeout, + ))); + return; + } + + self.streams.cleanup_closed_streams(); + + let res = self.crypto.states.check_key_update(now); + self.absorb_error(now, res); + + let lost = self.loss_recovery.timeout(&self.paths.primary(), now); + self.handle_lost_packets(&lost); + qlog::packets_lost(&mut self.qlog, &lost); + + if self.release_resumption_token_timer.is_some() { + self.create_resumption_token(now); + } + + if !self.paths.process_timeout(now, pto) { + qinfo!([self], "last available path failed"); + self.absorb_error::<Error>(now, Err(Error::NoAvailablePath)); + } + } + + /// Process new input datagrams on the connection. + pub fn process_input(&mut self, d: &Datagram, now: Instant) { + self.input(d, now, now); + self.process_saved(now); + self.streams.cleanup_closed_streams(); + } + + /// Process new input datagrams on the connection. + pub fn process_multiple_input<'a, I>(&mut self, dgrams: I, now: Instant) + where + I: IntoIterator<Item = &'a Datagram>, + I::IntoIter: ExactSizeIterator, + { + let dgrams = dgrams.into_iter(); + if dgrams.len() == 0 { + return; + } + + for d in dgrams { + self.input(d, now, now); + } + self.process_saved(now); + self.streams.cleanup_closed_streams(); + } + + /// Get the time that we next need to be called back, relative to `now`. + fn next_delay(&mut self, now: Instant, paced: bool) -> Duration { + qtrace!([self], "Get callback delay {:?}", now); + + // Only one timer matters when closing... + if let State::Closing { timeout, .. } | State::Draining { timeout, .. } = self.state { + self.hrtime.update(Self::LOOSE_TIMER_RESOLUTION); + return timeout.duration_since(now); + } + + let mut delays = SmallVec::<[_; 6]>::new(); + if let Some(ack_time) = self.acks.ack_time(now) { + qtrace!([self], "Delayed ACK timer {:?}", ack_time); + delays.push(ack_time); + } + + if let Some(p) = self.paths.primary_fallible() { + let path = p.borrow(); + let rtt = path.rtt(); + let pto = rtt.pto(PacketNumberSpace::ApplicationData); + + let keep_alive = self.streams.need_keep_alive(); + let idle_time = self.idle_timeout.expiry(now, pto, keep_alive); + qtrace!([self], "Idle/keepalive timer {:?}", idle_time); + delays.push(idle_time); + + if let Some(lr_time) = self.loss_recovery.next_timeout(rtt) { + qtrace!([self], "Loss recovery timer {:?}", lr_time); + delays.push(lr_time); + } + + if paced { + if let Some(pace_time) = path.sender().next_paced(rtt.estimate()) { + qtrace!([self], "Pacing timer {:?}", pace_time); + delays.push(pace_time); + } + } + + if let Some(path_time) = self.paths.next_timeout(pto) { + qtrace!([self], "Path probe timer {:?}", path_time); + delays.push(path_time); + } + } + + if let Some(key_update_time) = self.crypto.states.update_time() { + qtrace!([self], "Key update timer {:?}", key_update_time); + delays.push(key_update_time); + } + + // `release_resumption_token_timer` is not considered here, because + // it is not important enough to force the application to set a + // timeout for it It is expected that other activities will + // drive it. + + let earliest = delays.into_iter().min().unwrap(); + // TODO(agrover, mt) - need to analyze and fix #47 + // rather than just clamping to zero here. + debug_assert!(earliest > now); + let delay = earliest.saturating_duration_since(now); + qdebug!([self], "delay duration {:?}", delay); + self.hrtime.update(delay / 4); + delay + } + + /// Get output packets, as a result of receiving packets, or actions taken + /// by the application. + /// Returns datagrams to send, and how long to wait before calling again + /// even if no incoming packets. + #[must_use = "Output of the process_output function must be handled"] + pub fn process_output(&mut self, now: Instant) -> Output { + qtrace!([self], "process_output {:?} {:?}", self.state, now); + + match (&self.state, self.role) { + (State::Init, Role::Client) => { + let res = self.client_start(now); + self.absorb_error(now, res); + } + (State::Init | State::WaitInitial, Role::Server) => { + return Output::None; + } + _ => { + self.process_timer(now); + } + } + + match self.output(now) { + SendOption::Yes(dgram) => Output::Datagram(dgram), + SendOption::No(paced) => match self.state { + State::Init | State::Closed(_) => Output::None, + State::Closing { timeout, .. } | State::Draining { timeout, .. } => { + Output::Callback(timeout.duration_since(now)) + } + _ => Output::Callback(self.next_delay(now, paced)), + }, + } + } + + /// Process input and generate output. + #[must_use = "Output of the process function must be handled"] + pub fn process(&mut self, dgram: Option<&Datagram>, now: Instant) -> Output { + if let Some(d) = dgram { + self.input(d, now, now); + self.process_saved(now); + } + self.process_output(now) + } + + fn handle_retry(&mut self, packet: &PublicPacket, now: Instant) { + qinfo!([self], "received Retry"); + if matches!(self.address_validation, AddressValidationInfo::Retry { .. }) { + self.stats.borrow_mut().pkt_dropped("Extra Retry"); + return; + } + if packet.token().is_empty() { + self.stats.borrow_mut().pkt_dropped("Retry without a token"); + return; + } + if !packet.is_valid_retry(self.original_destination_cid.as_ref().unwrap()) { + self.stats + .borrow_mut() + .pkt_dropped("Retry with bad integrity tag"); + return; + } + // At this point, we should only have the connection ID that we generated. + // Update to the one that the server prefers. + let path = self.paths.primary(); + path.borrow_mut().set_remote_cid(packet.scid()); + + let retry_scid = ConnectionId::from(packet.scid()); + qinfo!( + [self], + "Valid Retry received, token={} scid={}", + hex(packet.token()), + retry_scid + ); + + let lost_packets = self.loss_recovery.retry(&path, now); + self.handle_lost_packets(&lost_packets); + + self.crypto.states.init( + self.conn_params.get_versions().compatible(), + self.role, + &retry_scid, + ); + self.address_validation = AddressValidationInfo::Retry { + token: packet.token().to_vec(), + retry_source_cid: retry_scid, + }; + } + + fn discard_keys(&mut self, space: PacketNumberSpace, now: Instant) { + if self.crypto.discard(space) { + qinfo!([self], "Drop packet number space {}", space); + let primary = self.paths.primary(); + self.loss_recovery.discard(&primary, space, now); + self.acks.drop_space(space); + } + } + + fn is_stateless_reset(&self, path: &PathRef, d: &Datagram) -> bool { + // If the datagram is too small, don't try. + // If the connection is connected, then the reset token will be invalid. + if d.len() < 16 || !self.state.connected() { + return false; + } + let token = <&[u8; 16]>::try_from(&d[d.len() - 16..]).unwrap(); + path.borrow().is_stateless_reset(token) + } + + fn check_stateless_reset( + &mut self, + path: &PathRef, + d: &Datagram, + first: bool, + now: Instant, + ) -> Res<()> { + if first && self.is_stateless_reset(path, d) { + // Failing to process a packet in a datagram might + // indicate that there is a stateless reset present. + qdebug!([self], "Stateless reset: {}", hex(&d[d.len() - 16..])); + self.state_signaling.reset(); + self.set_state(State::Draining { + error: ConnectionError::Transport(Error::StatelessReset), + timeout: self.get_closing_period_time(now), + }); + Err(Error::StatelessReset) + } else { + Ok(()) + } + } + + /// Process any saved datagrams that might be available for processing. + fn process_saved(&mut self, now: Instant) { + while let Some(cspace) = self.saved_datagrams.available() { + qdebug!([self], "process saved for space {:?}", cspace); + debug_assert!(self.crypto.states.rx_hp(self.version, cspace).is_some()); + for saved in self.saved_datagrams.take_saved() { + qtrace!([self], "input saved @{:?}: {:?}", saved.t, saved.d); + self.input(&saved.d, saved.t, now); + } + } + } + + /// In case a datagram arrives that we can only partially process, save any + /// part that we don't have keys for. + fn save_datagram(&mut self, cspace: CryptoSpace, d: &Datagram, remaining: usize, now: Instant) { + let d = if remaining < d.len() { + Datagram::new( + d.source(), + d.destination(), + d.tos(), + d.ttl(), + &d[d.len() - remaining..], + ) + } else { + d.clone() + }; + self.saved_datagrams.save(cspace, d, now); + self.stats.borrow_mut().saved_datagrams += 1; + } + + /// Perform version negotiation. + fn version_negotiation(&mut self, supported: &[WireVersion], now: Instant) -> Res<()> { + debug_assert_eq!(self.role, Role::Client); + + if let Some(version) = self.conn_params.get_versions().preferred(supported) { + assert_ne!(self.version, version); + + qinfo!([self], "Version negotiation: trying {:?}", version); + let local_addr = self.paths.primary().borrow().local_address(); + let remote_addr = self.paths.primary().borrow().remote_address(); + let conn_params = self + .conn_params + .clone() + .versions(version, self.conn_params.get_versions().all().to_vec()); + let mut c = Self::new_client( + self.crypto.server_name().unwrap(), + self.crypto.protocols(), + self.cid_manager.generator(), + local_addr, + remote_addr, + conn_params, + now, + )?; + c.conn_params + .get_versions_mut() + .set_initial(self.conn_params.get_versions().initial()); + mem::swap(self, &mut c); + qlog::client_version_information_negotiated( + &mut self.qlog, + self.conn_params.get_versions().all(), + supported, + version, + ); + Ok(()) + } else { + qinfo!([self], "Version negotiation: failed with {:?}", supported); + // This error goes straight to closed. + self.set_state(State::Closed(ConnectionError::Transport( + Error::VersionNegotiation, + ))); + Err(Error::VersionNegotiation) + } + } + + /// Perform any processing that we might have to do on packets prior to + /// attempting to remove protection. + fn preprocess_packet( + &mut self, + packet: &PublicPacket, + path: &PathRef, + dcid: Option<&ConnectionId>, + now: Instant, + ) -> Res<PreprocessResult> { + if dcid.map_or(false, |d| d != &packet.dcid()) { + self.stats + .borrow_mut() + .pkt_dropped("Coalesced packet has different DCID"); + return Ok(PreprocessResult::Next); + } + + if (packet.packet_type() == PacketType::Initial + || packet.packet_type() == PacketType::Handshake) + && self.role == Role::Client + && !path.borrow().is_primary() + { + // If we have received a packet from a different address than we have sent to + // we should ignore the packet. In such a case a path will be a newly created + // temporary path, not the primary path. + return Ok(PreprocessResult::Next); + } + + match (packet.packet_type(), &self.state, &self.role) { + (PacketType::Initial, State::Init, Role::Server) => { + let version = *packet.version().as_ref().unwrap(); + if !packet.is_valid_initial() + || !self.conn_params.get_versions().all().contains(&version) + { + self.stats.borrow_mut().pkt_dropped("Invalid Initial"); + return Ok(PreprocessResult::Next); + } + qinfo!( + [self], + "Received valid Initial packet with scid {:?} dcid {:?}", + packet.scid(), + packet.dcid() + ); + // Record the client's selected CID so that it can be accepted until + // the client starts using a real connection ID. + let dcid = ConnectionId::from(packet.dcid()); + self.crypto.states.init_server(version, &dcid); + self.original_destination_cid = Some(dcid); + self.set_state(State::WaitInitial); + + // We need to make sure that we set this transport parameter. + // This has to happen prior to processing the packet so that + // the TLS handshake has all it needs. + if !self.retry_sent() { + self.tps.borrow_mut().local.set_bytes( + tparams::ORIGINAL_DESTINATION_CONNECTION_ID, + packet.dcid().to_vec(), + ); + } + } + (PacketType::VersionNegotiation, State::WaitInitial, Role::Client) => { + if let Ok(versions) = packet.supported_versions() { + if versions.is_empty() + || versions.contains(&self.version().wire_version()) + || versions.contains(&0) + || &packet.scid() != self.odcid().unwrap() + || matches!(self.address_validation, AddressValidationInfo::Retry { .. }) + { + // Ignore VersionNegotiation packets that contain the current version. + // Or don't have the right connection ID. + // Or are received after a Retry. + self.stats.borrow_mut().pkt_dropped("Invalid VN"); + } else { + self.version_negotiation(&versions, now)?; + } + } else { + self.stats.borrow_mut().pkt_dropped("VN with no versions"); + }; + return Ok(PreprocessResult::End); + } + (PacketType::Retry, State::WaitInitial, Role::Client) => { + self.handle_retry(packet, now); + return Ok(PreprocessResult::Next); + } + (PacketType::Handshake | PacketType::Short, State::WaitInitial, Role::Client) => { + // This packet can't be processed now, but it could be a sign + // that Initial packets were lost. + // Resend Initial CRYPTO frames immediately a few times just + // in case. As we don't have an RTT estimate yet, this helps + // when there is a short RTT and losses. + if dcid.is_none() + && self.cid_manager.is_valid(packet.dcid()) + && self.stats.borrow().saved_datagrams <= EXTRA_INITIALS + { + self.crypto.resend_unacked(PacketNumberSpace::Initial); + } + } + (PacketType::VersionNegotiation | PacketType::Retry | PacketType::OtherVersion, ..) => { + self.stats + .borrow_mut() + .pkt_dropped(format!("{:?}", packet.packet_type())); + return Ok(PreprocessResult::Next); + } + _ => {} + } + + let res = match self.state { + State::Init => { + self.stats + .borrow_mut() + .pkt_dropped("Received while in Init state"); + PreprocessResult::Next + } + State::WaitInitial => PreprocessResult::Continue, + State::WaitVersion | State::Handshaking | State::Connected | State::Confirmed => { + if !self.cid_manager.is_valid(packet.dcid()) { + self.stats + .borrow_mut() + .pkt_dropped(format!("Invalid DCID {:?}", packet.dcid())); + PreprocessResult::Next + } else { + if self.role == Role::Server && packet.packet_type() == PacketType::Handshake { + // Server has received a Handshake packet -> discard Initial keys and states + self.discard_keys(PacketNumberSpace::Initial, now); + } + PreprocessResult::Continue + } + } + State::Closing { .. } => { + // Don't bother processing the packet. Instead ask to get a + // new close frame. + self.state_signaling.send_close(); + PreprocessResult::Next + } + State::Draining { .. } | State::Closed(..) => { + // Do nothing. + self.stats + .borrow_mut() + .pkt_dropped(format!("State {:?}", self.state)); + PreprocessResult::Next + } + }; + Ok(res) + } + + /// After a Initial, Handshake, ZeroRtt, or Short packet is successfully processed. + fn postprocess_packet( + &mut self, + path: &PathRef, + d: &Datagram, + packet: &PublicPacket, + migrate: bool, + now: Instant, + ) { + if self.state == State::WaitInitial { + self.start_handshake(path, packet, now); + } + + if self.state.connected() { + self.handle_migration(path, d, migrate, now); + } else if self.role != Role::Client + && (packet.packet_type() == PacketType::Handshake + || (packet.dcid().len() >= 8 && packet.dcid() == self.local_initial_source_cid)) + { + // We only allow one path during setup, so apply handshake + // path validation to this path. + path.borrow_mut().set_valid(now); + } + } + + /// Take a datagram as input. This reports an error if the packet was bad. + /// This takes two times: when the datagram was received, and the current time. + fn input(&mut self, d: &Datagram, received: Instant, now: Instant) { + // First determine the path. + let path = self.paths.find_path_with_rebinding( + d.destination(), + d.source(), + self.conn_params.get_cc_algorithm(), + self.conn_params.pacing_enabled(), + now, + ); + path.borrow_mut().add_received(d.len()); + let res = self.input_path(&path, d, received); + self.capture_error(Some(path), now, 0, res).ok(); + } + + fn input_path(&mut self, path: &PathRef, d: &Datagram, now: Instant) -> Res<()> { + let mut slc = &d[..]; + let mut dcid = None; + + qtrace!([self], "{} input {}", path.borrow(), hex(&**d)); + let pto = path.borrow().rtt().pto(PacketNumberSpace::ApplicationData); + + // Handle each packet in the datagram. + while !slc.is_empty() { + self.stats.borrow_mut().packets_rx += 1; + let (packet, remainder) = + match PublicPacket::decode(slc, self.cid_manager.decoder().as_ref()) { + Ok((packet, remainder)) => (packet, remainder), + Err(e) => { + qinfo!([self], "Garbage packet: {}", e); + qtrace!([self], "Garbage packet contents: {}", hex(slc)); + self.stats.borrow_mut().pkt_dropped("Garbage packet"); + break; + } + }; + match self.preprocess_packet(&packet, path, dcid.as_ref(), now)? { + PreprocessResult::Continue => (), + PreprocessResult::Next => break, + PreprocessResult::End => return Ok(()), + } + + qtrace!([self], "Received unverified packet {:?}", packet); + + match packet.decrypt(&mut self.crypto.states, now + pto) { + Ok(payload) => { + // OK, we have a valid packet. + self.idle_timeout.on_packet_received(now); + dump_packet( + self, + path, + "-> RX", + payload.packet_type(), + payload.pn(), + &payload[..], + ); + + qlog::packet_received(&mut self.qlog, &packet, &payload); + let space = PacketNumberSpace::from(payload.packet_type()); + if self.acks.get_mut(space).unwrap().is_duplicate(payload.pn()) { + qdebug!([self], "Duplicate packet {}-{}", space, payload.pn()); + self.stats.borrow_mut().dups_rx += 1; + } else { + match self.process_packet(path, &payload, now) { + Ok(migrate) => self.postprocess_packet(path, d, &packet, migrate, now), + Err(e) => { + self.ensure_error_path(path, &packet, now); + return Err(e); + } + } + } + } + Err(e) => { + match e { + Error::KeysPending(cspace) => { + // This packet can't be decrypted because we don't have the keys yet. + // Don't check this packet for a stateless reset, just return. + let remaining = slc.len(); + self.save_datagram(cspace, d, remaining, now); + return Ok(()); + } + Error::KeysExhausted => { + // Exhausting read keys is fatal. + return Err(e); + } + Error::KeysDiscarded(cspace) => { + // This was a valid-appearing Initial packet: maybe probe with + // a Handshake packet to keep the handshake moving. + self.received_untracked |= + self.role == Role::Client && cspace == CryptoSpace::Initial; + } + _ => (), + } + // Decryption failure, or not having keys is not fatal. + // If the state isn't available, or we can't decrypt the packet, drop + // the rest of the datagram on the floor, but don't generate an error. + self.check_stateless_reset(path, d, dcid.is_none(), now)?; + self.stats.borrow_mut().pkt_dropped("Decryption failure"); + qlog::packet_dropped(&mut self.qlog, &packet); + } + } + slc = remainder; + dcid = Some(ConnectionId::from(packet.dcid())); + } + self.check_stateless_reset(path, d, dcid.is_none(), now)?; + Ok(()) + } + + /// Process a packet. Returns true if the packet might initiate migration. + fn process_packet( + &mut self, + path: &PathRef, + packet: &DecryptedPacket, + now: Instant, + ) -> Res<bool> { + // TODO(ekr@rtfm.com): Have the server blow away the initial + // crypto state if this fails? Otherwise, we will get a panic + // on the assert for doesn't exist. + // OK, we have a valid packet. + + let mut ack_eliciting = false; + let mut probing = true; + let mut d = Decoder::from(&packet[..]); + let mut consecutive_padding = 0; + while d.remaining() > 0 { + let mut f = Frame::decode(&mut d)?; + + // Skip padding + while f == Frame::Padding && d.remaining() > 0 { + consecutive_padding += 1; + f = Frame::decode(&mut d)?; + } + if consecutive_padding > 0 { + qdebug!( + [self], + "PADDING frame repeated {} times", + consecutive_padding + ); + consecutive_padding = 0; + } + + ack_eliciting |= f.ack_eliciting(); + probing &= f.path_probing(); + let t = f.get_type(); + if let Err(e) = self.input_frame(path, packet.version(), packet.packet_type(), f, now) { + self.capture_error(Some(Rc::clone(path)), now, t, Err(e))?; + } + } + + let largest_received = if let Some(space) = self + .acks + .get_mut(PacketNumberSpace::from(packet.packet_type())) + { + space.set_received(now, packet.pn(), ack_eliciting) + } else { + qdebug!( + [self], + "processed a {:?} packet without tracking it", + packet.packet_type(), + ); + // This was a valid packet that caused the same packet number to be + // discarded. This happens when the client discards the Initial packet + // number space after receiving the ServerHello. Remember this so + // that we guarantee that we send a Handshake packet. + self.received_untracked = true; + // We don't migrate during the handshake, so return false. + false + }; + + Ok(largest_received && !probing) + } + + /// During connection setup, the first path needs to be setup. + /// This uses the connection IDs that were provided during the handshake + /// to setup that path. + #[allow(clippy::or_fun_call)] // Remove when MSRV >= 1.59 + fn setup_handshake_path(&mut self, path: &PathRef, now: Instant) { + self.paths.make_permanent( + path, + Some(self.local_initial_source_cid.clone()), + // Ideally we know what the peer wants us to use for the remote CID. + // But we will use our own guess if necessary. + ConnectionIdEntry::initial_remote( + self.remote_initial_source_cid + .as_ref() + .or(self.original_destination_cid.as_ref()) + .unwrap() + .clone(), + ), + ); + path.borrow_mut().set_valid(now); + } + + /// If the path isn't permanent, assign it a connection ID to make it so. + fn ensure_permanent(&mut self, path: &PathRef) -> Res<()> { + if self.paths.is_temporary(path) { + // If there isn't a connection ID to use for this path, the packet + // will be processed, but it won't be attributed to a path. That means + // no path probes or PATH_RESPONSE. But it's not fatal. + if let Some(cid) = self.connection_ids.next() { + self.paths.make_permanent(path, None, cid); + Ok(()) + } else if self.paths.primary().borrow().remote_cid().is_empty() { + self.paths + .make_permanent(path, None, ConnectionIdEntry::empty_remote()); + Ok(()) + } else { + qtrace!([self], "Unable to make path permanent: {}", path.borrow()); + Err(Error::InvalidMigration) + } + } else { + Ok(()) + } + } + + /// After an error, a permanent path is needed to send the CONNECTION_CLOSE. + /// This attempts to ensure that this exists. As the connection is now + /// temporary, there is no reason to do anything special here. + fn ensure_error_path(&mut self, path: &PathRef, packet: &PublicPacket, now: Instant) { + path.borrow_mut().set_valid(now); + if self.paths.is_temporary(path) { + // First try to fill in handshake details. + if packet.packet_type() == PacketType::Initial { + self.remote_initial_source_cid = Some(ConnectionId::from(packet.scid())); + self.setup_handshake_path(path, now); + } else { + // Otherwise try to get a usable connection ID. + mem::drop(self.ensure_permanent(path)); + } + } + } + + fn start_handshake(&mut self, path: &PathRef, packet: &PublicPacket, now: Instant) { + qtrace!([self], "starting handshake"); + debug_assert_eq!(packet.packet_type(), PacketType::Initial); + self.remote_initial_source_cid = Some(ConnectionId::from(packet.scid())); + + let got_version = if self.role == Role::Server { + self.cid_manager + .add_odcid(self.original_destination_cid.as_ref().unwrap().clone()); + // Make a path on which to run the handshake. + self.setup_handshake_path(path, now); + + self.zero_rtt_state = match self.crypto.enable_0rtt(self.version, self.role) { + Ok(true) => { + qdebug!([self], "Accepted 0-RTT"); + ZeroRttState::AcceptedServer + } + _ => ZeroRttState::Rejected, + }; + + // The server knows the final version if it has remote transport parameters. + self.tps.borrow().remote.is_some() + } else { + qdebug!([self], "Changing to use Server CID={}", packet.scid()); + debug_assert!(path.borrow().is_primary()); + path.borrow_mut().set_remote_cid(packet.scid()); + + // The client knows the final version if it processed a CRYPTO frame. + self.stats.borrow().frame_rx.crypto > 0 + }; + if got_version { + self.set_state(State::Handshaking); + } else { + self.set_state(State::WaitVersion); + } + } + + /// Migrate to the provided path. + /// Either local or remote address (but not both) may be provided as `None` to have + /// the address from the current primary path used. + /// If `force` is true, then migration is immediate. + /// Otherwise, migration occurs after the path is probed successfully. + /// Either way, the path is probed and will be abandoned if the probe fails. + /// + /// # Errors + /// + /// Fails if this is not a client, not confirmed, or there are not enough connection + /// IDs available to use. + pub fn migrate( + &mut self, + local: Option<SocketAddr>, + remote: Option<SocketAddr>, + force: bool, + now: Instant, + ) -> Res<()> { + if self.role != Role::Client { + return Err(Error::InvalidMigration); + } + if !matches!(self.state(), State::Confirmed) { + return Err(Error::InvalidMigration); + } + + // Fill in the blanks, using the current primary path. + if local.is_none() && remote.is_none() { + // Pointless migration is pointless. + return Err(Error::InvalidMigration); + } + let local = local.unwrap_or_else(|| self.paths.primary().borrow().local_address()); + let remote = remote.unwrap_or_else(|| self.paths.primary().borrow().remote_address()); + + if mem::discriminant(&local.ip()) != mem::discriminant(&remote.ip()) { + // Can't mix address families. + return Err(Error::InvalidMigration); + } + if local.port() == 0 || remote.ip().is_unspecified() || remote.port() == 0 { + // All but the local address need to be specified. + return Err(Error::InvalidMigration); + } + if (local.ip().is_loopback() ^ remote.ip().is_loopback()) && !local.ip().is_unspecified() { + // Block attempts to migrate to a path with loopback on only one end, unless the local + // address is unspecified. + return Err(Error::InvalidMigration); + } + + let path = self.paths.find_path( + local, + remote, + self.conn_params.get_cc_algorithm(), + self.conn_params.pacing_enabled(), + now, + ); + self.ensure_permanent(&path)?; + qinfo!( + [self], + "Migrate to {} probe {}", + path.borrow(), + if force { "now" } else { "after" } + ); + if self.paths.migrate(&path, force, now) { + self.loss_recovery.migrate(); + } + Ok(()) + } + + fn migrate_to_preferred_address(&mut self, now: Instant) -> Res<()> { + let spa = if matches!( + self.conn_params.get_preferred_address(), + PreferredAddressConfig::Disabled + ) { + None + } else { + self.tps.borrow_mut().remote().get_preferred_address() + }; + if let Some((addr, cid)) = spa { + // The connection ID isn't special, so just save it. + self.connection_ids.add_remote(cid)?; + + // The preferred address doesn't dictate what the local address is, so this + // has to use the existing address. So only pay attention to a preferred + // address from the same family as is currently in use. More thought will + // be needed to work out how to get addresses from a different family. + let prev = self.paths.primary().borrow().remote_address(); + let remote = match prev.ip() { + IpAddr::V4(_) => addr.ipv4().map(SocketAddr::V4), + IpAddr::V6(_) => addr.ipv6().map(SocketAddr::V6), + }; + + if let Some(remote) = remote { + // Ignore preferred address that move to loopback from non-loopback. + // `migrate` doesn't enforce this rule. + if !prev.ip().is_loopback() && remote.ip().is_loopback() { + qwarn!([self], "Ignoring a move to a loopback address: {}", remote); + return Ok(()); + } + + if self.migrate(None, Some(remote), false, now).is_err() { + qwarn!([self], "Ignoring bad preferred address: {}", remote); + } + } else { + qwarn!([self], "Unable to migrate to a different address family"); + } + } + Ok(()) + } + + fn handle_migration(&mut self, path: &PathRef, d: &Datagram, migrate: bool, now: Instant) { + if !migrate { + return; + } + if self.role == Role::Client { + return; + } + + if self.ensure_permanent(path).is_ok() { + self.paths.handle_migration(path, d.source(), now); + } else { + qinfo!( + [self], + "{} Peer migrated, but no connection ID available", + path.borrow() + ); + } + } + + fn output(&mut self, now: Instant) -> SendOption { + qtrace!([self], "output {:?}", now); + let res = match &self.state { + State::Init + | State::WaitInitial + | State::WaitVersion + | State::Handshaking + | State::Connected + | State::Confirmed => { + if let Some(path) = self.paths.select_path() { + let res = self.output_path(&path, now); + self.capture_error(Some(path), now, 0, res) + } else { + Ok(SendOption::default()) + } + } + State::Closing { .. } | State::Draining { .. } | State::Closed(_) => { + if let Some(details) = self.state_signaling.close_frame() { + let path = Rc::clone(details.path()); + let res = self.output_close(details); + self.capture_error(Some(path), now, 0, res) + } else { + Ok(SendOption::default()) + } + } + }; + res.unwrap_or_default() + } + + fn build_packet_header( + path: &Path, + cspace: CryptoSpace, + encoder: Encoder, + tx: &CryptoDxState, + address_validation: &AddressValidationInfo, + version: Version, + grease_quic_bit: bool, + ) -> (PacketType, PacketBuilder) { + let pt = PacketType::from(cspace); + let mut builder = if pt == PacketType::Short { + qdebug!("Building Short dcid {}", path.remote_cid()); + PacketBuilder::short(encoder, tx.key_phase(), path.remote_cid()) + } else { + qdebug!( + "Building {:?} dcid {} scid {}", + pt, + path.remote_cid(), + path.local_cid(), + ); + + PacketBuilder::long(encoder, pt, version, path.remote_cid(), path.local_cid()) + }; + if builder.remaining() > 0 { + builder.scramble(grease_quic_bit); + if pt == PacketType::Initial { + builder.initial_token(address_validation.token()); + } + } + + (pt, builder) + } + + #[must_use] + fn add_packet_number( + builder: &mut PacketBuilder, + tx: &CryptoDxState, + largest_acknowledged: Option<PacketNumber>, + ) -> PacketNumber { + // Get the packet number and work out how long it is. + let pn = tx.next_pn(); + let unacked_range = if let Some(la) = largest_acknowledged { + // Double the range from this to the last acknowledged in this space. + (pn - la) << 1 + } else { + pn + 1 + }; + // Count how many bytes in this range are non-zero. + let pn_len = mem::size_of::<PacketNumber>() + - usize::try_from(unacked_range.leading_zeros() / 8).unwrap(); + // pn_len can't be zero (unacked_range is > 0) + // TODO(mt) also use `4*path CWND/path MTU` to set a minimum length. + builder.pn(pn, pn_len); + pn + } + + fn can_grease_quic_bit(&self) -> bool { + let tph = self.tps.borrow(); + if let Some(r) = &tph.remote { + r.get_empty(tparams::GREASE_QUIC_BIT) + } else if let Some(r) = &tph.remote_0rtt { + r.get_empty(tparams::GREASE_QUIC_BIT) + } else { + false + } + } + + fn output_close(&mut self, close: ClosingFrame) -> Res<SendOption> { + let mut encoder = Encoder::with_capacity(256); + let grease_quic_bit = self.can_grease_quic_bit(); + let version = self.version(); + for space in PacketNumberSpace::iter() { + let Some((cspace, tx)) = self.crypto.states.select_tx_mut(self.version, *space) else { + continue; + }; + + let path = close.path().borrow(); + let (_, mut builder) = Self::build_packet_header( + &path, + cspace, + encoder, + tx, + &AddressValidationInfo::None, + version, + grease_quic_bit, + ); + _ = Self::add_packet_number( + &mut builder, + tx, + self.loss_recovery.largest_acknowledged_pn(*space), + ); + // The builder will set the limit to 0 if there isn't enough space for the header. + if builder.is_full() { + encoder = builder.abort(); + break; + } + builder.set_limit(min(path.amplification_limit(), path.mtu()) - tx.expansion()); + debug_assert!(builder.limit() <= 2048); + + // ConnectionError::Application is only allowed at 1RTT. + let sanitized = if *space == PacketNumberSpace::ApplicationData { + None + } else { + close.sanitize() + }; + sanitized + .as_ref() + .unwrap_or(&close) + .write_frame(&mut builder); + encoder = builder.build(tx)?; + } + + Ok(SendOption::Yes(close.path().borrow().datagram(encoder))) + } + + /// Write the frames that are exchanged in the application data space. + /// The order of calls here determines the relative priority of frames. + fn write_appdata_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + ) -> Res<()> { + let stats = &mut self.stats.borrow_mut(); + let frame_stats = &mut stats.frame_tx; + if self.role == Role::Server { + if let Some(t) = self.state_signaling.write_done(builder)? { + tokens.push(t); + frame_stats.handshake_done += 1; + } + } + + self.streams + .write_frames(TransmissionPriority::Critical, builder, tokens, frame_stats); + if builder.is_full() { + return Ok(()); + } + + self.streams.write_frames( + TransmissionPriority::Important, + builder, + tokens, + frame_stats, + ); + if builder.is_full() { + return Ok(()); + } + + // NEW_CONNECTION_ID, RETIRE_CONNECTION_ID, and ACK_FREQUENCY. + self.cid_manager + .write_frames(builder, tokens, frame_stats)?; + if builder.is_full() { + return Ok(()); + } + self.paths.write_frames(builder, tokens, frame_stats); + if builder.is_full() { + return Ok(()); + } + + self.streams + .write_frames(TransmissionPriority::High, builder, tokens, frame_stats); + if builder.is_full() { + return Ok(()); + } + + self.streams + .write_frames(TransmissionPriority::Normal, builder, tokens, frame_stats); + if builder.is_full() { + return Ok(()); + } + + // Datagrams are best-effort and unreliable. Let streams starve them for now. + self.quic_datagrams.write_frames(builder, tokens, stats); + if builder.is_full() { + return Ok(()); + } + + let frame_stats = &mut stats.frame_tx; + // CRYPTO here only includes NewSessionTicket, plus NEW_TOKEN. + // Both of these are only used for resumption and so can be relatively low priority. + self.crypto.write_frame( + PacketNumberSpace::ApplicationData, + builder, + tokens, + frame_stats, + )?; + if builder.is_full() { + return Ok(()); + } + self.new_token.write_frames(builder, tokens, frame_stats)?; + if builder.is_full() { + return Ok(()); + } + + self.streams + .write_frames(TransmissionPriority::Low, builder, tokens, frame_stats); + + #[cfg(test)] + { + if let Some(w) = &mut self.test_frame_writer { + w.write_frames(builder); + } + } + + Ok(()) + } + + // Maybe send a probe. Return true if the packet was ack-eliciting. + fn maybe_probe( + &mut self, + path: &PathRef, + force_probe: bool, + builder: &mut PacketBuilder, + ack_end: usize, + tokens: &mut Vec<RecoveryToken>, + now: Instant, + ) -> bool { + let untracked = self.received_untracked && !self.state.connected(); + self.received_untracked = false; + + // Anything written after an ACK already elicits acknowledgment. + // If we need to probe and nothing has been written, send a PING. + if builder.len() > ack_end { + return true; + } + + let probe = if untracked && builder.packet_empty() || force_probe { + // If we received an untracked packet and we aren't probing already + // or the PTO timer fired: probe. + true + } else { + let pto = path.borrow().rtt().pto(PacketNumberSpace::ApplicationData); + if !builder.packet_empty() { + // The packet only contains an ACK. Check whether we want to + // force an ACK with a PING so we can stop tracking packets. + self.loss_recovery.should_probe(pto, now) + } else if self.streams.need_keep_alive() { + // We need to keep the connection alive, including sending + // a PING again. + self.idle_timeout.send_keep_alive(now, pto, tokens) + } else { + false + } + }; + if probe { + // Nothing ack-eliciting and we need to probe; send PING. + debug_assert_ne!(builder.remaining(), 0); + builder.encode_varint(crate::frame::FRAME_TYPE_PING); + let stats = &mut self.stats.borrow_mut().frame_tx; + stats.ping += 1; + stats.all += 1; + } + probe + } + + /// Write frames to the provided builder. Returns a list of tokens used for + /// tracking loss or acknowledgment, whether any frame was ACK eliciting, and + /// whether the packet was padded. + fn write_frames( + &mut self, + path: &PathRef, + space: PacketNumberSpace, + profile: &SendProfile, + builder: &mut PacketBuilder, + now: Instant, + ) -> Res<(Vec<RecoveryToken>, bool, bool)> { + let mut tokens = Vec::new(); + let primary = path.borrow().is_primary(); + let mut ack_eliciting = false; + + if primary { + let stats = &mut self.stats.borrow_mut().frame_tx; + self.acks.write_frame( + space, + now, + path.borrow().rtt().estimate(), + builder, + &mut tokens, + stats, + ); + } + let ack_end = builder.len(); + + // Avoid sending probes until the handshake completes, + // but send them even when we don't have space. + let full_mtu = profile.limit() == path.borrow().mtu(); + if space == PacketNumberSpace::ApplicationData && self.state.connected() { + // Probes should only be padded if the full MTU is available. + // The probing code needs to know so it can track that. + if path.borrow_mut().write_frames( + builder, + &mut self.stats.borrow_mut().frame_tx, + full_mtu, + now, + ) { + builder.enable_padding(true); + } + } + + if profile.ack_only(space) { + // If we are CC limited we can only send acks! + return Ok((tokens, false, false)); + } + + if primary { + if space == PacketNumberSpace::ApplicationData { + self.write_appdata_frames(builder, &mut tokens)?; + } else { + let stats = &mut self.stats.borrow_mut().frame_tx; + self.crypto + .write_frame(space, builder, &mut tokens, stats)?; + } + } + + // Maybe send a probe now, either to probe for losses or to keep the connection live. + let force_probe = profile.should_probe(space); + ack_eliciting |= self.maybe_probe(path, force_probe, builder, ack_end, &mut tokens, now); + // If this is not the primary path, this should be ack-eliciting. + debug_assert!(primary || ack_eliciting); + + // Add padding. Only pad 1-RTT packets so that we don't prevent coalescing. + // And avoid padding packets that otherwise only contain ACK because adding PADDING + // causes those packets to consume congestion window, which is not tracked (yet). + // And avoid padding if we don't have a full MTU available. + let stats = &mut self.stats.borrow_mut().frame_tx; + let padded = if ack_eliciting && full_mtu && builder.pad() { + stats.padding += 1; + stats.all += 1; + true + } else { + false + }; + + stats.all += tokens.len(); + Ok((tokens, ack_eliciting, padded)) + } + + /// Build a datagram, possibly from multiple packets (for different PN + /// spaces) and each containing 1+ frames. + fn output_path(&mut self, path: &PathRef, now: Instant) -> Res<SendOption> { + let mut initial_sent = None; + let mut needs_padding = false; + let grease_quic_bit = self.can_grease_quic_bit(); + let version = self.version(); + + // Determine how we are sending packets (PTO, etc..). + let mtu = path.borrow().mtu(); + let profile = self.loss_recovery.send_profile(&path.borrow(), now); + qdebug!([self], "output_path send_profile {:?}", profile); + + // Frames for different epochs must go in different packets, but then these + // packets can go in a single datagram + let mut encoder = Encoder::with_capacity(profile.limit()); + for space in PacketNumberSpace::iter() { + // Ensure we have tx crypto state for this epoch, or skip it. + let Some((cspace, tx)) = self.crypto.states.select_tx_mut(self.version, *space) else { + continue; + }; + + let header_start = encoder.len(); + let (pt, mut builder) = Self::build_packet_header( + &path.borrow(), + cspace, + encoder, + tx, + &self.address_validation, + version, + grease_quic_bit, + ); + let pn = Self::add_packet_number( + &mut builder, + tx, + self.loss_recovery.largest_acknowledged_pn(*space), + ); + // The builder will set the limit to 0 if there isn't enough space for the header. + if builder.is_full() { + encoder = builder.abort(); + break; + } + + // Configure the limits and padding for this packet. + let aead_expansion = tx.expansion(); + builder.set_limit(profile.limit() - aead_expansion); + builder.enable_padding(needs_padding); + debug_assert!(builder.limit() <= 2048); + if builder.is_full() { + encoder = builder.abort(); + break; + } + + // Add frames to the packet. + let payload_start = builder.len(); + let (tokens, ack_eliciting, padded) = + self.write_frames(path, *space, &profile, &mut builder, now)?; + if builder.packet_empty() { + // Nothing to include in this packet. + encoder = builder.abort(); + continue; + } + + dump_packet( + self, + path, + "TX ->", + pt, + pn, + &builder.as_ref()[payload_start..], + ); + qlog::packet_sent( + &mut self.qlog, + pt, + pn, + builder.len() - header_start + aead_expansion, + &builder.as_ref()[payload_start..], + ); + + self.stats.borrow_mut().packets_tx += 1; + let tx = self.crypto.states.tx_mut(self.version, cspace).unwrap(); + encoder = builder.build(tx)?; + debug_assert!(encoder.len() <= mtu); + self.crypto.states.auto_update()?; + + if ack_eliciting { + self.idle_timeout.on_packet_sent(now); + } + let sent = SentPacket::new( + pt, + pn, + now, + ack_eliciting, + tokens, + encoder.len() - header_start, + ); + if padded { + needs_padding = false; + self.loss_recovery.on_packet_sent(path, sent); + } else if pt == PacketType::Initial && (self.role == Role::Client || ack_eliciting) { + // Packets containing Initial packets might need padding, and we want to + // track that padding along with the Initial packet. So defer tracking. + initial_sent = Some(sent); + needs_padding = true; + } else { + if pt == PacketType::Handshake && self.role == Role::Client { + needs_padding = false; + } + self.loss_recovery.on_packet_sent(path, sent); + } + + if *space == PacketNumberSpace::Handshake + && self.role == Role::Server + && self.state == State::Confirmed + { + // We could discard handshake keys in set_state, + // but wait until after sending an ACK. + self.discard_keys(PacketNumberSpace::Handshake, now); + } + } + + if encoder.is_empty() { + qinfo!("TX blocked, profile={:?} ", profile); + Ok(SendOption::No(profile.paced())) + } else { + // Perform additional padding for Initial packets as necessary. + let mut packets: Vec<u8> = encoder.into(); + if let Some(mut initial) = initial_sent.take() { + if needs_padding { + qdebug!( + [self], + "pad Initial from {} to path MTU {}", + packets.len(), + mtu + ); + initial.size += mtu - packets.len(); + packets.resize(mtu, 0); + } + self.loss_recovery.on_packet_sent(path, initial); + } + path.borrow_mut().add_sent(packets.len()); + Ok(SendOption::Yes(path.borrow().datagram(packets))) + } + } + + pub fn initiate_key_update(&mut self) -> Res<()> { + if self.state == State::Confirmed { + let la = self + .loss_recovery + .largest_acknowledged_pn(PacketNumberSpace::ApplicationData); + qinfo!([self], "Initiating key update"); + self.crypto.states.initiate_key_update(la) + } else { + Err(Error::KeyUpdateBlocked) + } + } + + #[cfg(test)] + pub fn get_epochs(&self) -> (Option<usize>, Option<usize>) { + self.crypto.states.get_epochs() + } + + fn client_start(&mut self, now: Instant) -> Res<()> { + qinfo!([self], "client_start"); + debug_assert_eq!(self.role, Role::Client); + qlog::client_connection_started(&mut self.qlog, &self.paths.primary()); + qlog::client_version_information_initiated(&mut self.qlog, self.conn_params.get_versions()); + + self.handshake(now, self.version, PacketNumberSpace::Initial, None)?; + self.set_state(State::WaitInitial); + self.zero_rtt_state = if self.crypto.enable_0rtt(self.version, self.role)? { + qdebug!([self], "Enabled 0-RTT"); + ZeroRttState::Sending + } else { + ZeroRttState::Init + }; + Ok(()) + } + + fn get_closing_period_time(&self, now: Instant) -> Instant { + // Spec says close time should be at least PTO times 3. + now + (self.pto() * 3) + } + + /// Close the connection. + pub fn close(&mut self, now: Instant, app_error: AppError, msg: impl AsRef<str>) { + let error = ConnectionError::Application(app_error); + let timeout = self.get_closing_period_time(now); + if let Some(path) = self.paths.primary_fallible() { + self.state_signaling.close(path, error.clone(), 0, msg); + self.set_state(State::Closing { error, timeout }); + } else { + self.set_state(State::Closed(error)); + } + } + + fn set_initial_limits(&mut self) { + self.streams.set_initial_limits(); + let peer_timeout = self + .tps + .borrow() + .remote() + .get_integer(tparams::IDLE_TIMEOUT); + if peer_timeout > 0 { + self.idle_timeout + .set_peer_timeout(Duration::from_millis(peer_timeout)); + } + + self.quic_datagrams.set_remote_datagram_size( + self.tps + .borrow() + .remote() + .get_integer(tparams::MAX_DATAGRAM_FRAME_SIZE), + ); + } + + pub fn is_stream_id_allowed(&self, stream_id: StreamId) -> bool { + self.streams.is_stream_id_allowed(stream_id) + } + + /// Process the final set of transport parameters. + fn process_tps(&mut self) -> Res<()> { + self.validate_cids()?; + self.validate_versions()?; + { + let tps = self.tps.borrow(); + let remote = tps.remote.as_ref().unwrap(); + + // If the peer provided a preferred address, then we have to be a client + // and they have to be using a non-empty connection ID. + if remote.get_preferred_address().is_some() + && (self.role == Role::Server + || self.remote_initial_source_cid.as_ref().unwrap().is_empty()) + { + return Err(Error::TransportParameterError); + } + + let reset_token = if let Some(token) = remote.get_bytes(tparams::STATELESS_RESET_TOKEN) + { + <[u8; 16]>::try_from(token).unwrap() + } else { + // The other side didn't provide a stateless reset token. + // That's OK, they can try guessing this. + <[u8; 16]>::try_from(&random(16)[..]).unwrap() + }; + self.paths + .primary() + .borrow_mut() + .set_reset_token(reset_token); + + let max_ad = Duration::from_millis(remote.get_integer(tparams::MAX_ACK_DELAY)); + let min_ad = if remote.has_value(tparams::MIN_ACK_DELAY) { + let min_ad = Duration::from_micros(remote.get_integer(tparams::MIN_ACK_DELAY)); + if min_ad > max_ad { + return Err(Error::TransportParameterError); + } + Some(min_ad) + } else { + None + }; + self.paths.primary().borrow_mut().set_ack_delay( + max_ad, + min_ad, + self.conn_params.get_ack_ratio(), + ); + + let max_active_cids = remote.get_integer(tparams::ACTIVE_CONNECTION_ID_LIMIT); + self.cid_manager.set_limit(max_active_cids); + } + self.set_initial_limits(); + qlog::connection_tparams_set(&mut self.qlog, &self.tps.borrow()); + Ok(()) + } + + fn validate_cids(&mut self) -> Res<()> { + let tph = self.tps.borrow(); + let remote_tps = tph.remote.as_ref().unwrap(); + + let tp = remote_tps.get_bytes(tparams::INITIAL_SOURCE_CONNECTION_ID); + if self + .remote_initial_source_cid + .as_ref() + .map(ConnectionId::as_cid_ref) + != tp.map(ConnectionIdRef::from) + { + qwarn!( + [self], + "ISCID test failed: self cid {:?} != tp cid {:?}", + self.remote_initial_source_cid, + tp.map(hex), + ); + return Err(Error::ProtocolViolation); + } + + if self.role == Role::Client { + let tp = remote_tps.get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID); + if self + .original_destination_cid + .as_ref() + .map(ConnectionId::as_cid_ref) + != tp.map(ConnectionIdRef::from) + { + qwarn!( + [self], + "ODCID test failed: self cid {:?} != tp cid {:?}", + self.original_destination_cid, + tp.map(hex), + ); + return Err(Error::ProtocolViolation); + } + + let tp = remote_tps.get_bytes(tparams::RETRY_SOURCE_CONNECTION_ID); + let expected = if let AddressValidationInfo::Retry { + retry_source_cid, .. + } = &self.address_validation + { + Some(retry_source_cid.as_cid_ref()) + } else { + None + }; + if expected != tp.map(ConnectionIdRef::from) { + qwarn!( + [self], + "RSCID test failed. self cid {:?} != tp cid {:?}", + expected, + tp.map(hex), + ); + return Err(Error::ProtocolViolation); + } + } + + Ok(()) + } + + /// Validate the `version_negotiation` transport parameter from the peer. + fn validate_versions(&mut self) -> Res<()> { + let tph = self.tps.borrow(); + let remote_tps = tph.remote.as_ref().unwrap(); + // `current` and `other` are the value from the peer's transport parameters. + // We're checking that these match our expectations. + if let Some((current, other)) = remote_tps.get_versions() { + qtrace!( + [self], + "validate_versions: current={:x} chosen={:x} other={:x?}", + self.version.wire_version(), + current, + other, + ); + if self.role == Role::Server { + // 1. A server acts on transport parameters, with validation + // of `current` happening in the transport parameter handler. + // All we need to do is confirm that the transport parameter + // was provided. + Ok(()) + } else if self.version().wire_version() != current { + qinfo!([self], "validate_versions: current version mismatch"); + Err(Error::VersionNegotiation) + } else if self + .conn_params + .get_versions() + .initial() + .is_compatible(self.version) + { + // 2. The current version is compatible with what we attempted. + // That's a compatible upgrade and that's OK. + Ok(()) + } else { + // 3. The initial version we attempted isn't compatible. Check that + // the one we would have chosen is compatible with this one. + let mut all_versions = other.to_owned(); + all_versions.push(current); + if self + .conn_params + .get_versions() + .preferred(&all_versions) + .ok_or(Error::VersionNegotiation)? + .is_compatible(self.version) + { + Ok(()) + } else { + qinfo!([self], "validate_versions: failed"); + Err(Error::VersionNegotiation) + } + } + } else if self.version != Version::Version1 && !self.version.is_draft() { + qinfo!([self], "validate_versions: missing extension"); + Err(Error::VersionNegotiation) + } else { + Ok(()) + } + } + + fn confirm_version(&mut self, v: Version) { + if self.version != v { + qinfo!([self], "Compatible upgrade {:?} ==> {:?}", self.version, v); + } + self.crypto.confirm_version(v); + self.version = v; + } + + fn compatible_upgrade(&mut self, packet_version: Version) { + if !matches!(self.state, State::WaitInitial | State::WaitVersion) { + return; + } + + if self.role == Role::Client { + self.confirm_version(packet_version); + } else if self.tps.borrow().remote.is_some() { + let version = self.tps.borrow().version(); + let dcid = self.original_destination_cid.as_ref().unwrap(); + self.crypto.states.init_server(version, dcid); + self.confirm_version(version); + } + } + + fn handshake( + &mut self, + now: Instant, + packet_version: Version, + space: PacketNumberSpace, + data: Option<&[u8]>, + ) -> Res<()> { + qtrace!([self], "Handshake space={} data={:0x?}", space, data); + + let try_update = data.is_some(); + match self.crypto.handshake(now, space, data)? { + HandshakeState::Authenticated(_) | HandshakeState::InProgress => (), + HandshakeState::AuthenticationPending => self.events.authentication_needed(), + HandshakeState::EchFallbackAuthenticationPending(public_name) => self + .events + .ech_fallback_authentication_needed(public_name.clone()), + HandshakeState::Complete(_) => { + if !self.state.connected() { + self.set_connected(now)?; + } + } + _ => { + unreachable!("Crypto state should not be new or failed after successful handshake") + } + } + + // There is a chance that this could be called less often, but getting the + // conditions right is a little tricky, so call whenever CRYPTO data is used. + if try_update { + self.compatible_upgrade(packet_version); + // We have transport parameters, it's go time. + if self.tps.borrow().remote.is_some() { + self.set_initial_limits(); + } + if self.crypto.install_keys(self.role)? { + if self.role == Role::Client { + // We won't acknowledge Initial packets as a result of this, but the + // server can rely on implicit acknowledgment. + self.discard_keys(PacketNumberSpace::Initial, now); + } + self.saved_datagrams.make_available(CryptoSpace::Handshake); + } + } + + Ok(()) + } + + fn input_frame( + &mut self, + path: &PathRef, + packet_version: Version, + packet_type: PacketType, + frame: Frame, + now: Instant, + ) -> Res<()> { + if !frame.is_allowed(packet_type) { + qinfo!("frame not allowed: {:?} {:?}", frame, packet_type); + return Err(Error::ProtocolViolation); + } + self.stats.borrow_mut().frame_rx.all += 1; + let space = PacketNumberSpace::from(packet_type); + if frame.is_stream() { + return self + .streams + .input_frame(frame, &mut self.stats.borrow_mut().frame_rx); + } + match frame { + Frame::Padding => { + // Note: This counts contiguous padding as a single frame. + self.stats.borrow_mut().frame_rx.padding += 1; + } + Frame::Ping => { + // If we get a PING and there are outstanding CRYPTO frames, + // prepare to resend them. + self.stats.borrow_mut().frame_rx.ping += 1; + self.crypto.resend_unacked(space); + if space == PacketNumberSpace::ApplicationData { + // Send an ACK immediately if we might not otherwise do so. + self.acks.immediate_ack(now); + } + } + Frame::Ack { + largest_acknowledged, + ack_delay, + first_ack_range, + ack_ranges, + } => { + let ranges = + Frame::decode_ack_frame(largest_acknowledged, first_ack_range, &ack_ranges)?; + self.handle_ack(space, largest_acknowledged, ranges, ack_delay, now); + } + Frame::Crypto { offset, data } => { + qtrace!( + [self], + "Crypto frame on space={} offset={}, data={:0x?}", + space, + offset, + &data + ); + self.stats.borrow_mut().frame_rx.crypto += 1; + self.crypto.streams.inbound_frame(space, offset, data)?; + if self.crypto.streams.data_ready(space) { + let mut buf = Vec::new(); + let read = self.crypto.streams.read_to_end(space, &mut buf); + qdebug!("Read {} bytes", read); + self.handshake(now, packet_version, space, Some(&buf))?; + self.create_resumption_token(now); + } else { + // If we get a useless CRYPTO frame send outstanding CRYPTO frames again. + self.crypto.resend_unacked(space); + } + } + Frame::NewToken { token } => { + self.stats.borrow_mut().frame_rx.new_token += 1; + self.new_token.save_token(token.to_vec()); + self.create_resumption_token(now); + } + Frame::NewConnectionId { + sequence_number, + connection_id, + stateless_reset_token, + retire_prior, + } => { + self.stats.borrow_mut().frame_rx.new_connection_id += 1; + self.connection_ids.add_remote(ConnectionIdEntry::new( + sequence_number, + ConnectionId::from(connection_id), + stateless_reset_token.to_owned(), + ))?; + self.paths + .retire_cids(retire_prior, &mut self.connection_ids); + if self.connection_ids.len() >= LOCAL_ACTIVE_CID_LIMIT { + qinfo!([self], "received too many connection IDs"); + return Err(Error::ConnectionIdLimitExceeded); + } + } + Frame::RetireConnectionId { sequence_number } => { + self.stats.borrow_mut().frame_rx.retire_connection_id += 1; + self.cid_manager.retire(sequence_number); + } + Frame::PathChallenge { data } => { + self.stats.borrow_mut().frame_rx.path_challenge += 1; + // If we were challenged, try to make the path permanent. + // Report an error if we don't have enough connection IDs. + self.ensure_permanent(path)?; + path.borrow_mut().challenged(data); + } + Frame::PathResponse { data } => { + self.stats.borrow_mut().frame_rx.path_response += 1; + if self.paths.path_response(data, now) { + // This PATH_RESPONSE enabled migration; tell loss recovery. + self.loss_recovery.migrate(); + } + } + Frame::ConnectionClose { + error_code, + frame_type, + reason_phrase, + } => { + self.stats.borrow_mut().frame_rx.connection_close += 1; + let reason_phrase = String::from_utf8_lossy(&reason_phrase); + qinfo!( + [self], + "ConnectionClose received. Error code: {:?} frame type {:x} reason {}", + error_code, + frame_type, + reason_phrase + ); + let (detail, frame_type) = if let CloseError::Application(_) = error_code { + // Use a transport error here because we want to send + // NO_ERROR in this case. + ( + Error::PeerApplicationError(error_code.code()), + FRAME_TYPE_CONNECTION_CLOSE_APPLICATION, + ) + } else { + ( + Error::PeerError(error_code.code()), + FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT, + ) + }; + let error = ConnectionError::Transport(detail); + self.state_signaling + .drain(Rc::clone(path), error.clone(), frame_type, ""); + self.set_state(State::Draining { + error, + timeout: self.get_closing_period_time(now), + }); + } + Frame::HandshakeDone => { + self.stats.borrow_mut().frame_rx.handshake_done += 1; + if self.role == Role::Server || !self.state.connected() { + return Err(Error::ProtocolViolation); + } + self.set_state(State::Confirmed); + self.discard_keys(PacketNumberSpace::Handshake, now); + self.migrate_to_preferred_address(now)?; + } + Frame::AckFrequency { + seqno, + tolerance, + delay, + ignore_order, + } => { + self.stats.borrow_mut().frame_rx.ack_frequency += 1; + let delay = Duration::from_micros(delay); + if delay < GRANULARITY { + return Err(Error::ProtocolViolation); + } + self.acks + .ack_freq(seqno, tolerance - 1, delay, ignore_order); + } + Frame::Datagram { data, .. } => { + self.stats.borrow_mut().frame_rx.datagram += 1; + self.quic_datagrams + .handle_datagram(data, &mut self.stats.borrow_mut())?; + } + _ => unreachable!("All other frames are for streams"), + }; + + Ok(()) + } + + /// Given a set of `SentPacket` instances, ensure that the source of the packet + /// is told that they are lost. This gives the frame generation code a chance + /// to retransmit the frame as needed. + fn handle_lost_packets(&mut self, lost_packets: &[SentPacket]) { + for lost in lost_packets { + for token in &lost.tokens { + qdebug!([self], "Lost: {:?}", token); + match token { + RecoveryToken::Ack(_) => {} + RecoveryToken::Crypto(ct) => self.crypto.lost(ct), + RecoveryToken::HandshakeDone => self.state_signaling.handshake_done(), + RecoveryToken::NewToken(seqno) => self.new_token.lost(*seqno), + RecoveryToken::NewConnectionId(ncid) => self.cid_manager.lost(ncid), + RecoveryToken::RetireConnectionId(seqno) => self.paths.lost_retire_cid(*seqno), + RecoveryToken::AckFrequency(rate) => self.paths.lost_ack_frequency(rate), + RecoveryToken::KeepAlive => self.idle_timeout.lost_keep_alive(), + RecoveryToken::Stream(stream_token) => self.streams.lost(stream_token), + RecoveryToken::Datagram(dgram_tracker) => { + self.events + .datagram_outcome(dgram_tracker, OutgoingDatagramOutcome::Lost); + self.stats.borrow_mut().datagram_tx.lost += 1; + } + } + } + } + } + + fn decode_ack_delay(&self, v: u64) -> Duration { + // If we have remote transport parameters, use them. + // Otherwise, ack delay should be zero (because it's the handshake). + if let Some(r) = self.tps.borrow().remote.as_ref() { + let exponent = u32::try_from(r.get_integer(tparams::ACK_DELAY_EXPONENT)).unwrap(); + Duration::from_micros(v.checked_shl(exponent).unwrap_or(u64::MAX)) + } else { + Duration::new(0, 0) + } + } + + fn handle_ack<R>( + &mut self, + space: PacketNumberSpace, + largest_acknowledged: u64, + ack_ranges: R, + ack_delay: u64, + now: Instant, + ) where + R: IntoIterator<Item = RangeInclusive<u64>> + Debug, + R::IntoIter: ExactSizeIterator, + { + qinfo!([self], "Rx ACK space={}, ranges={:?}", space, ack_ranges); + + let (acked_packets, lost_packets) = self.loss_recovery.on_ack_received( + &self.paths.primary(), + space, + largest_acknowledged, + ack_ranges, + self.decode_ack_delay(ack_delay), + now, + ); + for acked in acked_packets { + for token in &acked.tokens { + match token { + RecoveryToken::Stream(stream_token) => self.streams.acked(stream_token), + RecoveryToken::Ack(at) => self.acks.acked(at), + RecoveryToken::Crypto(ct) => self.crypto.acked(ct), + RecoveryToken::NewToken(seqno) => self.new_token.acked(*seqno), + RecoveryToken::NewConnectionId(entry) => self.cid_manager.acked(entry), + RecoveryToken::RetireConnectionId(seqno) => self.paths.acked_retire_cid(*seqno), + RecoveryToken::AckFrequency(rate) => self.paths.acked_ack_frequency(rate), + RecoveryToken::KeepAlive => self.idle_timeout.ack_keep_alive(), + RecoveryToken::Datagram(dgram_tracker) => self + .events + .datagram_outcome(dgram_tracker, OutgoingDatagramOutcome::Acked), + // We only worry when these are lost + RecoveryToken::HandshakeDone => (), + } + } + } + self.handle_lost_packets(&lost_packets); + qlog::packets_lost(&mut self.qlog, &lost_packets); + let stats = &mut self.stats.borrow_mut().frame_rx; + stats.ack += 1; + stats.largest_acknowledged = max(stats.largest_acknowledged, largest_acknowledged); + } + + /// When the server rejects 0-RTT we need to drop a bunch of stuff. + fn client_0rtt_rejected(&mut self, now: Instant) { + if !matches!(self.zero_rtt_state, ZeroRttState::Sending) { + return; + } + qdebug!([self], "0-RTT rejected"); + + // Tell 0-RTT packets that they were "lost". + let dropped = self.loss_recovery.drop_0rtt(&self.paths.primary(), now); + self.handle_lost_packets(&dropped); + + self.streams.zero_rtt_rejected(); + + self.crypto.states.discard_0rtt_keys(); + self.events.client_0rtt_rejected(); + } + + fn set_connected(&mut self, now: Instant) -> Res<()> { + qinfo!([self], "TLS connection complete"); + if self.crypto.tls.info().map(SecretAgentInfo::alpn).is_none() { + qwarn!([self], "No ALPN. Closing connection."); + // 120 = no_application_protocol + return Err(Error::CryptoAlert(120)); + } + if self.role == Role::Server { + // Remove the randomized client CID from the list of acceptable CIDs. + self.cid_manager.remove_odcid(); + // Mark the path as validated, if it isn't already. + let path = self.paths.primary(); + path.borrow_mut().set_valid(now); + // Generate a qlog event that the server connection started. + qlog::server_connection_started(&mut self.qlog, &path); + } else { + self.zero_rtt_state = if self.crypto.tls.info().unwrap().early_data_accepted() { + ZeroRttState::AcceptedClient + } else { + self.client_0rtt_rejected(now); + ZeroRttState::Rejected + }; + } + + // Setting application keys has to occur after 0-RTT rejection. + let pto = self.pto(); + self.crypto + .install_application_keys(self.version, now + pto)?; + self.process_tps()?; + self.set_state(State::Connected); + self.create_resumption_token(now); + self.saved_datagrams + .make_available(CryptoSpace::ApplicationData); + self.stats.borrow_mut().resumed = self.crypto.tls.info().unwrap().resumed(); + if self.role == Role::Server { + self.state_signaling.handshake_done(); + self.set_state(State::Confirmed); + } + qinfo!([self], "Connection established"); + Ok(()) + } + + fn set_state(&mut self, state: State) { + if state > self.state { + qinfo!([self], "State change from {:?} -> {:?}", self.state, state); + self.state = state.clone(); + if self.state.closed() { + self.streams.clear_streams(); + } + self.events.connection_state_change(state); + qlog::connection_state_updated(&mut self.qlog, &self.state); + } else if mem::discriminant(&state) != mem::discriminant(&self.state) { + // Only tolerate a regression in state if the new state is closing + // and the connection is already closed. + debug_assert!(matches!( + state, + State::Closing { .. } | State::Draining { .. } + )); + debug_assert!(self.state.closed()); + } + } + + /// Create a stream. + /// Returns new stream id + /// + /// # Errors + /// + /// `ConnectionState` if the connecton stat does not allow to create streams. + /// `StreamLimitError` if we are limiied by server's stream concurence. + pub fn stream_create(&mut self, st: StreamType) -> Res<StreamId> { + // Can't make streams while closing, otherwise rely on the stream limits. + match self.state { + State::Closing { .. } | State::Draining { .. } | State::Closed { .. } => { + return Err(Error::ConnectionState); + } + State::WaitInitial | State::Handshaking => { + if self.role == Role::Client && self.zero_rtt_state != ZeroRttState::Sending { + return Err(Error::ConnectionState); + } + } + // In all other states, trust that the stream limits are correct. + _ => (), + } + + self.streams.stream_create(st) + } + + /// Set the priority of a stream. + /// + /// # Errors + /// + /// `InvalidStreamId` the stream does not exist. + pub fn stream_priority( + &mut self, + stream_id: StreamId, + transmission: TransmissionPriority, + retransmission: RetransmissionPriority, + ) -> Res<()> { + self.streams + .get_send_stream_mut(stream_id)? + .set_priority(transmission, retransmission); + Ok(()) + } + + /// Set the SendOrder of a stream. Re-enqueues to keep the ordering correct + /// + /// # Errors + /// + /// Returns InvalidStreamId if the stream id doesn't exist + pub fn stream_sendorder( + &mut self, + stream_id: StreamId, + sendorder: Option<SendOrder>, + ) -> Res<()> { + self.streams.set_sendorder(stream_id, sendorder) + } + + /// Set the Fairness of a stream + /// + /// # Errors + /// + /// Returns InvalidStreamId if the stream id doesn't exist + pub fn stream_fairness(&mut self, stream_id: StreamId, fairness: bool) -> Res<()> { + self.streams.set_fairness(stream_id, fairness) + } + + pub fn send_stream_stats(&self, stream_id: StreamId) -> Res<SendStreamStats> { + self.streams.get_send_stream(stream_id).map(|s| s.stats()) + } + + pub fn recv_stream_stats(&mut self, stream_id: StreamId) -> Res<RecvStreamStats> { + let stream = self.streams.get_recv_stream_mut(stream_id)?; + + Ok(stream.stats()) + } + + /// Send data on a stream. + /// Returns how many bytes were successfully sent. Could be less + /// than total, based on receiver credit space available, etc. + /// + /// # Errors + /// + /// `InvalidStreamId` the stream does not exist, + /// `InvalidInput` if length of `data` is zero, + /// `FinalSizeError` if the stream has already been closed. + pub fn stream_send(&mut self, stream_id: StreamId, data: &[u8]) -> Res<usize> { + self.streams.get_send_stream_mut(stream_id)?.send(data) + } + + /// Send all data or nothing on a stream. May cause DATA_BLOCKED or + /// STREAM_DATA_BLOCKED frames to be sent. + /// Returns true if data was successfully sent, otherwise false. + /// + /// # Errors + /// + /// `InvalidStreamId` the stream does not exist, + /// `InvalidInput` if length of `data` is zero, + /// `FinalSizeError` if the stream has already been closed. + pub fn stream_send_atomic(&mut self, stream_id: StreamId, data: &[u8]) -> Res<bool> { + let val = self + .streams + .get_send_stream_mut(stream_id)? + .send_atomic(data); + if let Ok(val) = val { + debug_assert!( + val == 0 || val == data.len(), + "Unexpected value {} when trying to send {} bytes atomically", + val, + data.len() + ); + } + val.map(|v| v == data.len()) + } + + /// Bytes that stream_send() is guaranteed to accept for sending. + /// i.e. that will not be blocked by flow credits or send buffer max + /// capacity. + pub fn stream_avail_send_space(&self, stream_id: StreamId) -> Res<usize> { + Ok(self.streams.get_send_stream(stream_id)?.avail()) + } + + /// Close the stream. Enqueued data will be sent. + pub fn stream_close_send(&mut self, stream_id: StreamId) -> Res<()> { + self.streams.get_send_stream_mut(stream_id)?.close(); + Ok(()) + } + + /// Abandon transmission of in-flight and future stream data. + pub fn stream_reset_send(&mut self, stream_id: StreamId, err: AppError) -> Res<()> { + self.streams.get_send_stream_mut(stream_id)?.reset(err); + Ok(()) + } + + /// Read buffered data from stream. bool says whether read bytes includes + /// the final data on stream. + /// + /// # Errors + /// + /// `InvalidStreamId` if the stream does not exist. + /// `NoMoreData` if data and fin bit were previously read by the application. + pub fn stream_recv(&mut self, stream_id: StreamId, data: &mut [u8]) -> Res<(usize, bool)> { + let stream = self.streams.get_recv_stream_mut(stream_id)?; + + let rb = stream.read(data)?; + Ok(rb) + } + + /// Application is no longer interested in this stream. + pub fn stream_stop_sending(&mut self, stream_id: StreamId, err: AppError) -> Res<()> { + let stream = self.streams.get_recv_stream_mut(stream_id)?; + + stream.stop_sending(err); + Ok(()) + } + + /// Increases `max_stream_data` for a `stream_id`. + /// + /// # Errors + /// + /// Returns `InvalidStreamId` if a stream does not exist or the receiving + /// side is closed. + pub fn set_stream_max_data(&mut self, stream_id: StreamId, max_data: u64) -> Res<()> { + let stream = self.streams.get_recv_stream_mut(stream_id)?; + + stream.set_stream_max_data(max_data); + Ok(()) + } + + /// Mark a receive stream as being important enough to keep the connection alive + /// (if `keep` is `true`) or no longer important (if `keep` is `false`). If any + /// stream is marked this way, PING frames will be used to keep the connection + /// alive, even when there is no activity. + /// + /// # Errors + /// + /// Returns `InvalidStreamId` if a stream does not exist or the receiving + /// side is closed. + pub fn stream_keep_alive(&mut self, stream_id: StreamId, keep: bool) -> Res<()> { + self.streams.keep_alive(stream_id, keep) + } + + pub fn remote_datagram_size(&self) -> u64 { + self.quic_datagrams.remote_datagram_size() + } + + /// Returns the current max size of a datagram that can fit into a packet. + /// The value will change over time depending on the encoded size of the + /// packet number, ack frames, etc. + /// + /// # Error + /// + /// The function returns `NotAvailable` if datagrams are not enabled. + pub fn max_datagram_size(&self) -> Res<u64> { + let max_dgram_size = self.quic_datagrams.remote_datagram_size(); + if max_dgram_size == 0 { + return Err(Error::NotAvailable); + } + let version = self.version(); + let Some((cspace, tx)) = self + .crypto + .states + .select_tx(self.version, PacketNumberSpace::ApplicationData) + else { + return Err(Error::NotAvailable); + }; + let path = self.paths.primary_fallible().ok_or(Error::NotAvailable)?; + let mtu = path.borrow().mtu(); + let encoder = Encoder::with_capacity(mtu); + + let (_, mut builder) = Self::build_packet_header( + &path.borrow(), + cspace, + encoder, + tx, + &self.address_validation, + version, + false, + ); + _ = Self::add_packet_number( + &mut builder, + tx, + self.loss_recovery + .largest_acknowledged_pn(PacketNumberSpace::ApplicationData), + ); + + let data_len_possible = + u64::try_from(mtu.saturating_sub(tx.expansion() + builder.len() + 1)).unwrap(); + Ok(min(data_len_possible, max_dgram_size)) + } + + /// Queue a datagram for sending. + /// + /// # Error + /// + /// The function returns `TooMuchData` if the supply buffer is bigger than + /// the allowed remote datagram size. The funcion does not check if the + /// datagram can fit into a packet (i.e. MTU limit). This is checked during + /// creation of an actual packet and the datagram will be dropped if it does + /// not fit into the packet. The app is encourage to use `max_datagram_size` + /// to check the estimated max datagram size and to use smaller datagrams. + /// `max_datagram_size` is just a current estimate and will change over + /// time depending on the encoded size of the packet number, ack frames, etc. + + pub fn send_datagram(&mut self, buf: &[u8], id: impl Into<DatagramTracking>) -> Res<()> { + self.quic_datagrams + .add_datagram(buf, id.into(), &mut self.stats.borrow_mut()) + } +} + +impl EventProvider for Connection { + type Event = ConnectionEvent; + + /// Return true if there are outstanding events. + fn has_events(&self) -> bool { + self.events.has_events() + } + + /// Get events that indicate state changes on the connection. This method + /// correctly handles cases where handling one event can obsolete + /// previously-queued events, or cause new events to be generated. + fn next_event(&mut self) -> Option<Self::Event> { + self.events.next_event() + } +} + +impl ::std::fmt::Display for Connection { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{:?} ", self.role)?; + if let Some(cid) = self.odcid() { + std::fmt::Display::fmt(&cid, f) + } else { + write!(f, "...") + } + } +} + +#[cfg(test)] +mod tests; diff --git a/third_party/rust/neqo-transport/src/connection/params.rs b/third_party/rust/neqo-transport/src/connection/params.rs new file mode 100644 index 0000000000..48aba4303b --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/params.rs @@ -0,0 +1,392 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{cmp::max, convert::TryFrom, time::Duration}; + +pub use crate::recovery::FAST_PTO_SCALE; +use crate::{ + connection::{ConnectionIdManager, Role, LOCAL_ACTIVE_CID_LIMIT}, + recv_stream::RECV_BUFFER_SIZE, + rtt::GRANULARITY, + stream_id::StreamType, + tparams::{self, PreferredAddress, TransportParameter, TransportParametersHandler}, + tracking::DEFAULT_ACK_DELAY, + version::{Version, VersionConfig}, + CongestionControlAlgorithm, Res, +}; + +const LOCAL_MAX_DATA: u64 = 0x3FFF_FFFF_FFFF_FFFF; // 2^62-1 +const LOCAL_STREAM_LIMIT_BIDI: u64 = 16; +const LOCAL_STREAM_LIMIT_UNI: u64 = 16; +/// See `ConnectionParameters.ack_ratio` for a discussion of this value. +pub const ACK_RATIO_SCALE: u8 = 10; +/// By default, aim to have the peer acknowledge 4 times per round trip time. +/// See `ConnectionParameters.ack_ratio` for more. +const DEFAULT_ACK_RATIO: u8 = 4 * ACK_RATIO_SCALE; +/// The local value for the idle timeout period. +const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(30); +const MAX_QUEUED_DATAGRAMS_DEFAULT: usize = 10; + +/// What to do with preferred addresses. +#[derive(Debug, Clone)] +pub enum PreferredAddressConfig { + /// Disabled, whether for client or server. + Disabled, + /// Enabled at a client, disabled at a server. + Default, + /// Enabled at both client and server. + Address(PreferredAddress), +} + +/// ConnectionParameters use for setting intitial value for QUIC parameters. +/// This collects configuration like initial limits, protocol version, and +/// congestion control algorithm. +#[derive(Debug, Clone)] +pub struct ConnectionParameters { + versions: VersionConfig, + cc_algorithm: CongestionControlAlgorithm, + /// Initial connection-level flow control limit. + max_data: u64, + /// Initial flow control limit for receiving data on bidirectional streams that the peer + /// creates. + max_stream_data_bidi_remote: u64, + /// Initial flow control limit for receiving data on bidirectional streams that this endpoint + /// creates. + max_stream_data_bidi_local: u64, + /// Initial flow control limit for receiving data on unidirectional streams that the peer + /// creates. + max_stream_data_uni: u64, + /// Initial limit on bidirectional streams that the peer creates. + max_streams_bidi: u64, + /// Initial limit on unidirectional streams that this endpoint creates. + max_streams_uni: u64, + /// The ACK ratio determines how many acknowledgements we will request as a + /// fraction of both the current congestion window (expressed in packets) and + /// as a fraction of the current round trip time. This value is scaled by + /// `ACK_RATIO_SCALE`; that is, if the goal is to have at least five + /// acknowledgments every round trip, set the value to `5 * ACK_RATIO_SCALE`. + /// Values less than `ACK_RATIO_SCALE` are clamped to `ACK_RATIO_SCALE`. + ack_ratio: u8, + /// The duration of the idle timeout for the connection. + idle_timeout: Duration, + preferred_address: PreferredAddressConfig, + datagram_size: u64, + outgoing_datagram_queue: usize, + incoming_datagram_queue: usize, + fast_pto: u8, + fuzzing: bool, + grease: bool, + pacing: bool, +} + +impl Default for ConnectionParameters { + fn default() -> Self { + Self { + versions: VersionConfig::default(), + cc_algorithm: CongestionControlAlgorithm::NewReno, + max_data: LOCAL_MAX_DATA, + max_stream_data_bidi_remote: u64::try_from(RECV_BUFFER_SIZE).unwrap(), + max_stream_data_bidi_local: u64::try_from(RECV_BUFFER_SIZE).unwrap(), + max_stream_data_uni: u64::try_from(RECV_BUFFER_SIZE).unwrap(), + max_streams_bidi: LOCAL_STREAM_LIMIT_BIDI, + max_streams_uni: LOCAL_STREAM_LIMIT_UNI, + ack_ratio: DEFAULT_ACK_RATIO, + idle_timeout: DEFAULT_IDLE_TIMEOUT, + preferred_address: PreferredAddressConfig::Default, + datagram_size: 0, + outgoing_datagram_queue: MAX_QUEUED_DATAGRAMS_DEFAULT, + incoming_datagram_queue: MAX_QUEUED_DATAGRAMS_DEFAULT, + fast_pto: FAST_PTO_SCALE, + fuzzing: false, + grease: true, + pacing: true, + } + } +} + +impl ConnectionParameters { + pub fn get_versions(&self) -> &VersionConfig { + &self.versions + } + + pub(crate) fn get_versions_mut(&mut self) -> &mut VersionConfig { + &mut self.versions + } + + /// Describe the initial version that should be attempted and all the + /// versions that should be enabled. This list should contain the initial + /// version and be in order of preference, with more preferred versions + /// before less preferred. + pub fn versions(mut self, initial: Version, all: Vec<Version>) -> Self { + self.versions = VersionConfig::new(initial, all); + self + } + + pub fn get_cc_algorithm(&self) -> CongestionControlAlgorithm { + self.cc_algorithm + } + + pub fn cc_algorithm(mut self, v: CongestionControlAlgorithm) -> Self { + self.cc_algorithm = v; + self + } + + pub fn get_max_data(&self) -> u64 { + self.max_data + } + + pub fn max_data(mut self, v: u64) -> Self { + self.max_data = v; + self + } + + pub fn get_max_streams(&self, stream_type: StreamType) -> u64 { + match stream_type { + StreamType::BiDi => self.max_streams_bidi, + StreamType::UniDi => self.max_streams_uni, + } + } + + /// # Panics + /// + /// If v > 2^60 (the maximum allowed by the protocol). + pub fn max_streams(mut self, stream_type: StreamType, v: u64) -> Self { + assert!(v <= (1 << 60), "max_streams is too large"); + match stream_type { + StreamType::BiDi => { + self.max_streams_bidi = v; + } + StreamType::UniDi => { + self.max_streams_uni = v; + } + } + self + } + + /// Get the maximum stream data that we will accept on different types of streams. + /// + /// # Panics + /// + /// If `StreamType::UniDi` and `false` are passed as that is not a valid combination. + pub fn get_max_stream_data(&self, stream_type: StreamType, remote: bool) -> u64 { + match (stream_type, remote) { + (StreamType::BiDi, false) => self.max_stream_data_bidi_local, + (StreamType::BiDi, true) => self.max_stream_data_bidi_remote, + (StreamType::UniDi, false) => { + panic!("Can't get receive limit on a stream that can only be sent.") + } + (StreamType::UniDi, true) => self.max_stream_data_uni, + } + } + + /// Set the maximum stream data that we will accept on different types of streams. + /// + /// # Panics + /// + /// If `StreamType::UniDi` and `false` are passed as that is not a valid combination + /// or if v >= 62 (the maximum allowed by the protocol). + pub fn max_stream_data(mut self, stream_type: StreamType, remote: bool, v: u64) -> Self { + assert!(v < (1 << 62), "max stream data is too large"); + match (stream_type, remote) { + (StreamType::BiDi, false) => { + self.max_stream_data_bidi_local = v; + } + (StreamType::BiDi, true) => { + self.max_stream_data_bidi_remote = v; + } + (StreamType::UniDi, false) => { + panic!("Can't set receive limit on a stream that can only be sent.") + } + (StreamType::UniDi, true) => { + self.max_stream_data_uni = v; + } + } + self + } + + /// Set a preferred address (which only has an effect for a server). + pub fn preferred_address(mut self, preferred: PreferredAddress) -> Self { + self.preferred_address = PreferredAddressConfig::Address(preferred); + self + } + + /// Disable the use of preferred addresses. + pub fn disable_preferred_address(mut self) -> Self { + self.preferred_address = PreferredAddressConfig::Disabled; + self + } + + pub fn get_preferred_address(&self) -> &PreferredAddressConfig { + &self.preferred_address + } + + pub fn ack_ratio(mut self, ack_ratio: u8) -> Self { + self.ack_ratio = ack_ratio; + self + } + + pub fn get_ack_ratio(&self) -> u8 { + self.ack_ratio + } + + /// # Panics + /// + /// If `timeout` is 2^62 milliseconds or more. + pub fn idle_timeout(mut self, timeout: Duration) -> Self { + assert!(timeout.as_millis() < (1 << 62), "idle timeout is too long"); + self.idle_timeout = timeout; + self + } + + pub fn get_idle_timeout(&self) -> Duration { + self.idle_timeout + } + + pub fn get_datagram_size(&self) -> u64 { + self.datagram_size + } + + pub fn datagram_size(mut self, v: u64) -> Self { + self.datagram_size = v; + self + } + + pub fn get_outgoing_datagram_queue(&self) -> usize { + self.outgoing_datagram_queue + } + + pub fn outgoing_datagram_queue(mut self, v: usize) -> Self { + // The max queue length must be at least 1. + self.outgoing_datagram_queue = max(v, 1); + self + } + + pub fn get_incoming_datagram_queue(&self) -> usize { + self.incoming_datagram_queue + } + + pub fn incoming_datagram_queue(mut self, v: usize) -> Self { + // The max queue length must be at least 1. + self.incoming_datagram_queue = max(v, 1); + self + } + + pub fn get_fast_pto(&self) -> u8 { + self.fast_pto + } + + /// Scale the PTO timer. A value of `FAST_PTO_SCALE` follows the spec, a smaller + /// value does not, but produces more probes with the intent of ensuring lower + /// latency in the event of tail loss. A value of `FAST_PTO_SCALE/4` is quite + /// aggressive. Smaller values (other than zero) are not rejected, but could be + /// very wasteful. Values greater than `FAST_PTO_SCALE` delay probes and could + /// reduce performance. It should not be possible to increase the PTO timer by + /// too much based on the range of valid values, but a maximum value of 255 will + /// result in very poor performance. + /// Scaling PTO this way does not affect when persistent congestion is declared, + /// but may change how many retransmissions are sent before declaring persistent + /// congestion. + /// + /// # Panics + /// + /// A value of 0 is invalid and will cause a panic. + pub fn fast_pto(mut self, scale: u8) -> Self { + assert_ne!(scale, 0); + self.fast_pto = scale; + self + } + + pub fn is_fuzzing(&self) -> bool { + self.fuzzing + } + + pub fn fuzzing(mut self, enable: bool) -> Self { + self.fuzzing = enable; + self + } + + pub fn is_greasing(&self) -> bool { + self.grease + } + + pub fn grease(mut self, grease: bool) -> Self { + self.grease = grease; + self + } + + pub fn pacing_enabled(&self) -> bool { + self.pacing + } + + pub fn pacing(mut self, pacing: bool) -> Self { + self.pacing = pacing; + self + } + + pub fn create_transport_parameter( + &self, + role: Role, + cid_manager: &mut ConnectionIdManager, + ) -> Res<TransportParametersHandler> { + let mut tps = TransportParametersHandler::new(role, self.versions.clone()); + // default parameters + tps.local.set_integer( + tparams::ACTIVE_CONNECTION_ID_LIMIT, + u64::try_from(LOCAL_ACTIVE_CID_LIMIT).unwrap(), + ); + tps.local.set_empty(tparams::DISABLE_MIGRATION); + tps.local.set_empty(tparams::GREASE_QUIC_BIT); + tps.local.set_integer( + tparams::MAX_ACK_DELAY, + u64::try_from(DEFAULT_ACK_DELAY.as_millis()).unwrap(), + ); + tps.local.set_integer( + tparams::MIN_ACK_DELAY, + u64::try_from(GRANULARITY.as_micros()).unwrap(), + ); + + // set configurable parameters + tps.local + .set_integer(tparams::INITIAL_MAX_DATA, self.max_data); + tps.local.set_integer( + tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL, + self.max_stream_data_bidi_local, + ); + tps.local.set_integer( + tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, + self.max_stream_data_bidi_remote, + ); + tps.local.set_integer( + tparams::INITIAL_MAX_STREAM_DATA_UNI, + self.max_stream_data_uni, + ); + tps.local + .set_integer(tparams::INITIAL_MAX_STREAMS_BIDI, self.max_streams_bidi); + tps.local + .set_integer(tparams::INITIAL_MAX_STREAMS_UNI, self.max_streams_uni); + tps.local.set_integer( + tparams::IDLE_TIMEOUT, + u64::try_from(self.idle_timeout.as_millis()).unwrap_or(0), + ); + if let PreferredAddressConfig::Address(preferred) = &self.preferred_address { + if role == Role::Server { + let (cid, srt) = cid_manager.preferred_address_cid()?; + tps.local.set( + tparams::PREFERRED_ADDRESS, + TransportParameter::PreferredAddress { + v4: preferred.ipv4(), + v6: preferred.ipv6(), + cid, + srt, + }, + ); + } + } + tps.local + .set_integer(tparams::MAX_DATAGRAM_FRAME_SIZE, self.datagram_size); + Ok(tps) + } +} diff --git a/third_party/rust/neqo-transport/src/connection/saved.rs b/third_party/rust/neqo-transport/src/connection/saved.rs new file mode 100644 index 0000000000..f5616c732a --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/saved.rs @@ -0,0 +1,68 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{mem, time::Instant}; + +use neqo_common::{qdebug, qinfo, Datagram}; + +use crate::crypto::CryptoSpace; + +/// The number of datagrams that are saved during the handshake when +/// keys to decrypt them are not yet available. +const MAX_SAVED_DATAGRAMS: usize = 4; + +pub struct SavedDatagram { + /// The datagram. + pub d: Datagram, + /// The time that the datagram was received. + pub t: Instant, +} + +#[derive(Default)] +pub struct SavedDatagrams { + handshake: Vec<SavedDatagram>, + application_data: Vec<SavedDatagram>, + available: Option<CryptoSpace>, +} + +impl SavedDatagrams { + fn store(&mut self, cspace: CryptoSpace) -> &mut Vec<SavedDatagram> { + match cspace { + CryptoSpace::Handshake => &mut self.handshake, + CryptoSpace::ApplicationData => &mut self.application_data, + _ => panic!("unexpected space"), + } + } + + pub fn save(&mut self, cspace: CryptoSpace, d: Datagram, t: Instant) { + let store = self.store(cspace); + + if store.len() < MAX_SAVED_DATAGRAMS { + qdebug!("saving datagram of {} bytes", d.len()); + store.push(SavedDatagram { d, t }); + } else { + qinfo!("not saving datagram of {} bytes", d.len()); + } + } + + pub fn make_available(&mut self, cspace: CryptoSpace) { + debug_assert_ne!(cspace, CryptoSpace::ZeroRtt); + debug_assert_ne!(cspace, CryptoSpace::Initial); + if !self.store(cspace).is_empty() { + self.available = Some(cspace); + } + } + + pub fn available(&self) -> Option<CryptoSpace> { + self.available + } + + pub fn take_saved(&mut self) -> Vec<SavedDatagram> { + self.available + .take() + .map_or_else(Vec::new, |cspace| mem::take(self.store(cspace))) + } +} diff --git a/third_party/rust/neqo-transport/src/connection/state.rs b/third_party/rust/neqo-transport/src/connection/state.rs new file mode 100644 index 0000000000..9afb42174f --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/state.rs @@ -0,0 +1,281 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cmp::{min, Ordering}, + mem, + rc::Rc, + time::Instant, +}; + +use neqo_common::Encoder; + +use crate::{ + frame::{ + FrameType, FRAME_TYPE_CONNECTION_CLOSE_APPLICATION, FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT, + FRAME_TYPE_HANDSHAKE_DONE, + }, + packet::PacketBuilder, + path::PathRef, + recovery::RecoveryToken, + ConnectionError, Error, Res, +}; + +#[derive(Clone, Debug, PartialEq, Eq)] +/// The state of the Connection. +pub enum State { + /// A newly created connection. + Init, + /// Waiting for the first Initial packet. + WaitInitial, + /// Waiting to confirm which version was selected. + /// For a client, this is confirmed when a CRYPTO frame is received; + /// the version of the packet determines the version. + /// For a server, this is confirmed when transport parameters are + /// received and processed. + WaitVersion, + /// Exchanging Handshake packets. + Handshaking, + Connected, + Confirmed, + Closing { + error: ConnectionError, + timeout: Instant, + }, + Draining { + error: ConnectionError, + timeout: Instant, + }, + Closed(ConnectionError), +} + +impl State { + #[must_use] + pub fn connected(&self) -> bool { + matches!(self, Self::Connected | Self::Confirmed) + } + + #[must_use] + pub fn closed(&self) -> bool { + matches!( + self, + Self::Closing { .. } | Self::Draining { .. } | Self::Closed(_) + ) + } + + pub fn error(&self) -> Option<&ConnectionError> { + if let Self::Closing { error, .. } | Self::Draining { error, .. } | Self::Closed(error) = + self + { + Some(error) + } else { + None + } + } +} + +// Implement `PartialOrd` so that we can enforce monotonic state progression. +impl PartialOrd for State { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some(self.cmp(other)) + } +} + +impl Ord for State { + fn cmp(&self, other: &Self) -> Ordering { + if mem::discriminant(self) == mem::discriminant(other) { + return Ordering::Equal; + } + #[allow(clippy::match_same_arms)] // Lint bug: rust-lang/rust-clippy#860 + match (self, other) { + (Self::Init, _) => Ordering::Less, + (_, Self::Init) => Ordering::Greater, + (Self::WaitInitial, _) => Ordering::Less, + (_, Self::WaitInitial) => Ordering::Greater, + (Self::WaitVersion, _) => Ordering::Less, + (_, Self::WaitVersion) => Ordering::Greater, + (Self::Handshaking, _) => Ordering::Less, + (_, Self::Handshaking) => Ordering::Greater, + (Self::Connected, _) => Ordering::Less, + (_, Self::Connected) => Ordering::Greater, + (Self::Confirmed, _) => Ordering::Less, + (_, Self::Confirmed) => Ordering::Greater, + (Self::Closing { .. }, _) => Ordering::Less, + (_, Self::Closing { .. }) => Ordering::Greater, + (Self::Draining { .. }, _) => Ordering::Less, + (_, Self::Draining { .. }) => Ordering::Greater, + (Self::Closed(_), _) => unreachable!(), + } + } +} + +#[derive(Debug, Clone)] +pub struct ClosingFrame { + path: PathRef, + error: ConnectionError, + frame_type: FrameType, + reason_phrase: Vec<u8>, +} + +impl ClosingFrame { + fn new( + path: PathRef, + error: ConnectionError, + frame_type: FrameType, + message: impl AsRef<str>, + ) -> Self { + let reason_phrase = message.as_ref().as_bytes().to_vec(); + Self { + path, + error, + frame_type, + reason_phrase, + } + } + + pub fn path(&self) -> &PathRef { + &self.path + } + + pub fn sanitize(&self) -> Option<Self> { + if let ConnectionError::Application(_) = self.error { + // The default CONNECTION_CLOSE frame that is sent when an application + // error code needs to be sent in an Initial or Handshake packet. + Some(Self { + path: Rc::clone(&self.path), + error: ConnectionError::Transport(Error::ApplicationError), + frame_type: 0, + reason_phrase: Vec::new(), + }) + } else { + None + } + } + + pub fn write_frame(&self, builder: &mut PacketBuilder) { + // Allow 8 bytes for the reason phrase to ensure that if it needs to be + // truncated there is still at least a few bytes of the value. + if builder.remaining() < 1 + 8 + 8 + 2 + 8 { + return; + } + match &self.error { + ConnectionError::Transport(e) => { + builder.encode_varint(FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT); + builder.encode_varint(e.code()); + builder.encode_varint(self.frame_type); + } + ConnectionError::Application(code) => { + builder.encode_varint(FRAME_TYPE_CONNECTION_CLOSE_APPLICATION); + builder.encode_varint(*code); + } + } + // Truncate the reason phrase if it doesn't fit. As we send this frame in + // multiple packet number spaces, limit the overall size to 256. + let available = min(256, builder.remaining()); + let reason = if available < Encoder::vvec_len(self.reason_phrase.len()) { + &self.reason_phrase[..available - 2] + } else { + &self.reason_phrase + }; + builder.encode_vvec(reason); + } +} + +/// `StateSignaling` manages whether we need to send HANDSHAKE_DONE and CONNECTION_CLOSE. +/// Valid state transitions are: +/// * Idle -> HandshakeDone: at the server when the handshake completes +/// * HandshakeDone -> Idle: when a HANDSHAKE_DONE frame is sent +/// * Idle/HandshakeDone -> Closing/Draining: when closing or draining +/// * Closing/Draining -> CloseSent: after sending CONNECTION_CLOSE +/// * CloseSent -> Closing: any time a new CONNECTION_CLOSE is needed +/// * -> Reset: from any state in case of a stateless reset +#[derive(Debug, Clone)] +pub enum StateSignaling { + Idle, + HandshakeDone, + /// These states save the frame that needs to be sent. + Closing(ClosingFrame), + Draining(ClosingFrame), + /// This state saves the frame that might need to be sent again. + /// If it is `None`, then we are draining and don't send. + CloseSent(Option<ClosingFrame>), + Reset, +} + +impl StateSignaling { + pub fn handshake_done(&mut self) { + if !matches!(self, Self::Idle) { + debug_assert!(false, "StateSignaling must be in Idle state."); + return; + } + *self = Self::HandshakeDone; + } + + pub fn write_done(&mut self, builder: &mut PacketBuilder) -> Res<Option<RecoveryToken>> { + if matches!(self, Self::HandshakeDone) && builder.remaining() >= 1 { + *self = Self::Idle; + builder.encode_varint(FRAME_TYPE_HANDSHAKE_DONE); + Ok(Some(RecoveryToken::HandshakeDone)) + } else { + Ok(None) + } + } + + pub fn close( + &mut self, + path: PathRef, + error: ConnectionError, + frame_type: FrameType, + message: impl AsRef<str>, + ) { + if !matches!(self, Self::Reset) { + *self = Self::Closing(ClosingFrame::new(path, error, frame_type, message)); + } + } + + pub fn drain( + &mut self, + path: PathRef, + error: ConnectionError, + frame_type: FrameType, + message: impl AsRef<str>, + ) { + if !matches!(self, Self::Reset) { + *self = Self::Draining(ClosingFrame::new(path, error, frame_type, message)); + } + } + + /// If a close is pending, take a frame. + pub fn close_frame(&mut self) -> Option<ClosingFrame> { + match self { + Self::Closing(frame) => { + // When we are closing, we might need to send the close frame again. + let res = Some(frame.clone()); + *self = Self::CloseSent(Some(frame.clone())); + res + } + Self::Draining(frame) => { + // When we are draining, just send once. + let res = Some(frame.clone()); + *self = Self::CloseSent(None); + res + } + _ => None, + } + } + + /// If a close can be sent again, prepare to send it again. + pub fn send_close(&mut self) { + if let Self::CloseSent(Some(frame)) = self { + *self = Self::Closing(frame.clone()); + } + } + + /// We just got a stateless reset. Terminate. + pub fn reset(&mut self) { + *self = Self::Reset; + } +} diff --git a/third_party/rust/neqo-transport/src/connection/test_internal.rs b/third_party/rust/neqo-transport/src/connection/test_internal.rs new file mode 100644 index 0000000000..353c38e526 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/test_internal.rs @@ -0,0 +1,13 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Some access to internal connection stuff for testing purposes. + +use crate::packet::PacketBuilder; + +pub trait FrameWriter { + fn write_frames(&mut self, builder: &mut PacketBuilder); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/ackrate.rs b/third_party/rust/neqo-transport/src/connection/tests/ackrate.rs new file mode 100644 index 0000000000..1b83d42acd --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/ackrate.rs @@ -0,0 +1,194 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{mem, time::Duration}; + +use test_fixture::{addr_v4, assertions}; + +use super::{ + super::{ConnectionParameters, ACK_RATIO_SCALE}, + ack_bytes, connect_rtt_idle, default_client, default_server, fill_cwnd, increase_cwnd, + induce_persistent_congestion, new_client, new_server, send_something, DEFAULT_RTT, +}; +use crate::stream_id::StreamType; + +/// With the default RTT here (100ms) and default ratio (4), endpoints won't send +/// `ACK_FREQUENCY` as the ACK delay isn't different enough from the default. +#[test] +fn ack_rate_default() { + let mut client = default_client(); + let mut server = default_server(); + _ = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + assert_eq!(client.stats().frame_tx.ack_frequency, 0); + assert_eq!(server.stats().frame_tx.ack_frequency, 0); +} + +/// When the congestion window increases, the rate doesn't change. +#[test] +fn ack_rate_slow_start() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Increase the congestion window a few times. + let stream = client.stream_create(StreamType::UniDi).unwrap(); + let now = increase_cwnd(&mut client, &mut server, stream, now); + let now = increase_cwnd(&mut client, &mut server, stream, now); + _ = increase_cwnd(&mut client, &mut server, stream, now); + + // The client should not have sent an ACK_FREQUENCY frame, even + // though the value would have updated. + assert_eq!(client.stats().frame_tx.ack_frequency, 0); + assert_eq!(server.stats().frame_rx.ack_frequency, 0); +} + +/// When the congestion window decreases, a frame is sent. +#[test] +fn ack_rate_exit_slow_start() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Increase the congestion window a few times, enough that after a loss, + // there are enough packets in the window to increase the packet + // count in ACK_FREQUENCY frames. + let stream = client.stream_create(StreamType::UniDi).unwrap(); + let now = increase_cwnd(&mut client, &mut server, stream, now); + let now = increase_cwnd(&mut client, &mut server, stream, now); + + // Now fill the congestion window and drop the first packet. + let (mut pkts, mut now) = fill_cwnd(&mut client, stream, now); + pkts.remove(0); + + // After acknowledging the other packets the client will notice the loss. + now += DEFAULT_RTT / 2; + let ack = ack_bytes(&mut server, stream, pkts, now); + + // Receiving the ACK will cause the client to reduce its congestion window + // and to send ACK_FREQUENCY. + now += DEFAULT_RTT / 2; + assert_eq!(client.stats().frame_tx.ack_frequency, 0); + let af = client.process(Some(&ack), now).dgram(); + assert!(af.is_some()); + assert_eq!(client.stats().frame_tx.ack_frequency, 1); +} + +/// When the congestion window collapses, `ACK_FREQUENCY` is updated. +#[test] +fn ack_rate_persistent_congestion() { + // Use a configuration that results in the value being set after exiting + // the handshake. + const RTT: Duration = Duration::from_millis(3); + let mut client = new_client(ConnectionParameters::default().ack_ratio(ACK_RATIO_SCALE)); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, RTT); + + // The client should have sent a frame. + assert_eq!(client.stats().frame_tx.ack_frequency, 1); + + // Now crash the congestion window. + let stream = client.stream_create(StreamType::UniDi).unwrap(); + let (dgrams, mut now) = fill_cwnd(&mut client, stream, now); + now += RTT / 2; + mem::drop(ack_bytes(&mut server, stream, dgrams, now)); + + let now = induce_persistent_congestion(&mut client, &mut server, stream, now); + + // The client sends a second ACK_FREQUENCY frame with an increased rate. + let af = client.process_output(now).dgram(); + assert!(af.is_some()); + assert_eq!(client.stats().frame_tx.ack_frequency, 2); +} + +/// Validate that the configuration works for the client. +#[test] +fn ack_rate_client_one_rtt() { + // This has to be chosen so that the resulting ACK delay is between 1ms and 50ms. + // We also have to avoid values between 20..30ms (approximately). The default + // maximum ACK delay is 25ms and an ACK_FREQUENCY frame won't be sent when the + // change to the maximum ACK delay is too small. + const RTT: Duration = Duration::from_millis(3); + let mut client = new_client(ConnectionParameters::default().ack_ratio(ACK_RATIO_SCALE)); + let mut server = default_server(); + let mut now = connect_rtt_idle(&mut client, &mut server, RTT); + + // A single packet from the client will cause the server to engage its delayed + // acknowledgment timer, which should now be equal to RTT. + // The first packet will elicit an immediate ACK however, so do this twice. + let d = send_something(&mut client, now); + now += RTT / 2; + let ack = server.process(Some(&d), now).dgram(); + assert!(ack.is_some()); + let d = send_something(&mut client, now); + now += RTT / 2; + let delay = server.process(Some(&d), now).callback(); + assert_eq!(delay, RTT); + + assert_eq!(client.stats().frame_tx.ack_frequency, 1); +} + +/// Validate that the configuration works for the server. +#[test] +fn ack_rate_server_half_rtt() { + const RTT: Duration = Duration::from_millis(10); + let mut client = default_client(); + let mut server = new_server(ConnectionParameters::default().ack_ratio(ACK_RATIO_SCALE * 2)); + let mut now = connect_rtt_idle(&mut client, &mut server, RTT); + + // The server now sends something. + let d = send_something(&mut server, now); + now += RTT / 2; + // The client now will acknowledge immediately because it has been more than + // an RTT since it last sent an acknowledgment. + let ack = client.process(Some(&d), now); + assert!(ack.as_dgram_ref().is_some()); + let d = send_something(&mut server, now); + now += RTT / 2; + let delay = client.process(Some(&d), now).callback(); + assert_eq!(delay, RTT / 2); + + assert_eq!(server.stats().frame_tx.ack_frequency, 1); +} + +/// ACK delay calculations are path-specific, +/// so check that they can be sent on new paths. +#[test] +fn migrate_ack_delay() { + // Have the client send ACK_FREQUENCY frames at a normal-ish rate. + let mut client = new_client(ConnectionParameters::default().ack_ratio(ACK_RATIO_SCALE)); + let mut server = default_server(); + let mut now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + client + .migrate(Some(addr_v4()), Some(addr_v4()), true, now) + .unwrap(); + + let client1 = send_something(&mut client, now); + assertions::assert_v4_path(&client1, true); // Contains PATH_CHALLENGE. + let client2 = send_something(&mut client, now); + assertions::assert_v4_path(&client2, false); // Doesn't. Is dropped. + now += DEFAULT_RTT / 2; + server.process_input(&client1, now); + + let stream = client.stream_create(StreamType::UniDi).unwrap(); + let now = increase_cwnd(&mut client, &mut server, stream, now); + let now = increase_cwnd(&mut client, &mut server, stream, now); + let now = increase_cwnd(&mut client, &mut server, stream, now); + + // Now lose a packet and force the client to update + let (mut pkts, mut now) = fill_cwnd(&mut client, stream, now); + pkts.remove(0); + now += DEFAULT_RTT / 2; + let ack = ack_bytes(&mut server, stream, pkts, now); + + // After noticing this new loss, the client sends ACK_FREQUENCY. + // It has sent a few before (as we dropped `client2`), so ignore those. + let ad_before = client.stats().frame_tx.ack_frequency; + let af = client.process(Some(&ack), now).dgram(); + assert!(af.is_some()); + assert_eq!(client.stats().frame_tx.ack_frequency, ad_before + 1); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/cc.rs b/third_party/rust/neqo-transport/src/connection/tests/cc.rs new file mode 100644 index 0000000000..b3467ea67c --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/cc.rs @@ -0,0 +1,429 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{convert::TryFrom, mem, time::Duration}; + +use neqo_common::{qdebug, qinfo, Datagram}; + +use super::{ + super::Output, ack_bytes, assert_full_cwnd, connect_rtt_idle, cwnd, cwnd_avail, cwnd_packets, + default_client, default_server, fill_cwnd, induce_persistent_congestion, send_something, + CLIENT_HANDSHAKE_1RTT_PACKETS, DEFAULT_RTT, POST_HANDSHAKE_CWND, +}; +use crate::{ + cc::MAX_DATAGRAM_SIZE, + packet::PacketNumber, + recovery::{ACK_ONLY_SIZE_LIMIT, PACKET_THRESHOLD}, + sender::PACING_BURST_SIZE, + stream_id::StreamType, + tracking::DEFAULT_ACK_PACKET_TOLERANCE, +}; + +#[test] +/// Verify initial CWND is honored. +fn cc_slow_start() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Try to send a lot of data + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + let (c_tx_dgrams, _) = fill_cwnd(&mut client, stream_id, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + assert!(cwnd_avail(&client) < ACK_ONLY_SIZE_LIMIT); +} + +#[test] +/// Verify that CC moves to cong avoidance when a packet is marked lost. +fn cc_slow_start_to_cong_avoidance_recovery_period() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Create stream 0 + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + assert_eq!(stream_id, 0); + + // Buffer up lot of data and generate packets + let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream_id, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + // Predict the packet number of the last packet sent. + // We have already sent packets in `connect_rtt_idle`, + // so include a fudge factor. + let flight1_largest = + PacketNumber::try_from(c_tx_dgrams.len() + CLIENT_HANDSHAKE_1RTT_PACKETS).unwrap(); + + // Server: Receive and generate ack + now += DEFAULT_RTT / 2; + let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now); + assert_eq!( + server.stats().frame_tx.largest_acknowledged, + flight1_largest + ); + + // Client: Process ack + now += DEFAULT_RTT / 2; + client.process_input(&s_ack, now); + assert_eq!( + client.stats().frame_rx.largest_acknowledged, + flight1_largest + ); + + // Client: send more + let (mut c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream_id, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND * 2); + let flight2_largest = flight1_largest + u64::try_from(c_tx_dgrams.len()).unwrap(); + + // Server: Receive and generate ack again, but drop first packet + now += DEFAULT_RTT / 2; + c_tx_dgrams.remove(0); + let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now); + assert_eq!( + server.stats().frame_tx.largest_acknowledged, + flight2_largest + ); + + // Client: Process ack + now += DEFAULT_RTT / 2; + client.process_input(&s_ack, now); + assert_eq!( + client.stats().frame_rx.largest_acknowledged, + flight2_largest + ); +} + +#[test] +/// Verify that CC stays in recovery period when packet sent before start of +/// recovery period is acked. +fn cc_cong_avoidance_recovery_period_unchanged() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Create stream 0 + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + assert_eq!(stream_id, 0); + + // Buffer up lot of data and generate packets + let (mut c_tx_dgrams, now) = fill_cwnd(&mut client, stream_id, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + + // Drop 0th packet. When acked, this should put client into CARP. + c_tx_dgrams.remove(0); + + let c_tx_dgrams2 = c_tx_dgrams.split_off(5); + + // Server: Receive and generate ack + let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now); + client.process_input(&s_ack, now); + + let cwnd1 = cwnd(&client); + + // Generate ACK for more received packets + let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams2, now); + + // ACK more packets but they were sent before end of recovery period + client.process_input(&s_ack, now); + + // cwnd should not have changed since ACKed packets were sent before + // recovery period expired + let cwnd2 = cwnd(&client); + assert_eq!(cwnd1, cwnd2); +} + +#[test] +/// Ensure that a single packet is sent after entering recovery, even +/// when that exceeds the available congestion window. +fn single_packet_on_recovery() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Drop a few packets, up to the reordering threshold. + for _ in 0..PACKET_THRESHOLD { + let _dropped = send_something(&mut client, now); + } + let delivered = send_something(&mut client, now); + + // Now fill the congestion window. + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + assert_eq!(stream_id, 0); + let (_, now) = fill_cwnd(&mut client, stream_id, now); + assert!(cwnd_avail(&client) < ACK_ONLY_SIZE_LIMIT); + + // Acknowledge just one packet and cause one packet to be declared lost. + // The length is the amount of credit the client should have. + let ack = server.process(Some(&delivered), now).dgram(); + assert!(ack.is_some()); + + // The client should see the loss and enter recovery. + // As there are many outstanding packets, there should be no available cwnd. + client.process_input(&ack.unwrap(), now); + assert_eq!(cwnd_avail(&client), 0); + + // The client should send one packet, ignoring the cwnd. + let dgram = client.process_output(now).dgram(); + assert!(dgram.is_some()); +} + +#[test] +/// Verify that CC moves out of recovery period when packet sent after start +/// of recovery period is acked. +fn cc_cong_avoidance_recovery_period_to_cong_avoidance() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Create stream 0 + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + assert_eq!(stream_id, 0); + + // Buffer up lot of data and generate packets + let (mut c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream_id, now); + + // Drop 0th packet. When acked, this should put client into CARP. + c_tx_dgrams.remove(0); + + // Server: Receive and generate ack + now += DEFAULT_RTT / 2; + let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now); + + // Client: Process ack + now += DEFAULT_RTT / 2; + client.process_input(&s_ack, now); + + // Should be in CARP now. + now += DEFAULT_RTT / 2; + qinfo!("moving to congestion avoidance {}", cwnd(&client)); + + // Now make sure that we increase congestion window according to the + // accurate byte counting version of congestion avoidance. + // Check over several increases to be sure. + let mut expected_cwnd = cwnd(&client); + // Fill cwnd. + let (mut c_tx_dgrams, next_now) = fill_cwnd(&mut client, stream_id, now); + now = next_now; + for i in 0..5 { + qinfo!("iteration {}", i); + + let c_tx_size: usize = c_tx_dgrams.iter().map(|d| d.len()).sum(); + qinfo!( + "client sending {} bytes into cwnd of {}", + c_tx_size, + cwnd(&client) + ); + assert_eq!(c_tx_size, expected_cwnd); + + // As acks arrive we will continue filling cwnd and save all packets + // from this cycle will be stored in next_c_tx_dgrams. + let mut next_c_tx_dgrams: Vec<Datagram> = Vec::new(); + + // Until we process all the packets, the congestion window remains the same. + // Note that we need the client to process ACK frames in stages, so split the + // datagrams into two, ensuring that we allow for an ACK for each batch. + let most = c_tx_dgrams.len() - usize::try_from(DEFAULT_ACK_PACKET_TOLERANCE).unwrap() - 1; + let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams.drain(..most), now); + assert_eq!(cwnd(&client), expected_cwnd); + client.process_input(&s_ack, now); + // make sure to fill cwnd again. + let (mut new_pkts, next_now) = fill_cwnd(&mut client, stream_id, now); + now = next_now; + next_c_tx_dgrams.append(&mut new_pkts); + + let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now); + assert_eq!(cwnd(&client), expected_cwnd); + client.process_input(&s_ack, now); + // make sure to fill cwnd again. + let (mut new_pkts, next_now) = fill_cwnd(&mut client, stream_id, now); + now = next_now; + next_c_tx_dgrams.append(&mut new_pkts); + + expected_cwnd += MAX_DATAGRAM_SIZE; + assert_eq!(cwnd(&client), expected_cwnd); + c_tx_dgrams = next_c_tx_dgrams; + } +} + +#[test] +/// Verify transition to persistent congestion state if conditions are met. +fn cc_slow_start_to_persistent_congestion_no_acks() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + let stream = client.stream_create(StreamType::BiDi).unwrap(); + + // Buffer up lot of data and generate packets + let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + + // Server: Receive and generate ack + now += DEFAULT_RTT / 2; + mem::drop(ack_bytes(&mut server, stream, c_tx_dgrams, now)); + + // ACK lost. + induce_persistent_congestion(&mut client, &mut server, stream, now); +} + +#[test] +/// Verify transition to persistent congestion state if conditions are met. +fn cc_slow_start_to_persistent_congestion_some_acks() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Create stream 0 + let stream = client.stream_create(StreamType::BiDi).unwrap(); + + // Buffer up lot of data and generate packets + let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + + // Server: Receive and generate ack + now += Duration::from_millis(100); + let s_ack = ack_bytes(&mut server, stream, c_tx_dgrams, now); + + now += Duration::from_millis(100); + client.process_input(&s_ack, now); + + // send bytes that will be lost + let (_, next_now) = fill_cwnd(&mut client, stream, now); + now = next_now + Duration::from_millis(100); + + induce_persistent_congestion(&mut client, &mut server, stream, now); +} + +#[test] +/// Verify persistent congestion moves to slow start after recovery period +/// ends. +fn cc_persistent_congestion_to_slow_start() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Create stream 0 + let stream = client.stream_create(StreamType::BiDi).unwrap(); + + // Buffer up lot of data and generate packets + let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + + // Server: Receive and generate ack + now += Duration::from_millis(10); + mem::drop(ack_bytes(&mut server, stream, c_tx_dgrams, now)); + + // ACK lost. + + now = induce_persistent_congestion(&mut client, &mut server, stream, now); + + // New part of test starts here + + now += Duration::from_millis(10); + + // Send packets from after start of CARP + let (c_tx_dgrams, next_now) = fill_cwnd(&mut client, stream, now); + assert_eq!(c_tx_dgrams.len(), 2); + + // Server: Receive and generate ack + now = next_now + Duration::from_millis(100); + let s_ack = ack_bytes(&mut server, stream, c_tx_dgrams, now); + + // No longer in CARP. (pkts acked from after start of CARP) + // Should be in slow start now. + client.process_input(&s_ack, now); + + // ACKing 2 packets should let client send 4. + let (c_tx_dgrams, _) = fill_cwnd(&mut client, stream, now); + assert_eq!(c_tx_dgrams.len(), 4); +} + +#[test] +fn ack_are_not_cc() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Create a stream + let stream = client.stream_create(StreamType::BiDi).unwrap(); + assert_eq!(stream, 0); + + // Buffer up lot of data and generate packets, so that cc window is filled. + let (c_tx_dgrams, now) = fill_cwnd(&mut client, stream, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + + // The server hasn't received any of these packets yet, the server + // won't ACK, but if it sends an ack-eliciting packet instead. + qdebug!([server], "Sending ack-eliciting"); + let other_stream = server.stream_create(StreamType::BiDi).unwrap(); + assert_eq!(other_stream, 1); + server.stream_send(other_stream, b"dropped").unwrap(); + let dropped_packet = server.process(None, now).dgram(); + assert!(dropped_packet.is_some()); // Now drop this one. + + // Now the server sends a packet that will force an ACK, + // because the client will detect a gap. + server.stream_send(other_stream, b"sent").unwrap(); + let ack_eliciting_packet = server.process(None, now).dgram(); + assert!(ack_eliciting_packet.is_some()); + + // The client can ack the server packet even if cc windows is full. + qdebug!([client], "Process ack-eliciting"); + let ack_pkt = client.process(ack_eliciting_packet.as_ref(), now).dgram(); + assert!(ack_pkt.is_some()); + qdebug!([server], "Handle ACK"); + let prev_ack_count = server.stats().frame_rx.ack; + server.process_input(&ack_pkt.unwrap(), now); + assert_eq!(server.stats().frame_rx.ack, prev_ack_count + 1); +} + +#[test] +fn pace() { + const DATA: &[u8] = &[0xcc; 4_096]; + let mut client = default_client(); + let mut server = default_server(); + let mut now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Now fill up the pipe and watch it trickle out. + let stream = client.stream_create(StreamType::BiDi).unwrap(); + loop { + let written = client.stream_send(stream, DATA).unwrap(); + if written < DATA.len() { + break; + } + } + let mut count = 0; + // We should get a burst at first. + // The first packet is not subject to pacing as there are no bytes in flight. + // After that we allow the burst to continue up to a number of packets (2). + for _ in 0..=PACING_BURST_SIZE { + let dgram = client.process_output(now).dgram(); + assert!(dgram.is_some()); + count += 1; + } + let gap = client.process_output(now).callback(); + assert_ne!(gap, Duration::new(0, 0)); + for _ in (1 + PACING_BURST_SIZE)..cwnd_packets(POST_HANDSHAKE_CWND) { + match client.process_output(now) { + Output::Callback(t) => assert_eq!(t, gap), + Output::Datagram(_) => { + // The last packet might not be paced. + count += 1; + break; + } + Output::None => panic!(), + } + now += gap; + let dgram = client.process_output(now).dgram(); + assert!(dgram.is_some()); + count += 1; + } + let dgram = client.process_output(now).dgram(); + assert!(dgram.is_none()); + assert_eq!(count, cwnd_packets(POST_HANDSHAKE_CWND)); + let fin = client.process_output(now).callback(); + assert_ne!(fin, Duration::new(0, 0)); + assert_ne!(fin, gap); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/close.rs b/third_party/rust/neqo-transport/src/connection/tests/close.rs new file mode 100644 index 0000000000..f45e77e549 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/close.rs @@ -0,0 +1,210 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::time::Duration; + +use test_fixture::{self, datagram, now}; + +use super::{ + super::{Connection, Output, State}, + connect, connect_force_idle, default_client, default_server, send_something, +}; +use crate::{ + tparams::{self, TransportParameter}, + AppError, ConnectionError, Error, ERROR_APPLICATION_CLOSE, +}; + +fn assert_draining(c: &Connection, expected: &Error) { + assert!(c.state().closed()); + if let State::Draining { + error: ConnectionError::Transport(error), + .. + } = c.state() + { + assert_eq!(error, expected); + } else { + panic!(); + } +} + +#[test] +fn connection_close() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let now = now(); + + client.close(now, 42, ""); + + let out = client.process(None, now); + + server.process_input(&out.dgram().unwrap(), now); + assert_draining(&server, &Error::PeerApplicationError(42)); +} + +#[test] +fn connection_close_with_long_reason_string() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let now = now(); + // Create a long string and use it as the close reason. + let long_reason = String::from_utf8([0x61; 2048].to_vec()).unwrap(); + client.close(now, 42, long_reason); + + let out = client.process(None, now); + + server.process_input(&out.dgram().unwrap(), now); + assert_draining(&server, &Error::PeerApplicationError(42)); +} + +// During the handshake, an application close should be sanitized. +#[test] +fn early_application_close() { + let mut client = default_client(); + let mut server = default_server(); + + // One flight each. + let dgram = client.process(None, now()).dgram(); + assert!(dgram.is_some()); + let dgram = server.process(dgram.as_ref(), now()).dgram(); + assert!(dgram.is_some()); + + server.close(now(), 77, String::new()); + assert!(server.state().closed()); + let dgram = server.process(None, now()).dgram(); + assert!(dgram.is_some()); + + client.process_input(&dgram.unwrap(), now()); + assert_draining(&client, &Error::PeerError(ERROR_APPLICATION_CLOSE)); +} + +#[test] +fn bad_tls_version() { + let mut client = default_client(); + // Do a bad, bad thing. + client + .crypto + .tls + .set_option(neqo_crypto::Opt::Tls13CompatMode, true) + .unwrap(); + let mut server = default_server(); + + let dgram = client.process(None, now()).dgram(); + assert!(dgram.is_some()); + let dgram = server.process(dgram.as_ref(), now()).dgram(); + assert_eq!( + *server.state(), + State::Closed(ConnectionError::Transport(Error::ProtocolViolation)) + ); + assert!(dgram.is_some()); + client.process_input(&dgram.unwrap(), now()); + assert_draining(&client, &Error::PeerError(Error::ProtocolViolation.code())); +} + +/// Test the interaction between the loss recovery timer +/// and the closing timer. +#[test] +fn closing_timers_interation() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let mut now = now(); + + // We're going to induce time-based loss recovery so that timer is set. + let _p1 = send_something(&mut client, now); + let p2 = send_something(&mut client, now); + let ack = server.process(Some(&p2), now).dgram(); + assert!(ack.is_some()); // This is an ACK. + + // After processing the ACK, we should be on the loss recovery timer. + let cb = client.process(ack.as_ref(), now).callback(); + assert_ne!(cb, Duration::from_secs(0)); + now += cb; + + // Rather than let the timer pop, close the connection. + client.close(now, 0, ""); + let client_close = client.process(None, now).dgram(); + assert!(client_close.is_some()); + // This should now report the end of the closing period, not a + // zero-duration wait driven by the (now defunct) loss recovery timer. + let client_close_timer = client.process(None, now).callback(); + assert_ne!(client_close_timer, Duration::from_secs(0)); +} + +#[test] +fn closing_and_draining() { + const APP_ERROR: AppError = 7; + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // Save a packet from the client for later. + let p1 = send_something(&mut client, now()); + + // Close the connection. + client.close(now(), APP_ERROR, ""); + let client_close = client.process(None, now()).dgram(); + assert!(client_close.is_some()); + let client_close_timer = client.process(None, now()).callback(); + assert_ne!(client_close_timer, Duration::from_secs(0)); + + // The client will spit out the same packet in response to anything it receives. + let p3 = send_something(&mut server, now()); + let client_close2 = client.process(Some(&p3), now()).dgram(); + assert_eq!( + client_close.as_ref().unwrap().len(), + client_close2.as_ref().unwrap().len() + ); + + // After this time, the client should transition to closed. + let end = client.process(None, now() + client_close_timer); + assert_eq!(end, Output::None); + assert_eq!( + *client.state(), + State::Closed(ConnectionError::Application(APP_ERROR)) + ); + + // When the server receives the close, it too should generate CONNECTION_CLOSE. + let server_close = server.process(client_close.as_ref(), now()).dgram(); + assert!(server.state().closed()); + assert!(server_close.is_some()); + // .. but it ignores any further close packets. + let server_close_timer = server.process(client_close2.as_ref(), now()).callback(); + assert_ne!(server_close_timer, Duration::from_secs(0)); + // Even a legitimate packet without a close in it. + let server_close_timer2 = server.process(Some(&p1), now()).callback(); + assert_eq!(server_close_timer, server_close_timer2); + + let end = server.process(None, now() + server_close_timer); + assert_eq!(end, Output::None); + assert_eq!( + *server.state(), + State::Closed(ConnectionError::Transport(Error::PeerApplicationError( + APP_ERROR + ))) + ); +} + +/// Test that a client can handle a stateless reset correctly. +#[test] +fn stateless_reset_client() { + let mut client = default_client(); + let mut server = default_server(); + server + .set_local_tparam( + tparams::STATELESS_RESET_TOKEN, + TransportParameter::Bytes(vec![77; 16]), + ) + .unwrap(); + connect_force_idle(&mut client, &mut server); + + client.process_input(&datagram(vec![77; 21]), now()); + assert_draining(&client, &Error::StatelessReset); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/datagram.rs b/third_party/rust/neqo-transport/src/connection/tests/datagram.rs new file mode 100644 index 0000000000..5b7b8dc0b4 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/datagram.rs @@ -0,0 +1,620 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{cell::RefCell, convert::TryFrom, rc::Rc}; + +use neqo_common::event::Provider; +use test_fixture::now; + +use super::{ + assert_error, connect_force_idle, default_client, default_server, new_client, new_server, + AT_LEAST_PTO, +}; +use crate::{ + events::{ConnectionEvent, OutgoingDatagramOutcome}, + frame::FRAME_TYPE_DATAGRAM, + packet::PacketBuilder, + quic_datagrams::MAX_QUIC_DATAGRAM, + send_stream::{RetransmissionPriority, TransmissionPriority}, + Connection, ConnectionError, ConnectionParameters, Error, StreamType, +}; + +const DATAGRAM_LEN_MTU: u64 = 1310; +const DATA_MTU: &[u8] = &[1; 1310]; +const DATA_BIGGER_THAN_MTU: &[u8] = &[0; 2620]; +const DATAGRAM_LEN_SMALLER_THAN_MTU: u64 = 1200; +const DATA_SMALLER_THAN_MTU: &[u8] = &[0; 1200]; +const DATA_SMALLER_THAN_MTU_2: &[u8] = &[0; 600]; +const OUTGOING_QUEUE: usize = 2; + +struct InsertDatagram<'a> { + data: &'a [u8], +} + +impl crate::connection::test_internal::FrameWriter for InsertDatagram<'_> { + fn write_frames(&mut self, builder: &mut PacketBuilder) { + builder.encode_varint(FRAME_TYPE_DATAGRAM); + builder.encode(self.data); + } +} + +#[test] +fn datagram_disabled_both() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + assert_eq!(client.max_datagram_size(), Err(Error::NotAvailable)); + assert_eq!(server.max_datagram_size(), Err(Error::NotAvailable)); + assert_eq!( + client.send_datagram(DATA_SMALLER_THAN_MTU, None), + Err(Error::TooMuchData) + ); + assert_eq!(server.stats().frame_tx.datagram, 0); + assert_eq!( + server.send_datagram(DATA_SMALLER_THAN_MTU, None), + Err(Error::TooMuchData) + ); + assert_eq!(server.stats().frame_tx.datagram, 0); +} + +#[test] +fn datagram_enabled_on_client() { + let mut client = + new_client(ConnectionParameters::default().datagram_size(DATAGRAM_LEN_SMALLER_THAN_MTU)); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + assert_eq!(client.max_datagram_size(), Err(Error::NotAvailable)); + assert_eq!( + server.max_datagram_size(), + Ok(DATAGRAM_LEN_SMALLER_THAN_MTU) + ); + assert_eq!( + client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), + Err(Error::TooMuchData) + ); + let dgram_sent = server.stats().frame_tx.datagram; + assert_eq!(server.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(())); + let out = server.process_output(now()).dgram().unwrap(); + assert_eq!(server.stats().frame_tx.datagram, dgram_sent + 1); + + client.process_input(&out, now()); + assert!(matches!( + client.next_event().unwrap(), + ConnectionEvent::Datagram(data) if data == DATA_SMALLER_THAN_MTU + )); +} + +#[test] +fn datagram_enabled_on_server() { + let mut client = default_client(); + let mut server = + new_server(ConnectionParameters::default().datagram_size(DATAGRAM_LEN_SMALLER_THAN_MTU)); + connect_force_idle(&mut client, &mut server); + + assert_eq!( + client.max_datagram_size(), + Ok(DATAGRAM_LEN_SMALLER_THAN_MTU) + ); + assert_eq!(server.max_datagram_size(), Err(Error::NotAvailable)); + assert_eq!( + server.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), + Err(Error::TooMuchData) + ); + let dgram_sent = client.stats().frame_tx.datagram; + assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(())); + let out = client.process_output(now()).dgram().unwrap(); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1); + + server.process_input(&out, now()); + assert!(matches!( + server.next_event().unwrap(), + ConnectionEvent::Datagram(data) if data == DATA_SMALLER_THAN_MTU + )); +} + +fn connect_datagram() -> (Connection, Connection) { + let mut client = new_client( + ConnectionParameters::default() + .datagram_size(MAX_QUIC_DATAGRAM) + .outgoing_datagram_queue(OUTGOING_QUEUE), + ); + let mut server = new_server(ConnectionParameters::default().datagram_size(MAX_QUIC_DATAGRAM)); + connect_force_idle(&mut client, &mut server); + (client, server) +} + +#[test] +fn mtu_limit() { + let (client, server) = connect_datagram(); + + assert_eq!(client.max_datagram_size(), Ok(DATAGRAM_LEN_MTU)); + assert_eq!(server.max_datagram_size(), Ok(DATAGRAM_LEN_MTU)); +} + +#[test] +fn limit_data_size() { + let (mut client, mut server) = connect_datagram(); + + assert!(u64::try_from(DATA_BIGGER_THAN_MTU.len()).unwrap() > DATAGRAM_LEN_MTU); + // Datagram can be queued because they are smaller than allowed by the peer, + // but they cannot be sent. + assert_eq!(server.send_datagram(DATA_BIGGER_THAN_MTU, Some(1)), Ok(())); + + let dgram_dropped_s = server.stats().datagram_tx.dropped_too_big; + let dgram_sent_s = server.stats().frame_tx.datagram; + assert!(server.process_output(now()).dgram().is_none()); + assert_eq!( + server.stats().datagram_tx.dropped_too_big, + dgram_dropped_s + 1 + ); + assert_eq!(server.stats().frame_tx.datagram, dgram_sent_s); + assert!(matches!( + server.next_event().unwrap(), + ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::DroppedTooBig + )); + + // The same test for the client side. + assert_eq!(client.send_datagram(DATA_BIGGER_THAN_MTU, Some(1)), Ok(())); + let dgram_sent_c = client.stats().frame_tx.datagram; + assert!(client.process_output(now()).dgram().is_none()); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent_c); + assert!(matches!( + client.next_event().unwrap(), + ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::DroppedTooBig + )); +} + +#[test] +fn after_dgram_dropped_continue_writing_frames() { + let (mut client, _) = connect_datagram(); + + assert!(u64::try_from(DATA_BIGGER_THAN_MTU.len()).unwrap() > DATAGRAM_LEN_MTU); + // Datagram can be queued because they are smaller than allowed by the peer, + // but they cannot be sent. + assert_eq!(client.send_datagram(DATA_BIGGER_THAN_MTU, Some(1)), Ok(())); + assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(2)), Ok(())); + + let datagram_dropped = |e| { + matches!( + e, + ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::DroppedTooBig) + }; + + let dgram_dropped_c = client.stats().datagram_tx.dropped_too_big; + let dgram_sent_c = client.stats().frame_tx.datagram; + + assert!(client.process_output(now()).dgram().is_some()); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent_c + 1); + assert_eq!( + client.stats().datagram_tx.dropped_too_big, + dgram_dropped_c + 1 + ); + assert!(client.events().any(datagram_dropped)); +} + +#[test] +fn datagram_acked() { + let (mut client, mut server) = connect_datagram(); + + let dgram_sent = client.stats().frame_tx.datagram; + assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(())); + let out = client.process_output(now()).dgram(); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1); + + let dgram_received = server.stats().frame_rx.datagram; + server.process_input(&out.unwrap(), now()); + assert_eq!(server.stats().frame_rx.datagram, dgram_received + 1); + let now = now() + AT_LEAST_PTO; + // Ack should be sent + let ack_sent = server.stats().frame_tx.ack; + let out = server.process_output(now).dgram(); + assert_eq!(server.stats().frame_tx.ack, ack_sent + 1); + + assert!(matches!( + server.next_event().unwrap(), + ConnectionEvent::Datagram(data) if data == DATA_SMALLER_THAN_MTU + )); + + client.process_input(&out.unwrap(), now); + assert!(matches!( + client.next_event().unwrap(), + ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::Acked + )); +} + +fn send_packet_and_get_server_event( + client: &mut Connection, + server: &mut Connection, +) -> ConnectionEvent { + let out = client.process_output(now()).dgram(); + server.process_input(&out.unwrap(), now()); + let mut events: Vec<_> = server + .events() + .filter_map(|evt| match evt { + ConnectionEvent::RecvStreamReadable { .. } | ConnectionEvent::Datagram { .. } => { + Some(evt) + } + _ => None, + }) + .collect(); + // We should only get one event - either RecvStreamReadable or Datagram. + assert_eq!(events.len(), 1); + events.remove(0) +} + +/// Write a datagram that is big enough to fill a packet, but then see that +/// normal priority stream data is sent first. +#[test] +fn datagram_after_stream_data() { + let (mut client, mut server) = connect_datagram(); + + // Write a datagram first. + let dgram_sent = client.stats().frame_tx.datagram; + assert_eq!(client.send_datagram(DATA_MTU, Some(1)), Ok(())); + + // Create a stream with normal priority and send some data. + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_send(stream_id, &[6; 1200]).unwrap(); + + assert!( + matches!(send_packet_and_get_server_event(&mut client, &mut server), ConnectionEvent::RecvStreamReadable { stream_id: s } if s == stream_id) + ); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent); + + if let ConnectionEvent::Datagram(data) = + &send_packet_and_get_server_event(&mut client, &mut server) + { + assert_eq!(data, DATA_MTU); + } else { + panic!(); + } + assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1); +} + +#[test] +fn datagram_before_stream_data() { + let (mut client, mut server) = connect_datagram(); + + // Create a stream with low priority and send some data before datagram. + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client + .stream_priority( + stream_id, + TransmissionPriority::Low, + RetransmissionPriority::default(), + ) + .unwrap(); + client.stream_send(stream_id, &[6; 1200]).unwrap(); + + // Write a datagram. + let dgram_sent = client.stats().frame_tx.datagram; + assert_eq!(client.send_datagram(DATA_MTU, Some(1)), Ok(())); + + if let ConnectionEvent::Datagram(data) = + &send_packet_and_get_server_event(&mut client, &mut server) + { + assert_eq!(data, DATA_MTU); + } else { + panic!(); + } + assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1); + + assert!( + matches!(send_packet_and_get_server_event(&mut client, &mut server), ConnectionEvent::RecvStreamReadable { stream_id: s } if s == stream_id) + ); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1); +} + +#[test] +fn datagram_lost() { + let (mut client, _) = connect_datagram(); + + let dgram_sent = client.stats().frame_tx.datagram; + assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(())); + let _out = client.process_output(now()).dgram(); // This packet will be lost. + assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1); + + // Wait for PTO + let now = now() + AT_LEAST_PTO; + let dgram_sent2 = client.stats().frame_tx.datagram; + let pings_sent = client.stats().frame_tx.ping; + let dgram_lost = client.stats().datagram_tx.lost; + let out = client.process_output(now).dgram(); + assert!(out.is_some()); // PING probing + // Datagram is not sent again. + assert_eq!(client.stats().frame_tx.ping, pings_sent + 1); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent2); + assert_eq!(client.stats().datagram_tx.lost, dgram_lost + 1); + + assert!(matches!( + client.next_event().unwrap(), + ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::Lost + )); +} + +#[test] +fn datagram_sent_once() { + let (mut client, _) = connect_datagram(); + + let dgram_sent = client.stats().frame_tx.datagram; + assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(())); + let _out = client.process_output(now()).dgram(); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1); + + // Call process_output again should not send any new Datagram. + assert!(client.process_output(now()).dgram().is_none()); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1); +} + +#[test] +fn dgram_no_allowed() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + server.test_frame_writer = Some(Box::new(InsertDatagram { data: DATA_MTU })); + let out = server.process_output(now()).dgram().unwrap(); + server.test_frame_writer = None; + + client.process_input(&out, now()); + + assert_error( + &client, + &ConnectionError::Transport(Error::ProtocolViolation), + ); +} + +#[test] +#[allow(clippy::assertions_on_constants)] // this is a static assert, thanks +fn dgram_too_big() { + let mut client = + new_client(ConnectionParameters::default().datagram_size(DATAGRAM_LEN_SMALLER_THAN_MTU)); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + assert!(DATAGRAM_LEN_MTU > DATAGRAM_LEN_SMALLER_THAN_MTU); + server.test_frame_writer = Some(Box::new(InsertDatagram { data: DATA_MTU })); + let out = server.process_output(now()).dgram().unwrap(); + server.test_frame_writer = None; + + client.process_input(&out, now()); + + assert_error( + &client, + &ConnectionError::Transport(Error::ProtocolViolation), + ); +} + +#[test] +fn outgoing_datagram_queue_full() { + let (mut client, mut server) = connect_datagram(); + + let dgram_sent = client.stats().frame_tx.datagram; + assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(())); + assert_eq!( + client.send_datagram(DATA_SMALLER_THAN_MTU_2, Some(2)), + Ok(()) + ); + + // The outgoing datagram queue limit is 2, therefore the datagram with id 1 + // will be dropped after adding one more datagram. + let dgram_dropped = client.stats().datagram_tx.dropped_queue_full; + assert_eq!(client.send_datagram(DATA_MTU, Some(3)), Ok(())); + assert!(matches!( + client.next_event().unwrap(), + ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::DroppedQueueFull + )); + assert_eq!( + client.stats().datagram_tx.dropped_queue_full, + dgram_dropped + 1 + ); + + // Send DATA_SMALLER_THAN_MTU_2 datagram + let out = client.process_output(now()).dgram(); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1); + server.process_input(&out.unwrap(), now()); + assert!(matches!( + server.next_event().unwrap(), + ConnectionEvent::Datagram(data) if data == DATA_SMALLER_THAN_MTU_2 + )); + + // Send DATA_SMALLER_THAN_MTU_2 datagram + let dgram_sent2 = client.stats().frame_tx.datagram; + let out = client.process_output(now()).dgram(); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent2 + 1); + server.process_input(&out.unwrap(), now()); + assert!(matches!( + server.next_event().unwrap(), + ConnectionEvent::Datagram(data) if data == DATA_MTU + )); +} + +fn send_datagram(sender: &mut Connection, receiver: &mut Connection, data: &[u8]) { + let dgram_sent = sender.stats().frame_tx.datagram; + assert_eq!(sender.send_datagram(data, Some(1)), Ok(())); + let out = sender.process_output(now()).dgram().unwrap(); + assert_eq!(sender.stats().frame_tx.datagram, dgram_sent + 1); + + let dgram_received = receiver.stats().frame_rx.datagram; + receiver.process_input(&out, now()); + assert_eq!(receiver.stats().frame_rx.datagram, dgram_received + 1); +} + +#[test] +fn multiple_datagram_events() { + const DATA_SIZE: usize = 1200; + const MAX_QUEUE: usize = 3; + const FIRST_DATAGRAM: &[u8] = &[0; DATA_SIZE]; + const SECOND_DATAGRAM: &[u8] = &[1; DATA_SIZE]; + const THIRD_DATAGRAM: &[u8] = &[2; DATA_SIZE]; + const FOURTH_DATAGRAM: &[u8] = &[3; DATA_SIZE]; + + let mut client = new_client( + ConnectionParameters::default() + .datagram_size(u64::try_from(DATA_SIZE).unwrap()) + .incoming_datagram_queue(MAX_QUEUE), + ); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + send_datagram(&mut server, &mut client, FIRST_DATAGRAM); + send_datagram(&mut server, &mut client, SECOND_DATAGRAM); + send_datagram(&mut server, &mut client, THIRD_DATAGRAM); + + let mut datagrams = client.events().filter_map(|evt| { + if let ConnectionEvent::Datagram(d) = evt { + Some(d) + } else { + None + } + }); + assert_eq!(datagrams.next().unwrap(), FIRST_DATAGRAM); + assert_eq!(datagrams.next().unwrap(), SECOND_DATAGRAM); + assert_eq!(datagrams.next().unwrap(), THIRD_DATAGRAM); + assert!(datagrams.next().is_none()); + + // New events can be queued. + send_datagram(&mut server, &mut client, FOURTH_DATAGRAM); + let mut datagrams = client.events().filter_map(|evt| { + if let ConnectionEvent::Datagram(d) = evt { + Some(d) + } else { + None + } + }); + assert_eq!(datagrams.next().unwrap(), FOURTH_DATAGRAM); + assert!(datagrams.next().is_none()); +} + +#[test] +fn too_many_datagram_events() { + const DATA_SIZE: usize = 1200; + const MAX_QUEUE: usize = 2; + const FIRST_DATAGRAM: &[u8] = &[0; DATA_SIZE]; + const SECOND_DATAGRAM: &[u8] = &[1; DATA_SIZE]; + const THIRD_DATAGRAM: &[u8] = &[2; DATA_SIZE]; + const FOURTH_DATAGRAM: &[u8] = &[3; DATA_SIZE]; + + let mut client = new_client( + ConnectionParameters::default() + .datagram_size(u64::try_from(DATA_SIZE).unwrap()) + .incoming_datagram_queue(MAX_QUEUE), + ); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + send_datagram(&mut server, &mut client, FIRST_DATAGRAM); + send_datagram(&mut server, &mut client, SECOND_DATAGRAM); + send_datagram(&mut server, &mut client, THIRD_DATAGRAM); + + // Datagram with FIRST_DATAGRAM data will be dropped. + assert!(matches!( + client.next_event().unwrap(), + ConnectionEvent::Datagram(data) if data == SECOND_DATAGRAM + )); + assert!(matches!( + client.next_event().unwrap(), + ConnectionEvent::IncomingDatagramDropped + )); + assert!(matches!( + client.next_event().unwrap(), + ConnectionEvent::Datagram(data) if data == THIRD_DATAGRAM + )); + assert!(client.next_event().is_none()); + assert_eq!(client.stats().incoming_datagram_dropped, 1); + + // New events can be queued. + send_datagram(&mut server, &mut client, FOURTH_DATAGRAM); + assert!(matches!( + client.next_event().unwrap(), + ConnectionEvent::Datagram(data) if data == FOURTH_DATAGRAM + )); + assert!(client.next_event().is_none()); + assert_eq!(client.stats().incoming_datagram_dropped, 1); +} + +#[test] +fn multiple_quic_datagrams_in_one_packet() { + let (mut client, mut server) = connect_datagram(); + + let dgram_sent = client.stats().frame_tx.datagram; + // Enqueue 2 datagrams that can fit in a single packet. + assert_eq!( + client.send_datagram(DATA_SMALLER_THAN_MTU_2, Some(1)), + Ok(()) + ); + assert_eq!( + client.send_datagram(DATA_SMALLER_THAN_MTU_2, Some(2)), + Ok(()) + ); + + let out = client.process_output(now()).dgram(); + assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 2); + server.process_input(&out.unwrap(), now()); + let datagram = |e: &_| matches!(e, ConnectionEvent::Datagram(..)); + assert_eq!(server.events().filter(datagram).count(), 2); +} + +/// Datagrams that are close to the capacity of the packet need special +/// handling. They need to use the packet-filling frame type and +/// they cannot allow other frames to follow. +#[test] +fn datagram_fill() { + struct PanickingFrameWriter {} + impl crate::connection::test_internal::FrameWriter for PanickingFrameWriter { + fn write_frames(&mut self, builder: &mut PacketBuilder) { + panic!( + "builder invoked with {} bytes remaining", + builder.remaining() + ); + } + } + struct TrackingFrameWriter { + called: Rc<RefCell<bool>>, + } + impl crate::connection::test_internal::FrameWriter for TrackingFrameWriter { + fn write_frames(&mut self, builder: &mut PacketBuilder) { + assert_eq!(builder.remaining(), 2); + *self.called.borrow_mut() = true; + } + } + + let (mut client, mut server) = connect_datagram(); + + // Work out how much space we have for a datagram. + let space = { + let p = client.paths.primary(); + let path = p.borrow(); + // Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number, + // 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD. + path.mtu() - path.remote_cid().len() - 19 + }; + assert!(space >= 64); // Unlikely, but this test depends on the datagram being this large. + + // This should not be called. + client.test_frame_writer = Some(Box::new(PanickingFrameWriter {})); + + let buf = vec![9; space]; + // This will completely fill available space. + send_datagram(&mut client, &mut server, &buf); + // This will leave 1 byte free, but more frames won't be added in this space. + send_datagram(&mut client, &mut server, &buf[..buf.len() - 1]); + // This will leave 2 bytes free, which is enough space for a length field, + // but not enough space for another frame after that. + send_datagram(&mut client, &mut server, &buf[..buf.len() - 2]); + // Three bytes free will be space enough for a length frame, but not enough + // space left over for another frame (we need 2 bytes). + send_datagram(&mut client, &mut server, &buf[..buf.len() - 3]); + + // Four bytes free is enough space for another frame. + let called = Rc::new(RefCell::new(false)); + client.test_frame_writer = Some(Box::new(TrackingFrameWriter { + called: Rc::clone(&called), + })); + send_datagram(&mut client, &mut server, &buf[..buf.len() - 4]); + assert!(*called.borrow()); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/fuzzing.rs b/third_party/rust/neqo-transport/src/connection/tests/fuzzing.rs new file mode 100644 index 0000000000..5425e1a16e --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/fuzzing.rs @@ -0,0 +1,44 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![cfg_attr(feature = "deny-warnings", deny(warnings))] +#![warn(clippy::pedantic)] +#![cfg(feature = "fuzzing")] + +use neqo_crypto::FIXED_TAG_FUZZING; +use test_fixture::now; + +use super::{connect_force_idle, default_client, default_server}; +use crate::StreamType; + +#[test] +fn no_encryption() { + const DATA_CLIENT: &[u8] = &[2; 40]; + const DATA_SERVER: &[u8] = &[3; 50]; + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + + client.stream_send(stream_id, DATA_CLIENT).unwrap(); + let client_pkt = client.process_output(now()).dgram().unwrap(); + assert!(client_pkt[..client_pkt.len() - FIXED_TAG_FUZZING.len()].ends_with(DATA_CLIENT)); + + server.process_input(&client_pkt, now()); + let mut buf = vec![0; 100]; + let (len, _) = server.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(len, DATA_CLIENT.len()); + assert_eq!(&buf[..len], DATA_CLIENT); + server.stream_send(stream_id, DATA_SERVER).unwrap(); + let server_pkt = server.process_output(now()).dgram().unwrap(); + assert!(server_pkt[..server_pkt.len() - FIXED_TAG_FUZZING.len()].ends_with(DATA_SERVER)); + + client.process_input(&server_pkt, now()); + let (len, _) = client.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(len, DATA_SERVER.len()); + assert_eq!(&buf[..len], DATA_SERVER); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/handshake.rs b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs new file mode 100644 index 0000000000..93385ac1bc --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs @@ -0,0 +1,1137 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cell::RefCell, + convert::TryFrom, + mem, + net::{IpAddr, Ipv6Addr, SocketAddr}, + rc::Rc, + time::Duration, +}; + +use neqo_common::{event::Provider, qdebug, Datagram}; +use neqo_crypto::{ + constants::TLS_CHACHA20_POLY1305_SHA256, generate_ech_keys, AuthenticationStatus, +}; +use test_fixture::{ + self, addr, assertions, assertions::assert_coalesced_0rtt, datagram, fixture_init, now, + split_datagram, +}; + +use super::{ + super::{Connection, Output, State}, + assert_error, connect, connect_force_idle, connect_with_rtt, default_client, default_server, + get_tokens, handshake, maybe_authenticate, resumed_server, send_something, + CountingConnectionIdGenerator, AT_LEAST_PTO, DEFAULT_RTT, DEFAULT_STREAM_DATA, +}; +use crate::{ + connection::AddressValidation, + events::ConnectionEvent, + path::PATH_MTU_V6, + server::ValidateAddress, + tparams::{TransportParameter, MIN_ACK_DELAY}, + tracking::DEFAULT_ACK_DELAY, + ConnectionError, ConnectionParameters, EmptyConnectionIdGenerator, Error, StreamType, Version, +}; + +const ECH_CONFIG_ID: u8 = 7; +const ECH_PUBLIC_NAME: &str = "public.example"; + +#[test] +fn full_handshake() { + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6); + + qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); + let mut server = default_server(); + let out = server.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6); + + qdebug!("---- client: cert verification"); + let out = client.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_some()); + + let out = server.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_none()); + + assert!(maybe_authenticate(&mut client)); + + qdebug!("---- client: SH..FIN -> FIN"); + let out = client.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(*client.state(), State::Connected); + + qdebug!("---- server: FIN -> ACKS"); + let out = server.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(*server.state(), State::Confirmed); + + qdebug!("---- client: ACKS -> 0"); + let out = client.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_none()); + assert_eq!(*client.state(), State::Confirmed); +} + +#[test] +fn handshake_failed_authentication() { + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + + qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); + let mut server = default_server(); + let out = server.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_some()); + + qdebug!("---- client: cert verification"); + let out = client.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_some()); + + let out = server.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_none()); + + let authentication_needed = |e| matches!(e, ConnectionEvent::AuthenticationNeeded); + assert!(client.events().any(authentication_needed)); + qdebug!("---- client: Alert(certificate_revoked)"); + client.authenticated(AuthenticationStatus::CertRevoked, now()); + + qdebug!("---- client: -> Alert(certificate_revoked)"); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + + qdebug!("---- server: Alert(certificate_revoked)"); + let out = server.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_some()); + assert_error(&client, &ConnectionError::Transport(Error::CryptoAlert(44))); + assert_error(&server, &ConnectionError::Transport(Error::PeerError(300))); +} + +#[test] +fn no_alpn() { + fixture_init(); + let mut client = Connection::new_client( + "example.com", + &["bad-alpn"], + Rc::new(RefCell::new(CountingConnectionIdGenerator::default())), + addr(), + addr(), + ConnectionParameters::default(), + now(), + ) + .unwrap(); + let mut server = default_server(); + + handshake(&mut client, &mut server, now(), Duration::new(0, 0)); + // TODO (mt): errors are immediate, which means that we never send CONNECTION_CLOSE + // and the client never sees the server's rejection of its handshake. + // assert_error(&client, ConnectionError::Transport(Error::CryptoAlert(120))); + assert_error( + &server, + &ConnectionError::Transport(Error::CryptoAlert(120)), + ); +} + +#[test] +fn dup_server_flight1() { + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + + qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); + let mut server = default_server(); + let out_to_rep = server.process(out.as_dgram_ref(), now()); + assert!(out_to_rep.as_dgram_ref().is_some()); + qdebug!("Output={:0x?}", out_to_rep.as_dgram_ref()); + + qdebug!("---- client: cert verification"); + let out = client.process(Some(out_to_rep.as_dgram_ref().unwrap()), now()); + assert!(out.as_dgram_ref().is_some()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + + let out = server.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_none()); + + assert!(maybe_authenticate(&mut client)); + + qdebug!("---- client: SH..FIN -> FIN"); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + + assert_eq!(3, client.stats().packets_rx); + assert_eq!(0, client.stats().dups_rx); + assert_eq!(1, client.stats().dropped_rx); + + qdebug!("---- Dup, ignored"); + let out = client.process(out_to_rep.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_none()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + + // Four packets total received, 1 of them is a dup and one has been dropped because Initial keys + // are dropped. Add 2 counts of the padding that the server adds to Initial packets. + assert_eq!(6, client.stats().packets_rx); + assert_eq!(1, client.stats().dups_rx); + assert_eq!(3, client.stats().dropped_rx); +} + +// Test that we split crypto data if they cannot fit into one packet. +// To test this we will use a long server certificate. +#[test] +fn crypto_frame_split() { + let mut client = default_client(); + + let mut server = Connection::new_server( + test_fixture::LONG_CERT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(CountingConnectionIdGenerator::default())), + ConnectionParameters::default(), + ) + .expect("create a server"); + + let client1 = client.process(None, now()); + assert!(client1.as_dgram_ref().is_some()); + + // The entire server flight doesn't fit in a single packet because the + // certificate is large, therefore the server will produce 2 packets. + let server1 = server.process(client1.as_dgram_ref(), now()); + assert!(server1.as_dgram_ref().is_some()); + let server2 = server.process(None, now()); + assert!(server2.as_dgram_ref().is_some()); + + let client2 = client.process(server1.as_dgram_ref(), now()); + // This is an ack. + assert!(client2.as_dgram_ref().is_some()); + // The client might have the certificate now, so we can't guarantee that + // this will work. + let auth1 = maybe_authenticate(&mut client); + assert_eq!(*client.state(), State::Handshaking); + + // let server process the ack for the first packet. + let server3 = server.process(client2.as_dgram_ref(), now()); + assert!(server3.as_dgram_ref().is_none()); + + // Consume the second packet from the server. + let client3 = client.process(server2.as_dgram_ref(), now()); + + // Check authentication. + let auth2 = maybe_authenticate(&mut client); + assert!(auth1 ^ auth2); + // Now client has all data to finish handshake. + assert_eq!(*client.state(), State::Connected); + + let client4 = client.process(server3.as_dgram_ref(), now()); + // One of these will contain data depending on whether Authentication was completed + // after the first or second server packet. + assert!(client3.as_dgram_ref().is_some() ^ client4.as_dgram_ref().is_some()); + + mem::drop(server.process(client3.as_dgram_ref(), now())); + mem::drop(server.process(client4.as_dgram_ref(), now())); + + assert_eq!(*client.state(), State::Connected); + assert_eq!(*server.state(), State::Confirmed); +} + +/// Run a single ChaCha20-Poly1305 test and get a PTO. +#[test] +fn chacha20poly1305() { + let mut server = default_server(); + let mut client = Connection::new_client( + test_fixture::DEFAULT_SERVER_NAME, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(EmptyConnectionIdGenerator::default())), + addr(), + addr(), + ConnectionParameters::default(), + now(), + ) + .expect("create a default client"); + client.set_ciphers(&[TLS_CHACHA20_POLY1305_SHA256]).unwrap(); + connect_force_idle(&mut client, &mut server); +} + +/// Test that a server can send 0.5 RTT application data. +#[test] +fn send_05rtt() { + let mut client = default_client(); + let mut server = default_server(); + + let c1 = client.process(None, now()).dgram(); + assert!(c1.is_some()); + let s1 = server.process(c1.as_ref(), now()).dgram().unwrap(); + assert_eq!(s1.len(), PATH_MTU_V6); + + // The server should accept writes at this point. + let s2 = send_something(&mut server, now()); + + // Complete the handshake at the client. + client.process_input(&s1, now()); + maybe_authenticate(&mut client); + assert_eq!(*client.state(), State::Connected); + + // The client should receive the 0.5-RTT data now. + client.process_input(&s2, now()); + let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1]; + let stream_id = client + .events() + .find_map(|e| { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + Some(stream_id) + } else { + None + } + }) + .unwrap(); + let (l, ended) = client.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(&buf[..l], DEFAULT_STREAM_DATA); + assert!(ended); +} + +/// Test that a client buffers 0.5-RTT data when it arrives early. +#[test] +fn reorder_05rtt() { + let mut client = default_client(); + let mut server = default_server(); + + let c1 = client.process(None, now()).dgram(); + assert!(c1.is_some()); + let s1 = server.process(c1.as_ref(), now()).dgram().unwrap(); + + // The server should accept writes at this point. + let s2 = send_something(&mut server, now()); + + // We can't use the standard facility to complete the handshake, so + // drive it as aggressively as possible. + client.process_input(&s2, now()); + assert_eq!(client.stats().saved_datagrams, 1); + + // After processing the first packet, the client should go back and + // process the 0.5-RTT packet data, which should make data available. + client.process_input(&s1, now()); + // We can't use `maybe_authenticate` here as that consumes events. + client.authenticated(AuthenticationStatus::Ok, now()); + assert_eq!(*client.state(), State::Connected); + + let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1]; + let stream_id = client + .events() + .find_map(|e| { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + Some(stream_id) + } else { + None + } + }) + .unwrap(); + let (l, ended) = client.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(&buf[..l], DEFAULT_STREAM_DATA); + assert!(ended); +} + +#[test] +fn reorder_05rtt_with_0rtt() { + const RTT: Duration = Duration::from_millis(100); + + let mut client = default_client(); + let mut server = default_server(); + let validation = AddressValidation::new(now(), ValidateAddress::NoToken).unwrap(); + let validation = Rc::new(RefCell::new(validation)); + server.set_validation(Rc::clone(&validation)); + let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT); + + // Include RTT in sending the ticket or the ticket age reported by the + // client is wrong, which causes the server to reject 0-RTT. + now += RTT / 2; + server.send_ticket(now, &[]).unwrap(); + let ticket = server.process_output(now).dgram().unwrap(); + now += RTT / 2; + client.process_input(&ticket, now); + + let token = get_tokens(&mut client).pop().unwrap(); + let mut client = default_client(); + client.enable_resumption(now, token).unwrap(); + let mut server = resumed_server(&client); + + // Send ClientHello and some 0-RTT. + let c1 = send_something(&mut client, now); + assertions::assert_coalesced_0rtt(&c1[..]); + // Drop the 0-RTT from the coalesced datagram, so that the server + // acknowledges the next 0-RTT packet. + let (c1, _) = split_datagram(&c1); + let c2 = send_something(&mut client, now); + + // Handle the first packet and send 0.5-RTT in response. Drop the response. + now += RTT / 2; + mem::drop(server.process(Some(&c1), now).dgram().unwrap()); + // The gap in 0-RTT will result in this 0.5 RTT containing an ACK. + server.process_input(&c2, now); + let s2 = send_something(&mut server, now); + + // Save the 0.5 RTT. + now += RTT / 2; + client.process_input(&s2, now); + assert_eq!(client.stats().saved_datagrams, 1); + + // Now PTO at the client and cause the server to re-send handshake packets. + now += AT_LEAST_PTO; + let c3 = client.process(None, now).dgram(); + assert_coalesced_0rtt(c3.as_ref().unwrap()); + + now += RTT / 2; + let s3 = server.process(c3.as_ref(), now).dgram().unwrap(); + + // The client should be able to process the 0.5 RTT now. + // This should contain an ACK, so we are processing an ACK from the past. + now += RTT / 2; + client.process_input(&s3, now); + maybe_authenticate(&mut client); + let c4 = client.process(None, now).dgram(); + assert_eq!(*client.state(), State::Connected); + assert_eq!(client.paths.rtt(), RTT); + + now += RTT / 2; + server.process_input(&c4.unwrap(), now); + assert_eq!(*server.state(), State::Confirmed); + // Don't check server RTT as it will be massively inflated by a + // poor initial estimate received when the server dropped the + // Initial packet number space. +} + +/// Test that a server that coalesces 0.5 RTT with handshake packets +/// doesn't cause the client to drop application data. +#[test] +fn coalesce_05rtt() { + const RTT: Duration = Duration::from_millis(100); + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + + // The first exchange doesn't offer a chance for the server to send. + // So drop the server flight and wait for the PTO. + let c1 = client.process(None, now).dgram(); + assert!(c1.is_some()); + now += RTT / 2; + let s1 = server.process(c1.as_ref(), now).dgram(); + assert!(s1.is_some()); + + // Drop the server flight. Then send some data. + let stream_id = server.stream_create(StreamType::UniDi).unwrap(); + assert!(server.stream_send(stream_id, DEFAULT_STREAM_DATA).is_ok()); + assert!(server.stream_close_send(stream_id).is_ok()); + + // Now after a PTO the client can send another packet. + // The server should then send its entire flight again, + // including the application data, which it sends in a 1-RTT packet. + now += AT_LEAST_PTO; + let c2 = client.process(None, now).dgram(); + assert!(c2.is_some()); + now += RTT / 2; + let s2 = server.process(c2.as_ref(), now).dgram(); + // Even though there is a 1-RTT packet at the end of the datagram, the + // flight should be padded to full size. + assert_eq!(s2.as_ref().unwrap().len(), PATH_MTU_V6); + + // The client should process the datagram. It can't process the 1-RTT + // packet until authentication completes though. So it saves it. + now += RTT / 2; + assert_eq!(client.stats().dropped_rx, 0); + mem::drop(client.process(s2.as_ref(), now).dgram()); + // This packet will contain an ACK, but we can ignore it. + assert_eq!(client.stats().dropped_rx, 0); + assert_eq!(client.stats().packets_rx, 3); + assert_eq!(client.stats().saved_datagrams, 1); + + // After (successful) authentication, the packet is processed. + maybe_authenticate(&mut client); + let c3 = client.process(None, now).dgram(); + assert!(c3.is_some()); + assert_eq!(client.stats().dropped_rx, 0); // No Initial padding. + assert_eq!(client.stats().packets_rx, 4); + assert_eq!(client.stats().saved_datagrams, 1); + assert_eq!(client.stats().frame_rx.padding, 1); // Padding uses frames. + + // Allow the handshake to complete. + now += RTT / 2; + let s3 = server.process(c3.as_ref(), now).dgram(); + assert!(s3.is_some()); + assert_eq!(*server.state(), State::Confirmed); + now += RTT / 2; + mem::drop(client.process(s3.as_ref(), now).dgram()); + assert_eq!(*client.state(), State::Confirmed); + + assert_eq!(client.stats().dropped_rx, 0); // No dropped packets. +} + +#[test] +fn reorder_handshake() { + const RTT: Duration = Duration::from_millis(100); + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + + let c1 = client.process(None, now).dgram(); + assert!(c1.is_some()); + + now += RTT / 2; + let s1 = server.process(c1.as_ref(), now).dgram(); + assert!(s1.is_some()); + + // Drop the Initial packet from this. + let (_, s_hs) = split_datagram(&s1.unwrap()); + assert!(s_hs.is_some()); + + // Pass just the handshake packet in and the client can't handle it yet. + // It can only send another Initial packet. + now += RTT / 2; + let dgram = client.process(s_hs.as_ref(), now).dgram(); + assertions::assert_initial(dgram.as_ref().unwrap(), false); + assert_eq!(client.stats().saved_datagrams, 1); + assert_eq!(client.stats().packets_rx, 1); + + // Get the server to try again. + // Though we currently allow the server to arm its PTO timer, use + // a second client Initial packet to cause it to send again. + now += AT_LEAST_PTO; + let c2 = client.process(None, now).dgram(); + now += RTT / 2; + let s2 = server.process(c2.as_ref(), now).dgram(); + assert!(s2.is_some()); + + let (s_init, s_hs) = split_datagram(&s2.unwrap()); + assert!(s_hs.is_some()); + + // Processing the Handshake packet first should save it. + now += RTT / 2; + client.process_input(&s_hs.unwrap(), now); + assert_eq!(client.stats().saved_datagrams, 2); + assert_eq!(client.stats().packets_rx, 2); + + client.process_input(&s_init, now); + // Each saved packet should now be "received" again. + assert_eq!(client.stats().packets_rx, 7); + maybe_authenticate(&mut client); + let c3 = client.process(None, now).dgram(); + assert!(c3.is_some()); + + // Note that though packets were saved and processed very late, + // they don't cause the RTT to change. + now += RTT / 2; + let s3 = server.process(c3.as_ref(), now).dgram(); + assert_eq!(*server.state(), State::Confirmed); + // Don't check server RTT estimate as it will be inflated due to + // it making a guess based on retransmissions when it dropped + // the Initial packet number space. + + now += RTT / 2; + client.process_input(&s3.unwrap(), now); + assert_eq!(*client.state(), State::Confirmed); + assert_eq!(client.paths.rtt(), RTT); +} + +#[test] +fn reorder_1rtt() { + const RTT: Duration = Duration::from_millis(100); + const PACKETS: usize = 4; // Many, but not enough to overflow cwnd. + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + + let c1 = client.process(None, now).dgram(); + assert!(c1.is_some()); + + now += RTT / 2; + let s1 = server.process(c1.as_ref(), now).dgram(); + assert!(s1.is_some()); + + now += RTT / 2; + client.process_input(&s1.unwrap(), now); + maybe_authenticate(&mut client); + let c2 = client.process(None, now).dgram(); + assert!(c2.is_some()); + + // Now get a bunch of packets from the client. + // Give them to the server before giving it `c2`. + for _ in 0..PACKETS { + let d = send_something(&mut client, now); + server.process_input(&d, now + RTT / 2); + } + // The server has now received those packets, and saved them. + // The two extra received are Initial + the junk we use for padding. + assert_eq!(server.stats().packets_rx, PACKETS + 2); + assert_eq!(server.stats().saved_datagrams, PACKETS); + assert_eq!(server.stats().dropped_rx, 1); + + now += RTT / 2; + let s2 = server.process(c2.as_ref(), now).dgram(); + // The server has now received those packets, and saved them. + // The two additional are a Handshake and a 1-RTT (w/ NEW_CONNECTION_ID). + assert_eq!(server.stats().packets_rx, PACKETS * 2 + 4); + assert_eq!(server.stats().saved_datagrams, PACKETS); + assert_eq!(server.stats().dropped_rx, 1); + assert_eq!(*server.state(), State::Confirmed); + assert_eq!(server.paths.rtt(), RTT); + + now += RTT / 2; + client.process_input(&s2.unwrap(), now); + assert_eq!(client.paths.rtt(), RTT); + + // All the stream data that was sent should now be available. + let streams = server + .events() + .filter_map(|e| { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + Some(stream_id) + } else { + None + } + }) + .collect::<Vec<_>>(); + assert_eq!(streams.len(), PACKETS); + for stream_id in streams { + let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1]; + let (recvd, fin) = server.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(recvd, DEFAULT_STREAM_DATA.len()); + assert!(fin); + } +} + +#[cfg(not(feature = "fuzzing"))] +#[test] +fn corrupted_initial() { + let mut client = default_client(); + let mut server = default_server(); + let d = client.process(None, now()).dgram().unwrap(); + let mut corrupted = Vec::from(&d[..]); + // Find the last non-zero value and corrupt that. + let (idx, _) = corrupted + .iter() + .enumerate() + .rev() + .find(|(_, &v)| v != 0) + .unwrap(); + corrupted[idx] ^= 0x76; + let dgram = Datagram::new(d.source(), d.destination(), d.tos(), d.ttl(), corrupted); + server.process_input(&dgram, now()); + // The server should have received two packets, + // the first should be dropped, the second saved. + assert_eq!(server.stats().packets_rx, 2); + assert_eq!(server.stats().dropped_rx, 2); + assert_eq!(server.stats().saved_datagrams, 0); +} + +#[test] +// Absent path PTU discovery, max v6 packet size should be PATH_MTU_V6. +fn verify_pkt_honors_mtu() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let now = now(); + + let res = client.process(None, now); + let idle_timeout = ConnectionParameters::default().get_idle_timeout(); + assert_eq!(res, Output::Callback(idle_timeout)); + + // Try to send a large stream and verify first packet is correctly sized + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(client.stream_send(stream_id, &[0xbb; 2000]).unwrap(), 2000); + let pkt0 = client.process(None, now); + assert!(matches!(pkt0, Output::Datagram(_))); + assert_eq!(pkt0.as_dgram_ref().unwrap().len(), PATH_MTU_V6); +} + +#[test] +fn extra_initial_hs() { + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + + let c_init = client.process(None, now).dgram(); + assert!(c_init.is_some()); + now += DEFAULT_RTT / 2; + let s_init = server.process(c_init.as_ref(), now).dgram(); + assert!(s_init.is_some()); + now += DEFAULT_RTT / 2; + + // Drop the Initial packet, keep only the Handshake. + let (_, undecryptable) = split_datagram(&s_init.unwrap()); + assert!(undecryptable.is_some()); + + // Feed the same undecryptable packet into the client a few times. + // Do that EXTRA_INITIALS times and each time the client will emit + // another Initial packet. + for _ in 0..=super::super::EXTRA_INITIALS { + let c_init = client.process(undecryptable.as_ref(), now).dgram(); + assertions::assert_initial(c_init.as_ref().unwrap(), false); + now += DEFAULT_RTT / 10; + } + + // After EXTRA_INITIALS, the client stops sending Initial packets. + let nothing = client.process(undecryptable.as_ref(), now).dgram(); + assert!(nothing.is_none()); + + // Until PTO, where another Initial can be used to complete the handshake. + now += AT_LEAST_PTO; + let c_init = client.process(None, now).dgram(); + assertions::assert_initial(c_init.as_ref().unwrap(), false); + now += DEFAULT_RTT / 2; + let s_init = server.process(c_init.as_ref(), now).dgram(); + now += DEFAULT_RTT / 2; + client.process_input(&s_init.unwrap(), now); + maybe_authenticate(&mut client); + let c_fin = client.process_output(now).dgram(); + assert_eq!(*client.state(), State::Connected); + now += DEFAULT_RTT / 2; + server.process_input(&c_fin.unwrap(), now); + assert_eq!(*server.state(), State::Confirmed); +} + +#[test] +fn extra_initial_invalid_cid() { + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + + let c_init = client.process(None, now).dgram(); + assert!(c_init.is_some()); + now += DEFAULT_RTT / 2; + let s_init = server.process(c_init.as_ref(), now).dgram(); + assert!(s_init.is_some()); + now += DEFAULT_RTT / 2; + + // If the client receives a packet that contains the wrong connection + // ID, it won't send another Initial. + let (_, hs) = split_datagram(&s_init.unwrap()); + let hs = hs.unwrap(); + let mut copy = hs.to_vec(); + assert_ne!(copy[5], 0); // The DCID should be non-zero length. + copy[6] ^= 0xc4; + let dgram_copy = Datagram::new(hs.destination(), hs.source(), hs.tos(), hs.ttl(), copy); + let nothing = client.process(Some(&dgram_copy), now).dgram(); + assert!(nothing.is_none()); +} + +#[test] +fn connect_one_version() { + fn connect_v(version: Version) { + fixture_init(); + let mut client = Connection::new_client( + test_fixture::DEFAULT_SERVER_NAME, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(CountingConnectionIdGenerator::default())), + addr(), + addr(), + ConnectionParameters::default().versions(version, vec![version]), + now(), + ) + .unwrap(); + let mut server = Connection::new_server( + test_fixture::DEFAULT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(CountingConnectionIdGenerator::default())), + ConnectionParameters::default().versions(version, vec![version]), + ) + .unwrap(); + connect_force_idle(&mut client, &mut server); + assert_eq!(client.version(), version); + assert_eq!(server.version(), version); + } + + for v in Version::all() { + println!("Connecting with {v:?}"); + connect_v(v); + } +} + +#[test] +fn anti_amplification() { + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + + // With a gigantic transport parameter, the server is unable to complete + // the handshake within the amplification limit. + let very_big = TransportParameter::Bytes(vec![0; PATH_MTU_V6 * 3]); + server.set_local_tparam(0xce16, very_big).unwrap(); + + let c_init = client.process_output(now).dgram(); + now += DEFAULT_RTT / 2; + let s_init1 = server.process(c_init.as_ref(), now).dgram().unwrap(); + assert_eq!(s_init1.len(), PATH_MTU_V6); + let s_init2 = server.process_output(now).dgram().unwrap(); + assert_eq!(s_init2.len(), PATH_MTU_V6); + + // Skip the gap for pacing here. + let s_pacing = server.process_output(now).callback(); + assert_ne!(s_pacing, Duration::new(0, 0)); + now += s_pacing; + + let s_init3 = server.process_output(now).dgram().unwrap(); + assert_eq!(s_init3.len(), PATH_MTU_V6); + let cb = server.process_output(now).callback(); + assert_ne!(cb, Duration::new(0, 0)); + + now += DEFAULT_RTT / 2; + client.process_input(&s_init1, now); + client.process_input(&s_init2, now); + let ack_count = client.stats().frame_tx.ack; + let frame_count = client.stats().frame_tx.all; + let ack = client.process(Some(&s_init3), now).dgram().unwrap(); + assert!(!maybe_authenticate(&mut client)); // No need yet. + + // The client sends a padded datagram, with just ACK for Handshake. + assert_eq!(client.stats().frame_tx.ack, ack_count + 1); + assert_eq!(client.stats().frame_tx.all, frame_count + 1); + assert_ne!(ack.len(), PATH_MTU_V6); // Not padded (it includes Handshake). + + now += DEFAULT_RTT / 2; + let remainder = server.process(Some(&ack), now).dgram(); + + now += DEFAULT_RTT / 2; + client.process_input(&remainder.unwrap(), now); + assert!(maybe_authenticate(&mut client)); // OK, we have all of it. + let fin = client.process_output(now).dgram(); + assert_eq!(*client.state(), State::Connected); + + now += DEFAULT_RTT / 2; + server.process_input(&fin.unwrap(), now); + assert_eq!(*server.state(), State::Confirmed); +} + +#[cfg(not(feature = "fuzzing"))] +#[test] +fn garbage_initial() { + let mut client = default_client(); + let mut server = default_server(); + + let dgram = client.process_output(now()).dgram().unwrap(); + let (initial, rest) = split_datagram(&dgram); + let mut corrupted = Vec::from(&initial[..initial.len() - 1]); + corrupted.push(initial[initial.len() - 1] ^ 0xb7); + corrupted.extend_from_slice(rest.as_ref().map_or(&[], |r| &r[..])); + let garbage = datagram(corrupted); + assert_eq!(Output::None, server.process(Some(&garbage), now())); +} + +#[test] +fn drop_initial_packet_from_wrong_address() { + let mut client = default_client(); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + + let mut server = default_server(); + let out = server.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_some()); + + let p = out.dgram().unwrap(); + let dgram = Datagram::new( + SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2)), 443), + p.destination(), + p.tos(), + p.ttl(), + &p[..], + ); + + let out = client.process(Some(&dgram), now()); + assert!(out.as_dgram_ref().is_none()); +} + +#[test] +fn drop_handshake_packet_from_wrong_address() { + let mut client = default_client(); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + + let mut server = default_server(); + let out = server.process(out.as_dgram_ref(), now()); + assert!(out.as_dgram_ref().is_some()); + + let (s_in, s_hs) = split_datagram(&out.dgram().unwrap()); + + // Pass the initial packet. + mem::drop(client.process(Some(&s_in), now()).dgram()); + + let p = s_hs.unwrap(); + let dgram = Datagram::new( + SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2)), 443), + p.destination(), + p.tos(), + p.ttl(), + &p[..], + ); + + let out = client.process(Some(&dgram), now()); + assert!(out.as_dgram_ref().is_none()); +} + +#[test] +fn ech() { + let mut server = default_server(); + let (sk, pk) = generate_ech_keys().unwrap(); + server + .server_enable_ech(ECH_CONFIG_ID, ECH_PUBLIC_NAME, &sk, &pk) + .unwrap(); + + let mut client = default_client(); + client.client_enable_ech(server.ech_config()).unwrap(); + + connect(&mut client, &mut server); + + assert!(client.tls_info().unwrap().ech_accepted()); + assert!(server.tls_info().unwrap().ech_accepted()); + assert!(client.tls_preinfo().unwrap().ech_accepted().unwrap()); + assert!(server.tls_preinfo().unwrap().ech_accepted().unwrap()); +} + +fn damaged_ech_config(config: &[u8]) -> Vec<u8> { + let mut cfg = Vec::from(config); + // Ensure that the version and config_id is correct. + assert_eq!(cfg[2], 0xfe); + assert_eq!(cfg[3], 0x0d); + assert_eq!(cfg[6], ECH_CONFIG_ID); + // Change the config_id so that the server doesn't recognize it. + cfg[6] ^= 0x94; + cfg +} + +#[test] +fn ech_retry() { + fixture_init(); + let mut server = default_server(); + let (sk, pk) = generate_ech_keys().unwrap(); + server + .server_enable_ech(ECH_CONFIG_ID, ECH_PUBLIC_NAME, &sk, &pk) + .unwrap(); + + let mut client = default_client(); + client + .client_enable_ech(&damaged_ech_config(server.ech_config())) + .unwrap(); + + let dgram = client.process_output(now()).dgram(); + let dgram = server.process(dgram.as_ref(), now()).dgram(); + client.process_input(&dgram.unwrap(), now()); + let auth_event = ConnectionEvent::EchFallbackAuthenticationNeeded { + public_name: String::from(ECH_PUBLIC_NAME), + }; + assert!(client.events().any(|e| e == auth_event)); + client.authenticated(AuthenticationStatus::Ok, now()); + assert!(client.state().error().is_some()); + + // Tell the server about the error. + let dgram = client.process_output(now()).dgram(); + server.process_input(&dgram.unwrap(), now()); + assert_eq!( + server.state().error(), + Some(&ConnectionError::Transport(Error::PeerError(0x100 + 121))) + ); + + let Some(ConnectionError::Transport(Error::EchRetry(updated_config))) = client.state().error() + else { + panic!( + "Client state should be failed with EchRetry, is {:?}", + client.state() + ); + }; + + let mut server = default_server(); + server + .server_enable_ech(ECH_CONFIG_ID, ECH_PUBLIC_NAME, &sk, &pk) + .unwrap(); + let mut client = default_client(); + client.client_enable_ech(updated_config).unwrap(); + + connect(&mut client, &mut server); + + assert!(client.tls_info().unwrap().ech_accepted()); + assert!(server.tls_info().unwrap().ech_accepted()); + assert!(client.tls_preinfo().unwrap().ech_accepted().unwrap()); + assert!(server.tls_preinfo().unwrap().ech_accepted().unwrap()); +} + +#[test] +fn ech_retry_fallback_rejected() { + fixture_init(); + let mut server = default_server(); + let (sk, pk) = generate_ech_keys().unwrap(); + server + .server_enable_ech(ECH_CONFIG_ID, ECH_PUBLIC_NAME, &sk, &pk) + .unwrap(); + + let mut client = default_client(); + client + .client_enable_ech(&damaged_ech_config(server.ech_config())) + .unwrap(); + + let dgram = client.process_output(now()).dgram(); + let dgram = server.process(dgram.as_ref(), now()).dgram(); + client.process_input(&dgram.unwrap(), now()); + let auth_event = ConnectionEvent::EchFallbackAuthenticationNeeded { + public_name: String::from(ECH_PUBLIC_NAME), + }; + assert!(client.events().any(|e| e == auth_event)); + client.authenticated(AuthenticationStatus::PolicyRejection, now()); + assert!(client.state().error().is_some()); + + if let Some(ConnectionError::Transport(Error::EchRetry(_))) = client.state().error() { + panic!("Client should not get EchRetry error"); + } + + // Pass the error on. + let dgram = client.process_output(now()).dgram(); + server.process_input(&dgram.unwrap(), now()); + assert_eq!( + server.state().error(), + Some(&ConnectionError::Transport(Error::PeerError(298))) + ); // A bad_certificate alert. +} + +#[test] +fn bad_min_ack_delay() { + const EXPECTED_ERROR: ConnectionError = + ConnectionError::Transport(Error::TransportParameterError); + let mut server = default_server(); + let max_ad = u64::try_from(DEFAULT_ACK_DELAY.as_micros()).unwrap(); + server + .set_local_tparam(MIN_ACK_DELAY, TransportParameter::Integer(max_ad + 1)) + .unwrap(); + let mut client = default_client(); + + let dgram = client.process_output(now()).dgram(); + let dgram = server.process(dgram.as_ref(), now()).dgram(); + client.process_input(&dgram.unwrap(), now()); + client.authenticated(AuthenticationStatus::Ok, now()); + assert_eq!(client.state().error(), Some(&EXPECTED_ERROR)); + let dgram = client.process_output(now()).dgram(); + + server.process_input(&dgram.unwrap(), now()); + assert_eq!( + server.state().error(), + Some(&ConnectionError::Transport(Error::PeerError( + Error::TransportParameterError.code() + ))) + ); +} + +/// Ensure that the client probes correctly if it only receives Initial packets +/// from the server. +#[test] +fn only_server_initial() { + let mut server = default_server(); + let mut client = default_client(); + let mut now = now(); + + let client_dgram = client.process_output(now).dgram(); + + // Now fetch two flights of messages from the server. + let server_dgram1 = server.process(client_dgram.as_ref(), now).dgram(); + let server_dgram2 = server.process_output(now + AT_LEAST_PTO).dgram(); + + // Only pass on the Initial from the first. We should get a Handshake in return. + let (initial, handshake) = split_datagram(&server_dgram1.unwrap()); + assert!(handshake.is_some()); + + // The client will not acknowledge the Initial as it discards keys. + // It sends a Handshake probe instead, containing just a PING frame. + assert_eq!(client.stats().frame_tx.ping, 0); + let probe = client.process(Some(&initial), now).dgram(); + assertions::assert_handshake(&probe.unwrap()); + assert_eq!(client.stats().dropped_rx, 0); + assert_eq!(client.stats().frame_tx.ping, 1); + + let (initial, handshake) = split_datagram(&server_dgram2.unwrap()); + assert!(handshake.is_some()); + + // The same happens after a PTO, even though the client will discard the Initial packet. + now += AT_LEAST_PTO; + assert_eq!(client.stats().frame_tx.ping, 1); + let discarded = client.stats().dropped_rx; + let probe = client.process(Some(&initial), now).dgram(); + assertions::assert_handshake(&probe.unwrap()); + assert_eq!(client.stats().frame_tx.ping, 2); + assert_eq!(client.stats().dropped_rx, discarded + 1); + + // Pass the Handshake packet and complete the handshake. + client.process_input(&handshake.unwrap(), now); + maybe_authenticate(&mut client); + let dgram = client.process_output(now).dgram(); + let dgram = server.process(dgram.as_ref(), now).dgram(); + client.process_input(&dgram.unwrap(), now); + + assert_eq!(*client.state(), State::Confirmed); + assert_eq!(*server.state(), State::Confirmed); +} + +// Collect a few spare Initial packets as the handshake is exchanged. +// Later, replay those packets to see if they result in additional probes; they should not. +#[test] +fn no_extra_probes_after_confirmed() { + let mut server = default_server(); + let mut client = default_client(); + let mut now = now(); + + // First, collect a client Initial. + let spare_initial = client.process_output(now).dgram(); + assert!(spare_initial.is_some()); + + // Collect ANOTHER client Initial. + now += AT_LEAST_PTO; + let dgram = client.process_output(now).dgram(); + let (replay_initial, _) = split_datagram(dgram.as_ref().unwrap()); + + // Finally, run the handshake. + now += AT_LEAST_PTO * 2; + let dgram = client.process_output(now).dgram(); + let dgram = server.process(dgram.as_ref(), now).dgram(); + + // The server should have dropped the Initial keys now, so passing in the Initial + // should elicit a retransmit rather than having it completely ignored. + let spare_handshake = server.process(Some(&replay_initial), now).dgram(); + assert!(spare_handshake.is_some()); + + client.process_input(&dgram.unwrap(), now); + maybe_authenticate(&mut client); + let dgram = client.process_output(now).dgram(); + let dgram = server.process(dgram.as_ref(), now).dgram(); + client.process_input(&dgram.unwrap(), now); + + assert_eq!(*client.state(), State::Confirmed); + assert_eq!(*server.state(), State::Confirmed); + + let probe = server.process(spare_initial.as_ref(), now).dgram(); + assert!(probe.is_none()); + let probe = client.process(spare_handshake.as_ref(), now).dgram(); + assert!(probe.is_none()); +} + +#[test] +fn implicit_rtt_server() { + const RTT: Duration = Duration::from_secs(2); + let mut server = default_server(); + let mut client = default_client(); + let mut now = now(); + + let dgram = client.process_output(now).dgram(); + now += RTT / 2; + let dgram = server.process(dgram.as_ref(), now).dgram(); + now += RTT / 2; + let dgram = client.process(dgram.as_ref(), now).dgram(); + assertions::assert_handshake(dgram.as_ref().unwrap()); + now += RTT / 2; + server.process_input(&dgram.unwrap(), now); + + // The server doesn't receive any acknowledgments, but it can infer + // an RTT estimate from having discarded the Initial packet number space. + assert_eq!(server.stats().rtt, RTT); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/idle.rs b/third_party/rust/neqo-transport/src/connection/tests/idle.rs new file mode 100644 index 0000000000..c33726917a --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/idle.rs @@ -0,0 +1,752 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + mem, + time::{Duration, Instant}, +}; + +use neqo_common::{qtrace, Encoder}; +use test_fixture::{self, now, split_datagram}; + +use super::{ + super::{Connection, ConnectionParameters, IdleTimeout, Output, State}, + connect, connect_force_idle, connect_rtt_idle, connect_with_rtt, default_client, + default_server, maybe_authenticate, new_client, new_server, send_and_receive, send_something, + AT_LEAST_PTO, DEFAULT_STREAM_DATA, +}; +use crate::{ + packet::PacketBuilder, + stats::FrameStats, + stream_id::{StreamId, StreamType}, + tparams::{self, TransportParameter}, + tracking::PacketNumberSpace, +}; + +fn default_timeout() -> Duration { + ConnectionParameters::default().get_idle_timeout() +} + +fn test_idle_timeout(client: &mut Connection, server: &mut Connection, timeout: Duration) { + assert!(timeout > Duration::from_secs(1)); + connect_force_idle(client, server); + + let now = now(); + + let res = client.process(None, now); + assert_eq!(res, Output::Callback(timeout)); + + // Still connected after timeout-1 seconds. Idle timer not reset + mem::drop(client.process( + None, + now + timeout.checked_sub(Duration::from_secs(1)).unwrap(), + )); + assert!(matches!(client.state(), State::Confirmed)); + + mem::drop(client.process(None, now + timeout)); + + // Not connected after timeout. + assert!(matches!(client.state(), State::Closed(_))); +} + +#[test] +fn idle_timeout() { + let mut client = default_client(); + let mut server = default_server(); + test_idle_timeout(&mut client, &mut server, default_timeout()); +} + +#[test] +fn idle_timeout_custom_client() { + const IDLE_TIMEOUT: Duration = Duration::from_secs(5); + let mut client = new_client(ConnectionParameters::default().idle_timeout(IDLE_TIMEOUT)); + let mut server = default_server(); + test_idle_timeout(&mut client, &mut server, IDLE_TIMEOUT); +} + +#[test] +fn idle_timeout_custom_server() { + const IDLE_TIMEOUT: Duration = Duration::from_secs(5); + let mut client = default_client(); + let mut server = new_server(ConnectionParameters::default().idle_timeout(IDLE_TIMEOUT)); + test_idle_timeout(&mut client, &mut server, IDLE_TIMEOUT); +} + +#[test] +fn idle_timeout_custom_both() { + const LOWER_TIMEOUT: Duration = Duration::from_secs(5); + const HIGHER_TIMEOUT: Duration = Duration::from_secs(10); + let mut client = new_client(ConnectionParameters::default().idle_timeout(HIGHER_TIMEOUT)); + let mut server = new_server(ConnectionParameters::default().idle_timeout(LOWER_TIMEOUT)); + test_idle_timeout(&mut client, &mut server, LOWER_TIMEOUT); +} + +#[test] +fn asymmetric_idle_timeout() { + const LOWER_TIMEOUT_MS: u64 = 1000; + const LOWER_TIMEOUT: Duration = Duration::from_millis(LOWER_TIMEOUT_MS); + // Sanity check the constant. + assert!(LOWER_TIMEOUT < default_timeout()); + + let mut client = default_client(); + let mut server = default_server(); + + // Overwrite the default at the server. + server + .tps + .borrow_mut() + .local + .set_integer(tparams::IDLE_TIMEOUT, LOWER_TIMEOUT_MS); + server.idle_timeout = IdleTimeout::new(LOWER_TIMEOUT); + + // Now connect and force idleness manually. + // We do that by following what `force_idle` does and have each endpoint + // send two packets, which are delivered out of order. See `force_idle`. + connect(&mut client, &mut server); + let c1 = send_something(&mut client, now()); + let c2 = send_something(&mut client, now()); + server.process_input(&c2, now()); + server.process_input(&c1, now()); + let s1 = send_something(&mut server, now()); + let s2 = send_something(&mut server, now()); + client.process_input(&s2, now()); + let ack = client.process(Some(&s1), now()).dgram(); + assert!(ack.is_some()); + // Now both should have received ACK frames so should be idle. + assert_eq!( + server.process(ack.as_ref(), now()), + Output::Callback(LOWER_TIMEOUT) + ); + assert_eq!(client.process(None, now()), Output::Callback(LOWER_TIMEOUT)); +} + +#[test] +fn tiny_idle_timeout() { + const RTT: Duration = Duration::from_millis(500); + const LOWER_TIMEOUT_MS: u64 = 100; + const LOWER_TIMEOUT: Duration = Duration::from_millis(LOWER_TIMEOUT_MS); + // We won't respect a value that is lower than 3*PTO, sanity check. + assert!(LOWER_TIMEOUT < 3 * RTT); + + let mut client = default_client(); + let mut server = default_server(); + + // Overwrite the default at the server. + server + .set_local_tparam( + tparams::IDLE_TIMEOUT, + TransportParameter::Integer(LOWER_TIMEOUT_MS), + ) + .unwrap(); + server.idle_timeout = IdleTimeout::new(LOWER_TIMEOUT); + + // Now connect with an RTT and force idleness manually. + let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT); + let c1 = send_something(&mut client, now); + let c2 = send_something(&mut client, now); + now += RTT / 2; + server.process_input(&c2, now); + server.process_input(&c1, now); + let s1 = send_something(&mut server, now); + let s2 = send_something(&mut server, now); + now += RTT / 2; + client.process_input(&s2, now); + let ack = client.process(Some(&s1), now).dgram(); + assert!(ack.is_some()); + + // The client should be idle now, but with a different timer. + if let Output::Callback(t) = client.process(None, now) { + assert!(t > LOWER_TIMEOUT); + } else { + panic!("Client not idle"); + } + + // The server should go idle after the ACK, but again with a larger timeout. + now += RTT / 2; + if let Output::Callback(t) = client.process(ack.as_ref(), now) { + assert!(t > LOWER_TIMEOUT); + } else { + panic!("Client not idle"); + } +} + +#[test] +fn idle_send_packet1() { + const DELTA: Duration = Duration::from_millis(10); + + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + connect_force_idle(&mut client, &mut server); + + let timeout = client.process(None, now).callback(); + assert_eq!(timeout, default_timeout()); + + now += Duration::from_secs(10); + let dgram = send_and_receive(&mut client, &mut server, now); + assert!(dgram.is_some()); // the server will want to ACK, we can drop that. + + // Still connected after 39 seconds because idle timer reset by the + // outgoing packet. + now += default_timeout() - DELTA; + let dgram = client.process(None, now).dgram(); + assert!(dgram.is_some()); // PTO + assert!(client.state().connected()); + + // Not connected after 40 seconds. + now += DELTA; + let out = client.process(None, now); + assert!(matches!(out, Output::None)); + assert!(client.state().closed()); +} + +#[test] +fn idle_send_packet2() { + const GAP: Duration = Duration::from_secs(10); + const DELTA: Duration = Duration::from_millis(10); + + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let mut now = now(); + + let timeout = client.process(None, now).callback(); + assert_eq!(timeout, default_timeout()); + + // First transmission at t=GAP. + now += GAP; + mem::drop(send_something(&mut client, now)); + + // Second transmission at t=2*GAP. + mem::drop(send_something(&mut client, now + GAP)); + assert!((GAP * 2 + DELTA) < default_timeout()); + + // Still connected just before GAP + default_timeout(). + now += default_timeout() - DELTA; + let dgram = client.process(None, now).dgram(); + assert!(dgram.is_some()); // PTO + assert!(matches!(client.state(), State::Confirmed)); + + // Not connected after 40 seconds because timer not reset by second + // outgoing packet + now += DELTA; + let out = client.process(None, now); + assert!(matches!(out, Output::None)); + assert!(matches!(client.state(), State::Closed(_))); +} + +#[test] +fn idle_recv_packet() { + const FUDGE: Duration = Duration::from_millis(10); + + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let mut now = now(); + + let res = client.process(None, now); + assert_eq!(res, Output::Callback(default_timeout())); + + let stream = client.stream_create(StreamType::BiDi).unwrap(); + assert_eq!(stream, 0); + assert_eq!(client.stream_send(stream, b"hello").unwrap(), 5); + + // Respond with another packet. + // Note that it is important that this not result in the RTT increasing above 0. + // Otherwise, the eventual timeout will be extended (and we're not testing that). + now += Duration::from_secs(10); + let out = client.process(None, now); + server.process_input(&out.dgram().unwrap(), now); + assert_eq!(server.stream_send(stream, b"world").unwrap(), 5); + let out = server.process_output(now); + assert_ne!(out.as_dgram_ref(), None); + mem::drop(client.process(out.as_dgram_ref(), now)); + assert!(matches!(client.state(), State::Confirmed)); + + // Add a little less than the idle timeout and we're still connected. + now += default_timeout() - FUDGE; + mem::drop(client.process(None, now)); + assert!(matches!(client.state(), State::Confirmed)); + + now += FUDGE; + mem::drop(client.process(None, now)); + + assert!(matches!(client.state(), State::Closed(_))); +} + +/// Caching packets should not cause the connection to become idle. +/// This requires a few tricks to keep the connection from going +/// idle while preventing any progress on the handshake. +#[test] +fn idle_caching() { + let mut client = default_client(); + let mut server = default_server(); + let start = now(); + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + + // Perform the first round trip, but drop the Initial from the server. + // The client then caches the Handshake packet. + let dgram = client.process_output(start).dgram(); + let dgram = server.process(dgram.as_ref(), start).dgram(); + let (_, handshake) = split_datagram(&dgram.unwrap()); + client.process_input(&handshake.unwrap(), start); + + // Perform an exchange and keep the connection alive. + // Only allow a packet containing a PING to pass. + let middle = start + AT_LEAST_PTO; + mem::drop(client.process_output(middle)); + let dgram = client.process_output(middle).dgram(); + + // Get the server to send its first probe and throw that away. + mem::drop(server.process_output(middle).dgram()); + // Now let the server process the client PING. This causes the server + // to send CRYPTO frames again, so manually extract and discard those. + let ping_before_s = server.stats().frame_rx.ping; + server.process_input(&dgram.unwrap(), middle); + assert_eq!(server.stats().frame_rx.ping, ping_before_s + 1); + let mut tokens = Vec::new(); + server + .crypto + .streams + .write_frame( + PacketNumberSpace::Initial, + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ) + .unwrap(); + assert_eq!(tokens.len(), 1); + tokens.clear(); + server + .crypto + .streams + .write_frame( + PacketNumberSpace::Initial, + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ) + .unwrap(); + assert!(tokens.is_empty()); + let dgram = server.process_output(middle).dgram(); + + // Now only allow the Initial packet from the server through; + // it shouldn't contain a CRYPTO frame. + let (initial, _) = split_datagram(&dgram.unwrap()); + let ping_before_c = client.stats().frame_rx.ping; + let ack_before = client.stats().frame_rx.ack; + client.process_input(&initial, middle); + assert_eq!(client.stats().frame_rx.ping, ping_before_c + 1); + assert_eq!(client.stats().frame_rx.ack, ack_before + 1); + + let end = start + default_timeout() + (AT_LEAST_PTO / 2); + // Now let the server Initial through, with the CRYPTO frame. + let dgram = server.process_output(end).dgram(); + let (initial, _) = split_datagram(&dgram.unwrap()); + neqo_common::qwarn!("client ingests initial, finally"); + mem::drop(client.process(Some(&initial), end)); + maybe_authenticate(&mut client); + let dgram = client.process_output(end).dgram(); + let dgram = server.process(dgram.as_ref(), end).dgram(); + client.process_input(&dgram.unwrap(), end); + assert_eq!(*client.state(), State::Confirmed); + assert_eq!(*server.state(), State::Confirmed); +} + +/// This function opens a bidirectional stream and leaves both endpoints +/// idle, with the stream left open. +/// The stream ID of that stream is returned (along with the new time). +fn create_stream_idle_rtt( + initiator: &mut Connection, + responder: &mut Connection, + mut now: Instant, + rtt: Duration, +) -> (Instant, StreamId) { + let check_idle = |endpoint: &mut Connection, now: Instant| { + let delay = endpoint.process_output(now).callback(); + qtrace!([endpoint], "idle timeout {:?}", delay); + if rtt < default_timeout() / 4 { + assert_eq!(default_timeout(), delay); + } else { + assert!(delay > default_timeout()); + } + }; + + // Exchange a message each way on a stream. + let stream = initiator.stream_create(StreamType::BiDi).unwrap(); + _ = initiator.stream_send(stream, DEFAULT_STREAM_DATA).unwrap(); + let req = initiator.process_output(now).dgram(); + now += rtt / 2; + responder.process_input(&req.unwrap(), now); + + // Reordering two packets from the responder forces the initiator to be idle. + _ = responder.stream_send(stream, DEFAULT_STREAM_DATA).unwrap(); + let resp1 = responder.process_output(now).dgram(); + _ = responder.stream_send(stream, DEFAULT_STREAM_DATA).unwrap(); + let resp2 = responder.process_output(now).dgram(); + + now += rtt / 2; + initiator.process_input(&resp2.unwrap(), now); + initiator.process_input(&resp1.unwrap(), now); + let ack = initiator.process_output(now).dgram(); + assert!(ack.is_some()); + check_idle(initiator, now); + + // Receiving the ACK should return the responder to idle too. + now += rtt / 2; + responder.process_input(&ack.unwrap(), now); + check_idle(responder, now); + + (now, stream) +} + +fn create_stream_idle(initiator: &mut Connection, responder: &mut Connection) -> StreamId { + let (_, stream) = create_stream_idle_rtt(initiator, responder, now(), Duration::new(0, 0)); + stream +} + +fn assert_idle(endpoint: &mut Connection, now: Instant, expected: Duration) { + assert_eq!(endpoint.process_output(now).callback(), expected); +} + +/// The creator of a stream marks it as important enough to use a keep-alive. +#[test] +fn keep_alive_initiator() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let stream = create_stream_idle(&mut server, &mut client); + let mut now = now(); + + // Marking the stream for keep-alive changes the idle timeout. + server.stream_keep_alive(stream, true).unwrap(); + assert_idle(&mut server, now, default_timeout() / 2); + + // Wait that long and the server should send a PING frame. + now += default_timeout() / 2; + let pings_before = server.stats().frame_tx.ping; + let ping = server.process_output(now).dgram(); + assert!(ping.is_some()); + assert_eq!(server.stats().frame_tx.ping, pings_before + 1); + + // Exchange ack for the PING. + let out = client.process(ping.as_ref(), now).dgram(); + let out = server.process(out.as_ref(), now).dgram(); + assert!(client.process(out.as_ref(), now).dgram().is_none()); + + // Check that there will be next keep-alive ping after default_timeout() / 2. + assert_idle(&mut server, now, default_timeout() / 2); + now += default_timeout() / 2; + let pings_before2 = server.stats().frame_tx.ping; + let ping = server.process_output(now).dgram(); + assert!(ping.is_some()); + assert_eq!(server.stats().frame_tx.ping, pings_before2 + 1); +} + +/// Test a keep-alive ping is retransmitted if lost. +#[test] +fn keep_alive_lost() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let stream = create_stream_idle(&mut server, &mut client); + let mut now = now(); + + // Marking the stream for keep-alive changes the idle timeout. + server.stream_keep_alive(stream, true).unwrap(); + assert_idle(&mut server, now, default_timeout() / 2); + + // Wait that long and the server should send a PING frame. + now += default_timeout() / 2; + let pings_before = server.stats().frame_tx.ping; + let ping = server.process_output(now).dgram(); + assert!(ping.is_some()); + assert_eq!(server.stats().frame_tx.ping, pings_before + 1); + + // Wait for ping to be marked lost. + assert!(server.process_output(now).callback() < AT_LEAST_PTO); + now += AT_LEAST_PTO; + let pings_before2 = server.stats().frame_tx.ping; + let ping = server.process_output(now).dgram(); + assert!(ping.is_some()); + assert_eq!(server.stats().frame_tx.ping, pings_before2 + 1); + + // Exchange ack for the PING. + let out = client.process(ping.as_ref(), now).dgram(); + + now += Duration::from_millis(20); + let out = server.process(out.as_ref(), now).dgram(); + + assert!(client.process(out.as_ref(), now).dgram().is_none()); + + // TODO: if we run server.process with current value of now, the server will + // return some small timeout for the recovry although it does not have + // any outstanding data. Therefore we call it after AT_LEAST_PTO. + now += AT_LEAST_PTO; + assert_idle(&mut server, now, default_timeout() / 2 - AT_LEAST_PTO); +} + +/// The other peer can also keep it alive. +#[test] +fn keep_alive_responder() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let stream = create_stream_idle(&mut server, &mut client); + let mut now = now(); + + // Marking the stream for keep-alive changes the idle timeout. + client.stream_keep_alive(stream, true).unwrap(); + assert_idle(&mut client, now, default_timeout() / 2); + + // Wait that long and the client should send a PING frame. + now += default_timeout() / 2; + let pings_before = client.stats().frame_tx.ping; + let ping = client.process_output(now).dgram(); + assert!(ping.is_some()); + assert_eq!(client.stats().frame_tx.ping, pings_before + 1); +} + +/// Unmark a stream as being keep-alive. +#[test] +fn keep_alive_unmark() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let stream = create_stream_idle(&mut client, &mut server); + + client.stream_keep_alive(stream, true).unwrap(); + assert_idle(&mut client, now(), default_timeout() / 2); + + client.stream_keep_alive(stream, false).unwrap(); + assert_idle(&mut client, now(), default_timeout()); +} + +/// The sender has something to send. Make it send it +/// and cause the receiver to become idle by sending something +/// else, reordering the packets, and consuming the ACK. +/// Note that the sender might not be idle if the thing that it +/// sends results in something in addition to an ACK. +fn transfer_force_idle(sender: &mut Connection, receiver: &mut Connection) { + let dgram = sender.process_output(now()).dgram(); + let chaff = send_something(sender, now()); + receiver.process_input(&chaff, now()); + receiver.process_input(&dgram.unwrap(), now()); + let ack = receiver.process_output(now()).dgram(); + sender.process_input(&ack.unwrap(), now()); +} + +/// Receiving the end of the stream stops keep-alives for that stream. +/// Even if that data hasn't been read. +#[test] +fn keep_alive_close() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let stream = create_stream_idle(&mut client, &mut server); + + client.stream_keep_alive(stream, true).unwrap(); + assert_idle(&mut client, now(), default_timeout() / 2); + + client.stream_close_send(stream).unwrap(); + transfer_force_idle(&mut client, &mut server); + assert_idle(&mut client, now(), default_timeout() / 2); + + server.stream_close_send(stream).unwrap(); + transfer_force_idle(&mut server, &mut client); + assert_idle(&mut client, now(), default_timeout()); +} + +/// Receiving `RESET_STREAM` stops keep-alives for that stream, but only once +/// the sending side is also closed. +#[test] +fn keep_alive_reset() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let stream = create_stream_idle(&mut client, &mut server); + + client.stream_keep_alive(stream, true).unwrap(); + assert_idle(&mut client, now(), default_timeout() / 2); + + client.stream_close_send(stream).unwrap(); + transfer_force_idle(&mut client, &mut server); + assert_idle(&mut client, now(), default_timeout() / 2); + + server.stream_reset_send(stream, 0).unwrap(); + transfer_force_idle(&mut server, &mut client); + assert_idle(&mut client, now(), default_timeout()); + + // The client will fade away from here. + let t = now() + (default_timeout() / 2); + assert_eq!(client.process_output(t).callback(), default_timeout() / 2); + let t = now() + default_timeout(); + assert_eq!(client.process_output(t), Output::None); +} + +/// Stopping sending also cancels the keep-alive. +#[test] +fn keep_alive_stop_sending() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let stream = create_stream_idle(&mut client, &mut server); + + client.stream_keep_alive(stream, true).unwrap(); + assert_idle(&mut client, now(), default_timeout() / 2); + + client.stream_close_send(stream).unwrap(); + client.stream_stop_sending(stream, 0).unwrap(); + transfer_force_idle(&mut client, &mut server); + // The server will have sent RESET_STREAM, which the client will + // want to acknowledge, so force that out. + let junk = send_something(&mut server, now()); + let ack = client.process(Some(&junk), now()).dgram(); + assert!(ack.is_some()); + + // Now the client should be idle. + assert_idle(&mut client, now(), default_timeout()); +} + +/// Multiple active streams are tracked properly. +#[test] +fn keep_alive_multiple_stop() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let stream = create_stream_idle(&mut client, &mut server); + + client.stream_keep_alive(stream, true).unwrap(); + assert_idle(&mut client, now(), default_timeout() / 2); + + let other = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_keep_alive(other, true).unwrap(); + assert_idle(&mut client, now(), default_timeout() / 2); + + client.stream_keep_alive(stream, false).unwrap(); + assert_idle(&mut client, now(), default_timeout() / 2); + + client.stream_keep_alive(other, false).unwrap(); + assert_idle(&mut client, now(), default_timeout()); +} + +/// If the RTT is too long relative to the idle timeout, the keep-alive is large too. +#[test] +fn keep_alive_large_rtt() { + let mut client = default_client(); + let mut server = default_server(); + // Use an RTT that is large enough to cause the PTO timer to exceed half + // the idle timeout. + let rtt = default_timeout() * 3 / 4; + let now = connect_with_rtt(&mut client, &mut server, now(), rtt); + let (now, stream) = create_stream_idle_rtt(&mut server, &mut client, now, rtt); + + // Calculating PTO here is tricky as RTTvar has eroded after multiple round trips. + // Just check that the delay is larger than the baseline and the RTT. + for endpoint in &mut [client, server] { + endpoint.stream_keep_alive(stream, true).unwrap(); + let delay = endpoint.process_output(now).callback(); + qtrace!([endpoint], "new delay {:?}", delay); + assert!(delay > default_timeout() / 2); + assert!(delay > rtt); + } +} + +/// Only the recipient of a unidirectional stream can keep it alive. +#[test] +fn keep_alive_uni() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let stream = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_keep_alive(stream, true).unwrap_err(); + _ = client.stream_send(stream, DEFAULT_STREAM_DATA).unwrap(); + let dgram = client.process_output(now()).dgram(); + + server.process_input(&dgram.unwrap(), now()); + server.stream_keep_alive(stream, true).unwrap(); +} + +/// Test a keep-alive ping is send if there are outstading ack-eliciting packets and that +/// the connection is closed after the idle timeout passes. +#[test] +fn keep_alive_with_ack_eliciting_packet_lost() { + const RTT: Duration = Duration::from_millis(500); // PTO will be ~1.1125s + + // The idle time out will be set to ~ 5 * PTO. (IDLE_TIMEOUT/2 > pto and IDLE_TIMEOUT/2 < pto + // + 2pto) After handshake all packets will be lost. The following steps will happen after + // the handshake: + // - data will be sent on a stream that is marked for keep-alive, (at start time) + // - PTO timer will trigger first, and the data will be retransmited toghether with a PING, (at + // the start time + pto) + // - keep-alive timer will trigger and a keep-alive PING will be sent, (at the start time + + // IDLE_TIMEOUT / 2) + // - PTO timer will trigger again. (at the start time + pto + 2*pto) + // - Idle time out will trigger (at the timeout + IDLE_TIMEOUT) + const IDLE_TIMEOUT: Duration = Duration::from_millis(6000); + + let mut client = new_client(ConnectionParameters::default().idle_timeout(IDLE_TIMEOUT)); + let mut server = default_server(); + let mut now = connect_rtt_idle(&mut client, &mut server, RTT); + // connect_rtt_idle increase now by RTT / 2; + now -= RTT / 2; + assert_idle(&mut client, now, IDLE_TIMEOUT); + + // Create a stream. + let stream = client.stream_create(StreamType::BiDi).unwrap(); + // Marking the stream for keep-alive changes the idle timeout. + client.stream_keep_alive(stream, true).unwrap(); + assert_idle(&mut client, now, IDLE_TIMEOUT / 2); + + // Send data on the stream that will be lost. + _ = client.stream_send(stream, DEFAULT_STREAM_DATA).unwrap(); + let _lost_packet = client.process_output(now).dgram(); + + let pto = client.process_output(now).callback(); + // Wait for packet to be marked lost. + assert!(pto < IDLE_TIMEOUT / 2); + now += pto; + let retransmit = client.process_output(now).dgram(); + assert!(retransmit.is_some()); + let retransmit = client.process_output(now).dgram(); + assert!(retransmit.is_some()); + + // The next callback should be for an idle PING. + assert_eq!( + client.process_output(now).callback(), + IDLE_TIMEOUT / 2 - pto + ); + + // Wait that long and the client should send a PING frame. + now += IDLE_TIMEOUT / 2 - pto; + let pings_before = client.stats().frame_tx.ping; + let ping = client.process_output(now).dgram(); + assert!(ping.is_some()); + assert_eq!(client.stats().frame_tx.ping, pings_before + 1); + + // The next callback is for a PTO, the PTO timer is 2 * pto now. + assert_eq!(client.process_output(now).callback(), pto * 2); + now += pto * 2; + // Now we will retransmit stream data. + let retransmit = client.process_output(now).dgram(); + assert!(retransmit.is_some()); + let retransmit = client.process_output(now).dgram(); + assert!(retransmit.is_some()); + + // The next callback will be an idle timeout. + assert_eq!( + client.process_output(now).callback(), + IDLE_TIMEOUT / 2 - 2 * pto + ); + + now += IDLE_TIMEOUT / 2 - 2 * pto; + let out = client.process_output(now); + assert!(matches!(out, Output::None)); + assert!(matches!(client.state(), State::Closed(_))); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/keys.rs b/third_party/rust/neqo-transport/src/connection/tests/keys.rs new file mode 100644 index 0000000000..c247bba670 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/keys.rs @@ -0,0 +1,346 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::mem; + +use neqo_common::{qdebug, Datagram}; +use test_fixture::{self, now}; + +use super::{ + super::{ + super::{ConnectionError, ERROR_AEAD_LIMIT_REACHED}, + Connection, ConnectionParameters, Error, Output, State, StreamType, + }, + connect, connect_force_idle, default_client, default_server, maybe_authenticate, + send_and_receive, send_something, AT_LEAST_PTO, +}; +use crate::{ + crypto::{OVERWRITE_INVOCATIONS, UPDATE_WRITE_KEYS_AT}, + packet::PacketNumber, + path::PATH_MTU_V6, +}; + +fn check_discarded( + peer: &mut Connection, + pkt: &Datagram, + response: bool, + dropped: usize, + dups: usize, +) { + // Make sure to flush any saved datagrams before doing this. + mem::drop(peer.process_output(now())); + + let before = peer.stats(); + let out = peer.process(Some(pkt), now()); + assert_eq!(out.as_dgram_ref().is_some(), response); + let after = peer.stats(); + assert_eq!(dropped, after.dropped_rx - before.dropped_rx); + assert_eq!(dups, after.dups_rx - before.dups_rx); +} + +fn assert_update_blocked(c: &mut Connection) { + assert_eq!( + c.initiate_key_update().unwrap_err(), + Error::KeyUpdateBlocked + ); +} + +fn overwrite_invocations(n: PacketNumber) { + OVERWRITE_INVOCATIONS.with(|v| { + *v.borrow_mut() = Some(n); + }); +} + +#[test] +fn discarded_initial_keys() { + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let init_pkt_c = client.process(None, now()).dgram(); + assert!(init_pkt_c.is_some()); + assert_eq!(init_pkt_c.as_ref().unwrap().len(), PATH_MTU_V6); + + qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); + let mut server = default_server(); + let init_pkt_s = server.process(init_pkt_c.as_ref(), now()).dgram(); + assert!(init_pkt_s.is_some()); + + qdebug!("---- client: cert verification"); + let out = client.process(init_pkt_s.as_ref(), now()).dgram(); + assert!(out.is_some()); + + // The client has received a handshake packet. It will remove the Initial keys. + // We will check this by processing init_pkt_s a second time. + // The initial packet should be dropped. The packet contains a Handshake packet as well, which + // will be marked as dup. And it will contain padding, which will be "dropped". + // The client will generate a Handshake packet here to avoid stalling. + check_discarded(&mut client, &init_pkt_s.unwrap(), true, 2, 1); + + assert!(maybe_authenticate(&mut client)); + + // The server has not removed the Initial keys yet, because it has not yet received a Handshake + // packet from the client. + // We will check this by processing init_pkt_c a second time. + // The dropped packet is padding. The Initial packet has been mark dup. + check_discarded(&mut server, &init_pkt_c.clone().unwrap(), false, 1, 1); + + qdebug!("---- client: SH..FIN -> FIN"); + let out = client.process(None, now()).dgram(); + assert!(out.is_some()); + + // The server will process the first Handshake packet. + // After this the Initial keys will be dropped. + let out = server.process(out.as_ref(), now()).dgram(); + assert!(out.is_some()); + + // Check that the Initial keys are dropped at the server + // We will check this by processing init_pkt_c a third time. + // The Initial packet has been dropped and padding that follows it. + // There is no dups, everything has been dropped. + check_discarded(&mut server, &init_pkt_c.unwrap(), false, 1, 0); +} + +#[test] +fn key_update_client() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + let mut now = now(); + + assert_eq!(client.get_epochs(), (Some(3), Some(3))); // (write, read) + assert_eq!(server.get_epochs(), (Some(3), Some(3))); + + assert!(client.initiate_key_update().is_ok()); + assert_update_blocked(&mut client); + + // Initiating an update should only increase the write epoch. + let idle_timeout = ConnectionParameters::default().get_idle_timeout(); + assert_eq!(Output::Callback(idle_timeout), client.process(None, now)); + assert_eq!(client.get_epochs(), (Some(4), Some(3))); + + // Send something to propagate the update. + // Note that the server will acknowledge immediately when RTT is zero. + assert!(send_and_receive(&mut client, &mut server, now).is_some()); + + // The server should now be waiting to discharge read keys. + assert_eq!(server.get_epochs(), (Some(4), Some(3))); + let res = server.process(None, now); + if let Output::Callback(t) = res { + assert!(t < idle_timeout); + } else { + panic!("server should now be waiting to clear keys"); + } + + // Without having had time to purge old keys, more updates are blocked. + // The spec would permits it at this point, but we are more conservative. + assert_update_blocked(&mut client); + // The server can't update until it receives an ACK for a packet. + assert_update_blocked(&mut server); + + // Waiting now for at least a PTO should cause the server to drop old keys. + // But at this point the client hasn't received a key update from the server. + // It will be stuck with old keys. + now += AT_LEAST_PTO; + let dgram = client.process(None, now).dgram(); + assert!(dgram.is_some()); // Drop this packet. + assert_eq!(client.get_epochs(), (Some(4), Some(3))); + mem::drop(server.process(None, now)); + assert_eq!(server.get_epochs(), (Some(4), Some(4))); + + // Even though the server has updated, it hasn't received an ACK yet. + assert_update_blocked(&mut server); + + // Now get an ACK from the server. + // The previous PTO packet (see above) was dropped, so we should get an ACK here. + let dgram = send_and_receive(&mut client, &mut server, now); + assert!(dgram.is_some()); + let res = client.process(dgram.as_ref(), now); + // This is the first packet that the client has received from the server + // with new keys, so its read timer just started. + if let Output::Callback(t) = res { + assert!(t < ConnectionParameters::default().get_idle_timeout()); + } else { + panic!("client should now be waiting to clear keys"); + } + + assert_update_blocked(&mut client); + assert_eq!(client.get_epochs(), (Some(4), Some(3))); + // The server can't update until it gets something from the client. + assert_update_blocked(&mut server); + + now += AT_LEAST_PTO; + mem::drop(client.process(None, now)); + assert_eq!(client.get_epochs(), (Some(4), Some(4))); +} + +#[test] +fn key_update_consecutive() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let now = now(); + + assert!(server.initiate_key_update().is_ok()); + assert_eq!(server.get_epochs(), (Some(4), Some(3))); + + // Server sends something. + // Send twice and drop the first to induce an ACK from the client. + mem::drop(send_something(&mut server, now)); // Drop this. + + // Another packet from the server will cause the client to ACK and update keys. + let dgram = send_and_receive(&mut server, &mut client, now); + assert!(dgram.is_some()); + assert_eq!(client.get_epochs(), (Some(4), Some(3))); + + // Have the server process the ACK. + if let Output::Callback(_) = server.process(dgram.as_ref(), now) { + assert_eq!(server.get_epochs(), (Some(4), Some(3))); + // Now move the server temporarily into the future so that it + // rotates the keys. The client stays in the present. + mem::drop(server.process(None, now + AT_LEAST_PTO)); + assert_eq!(server.get_epochs(), (Some(4), Some(4))); + } else { + panic!("server should have a timer set"); + } + + // Now update keys on the server again. + assert!(server.initiate_key_update().is_ok()); + assert_eq!(server.get_epochs(), (Some(5), Some(4))); + + let dgram = send_something(&mut server, now + AT_LEAST_PTO); + + // However, as the server didn't wait long enough to update again, the + // client hasn't rotated its keys, so the packet gets dropped. + check_discarded(&mut client, &dgram, false, 1, 0); +} + +// Key updates can't be initiated too early. +#[test] +fn key_update_before_confirmed() { + let mut client = default_client(); + assert_update_blocked(&mut client); + let mut server = default_server(); + assert_update_blocked(&mut server); + + // Client Initial + let dgram = client.process(None, now()).dgram(); + assert!(dgram.is_some()); + assert_update_blocked(&mut client); + + // Server Initial + Handshake + let dgram = server.process(dgram.as_ref(), now()).dgram(); + assert!(dgram.is_some()); + assert_update_blocked(&mut server); + + // Client Handshake + client.process_input(&dgram.unwrap(), now()); + assert_update_blocked(&mut client); + + assert!(maybe_authenticate(&mut client)); + assert_update_blocked(&mut client); + + let dgram = client.process(None, now()).dgram(); + assert!(dgram.is_some()); + assert_update_blocked(&mut client); + + // Server HANDSHAKE_DONE + let dgram = server.process(dgram.as_ref(), now()).dgram(); + assert!(dgram.is_some()); + assert!(server.initiate_key_update().is_ok()); + + // Client receives HANDSHAKE_DONE + let dgram = client.process(dgram.as_ref(), now()).dgram(); + assert!(dgram.is_none()); + assert!(client.initiate_key_update().is_ok()); +} + +#[test] +fn exhaust_write_keys() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + overwrite_invocations(0); + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert!(client.stream_send(stream_id, b"explode!").is_ok()); + let dgram = client.process_output(now()).dgram(); + assert!(dgram.is_none()); + assert!(matches!( + client.state(), + State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + )); +} + +#[test] +fn exhaust_read_keys() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let dgram = send_something(&mut client, now()); + + overwrite_invocations(0); + let dgram = server.process(Some(&dgram), now()).dgram(); + assert!(matches!( + server.state(), + State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + )); + + client.process_input(&dgram.unwrap(), now()); + assert!(matches!( + client.state(), + State::Draining { + error: ConnectionError::Transport(Error::PeerError(ERROR_AEAD_LIMIT_REACHED)), + .. + } + )); +} + +#[test] +fn automatic_update_write_keys() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + overwrite_invocations(UPDATE_WRITE_KEYS_AT); + mem::drop(send_something(&mut client, now())); + assert_eq!(client.get_epochs(), (Some(4), Some(3))); +} + +#[test] +fn automatic_update_write_keys_later() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + overwrite_invocations(UPDATE_WRITE_KEYS_AT + 2); + // No update after the first. + mem::drop(send_something(&mut client, now())); + assert_eq!(client.get_epochs(), (Some(3), Some(3))); + // The second will update though. + mem::drop(send_something(&mut client, now())); + assert_eq!(client.get_epochs(), (Some(4), Some(3))); +} + +#[test] +fn automatic_update_write_keys_blocked() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + // An outstanding key update will block the automatic update. + client.initiate_key_update().unwrap(); + + overwrite_invocations(UPDATE_WRITE_KEYS_AT); + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert!(client.stream_send(stream_id, b"explode!").is_ok()); + let dgram = client.process_output(now()).dgram(); + // Not being able to update is fatal. + assert!(dgram.is_none()); + assert!(matches!( + client.state(), + State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + )); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/migration.rs b/third_party/rust/neqo-transport/src/connection/tests/migration.rs new file mode 100644 index 0000000000..8307a7dd84 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/migration.rs @@ -0,0 +1,953 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cell::RefCell, + net::{IpAddr, Ipv6Addr, SocketAddr}, + rc::Rc, + time::{Duration, Instant}, +}; + +use neqo_common::{Datagram, Decoder}; +use test_fixture::{ + self, addr, addr_v4, + assertions::{assert_v4_path, assert_v6_path}, + fixture_init, new_neqo_qlog, now, +}; + +use super::{ + super::{Connection, Output, State, StreamType}, + connect_fail, connect_force_idle, connect_rtt_idle, default_client, default_server, + maybe_authenticate, new_client, new_server, send_something, CountingConnectionIdGenerator, +}; +use crate::{ + cid::LOCAL_ACTIVE_CID_LIMIT, + connection::tests::send_something_paced, + frame::FRAME_TYPE_NEW_CONNECTION_ID, + packet::PacketBuilder, + path::{PATH_MTU_V4, PATH_MTU_V6}, + tparams::{self, PreferredAddress, TransportParameter}, + ConnectionError, ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef, + ConnectionParameters, EmptyConnectionIdGenerator, Error, +}; + +/// This should be a valid-seeming transport parameter. +/// And it should have different values to `addr` and `addr_v4`. +const SAMPLE_PREFERRED_ADDRESS: &[u8] = &[ + 0xc0, 0x00, 0x02, 0x02, 0x01, 0xbb, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0xbb, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, 0x03, 0x03, + 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, +]; + +// These tests generally use two paths: +// The connection is established on a path with the same IPv6 address on both ends. +// Migrations move to a path with the same IPv4 address on both ends. +// This simplifies validation as the same assertions can be used for client and server. +// The risk is that there is a place where source/destination local/remote is inverted. + +fn loopback() -> SocketAddr { + SocketAddr::new(IpAddr::V6(Ipv6Addr::from(1)), 443) +} + +fn change_path(d: &Datagram, a: SocketAddr) -> Datagram { + Datagram::new(a, a, d.tos(), d.ttl(), &d[..]) +} + +fn new_port(a: SocketAddr) -> SocketAddr { + let (port, _) = a.port().overflowing_add(410); + SocketAddr::new(a.ip(), port) +} + +fn change_source_port(d: &Datagram) -> Datagram { + Datagram::new( + new_port(d.source()), + d.destination(), + d.tos(), + d.ttl(), + &d[..], + ) +} + +/// As these tests use a new path, that path often has a non-zero RTT. +/// Pacing can be a problem when testing that path. This skips time forward. +fn skip_pacing(c: &mut Connection, now: Instant) -> Instant { + let pacing = c.process_output(now).callback(); + assert_ne!(pacing, Duration::new(0, 0)); + now + pacing +} + +#[test] +fn rebinding_port() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let dgram = send_something(&mut client, now()); + let dgram = change_source_port(&dgram); + + server.process_input(&dgram, now()); + // Have the server send something so that it generates a packet. + let stream_id = server.stream_create(StreamType::UniDi).unwrap(); + server.stream_close_send(stream_id).unwrap(); + let dgram = server.process_output(now()).dgram(); + let dgram = dgram.unwrap(); + assert_eq!(dgram.source(), addr()); + assert_eq!(dgram.destination(), new_port(addr())); +} + +/// This simulates an attack where a valid packet is forwarded on +/// a different path. This shows how both paths are probed and the +/// server eventually returns to the original path. +#[test] +fn path_forwarding_attack() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + let mut now = now(); + + let dgram = send_something(&mut client, now); + let dgram = change_path(&dgram, addr_v4()); + server.process_input(&dgram, now); + + // The server now probes the new (primary) path. + let new_probe = server.process_output(now).dgram().unwrap(); + assert_eq!(server.stats().frame_tx.path_challenge, 1); + assert_v4_path(&new_probe, false); // Can't be padded. + + // The server also probes the old path. + let old_probe = server.process_output(now).dgram().unwrap(); + assert_eq!(server.stats().frame_tx.path_challenge, 2); + assert_v6_path(&old_probe, true); + + // New data from the server is sent on the new path, but that is + // now constrained by the amplification limit. + let stream_id = server.stream_create(StreamType::UniDi).unwrap(); + server.stream_close_send(stream_id).unwrap(); + assert!(server.process_output(now).dgram().is_none()); + + // The client should respond to the challenge on the new path. + // The server couldn't pad, so the client is also amplification limited. + let new_resp = client.process(Some(&new_probe), now).dgram().unwrap(); + assert_eq!(client.stats().frame_rx.path_challenge, 1); + assert_eq!(client.stats().frame_tx.path_challenge, 1); + assert_eq!(client.stats().frame_tx.path_response, 1); + assert_v4_path(&new_resp, false); + + // The client also responds to probes on the old path. + let old_resp = client.process(Some(&old_probe), now).dgram().unwrap(); + assert_eq!(client.stats().frame_rx.path_challenge, 2); + assert_eq!(client.stats().frame_tx.path_challenge, 1); + assert_eq!(client.stats().frame_tx.path_response, 2); + assert_v6_path(&old_resp, true); + + // But the client still sends data on the old path. + let client_data1 = send_something(&mut client, now); + assert_v6_path(&client_data1, false); // Just data. + + // Receiving the PATH_RESPONSE from the client opens the amplification + // limit enough for the server to respond. + // This is padded because it includes PATH_CHALLENGE. + let server_data1 = server.process(Some(&new_resp), now).dgram().unwrap(); + assert_v4_path(&server_data1, true); + assert_eq!(server.stats().frame_tx.path_challenge, 3); + + // The client responds to this probe on the new path. + client.process_input(&server_data1, now); + let stream_before = client.stats().frame_tx.stream; + let padded_resp = send_something(&mut client, now); + assert_eq!(stream_before, client.stats().frame_tx.stream); + assert_v4_path(&padded_resp, true); // This is padded! + + // But new data from the client stays on the old path. + let client_data2 = client.process_output(now).dgram().unwrap(); + assert_v6_path(&client_data2, false); + + // The server keeps sending on the new path. + now = skip_pacing(&mut server, now); + let server_data2 = send_something(&mut server, now); + assert_v4_path(&server_data2, false); + + // Until new data is received from the client on the old path. + server.process_input(&client_data2, now); + // The server sends a probe on the "old" path. + let server_data3 = send_something(&mut server, now); + assert_v4_path(&server_data3, true); + // But switches data transmission to the "new" path. + let server_data4 = server.process_output(now).dgram().unwrap(); + assert_v6_path(&server_data4, false); +} + +#[test] +fn migrate_immediate() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + let now = now(); + + client + .migrate(Some(addr_v4()), Some(addr_v4()), true, now) + .unwrap(); + + let client1 = send_something(&mut client, now); + assert_v4_path(&client1, true); // Contains PATH_CHALLENGE. + let client2 = send_something(&mut client, now); + assert_v4_path(&client2, false); // Doesn't. + + let server_delayed = send_something(&mut server, now); + + // The server accepts the first packet and migrates (but probes). + let server1 = server.process(Some(&client1), now).dgram().unwrap(); + assert_v4_path(&server1, true); + let server2 = server.process_output(now).dgram().unwrap(); + assert_v6_path(&server2, true); + + // The second packet has no real effect, it just elicits an ACK. + let all_before = server.stats().frame_tx.all; + let ack_before = server.stats().frame_tx.ack; + let server3 = server.process(Some(&client2), now).dgram(); + assert!(server3.is_some()); + assert_eq!(server.stats().frame_tx.all, all_before + 1); + assert_eq!(server.stats().frame_tx.ack, ack_before + 1); + + // Receiving a packet sent by the server before migration doesn't change path. + client.process_input(&server_delayed, now); + // The client has sent two unpaced packets and this new path has no RTT estimate + // so this might be paced. + let (client3, _t) = send_something_paced(&mut client, now, true); + assert_v4_path(&client3, false); +} + +/// RTT estimates for paths should be preserved across migrations. +#[test] +fn migrate_rtt() { + const RTT: Duration = Duration::from_millis(20); + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, RTT); + + client + .migrate(Some(addr_v4()), Some(addr_v4()), true, now) + .unwrap(); + // The RTT might be increased for the new path, so allow a little flexibility. + let rtt = client.paths.rtt(); + assert!(rtt > RTT); + assert!(rtt < RTT * 2); +} + +#[test] +fn migrate_immediate_fail() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + let mut now = now(); + + client + .migrate(Some(addr_v4()), Some(addr_v4()), true, now) + .unwrap(); + + let probe = client.process_output(now).dgram().unwrap(); + assert_v4_path(&probe, true); // Contains PATH_CHALLENGE. + + for _ in 0..2 { + let cb = client.process_output(now).callback(); + assert_ne!(cb, Duration::new(0, 0)); + now += cb; + + let before = client.stats().frame_tx; + let probe = client.process_output(now).dgram().unwrap(); + assert_v4_path(&probe, true); // Contains PATH_CHALLENGE. + let after = client.stats().frame_tx; + assert_eq!(after.path_challenge, before.path_challenge + 1); + assert_eq!(after.padding, before.padding + 1); + assert_eq!(after.all, before.all + 2); + + // This might be a PTO, which will result in sending a probe. + if let Some(probe) = client.process_output(now).dgram() { + assert_v4_path(&probe, false); // Contains PATH_CHALLENGE. + let after = client.stats().frame_tx; + assert_eq!(after.ping, before.ping + 1); + assert_eq!(after.all, before.all + 3); + } + } + + let pto = client.process_output(now).callback(); + assert_ne!(pto, Duration::new(0, 0)); + now += pto; + + // The client should fall back to the original path and retire the connection ID. + let fallback = client.process_output(now).dgram(); + assert_v6_path(&fallback.unwrap(), false); + assert_eq!(client.stats().frame_tx.retire_connection_id, 1); +} + +/// Migrating to the same path shouldn't do anything special, +/// except that the path is probed. +#[test] +fn migrate_same() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + let now = now(); + + client + .migrate(Some(addr()), Some(addr()), true, now) + .unwrap(); + + let probe = client.process_output(now).dgram().unwrap(); + assert_v6_path(&probe, true); // Contains PATH_CHALLENGE. + assert_eq!(client.stats().frame_tx.path_challenge, 1); + + let resp = server.process(Some(&probe), now).dgram().unwrap(); + assert_v6_path(&resp, true); + assert_eq!(server.stats().frame_tx.path_response, 1); + assert_eq!(server.stats().frame_tx.path_challenge, 0); + + // Everything continues happily. + client.process_input(&resp, now); + let contd = send_something(&mut client, now); + assert_v6_path(&contd, false); +} + +/// Migrating to the same path, if it fails, causes the connection to fail. +#[test] +fn migrate_same_fail() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + let mut now = now(); + + client + .migrate(Some(addr()), Some(addr()), true, now) + .unwrap(); + + let probe = client.process_output(now).dgram().unwrap(); + assert_v6_path(&probe, true); // Contains PATH_CHALLENGE. + + for _ in 0..2 { + let cb = client.process_output(now).callback(); + assert_ne!(cb, Duration::new(0, 0)); + now += cb; + + let before = client.stats().frame_tx; + let probe = client.process_output(now).dgram().unwrap(); + assert_v6_path(&probe, true); // Contains PATH_CHALLENGE. + let after = client.stats().frame_tx; + assert_eq!(after.path_challenge, before.path_challenge + 1); + assert_eq!(after.padding, before.padding + 1); + assert_eq!(after.all, before.all + 2); + + // This might be a PTO, which will result in sending a probe. + if let Some(probe) = client.process_output(now).dgram() { + assert_v6_path(&probe, false); // Contains PATH_CHALLENGE. + let after = client.stats().frame_tx; + assert_eq!(after.ping, before.ping + 1); + assert_eq!(after.all, before.all + 3); + } + } + + let pto = client.process_output(now).callback(); + assert_ne!(pto, Duration::new(0, 0)); + now += pto; + + // The client should mark this path as failed and close immediately. + let res = client.process_output(now); + assert!(matches!(res, Output::None)); + assert!(matches!( + client.state(), + State::Closed(ConnectionError::Transport(Error::NoAvailablePath)) + )); +} + +/// This gets the connection ID from a datagram using the default +/// connection ID generator/decoder. +fn get_cid(d: &Datagram) -> ConnectionIdRef { + let gen = CountingConnectionIdGenerator::default(); + assert_eq!(d[0] & 0x80, 0); // Only support short packets for now. + gen.decode_cid(&mut Decoder::from(&d[1..])).unwrap() +} + +fn migration(mut client: Connection) { + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + let now = now(); + + client + .migrate(Some(addr_v4()), Some(addr_v4()), false, now) + .unwrap(); + + let probe = client.process_output(now).dgram().unwrap(); + assert_v4_path(&probe, true); // Contains PATH_CHALLENGE. + assert_eq!(client.stats().frame_tx.path_challenge, 1); + let probe_cid = ConnectionId::from(get_cid(&probe)); + + let resp = server.process(Some(&probe), now).dgram().unwrap(); + assert_v4_path(&resp, true); + assert_eq!(server.stats().frame_tx.path_response, 1); + assert_eq!(server.stats().frame_tx.path_challenge, 1); + + // Data continues to be exchanged on the new path. + let client_data = send_something(&mut client, now); + assert_ne!(get_cid(&client_data), probe_cid); + assert_v6_path(&client_data, false); + server.process_input(&client_data, now); + let server_data = send_something(&mut server, now); + assert_v6_path(&server_data, false); + + // Once the client receives the probe response, it migrates to the new path. + client.process_input(&resp, now); + assert_eq!(client.stats().frame_rx.path_challenge, 1); + let migrate_client = send_something(&mut client, now); + assert_v4_path(&migrate_client, true); // Responds to server probe. + + // The server now sees the migration and will switch over. + // However, it will probe the old path again, even though it has just + // received a response to its last probe, because it needs to verify + // that the migration is genuine. + server.process_input(&migrate_client, now); + let stream_before = server.stats().frame_tx.stream; + let probe_old_server = send_something(&mut server, now); + // This is just the double-check probe; no STREAM frames. + assert_v6_path(&probe_old_server, true); + assert_eq!(server.stats().frame_tx.path_challenge, 2); + assert_eq!(server.stats().frame_tx.stream, stream_before); + + // The server then sends data on the new path. + let migrate_server = server.process_output(now).dgram().unwrap(); + assert_v4_path(&migrate_server, false); + assert_eq!(server.stats().frame_tx.path_challenge, 2); + assert_eq!(server.stats().frame_tx.stream, stream_before + 1); + + // The client receives these checks and responds to the probe, but uses the new path. + client.process_input(&migrate_server, now); + client.process_input(&probe_old_server, now); + let old_probe_resp = send_something(&mut client, now); + assert_v6_path(&old_probe_resp, true); + let client_confirmation = client.process_output(now).dgram().unwrap(); + assert_v4_path(&client_confirmation, false); + + // The server has now sent 2 packets, so it is blocked on the pacer. Wait. + let server_pacing = server.process_output(now).callback(); + assert_ne!(server_pacing, Duration::new(0, 0)); + // ... then confirm that the server sends on the new path still. + let server_confirmation = send_something(&mut server, now + server_pacing); + assert_v4_path(&server_confirmation, false); +} + +#[test] +fn migration_graceful() { + migration(default_client()); +} + +/// A client should be able to migrate when it has a zero-length connection ID. +#[test] +fn migration_client_empty_cid() { + fixture_init(); + let client = Connection::new_client( + test_fixture::DEFAULT_SERVER_NAME, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(EmptyConnectionIdGenerator::default())), + addr(), + addr(), + ConnectionParameters::default(), + now(), + ) + .unwrap(); + migration(client); +} + +/// Drive the handshake in the most expeditious fashion. +/// Returns the packet containing `HANDSHAKE_DONE` from the server. +fn fast_handshake(client: &mut Connection, server: &mut Connection) -> Option<Datagram> { + let dgram = client.process_output(now()).dgram(); + let dgram = server.process(dgram.as_ref(), now()).dgram(); + client.process_input(&dgram.unwrap(), now()); + assert!(maybe_authenticate(client)); + let dgram = client.process_output(now()).dgram(); + server.process(dgram.as_ref(), now()).dgram() +} + +fn preferred_address(hs_client: SocketAddr, hs_server: SocketAddr, preferred: SocketAddr) { + let mtu = match hs_client.ip() { + IpAddr::V4(_) => PATH_MTU_V4, + IpAddr::V6(_) => PATH_MTU_V6, + }; + let assert_orig_path = |d: &Datagram, full_mtu: bool| { + assert_eq!( + d.destination(), + if d.source() == hs_client { + hs_server + } else if d.source() == hs_server { + hs_client + } else { + panic!(); + } + ); + if full_mtu { + assert_eq!(d.len(), mtu); + } + }; + let assert_toward_spa = |d: &Datagram, full_mtu: bool| { + assert_eq!(d.destination(), preferred); + assert_eq!(d.source(), hs_client); + if full_mtu { + assert_eq!(d.len(), mtu); + } + }; + let assert_from_spa = |d: &Datagram, full_mtu: bool| { + assert_eq!(d.destination(), hs_client); + assert_eq!(d.source(), preferred); + if full_mtu { + assert_eq!(d.len(), mtu); + } + }; + + fixture_init(); + let (log, _contents) = new_neqo_qlog(); + let mut client = Connection::new_client( + test_fixture::DEFAULT_SERVER_NAME, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(EmptyConnectionIdGenerator::default())), + hs_client, + hs_server, + ConnectionParameters::default(), + now(), + ) + .unwrap(); + client.set_qlog(log); + let spa = match preferred { + SocketAddr::V6(v6) => PreferredAddress::new(None, Some(v6)), + SocketAddr::V4(v4) => PreferredAddress::new(Some(v4), None), + }; + let mut server = new_server(ConnectionParameters::default().preferred_address(spa)); + + let dgram = fast_handshake(&mut client, &mut server); + + // The client is about to process HANDSHAKE_DONE. + // It should start probing toward the server's preferred address. + let probe = client.process(dgram.as_ref(), now()).dgram().unwrap(); + assert_toward_spa(&probe, true); + assert_eq!(client.stats().frame_tx.path_challenge, 1); + assert_ne!(client.process_output(now()).callback(), Duration::new(0, 0)); + + // Data continues on the main path for the client. + let data = send_something(&mut client, now()); + assert_orig_path(&data, false); + + // The server responds to the probe. + let resp = server.process(Some(&probe), now()).dgram().unwrap(); + assert_from_spa(&resp, true); + assert_eq!(server.stats().frame_tx.path_challenge, 1); + assert_eq!(server.stats().frame_tx.path_response, 1); + + // Data continues on the main path for the server. + server.process_input(&data, now()); + let data = send_something(&mut server, now()); + assert_orig_path(&data, false); + + // Client gets the probe response back and it migrates. + client.process_input(&resp, now()); + client.process_input(&data, now()); + let data = send_something(&mut client, now()); + assert_toward_spa(&data, true); + assert_eq!(client.stats().frame_tx.stream, 2); + assert_eq!(client.stats().frame_tx.path_response, 1); + + // The server sees the migration and probes the old path. + let probe = server.process(Some(&data), now()).dgram().unwrap(); + assert_orig_path(&probe, true); + assert_eq!(server.stats().frame_tx.path_challenge, 2); + + // But data now goes on the new path. + let data = send_something(&mut server, now()); + assert_from_spa(&data, false); +} + +/// Migration works for a new port number. +#[test] +fn preferred_address_new_port() { + let a = addr(); + preferred_address(a, a, new_port(a)); +} + +/// Migration works for a new address too. +#[test] +fn preferred_address_new_address() { + let mut preferred = addr(); + preferred.set_ip(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2))); + preferred_address(addr(), addr(), preferred); +} + +/// Migration works for IPv4 addresses. +#[test] +fn preferred_address_new_port_v4() { + let a = addr_v4(); + preferred_address(a, a, new_port(a)); +} + +/// Migrating to a loopback address is OK if we started there. +#[test] +fn preferred_address_loopback() { + let a = loopback(); + preferred_address(a, a, new_port(a)); +} + +fn expect_no_migration(client: &mut Connection, server: &mut Connection) { + let dgram = fast_handshake(client, server); + + // The client won't probe now, though it could; it remains idle. + let out = client.process(dgram.as_ref(), now()); + assert_ne!(out.callback(), Duration::new(0, 0)); + + // Data continues on the main path for the client. + let data = send_something(client, now()); + assert_v6_path(&data, false); + assert_eq!(client.stats().frame_tx.path_challenge, 0); +} + +fn preferred_address_ignored(spa: PreferredAddress) { + let mut client = default_client(); + let mut server = new_server(ConnectionParameters::default().preferred_address(spa)); + + expect_no_migration(&mut client, &mut server); +} + +/// Using a loopback address in the preferred address is ignored. +#[test] +fn preferred_address_ignore_loopback() { + preferred_address_ignored(PreferredAddress::new_any(None, Some(loopback()))); +} + +/// A preferred address in the wrong address family is ignored. +#[test] +fn preferred_address_ignore_different_family() { + preferred_address_ignored(PreferredAddress::new_any(Some(addr_v4()), None)); +} + +/// Disabling preferred addresses at the client means that it ignores a perfectly +/// good preferred address. +#[test] +fn preferred_address_disabled_client() { + let mut client = new_client(ConnectionParameters::default().disable_preferred_address()); + let mut preferred = addr(); + preferred.set_ip(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2))); + let spa = PreferredAddress::new_any(None, Some(preferred)); + let mut server = new_server(ConnectionParameters::default().preferred_address(spa)); + + expect_no_migration(&mut client, &mut server); +} + +#[test] +fn preferred_address_empty_cid() { + fixture_init(); + + let spa = PreferredAddress::new_any(None, Some(new_port(addr()))); + let res = Connection::new_server( + test_fixture::DEFAULT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(EmptyConnectionIdGenerator::default())), + ConnectionParameters::default().preferred_address(spa), + ); + assert_eq!(res.unwrap_err(), Error::ConnectionIdsExhausted); +} + +/// A server cannot include a preferred address if it chooses an empty connection ID. +#[test] +fn preferred_address_server_empty_cid() { + let mut client = default_client(); + let mut server = Connection::new_server( + test_fixture::DEFAULT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(EmptyConnectionIdGenerator::default())), + ConnectionParameters::default(), + ) + .unwrap(); + + server + .set_local_tparam( + tparams::PREFERRED_ADDRESS, + TransportParameter::Bytes(SAMPLE_PREFERRED_ADDRESS.to_vec()), + ) + .unwrap(); + + connect_fail( + &mut client, + &mut server, + Error::TransportParameterError, + Error::PeerError(Error::TransportParameterError.code()), + ); +} + +/// A client shouldn't send a preferred address transport parameter. +#[test] +fn preferred_address_client() { + let mut client = default_client(); + let mut server = default_server(); + + client + .set_local_tparam( + tparams::PREFERRED_ADDRESS, + TransportParameter::Bytes(SAMPLE_PREFERRED_ADDRESS.to_vec()), + ) + .unwrap(); + + connect_fail( + &mut client, + &mut server, + Error::PeerError(Error::TransportParameterError.code()), + Error::TransportParameterError, + ); +} + +/// Test that migration isn't permitted if the connection isn't in the right state. +#[test] +fn migration_invalid_state() { + let mut client = default_client(); + assert!(client + .migrate(Some(addr()), Some(addr()), false, now()) + .is_err()); + + let mut server = default_server(); + assert!(server + .migrate(Some(addr()), Some(addr()), false, now()) + .is_err()); + connect_force_idle(&mut client, &mut server); + + assert!(server + .migrate(Some(addr()), Some(addr()), false, now()) + .is_err()); + + client.close(now(), 0, "closing"); + assert!(client + .migrate(Some(addr()), Some(addr()), false, now()) + .is_err()); + let close = client.process(None, now()).dgram(); + + let dgram = server.process(close.as_ref(), now()).dgram(); + assert!(server + .migrate(Some(addr()), Some(addr()), false, now()) + .is_err()); + + client.process_input(&dgram.unwrap(), now()); + assert!(client + .migrate(Some(addr()), Some(addr()), false, now()) + .is_err()); +} + +#[test] +fn migration_invalid_address() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let mut cant_migrate = |local, remote| { + assert_eq!( + client.migrate(local, remote, true, now()).unwrap_err(), + Error::InvalidMigration + ); + }; + + // Providing neither address is pointless and therefore an error. + cant_migrate(None, None); + + // Providing a zero port number isn't valid. + let mut zero_port = addr(); + zero_port.set_port(0); + cant_migrate(None, Some(zero_port)); + cant_migrate(Some(zero_port), None); + + // An unspecified remote address is bad. + let mut remote_unspecified = addr(); + remote_unspecified.set_ip(IpAddr::V6(Ipv6Addr::from(0))); + cant_migrate(None, Some(remote_unspecified)); + + // Mixed address families is bad. + cant_migrate(Some(addr()), Some(addr_v4())); + cant_migrate(Some(addr_v4()), Some(addr())); + + // Loopback to non-loopback is bad. + cant_migrate(Some(addr()), Some(loopback())); + cant_migrate(Some(loopback()), Some(addr())); + assert_eq!( + client + .migrate(Some(addr()), Some(loopback()), true, now()) + .unwrap_err(), + Error::InvalidMigration + ); + assert_eq!( + client + .migrate(Some(loopback()), Some(addr()), true, now()) + .unwrap_err(), + Error::InvalidMigration + ); +} + +/// This inserts a frame into packets that provides a single new +/// connection ID and retires all others. +struct RetireAll { + cid_gen: Rc<RefCell<dyn ConnectionIdGenerator>>, +} + +impl crate::connection::test_internal::FrameWriter for RetireAll { + fn write_frames(&mut self, builder: &mut PacketBuilder) { + // Use a sequence number that is large enough that all existing values + // will be lower (so they get retired). As the code doesn't care about + // gaps in sequence numbers, this is safe, even though the gap might + // hint that there are more outstanding connection IDs that are allowed. + const SEQNO: u64 = 100; + let cid = self.cid_gen.borrow_mut().generate_cid().unwrap(); + builder + .encode_varint(FRAME_TYPE_NEW_CONNECTION_ID) + .encode_varint(SEQNO) + .encode_varint(SEQNO) // Retire Prior To + .encode_vec(1, &cid) + .encode(&[0x7f; 16]); + } +} + +/// Test that forcing retirement of connection IDs forces retirement of all active +/// connection IDs and the use of of newer one. +#[test] +fn retire_all() { + let mut client = default_client(); + let cid_gen: Rc<RefCell<dyn ConnectionIdGenerator>> = + Rc::new(RefCell::new(CountingConnectionIdGenerator::default())); + let mut server = Connection::new_server( + test_fixture::DEFAULT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::clone(&cid_gen), + ConnectionParameters::default(), + ) + .unwrap(); + connect_force_idle(&mut client, &mut server); + + let original_cid = ConnectionId::from(get_cid(&send_something(&mut client, now()))); + + server.test_frame_writer = Some(Box::new(RetireAll { cid_gen })); + let ncid = send_something(&mut server, now()); + server.test_frame_writer = None; + + let new_cid_before = client.stats().frame_rx.new_connection_id; + let retire_cid_before = client.stats().frame_tx.retire_connection_id; + client.process_input(&ncid, now()); + let retire = send_something(&mut client, now()); + assert_eq!( + client.stats().frame_rx.new_connection_id, + new_cid_before + 1 + ); + assert_eq!( + client.stats().frame_tx.retire_connection_id, + retire_cid_before + LOCAL_ACTIVE_CID_LIMIT + ); + + assert_ne!(get_cid(&retire), original_cid); +} + +/// During a graceful migration, if the probed path can't get a new connection ID due +/// to being forced to retire the one it is using, the migration will fail. +#[test] +fn retire_prior_to_migration_failure() { + let mut client = default_client(); + let cid_gen: Rc<RefCell<dyn ConnectionIdGenerator>> = + Rc::new(RefCell::new(CountingConnectionIdGenerator::default())); + let mut server = Connection::new_server( + test_fixture::DEFAULT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::clone(&cid_gen), + ConnectionParameters::default(), + ) + .unwrap(); + connect_force_idle(&mut client, &mut server); + + let original_cid = ConnectionId::from(get_cid(&send_something(&mut client, now()))); + + client + .migrate(Some(addr_v4()), Some(addr_v4()), false, now()) + .unwrap(); + + // The client now probes the new path. + let probe = client.process_output(now()).dgram().unwrap(); + assert_v4_path(&probe, true); + assert_eq!(client.stats().frame_tx.path_challenge, 1); + let probe_cid = ConnectionId::from(get_cid(&probe)); + assert_ne!(original_cid, probe_cid); + + // Have the server receive the probe, but separately have it decide to + // retire all of the available connection IDs. + server.test_frame_writer = Some(Box::new(RetireAll { cid_gen })); + let retire_all = send_something(&mut server, now()); + server.test_frame_writer = None; + + let resp = server.process(Some(&probe), now()).dgram().unwrap(); + assert_v4_path(&resp, true); + assert_eq!(server.stats().frame_tx.path_response, 1); + assert_eq!(server.stats().frame_tx.path_challenge, 1); + + // Have the client receive the NEW_CONNECTION_ID with Retire Prior To. + client.process_input(&retire_all, now()); + // This packet contains the probe response, which should be fine, but it + // also includes PATH_CHALLENGE for the new path, and the client can't + // respond without a connection ID. We treat this as a connection error. + client.process_input(&resp, now()); + assert!(matches!( + client.state(), + State::Closing { + error: ConnectionError::Transport(Error::InvalidMigration), + .. + } + )); +} + +/// The timing of when frames arrive can mean that the migration path can +/// get the last available connection ID. +#[test] +fn retire_prior_to_migration_success() { + let mut client = default_client(); + let cid_gen: Rc<RefCell<dyn ConnectionIdGenerator>> = + Rc::new(RefCell::new(CountingConnectionIdGenerator::default())); + let mut server = Connection::new_server( + test_fixture::DEFAULT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::clone(&cid_gen), + ConnectionParameters::default(), + ) + .unwrap(); + connect_force_idle(&mut client, &mut server); + + let original_cid = ConnectionId::from(get_cid(&send_something(&mut client, now()))); + + client + .migrate(Some(addr_v4()), Some(addr_v4()), false, now()) + .unwrap(); + + // The client now probes the new path. + let probe = client.process_output(now()).dgram().unwrap(); + assert_v4_path(&probe, true); + assert_eq!(client.stats().frame_tx.path_challenge, 1); + let probe_cid = ConnectionId::from(get_cid(&probe)); + assert_ne!(original_cid, probe_cid); + + // Have the server receive the probe, but separately have it decide to + // retire all of the available connection IDs. + server.test_frame_writer = Some(Box::new(RetireAll { cid_gen })); + let retire_all = send_something(&mut server, now()); + server.test_frame_writer = None; + + let resp = server.process(Some(&probe), now()).dgram().unwrap(); + assert_v4_path(&resp, true); + assert_eq!(server.stats().frame_tx.path_response, 1); + assert_eq!(server.stats().frame_tx.path_challenge, 1); + + // Have the client receive the NEW_CONNECTION_ID with Retire Prior To second. + // As this occurs in a very specific order, migration succeeds. + client.process_input(&resp, now()); + client.process_input(&retire_all, now()); + + // Migration succeeds and the new path gets the last connection ID. + let dgram = send_something(&mut client, now()); + assert_v4_path(&dgram, false); + assert_ne!(get_cid(&dgram), original_cid); + assert_ne!(get_cid(&dgram), probe_cid); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/mod.rs b/third_party/rust/neqo-transport/src/connection/tests/mod.rs new file mode 100644 index 0000000000..8a999f4048 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/mod.rs @@ -0,0 +1,614 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![deny(clippy::pedantic)] + +use std::{ + cell::RefCell, + cmp::min, + convert::TryFrom, + mem, + rc::Rc, + time::{Duration, Instant}, +}; + +use enum_map::enum_map; +use neqo_common::{event::Provider, qdebug, qtrace, Datagram, Decoder, Role}; +use neqo_crypto::{random, AllowZeroRtt, AuthenticationStatus, ResumptionToken}; +use test_fixture::{self, addr, fixture_init, new_neqo_qlog, now}; + +use super::{Connection, ConnectionError, ConnectionId, Output, State}; +use crate::{ + addr_valid::{AddressValidation, ValidateAddress}, + cc::{CWND_INITIAL_PKTS, CWND_MIN}, + cid::ConnectionIdRef, + events::ConnectionEvent, + frame::FRAME_TYPE_PING, + packet::PacketBuilder, + path::PATH_MTU_V6, + recovery::ACK_ONLY_SIZE_LIMIT, + stats::{FrameStats, Stats, MAX_PTO_COUNTS}, + ConnectionIdDecoder, ConnectionIdGenerator, ConnectionParameters, Error, StreamId, StreamType, + Version, +}; + +// All the tests. +mod ackrate; +mod cc; +mod close; +mod datagram; +mod fuzzing; +mod handshake; +mod idle; +mod keys; +mod migration; +mod priority; +mod recovery; +mod resumption; +mod stream; +mod vn; +mod zerortt; + +const DEFAULT_RTT: Duration = Duration::from_millis(100); +const AT_LEAST_PTO: Duration = Duration::from_secs(1); +const DEFAULT_STREAM_DATA: &[u8] = b"message"; +/// The number of 1-RTT packets sent in `force_idle` by a client. +const CLIENT_HANDSHAKE_1RTT_PACKETS: usize = 1; + +/// WARNING! In this module, this version of the generator needs to be used. +/// This copies the implementation from +/// `test_fixture::CountingConnectionIdGenerator`, but it uses the different +/// types that are exposed to this module. See also `default_client`. +/// +/// This version doesn't randomize the length; as the congestion control tests +/// count the amount of data sent precisely. +#[derive(Debug, Default)] +pub struct CountingConnectionIdGenerator { + counter: u32, +} + +impl ConnectionIdDecoder for CountingConnectionIdGenerator { + fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> { + let len = usize::from(dec.peek_byte().unwrap()); + dec.decode(len).map(ConnectionIdRef::from) + } +} + +impl ConnectionIdGenerator for CountingConnectionIdGenerator { + fn generate_cid(&mut self) -> Option<ConnectionId> { + let mut r = random(20); + r[0] = 8; + r[1] = u8::try_from(self.counter >> 24).unwrap(); + r[2] = u8::try_from((self.counter >> 16) & 0xff).unwrap(); + r[3] = u8::try_from((self.counter >> 8) & 0xff).unwrap(); + r[4] = u8::try_from(self.counter & 0xff).unwrap(); + self.counter += 1; + Some(ConnectionId::from(&r[..8])) + } + + fn as_decoder(&self) -> &dyn ConnectionIdDecoder { + self + } +} + +// This is fabulous: because test_fixture uses the public API for Connection, +// it gets a different type to the ones that are referenced via super::super::*. +// Thus, this code can't use default_client() and default_server() from +// test_fixture because they produce different - and incompatible - types. +// +// These are a direct copy of those functions. +pub fn new_client(params: ConnectionParameters) -> Connection { + fixture_init(); + let (log, _contents) = new_neqo_qlog(); + let mut client = Connection::new_client( + test_fixture::DEFAULT_SERVER_NAME, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(CountingConnectionIdGenerator::default())), + addr(), + addr(), + params, + now(), + ) + .expect("create a default client"); + client.set_qlog(log); + client +} + +pub fn default_client() -> Connection { + new_client(ConnectionParameters::default()) +} + +pub fn new_server(params: ConnectionParameters) -> Connection { + fixture_init(); + let (log, _contents) = new_neqo_qlog(); + let mut c = Connection::new_server( + test_fixture::DEFAULT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(CountingConnectionIdGenerator::default())), + params, + ) + .expect("create a default server"); + c.set_qlog(log); + c.server_enable_0rtt(&test_fixture::anti_replay(), AllowZeroRtt {}) + .expect("enable 0-RTT"); + c +} +pub fn default_server() -> Connection { + new_server(ConnectionParameters::default()) +} +pub fn resumed_server(client: &Connection) -> Connection { + new_server(ConnectionParameters::default().versions(client.version(), Version::all())) +} + +/// If state is `AuthenticationNeeded` call `authenticated()`. This function will +/// consume all outstanding events on the connection. +pub fn maybe_authenticate(conn: &mut Connection) -> bool { + let authentication_needed = |e| matches!(e, ConnectionEvent::AuthenticationNeeded); + if conn.events().any(authentication_needed) { + conn.authenticated(AuthenticationStatus::Ok, now()); + return true; + } + false +} + +/// Compute the RTT variance after `n` ACKs or other RTT updates. +pub fn rttvar_after_n_updates(n: usize, rtt: Duration) -> Duration { + assert!(n > 0); + let mut rttvar = rtt / 2; + for _ in 1..n { + rttvar = rttvar * 3 / 4; + } + rttvar +} + +/// This inserts a PING frame into packets. +struct PingWriter {} + +impl crate::connection::test_internal::FrameWriter for PingWriter { + fn write_frames(&mut self, builder: &mut PacketBuilder) { + builder.encode_varint(FRAME_TYPE_PING); + } +} + +/// Drive the handshake between the client and server. +fn handshake( + client: &mut Connection, + server: &mut Connection, + now: Instant, + rtt: Duration, +) -> Instant { + let mut a = client; + let mut b = server; + let mut now = now; + + let mut input = None; + let is_done = |c: &mut Connection| { + matches!( + c.state(), + State::Confirmed | State::Closing { .. } | State::Closed(..) + ) + }; + + let mut did_ping = enum_map! {_ => false}; + while !is_done(a) { + _ = maybe_authenticate(a); + let had_input = input.is_some(); + // Insert a PING frame into the first application data packet an endpoint sends, + // in order to force the peer to ACK it. For the server, this is depending on the + // client's connection state, which is accessible during the tests. + // + // We're doing this to prevent packet loss from delaying ACKs, which would cause + // cwnd to shrink, and also to prevent the delayed ACK timer from being armed after + // the handshake, which is not something the tests are written to account for. + let should_ping = !did_ping[a.role()] + && (a.role() == Role::Client && *a.state() == State::Connected + || (a.role() == Role::Server && *b.state() == State::Connected)); + if should_ping { + a.test_frame_writer = Some(Box::new(PingWriter {})); + } + let output = a.process(input.as_ref(), now).dgram(); + if should_ping { + a.test_frame_writer = None; + did_ping[a.role()] = true; + } + assert!(had_input || output.is_some()); + input = output; + qtrace!("handshake: t += {:?}", rtt / 2); + now += rtt / 2; + mem::swap(&mut a, &mut b); + } + if let Some(d) = input { + a.process_input(&d, now); + } + now +} + +fn connect_fail( + client: &mut Connection, + server: &mut Connection, + client_error: Error, + server_error: Error, +) { + handshake(client, server, now(), Duration::new(0, 0)); + assert_error(client, &ConnectionError::Transport(client_error)); + assert_error(server, &ConnectionError::Transport(server_error)); +} + +fn connect_with_rtt( + client: &mut Connection, + server: &mut Connection, + now: Instant, + rtt: Duration, +) -> Instant { + fn check_rtt(stats: &Stats, rtt: Duration) { + assert_eq!(stats.rtt, rtt); + // Validate that rttvar has been computed correctly based on the number of RTT updates. + let n = stats.frame_rx.ack + usize::from(stats.rtt_init_guess); + assert_eq!(stats.rttvar, rttvar_after_n_updates(n, rtt)); + } + let now = handshake(client, server, now, rtt); + assert_eq!(*client.state(), State::Confirmed); + assert_eq!(*server.state(), State::Confirmed); + + check_rtt(&client.stats(), rtt); + check_rtt(&server.stats(), rtt); + now +} + +fn connect(client: &mut Connection, server: &mut Connection) { + connect_with_rtt(client, server, now(), Duration::new(0, 0)); +} + +fn assert_error(c: &Connection, expected: &ConnectionError) { + match c.state() { + State::Closing { error, .. } | State::Draining { error, .. } | State::Closed(error) => { + assert_eq!(*error, *expected, "{c} error mismatch"); + } + _ => panic!("bad state {:?}", c.state()), + } +} + +fn exchange_ticket( + client: &mut Connection, + server: &mut Connection, + now: Instant, +) -> ResumptionToken { + let validation = AddressValidation::new(now, ValidateAddress::NoToken).unwrap(); + let validation = Rc::new(RefCell::new(validation)); + server.set_validation(Rc::clone(&validation)); + server.send_ticket(now, &[]).expect("can send ticket"); + let ticket = server.process_output(now).dgram(); + assert!(ticket.is_some()); + client.process_input(&ticket.unwrap(), now); + assert_eq!(*client.state(), State::Confirmed); + get_tokens(client).pop().expect("should have token") +} + +/// The `handshake` method inserts PING frames into the first application data packets, +/// which forces each peer to ACK them. As a side effect, that causes both sides of the +/// connection to be idle aftwerwards. This method simply verifies that this is the case. +fn assert_idle(client: &mut Connection, server: &mut Connection, rtt: Duration, now: Instant) { + let idle_timeout = min( + client.conn_params.get_idle_timeout(), + server.conn_params.get_idle_timeout(), + ); + // Client started its idle period half an RTT before now. + assert_eq!( + client.process_output(now), + Output::Callback(idle_timeout - rtt / 2) + ); + assert_eq!(server.process_output(now), Output::Callback(idle_timeout)); +} + +/// Connect with an RTT and then force both peers to be idle. +fn connect_rtt_idle(client: &mut Connection, server: &mut Connection, rtt: Duration) -> Instant { + let now = connect_with_rtt(client, server, now(), rtt); + assert_idle(client, server, rtt, now); + // Drain events from both as well. + _ = client.events().count(); + _ = server.events().count(); + qtrace!("----- connected and idle with RTT {:?}", rtt); + now +} + +fn connect_force_idle(client: &mut Connection, server: &mut Connection) { + connect_rtt_idle(client, server, Duration::new(0, 0)); +} + +fn fill_stream(c: &mut Connection, stream: StreamId) { + const BLOCK_SIZE: usize = 4_096; + loop { + let bytes_sent = c.stream_send(stream, &[0x42; BLOCK_SIZE]).unwrap(); + qtrace!("fill_cwnd wrote {} bytes", bytes_sent); + if bytes_sent < BLOCK_SIZE { + break; + } + } +} + +/// This fills the congestion window from a single source. +/// As the pacer will interfere with this, this moves time forward +/// as `Output::Callback` is received. Because it is hard to tell +/// from the return value whether a timeout is an ACK delay, PTO, or +/// pacing, this looks at the congestion window to tell when to stop. +/// Returns a list of datagrams and the new time. +fn fill_cwnd(c: &mut Connection, stream: StreamId, mut now: Instant) -> (Vec<Datagram>, Instant) { + // Train wreck function to get the remaining congestion window on the primary path. + fn cwnd(c: &Connection) -> usize { + c.paths.primary().borrow().sender().cwnd_avail() + } + + qtrace!("fill_cwnd starting cwnd: {}", cwnd(c)); + fill_stream(c, stream); + + let mut total_dgrams = Vec::new(); + loop { + let pkt = c.process_output(now); + qtrace!("fill_cwnd cwnd remaining={}, output: {:?}", cwnd(c), pkt); + match pkt { + Output::Datagram(dgram) => { + total_dgrams.push(dgram); + } + Output::Callback(t) => { + if cwnd(c) < ACK_ONLY_SIZE_LIMIT { + break; + } + now += t; + } + Output::None => panic!(), + } + } + + qtrace!( + "fill_cwnd sent {} bytes", + total_dgrams.iter().map(|d| d.len()).sum::<usize>() + ); + (total_dgrams, now) +} + +/// This function is like the combination of `fill_cwnd` and `ack_bytes`. +/// However, it acknowledges everything inline and preserves an RTT of `DEFAULT_RTT`. +fn increase_cwnd( + sender: &mut Connection, + receiver: &mut Connection, + stream: StreamId, + mut now: Instant, +) -> Instant { + fill_stream(sender, stream); + loop { + let pkt = sender.process_output(now); + match pkt { + Output::Datagram(dgram) => { + receiver.process_input(&dgram, now + DEFAULT_RTT / 2); + } + Output::Callback(t) => { + if t < DEFAULT_RTT { + now += t; + } else { + break; // We're on PTO now. + } + } + Output::None => panic!(), + } + } + + // Now acknowledge all those packets at once. + now += DEFAULT_RTT / 2; + let ack = receiver.process_output(now).dgram(); + now += DEFAULT_RTT / 2; + sender.process_input(&ack.unwrap(), now); + now +} + +/// Receive multiple packets and generate an ack-only packet. +/// +/// # Panics +/// +/// The caller is responsible for ensuring that `dest` has received +/// enough data that it wants to generate an ACK. This panics if +/// no ACK frame is generated. +fn ack_bytes<D>(dest: &mut Connection, stream: StreamId, in_dgrams: D, now: Instant) -> Datagram +where + D: IntoIterator<Item = Datagram>, + D::IntoIter: ExactSizeIterator, +{ + let mut srv_buf = [0; 4_096]; + + let in_dgrams = in_dgrams.into_iter(); + qdebug!([dest], "ack_bytes {} datagrams", in_dgrams.len()); + for dgram in in_dgrams { + dest.process_input(&dgram, now); + } + + loop { + let (bytes_read, _fin) = dest.stream_recv(stream, &mut srv_buf).unwrap(); + qtrace!([dest], "ack_bytes read {} bytes", bytes_read); + if bytes_read == 0 { + break; + } + } + + dest.process_output(now).dgram().unwrap() +} + +// Get the current congestion window for the connection. +fn cwnd(c: &Connection) -> usize { + c.paths.primary().borrow().sender().cwnd() +} +fn cwnd_avail(c: &Connection) -> usize { + c.paths.primary().borrow().sender().cwnd_avail() +} + +fn induce_persistent_congestion( + client: &mut Connection, + server: &mut Connection, + stream: StreamId, + mut now: Instant, +) -> Instant { + // Note: wait some arbitrary time that should be longer than pto + // timer. This is rather brittle. + qtrace!([client], "induce_persistent_congestion"); + now += AT_LEAST_PTO; + + let mut pto_counts = [0; MAX_PTO_COUNTS]; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + qtrace!([client], "first PTO"); + let (c_tx_dgrams, next_now) = fill_cwnd(client, stream, now); + now = next_now; + assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets + + pto_counts[0] = 1; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + qtrace!([client], "second PTO"); + now += AT_LEAST_PTO * 2; + let (c_tx_dgrams, next_now) = fill_cwnd(client, stream, now); + now = next_now; + assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets + + pto_counts[0] = 0; + pto_counts[1] = 1; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + qtrace!([client], "third PTO"); + now += AT_LEAST_PTO * 4; + let (c_tx_dgrams, next_now) = fill_cwnd(client, stream, now); + now = next_now; + assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets + + pto_counts[1] = 0; + pto_counts[2] = 1; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + // An ACK for the third PTO causes persistent congestion. + let s_ack = ack_bytes(server, stream, c_tx_dgrams, now); + client.process_input(&s_ack, now); + assert_eq!(cwnd(client), CWND_MIN); + now +} + +/// This magic number is the size of the client's CWND after the handshake completes. +/// This is the same as the initial congestion window, because during the handshake +/// the cc is app limited and cwnd is not increased. +/// +/// As we change how we build packets, or even as NSS changes, +/// this number might be different. The tests that depend on this +/// value could fail as a result of variations, so it's OK to just +/// change this value, but it is good to first understand where the +/// change came from. +const POST_HANDSHAKE_CWND: usize = PATH_MTU_V6 * CWND_INITIAL_PKTS; + +/// Determine the number of packets required to fill the CWND. +const fn cwnd_packets(data: usize) -> usize { + // Add one if the last chunk is >= ACK_ONLY_SIZE_LIMIT. + (data + PATH_MTU_V6 - ACK_ONLY_SIZE_LIMIT) / PATH_MTU_V6 +} + +/// Determine the size of the last packet. +/// The minimal size of a packet is `ACK_ONLY_SIZE_LIMIT`. +fn last_packet(cwnd: usize) -> usize { + if (cwnd % PATH_MTU_V6) > ACK_ONLY_SIZE_LIMIT { + cwnd % PATH_MTU_V6 + } else { + PATH_MTU_V6 + } +} + +/// Assert that the set of packets fill the CWND. +fn assert_full_cwnd(packets: &[Datagram], cwnd: usize) { + assert_eq!(packets.len(), cwnd_packets(cwnd)); + let (last, rest) = packets.split_last().unwrap(); + assert!(rest.iter().all(|d| d.len() == PATH_MTU_V6)); + assert_eq!(last.len(), last_packet(cwnd)); +} + +/// Send something on a stream from `sender` to `receiver`, maybe allowing for pacing. +/// Return the resulting datagram and the new time. +#[must_use] +fn send_something_paced( + sender: &mut Connection, + mut now: Instant, + allow_pacing: bool, +) -> (Datagram, Instant) { + let stream_id = sender.stream_create(StreamType::UniDi).unwrap(); + assert!(sender.stream_send(stream_id, DEFAULT_STREAM_DATA).is_ok()); + assert!(sender.stream_close_send(stream_id).is_ok()); + qdebug!([sender], "send_something on {}", stream_id); + let dgram = match sender.process_output(now) { + Output::Callback(t) => { + assert!(allow_pacing, "send_something: unexpected delay"); + now += t; + sender + .process_output(now) + .dgram() + .expect("send_something: should have something to send") + } + Output::Datagram(d) => d, + Output::None => panic!("send_something: got Output::None"), + }; + (dgram, now) +} + +/// Send something on a stream from `sender` to `receiver`. +/// Return the resulting datagram. +fn send_something(sender: &mut Connection, now: Instant) -> Datagram { + send_something_paced(sender, now, false).0 +} + +/// Send something on a stream from `sender` to `receiver`. +/// Return any ACK that might result. +fn send_and_receive( + sender: &mut Connection, + receiver: &mut Connection, + now: Instant, +) -> Option<Datagram> { + let dgram = send_something(sender, now); + receiver.process(Some(&dgram), now).dgram() +} + +fn get_tokens(client: &mut Connection) -> Vec<ResumptionToken> { + client + .events() + .filter_map(|e| { + if let ConnectionEvent::ResumptionToken(token) = e { + Some(token) + } else { + None + } + }) + .collect() +} + +fn assert_default_stats(stats: &Stats) { + assert_eq!(stats.packets_rx, 0); + assert_eq!(stats.packets_tx, 0); + let dflt_frames = FrameStats::default(); + assert_eq!(stats.frame_rx, dflt_frames); + assert_eq!(stats.frame_tx, dflt_frames); +} + +#[test] +fn create_client() { + let client = default_client(); + assert_eq!(client.role(), Role::Client); + assert!(matches!(client.state(), State::Init)); + let stats = client.stats(); + assert_default_stats(&stats); + assert_eq!(stats.rtt, crate::rtt::INITIAL_RTT); + assert_eq!(stats.rttvar, crate::rtt::INITIAL_RTT / 2); +} + +#[test] +fn create_server() { + let server = default_server(); + assert_eq!(server.role(), Role::Server); + assert!(matches!(server.state(), State::Init)); + let stats = server.stats(); + assert_default_stats(&stats); + // Server won't have a default path, so no RTT. + assert_eq!(stats.rtt, Duration::from_secs(0)); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/priority.rs b/third_party/rust/neqo-transport/src/connection/tests/priority.rs new file mode 100644 index 0000000000..1f86aa22e5 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/priority.rs @@ -0,0 +1,404 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{cell::RefCell, mem, rc::Rc}; + +use neqo_common::event::Provider; +use test_fixture::{self, now}; + +use super::{ + super::{Connection, Error, Output}, + connect, default_client, default_server, fill_cwnd, maybe_authenticate, +}; +use crate::{ + addr_valid::{AddressValidation, ValidateAddress}, + send_stream::{RetransmissionPriority, TransmissionPriority}, + ConnectionEvent, StreamId, StreamType, +}; + +const BLOCK_SIZE: usize = 4_096; + +fn fill_stream(c: &mut Connection, id: StreamId) { + loop { + if c.stream_send(id, &[0x42; BLOCK_SIZE]).unwrap() < BLOCK_SIZE { + return; + } + } +} + +/// A receive stream cannot be prioritized (yet). +#[test] +fn receive_stream() { + const MESSAGE: &[u8] = b"hello"; + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let id = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(MESSAGE.len(), client.stream_send(id, MESSAGE).unwrap()); + let dgram = client.process_output(now()).dgram(); + + server.process_input(&dgram.unwrap(), now()); + assert_eq!( + server + .stream_priority( + id, + TransmissionPriority::default(), + RetransmissionPriority::default() + ) + .unwrap_err(), + Error::InvalidStreamId, + "Priority doesn't apply to inbound unidirectional streams" + ); + + // But the stream does exist and can be read. + let mut buf = [0; 10]; + let (len, end) = server.stream_recv(id, &mut buf).unwrap(); + assert_eq!(MESSAGE, &buf[..len]); + assert!(!end); +} + +/// Higher priority streams get sent ahead of lower ones, even when +/// the higher priority stream is written to later. +#[test] +fn relative() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // id_normal is created first, but it is lower priority. + let id_normal = client.stream_create(StreamType::UniDi).unwrap(); + fill_stream(&mut client, id_normal); + let high = client.stream_create(StreamType::UniDi).unwrap(); + fill_stream(&mut client, high); + client + .stream_priority( + high, + TransmissionPriority::High, + RetransmissionPriority::default(), + ) + .unwrap(); + + let dgram = client.process_output(now()).dgram(); + server.process_input(&dgram.unwrap(), now()); + + // The "id_normal" stream will get a `NewStream` event, but no data. + for e in server.events() { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + assert_ne!(stream_id, id_normal); + } + } +} + +/// Check that changing priority has effect on the next packet that is sent. +#[test] +fn reprioritize() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // id_normal is created first, but it is lower priority. + let id_normal = client.stream_create(StreamType::UniDi).unwrap(); + fill_stream(&mut client, id_normal); + let id_high = client.stream_create(StreamType::UniDi).unwrap(); + fill_stream(&mut client, id_high); + client + .stream_priority( + id_high, + TransmissionPriority::High, + RetransmissionPriority::default(), + ) + .unwrap(); + + let dgram = client.process_output(now()).dgram(); + server.process_input(&dgram.unwrap(), now()); + + // The "id_normal" stream will get a `NewStream` event, but no data. + for e in server.events() { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + assert_ne!(stream_id, id_normal); + } + } + + // When the high priority stream drops in priority, the streams are equal + // priority and so their stream ID determines what is sent. + client + .stream_priority( + id_high, + TransmissionPriority::Normal, + RetransmissionPriority::default(), + ) + .unwrap(); + let dgram = client.process_output(now()).dgram(); + server.process_input(&dgram.unwrap(), now()); + + for e in server.events() { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + assert_ne!(stream_id, id_high); + } + } +} + +/// Retransmission can be prioritized differently (usually higher). +#[test] +fn repairing_loss() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let mut now = now(); + + // Send a few packets at low priority, lose one. + let id_low = client.stream_create(StreamType::UniDi).unwrap(); + fill_stream(&mut client, id_low); + client + .stream_priority( + id_low, + TransmissionPriority::Low, + RetransmissionPriority::Higher, + ) + .unwrap(); + + let _lost = client.process_output(now).dgram(); + for _ in 0..5 { + match client.process_output(now) { + Output::Datagram(d) => server.process_input(&d, now), + Output::Callback(delay) => now += delay, + Output::None => unreachable!(), + } + } + + // Generate an ACK. The first packet is now considered lost. + let ack = server.process_output(now).dgram(); + _ = server.events().count(); // Drain events. + + let id_normal = client.stream_create(StreamType::UniDi).unwrap(); + fill_stream(&mut client, id_normal); + + let dgram = client.process(ack.as_ref(), now).dgram(); + assert_eq!(client.stats().lost, 1); // Client should have noticed the loss. + server.process_input(&dgram.unwrap(), now); + + // Only the low priority stream has data as the retransmission of the data from + // the lost packet is now more important than new data from the high priority stream. + for e in server.events() { + println!("Event: {e:?}"); + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + assert_eq!(stream_id, id_low); + } + } + + // However, only the retransmission is prioritized. + // Though this might contain some retransmitted data, as other frames might push + // the retransmitted data into a second packet, it will also contain data from the + // normal priority stream. + let dgram = client.process_output(now).dgram(); + server.process_input(&dgram.unwrap(), now); + assert!(server.events().any( + |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id } if stream_id == id_normal), + )); +} + +#[test] +fn critical() { + let mut client = default_client(); + let mut server = default_server(); + let now = now(); + + // Rather than connect, send stream data in 0.5-RTT. + // That allows this to test that critical streams pre-empt most frame types. + let dgram = client.process_output(now).dgram(); + let dgram = server.process(dgram.as_ref(), now).dgram(); + client.process_input(&dgram.unwrap(), now); + maybe_authenticate(&mut client); + + let id = server.stream_create(StreamType::UniDi).unwrap(); + server + .stream_priority( + id, + TransmissionPriority::Critical, + RetransmissionPriority::default(), + ) + .unwrap(); + + // Can't use fill_cwnd here because the server is blocked on the amplification + // limit, so it can't fill the congestion window. + while server.stream_create(StreamType::UniDi).is_ok() {} + + fill_stream(&mut server, id); + let stats_before = server.stats().frame_tx; + let dgram = server.process_output(now).dgram(); + let stats_after = server.stats().frame_tx; + assert_eq!(stats_after.crypto, stats_before.crypto); + assert_eq!(stats_after.streams_blocked, 0); + assert_eq!(stats_after.new_connection_id, 0); + assert_eq!(stats_after.new_token, 0); + assert_eq!(stats_after.handshake_done, 0); + + // Complete the handshake. + let dgram = client.process(dgram.as_ref(), now).dgram(); + server.process_input(&dgram.unwrap(), now); + + // Critical beats everything but HANDSHAKE_DONE. + let stats_before = server.stats().frame_tx; + mem::drop(fill_cwnd(&mut server, id, now)); + let stats_after = server.stats().frame_tx; + assert_eq!(stats_after.crypto, stats_before.crypto); + assert_eq!(stats_after.streams_blocked, 0); + assert_eq!(stats_after.new_connection_id, 0); + assert_eq!(stats_after.new_token, 0); + assert_eq!(stats_after.handshake_done, 1); +} + +#[test] +fn important() { + let mut client = default_client(); + let mut server = default_server(); + let now = now(); + + // Rather than connect, send stream data in 0.5-RTT. + // That allows this to test that important streams pre-empt most frame types. + let dgram = client.process_output(now).dgram(); + let dgram = server.process(dgram.as_ref(), now).dgram(); + client.process_input(&dgram.unwrap(), now); + maybe_authenticate(&mut client); + + let id = server.stream_create(StreamType::UniDi).unwrap(); + server + .stream_priority( + id, + TransmissionPriority::Important, + RetransmissionPriority::default(), + ) + .unwrap(); + fill_stream(&mut server, id); + + // Important beats everything but flow control. + // Make enough streams to get a STREAMS_BLOCKED frame out. + while server.stream_create(StreamType::UniDi).is_ok() {} + + let stats_before = server.stats().frame_tx; + let dgram = server.process_output(now).dgram(); + let stats_after = server.stats().frame_tx; + assert_eq!(stats_after.crypto, stats_before.crypto); + assert_eq!(stats_after.streams_blocked, 1); + assert_eq!(stats_after.new_connection_id, 0); + assert_eq!(stats_after.new_token, 0); + assert_eq!(stats_after.handshake_done, 0); + assert_eq!(stats_after.stream, stats_before.stream + 1); + + // Complete the handshake. + let dgram = client.process(dgram.as_ref(), now).dgram(); + server.process_input(&dgram.unwrap(), now); + + // Important beats everything but flow control. + let stats_before = server.stats().frame_tx; + mem::drop(fill_cwnd(&mut server, id, now)); + let stats_after = server.stats().frame_tx; + assert_eq!(stats_after.crypto, stats_before.crypto); + assert_eq!(stats_after.streams_blocked, 1); + assert_eq!(stats_after.new_connection_id, 0); + assert_eq!(stats_after.new_token, 0); + assert_eq!(stats_after.handshake_done, 1); + assert!(stats_after.stream > stats_before.stream); +} + +#[test] +fn high_normal() { + let mut client = default_client(); + let mut server = default_server(); + let now = now(); + + // Rather than connect, send stream data in 0.5-RTT. + // That allows this to test that important streams pre-empt most frame types. + let dgram = client.process_output(now).dgram(); + let dgram = server.process(dgram.as_ref(), now).dgram(); + client.process_input(&dgram.unwrap(), now); + maybe_authenticate(&mut client); + + let id = server.stream_create(StreamType::UniDi).unwrap(); + server + .stream_priority( + id, + TransmissionPriority::High, + RetransmissionPriority::default(), + ) + .unwrap(); + fill_stream(&mut server, id); + + // Important beats everything but flow control. + // Make enough streams to get a STREAMS_BLOCKED frame out. + while server.stream_create(StreamType::UniDi).is_ok() {} + + let stats_before = server.stats().frame_tx; + let dgram = server.process_output(now).dgram(); + let stats_after = server.stats().frame_tx; + assert_eq!(stats_after.crypto, stats_before.crypto); + assert_eq!(stats_after.streams_blocked, 1); + assert_eq!(stats_after.new_connection_id, 0); + assert_eq!(stats_after.new_token, 0); + assert_eq!(stats_after.handshake_done, 0); + assert_eq!(stats_after.stream, stats_before.stream + 1); + + // Complete the handshake. + let dgram = client.process(dgram.as_ref(), now).dgram(); + server.process_input(&dgram.unwrap(), now); + + // High or Normal doesn't beat NEW_CONNECTION_ID, + // but they beat CRYPTO/NEW_TOKEN. + let stats_before = server.stats().frame_tx; + server.send_ticket(now, &[]).unwrap(); + mem::drop(fill_cwnd(&mut server, id, now)); + let stats_after = server.stats().frame_tx; + assert_eq!(stats_after.crypto, stats_before.crypto); + assert_eq!(stats_after.streams_blocked, 1); + assert_ne!(stats_after.new_connection_id, 0); // Note: > 0 + assert_eq!(stats_after.new_token, 0); + assert_eq!(stats_after.handshake_done, 1); + assert!(stats_after.stream > stats_before.stream); +} + +#[test] +fn low() { + let mut client = default_client(); + let mut server = default_server(); + let now = now(); + // Use address validation; note that we need to hold a strong reference + // as the server will only hold a weak reference. + let validation = Rc::new(RefCell::new( + AddressValidation::new(now, ValidateAddress::Never).unwrap(), + )); + server.set_validation(Rc::clone(&validation)); + connect(&mut client, &mut server); + + let id = server.stream_create(StreamType::UniDi).unwrap(); + server + .stream_priority( + id, + TransmissionPriority::Low, + RetransmissionPriority::default(), + ) + .unwrap(); + fill_stream(&mut server, id); + + // Send a session ticket and make it big enough to require a whole packet. + // The resulting CRYPTO frame beats out the stream data. + let stats_before = server.stats().frame_tx; + server.send_ticket(now, &[0; 2048]).unwrap(); + mem::drop(server.process_output(now)); + let stats_after = server.stats().frame_tx; + assert_eq!(stats_after.crypto, stats_before.crypto + 1); + assert_eq!(stats_after.stream, stats_before.stream); + + // The above can't test if NEW_TOKEN wins because once that fits in a packet, + // it is very hard to ensure that the STREAM frame won't also fit. + // However, we can ensure that the next packet doesn't consist of just STREAM. + let stats_before = server.stats().frame_tx; + mem::drop(server.process_output(now)); + let stats_after = server.stats().frame_tx; + assert_eq!(stats_after.crypto, stats_before.crypto + 1); + assert_eq!(stats_after.new_token, 1); + assert_eq!(stats_after.stream, stats_before.stream + 1); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/recovery.rs b/third_party/rust/neqo-transport/src/connection/tests/recovery.rs new file mode 100644 index 0000000000..0f12d03107 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/recovery.rs @@ -0,0 +1,804 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + mem, + time::{Duration, Instant}, +}; + +use neqo_common::qdebug; +use neqo_crypto::AuthenticationStatus; +use test_fixture::{ + assertions::{assert_handshake, assert_initial}, + now, split_datagram, +}; + +use super::{ + super::{Connection, ConnectionParameters, Output, State}, + assert_full_cwnd, connect, connect_force_idle, connect_rtt_idle, connect_with_rtt, cwnd, + default_client, default_server, fill_cwnd, maybe_authenticate, new_client, send_and_receive, + send_something, AT_LEAST_PTO, DEFAULT_RTT, DEFAULT_STREAM_DATA, POST_HANDSHAKE_CWND, +}; +use crate::{ + cc::CWND_MIN, + path::PATH_MTU_V6, + recovery::{ + FAST_PTO_SCALE, MAX_OUTSTANDING_UNACK, MAX_PTO_PACKET_COUNT, MIN_OUTSTANDING_UNACK, + }, + rtt::GRANULARITY, + stats::MAX_PTO_COUNTS, + tparams::TransportParameter, + tracking::DEFAULT_ACK_DELAY, + StreamType, +}; + +#[test] +fn pto_works_basic() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let mut now = now(); + + let res = client.process(None, now); + let idle_timeout = ConnectionParameters::default().get_idle_timeout(); + assert_eq!(res, Output::Callback(idle_timeout)); + + // Send data on two streams + let stream1 = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(client.stream_send(stream1, b"hello").unwrap(), 5); + assert_eq!(client.stream_send(stream1, b" world!").unwrap(), 7); + + let stream2 = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(client.stream_send(stream2, b"there!").unwrap(), 6); + + // Send a packet after some time. + now += Duration::from_secs(10); + let out = client.process(None, now); + assert!(out.dgram().is_some()); + + // Nothing to do, should return callback + let out = client.process(None, now); + assert!(matches!(out, Output::Callback(_))); + + // One second later, it should want to send PTO packet + now += AT_LEAST_PTO; + let out = client.process(None, now); + + let stream_before = server.stats().frame_rx.stream; + server.process_input(&out.dgram().unwrap(), now); + assert_eq!(server.stats().frame_rx.stream, stream_before + 2); +} + +#[test] +fn pto_works_full_cwnd() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Send lots of data. + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + let (dgrams, now) = fill_cwnd(&mut client, stream_id, now); + assert_full_cwnd(&dgrams, POST_HANDSHAKE_CWND); + + // Fill the CWND after waiting for a PTO. + let (dgrams, now) = fill_cwnd(&mut client, stream_id, now + AT_LEAST_PTO); + // Two packets in the PTO. + // The first should be full sized; the second might be small. + assert_eq!(dgrams.len(), 2); + assert_eq!(dgrams[0].len(), PATH_MTU_V6); + + // Both datagrams contain one or more STREAM frames. + for d in dgrams { + let stream_before = server.stats().frame_rx.stream; + server.process_input(&d, now); + assert!(server.stats().frame_rx.stream > stream_before); + } +} + +#[test] +fn pto_works_ping() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + let mut now = now() + Duration::from_secs(10); + + // Send a few packets from the client. + let pkt0 = send_something(&mut client, now); + let pkt1 = send_something(&mut client, now); + let pkt2 = send_something(&mut client, now); + let pkt3 = send_something(&mut client, now); + + // Nothing to do, should return callback + let cb = client.process(None, now).callback(); + // The PTO timer is calculated with: + // RTT + max(rttvar * 4, GRANULARITY) + max_ack_delay + // With zero RTT and rttvar, max_ack_delay is minimum too (GRANULARITY) + assert_eq!(cb, GRANULARITY * 2); + + // Process these by server, skipping pkt0 + let srv0 = server.process(Some(&pkt1), now).dgram(); + assert!(srv0.is_some()); // ooo, ack client pkt1 + + now += Duration::from_millis(20); + + // process pkt2 (immediate ack because last ack was more than an RTT ago; RTT=0) + let srv1 = server.process(Some(&pkt2), now).dgram(); + assert!(srv1.is_some()); // this is now dropped + + now += Duration::from_millis(20); + // process pkt3 (acked for same reason) + let srv2 = server.process(Some(&pkt3), now).dgram(); + // ack client pkt 2 & 3 + assert!(srv2.is_some()); + + // client processes ack + let pkt4 = client.process(srv2.as_ref(), now).dgram(); + // client resends data from pkt0 + assert!(pkt4.is_some()); + + // server sees ooo pkt0 and generates immediate ack + let srv3 = server.process(Some(&pkt0), now).dgram(); + assert!(srv3.is_some()); + + // Accept the acknowledgment. + let pkt5 = client.process(srv3.as_ref(), now).dgram(); + assert!(pkt5.is_none()); + + now += Duration::from_millis(70); + // PTO expires. No unacked data. Only send PING. + let client_pings = client.stats().frame_tx.ping; + let pkt6 = client.process(None, now).dgram(); + assert_eq!(client.stats().frame_tx.ping, client_pings + 1); + + let server_pings = server.stats().frame_rx.ping; + server.process_input(&pkt6.unwrap(), now); + assert_eq!(server.stats().frame_rx.ping, server_pings + 1); +} + +#[test] +fn pto_initial() { + const INITIAL_PTO: Duration = Duration::from_millis(300); + let mut now = now(); + + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let pkt1 = client.process(None, now).dgram(); + assert!(pkt1.is_some()); + assert_eq!(pkt1.clone().unwrap().len(), PATH_MTU_V6); + + let delay = client.process(None, now).callback(); + assert_eq!(delay, INITIAL_PTO); + + // Resend initial after PTO. + now += delay; + let pkt2 = client.process(None, now).dgram(); + assert!(pkt2.is_some()); + assert_eq!(pkt2.unwrap().len(), PATH_MTU_V6); + + let delay = client.process(None, now).callback(); + // PTO has doubled. + assert_eq!(delay, INITIAL_PTO * 2); + + // Server process the first initial pkt. + let mut server = default_server(); + let out = server.process(pkt1.as_ref(), now).dgram(); + assert!(out.is_some()); + + // Client receives ack for the first initial packet as well a Handshake packet. + // After the handshake packet the initial keys and the crypto stream for the initial + // packet number space will be discarded. + // Here only an ack for the Handshake packet will be sent. + let out = client.process(out.as_ref(), now).dgram(); + assert!(out.is_some()); + + // We do not have PTO for the resent initial packet any more, but + // the Handshake PTO timer should be armed. As the RTT is apparently + // the same as the initial PTO value, and there is only one sample, + // the PTO will be 3x the INITIAL PTO. + let delay = client.process(None, now).callback(); + assert_eq!(delay, INITIAL_PTO * 3); +} + +/// A complete handshake that involves a PTO in the Handshake space. +#[test] +fn pto_handshake_complete() { + const HALF_RTT: Duration = Duration::from_millis(10); + + let mut now = now(); + // start handshake + let mut client = default_client(); + let mut server = default_server(); + + let pkt = client.process(None, now).dgram(); + assert_initial(pkt.as_ref().unwrap(), false); + let cb = client.process(None, now).callback(); + assert_eq!(cb, Duration::from_millis(300)); + + now += HALF_RTT; + let pkt = server.process(pkt.as_ref(), now).dgram(); + assert_initial(pkt.as_ref().unwrap(), false); + + now += HALF_RTT; + let pkt = client.process(pkt.as_ref(), now).dgram(); + assert_handshake(pkt.as_ref().unwrap()); + + let cb = client.process(None, now).callback(); + // The client now has a single RTT estimate (20ms), so + // the handshake PTO is set based on that. + assert_eq!(cb, HALF_RTT * 6); + + now += HALF_RTT; + let pkt = server.process(pkt.as_ref(), now).dgram(); + assert!(pkt.is_none()); + + now += HALF_RTT; + client.authenticated(AuthenticationStatus::Ok, now); + + qdebug!("---- client: SH..FIN -> FIN"); + let pkt1 = client.process(None, now).dgram(); + assert_handshake(pkt1.as_ref().unwrap()); + assert_eq!(*client.state(), State::Connected); + + let cb = client.process(None, now).callback(); + assert_eq!(cb, HALF_RTT * 6); + + let mut pto_counts = [0; MAX_PTO_COUNTS]; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + // Wait for PTO to expire and resend a handshake packet. + // Wait long enough that the 1-RTT PTO also fires. + qdebug!("---- client: PTO"); + now += HALF_RTT * 6; + let pkt2 = client.process(None, now).dgram(); + assert_handshake(pkt2.as_ref().unwrap()); + + pto_counts[0] = 1; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + // Get a second PTO packet. + // Add some application data to this datagram, then split the 1-RTT off. + // We'll use that packet to force the server to acknowledge 1-RTT. + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_close_send(stream_id).unwrap(); + let pkt3 = client.process(None, now).dgram(); + assert_handshake(pkt3.as_ref().unwrap()); + let (pkt3_hs, pkt3_1rtt) = split_datagram(&pkt3.unwrap()); + assert_handshake(&pkt3_hs); + assert!(pkt3_1rtt.is_some()); + + // PTO has been doubled. + let cb = client.process(None, now).callback(); + assert_eq!(cb, HALF_RTT * 12); + + // We still have only a single PTO + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + qdebug!("---- server: receive FIN and send ACK"); + now += HALF_RTT; + // Now let the server have pkt1 and expect an immediate Handshake ACK. + // The output will be a Handshake packet with ACK and 1-RTT packet with + // HANDSHAKE_DONE and (because of pkt3_1rtt) an ACK. + // This should remove the 1-RTT PTO from messing this test up. + let server_acks = server.stats().frame_tx.ack; + let server_done = server.stats().frame_tx.handshake_done; + server.process_input(&pkt3_1rtt.unwrap(), now); + let ack = server.process(pkt1.as_ref(), now).dgram(); + assert!(ack.is_some()); + assert_eq!(server.stats().frame_tx.ack, server_acks + 2); + assert_eq!(server.stats().frame_tx.handshake_done, server_done + 1); + + // Check that the other packets (pkt2, pkt3) are Handshake packets. + // The server discarded the Handshake keys already, therefore they are dropped. + // Note that these don't include 1-RTT packets, because 1-RTT isn't send on PTO. + let (pkt2_hs, pkt2_1rtt) = split_datagram(&pkt2.unwrap()); + assert_handshake(&pkt2_hs); + assert!(pkt2_1rtt.is_some()); + let dropped_before1 = server.stats().dropped_rx; + let server_frames = server.stats().frame_rx.all; + server.process_input(&pkt2_hs, now); + assert_eq!(1, server.stats().dropped_rx - dropped_before1); + assert_eq!(server.stats().frame_rx.all, server_frames); + + server.process_input(&pkt2_1rtt.unwrap(), now); + let server_frames2 = server.stats().frame_rx.all; + let dropped_before2 = server.stats().dropped_rx; + server.process_input(&pkt3_hs, now); + assert_eq!(1, server.stats().dropped_rx - dropped_before2); + assert_eq!(server.stats().frame_rx.all, server_frames2); + + now += HALF_RTT; + + // Let the client receive the ACK. + // It should now be wait to acknowledge the HANDSHAKE_DONE. + let cb = client.process(ack.as_ref(), now).callback(); + // The default ack delay is the RTT divided by the default ACK ratio of 4. + let expected_ack_delay = HALF_RTT * 2 / 4; + assert_eq!(cb, expected_ack_delay); + + // Let the ACK delay timer expire. + now += cb; + let out = client.process(None, now).dgram(); + assert!(out.is_some()); +} + +/// Test that PTO in the Handshake space contains the right frames. +#[test] +fn pto_handshake_frames() { + let mut now = now(); + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let pkt = client.process(None, now); + + now += Duration::from_millis(10); + qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); + let mut server = default_server(); + let pkt = server.process(pkt.as_dgram_ref(), now); + + now += Duration::from_millis(10); + qdebug!("---- client: cert verification"); + let pkt = client.process(pkt.as_dgram_ref(), now); + + now += Duration::from_millis(10); + mem::drop(server.process(pkt.as_dgram_ref(), now)); + + now += Duration::from_millis(10); + client.authenticated(AuthenticationStatus::Ok, now); + + let stream = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(stream, 2); + assert_eq!(client.stream_send(stream, b"zero").unwrap(), 4); + qdebug!("---- client: SH..FIN -> FIN and 1RTT packet"); + let pkt1 = client.process(None, now).dgram(); + assert!(pkt1.is_some()); + + // Get PTO timer. + let out = client.process(None, now); + assert_eq!(out, Output::Callback(Duration::from_millis(60))); + + // Wait for PTO to expire and resend a handshake packet. + now += Duration::from_millis(60); + let pkt2 = client.process(None, now).dgram(); + assert!(pkt2.is_some()); + + now += Duration::from_millis(10); + let crypto_before = server.stats().frame_rx.crypto; + server.process_input(&pkt2.unwrap(), now); + assert_eq!(server.stats().frame_rx.crypto, crypto_before + 1); +} + +/// In the case that the Handshake takes too many packets, the server might +/// be stalled on the anti-amplification limit. If a Handshake ACK from the +/// client is lost, the client has to keep the PTO timer armed or the server +/// might be unable to send anything, causing a deadlock. +#[test] +fn handshake_ack_pto() { + const RTT: Duration = Duration::from_millis(10); + let mut now = now(); + let mut client = default_client(); + let mut server = default_server(); + // This is a greasing transport parameter, and large enough that the + // server needs to send two Handshake packets. + let big = TransportParameter::Bytes(vec![0; PATH_MTU_V6]); + server.set_local_tparam(0xce16, big).unwrap(); + + let c1 = client.process(None, now).dgram(); + + now += RTT / 2; + let s1 = server.process(c1.as_ref(), now).dgram(); + assert!(s1.is_some()); + let s2 = server.process(None, now).dgram(); + assert!(s1.is_some()); + + // Now let the client have the Initial, but drop the first coalesced Handshake packet. + now += RTT / 2; + let (initial, _) = split_datagram(&s1.unwrap()); + client.process_input(&initial, now); + let c2 = client.process(s2.as_ref(), now).dgram(); + assert!(c2.is_some()); // This is an ACK. Drop it. + let delay = client.process(None, now).callback(); + assert_eq!(delay, RTT * 3); + + let mut pto_counts = [0; MAX_PTO_COUNTS]; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + // Wait for the PTO and ensure that the client generates a packet. + now += delay; + let c3 = client.process(None, now).dgram(); + assert!(c3.is_some()); + + now += RTT / 2; + let ping_before = server.stats().frame_rx.ping; + server.process_input(&c3.unwrap(), now); + assert_eq!(server.stats().frame_rx.ping, ping_before + 1); + + pto_counts[0] = 1; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + // Now complete the handshake as cheaply as possible. + let dgram = server.process(None, now).dgram(); + client.process_input(&dgram.unwrap(), now); + maybe_authenticate(&mut client); + let dgram = client.process(None, now).dgram(); + assert_eq!(*client.state(), State::Connected); + let dgram = server.process(dgram.as_ref(), now).dgram(); + assert_eq!(*server.state(), State::Confirmed); + client.process_input(&dgram.unwrap(), now); + assert_eq!(*client.state(), State::Confirmed); + + assert_eq!(client.stats.borrow().pto_counts, pto_counts); +} + +#[test] +fn loss_recovery_crash() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let now = now(); + + // The server sends something, but we will drop this. + mem::drop(send_something(&mut server, now)); + + // Then send something again, but let it through. + let ack = send_and_receive(&mut server, &mut client, now); + assert!(ack.is_some()); + + // Have the server process the ACK. + let cb = server.process(ack.as_ref(), now).callback(); + assert!(cb > Duration::from_secs(0)); + + // Now we leap into the future. The server should regard the first + // packet as lost based on time alone. + let dgram = server.process(None, now + AT_LEAST_PTO).dgram(); + assert!(dgram.is_some()); + + // This crashes. + mem::drop(send_something(&mut server, now + AT_LEAST_PTO)); +} + +// If we receive packets after the PTO timer has fired, we won't clear +// the PTO state, but we might need to acknowledge those packets. +// This shouldn't happen, but we found that some implementations do this. +#[test] +fn ack_after_pto() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let mut now = now(); + + // The client sends and is forced into a PTO. + mem::drop(send_something(&mut client, now)); + + // Jump forward to the PTO and drain the PTO packets. + now += AT_LEAST_PTO; + // We can use MAX_PTO_PACKET_COUNT, because we know the handshake is over. + for _ in 0..MAX_PTO_PACKET_COUNT { + let dgram = client.process(None, now).dgram(); + assert!(dgram.is_some()); + } + assert!(client.process(None, now).dgram().is_none()); + + // The server now needs to send something that will cause the + // client to want to acknowledge it. A little out of order + // delivery is just the thing. + // Note: The server can't ACK anything here, but none of what + // the client has sent so far has been transferred. + mem::drop(send_something(&mut server, now)); + let dgram = send_something(&mut server, now); + + // The client is now after a PTO, but if it receives something + // that demands acknowledgment, it will send just the ACK. + let ack = client.process(Some(&dgram), now).dgram(); + assert!(ack.is_some()); + + // Make sure that the packet only contained an ACK frame. + let all_frames_before = server.stats().frame_rx.all; + let ack_before = server.stats().frame_rx.ack; + server.process_input(&ack.unwrap(), now); + assert_eq!(server.stats().frame_rx.all, all_frames_before + 1); + assert_eq!(server.stats().frame_rx.ack, ack_before + 1); +} + +/// When we declare a packet as lost, we keep it around for a while for another loss period. +/// Those packets should not affect how we report the loss recovery timer. +/// As the loss recovery timer based on RTT we use that to drive the state. +#[test] +fn lost_but_kept_and_lr_timer() { + const RTT: Duration = Duration::from_secs(1); + let mut client = default_client(); + let mut server = default_server(); + let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT); + + // Two packets (p1, p2) are sent at around t=0. The first is lost. + let _p1 = send_something(&mut client, now); + let p2 = send_something(&mut client, now); + + // At t=RTT/2 the server receives the packet and ACKs it. + now += RTT / 2; + let ack = server.process(Some(&p2), now).dgram(); + assert!(ack.is_some()); + // The client also sends another two packets (p3, p4), again losing the first. + let _p3 = send_something(&mut client, now); + let p4 = send_something(&mut client, now); + + // At t=RTT the client receives the ACK and goes into timed loss recovery. + // The client doesn't call p1 lost at this stage, but it will soon. + now += RTT / 2; + let res = client.process(ack.as_ref(), now); + // The client should be on a loss recovery timer as p1 is missing. + let lr_timer = res.callback(); + // Loss recovery timer should be RTT/8, but only check for 0 or >=RTT/2. + assert_ne!(lr_timer, Duration::from_secs(0)); + assert!(lr_timer < (RTT / 2)); + // The server also receives and acknowledges p4, again sending an ACK. + let ack = server.process(Some(&p4), now).dgram(); + assert!(ack.is_some()); + + // At t=RTT*3/2 the client should declare p1 to be lost. + now += RTT / 2; + // So the client will send the data from p1 again. + let res = client.process(None, now); + assert!(res.dgram().is_some()); + // When the client processes the ACK, it should engage the + // loss recovery timer for p3, not p1 (even though it still tracks p1). + let res = client.process(ack.as_ref(), now); + let lr_timer2 = res.callback(); + assert_eq!(lr_timer, lr_timer2); +} + +/// We should not be setting the loss recovery timer based on packets +/// that are sent prior to the largest acknowledged. +/// Testing this requires that we construct a case where one packet +/// number space causes the loss recovery timer to be engaged. At the same time, +/// there is a packet in another space that hasn't been acknowledged AND +/// that packet number space has not received acknowledgments for later packets. +#[test] +fn loss_time_past_largest_acked() { + const RTT: Duration = Duration::from_secs(10); + const INCR: Duration = Duration::from_millis(1); + let mut client = default_client(); + let mut server = default_server(); + + let mut now = now(); + + // Start the handshake. + let c_in = client.process(None, now).dgram(); + now += RTT / 2; + let s_hs1 = server.process(c_in.as_ref(), now).dgram(); + + // Get some spare server handshake packets for the client to ACK. + // This involves a time machine, so be a little cautious. + // This test uses an RTT of 10s, but our server starts + // with a much lower RTT estimate, so the PTO at this point should + // be much smaller than an RTT and so the server shouldn't see + // time go backwards. + let s_pto = server.process(None, now).callback(); + assert_ne!(s_pto, Duration::from_secs(0)); + assert!(s_pto < RTT); + let s_hs2 = server.process(None, now + s_pto).dgram(); + assert!(s_hs2.is_some()); + let s_hs3 = server.process(None, now + s_pto).dgram(); + assert!(s_hs3.is_some()); + + // Get some Handshake packets from the client. + // We need one to be left unacknowledged before one that is acknowledged. + // So that the client engages the loss recovery timer. + // This is complicated by the fact that it is hard to cause the client + // to generate an ack-eliciting packet. For that, we use the Finished message. + // Reordering delivery ensures that the later packet is also acknowledged. + now += RTT / 2; + let c_hs1 = client.process(s_hs1.as_ref(), now).dgram(); + assert!(c_hs1.is_some()); // This comes first, so it's useless. + maybe_authenticate(&mut client); + let c_hs2 = client.process(None, now).dgram(); + assert!(c_hs2.is_some()); // This one will elicit an ACK. + + // The we need the outstanding packet to be sent after the + // application data packet, so space these out a tiny bit. + let _p1 = send_something(&mut client, now + INCR); + let c_hs3 = client.process(s_hs2.as_ref(), now + (INCR * 2)).dgram(); + assert!(c_hs3.is_some()); // This will be left outstanding. + let c_hs4 = client.process(s_hs3.as_ref(), now + (INCR * 3)).dgram(); + assert!(c_hs4.is_some()); // This will be acknowledged. + + // Process c_hs2 and c_hs4, but skip c_hs3. + // Then get an ACK for the client. + now += RTT / 2; + // Deliver c_hs4 first, but don't generate a packet. + server.process_input(&c_hs4.unwrap(), now); + let s_ack = server.process(c_hs2.as_ref(), now).dgram(); + assert!(s_ack.is_some()); + // This includes an ACK, but it also includes HANDSHAKE_DONE, + // which we need to remove because that will cause the Handshake loss + // recovery state to be dropped. + let (s_hs_ack, _s_ap_ack) = split_datagram(&s_ack.unwrap()); + + // Now the client should start its loss recovery timer based on the ACK. + now += RTT / 2; + let c_ack = client.process(Some(&s_hs_ack), now).dgram(); + assert!(c_ack.is_none()); + // The client should now have the loss recovery timer active. + let lr_time = client.process(None, now).callback(); + assert_ne!(lr_time, Duration::from_secs(0)); + assert!(lr_time < (RTT / 2)); +} + +/// `sender` sends a little, `receiver` acknowledges it. +/// Repeat until `count` acknowledgements are sent. +/// Returns the last packet containing acknowledgements, if any. +fn trickle(sender: &mut Connection, receiver: &mut Connection, mut count: usize, now: Instant) { + let id = sender.stream_create(StreamType::UniDi).unwrap(); + let mut maybe_ack = None; + while count > 0 { + qdebug!("trickle: remaining={}", count); + assert_eq!(sender.stream_send(id, &[9]).unwrap(), 1); + let dgram = sender.process(maybe_ack.as_ref(), now).dgram(); + + maybe_ack = receiver.process(dgram.as_ref(), now).dgram(); + count -= usize::from(maybe_ack.is_some()); + } + sender.process_input(&maybe_ack.unwrap(), now); +} + +/// Ensure that a PING frame is sent with ACK sometimes. +/// `fast` allows testing of when `MAX_OUTSTANDING_UNACK` packets are +/// outstanding (`fast` is `true`) within 1 PTO and when only +/// `MIN_OUTSTANDING_UNACK` packets arrive after 2 PTOs (`fast` is `false`). +fn ping_with_ack(fast: bool) { + let mut sender = default_client(); + let mut receiver = default_server(); + let mut now = now(); + connect_force_idle(&mut sender, &mut receiver); + let sender_acks_before = sender.stats().frame_tx.ack; + let receiver_acks_before = receiver.stats().frame_tx.ack; + let count = if fast { + MAX_OUTSTANDING_UNACK + } else { + MIN_OUTSTANDING_UNACK + }; + trickle(&mut sender, &mut receiver, count, now); + assert_eq!(sender.stats().frame_tx.ack, sender_acks_before); + assert_eq!(receiver.stats().frame_tx.ack, receiver_acks_before + count); + assert_eq!(receiver.stats().frame_tx.ping, 0); + + if !fast { + // Wait at least one PTO, from the reciever's perspective. + // A receiver that hasn't received MAX_OUTSTANDING_UNACK won't send PING. + now += receiver.pto() + Duration::from_micros(1); + trickle(&mut sender, &mut receiver, 1, now); + assert_eq!(receiver.stats().frame_tx.ping, 0); + } + + // After a second PTO (or the first if fast), new acknowledgements come + // with a PING frame and cause an ACK to be sent by the sender. + now += receiver.pto() + Duration::from_micros(1); + trickle(&mut sender, &mut receiver, 1, now); + assert_eq!(receiver.stats().frame_tx.ping, 1); + if let Output::Callback(t) = sender.process_output(now) { + assert_eq!(t, DEFAULT_ACK_DELAY); + assert!(sender.process_output(now + t).dgram().is_some()); + } + assert_eq!(sender.stats().frame_tx.ack, sender_acks_before + 1); +} + +#[test] +fn ping_with_ack_fast() { + ping_with_ack(true); +} + +#[test] +fn ping_with_ack_slow() { + ping_with_ack(false); +} + +#[test] +fn ping_with_ack_min() { + const COUNT: usize = MIN_OUTSTANDING_UNACK - 2; + let mut sender = default_client(); + let mut receiver = default_server(); + let mut now = now(); + connect_force_idle(&mut sender, &mut receiver); + let sender_acks_before = sender.stats().frame_tx.ack; + let receiver_acks_before = receiver.stats().frame_tx.ack; + trickle(&mut sender, &mut receiver, COUNT, now); + assert_eq!(sender.stats().frame_tx.ack, sender_acks_before); + assert_eq!(receiver.stats().frame_tx.ack, receiver_acks_before + COUNT); + assert_eq!(receiver.stats().frame_tx.ping, 0); + + // After 3 PTO, no PING because there are too few outstanding packets. + now += receiver.pto() * 3 + Duration::from_micros(1); + trickle(&mut sender, &mut receiver, 1, now); + assert_eq!(receiver.stats().frame_tx.ping, 0); +} + +/// This calculates the PTO timer immediately after connection establishment. +/// It depends on there only being 2 RTT samples in the handshake. +fn expected_pto(rtt: Duration) -> Duration { + // PTO calculation is rtt + 4rttvar + ack delay. + // rttvar should be (rtt + 4 * (rtt / 2) * (3/4)^n + 25ms)/2 + // where n is the number of round trips + // This uses a 25ms ack delay as the ACK delay extension + // is negotiated and no ACK_DELAY frame has been received. + rtt + rtt * 9 / 8 + Duration::from_millis(25) +} + +#[test] +fn fast_pto() { + let mut client = new_client(ConnectionParameters::default().fast_pto(FAST_PTO_SCALE / 2)); + let mut server = default_server(); + let mut now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + let res = client.process(None, now); + let idle_timeout = ConnectionParameters::default().get_idle_timeout() - (DEFAULT_RTT / 2); + assert_eq!(res, Output::Callback(idle_timeout)); + + // Send data on two streams + let stream = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!( + client.stream_send(stream, DEFAULT_STREAM_DATA).unwrap(), + DEFAULT_STREAM_DATA.len() + ); + + // Send a packet after some time. + now += idle_timeout / 2; + let dgram = client.process_output(now).dgram(); + assert!(dgram.is_some()); + + // Nothing to do, should return a callback. + let cb = client.process_output(now).callback(); + assert_eq!(expected_pto(DEFAULT_RTT) / 2, cb); + + // Once the PTO timer expires, a PTO packet should be sent should want to send PTO packet. + now += cb; + let dgram = client.process(None, now).dgram(); + + let stream_before = server.stats().frame_rx.stream; + server.process_input(&dgram.unwrap(), now); + assert_eq!(server.stats().frame_rx.stream, stream_before + 1); +} + +/// Even if the PTO timer is slowed right down, persistent congestion is declared +/// based on the "true" value of the timer. +#[test] +fn fast_pto_persistent_congestion() { + let mut client = new_client(ConnectionParameters::default().fast_pto(FAST_PTO_SCALE * 2)); + let mut server = default_server(); + let mut now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + let res = client.process(None, now); + let idle_timeout = ConnectionParameters::default().get_idle_timeout() - (DEFAULT_RTT / 2); + assert_eq!(res, Output::Callback(idle_timeout)); + + // Send packets spaced by the PTO timer. And lose them. + // Note: This timing is a tiny bit higher than the client will use + // to determine persistent congestion. The ACK below adds another RTT + // estimate, which will reduce rttvar by 3/4, so persistent congestion + // will occur at `rtt + rtt*27/32 + 25ms`. + // That is OK as we're still showing that this interval is less than + // six times the PTO, which is what would be used if the scaling + // applied to the PTO used to determine persistent congestion. + let pc_interval = expected_pto(DEFAULT_RTT) * 3; + println!("pc_interval {pc_interval:?}"); + let _drop1 = send_something(&mut client, now); + + // Check that the PTO matches expectations. + let cb = client.process_output(now).callback(); + assert_eq!(expected_pto(DEFAULT_RTT) * 2, cb); + + now += pc_interval; + let _drop2 = send_something(&mut client, now); + let _drop3 = send_something(&mut client, now); + let _drop4 = send_something(&mut client, now); + let dgram = send_something(&mut client, now); + + // Now acknowledge the tail packet and enter persistent congestion. + now += DEFAULT_RTT / 2; + let ack = server.process(Some(&dgram), now).dgram(); + now += DEFAULT_RTT / 2; + client.process_input(&ack.unwrap(), now); + assert_eq!(cwnd(&client), CWND_MIN); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/resumption.rs b/third_party/rust/neqo-transport/src/connection/tests/resumption.rs new file mode 100644 index 0000000000..a8c45a9f06 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/resumption.rs @@ -0,0 +1,246 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{cell::RefCell, mem, rc::Rc, time::Duration}; + +use test_fixture::{self, assertions, now}; + +use super::{ + connect, connect_with_rtt, default_client, default_server, exchange_ticket, get_tokens, + new_client, resumed_server, send_something, AT_LEAST_PTO, +}; +use crate::{ + addr_valid::{AddressValidation, ValidateAddress}, + ConnectionParameters, Error, Version, +}; + +#[test] +fn resume() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let token = exchange_ticket(&mut client, &mut server, now()); + let mut client = default_client(); + client + .enable_resumption(now(), token) + .expect("should set token"); + let mut server = resumed_server(&client); + connect(&mut client, &mut server); + assert!(client.tls_info().unwrap().resumed()); + assert!(server.tls_info().unwrap().resumed()); +} + +#[test] +fn remember_smoothed_rtt() { + const RTT1: Duration = Duration::from_millis(130); + const RTT2: Duration = Duration::from_millis(70); + + let mut client = default_client(); + let mut server = default_server(); + + let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT1); + assert_eq!(client.paths.rtt(), RTT1); + + // We can't use exchange_ticket here because it doesn't respect RTT. + // Also, connect_with_rtt() ends with the server receiving a packet it + // wants to acknowledge; so the ticket will include an ACK frame too. + let validation = AddressValidation::new(now, ValidateAddress::NoToken).unwrap(); + let validation = Rc::new(RefCell::new(validation)); + server.set_validation(Rc::clone(&validation)); + server.send_ticket(now, &[]).expect("can send ticket"); + let ticket = server.process_output(now).dgram(); + assert!(ticket.is_some()); + now += RTT1 / 2; + client.process_input(&ticket.unwrap(), now); + let token = get_tokens(&mut client).pop().unwrap(); + + let mut client = default_client(); + client.enable_resumption(now, token).unwrap(); + assert_eq!( + client.paths.rtt(), + RTT1, + "client should remember previous RTT" + ); + let mut server = resumed_server(&client); + + connect_with_rtt(&mut client, &mut server, now, RTT2); + assert_eq!( + client.paths.rtt(), + RTT2, + "previous RTT should be completely erased" + ); +} + +/// Check that a resumed connection uses a token on Initial packets. +#[test] +fn address_validation_token_resume() { + const RTT: Duration = Duration::from_millis(10); + + let mut client = default_client(); + let mut server = default_server(); + let validation = AddressValidation::new(now(), ValidateAddress::Always).unwrap(); + let validation = Rc::new(RefCell::new(validation)); + server.set_validation(Rc::clone(&validation)); + let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT); + + let token = exchange_ticket(&mut client, &mut server, now); + let mut client = default_client(); + client.enable_resumption(now, token).unwrap(); + let mut server = resumed_server(&client); + + // Grab an Initial packet from the client. + let dgram = client.process(None, now).dgram(); + assertions::assert_initial(dgram.as_ref().unwrap(), true); + + // Now try to complete the handshake after giving time for a client PTO. + now += AT_LEAST_PTO; + connect_with_rtt(&mut client, &mut server, now, RTT); + assert!(client.crypto.tls.info().unwrap().resumed()); + assert!(server.crypto.tls.info().unwrap().resumed()); +} + +fn can_resume(token: impl AsRef<[u8]>, initial_has_token: bool) { + let mut client = default_client(); + client.enable_resumption(now(), token).unwrap(); + let initial = client.process_output(now()).dgram(); + assertions::assert_initial(initial.as_ref().unwrap(), initial_has_token); +} + +#[test] +fn two_tickets_on_timer() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // Send two tickets and then bundle those into a packet. + server.send_ticket(now(), &[]).expect("send ticket1"); + server.send_ticket(now(), &[]).expect("send ticket2"); + let pkt = send_something(&mut server, now()); + + // process() will return an ack first + assert!(client.process(Some(&pkt), now()).dgram().is_some()); + // We do not have a ResumptionToken event yet, because NEW_TOKEN was not sent. + assert_eq!(get_tokens(&mut client).len(), 0); + + // We need to wait for release_resumption_token_timer to expire. The timer will be + // set to 3 * PTO + let mut now = now() + 3 * client.pto(); + mem::drop(client.process(None, now)); + let mut recv_tokens = get_tokens(&mut client); + assert_eq!(recv_tokens.len(), 1); + let token1 = recv_tokens.pop().unwrap(); + // Wai for anottheer 3 * PTO to get the nex okeen. + now += 3 * client.pto(); + mem::drop(client.process(None, now)); + let mut recv_tokens = get_tokens(&mut client); + assert_eq!(recv_tokens.len(), 1); + let token2 = recv_tokens.pop().unwrap(); + // Wait for 3 * PTO, but now there are no more tokens. + now += 3 * client.pto(); + mem::drop(client.process(None, now)); + assert_eq!(get_tokens(&mut client).len(), 0); + assert_ne!(token1.as_ref(), token2.as_ref()); + + can_resume(token1, false); + can_resume(token2, false); +} + +#[test] +fn two_tickets_with_new_token() { + let mut client = default_client(); + let mut server = default_server(); + let validation = AddressValidation::new(now(), ValidateAddress::Always).unwrap(); + let validation = Rc::new(RefCell::new(validation)); + server.set_validation(Rc::clone(&validation)); + connect(&mut client, &mut server); + + // Send two tickets with tokens and then bundle those into a packet. + server.send_ticket(now(), &[]).expect("send ticket1"); + server.send_ticket(now(), &[]).expect("send ticket2"); + let pkt = send_something(&mut server, now()); + + client.process_input(&pkt, now()); + let mut all_tokens = get_tokens(&mut client); + assert_eq!(all_tokens.len(), 2); + let token1 = all_tokens.pop().unwrap(); + let token2 = all_tokens.pop().unwrap(); + assert_ne!(token1.as_ref(), token2.as_ref()); + + can_resume(token1, true); + can_resume(token2, true); +} + +/// By disabling address validation, the server won't send `NEW_TOKEN`, but +/// we can take the session ticket still. +#[test] +fn take_token() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + server.send_ticket(now(), &[]).unwrap(); + let dgram = server.process(None, now()).dgram(); + client.process_input(&dgram.unwrap(), now()); + + // There should be no ResumptionToken event here. + let tokens = get_tokens(&mut client); + assert_eq!(tokens.len(), 0); + + // But we should be able to get the token directly, and use it. + let token = client.take_resumption_token(now()).unwrap(); + can_resume(token, false); +} + +/// If a version is selected and subsequently disabled, resumption fails. +#[test] +fn resume_disabled_version() { + let mut client = new_client( + ConnectionParameters::default().versions(Version::Version1, vec![Version::Version1]), + ); + let mut server = default_server(); + connect(&mut client, &mut server); + let token = exchange_ticket(&mut client, &mut server, now()); + + let mut client = new_client( + ConnectionParameters::default().versions(Version::Version2, vec![Version::Version2]), + ); + assert_eq!( + client.enable_resumption(now(), token).unwrap_err(), + Error::DisabledVersion + ); +} + +/// It's not possible to resume once a packet has been sent. +#[test] +fn resume_after_packet() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let token = exchange_ticket(&mut client, &mut server, now()); + + let mut client = default_client(); + mem::drop(client.process_output(now()).dgram().unwrap()); + assert_eq!( + client.enable_resumption(now(), token).unwrap_err(), + Error::ConnectionState + ); +} + +/// It's not possible to resume at the server. +#[test] +fn resume_server() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let token = exchange_ticket(&mut client, &mut server, now()); + + let mut server = default_server(); + assert_eq!( + server.enable_resumption(now(), token).unwrap_err(), + Error::ConnectionState + ); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/stream.rs b/third_party/rust/neqo-transport/src/connection/tests/stream.rs new file mode 100644 index 0000000000..586a537b9d --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/stream.rs @@ -0,0 +1,1162 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{cmp::max, collections::HashMap, convert::TryFrom, mem}; + +use neqo_common::{event::Provider, qdebug}; +use test_fixture::now; + +use super::{ + super::State, assert_error, connect, connect_force_idle, default_client, default_server, + maybe_authenticate, new_client, new_server, send_something, DEFAULT_STREAM_DATA, +}; +use crate::{ + events::ConnectionEvent, + recv_stream::RECV_BUFFER_SIZE, + send_stream::{OrderGroup, SendStreamState, SEND_BUFFER_SIZE}, + streams::{SendOrder, StreamOrder}, + tparams::{self, TransportParameter}, + // tracking::DEFAULT_ACK_PACKET_TOLERANCE, + Connection, + ConnectionError, + ConnectionParameters, + Error, + StreamId, + StreamType, +}; + +#[test] +fn stream_create() { + let mut client = default_client(); + + let out = client.process(None, now()); + let mut server = default_server(); + let out = server.process(out.as_dgram_ref(), now()); + + let out = client.process(out.as_dgram_ref(), now()); + mem::drop(server.process(out.as_dgram_ref(), now())); + assert!(maybe_authenticate(&mut client)); + let out = client.process(None, now()); + + // client now in State::Connected + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 6); + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0); + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 4); + + mem::drop(server.process(out.as_dgram_ref(), now())); + // server now in State::Connected + assert_eq!(server.stream_create(StreamType::UniDi).unwrap(), 3); + assert_eq!(server.stream_create(StreamType::UniDi).unwrap(), 7); + assert_eq!(server.stream_create(StreamType::BiDi).unwrap(), 1); + assert_eq!(server.stream_create(StreamType::BiDi).unwrap(), 5); +} + +#[test] +// tests stream send/recv after connection is established. +fn transfer() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + qdebug!("---- client sends"); + // Send + let client_stream_id = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(client_stream_id, &[6; 100]).unwrap(); + client.stream_send(client_stream_id, &[7; 40]).unwrap(); + client.stream_send(client_stream_id, &[8; 4000]).unwrap(); + + // Send to another stream but some data after fin has been set + let client_stream_id2 = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(client_stream_id2, &[6; 60]).unwrap(); + client.stream_close_send(client_stream_id2).unwrap(); + client.stream_send(client_stream_id2, &[7; 50]).unwrap_err(); + // Sending this much takes a few datagrams. + let mut datagrams = vec![]; + let mut out = client.process_output(now()); + while let Some(d) = out.dgram() { + datagrams.push(d); + out = client.process_output(now()); + } + assert_eq!(datagrams.len(), 4); + assert_eq!(*client.state(), State::Confirmed); + + qdebug!("---- server receives"); + for d in datagrams { + let out = server.process(Some(&d), now()); + // With an RTT of zero, the server will acknowledge every packet immediately. + assert!(out.as_dgram_ref().is_some()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + } + assert_eq!(*server.state(), State::Confirmed); + + let mut buf = vec![0; 4000]; + + let mut stream_ids = server.events().filter_map(|evt| match evt { + ConnectionEvent::NewStream { stream_id, .. } => Some(stream_id), + _ => None, + }); + let first_stream = stream_ids.next().expect("should have a new stream event"); + let second_stream = stream_ids + .next() + .expect("should have a second new stream event"); + assert!(stream_ids.next().is_none()); + let (received1, fin1) = server.stream_recv(first_stream, &mut buf).unwrap(); + assert_eq!(received1, 4000); + assert!(!fin1); + let (received2, fin2) = server.stream_recv(first_stream, &mut buf).unwrap(); + assert_eq!(received2, 140); + assert!(!fin2); + + let (received3, fin3) = server.stream_recv(second_stream, &mut buf).unwrap(); + assert_eq!(received3, 60); + assert!(fin3); +} + +#[derive(PartialEq, Eq, PartialOrd, Ord)] +struct IdEntry { + sendorder: StreamOrder, + stream_id: StreamId, +} + +// tests stream sendorder priorization +fn sendorder_test(order_of_sendorder: &[Option<SendOrder>]) { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + qdebug!("---- client sends"); + // open all streams and set the sendorders + let mut ordered = Vec::new(); + let mut streams = Vec::<StreamId>::new(); + for sendorder in order_of_sendorder { + let id = client.stream_create(StreamType::UniDi).unwrap(); + streams.push(id); + ordered.push((id, *sendorder)); + // must be set before sendorder + client.streams.set_fairness(id, true).ok(); + client.streams.set_sendorder(id, *sendorder).ok(); + } + // Write some data to all the streams + for stream_id in streams { + client.stream_send(stream_id, &[6; 100]).unwrap(); + } + + // Sending this much takes a few datagrams. + // Note: this test uses an RTT of 0 which simplifies things (no pacing) + let mut datagrams = Vec::new(); + let mut out = client.process_output(now()); + while let Some(d) = out.dgram() { + datagrams.push(d); + out = client.process_output(now()); + } + assert_eq!(*client.state(), State::Confirmed); + + qdebug!("---- server receives"); + for d in datagrams { + let out = server.process(Some(&d), now()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + } + assert_eq!(*server.state(), State::Confirmed); + + let stream_ids = server + .events() + .filter_map(|evt| match evt { + ConnectionEvent::RecvStreamReadable { stream_id, .. } => Some(stream_id), + _ => None, + }) + .enumerate() + .map(|(a, b)| (b, a)) + .collect::<HashMap<_, _>>(); + + // streams should arrive in priority order, not order of creation, if sendorder prioritization + // is working correctly + + // 'ordered' has the send order currently. Re-sort it by sendorder, but + // if two items from the same sendorder exist, secondarily sort by the ordering in + // the stream_ids vector (HashMap<StreamId, index: usize>) + ordered.sort_unstable_by_key(|(stream_id, sendorder)| { + ( + StreamOrder { + sendorder: *sendorder, + }, + stream_ids[stream_id], + ) + }); + // make sure everything now is in the same order, since we modified the order of + // same-sendorder items to match the ordering of those we saw in reception + for (i, (stream_id, _sendorder)) in ordered.iter().enumerate() { + assert_eq!(i, stream_ids[stream_id]); + } +} + +#[test] +fn sendorder_0() { + sendorder_test(&[None, Some(1), Some(2), Some(3)]); +} +#[test] +fn sendorder_1() { + sendorder_test(&[Some(3), Some(2), Some(1), None]); +} +#[test] +fn sendorder_2() { + sendorder_test(&[Some(3), None, Some(2), Some(1)]); +} +#[test] +fn sendorder_3() { + sendorder_test(&[Some(1), Some(2), None, Some(3)]); +} +#[test] +fn sendorder_4() { + sendorder_test(&[ + Some(1), + Some(2), + Some(1), + None, + Some(3), + Some(1), + Some(3), + None, + ]); +} + +// Tests stream sendorder priorization +// Converts Vecs of u64's into StreamIds +fn fairness_test<S, R>(source: S, number_iterates: usize, truncate_to: usize, result_array: &R) +where + S: IntoIterator, + S::Item: Into<StreamId>, + R: IntoIterator + std::fmt::Debug, + R::Item: Into<StreamId>, + Vec<u64>: PartialEq<R>, +{ + // test the OrderGroup code used for fairness + let mut group: OrderGroup = OrderGroup::default(); + for stream_id in source { + group.insert(stream_id.into()); + } + { + let mut iterator1 = group.iter(); + // advance_by() would help here + let mut n = number_iterates; + while n > 0 { + iterator1.next(); + n -= 1; + } + // let iterator1 go out of scope + } + group.truncate(truncate_to); + + let iterator2 = group.iter(); + let result: Vec<u64> = iterator2.map(StreamId::as_u64).collect(); + assert_eq!(result, *result_array); +} + +#[test] +fn ordergroup_0() { + let source: [u64; 0] = []; + let result: [u64; 0] = []; + fairness_test(source, 1, usize::MAX, &result); +} + +#[test] +fn ordergroup_1() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 6] = [1, 2, 3, 4, 5, 0]; + fairness_test(source, 1, usize::MAX, &result); +} + +#[test] +fn ordergroup_2() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 6] = [2, 3, 4, 5, 0, 1]; + fairness_test(source, 2, usize::MAX, &result); +} + +#[test] +fn ordergroup_3() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 6] = [0, 1, 2, 3, 4, 5]; + fairness_test(source, 10, usize::MAX, &result); +} + +#[test] +fn ordergroup_4() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 6] = [0, 1, 2, 3, 4, 5]; + fairness_test(source, 0, usize::MAX, &result); +} + +#[test] +fn ordergroup_5() { + let source: [u64; 1] = [0]; + let result: [u64; 1] = [0]; + fairness_test(source, 1, usize::MAX, &result); +} + +#[test] +fn ordergroup_6() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 6] = [5, 0, 1, 2, 3, 4]; + fairness_test(source, 5, usize::MAX, &result); +} + +#[test] +fn ordergroup_7() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 3] = [0, 1, 2]; + fairness_test(source, 5, 3, &result); +} + +#[test] +// Send fin even if a peer closes a reomte bidi send stream before sending any data. +fn report_fin_when_stream_closed_wo_data() { + // Note that the two servers in this test will get different anti-replay filters. + // That's OK because we aren't testing anti-replay. + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // create a stream + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_send(stream_id, &[0x00]).unwrap(); + let out = client.process(None, now()); + mem::drop(server.process(out.as_dgram_ref(), now())); + + server.stream_close_send(stream_id).unwrap(); + let out = server.process(None, now()); + mem::drop(client.process(out.as_dgram_ref(), now())); + let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable { .. }); + assert!(client.events().any(stream_readable)); +} + +fn exchange_data(client: &mut Connection, server: &mut Connection) { + let mut input = None; + loop { + let out = client.process(input.as_ref(), now()).dgram(); + let c_done = out.is_none(); + let out = server.process(out.as_ref(), now()).dgram(); + if out.is_none() && c_done { + break; + } + input = out; + } +} + +#[test] +fn sending_max_data() { + const SMALL_MAX_DATA: usize = 2048; + + let mut client = default_client(); + let mut server = new_server( + ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()), + ); + + connect(&mut client, &mut server); + + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(client.events().count(), 2); // SendStreamWritable, StateChange(connected) + assert_eq!(stream_id, 2); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SMALL_MAX_DATA + ); + + assert_eq!( + client + .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1]) + .unwrap(), + SMALL_MAX_DATA + ); + + exchange_data(&mut client, &mut server); + + let mut buf = vec![0; 40000]; + let (received, fin) = server.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(received, SMALL_MAX_DATA); + assert!(!fin); + + let out = server.process(None, now()).dgram(); + client.process_input(&out.unwrap(), now()); + + assert_eq!( + client + .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1]) + .unwrap(), + SMALL_MAX_DATA + ); +} + +#[test] +fn max_data() { + const SMALL_MAX_DATA: usize = 16383; + + let mut client = default_client(); + let mut server = default_server(); + + server + .set_local_tparam( + tparams::INITIAL_MAX_DATA, + TransportParameter::Integer(u64::try_from(SMALL_MAX_DATA).unwrap()), + ) + .unwrap(); + + connect(&mut client, &mut server); + + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(client.events().count(), 2); // SendStreamWritable, StateChange(connected) + assert_eq!(stream_id, 2); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SMALL_MAX_DATA + ); + assert_eq!( + client + .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1]) + .unwrap(), + SMALL_MAX_DATA + ); + assert_eq!(client.events().count(), 0); + + assert_eq!(client.stream_send(stream_id, b"hello").unwrap(), 0); + client + .streams + .get_send_stream_mut(stream_id) + .unwrap() + .mark_as_sent(0, 4096, false); + assert_eq!(client.events().count(), 0); + client + .streams + .get_send_stream_mut(stream_id) + .unwrap() + .mark_as_acked(0, 4096, false); + assert_eq!(client.events().count(), 0); + + assert_eq!(client.stream_send(stream_id, b"hello").unwrap(), 0); + // no event because still limited by conn max data + assert_eq!(client.events().count(), 0); + + // Increase max data. Avail space now limited by stream credit + client.streams.handle_max_data(100_000_000); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SEND_BUFFER_SIZE - SMALL_MAX_DATA + ); + + // Increase max stream data. Avail space now limited by tx buffer + client + .streams + .get_send_stream_mut(stream_id) + .unwrap() + .set_max_stream_data(100_000_000); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SEND_BUFFER_SIZE - SMALL_MAX_DATA + 4096 + ); + + let evts = client.events().collect::<Vec<_>>(); + assert_eq!(evts.len(), 1); + assert!(matches!( + evts[0], + ConnectionEvent::SendStreamWritable { .. } + )); +} + +#[test] +fn exceed_max_data() { + const SMALL_MAX_DATA: usize = 1024; + + let mut client = default_client(); + let mut server = new_server( + ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()), + ); + + connect(&mut client, &mut server); + + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(client.events().count(), 2); // SendStreamWritable, StateChange(connected) + assert_eq!(stream_id, 2); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SMALL_MAX_DATA + ); + assert_eq!( + client + .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1]) + .unwrap(), + SMALL_MAX_DATA + ); + + assert_eq!(client.stream_send(stream_id, b"hello").unwrap(), 0); + + // Artificially trick the client to think that it has more flow control credit. + client.streams.handle_max_data(100_000_000); + assert_eq!(client.stream_send(stream_id, b"h").unwrap(), 1); + + exchange_data(&mut client, &mut server); + + assert_error( + &client, + &ConnectionError::Transport(Error::PeerError(Error::FlowControlError.code())), + ); + assert_error( + &server, + &ConnectionError::Transport(Error::FlowControlError), + ); +} + +#[test] +// If we send a stop_sending to the peer, we should not accept more data from the peer. +fn do_not_accept_data_after_stop_sending() { + // Note that the two servers in this test will get different anti-replay filters. + // That's OK because we aren't testing anti-replay. + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // create a stream + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_send(stream_id, &[0x00]).unwrap(); + let out = client.process(None, now()); + mem::drop(server.process(out.as_dgram_ref(), now())); + + let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable { .. }); + assert!(server.events().any(stream_readable)); + + // Send one more packet from client. The packet should arrive after the server + // has already requested stop_sending. + client.stream_send(stream_id, &[0x00]).unwrap(); + let out_second_data_frame = client.process(None, now()); + // Call stop sending. + assert_eq!( + Ok(()), + server.stream_stop_sending(stream_id, Error::NoError.code()) + ); + + // Receive the second data frame. The frame should be ignored and + // DataReadable events shouldn't be posted. + let out = server.process(out_second_data_frame.as_dgram_ref(), now()); + assert!(!server.events().any(stream_readable)); + + mem::drop(client.process(out.as_dgram_ref(), now())); + assert_eq!( + Err(Error::FinalSizeError), + client.stream_send(stream_id, &[0x00]) + ); +} + +#[test] +// Server sends stop_sending, the client simultaneous sends reset. +fn simultaneous_stop_sending_and_reset() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // create a stream + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_send(stream_id, &[0x00]).unwrap(); + let out = client.process(None, now()); + let ack = server.process(out.as_dgram_ref(), now()).dgram(); + + let stream_readable = + |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id: id } if id == stream_id); + assert!(server.events().any(stream_readable)); + + // The client resets the stream. The packet with reset should arrive after the server + // has already requested stop_sending. + client.stream_reset_send(stream_id, 0).unwrap(); + let out_reset_frame = client.process(ack.as_ref(), now()).dgram(); + + // Send something out of order to force the server to generate an + // acknowledgment at the next opportunity. + let force_ack = send_something(&mut client, now()); + server.process_input(&force_ack, now()); + + // Call stop sending. + server.stream_stop_sending(stream_id, 0).unwrap(); + // Receive the second data frame. The frame should be ignored and + // DataReadable events shouldn't be posted. + let ack = server.process(out_reset_frame.as_ref(), now()).dgram(); + assert!(ack.is_some()); + assert!(!server.events().any(stream_readable)); + + // The client gets the STOP_SENDING frame. + client.process_input(&ack.unwrap(), now()); + assert_eq!( + Err(Error::InvalidStreamId), + client.stream_send(stream_id, &[0x00]) + ); +} + +#[test] +fn client_fin_reorder() { + let mut client = default_client(); + let mut server = default_server(); + + // Send ClientHello. + let client_hs = client.process(None, now()); + assert!(client_hs.as_dgram_ref().is_some()); + + let server_hs = server.process(client_hs.as_dgram_ref(), now()); + assert!(server_hs.as_dgram_ref().is_some()); // ServerHello, etc... + + let client_ack = client.process(server_hs.as_dgram_ref(), now()); + assert!(client_ack.as_dgram_ref().is_some()); + + let server_out = server.process(client_ack.as_dgram_ref(), now()); + assert!(server_out.as_dgram_ref().is_none()); + + assert!(maybe_authenticate(&mut client)); + assert_eq!(*client.state(), State::Connected); + + let client_fin = client.process(None, now()); + assert!(client_fin.as_dgram_ref().is_some()); + + let client_stream_id = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(client_stream_id, &[1, 2, 3]).unwrap(); + let client_stream_data = client.process(None, now()); + assert!(client_stream_data.as_dgram_ref().is_some()); + + // Now stream data gets before client_fin + let server_out = server.process(client_stream_data.as_dgram_ref(), now()); + assert!(server_out.as_dgram_ref().is_none()); // the packet will be discarded + + assert_eq!(*server.state(), State::Handshaking); + let server_out = server.process(client_fin.as_dgram_ref(), now()); + assert!(server_out.as_dgram_ref().is_some()); +} + +#[test] +fn after_fin_is_read_conn_events_for_stream_should_be_removed() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let id = server.stream_create(StreamType::BiDi).unwrap(); + server.stream_send(id, &[6; 10]).unwrap(); + server.stream_close_send(id).unwrap(); + let out = server.process(None, now()).dgram(); + assert!(out.is_some()); + + mem::drop(client.process(out.as_ref(), now())); + + // read from the stream before checking connection events. + let mut buf = vec![0; 4000]; + let (_, fin) = client.stream_recv(id, &mut buf).unwrap(); + assert!(fin); + + // Make sure we do not have RecvStreamReadable events for the stream when fin has been read. + let readable_stream_evt = + |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id } if stream_id == id); + assert!(!client.events().any(readable_stream_evt)); +} + +#[test] +fn after_stream_stop_sending_is_called_conn_events_for_stream_should_be_removed() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let id = server.stream_create(StreamType::BiDi).unwrap(); + server.stream_send(id, &[6; 10]).unwrap(); + server.stream_close_send(id).unwrap(); + let out = server.process(None, now()).dgram(); + assert!(out.is_some()); + + mem::drop(client.process(out.as_ref(), now())); + + // send stop seending. + client + .stream_stop_sending(id, Error::NoError.code()) + .unwrap(); + + // Make sure we do not have RecvStreamReadable events for the stream after stream_stop_sending + // has been called. + let readable_stream_evt = + |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id } if stream_id == id); + assert!(!client.events().any(readable_stream_evt)); +} + +#[test] +fn stream_data_blocked_generates_max_stream_data() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let now = now(); + + // Send some data and consume some flow control. + let stream_id = server.stream_create(StreamType::UniDi).unwrap(); + _ = server.stream_send(stream_id, DEFAULT_STREAM_DATA).unwrap(); + let dgram = server.process(None, now).dgram(); + assert!(dgram.is_some()); + + // Consume the data. + client.process_input(&dgram.unwrap(), now); + let mut buf = [0; 10]; + let (count, end) = client.stream_recv(stream_id, &mut buf[..]).unwrap(); + assert_eq!(count, DEFAULT_STREAM_DATA.len()); + assert!(!end); + + // Now send `STREAM_DATA_BLOCKED`. + let internal_stream = server.streams.get_send_stream_mut(stream_id).unwrap(); + if let SendStreamState::Send { fc, .. } = internal_stream.state() { + fc.blocked(); + } else { + panic!("unexpected stream state"); + } + let dgram = server.process_output(now).dgram(); + assert!(dgram.is_some()); + + let sdb_before = client.stats().frame_rx.stream_data_blocked; + let dgram = client.process(dgram.as_ref(), now).dgram(); + assert_eq!(client.stats().frame_rx.stream_data_blocked, sdb_before + 1); + assert!(dgram.is_some()); + + // Client should have sent a MAX_STREAM_DATA frame with just a small increase + // on the default window size. + let msd_before = server.stats().frame_rx.max_stream_data; + server.process_input(&dgram.unwrap(), now); + assert_eq!(server.stats().frame_rx.max_stream_data, msd_before + 1); + + // Test that the entirety of the receive buffer is available now. + let mut written = 0; + loop { + const LARGE_BUFFER: &[u8] = &[0; 1024]; + let amount = server.stream_send(stream_id, LARGE_BUFFER).unwrap(); + if amount == 0 { + break; + } + written += amount; + } + assert_eq!(written, RECV_BUFFER_SIZE); +} + +/// See <https://github.com/mozilla/neqo/issues/871> +#[test] +fn max_streams_after_bidi_closed() { + const REQUEST: &[u8] = b"ping"; + const RESPONSE: &[u8] = b"pong"; + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + while client.stream_create(StreamType::BiDi).is_ok() { + // Exhaust the stream limit. + } + // Write on the one stream and send that out. + _ = client.stream_send(stream_id, REQUEST).unwrap(); + client.stream_close_send(stream_id).unwrap(); + let dgram = client.process(None, now()).dgram(); + + // Now handle the stream and send an incomplete response. + server.process_input(&dgram.unwrap(), now()); + server.stream_send(stream_id, RESPONSE).unwrap(); + let dgram = server.process_output(now()).dgram(); + + // The server shouldn't have released more stream credit. + client.process_input(&dgram.unwrap(), now()); + let e = client.stream_create(StreamType::BiDi).unwrap_err(); + assert!(matches!(e, Error::StreamLimitError)); + + // Closing the stream isn't enough. + server.stream_close_send(stream_id).unwrap(); + let dgram = server.process_output(now()).dgram(); + client.process_input(&dgram.unwrap(), now()); + assert!(client.stream_create(StreamType::BiDi).is_err()); + + // The server needs to see an acknowledgment from the client for its + // response AND the server has to read all of the request. + // and the server needs to read all the data. Read first. + let mut buf = [0; REQUEST.len()]; + let (count, fin) = server.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(&buf[..count], REQUEST); + assert!(fin); + + // We need an ACK from the client now, but that isn't guaranteed, + // so give the client one more packet just in case. + let dgram = send_something(&mut server, now()); + client.process_input(&dgram, now()); + + // Now get the client to send the ACK and have the server handle that. + let dgram = send_something(&mut client, now()); + let dgram = server.process(Some(&dgram), now()).dgram(); + client.process_input(&dgram.unwrap(), now()); + assert!(client.stream_create(StreamType::BiDi).is_ok()); + assert!(client.stream_create(StreamType::BiDi).is_err()); +} + +#[test] +fn no_dupdata_readable_events() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // create a stream + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_send(stream_id, &[0x00]).unwrap(); + let out = client.process(None, now()); + mem::drop(server.process(out.as_dgram_ref(), now())); + + // We have a data_readable event. + let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable { .. }); + assert!(server.events().any(stream_readable)); + + // Send one more data frame from client. The previous stream data has not been read yet, + // therefore there should not be a new DataReadable event. + client.stream_send(stream_id, &[0x00]).unwrap(); + let out_second_data_frame = client.process(None, now()); + mem::drop(server.process(out_second_data_frame.as_dgram_ref(), now())); + assert!(!server.events().any(stream_readable)); + + // One more frame with a fin will not produce a new DataReadable event, because the + // previous stream data has not been read yet. + client.stream_send(stream_id, &[0x00]).unwrap(); + client.stream_close_send(stream_id).unwrap(); + let out_third_data_frame = client.process(None, now()); + mem::drop(server.process(out_third_data_frame.as_dgram_ref(), now())); + assert!(!server.events().any(stream_readable)); +} + +#[test] +fn no_dupdata_readable_events_empty_last_frame() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // create a stream + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_send(stream_id, &[0x00]).unwrap(); + let out = client.process(None, now()); + mem::drop(server.process(out.as_dgram_ref(), now())); + + // We have a data_readable event. + let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable { .. }); + assert!(server.events().any(stream_readable)); + + // An empty frame with a fin will not produce a new DataReadable event, because + // the previous stream data has not been read yet. + client.stream_close_send(stream_id).unwrap(); + let out_second_data_frame = client.process(None, now()); + mem::drop(server.process(out_second_data_frame.as_dgram_ref(), now())); + assert!(!server.events().any(stream_readable)); +} + +fn change_flow_control(stream_type: StreamType, new_fc: u64) { + const RECV_BUFFER_START: u64 = 300; + + let mut client = new_client( + ConnectionParameters::default() + .max_stream_data(StreamType::BiDi, true, RECV_BUFFER_START) + .max_stream_data(StreamType::UniDi, true, RECV_BUFFER_START), + ); + let mut server = default_server(); + connect(&mut client, &mut server); + + // create a stream + let stream_id = server.stream_create(stream_type).unwrap(); + let written1 = server.stream_send(stream_id, &[0x0; 10000]).unwrap(); + assert_eq!(u64::try_from(written1).unwrap(), RECV_BUFFER_START); + + // Send the stream to the client. + let out = server.process(None, now()); + mem::drop(client.process(out.as_dgram_ref(), now())); + + // change max_stream_data for stream_id. + client.set_stream_max_data(stream_id, new_fc).unwrap(); + + // server should receive a MAX_SREAM_DATA frame if the flow control window is updated. + let out2 = client.process(None, now()); + let out3 = server.process(out2.as_dgram_ref(), now()); + let expected = usize::from(RECV_BUFFER_START < new_fc); + assert_eq!(server.stats().frame_rx.max_stream_data, expected); + + // If the flow control window has been increased, server can write more data. + let written2 = server.stream_send(stream_id, &[0x0; 10000]).unwrap(); + if RECV_BUFFER_START < new_fc { + assert_eq!(u64::try_from(written2).unwrap(), new_fc - RECV_BUFFER_START); + } else { + assert_eq!(written2, 0); + } + + // Exchange packets so that client gets all data. + let out4 = client.process(out3.as_dgram_ref(), now()); + let out5 = server.process(out4.as_dgram_ref(), now()); + mem::drop(client.process(out5.as_dgram_ref(), now())); + + // read all data by client + let mut buf = [0x0; 10000]; + let (read, _) = client.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(u64::try_from(read).unwrap(), max(RECV_BUFFER_START, new_fc)); + + let out4 = client.process(None, now()); + mem::drop(server.process(out4.as_dgram_ref(), now())); + + let written3 = server.stream_send(stream_id, &[0x0; 10000]).unwrap(); + assert_eq!(u64::try_from(written3).unwrap(), new_fc); +} + +#[test] +fn increase_decrease_flow_control() { + const RECV_BUFFER_NEW_BIGGER: u64 = 400; + const RECV_BUFFER_NEW_SMALLER: u64 = 200; + + change_flow_control(StreamType::UniDi, RECV_BUFFER_NEW_BIGGER); + change_flow_control(StreamType::BiDi, RECV_BUFFER_NEW_BIGGER); + + change_flow_control(StreamType::UniDi, RECV_BUFFER_NEW_SMALLER); + change_flow_control(StreamType::BiDi, RECV_BUFFER_NEW_SMALLER); +} + +#[test] +fn session_flow_control_stop_sending_state_recv() { + const SMALL_MAX_DATA: usize = 1024; + + let mut client = default_client(); + let mut server = new_server( + ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()), + ); + + connect(&mut client, &mut server); + + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SMALL_MAX_DATA + ); + + // send 1 byte so that the server learns about the stream. + assert_eq!(client.stream_send(stream_id, b"a").unwrap(), 1); + + exchange_data(&mut client, &mut server); + + server + .stream_stop_sending(stream_id, Error::NoError.code()) + .unwrap(); + + assert_eq!( + client + .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA]) + .unwrap(), + SMALL_MAX_DATA - 1 + ); + + // In this case the final size is only known after RESET frame is received. + // The server sends STOP_SENDING -> the client sends RESET -> the server + // sends MAX_DATA. + let out = server.process(None, now()).dgram(); + let out = client.process(out.as_ref(), now()).dgram(); + // the client is still limited. + let stream_id2 = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(client.stream_avail_send_space(stream_id2).unwrap(), 0); + let out = server.process(out.as_ref(), now()).dgram(); + client.process_input(&out.unwrap(), now()); + assert_eq!( + client.stream_avail_send_space(stream_id2).unwrap(), + SMALL_MAX_DATA + ); +} + +#[test] +fn session_flow_control_stop_sending_state_size_known() { + const SMALL_MAX_DATA: usize = 1024; + + let mut client = default_client(); + let mut server = new_server( + ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()), + ); + + connect(&mut client, &mut server); + + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SMALL_MAX_DATA + ); + + // send 1 byte so that the server learns about the stream. + assert_eq!( + client + .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1]) + .unwrap(), + SMALL_MAX_DATA + ); + + let out1 = client.process(None, now()).dgram(); + // Delay this packet and let the server receive fin first (it will enter SizeKnown state). + client.stream_close_send(stream_id).unwrap(); + let out2 = client.process(None, now()).dgram(); + + server.process_input(&out2.unwrap(), now()); + + server + .stream_stop_sending(stream_id, Error::NoError.code()) + .unwrap(); + + // In this case the final size is known when stream_stop_sending is called + // and the server releases flow control immediately and sends STOP_SENDING and + // MAX_DATA in the same packet. + let out = server.process(out1.as_ref(), now()).dgram(); + client.process_input(&out.unwrap(), now()); + + // The flow control should have been updated and the client can again send + // SMALL_MAX_DATA. + let stream_id2 = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!( + client.stream_avail_send_space(stream_id2).unwrap(), + SMALL_MAX_DATA + ); +} + +#[test] +fn session_flow_control_stop_sending_state_data_recvd() { + const SMALL_MAX_DATA: usize = 1024; + + let mut client = default_client(); + let mut server = new_server( + ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()), + ); + + connect(&mut client, &mut server); + + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SMALL_MAX_DATA + ); + + // send 1 byte so that the server learns about the stream. + assert_eq!( + client + .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1]) + .unwrap(), + SMALL_MAX_DATA + ); + + client.stream_close_send(stream_id).unwrap(); + + exchange_data(&mut client, &mut server); + + // The stream is DataRecvd state + server + .stream_stop_sending(stream_id, Error::NoError.code()) + .unwrap(); + + exchange_data(&mut client, &mut server); + + // The flow control should have been updated and the client can again send + // SMALL_MAX_DATA. + let stream_id2 = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!( + client.stream_avail_send_space(stream_id2).unwrap(), + SMALL_MAX_DATA + ); +} + +#[test] +fn session_flow_control_affects_all_streams() { + const SMALL_MAX_DATA: usize = 1024; + + let mut client = default_client(); + let mut server = new_server( + ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()), + ); + + connect(&mut client, &mut server); + + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SMALL_MAX_DATA + ); + + let stream_id2 = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!( + client.stream_avail_send_space(stream_id2).unwrap(), + SMALL_MAX_DATA + ); + + assert_eq!( + client + .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA / 2 + 1]) + .unwrap(), + SMALL_MAX_DATA / 2 + 1 + ); + + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SMALL_MAX_DATA / 2 - 1 + ); + assert_eq!( + client.stream_avail_send_space(stream_id2).unwrap(), + SMALL_MAX_DATA / 2 - 1 + ); + + exchange_data(&mut client, &mut server); + + let mut buf = [0x0; SMALL_MAX_DATA]; + let (read, _) = server.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(read, SMALL_MAX_DATA / 2 + 1); + + exchange_data(&mut client, &mut server); + + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SMALL_MAX_DATA + ); + + assert_eq!( + client.stream_avail_send_space(stream_id2).unwrap(), + SMALL_MAX_DATA + ); +} + +fn connect_w_different_limit(bidi_limit: u64, unidi_limit: u64) { + let mut client = default_client(); + let out = client.process(None, now()); + let mut server = new_server( + ConnectionParameters::default() + .max_streams(StreamType::BiDi, bidi_limit) + .max_streams(StreamType::UniDi, unidi_limit), + ); + let out = server.process(out.as_dgram_ref(), now()); + + let out = client.process(out.as_dgram_ref(), now()); + mem::drop(server.process(out.as_dgram_ref(), now())); + + assert!(maybe_authenticate(&mut client)); + + let mut bidi_events = 0; + let mut unidi_events = 0; + let mut connected_events = 0; + for e in client.events() { + match e { + ConnectionEvent::SendStreamCreatable { stream_type } => { + if stream_type == StreamType::BiDi { + bidi_events += 1; + } else { + unidi_events += 1; + } + } + ConnectionEvent::StateChange(State::Connected) => { + connected_events += 1; + } + _ => {} + } + } + assert_eq!(bidi_events, usize::from(bidi_limit > 0)); + assert_eq!(unidi_events, usize::from(unidi_limit > 0)); + assert_eq!(connected_events, 1); +} + +#[test] +fn client_stream_creatable_event() { + connect_w_different_limit(0, 0); + connect_w_different_limit(0, 1); + connect_w_different_limit(1, 0); + connect_w_different_limit(1, 1); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/vn.rs b/third_party/rust/neqo-transport/src/connection/tests/vn.rs new file mode 100644 index 0000000000..22f15c991c --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/vn.rs @@ -0,0 +1,482 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{mem, time::Duration}; + +use neqo_common::{event::Provider, Decoder, Encoder}; +use test_fixture::{self, assertions, datagram, now}; + +use super::{ + super::{ConnectionError, ConnectionEvent, Output, State, ZeroRttState}, + connect, connect_fail, default_client, default_server, exchange_ticket, new_client, new_server, + send_something, +}; +use crate::{ + packet::PACKET_BIT_LONG, + tparams::{self, TransportParameter}, + ConnectionParameters, Error, Version, +}; + +// The expected PTO duration after the first Initial is sent. +const INITIAL_PTO: Duration = Duration::from_millis(300); + +#[test] +fn unknown_version() { + let mut client = default_client(); + // Start the handshake. + mem::drop(client.process(None, now()).dgram()); + + let mut unknown_version_packet = vec![0x80, 0x1a, 0x1a, 0x1a, 0x1a]; + unknown_version_packet.resize(1200, 0x0); + mem::drop(client.process(Some(&datagram(unknown_version_packet)), now())); + assert_eq!(1, client.stats().dropped_rx); +} + +#[test] +fn server_receive_unknown_first_packet() { + let mut server = default_server(); + + let mut unknown_version_packet = vec![0x80, 0x1a, 0x1a, 0x1a, 0x1a]; + unknown_version_packet.resize(1200, 0x0); + + assert_eq!( + server.process(Some(&datagram(unknown_version_packet,)), now(),), + Output::None + ); + + assert_eq!(1, server.stats().dropped_rx); +} + +fn create_vn(initial_pkt: &[u8], versions: &[u32]) -> Vec<u8> { + let mut dec = Decoder::from(&initial_pkt[5..]); // Skip past version. + let dst_cid = dec.decode_vec(1).expect("client DCID"); + let src_cid = dec.decode_vec(1).expect("client SCID"); + + let mut encoder = Encoder::default(); + encoder.encode_byte(PACKET_BIT_LONG); + encoder.encode(&[0; 4]); // Zero version == VN. + encoder.encode_vec(1, src_cid); + encoder.encode_vec(1, dst_cid); + + for v in versions { + encoder.encode_uint(4, *v); + } + encoder.into() +} + +#[test] +fn version_negotiation_current_version() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let vn = create_vn( + &initial_pkt, + &[0x1a1a_1a1a, Version::default().wire_version()], + ); + + let dgram = datagram(vn); + let delay = client.process(Some(&dgram), now()).callback(); + assert_eq!(delay, INITIAL_PTO); + assert_eq!(*client.state(), State::WaitInitial); + assert_eq!(1, client.stats().dropped_rx); +} + +#[test] +fn version_negotiation_version0() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let vn = create_vn(&initial_pkt, &[0, 0x1a1a_1a1a]); + + let dgram = datagram(vn); + let delay = client.process(Some(&dgram), now()).callback(); + assert_eq!(delay, INITIAL_PTO); + assert_eq!(*client.state(), State::WaitInitial); + assert_eq!(1, client.stats().dropped_rx); +} + +#[test] +fn version_negotiation_only_reserved() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a]); + + let dgram = datagram(vn); + assert_eq!(client.process(Some(&dgram), now()), Output::None); + match client.state() { + State::Closed(err) => { + assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation)); + } + _ => panic!("Invalid client state"), + } +} + +#[test] +fn version_negotiation_corrupted() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a]); + + let dgram = datagram(vn[..vn.len() - 1].to_vec()); + let delay = client.process(Some(&dgram), now()).callback(); + assert_eq!(delay, INITIAL_PTO); + assert_eq!(*client.state(), State::WaitInitial); + assert_eq!(1, client.stats().dropped_rx); +} + +#[test] +fn version_negotiation_empty() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let vn = create_vn(&initial_pkt, &[]); + + let dgram = datagram(vn); + let delay = client.process(Some(&dgram), now()).callback(); + assert_eq!(delay, INITIAL_PTO); + assert_eq!(*client.state(), State::WaitInitial); + assert_eq!(1, client.stats().dropped_rx); +} + +#[test] +fn version_negotiation_not_supported() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a, 0xff00_0001]); + let dgram = datagram(vn); + assert_eq!(client.process(Some(&dgram), now()), Output::None); + match client.state() { + State::Closed(err) => { + assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation)); + } + _ => panic!("Invalid client state"), + } +} + +#[test] +fn version_negotiation_bad_cid() { + let mut client = default_client(); + // Start the handshake. + let mut initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + initial_pkt[6] ^= 0xc4; + let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a, 0xff00_0001]); + + let dgram = datagram(vn); + let delay = client.process(Some(&dgram), now()).callback(); + assert_eq!(delay, INITIAL_PTO); + assert_eq!(*client.state(), State::WaitInitial); + assert_eq!(1, client.stats().dropped_rx); +} + +#[test] +fn compatible_upgrade() { + let mut client = default_client(); + let mut server = default_server(); + + connect(&mut client, &mut server); + assert_eq!(client.version(), Version::Version2); + assert_eq!(server.version(), Version::Version2); +} + +/// When the first packet from the client is gigantic, the server might generate acknowledgment +/// packets in version 1. Both client and server need to handle that gracefully. +#[test] +fn compatible_upgrade_large_initial() { + let params = ConnectionParameters::default().versions( + Version::Version1, + vec![Version::Version2, Version::Version1], + ); + let mut client = new_client(params.clone()); + client + .set_local_tparam( + 0x0845_de37_00ac_a5f9, + TransportParameter::Bytes(vec![0; 2048]), + ) + .unwrap(); + let mut server = new_server(params); + + // Client Initial should take 2 packets. + // Each should elicit a Version 1 ACK from the server. + let dgram = client.process_output(now()).dgram(); + assert!(dgram.is_some()); + let dgram = server.process(dgram.as_ref(), now()).dgram(); + assert!(dgram.is_some()); + // The following uses the Version from *outside* this crate. + assertions::assert_version(dgram.as_ref().unwrap(), Version::Version1.wire_version()); + client.process_input(&dgram.unwrap(), now()); + + connect(&mut client, &mut server); + assert_eq!(client.version(), Version::Version2); + assert_eq!(server.version(), Version::Version2); + // Only handshake padding is "dropped". + assert_eq!(client.stats().dropped_rx, 1); + assert_eq!(server.stats().dropped_rx, 1); +} + +/// A server that supports versions 1 and 2 might prefer version 1 and that's OK. +/// This one starts with version 1 and stays there. +#[test] +fn compatible_no_upgrade() { + let mut client = new_client(ConnectionParameters::default().versions( + Version::Version1, + vec![Version::Version2, Version::Version1], + )); + let mut server = new_server(ConnectionParameters::default().versions( + Version::Version1, + vec![Version::Version1, Version::Version2], + )); + + connect(&mut client, &mut server); + assert_eq!(client.version(), Version::Version1); + assert_eq!(server.version(), Version::Version1); +} + +/// A server that supports versions 1 and 2 might prefer version 1 and that's OK. +/// This one starts with version 2 and downgrades to version 1. +#[test] +fn compatible_downgrade() { + let mut client = new_client(ConnectionParameters::default().versions( + Version::Version2, + vec![Version::Version2, Version::Version1], + )); + let mut server = new_server(ConnectionParameters::default().versions( + Version::Version2, + vec![Version::Version1, Version::Version2], + )); + + connect(&mut client, &mut server); + assert_eq!(client.version(), Version::Version1); + assert_eq!(server.version(), Version::Version1); +} + +/// Inject a Version Negotiation packet, which the client detects when it validates the +/// server `version_negotiation` transport parameter. +#[test] +fn version_negotiation_downgrade() { + const DOWNGRADE: Version = Version::Draft29; + + let mut client = default_client(); + // The server sets the current version in the transport parameter and + // protects Initial packets with the version in its configuration. + // When a server `Connection` is created by a `Server`, the configuration is set + // to match the version of the packet it first receives. This replicates that. + let mut server = + new_server(ConnectionParameters::default().versions(DOWNGRADE, Version::all())); + + // Start the handshake and spoof a VN packet. + let initial = client.process_output(now()).dgram().unwrap(); + let vn = create_vn(&initial, &[DOWNGRADE.wire_version()]); + let dgram = datagram(vn); + client.process_input(&dgram, now()); + + connect_fail( + &mut client, + &mut server, + Error::VersionNegotiation, + Error::PeerError(Error::VersionNegotiation.code()), + ); +} + +/// A server connection needs to be configured with the version that the client attempts. +/// Otherwise, it will object to the client transport parameters and not do anything. +#[test] +fn invalid_server_version() { + let mut client = + new_client(ConnectionParameters::default().versions(Version::Version1, Version::all())); + let mut server = + new_server(ConnectionParameters::default().versions(Version::Version2, Version::all())); + + let dgram = client.process_output(now()).dgram(); + server.process_input(&dgram.unwrap(), now()); + + // One packet received. + assert_eq!(server.stats().packets_rx, 1); + // None dropped; the server will have decrypted it successfully. + assert_eq!(server.stats().dropped_rx, 0); + assert_eq!(server.stats().saved_datagrams, 0); + // The server effectively hasn't reacted here. + match server.state() { + State::Closed(err) => { + assert_eq!(*err, ConnectionError::Transport(Error::CryptoAlert(47))); + } + _ => panic!("invalid server state"), + } +} + +#[test] +fn invalid_current_version_client() { + const OTHER_VERSION: Version = Version::Draft29; + + let mut client = default_client(); + let mut server = default_server(); + + assert_ne!(OTHER_VERSION, client.version()); + client + .set_local_tparam( + tparams::VERSION_INFORMATION, + TransportParameter::Versions { + current: OTHER_VERSION.wire_version(), + other: Version::all() + .iter() + .copied() + .map(Version::wire_version) + .collect(), + }, + ) + .unwrap(); + + connect_fail( + &mut client, + &mut server, + Error::PeerError(Error::CryptoAlert(47).code()), + Error::CryptoAlert(47), + ); +} + +/// To test this, we need to disable compatible upgrade so that the server doesn't update +/// its transport parameters. Then, we can overwrite its transport parameters without +/// them being overwritten. Otherwise, it would be hard to find a window during which +/// the transport parameter can be modified. +#[test] +fn invalid_current_version_server() { + const OTHER_VERSION: Version = Version::Draft29; + + let mut client = default_client(); + let mut server = new_server( + ConnectionParameters::default().versions(Version::default(), vec![Version::default()]), + ); + + assert!(!Version::default().is_compatible(OTHER_VERSION)); + server + .set_local_tparam( + tparams::VERSION_INFORMATION, + TransportParameter::Versions { + current: OTHER_VERSION.wire_version(), + other: vec![OTHER_VERSION.wire_version()], + }, + ) + .unwrap(); + + connect_fail( + &mut client, + &mut server, + Error::CryptoAlert(47), + Error::PeerError(Error::CryptoAlert(47).code()), + ); +} + +#[test] +fn no_compatible_version() { + const OTHER_VERSION: Version = Version::Draft29; + + let mut client = default_client(); + let mut server = default_server(); + + assert_ne!(OTHER_VERSION, client.version()); + client + .set_local_tparam( + tparams::VERSION_INFORMATION, + TransportParameter::Versions { + current: Version::default().wire_version(), + other: vec![OTHER_VERSION.wire_version()], + }, + ) + .unwrap(); + + connect_fail( + &mut client, + &mut server, + Error::PeerError(Error::CryptoAlert(47).code()), + Error::CryptoAlert(47), + ); +} + +/// When a compatible upgrade chooses a different version, 0-RTT is rejected. +#[test] +fn compatible_upgrade_0rtt_rejected() { + // This is the baseline configuration where v1 is attempted and v2 preferred. + let prefer_v2 = ConnectionParameters::default().versions( + Version::Version1, + vec![Version::Version2, Version::Version1], + ); + let mut client = new_client(prefer_v2.clone()); + // The server will start with this so that the client resumes with v1. + let just_v1 = + ConnectionParameters::default().versions(Version::Version1, vec![Version::Version1]); + let mut server = new_server(just_v1); + + connect(&mut client, &mut server); + assert_eq!(client.version(), Version::Version1); + let token = exchange_ticket(&mut client, &mut server, now()); + + // Now upgrade the server to the preferred configuration. + let mut client = new_client(prefer_v2.clone()); + let mut server = new_server(prefer_v2); + client.enable_resumption(now(), token).unwrap(); + + // Create a packet with 0-RTT from the client. + let initial = send_something(&mut client, now()); + assertions::assert_version(&initial, Version::Version1.wire_version()); + assertions::assert_coalesced_0rtt(&initial); + server.process_input(&initial, now()); + assert!(!server + .events() + .any(|e| matches!(e, ConnectionEvent::NewStream { .. }))); + + // Finalize the connection. Don't use connect() because it uses + // maybe_authenticate() too liberally and that eats the events we want to check. + let dgram = server.process_output(now()).dgram(); // ServerHello flight + let dgram = client.process(dgram.as_ref(), now()).dgram(); // Client Finished (note: no authentication) + let dgram = server.process(dgram.as_ref(), now()).dgram(); // HANDSHAKE_DONE + client.process_input(&dgram.unwrap(), now()); + + assert!(matches!(client.state(), State::Confirmed)); + assert!(matches!(server.state(), State::Confirmed)); + + assert!(client.events().any(|e| { + println!(" client event: {e:?}"); + matches!(e, ConnectionEvent::ZeroRttRejected) + })); + assert_eq!(client.zero_rtt_state(), ZeroRttState::Rejected); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs b/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs new file mode 100644 index 0000000000..0aa5573c98 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs @@ -0,0 +1,257 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{cell::RefCell, rc::Rc}; + +use neqo_common::event::Provider; +use neqo_crypto::{AllowZeroRtt, AntiReplay}; +use test_fixture::{self, assertions, now}; + +use super::{ + super::Connection, connect, default_client, default_server, exchange_ticket, new_server, + resumed_server, CountingConnectionIdGenerator, +}; +use crate::{events::ConnectionEvent, ConnectionParameters, Error, StreamType, Version}; + +#[test] +fn zero_rtt_negotiate() { + // Note that the two servers in this test will get different anti-replay filters. + // That's OK because we aren't testing anti-replay. + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let token = exchange_ticket(&mut client, &mut server, now()); + let mut client = default_client(); + client + .enable_resumption(now(), token) + .expect("should set token"); + let mut server = resumed_server(&client); + connect(&mut client, &mut server); + assert!(client.tls_info().unwrap().early_data_accepted()); + assert!(server.tls_info().unwrap().early_data_accepted()); +} + +#[test] +fn zero_rtt_send_recv() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let token = exchange_ticket(&mut client, &mut server, now()); + let mut client = default_client(); + client + .enable_resumption(now(), token) + .expect("should set token"); + let mut server = resumed_server(&client); + + // Send ClientHello. + let client_hs = client.process(None, now()); + assert!(client_hs.as_dgram_ref().is_some()); + + // Now send a 0-RTT packet. + let client_stream_id = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(client_stream_id, &[1, 2, 3]).unwrap(); + let client_0rtt = client.process(None, now()); + assert!(client_0rtt.as_dgram_ref().is_some()); + // 0-RTT packets on their own shouldn't be padded to 1200. + assert!(client_0rtt.as_dgram_ref().unwrap().len() < 1200); + + let server_hs = server.process(client_hs.as_dgram_ref(), now()); + assert!(server_hs.as_dgram_ref().is_some()); // ServerHello, etc... + + let all_frames = server.stats().frame_tx.all; + let ack_frames = server.stats().frame_tx.ack; + let server_process_0rtt = server.process(client_0rtt.as_dgram_ref(), now()); + assert!(server_process_0rtt.as_dgram_ref().is_some()); + assert_eq!(server.stats().frame_tx.all, all_frames + 1); + assert_eq!(server.stats().frame_tx.ack, ack_frames + 1); + + let server_stream_id = server + .events() + .find_map(|evt| match evt { + ConnectionEvent::NewStream { stream_id, .. } => Some(stream_id), + _ => None, + }) + .expect("should have received a new stream event"); + assert_eq!(client_stream_id, server_stream_id.as_u64()); +} + +#[test] +fn zero_rtt_send_coalesce() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let token = exchange_ticket(&mut client, &mut server, now()); + let mut client = default_client(); + client + .enable_resumption(now(), token) + .expect("should set token"); + let mut server = resumed_server(&client); + + // Write 0-RTT before generating any packets. + // This should result in a datagram that coalesces Initial and 0-RTT. + let client_stream_id = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(client_stream_id, &[1, 2, 3]).unwrap(); + let client_0rtt = client.process(None, now()); + assert!(client_0rtt.as_dgram_ref().is_some()); + + assertions::assert_coalesced_0rtt(&client_0rtt.as_dgram_ref().unwrap()[..]); + + let server_hs = server.process(client_0rtt.as_dgram_ref(), now()); + assert!(server_hs.as_dgram_ref().is_some()); // Should produce ServerHello etc... + + let server_stream_id = server + .events() + .find_map(|evt| match evt { + ConnectionEvent::NewStream { stream_id } => Some(stream_id), + _ => None, + }) + .expect("should have received a new stream event"); + assert_eq!(client_stream_id, server_stream_id.as_u64()); +} + +#[test] +fn zero_rtt_before_resumption_token() { + let mut client = default_client(); + assert!(client.stream_create(StreamType::BiDi).is_err()); +} + +#[test] +fn zero_rtt_send_reject() { + const MESSAGE: &[u8] = &[1, 2, 3]; + + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let token = exchange_ticket(&mut client, &mut server, now()); + let mut client = default_client(); + client + .enable_resumption(now(), token) + .expect("should set token"); + let mut server = Connection::new_server( + test_fixture::DEFAULT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(CountingConnectionIdGenerator::default())), + ConnectionParameters::default().versions(client.version(), Version::all()), + ) + .unwrap(); + // Using a freshly initialized anti-replay context + // should result in the server rejecting 0-RTT. + let ar = + AntiReplay::new(now(), test_fixture::ANTI_REPLAY_WINDOW, 1, 3).expect("setup anti-replay"); + server + .server_enable_0rtt(&ar, AllowZeroRtt {}) + .expect("enable 0-RTT"); + + // Send ClientHello. + let client_hs = client.process(None, now()); + assert!(client_hs.as_dgram_ref().is_some()); + + // Write some data on the client. + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(stream_id, MESSAGE).unwrap(); + let client_0rtt = client.process(None, now()); + assert!(client_0rtt.as_dgram_ref().is_some()); + + let server_hs = server.process(client_hs.as_dgram_ref(), now()); + assert!(server_hs.as_dgram_ref().is_some()); // Should produce ServerHello etc... + let server_ignored = server.process(client_0rtt.as_dgram_ref(), now()); + assert!(server_ignored.as_dgram_ref().is_none()); + + // The server shouldn't receive that 0-RTT data. + let recvd_stream_evt = |e| matches!(e, ConnectionEvent::NewStream { .. }); + assert!(!server.events().any(recvd_stream_evt)); + + // Client should get a rejection. + let client_fin = client.process(server_hs.as_dgram_ref(), now()); + let recvd_0rtt_reject = |e| e == ConnectionEvent::ZeroRttRejected; + assert!(client.events().any(recvd_0rtt_reject)); + + // Server consume client_fin + let server_ack = server.process(client_fin.as_dgram_ref(), now()); + assert!(server_ack.as_dgram_ref().is_some()); + let client_out = client.process(server_ack.as_dgram_ref(), now()); + assert!(client_out.as_dgram_ref().is_none()); + + // ...and the client stream should be gone. + let res = client.stream_send(stream_id, MESSAGE); + assert!(res.is_err()); + assert_eq!(res.unwrap_err(), Error::InvalidStreamId); + + // Open a new stream and send data. StreamId should start with 0. + let stream_id_after_reject = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(stream_id, stream_id_after_reject); + client.stream_send(stream_id_after_reject, MESSAGE).unwrap(); + let client_after_reject = client.process(None, now()).dgram(); + assert!(client_after_reject.is_some()); + + // The server should receive new stream + server.process_input(&client_after_reject.unwrap(), now()); + assert!(server.events().any(recvd_stream_evt)); +} + +#[test] +fn zero_rtt_update_flow_control() { + const LOW: u64 = 3; + const HIGH: u64 = 10; + #[allow(clippy::cast_possible_truncation)] + const MESSAGE: &[u8] = &[0; HIGH as usize]; + + let mut client = default_client(); + let mut server = new_server( + ConnectionParameters::default() + .max_stream_data(StreamType::UniDi, true, LOW) + .max_stream_data(StreamType::BiDi, true, LOW), + ); + connect(&mut client, &mut server); + + let token = exchange_ticket(&mut client, &mut server, now()); + let mut client = default_client(); + client + .enable_resumption(now(), token) + .expect("should set token"); + let mut server = new_server( + ConnectionParameters::default() + .max_stream_data(StreamType::UniDi, true, HIGH) + .max_stream_data(StreamType::BiDi, true, HIGH) + .versions(client.version, Version::all()), + ); + + // Stream limits should be low for 0-RTT. + let client_hs = client.process(None, now()).dgram(); + let uni_stream = client.stream_create(StreamType::UniDi).unwrap(); + assert!(!client.stream_send_atomic(uni_stream, MESSAGE).unwrap()); + let bidi_stream = client.stream_create(StreamType::BiDi).unwrap(); + assert!(!client.stream_send_atomic(bidi_stream, MESSAGE).unwrap()); + + // Now get the server transport parameters. + let server_hs = server.process(client_hs.as_ref(), now()).dgram(); + client.process_input(&server_hs.unwrap(), now()); + + // The streams should report a writeable event. + let mut uni_stream_event = false; + let mut bidi_stream_event = false; + for e in client.events() { + if let ConnectionEvent::SendStreamWritable { stream_id } = e { + if stream_id.is_uni() { + uni_stream_event = true; + } else { + bidi_stream_event = true; + } + } + } + assert!(uni_stream_event); + assert!(bidi_stream_event); + // But no MAX_STREAM_DATA frame was received. + assert_eq!(client.stats().frame_rx.max_stream_data, 0); + + // And the new limit applies. + assert!(client.stream_send_atomic(uni_stream, MESSAGE).unwrap()); + assert!(client.stream_send_atomic(bidi_stream, MESSAGE).unwrap()); +} diff --git a/third_party/rust/neqo-transport/src/crypto.rs b/third_party/rust/neqo-transport/src/crypto.rs new file mode 100644 index 0000000000..f6cc7c0e2f --- /dev/null +++ b/third_party/rust/neqo-transport/src/crypto.rs @@ -0,0 +1,1583 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cell::RefCell, + cmp::{max, min}, + collections::HashMap, + convert::TryFrom, + mem, + ops::{Index, IndexMut, Range}, + rc::Rc, + time::Instant, +}; + +use neqo_common::{hex, hex_snip_middle, qdebug, qinfo, qtrace, Encoder, Role}; +use neqo_crypto::{ + hkdf, hp::HpKey, Aead, Agent, AntiReplay, Cipher, Epoch, Error as CryptoError, HandshakeState, + PrivateKey, PublicKey, Record, RecordList, ResumptionToken, SymKey, ZeroRttChecker, + TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, TLS_CT_HANDSHAKE, + TLS_EPOCH_APPLICATION_DATA, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL, TLS_EPOCH_ZERO_RTT, + TLS_GRP_EC_SECP256R1, TLS_GRP_EC_SECP384R1, TLS_GRP_EC_SECP521R1, TLS_GRP_EC_X25519, + TLS_VERSION_1_3, +}; + +use crate::{ + cid::ConnectionIdRef, + packet::{PacketBuilder, PacketNumber}, + recovery::RecoveryToken, + recv_stream::RxStreamOrderer, + send_stream::TxBuffer, + stats::FrameStats, + tparams::{TpZeroRttChecker, TransportParameters, TransportParametersHandler}, + tracking::PacketNumberSpace, + version::Version, + Error, Res, +}; + +const MAX_AUTH_TAG: usize = 32; +/// The number of invocations remaining on a write cipher before we try +/// to update keys. This has to be much smaller than the number returned +/// by `CryptoDxState::limit` or updates will happen too often. As we don't +/// need to ask permission to update, this can be quite small. +pub(crate) const UPDATE_WRITE_KEYS_AT: PacketNumber = 100; + +// This is a testing kludge that allows for overwriting the number of +// invocations of the next cipher to operate. With this, it is possible +// to test what happens when the number of invocations reaches 0, or +// when it hits `UPDATE_WRITE_KEYS_AT` and an automatic update should occur. +// This is a little crude, but it saves a lot of plumbing. +#[cfg(test)] +thread_local!(pub(crate) static OVERWRITE_INVOCATIONS: RefCell<Option<PacketNumber>> = RefCell::default()); + +#[derive(Debug)] +pub struct Crypto { + version: Version, + protocols: Vec<String>, + pub(crate) tls: Agent, + pub(crate) streams: CryptoStreams, + pub(crate) states: CryptoStates, +} + +type TpHandler = Rc<RefCell<TransportParametersHandler>>; + +impl Crypto { + pub fn new( + version: Version, + mut agent: Agent, + protocols: Vec<String>, + tphandler: TpHandler, + fuzzing: bool, + ) -> Res<Self> { + agent.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?; + agent.set_ciphers(&[ + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, + ])?; + agent.set_groups(&[ + TLS_GRP_EC_X25519, + TLS_GRP_EC_SECP256R1, + TLS_GRP_EC_SECP384R1, + TLS_GRP_EC_SECP521R1, + ])?; + agent.send_additional_key_shares(1)?; + agent.set_alpn(&protocols)?; + agent.disable_end_of_early_data()?; + // Always enable 0-RTT on the client, but the server needs + // more configuration passed to server_enable_0rtt. + if let Agent::Client(c) = &mut agent { + c.enable_0rtt()?; + } + let extension = match version { + Version::Version2 | Version::Version1 => 0x39, + Version::Draft29 | Version::Draft30 | Version::Draft31 | Version::Draft32 => 0xffa5, + }; + agent.extension_handler(extension, tphandler)?; + Ok(Self { + version, + protocols, + tls: agent, + streams: Default::default(), + states: CryptoStates { + fuzzing, + ..Default::default() + }, + }) + } + + /// Get the name of the server. (Only works for the client currently). + pub fn server_name(&self) -> Option<&str> { + if let Agent::Client(c) = &self.tls { + Some(c.server_name()) + } else { + None + } + } + + /// Get the set of enabled protocols. + pub fn protocols(&self) -> &[String] { + &self.protocols + } + + pub fn server_enable_0rtt( + &mut self, + tphandler: TpHandler, + anti_replay: &AntiReplay, + zero_rtt_checker: impl ZeroRttChecker + 'static, + ) -> Res<()> { + if let Agent::Server(s) = &mut self.tls { + Ok(s.enable_0rtt( + anti_replay, + 0xffff_ffff, + TpZeroRttChecker::wrap(tphandler, zero_rtt_checker), + )?) + } else { + panic!("not a server"); + } + } + + pub fn server_enable_ech( + &mut self, + config: u8, + public_name: &str, + sk: &PrivateKey, + pk: &PublicKey, + ) -> Res<()> { + if let Agent::Server(s) = &mut self.tls { + s.enable_ech(config, public_name, sk, pk)?; + Ok(()) + } else { + panic!("not a client"); + } + } + + pub fn client_enable_ech(&mut self, ech_config_list: impl AsRef<[u8]>) -> Res<()> { + if let Agent::Client(c) = &mut self.tls { + c.enable_ech(ech_config_list)?; + Ok(()) + } else { + panic!("not a client"); + } + } + + /// Get the active ECH configuration, which is empty if ECH is disabled. + pub fn ech_config(&self) -> &[u8] { + self.tls.ech_config() + } + + pub fn handshake( + &mut self, + now: Instant, + space: PacketNumberSpace, + data: Option<&[u8]>, + ) -> Res<&HandshakeState> { + let input = data.map(|d| { + qtrace!("Handshake record received {:0x?} ", d); + let epoch = match space { + PacketNumberSpace::Initial => TLS_EPOCH_INITIAL, + PacketNumberSpace::Handshake => TLS_EPOCH_HANDSHAKE, + // Our epoch progresses forward, but the TLS epoch is fixed to 3. + PacketNumberSpace::ApplicationData => TLS_EPOCH_APPLICATION_DATA, + }; + Record { + ct: TLS_CT_HANDSHAKE, + epoch, + data: d.to_vec(), + } + }); + + match self.tls.handshake_raw(now, input) { + Ok(output) => { + self.buffer_records(output)?; + Ok(self.tls.state()) + } + Err(CryptoError::EchRetry(v)) => Err(Error::EchRetry(v)), + Err(e) => { + qinfo!("Handshake failed {:?}", e); + Err(match self.tls.alert() { + Some(a) => Error::CryptoAlert(*a), + _ => Error::CryptoError(e), + }) + } + } + } + + /// Enable 0-RTT and return `true` if it is enabled successfully. + pub fn enable_0rtt(&mut self, version: Version, role: Role) -> Res<bool> { + let info = self.tls.preinfo()?; + // `info.early_data()` returns false for a server, + // so use `early_data_cipher()` to tell if 0-RTT is enabled. + let cipher = info.early_data_cipher(); + if cipher.is_none() { + return Ok(false); + } + let (dir, secret) = match role { + Role::Client => ( + CryptoDxDirection::Write, + self.tls.write_secret(TLS_EPOCH_ZERO_RTT), + ), + Role::Server => ( + CryptoDxDirection::Read, + self.tls.read_secret(TLS_EPOCH_ZERO_RTT), + ), + }; + let secret = secret.ok_or(Error::InternalError)?; + self.states + .set_0rtt_keys(version, dir, &secret, cipher.unwrap()); + Ok(true) + } + + /// Lock in a compatible upgrade. + pub fn confirm_version(&mut self, confirmed: Version) { + self.states.confirm_version(self.version, confirmed); + self.version = confirmed; + } + + /// Returns true if new handshake keys were installed. + pub fn install_keys(&mut self, role: Role) -> Res<bool> { + if !self.tls.state().is_final() { + let installed_hs = self.install_handshake_keys()?; + if role == Role::Server { + self.maybe_install_application_write_key(self.version)?; + } + Ok(installed_hs) + } else { + Ok(false) + } + } + + fn install_handshake_keys(&mut self) -> Res<bool> { + qtrace!([self], "Attempt to install handshake keys"); + let Some(write_secret) = self.tls.write_secret(TLS_EPOCH_HANDSHAKE) else { + // No keys is fine. + return Ok(false); + }; + let read_secret = self + .tls + .read_secret(TLS_EPOCH_HANDSHAKE) + .ok_or(Error::InternalError)?; + let cipher = match self.tls.info() { + None => self.tls.preinfo()?.cipher_suite(), + Some(info) => Some(info.cipher_suite()), + } + .ok_or(Error::InternalError)?; + self.states + .set_handshake_keys(self.version, &write_secret, &read_secret, cipher); + qdebug!([self], "Handshake keys installed"); + Ok(true) + } + + fn maybe_install_application_write_key(&mut self, version: Version) -> Res<()> { + qtrace!([self], "Attempt to install application write key"); + if let Some(secret) = self.tls.write_secret(TLS_EPOCH_APPLICATION_DATA) { + self.states.set_application_write_key(version, secret)?; + qdebug!([self], "Application write key installed"); + } + Ok(()) + } + + pub fn install_application_keys(&mut self, version: Version, expire_0rtt: Instant) -> Res<()> { + self.maybe_install_application_write_key(version)?; + // The write key might have been installed earlier, but it should + // always be installed now. + debug_assert!(self.states.app_write.is_some()); + let read_secret = self + .tls + .read_secret(TLS_EPOCH_APPLICATION_DATA) + .ok_or(Error::InternalError)?; + self.states + .set_application_read_key(version, read_secret, expire_0rtt)?; + qdebug!([self], "application read keys installed"); + Ok(()) + } + + /// Buffer crypto records for sending. + pub fn buffer_records(&mut self, records: RecordList) -> Res<()> { + for r in records { + if r.ct != TLS_CT_HANDSHAKE { + return Err(Error::ProtocolViolation); + } + qtrace!([self], "Adding CRYPTO data {:?}", r); + self.streams.send(PacketNumberSpace::from(r.epoch), &r.data); + } + Ok(()) + } + + pub fn write_frame( + &mut self, + space: PacketNumberSpace, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) -> Res<()> { + self.streams.write_frame(space, builder, tokens, stats) + } + + pub fn acked(&mut self, token: &CryptoRecoveryToken) { + qinfo!( + "Acked crypto frame space={} offset={} length={}", + token.space, + token.offset, + token.length + ); + self.streams.acked(token); + } + + pub fn lost(&mut self, token: &CryptoRecoveryToken) { + qinfo!( + "Lost crypto frame space={} offset={} length={}", + token.space, + token.offset, + token.length + ); + self.streams.lost(token); + } + + /// Mark any outstanding frames in the indicated space as "lost" so + /// that they can be sent again. + pub fn resend_unacked(&mut self, space: PacketNumberSpace) { + self.streams.resend_unacked(space); + } + + /// Discard state for a packet number space and return true + /// if something was discarded. + pub fn discard(&mut self, space: PacketNumberSpace) -> bool { + self.streams.discard(space); + self.states.discard(space) + } + + pub fn create_resumption_token( + &mut self, + new_token: Option<&[u8]>, + tps: &TransportParameters, + version: Version, + rtt: u64, + ) -> Option<ResumptionToken> { + if let Agent::Client(ref mut c) = self.tls { + if let Some(ref t) = c.resumption_token() { + qtrace!("TLS token {}", hex(t.as_ref())); + let mut enc = Encoder::default(); + enc.encode_uint(4, version.wire_version()); + enc.encode_varint(rtt); + enc.encode_vvec_with(|enc_inner| { + tps.encode(enc_inner); + }); + enc.encode_vvec(new_token.unwrap_or(&[])); + enc.encode(t.as_ref()); + qinfo!("resumption token {}", hex_snip_middle(enc.as_ref())); + Some(ResumptionToken::new(enc.into(), t.expiration_time())) + } else { + None + } + } else { + unreachable!("It is a server."); + } + } + + pub fn has_resumption_token(&self) -> bool { + if let Agent::Client(c) = &self.tls { + c.has_resumption_token() + } else { + unreachable!("It is a server."); + } + } +} + +impl ::std::fmt::Display for Crypto { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Crypto") + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CryptoDxDirection { + Read, + Write, +} + +#[derive(Debug)] +pub struct CryptoDxState { + /// The QUIC version. + version: Version, + /// Whether packets protected with this state will be read or written. + direction: CryptoDxDirection, + /// The epoch of this crypto state. This initially tracks TLS epochs + /// via DTLS: 0 = initial, 1 = 0-RTT, 2 = handshake, 3 = application. + /// But we don't need to keep that, and QUIC isn't limited in how + /// many times keys can be updated, so we don't use `u16` for this. + epoch: usize, + aead: Aead, + hpkey: HpKey, + /// This tracks the range of packet numbers that have been seen. This allows + /// for verifying that packet numbers before a key update are strictly lower + /// than packet numbers after a key update. + used_pn: Range<PacketNumber>, + /// This is the minimum packet number that is allowed. + min_pn: PacketNumber, + /// The total number of operations that are remaining before the keys + /// become exhausted and can't be used any more. + invocations: PacketNumber, + fuzzing: bool, +} + +impl CryptoDxState { + #[allow(clippy::reversed_empty_ranges)] // To initialize an empty range. + pub fn new( + version: Version, + direction: CryptoDxDirection, + epoch: Epoch, + secret: &SymKey, + cipher: Cipher, + fuzzing: bool, + ) -> Self { + qinfo!( + "Making {:?} {} CryptoDxState, v={:?} cipher={}", + direction, + epoch, + version, + cipher, + ); + let hplabel = String::from(version.label_prefix()) + "hp"; + Self { + version, + direction, + epoch: usize::from(epoch), + aead: Aead::new( + fuzzing, + TLS_VERSION_1_3, + cipher, + secret, + version.label_prefix(), + ) + .unwrap(), + hpkey: HpKey::extract(TLS_VERSION_1_3, cipher, secret, &hplabel).unwrap(), + used_pn: 0..0, + min_pn: 0, + invocations: Self::limit(direction, cipher), + fuzzing, + } + } + + pub fn new_initial( + version: Version, + direction: CryptoDxDirection, + label: &str, + dcid: &[u8], + fuzzing: bool, + ) -> Self { + qtrace!("new_initial {:?} {}", version, ConnectionIdRef::from(dcid)); + let salt = version.initial_salt(); + let cipher = TLS_AES_128_GCM_SHA256; + let initial_secret = hkdf::extract( + TLS_VERSION_1_3, + cipher, + Some(hkdf::import_key(TLS_VERSION_1_3, salt).as_ref().unwrap()), + hkdf::import_key(TLS_VERSION_1_3, dcid).as_ref().unwrap(), + ) + .unwrap(); + + let secret = + hkdf::expand_label(TLS_VERSION_1_3, cipher, &initial_secret, &[], label).unwrap(); + + Self::new( + version, + direction, + TLS_EPOCH_INITIAL, + &secret, + cipher, + fuzzing, + ) + } + + /// Determine the confidentiality and integrity limits for the cipher. + fn limit(direction: CryptoDxDirection, cipher: Cipher) -> PacketNumber { + match direction { + // This uses the smaller limits for 2^16 byte packets + // as we don't control incoming packet size. + CryptoDxDirection::Read => match cipher { + TLS_AES_128_GCM_SHA256 => 1 << 52, + TLS_AES_256_GCM_SHA384 => PacketNumber::MAX, + TLS_CHACHA20_POLY1305_SHA256 => 1 << 36, + _ => unreachable!(), + }, + // This uses the larger limits for 2^11 byte packets. + CryptoDxDirection::Write => match cipher { + TLS_AES_128_GCM_SHA256 | TLS_AES_256_GCM_SHA384 => 1 << 28, + TLS_CHACHA20_POLY1305_SHA256 => PacketNumber::MAX, + _ => unreachable!(), + }, + } + } + + fn invoked(&mut self) -> Res<()> { + #[cfg(test)] + OVERWRITE_INVOCATIONS.with(|v| { + if let Some(i) = v.borrow_mut().take() { + neqo_common::qwarn!("Setting {:?} invocations to {}", self.direction, i); + self.invocations = i; + } + }); + self.invocations = self + .invocations + .checked_sub(1) + .ok_or(Error::KeysExhausted)?; + Ok(()) + } + + /// Determine whether we should initiate a key update. + pub fn should_update(&self) -> bool { + // There is no point in updating read keys as the limit is global. + debug_assert_eq!(self.direction, CryptoDxDirection::Write); + self.invocations <= UPDATE_WRITE_KEYS_AT + } + + pub fn next(&self, next_secret: &SymKey, cipher: Cipher) -> Self { + let pn = self.next_pn(); + // We count invocations of each write key just for that key, but all + // attempts to invocations to read count toward a single limit. + // This doesn't count use of Handshake keys. + let invocations = if self.direction == CryptoDxDirection::Read { + self.invocations + } else { + Self::limit(CryptoDxDirection::Write, cipher) + }; + Self { + version: self.version, + direction: self.direction, + epoch: self.epoch + 1, + aead: Aead::new( + self.fuzzing, + TLS_VERSION_1_3, + cipher, + next_secret, + self.version.label_prefix(), + ) + .unwrap(), + hpkey: self.hpkey.clone(), + used_pn: pn..pn, + min_pn: pn, + invocations, + fuzzing: self.fuzzing, + } + } + + #[must_use] + pub fn version(&self) -> Version { + self.version + } + + #[must_use] + pub fn key_phase(&self) -> bool { + // Epoch 3 => 0, 4 => 1, 5 => 0, 6 => 1, ... + self.epoch & 1 != 1 + } + + /// This is a continuation of a previous, so adjust the range accordingly. + /// Fail if the two ranges overlap. Do nothing if the directions don't match. + pub fn continuation(&mut self, prev: &Self) -> Res<()> { + debug_assert_eq!(self.direction, prev.direction); + let next = prev.next_pn(); + self.min_pn = next; + if self.used_pn.is_empty() { + self.used_pn = next..next; + Ok(()) + } else if prev.used_pn.end > self.used_pn.start { + qdebug!( + [self], + "Found packet with too new packet number {} > {}, compared to {}", + self.used_pn.start, + prev.used_pn.end, + prev, + ); + Err(Error::PacketNumberOverlap) + } else { + self.used_pn.start = next; + Ok(()) + } + } + + /// Mark a packet number as used. If this is too low, reject it. + /// Note that this won't catch a value that is too high if packets protected with + /// old keys are received after a key update. That needs to be caught elsewhere. + pub fn used(&mut self, pn: PacketNumber) -> Res<()> { + if pn < self.min_pn { + qdebug!( + [self], + "Found packet with too old packet number: {} < {}", + pn, + self.min_pn + ); + return Err(Error::PacketNumberOverlap); + } + if self.used_pn.start == self.used_pn.end { + self.used_pn.start = pn; + } + self.used_pn.end = max(pn + 1, self.used_pn.end); + Ok(()) + } + + #[must_use] + pub fn needs_update(&self) -> bool { + // Only initiate a key update if we have processed exactly one packet + // and we are in an epoch greater than 3. + self.used_pn.start + 1 == self.used_pn.end + && self.epoch > usize::from(TLS_EPOCH_APPLICATION_DATA) + } + + #[must_use] + pub fn can_update(&self, largest_acknowledged: Option<PacketNumber>) -> bool { + if let Some(la) = largest_acknowledged { + self.used_pn.contains(&la) + } else { + // If we haven't received any acknowledgments, it's OK to update + // the first application data epoch. + self.epoch == usize::from(TLS_EPOCH_APPLICATION_DATA) + } + } + + pub fn compute_mask(&self, sample: &[u8]) -> Res<Vec<u8>> { + let mask = self.hpkey.mask(sample)?; + qtrace!([self], "HP sample={} mask={}", hex(sample), hex(&mask)); + Ok(mask) + } + + #[must_use] + pub fn next_pn(&self) -> PacketNumber { + self.used_pn.end + } + + pub fn encrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> { + debug_assert_eq!(self.direction, CryptoDxDirection::Write); + qtrace!( + [self], + "encrypt pn={} hdr={} body={}", + pn, + hex(hdr), + hex(body) + ); + // The numbers in `Self::limit` assume a maximum packet size of 2^11. + if body.len() > 2048 { + debug_assert!(false); + return Err(Error::InternalError); + } + self.invoked()?; + + let size = body.len() + MAX_AUTH_TAG; + let mut out = vec![0; size]; + let res = self.aead.encrypt(pn, hdr, body, &mut out)?; + + qtrace!([self], "encrypt ct={}", hex(res)); + debug_assert_eq!(pn, self.next_pn()); + self.used(pn)?; + Ok(res.to_vec()) + } + + #[must_use] + pub fn expansion(&self) -> usize { + self.aead.expansion() + } + + pub fn decrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> { + debug_assert_eq!(self.direction, CryptoDxDirection::Read); + qtrace!( + [self], + "decrypt pn={} hdr={} body={}", + pn, + hex(hdr), + hex(body) + ); + self.invoked()?; + let mut out = vec![0; body.len()]; + let res = self.aead.decrypt(pn, hdr, body, &mut out)?; + self.used(pn)?; + Ok(res.to_vec()) + } + + #[cfg(all(test, not(feature = "fuzzing")))] + pub(crate) fn test_default() -> Self { + // This matches the value in packet.rs + const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; + Self::new_initial( + Version::default(), + CryptoDxDirection::Write, + "server in", + CLIENT_CID, + false, + ) + } + + /// Get the amount of extra padding packets protected with this profile need. + /// This is the difference between the size of the header protection sample + /// and the AEAD expansion. + pub fn extra_padding(&self) -> usize { + self.hpkey + .sample_size() + .saturating_sub(self.aead.expansion()) + } +} + +impl std::fmt::Display for CryptoDxState { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "epoch {} {:?}", self.epoch, self.direction) + } +} + +#[derive(Debug)] +pub struct CryptoState { + tx: CryptoDxState, + rx: CryptoDxState, +} + +impl Index<CryptoDxDirection> for CryptoState { + type Output = CryptoDxState; + + fn index(&self, dir: CryptoDxDirection) -> &Self::Output { + match dir { + CryptoDxDirection::Read => &self.rx, + CryptoDxDirection::Write => &self.tx, + } + } +} + +impl IndexMut<CryptoDxDirection> for CryptoState { + fn index_mut(&mut self, dir: CryptoDxDirection) -> &mut Self::Output { + match dir { + CryptoDxDirection::Read => &mut self.rx, + CryptoDxDirection::Write => &mut self.tx, + } + } +} + +/// `CryptoDxAppData` wraps the state necessary for one direction of application data keys. +/// This includes the secret needed to generate the next set of keys. +#[derive(Debug)] +pub(crate) struct CryptoDxAppData { + dx: CryptoDxState, + cipher: Cipher, + // Not the secret used to create `self.dx`, but the one needed for the next iteration. + next_secret: SymKey, + fuzzing: bool, +} + +impl CryptoDxAppData { + pub fn new( + version: Version, + dir: CryptoDxDirection, + secret: SymKey, + cipher: Cipher, + fuzzing: bool, + ) -> Res<Self> { + Ok(Self { + dx: CryptoDxState::new( + version, + dir, + TLS_EPOCH_APPLICATION_DATA, + &secret, + cipher, + fuzzing, + ), + cipher, + next_secret: Self::update_secret(cipher, &secret)?, + fuzzing, + }) + } + + fn update_secret(cipher: Cipher, secret: &SymKey) -> Res<SymKey> { + let next = hkdf::expand_label(TLS_VERSION_1_3, cipher, secret, &[], "quic ku")?; + Ok(next) + } + + pub fn next(&self) -> Res<Self> { + if self.dx.epoch == usize::max_value() { + // Guard against too many key updates. + return Err(Error::KeysExhausted); + } + let next_secret = Self::update_secret(self.cipher, &self.next_secret)?; + Ok(Self { + dx: self.dx.next(&self.next_secret, self.cipher), + cipher: self.cipher, + next_secret, + fuzzing: self.fuzzing, + }) + } + + pub fn epoch(&self) -> usize { + self.dx.epoch + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum CryptoSpace { + Initial, + ZeroRtt, + Handshake, + ApplicationData, +} + +/// All of the keying material needed for a connection. +/// +/// Note that the methods on this struct take a version but those are only ever +/// used for Initial keys; a version has been selected at the time we need to +/// get other keys, so those have fixed versions. +#[derive(Debug, Default)] +pub struct CryptoStates { + initials: HashMap<Version, CryptoState>, + handshake: Option<CryptoState>, + zero_rtt: Option<CryptoDxState>, // One direction only! + cipher: Cipher, + app_write: Option<CryptoDxAppData>, + app_read: Option<CryptoDxAppData>, + app_read_next: Option<CryptoDxAppData>, + // If this is set, then we have noticed a genuine update. + // Once this time passes, we should switch in new keys. + read_update_time: Option<Instant>, + fuzzing: bool, +} + +impl CryptoStates { + /// Select a `CryptoDxState` and `CryptoSpace` for the given `PacketNumberSpace`. + /// This selects 0-RTT keys for `PacketNumberSpace::ApplicationData` if 1-RTT keys are + /// not yet available. + pub fn select_tx_mut( + &mut self, + version: Version, + space: PacketNumberSpace, + ) -> Option<(CryptoSpace, &mut CryptoDxState)> { + match space { + PacketNumberSpace::Initial => self + .tx_mut(version, CryptoSpace::Initial) + .map(|dx| (CryptoSpace::Initial, dx)), + PacketNumberSpace::Handshake => self + .tx_mut(version, CryptoSpace::Handshake) + .map(|dx| (CryptoSpace::Handshake, dx)), + PacketNumberSpace::ApplicationData => { + if let Some(app) = self.app_write.as_mut() { + Some((CryptoSpace::ApplicationData, &mut app.dx)) + } else { + self.zero_rtt.as_mut().map(|dx| (CryptoSpace::ZeroRtt, dx)) + } + } + } + } + + pub fn tx_mut<'a>( + &'a mut self, + version: Version, + cspace: CryptoSpace, + ) -> Option<&'a mut CryptoDxState> { + let tx = |k: Option<&'a mut CryptoState>| k.map(|dx| &mut dx.tx); + match cspace { + CryptoSpace::Initial => tx(self.initials.get_mut(&version)), + CryptoSpace::ZeroRtt => self + .zero_rtt + .as_mut() + .filter(|z| z.direction == CryptoDxDirection::Write), + CryptoSpace::Handshake => tx(self.handshake.as_mut()), + CryptoSpace::ApplicationData => self.app_write.as_mut().map(|app| &mut app.dx), + } + } + + pub fn tx<'a>(&'a self, version: Version, cspace: CryptoSpace) -> Option<&'a CryptoDxState> { + let tx = |k: Option<&'a CryptoState>| k.map(|dx| &dx.tx); + match cspace { + CryptoSpace::Initial => tx(self.initials.get(&version)), + CryptoSpace::ZeroRtt => self + .zero_rtt + .as_ref() + .filter(|z| z.direction == CryptoDxDirection::Write), + CryptoSpace::Handshake => tx(self.handshake.as_ref()), + CryptoSpace::ApplicationData => self.app_write.as_ref().map(|app| &app.dx), + } + } + + pub fn select_tx( + &self, + version: Version, + space: PacketNumberSpace, + ) -> Option<(CryptoSpace, &CryptoDxState)> { + match space { + PacketNumberSpace::Initial => self + .tx(version, CryptoSpace::Initial) + .map(|dx| (CryptoSpace::Initial, dx)), + PacketNumberSpace::Handshake => self + .tx(version, CryptoSpace::Handshake) + .map(|dx| (CryptoSpace::Handshake, dx)), + PacketNumberSpace::ApplicationData => { + if let Some(app) = self.app_write.as_ref() { + Some((CryptoSpace::ApplicationData, &app.dx)) + } else { + self.zero_rtt.as_ref().map(|dx| (CryptoSpace::ZeroRtt, dx)) + } + } + } + } + + pub fn rx_hp(&mut self, version: Version, cspace: CryptoSpace) -> Option<&mut CryptoDxState> { + if let CryptoSpace::ApplicationData = cspace { + self.app_read.as_mut().map(|ar| &mut ar.dx) + } else { + self.rx(version, cspace, false) + } + } + + pub fn rx<'a>( + &'a mut self, + version: Version, + cspace: CryptoSpace, + key_phase: bool, + ) -> Option<&'a mut CryptoDxState> { + let rx = |x: Option<&'a mut CryptoState>| x.map(|dx| &mut dx.rx); + match cspace { + CryptoSpace::Initial => rx(self.initials.get_mut(&version)), + CryptoSpace::ZeroRtt => self + .zero_rtt + .as_mut() + .filter(|z| z.direction == CryptoDxDirection::Read), + CryptoSpace::Handshake => rx(self.handshake.as_mut()), + CryptoSpace::ApplicationData => { + let f = |a: Option<&'a mut CryptoDxAppData>| { + a.filter(|ar| ar.dx.key_phase() == key_phase) + }; + // XOR to reduce the leakage about which key is chosen. + f(self.app_read.as_mut()) + .xor(f(self.app_read_next.as_mut())) + .map(|ar| &mut ar.dx) + } + } + } + + /// Whether keys for processing packets in the indicated space are pending. + /// This allows the caller to determine whether to save a packet for later + /// when keys are not available. + /// NOTE: 0-RTT keys are not considered here. The expectation is that a + /// server will have to save 0-RTT packets in a different place. Though it + /// is possible to attribute 0-RTT packets to an existing connection if there + /// is a multi-packet Initial, that is an unusual circumstance, so we + /// don't do caching for that in those places that call this function. + pub fn rx_pending(&self, space: CryptoSpace) -> bool { + match space { + CryptoSpace::Initial | CryptoSpace::ZeroRtt => false, + CryptoSpace::Handshake => self.handshake.is_none() && !self.initials.is_empty(), + CryptoSpace::ApplicationData => self.app_read.is_none(), + } + } + + /// Create the initial crypto state. + /// Note that the version here can change and that's OK. + pub fn init<'v, V>(&mut self, versions: V, role: Role, dcid: &[u8]) + where + V: IntoIterator<Item = &'v Version>, + { + const CLIENT_INITIAL_LABEL: &str = "client in"; + const SERVER_INITIAL_LABEL: &str = "server in"; + + let (write, read) = match role { + Role::Client => (CLIENT_INITIAL_LABEL, SERVER_INITIAL_LABEL), + Role::Server => (SERVER_INITIAL_LABEL, CLIENT_INITIAL_LABEL), + }; + + for v in versions { + qinfo!( + [self], + "Creating initial cipher state v={:?}, role={:?} dcid={}", + v, + role, + hex(dcid) + ); + + let mut initial = CryptoState { + tx: CryptoDxState::new_initial( + *v, + CryptoDxDirection::Write, + write, + dcid, + self.fuzzing, + ), + rx: CryptoDxState::new_initial( + *v, + CryptoDxDirection::Read, + read, + dcid, + self.fuzzing, + ), + }; + if let Some(prev) = self.initials.get(v) { + qinfo!( + [self], + "Continue packet numbers for initial after retry (write is {:?})", + prev.rx.used_pn, + ); + initial.tx.continuation(&prev.tx).unwrap(); + } + self.initials.insert(*v, initial); + } + } + + /// At a server, we can be more targeted in initializing. + /// Initialize on demand: either to decrypt Initial packets that we receive + /// or after a version has been selected. + /// This is maybe slightly inefficient in the first case, because we might + /// not need the send keys if the packet is subsequently discarded, but + /// the overall effort is small enough to write off. + pub fn init_server(&mut self, version: Version, dcid: &[u8]) { + if !self.initials.contains_key(&version) { + self.init(&[version], Role::Server, dcid); + } + } + + pub fn confirm_version(&mut self, orig: Version, confirmed: Version) { + if orig != confirmed { + // This part where the old data is removed and then re-added is to + // appease the borrow checker. + // Note that on the server, we might not have initials for |orig| if it + // was configured for |orig| and only |confirmed| Initial packets arrived. + if let Some(prev) = self.initials.remove(&orig) { + let next = self.initials.get_mut(&confirmed).unwrap(); + next.tx.continuation(&prev.tx).unwrap(); + self.initials.insert(orig, prev); + } + } + } + + pub fn set_0rtt_keys( + &mut self, + version: Version, + dir: CryptoDxDirection, + secret: &SymKey, + cipher: Cipher, + ) { + qtrace!([self], "install 0-RTT keys"); + self.zero_rtt = Some(CryptoDxState::new( + version, + dir, + TLS_EPOCH_ZERO_RTT, + secret, + cipher, + self.fuzzing, + )); + } + + /// Discard keys and return true if that happened. + pub fn discard(&mut self, space: PacketNumberSpace) -> bool { + match space { + PacketNumberSpace::Initial => { + let empty = self.initials.is_empty(); + self.initials.clear(); + !empty + } + PacketNumberSpace::Handshake => self.handshake.take().is_some(), + PacketNumberSpace::ApplicationData => panic!("Can't drop application data keys"), + } + } + + pub fn discard_0rtt_keys(&mut self) { + qtrace!([self], "discard 0-RTT keys"); + assert!( + self.app_read.is_none(), + "Can't discard 0-RTT after setting application keys" + ); + self.zero_rtt = None; + } + + pub fn set_handshake_keys( + &mut self, + version: Version, + write_secret: &SymKey, + read_secret: &SymKey, + cipher: Cipher, + ) { + self.cipher = cipher; + self.handshake = Some(CryptoState { + tx: CryptoDxState::new( + version, + CryptoDxDirection::Write, + TLS_EPOCH_HANDSHAKE, + write_secret, + cipher, + self.fuzzing, + ), + rx: CryptoDxState::new( + version, + CryptoDxDirection::Read, + TLS_EPOCH_HANDSHAKE, + read_secret, + cipher, + self.fuzzing, + ), + }); + } + + pub fn set_application_write_key(&mut self, version: Version, secret: SymKey) -> Res<()> { + debug_assert!(self.app_write.is_none()); + debug_assert_ne!(self.cipher, 0); + let mut app = CryptoDxAppData::new( + version, + CryptoDxDirection::Write, + secret, + self.cipher, + self.fuzzing, + )?; + if let Some(z) = &self.zero_rtt { + if z.direction == CryptoDxDirection::Write { + app.dx.continuation(z)?; + } + } + self.zero_rtt = None; + self.app_write = Some(app); + Ok(()) + } + + pub fn set_application_read_key( + &mut self, + version: Version, + secret: SymKey, + expire_0rtt: Instant, + ) -> Res<()> { + debug_assert!(self.app_write.is_some(), "should have write keys installed"); + debug_assert!(self.app_read.is_none()); + let mut app = CryptoDxAppData::new( + version, + CryptoDxDirection::Read, + secret, + self.cipher, + self.fuzzing, + )?; + if let Some(z) = &self.zero_rtt { + if z.direction == CryptoDxDirection::Read { + app.dx.continuation(z)?; + } + self.read_update_time = Some(expire_0rtt); + } + self.app_read_next = Some(app.next()?); + self.app_read = Some(app); + Ok(()) + } + + /// Update the write keys. + pub fn initiate_key_update(&mut self, largest_acknowledged: Option<PacketNumber>) -> Res<()> { + // Only update if we are able to. We can only do this if we have + // received an acknowledgement for a packet in the current phase. + // Also, skip this if we are waiting for read keys on the existing + // key update to be rolled over. + let write = &self.app_write.as_ref().unwrap().dx; + if write.can_update(largest_acknowledged) && self.read_update_time.is_none() { + // This call additionally checks that we don't advance to the next + // epoch while a key update is in progress. + if self.maybe_update_write()? { + Ok(()) + } else { + qdebug!([self], "Write keys already updated"); + Err(Error::KeyUpdateBlocked) + } + } else { + qdebug!([self], "Waiting for ACK or blocked on read key timer"); + Err(Error::KeyUpdateBlocked) + } + } + + /// Try to update, and return true if it happened. + fn maybe_update_write(&mut self) -> Res<bool> { + // Update write keys. But only do so if the write keys are not already + // ahead of the read keys. If we initiated the key update, the write keys + // will already be ahead. + debug_assert!(self.read_update_time.is_none()); + let write = &self.app_write.as_ref().unwrap(); + let read = &self.app_read.as_ref().unwrap(); + if write.epoch() == read.epoch() { + qdebug!([self], "Update write keys to epoch={}", write.epoch() + 1); + self.app_write = Some(write.next()?); + Ok(true) + } else { + Ok(false) + } + } + + /// Check whether write keys are close to running out of invocations. + /// If that is close, update them if possible. Failing to update at + /// this stage is cause for a fatal error. + pub fn auto_update(&mut self) -> Res<()> { + if let Some(app_write) = self.app_write.as_ref() { + if app_write.dx.should_update() { + qinfo!([self], "Initiating automatic key update"); + if !self.maybe_update_write()? { + return Err(Error::KeysExhausted); + } + } + } + Ok(()) + } + + fn has_0rtt_read(&self) -> bool { + self.zero_rtt + .as_ref() + .filter(|z| z.direction == CryptoDxDirection::Read) + .is_some() + } + + /// Prepare to update read keys. This doesn't happen immediately as + /// we want to ensure that we can continue to receive any delayed + /// packets that use the old keys. So we just set a timer. + pub fn key_update_received(&mut self, expiration: Instant) -> Res<()> { + qtrace!([self], "Key update received"); + // If we received a key update, then we assume that the peer has + // acknowledged a packet we sent in this epoch. It's OK to do that + // because they aren't allowed to update without first having received + // something from us. If the ACK isn't in the packet that triggered this + // key update, it must be in some other packet they have sent. + _ = self.maybe_update_write()?; + + // We shouldn't have 0-RTT keys at this point, but if we do, dump them. + debug_assert_eq!(self.read_update_time.is_some(), self.has_0rtt_read()); + if self.has_0rtt_read() { + self.zero_rtt = None; + } + self.read_update_time = Some(expiration); + Ok(()) + } + + #[must_use] + pub fn update_time(&self) -> Option<Instant> { + self.read_update_time + } + + /// Check if time has passed for updating key update parameters. + /// If it has, then swap keys over and allow more key updates to be initiated. + /// This is also used to discard 0-RTT read keys at the server in the same way. + pub fn check_key_update(&mut self, now: Instant) -> Res<()> { + if let Some(expiry) = self.read_update_time { + // If enough time has passed, then install new keys and clear the timer. + if now >= expiry { + if self.has_0rtt_read() { + qtrace!([self], "Discarding 0-RTT keys"); + self.zero_rtt = None; + } else { + qtrace!([self], "Rotating read keys"); + mem::swap(&mut self.app_read, &mut self.app_read_next); + self.app_read_next = Some(self.app_read.as_ref().unwrap().next()?); + } + self.read_update_time = None; + } + } + Ok(()) + } + + /// Get the current/highest epoch. This returns (write, read) epochs. + #[cfg(test)] + pub fn get_epochs(&self) -> (Option<usize>, Option<usize>) { + let to_epoch = |app: &Option<CryptoDxAppData>| app.as_ref().map(|a| a.dx.epoch); + (to_epoch(&self.app_write), to_epoch(&self.app_read)) + } + + /// While we are awaiting the completion of a key update, we might receive + /// valid packets that are protected with old keys. We need to ensure that + /// these don't carry packet numbers higher than those in packets protected + /// with the newer keys. To ensure that, this is called after every decryption. + pub fn check_pn_overlap(&mut self) -> Res<()> { + // We only need to do the check while we are waiting for read keys to be updated. + if self.read_update_time.is_some() { + qtrace!([self], "Checking for PN overlap"); + let next_dx = &mut self.app_read_next.as_mut().unwrap().dx; + next_dx.continuation(&self.app_read.as_ref().unwrap().dx)?; + } + Ok(()) + } + + /// Make some state for removing protection in tests. + #[cfg(not(feature = "fuzzing"))] + #[cfg(test)] + pub(crate) fn test_default() -> Self { + let read = |epoch| { + let mut dx = CryptoDxState::test_default(); + dx.direction = CryptoDxDirection::Read; + dx.epoch = epoch; + dx + }; + let app_read = |epoch| CryptoDxAppData { + dx: read(epoch), + cipher: TLS_AES_128_GCM_SHA256, + next_secret: hkdf::import_key(TLS_VERSION_1_3, &[0xaa; 32]).unwrap(), + fuzzing: false, + }; + let mut initials = HashMap::new(); + initials.insert( + Version::Version1, + CryptoState { + tx: CryptoDxState::test_default(), + rx: read(0), + }, + ); + Self { + initials, + handshake: None, + zero_rtt: None, + cipher: TLS_AES_128_GCM_SHA256, + // This isn't used, but the epoch is read to check for a key update. + app_write: Some(app_read(3)), + app_read: Some(app_read(3)), + app_read_next: Some(app_read(4)), + read_update_time: None, + fuzzing: false, + } + } + + #[cfg(all(not(feature = "fuzzing"), test))] + pub(crate) fn test_chacha() -> Self { + const SECRET: &[u8] = &[ + 0x9a, 0xc3, 0x12, 0xa7, 0xf8, 0x77, 0x46, 0x8e, 0xbe, 0x69, 0x42, 0x27, 0x48, 0xad, + 0x00, 0xa1, 0x54, 0x43, 0xf1, 0x82, 0x03, 0xa0, 0x7d, 0x60, 0x60, 0xf6, 0x88, 0xf3, + 0x0f, 0x21, 0x63, 0x2b, + ]; + let secret = hkdf::import_key(TLS_VERSION_1_3, SECRET).unwrap(); + let app_read = |epoch| CryptoDxAppData { + dx: CryptoDxState { + version: Version::Version1, + direction: CryptoDxDirection::Read, + epoch, + aead: Aead::new( + false, + TLS_VERSION_1_3, + TLS_CHACHA20_POLY1305_SHA256, + &secret, + "quic ", // This is a v1 test so hard-code the label. + ) + .unwrap(), + hpkey: HpKey::extract( + TLS_VERSION_1_3, + TLS_CHACHA20_POLY1305_SHA256, + &secret, + "quic hp", + ) + .unwrap(), + used_pn: 0..645_971_972, + min_pn: 0, + invocations: 10, + fuzzing: false, + }, + cipher: TLS_CHACHA20_POLY1305_SHA256, + next_secret: secret.clone(), + fuzzing: false, + }; + Self { + initials: HashMap::new(), + handshake: None, + zero_rtt: None, + cipher: TLS_CHACHA20_POLY1305_SHA256, + app_write: Some(app_read(3)), + app_read: Some(app_read(3)), + app_read_next: Some(app_read(4)), + read_update_time: None, + fuzzing: false, + } + } +} + +impl std::fmt::Display for CryptoStates { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "CryptoStates") + } +} + +#[derive(Debug, Default)] +pub struct CryptoStream { + tx: TxBuffer, + rx: RxStreamOrderer, +} + +#[derive(Debug)] +#[allow(dead_code)] // Suppress false positive: https://github.com/rust-lang/rust/issues/68408 +pub enum CryptoStreams { + Initial { + initial: CryptoStream, + handshake: CryptoStream, + application: CryptoStream, + }, + Handshake { + handshake: CryptoStream, + application: CryptoStream, + }, + ApplicationData { + application: CryptoStream, + }, +} + +impl CryptoStreams { + /// Keep around 64k if a server wants to push excess data at us. + const BUFFER_LIMIT: u64 = 65536; + + pub fn discard(&mut self, space: PacketNumberSpace) { + match space { + PacketNumberSpace::Initial => { + if let Self::Initial { + handshake, + application, + .. + } = self + { + *self = Self::Handshake { + handshake: mem::take(handshake), + application: mem::take(application), + }; + } + } + PacketNumberSpace::Handshake => { + if let Self::Handshake { application, .. } = self { + *self = Self::ApplicationData { + application: mem::take(application), + }; + } else if matches!(self, Self::Initial { .. }) { + panic!("Discarding handshake before initial discarded"); + } + } + PacketNumberSpace::ApplicationData => { + panic!("Discarding application data crypto streams") + } + } + } + + pub fn send(&mut self, space: PacketNumberSpace, data: &[u8]) { + self.get_mut(space).unwrap().tx.send(data); + } + + pub fn inbound_frame(&mut self, space: PacketNumberSpace, offset: u64, data: &[u8]) -> Res<()> { + let rx = &mut self.get_mut(space).unwrap().rx; + rx.inbound_frame(offset, data); + if rx.received() - rx.retired() <= Self::BUFFER_LIMIT { + Ok(()) + } else { + Err(Error::CryptoBufferExceeded) + } + } + + pub fn data_ready(&self, space: PacketNumberSpace) -> bool { + self.get(space).map_or(false, |cs| cs.rx.data_ready()) + } + + pub fn read_to_end(&mut self, space: PacketNumberSpace, buf: &mut Vec<u8>) -> usize { + self.get_mut(space).unwrap().rx.read_to_end(buf) + } + + pub fn acked(&mut self, token: &CryptoRecoveryToken) { + self.get_mut(token.space) + .unwrap() + .tx + .mark_as_acked(token.offset, token.length); + } + + pub fn lost(&mut self, token: &CryptoRecoveryToken) { + // See BZ 1624800, ignore lost packets in spaces we've dropped keys + if let Some(cs) = self.get_mut(token.space) { + cs.tx.mark_as_lost(token.offset, token.length); + } + } + + /// Resend any Initial or Handshake CRYPTO frames that might be outstanding. + /// This can help speed up handshake times. + pub fn resend_unacked(&mut self, space: PacketNumberSpace) { + if space != PacketNumberSpace::ApplicationData { + if let Some(cs) = self.get_mut(space) { + cs.tx.unmark_sent(); + } + } + } + + fn get(&self, space: PacketNumberSpace) -> Option<&CryptoStream> { + let (initial, hs, app) = match self { + Self::Initial { + initial, + handshake, + application, + } => (Some(initial), Some(handshake), Some(application)), + Self::Handshake { + handshake, + application, + } => (None, Some(handshake), Some(application)), + Self::ApplicationData { application } => (None, None, Some(application)), + }; + match space { + PacketNumberSpace::Initial => initial, + PacketNumberSpace::Handshake => hs, + PacketNumberSpace::ApplicationData => app, + } + } + + fn get_mut(&mut self, space: PacketNumberSpace) -> Option<&mut CryptoStream> { + let (initial, hs, app) = match self { + Self::Initial { + initial, + handshake, + application, + } => (Some(initial), Some(handshake), Some(application)), + Self::Handshake { + handshake, + application, + } => (None, Some(handshake), Some(application)), + Self::ApplicationData { application } => (None, None, Some(application)), + }; + match space { + PacketNumberSpace::Initial => initial, + PacketNumberSpace::Handshake => hs, + PacketNumberSpace::ApplicationData => app, + } + } + + pub fn write_frame( + &mut self, + space: PacketNumberSpace, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) -> Res<()> { + let cs = self.get_mut(space).unwrap(); + if let Some((offset, data)) = cs.tx.next_bytes() { + let mut header_len = 1 + Encoder::varint_len(offset) + 1; + + // Don't bother if there isn't room for the header and some data. + if builder.remaining() < header_len + 1 { + return Ok(()); + } + // Calculate length of data based on the minimum of: + // - available data + // - remaining space, less the header, which counts only one byte for the length at + // first to avoid underestimating length + let length = min(data.len(), builder.remaining() - header_len); + header_len += Encoder::varint_len(u64::try_from(length).unwrap()) - 1; + let length = min(data.len(), builder.remaining() - header_len); + + builder.encode_varint(crate::frame::FRAME_TYPE_CRYPTO); + builder.encode_varint(offset); + builder.encode_vvec(&data[..length]); + + cs.tx.mark_as_sent(offset, length); + + qdebug!("CRYPTO for {} offset={}, len={}", space, offset, length); + tokens.push(RecoveryToken::Crypto(CryptoRecoveryToken { + space, + offset, + length, + })); + stats.crypto += 1; + } + Ok(()) + } +} + +impl Default for CryptoStreams { + fn default() -> Self { + Self::Initial { + initial: CryptoStream::default(), + handshake: CryptoStream::default(), + application: CryptoStream::default(), + } + } +} + +#[derive(Debug, Clone)] +pub struct CryptoRecoveryToken { + space: PacketNumberSpace, + offset: u64, + length: usize, +} diff --git a/third_party/rust/neqo-transport/src/events.rs b/third_party/rust/neqo-transport/src/events.rs new file mode 100644 index 0000000000..88a85250ee --- /dev/null +++ b/third_party/rust/neqo-transport/src/events.rs @@ -0,0 +1,321 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Collecting a list of events relevant to whoever is using the Connection. + +use std::{cell::RefCell, collections::VecDeque, rc::Rc}; + +use neqo_common::event::Provider as EventProvider; +use neqo_crypto::ResumptionToken; + +use crate::{ + connection::State, + quic_datagrams::DatagramTracking, + stream_id::{StreamId, StreamType}, + AppError, Stats, +}; + +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq)] +pub enum OutgoingDatagramOutcome { + DroppedTooBig, + DroppedQueueFull, + Lost, + Acked, +} + +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq)] +pub enum ConnectionEvent { + /// Cert authentication needed + AuthenticationNeeded, + /// Encrypted client hello fallback occurred. The certificate for the + /// public name needs to be authenticated. + EchFallbackAuthenticationNeeded { + public_name: String, + }, + /// A new uni (read) or bidi stream has been opened by the peer. + NewStream { + stream_id: StreamId, + }, + /// Space available in the buffer for an application write to succeed. + SendStreamWritable { + stream_id: StreamId, + }, + /// New bytes available for reading. + RecvStreamReadable { + stream_id: StreamId, + }, + /// Peer reset the stream. + RecvStreamReset { + stream_id: StreamId, + app_error: AppError, + }, + /// Peer has sent STOP_SENDING + SendStreamStopSending { + stream_id: StreamId, + app_error: AppError, + }, + /// Peer has acked everything sent on the stream. + SendStreamComplete { + stream_id: StreamId, + }, + /// Peer increased MAX_STREAMS + SendStreamCreatable { + stream_type: StreamType, + }, + /// Connection state change. + StateChange(State), + /// The server rejected 0-RTT. + /// This event invalidates all state in streams that has been created. + /// Any data written to streams needs to be written again. + ZeroRttRejected, + ResumptionToken(ResumptionToken), + Datagram(Vec<u8>), + OutgoingDatagramOutcome { + id: u64, + outcome: OutgoingDatagramOutcome, + }, + IncomingDatagramDropped, +} + +#[derive(Debug, Default, Clone)] +#[allow(clippy::module_name_repetitions)] +pub struct ConnectionEvents { + events: Rc<RefCell<VecDeque<ConnectionEvent>>>, +} + +impl ConnectionEvents { + pub fn authentication_needed(&self) { + self.insert(ConnectionEvent::AuthenticationNeeded); + } + + pub fn ech_fallback_authentication_needed(&self, public_name: String) { + self.insert(ConnectionEvent::EchFallbackAuthenticationNeeded { public_name }); + } + + pub fn new_stream(&self, stream_id: StreamId) { + self.insert(ConnectionEvent::NewStream { stream_id }); + } + + pub fn recv_stream_readable(&self, stream_id: StreamId) { + self.insert(ConnectionEvent::RecvStreamReadable { stream_id }); + } + + pub fn recv_stream_reset(&self, stream_id: StreamId, app_error: AppError) { + // If reset, no longer readable. + self.remove(|evt| matches!(evt, ConnectionEvent::RecvStreamReadable { stream_id: x } if *x == stream_id.as_u64())); + + self.insert(ConnectionEvent::RecvStreamReset { + stream_id, + app_error, + }); + } + + pub fn send_stream_writable(&self, stream_id: StreamId) { + self.insert(ConnectionEvent::SendStreamWritable { stream_id }); + } + + pub fn send_stream_stop_sending(&self, stream_id: StreamId, app_error: AppError) { + // If stopped, no longer writable. + self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamWritable { stream_id: x } if *x == stream_id)); + + self.insert(ConnectionEvent::SendStreamStopSending { + stream_id, + app_error, + }); + } + + pub fn send_stream_complete(&self, stream_id: StreamId) { + self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamWritable { stream_id: x } if *x == stream_id)); + + self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamStopSending { stream_id: x, .. } if *x == stream_id.as_u64())); + + self.insert(ConnectionEvent::SendStreamComplete { stream_id }); + } + + pub fn send_stream_creatable(&self, stream_type: StreamType) { + self.insert(ConnectionEvent::SendStreamCreatable { stream_type }); + } + + pub fn connection_state_change(&self, state: State) { + // If closing, existing events no longer relevant. + match state { + State::Closing { .. } | State::Closed(_) => self.events.borrow_mut().clear(), + _ => (), + } + self.insert(ConnectionEvent::StateChange(state)); + } + + pub fn client_resumption_token(&self, token: ResumptionToken) { + self.insert(ConnectionEvent::ResumptionToken(token)); + } + + pub fn client_0rtt_rejected(&self) { + // If 0rtt rejected, must start over and existing events are no longer + // relevant. + self.events.borrow_mut().clear(); + self.insert(ConnectionEvent::ZeroRttRejected); + } + + pub fn recv_stream_complete(&self, stream_id: StreamId) { + // If stopped, no longer readable. + self.remove(|evt| matches!(evt, ConnectionEvent::RecvStreamReadable { stream_id: x } if *x == stream_id.as_u64())); + } + + // The number of datagrams in the events queue is limited to max_queued_datagrams. + // This function ensure this and deletes the oldest datagrams if needed. + fn check_datagram_queued(&self, max_queued_datagrams: usize, stats: &mut Stats) { + let mut q = self.events.borrow_mut(); + let mut remove = None; + if q.iter() + .filter(|evt| matches!(evt, ConnectionEvent::Datagram(_))) + .count() + == max_queued_datagrams + { + if let Some(d) = q + .iter() + .rev() + .enumerate() + .filter(|(_, evt)| matches!(evt, ConnectionEvent::Datagram(_))) + .take(1) + .next() + { + remove = Some(d.0); + } + } + if let Some(r) = remove { + q.remove(r); + q.push_back(ConnectionEvent::IncomingDatagramDropped); + stats.incoming_datagram_dropped += 1; + } + } + + pub fn add_datagram(&self, max_queued_datagrams: usize, data: &[u8], stats: &mut Stats) { + self.check_datagram_queued(max_queued_datagrams, stats); + self.events + .borrow_mut() + .push_back(ConnectionEvent::Datagram(data.to_vec())); + } + + pub fn datagram_outcome( + &self, + dgram_tracker: &DatagramTracking, + outcome: OutgoingDatagramOutcome, + ) { + if let DatagramTracking::Id(id) = dgram_tracker { + self.events + .borrow_mut() + .push_back(ConnectionEvent::OutgoingDatagramOutcome { id: *id, outcome }); + } + } + + fn insert(&self, event: ConnectionEvent) { + let mut q = self.events.borrow_mut(); + + // Special-case two enums that are not strictly PartialEq equal but that + // we wish to avoid inserting duplicates. + let already_present = match &event { + ConnectionEvent::SendStreamStopSending { stream_id, .. } => q.iter().any(|evt| { + matches!(evt, ConnectionEvent::SendStreamStopSending { stream_id: x, .. } + if *x == *stream_id) + }), + ConnectionEvent::RecvStreamReset { stream_id, .. } => q.iter().any(|evt| { + matches!(evt, ConnectionEvent::RecvStreamReset { stream_id: x, .. } + if *x == *stream_id) + }), + _ => q.contains(&event), + }; + if !already_present { + q.push_back(event); + } + } + + fn remove<F>(&self, f: F) + where + F: Fn(&ConnectionEvent) -> bool, + { + self.events.borrow_mut().retain(|evt| !f(evt)); + } +} + +impl EventProvider for ConnectionEvents { + type Event = ConnectionEvent; + + fn has_events(&self) -> bool { + !self.events.borrow().is_empty() + } + + fn next_event(&mut self) -> Option<Self::Event> { + self.events.borrow_mut().pop_front() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ConnectionError, Error}; + + #[test] + fn event_culling() { + let mut evts = ConnectionEvents::default(); + + evts.client_0rtt_rejected(); + evts.client_0rtt_rejected(); + assert_eq!(evts.events().count(), 1); + assert_eq!(evts.events().count(), 0); + + evts.new_stream(4.into()); + evts.new_stream(4.into()); + assert_eq!(evts.events().count(), 1); + + evts.recv_stream_readable(6.into()); + evts.recv_stream_reset(6.into(), 66); + evts.recv_stream_reset(6.into(), 65); + assert_eq!(evts.events().count(), 1); + + evts.send_stream_writable(8.into()); + evts.send_stream_writable(8.into()); + evts.send_stream_stop_sending(8.into(), 55); + evts.send_stream_stop_sending(8.into(), 56); + let events = evts.events().collect::<Vec<_>>(); + assert_eq!(events.len(), 1); + assert_eq!( + events[0], + ConnectionEvent::SendStreamStopSending { + stream_id: StreamId::new(8), + app_error: 55 + } + ); + + evts.send_stream_writable(8.into()); + evts.send_stream_writable(8.into()); + evts.send_stream_stop_sending(8.into(), 55); + evts.send_stream_stop_sending(8.into(), 56); + evts.send_stream_complete(8.into()); + assert_eq!(evts.events().count(), 1); + + evts.send_stream_writable(8.into()); + evts.send_stream_writable(9.into()); + evts.send_stream_stop_sending(10.into(), 55); + evts.send_stream_stop_sending(11.into(), 56); + evts.send_stream_complete(12.into()); + assert_eq!(evts.events().count(), 5); + + evts.send_stream_writable(8.into()); + evts.send_stream_writable(9.into()); + evts.send_stream_stop_sending(10.into(), 55); + evts.send_stream_stop_sending(11.into(), 56); + evts.send_stream_complete(12.into()); + evts.client_0rtt_rejected(); + assert_eq!(evts.events().count(), 1); + + evts.send_stream_writable(9.into()); + evts.send_stream_stop_sending(10.into(), 55); + evts.connection_state_change(State::Closed(ConnectionError::Transport( + Error::StreamStateError, + ))); + assert_eq!(evts.events().count(), 1); + } +} diff --git a/third_party/rust/neqo-transport/src/fc.rs b/third_party/rust/neqo-transport/src/fc.rs new file mode 100644 index 0000000000..a219ca7e8d --- /dev/null +++ b/third_party/rust/neqo-transport/src/fc.rs @@ -0,0 +1,918 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Tracks possibly-redundant flow control signals from other code and converts +// into flow control frames needing to be sent to the remote. + +use std::{ + convert::TryFrom, + fmt::Debug, + ops::{Deref, DerefMut, Index, IndexMut}, +}; + +use neqo_common::{qtrace, Role}; + +use crate::{ + frame::{ + FRAME_TYPE_DATA_BLOCKED, FRAME_TYPE_MAX_DATA, FRAME_TYPE_MAX_STREAMS_BIDI, + FRAME_TYPE_MAX_STREAMS_UNIDI, FRAME_TYPE_MAX_STREAM_DATA, FRAME_TYPE_STREAMS_BLOCKED_BIDI, + FRAME_TYPE_STREAMS_BLOCKED_UNIDI, FRAME_TYPE_STREAM_DATA_BLOCKED, + }, + packet::PacketBuilder, + recovery::{RecoveryToken, StreamRecoveryToken}, + stats::FrameStats, + stream_id::{StreamId, StreamType}, + Error, Res, +}; + +#[derive(Debug)] +pub struct SenderFlowControl<T> +where + T: Debug + Sized, +{ + /// The thing that we're counting for. + subject: T, + /// The limit. + limit: u64, + /// How much of that limit we've used. + used: u64, + /// The point at which blocking occurred. This is updated each time + /// the sender decides that it is blocked. It only ever changes + /// when blocking occurs. This ensures that blocking at any given limit + /// is only reported once. + /// Note: All values are one greater than the corresponding `limit` to + /// allow distinguishing between blocking at a limit of 0 and no blocking. + blocked_at: u64, + /// Whether a blocked frame should be sent. + blocked_frame: bool, +} + +impl<T> SenderFlowControl<T> +where + T: Debug + Sized, +{ + /// Make a new instance with the initial value and subject. + pub fn new(subject: T, initial: u64) -> Self { + Self { + subject, + limit: initial, + used: 0, + blocked_at: 0, + blocked_frame: false, + } + } + + /// Update the maximum. Returns `true` if the change was an increase. + pub fn update(&mut self, limit: u64) -> bool { + debug_assert!(limit < u64::MAX); + if limit > self.limit { + self.limit = limit; + self.blocked_frame = false; + true + } else { + false + } + } + + /// Consume flow control. + pub fn consume(&mut self, count: usize) { + let amt = u64::try_from(count).unwrap(); + debug_assert!(self.used + amt <= self.limit); + self.used += amt; + } + + /// Get available flow control. + pub fn available(&self) -> usize { + usize::try_from(self.limit - self.used).unwrap_or(usize::MAX) + } + + /// How much data has been written. + pub fn used(&self) -> u64 { + self.used + } + + /// Mark flow control as blocked. + /// This only does something if the current limit exceeds the last reported blocking limit. + pub fn blocked(&mut self) { + if self.limit >= self.blocked_at { + self.blocked_at = self.limit + 1; + self.blocked_frame = true; + } + } + + /// Return whether a blocking frame needs to be sent. + /// This is `Some` with the active limit if `blocked` has been called, + /// if a blocking frame has not been sent (or it has been lost), and + /// if the blocking condition remains. + fn blocked_needed(&self) -> Option<u64> { + if self.blocked_frame && self.limit < self.blocked_at { + Some(self.blocked_at - 1) + } else { + None + } + } + + /// Clear the need to send a blocked frame. + fn blocked_sent(&mut self) { + self.blocked_frame = false; + } + + /// Mark a blocked frame as having been lost. + /// Only send again if value of `self.blocked_at` hasn't increased since sending. + /// That would imply that the limit has since increased. + pub fn frame_lost(&mut self, limit: u64) { + if self.blocked_at == limit + 1 { + self.blocked_frame = true; + } + } +} + +impl SenderFlowControl<()> { + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + if let Some(limit) = self.blocked_needed() { + if builder.write_varint_frame(&[FRAME_TYPE_DATA_BLOCKED, limit]) { + stats.data_blocked += 1; + tokens.push(RecoveryToken::Stream(StreamRecoveryToken::DataBlocked( + limit, + ))); + self.blocked_sent(); + } + } + } +} + +impl SenderFlowControl<StreamId> { + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + if let Some(limit) = self.blocked_needed() { + if builder.write_varint_frame(&[ + FRAME_TYPE_STREAM_DATA_BLOCKED, + self.subject.as_u64(), + limit, + ]) { + stats.stream_data_blocked += 1; + tokens.push(RecoveryToken::Stream( + StreamRecoveryToken::StreamDataBlocked { + stream_id: self.subject, + limit, + }, + )); + self.blocked_sent(); + } + } + } +} + +impl SenderFlowControl<StreamType> { + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + if let Some(limit) = self.blocked_needed() { + let frame = match self.subject { + StreamType::BiDi => FRAME_TYPE_STREAMS_BLOCKED_BIDI, + StreamType::UniDi => FRAME_TYPE_STREAMS_BLOCKED_UNIDI, + }; + if builder.write_varint_frame(&[frame, limit]) { + stats.streams_blocked += 1; + tokens.push(RecoveryToken::Stream(StreamRecoveryToken::StreamsBlocked { + stream_type: self.subject, + limit, + })); + self.blocked_sent(); + } + } + } +} + +#[derive(Debug)] +pub struct ReceiverFlowControl<T> +where + T: Debug + Sized, +{ + /// The thing that we're counting for. + subject: T, + /// The maximum amount of items that can be active (e.g., the size of the receive buffer). + max_active: u64, + /// Last max allowed sent. + max_allowed: u64, + /// Item received, but not retired yet. + /// This will be used for byte flow control: each stream will remember is largest byte + /// offset received and session flow control will remember the sum of all bytes consumed + /// by all streams. + consumed: u64, + /// Retired items. + retired: u64, + frame_pending: bool, +} + +impl<T> ReceiverFlowControl<T> +where + T: Debug + Sized, +{ + /// Make a new instance with the initial value and subject. + pub fn new(subject: T, max: u64) -> Self { + Self { + subject, + max_active: max, + max_allowed: max, + consumed: 0, + retired: 0, + frame_pending: false, + } + } + + /// Retired some items and maybe send flow control + /// update. + pub fn retire(&mut self, retired: u64) { + if retired <= self.retired { + return; + } + + self.retired = retired; + if self.retired + self.max_active / 2 > self.max_allowed { + self.frame_pending = true; + } + } + + /// This function is called when STREAM_DATA_BLOCKED frame is received. + /// The flow control will try to send an update if possible. + pub fn send_flowc_update(&mut self) { + if self.retired + self.max_active > self.max_allowed { + self.frame_pending = true; + } + } + + pub fn frame_needed(&self) -> bool { + self.frame_pending + } + + pub fn next_limit(&self) -> u64 { + self.retired + self.max_active + } + + pub fn max_active(&self) -> u64 { + self.max_active + } + + pub fn frame_lost(&mut self, maximum_data: u64) { + if maximum_data == self.max_allowed { + self.frame_pending = true; + } + } + + fn frame_sent(&mut self, new_max: u64) { + self.max_allowed = new_max; + self.frame_pending = false; + } + + pub fn set_max_active(&mut self, max: u64) { + // If max_active has been increased, send an update immediately. + self.frame_pending |= self.max_active < max; + self.max_active = max; + } + + pub fn retired(&self) -> u64 { + self.retired + } + + pub fn consumed(&self) -> u64 { + self.consumed + } +} + +impl ReceiverFlowControl<()> { + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + if !self.frame_needed() { + return; + } + let max_allowed = self.next_limit(); + if builder.write_varint_frame(&[FRAME_TYPE_MAX_DATA, max_allowed]) { + stats.max_data += 1; + tokens.push(RecoveryToken::Stream(StreamRecoveryToken::MaxData( + max_allowed, + ))); + self.frame_sent(max_allowed); + } + } + + pub fn add_retired(&mut self, count: u64) { + debug_assert!(self.retired + count <= self.consumed); + self.retired += count; + if self.retired + self.max_active / 2 > self.max_allowed { + self.frame_pending = true; + } + } + + pub fn consume(&mut self, count: u64) -> Res<()> { + if self.consumed + count > self.max_allowed { + qtrace!( + "Session RX window exceeded: consumed:{} new:{} limit:{}", + self.consumed, + count, + self.max_allowed + ); + return Err(Error::FlowControlError); + } + self.consumed += count; + Ok(()) + } +} + +impl Default for ReceiverFlowControl<()> { + fn default() -> Self { + Self::new((), 0) + } +} + +impl ReceiverFlowControl<StreamId> { + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + if !self.frame_needed() { + return; + } + let max_allowed = self.next_limit(); + if builder.write_varint_frame(&[ + FRAME_TYPE_MAX_STREAM_DATA, + self.subject.as_u64(), + max_allowed, + ]) { + stats.max_stream_data += 1; + tokens.push(RecoveryToken::Stream(StreamRecoveryToken::MaxStreamData { + stream_id: self.subject, + max_data: max_allowed, + })); + self.frame_sent(max_allowed); + } + } + + pub fn add_retired(&mut self, count: u64) { + debug_assert!(self.retired + count <= self.consumed); + self.retired += count; + if self.retired + self.max_active / 2 > self.max_allowed { + self.frame_pending = true; + } + } + + pub fn set_consumed(&mut self, consumed: u64) -> Res<u64> { + if consumed <= self.consumed { + return Ok(0); + } + + if consumed > self.max_allowed { + qtrace!("Stream RX window exceeded: {}", consumed); + return Err(Error::FlowControlError); + } + let new_consumed = consumed - self.consumed; + self.consumed = consumed; + Ok(new_consumed) + } +} + +impl Default for ReceiverFlowControl<StreamId> { + fn default() -> Self { + Self::new(StreamId::new(0), 0) + } +} + +impl ReceiverFlowControl<StreamType> { + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + if !self.frame_needed() { + return; + } + let max_streams = self.next_limit(); + let frame = match self.subject { + StreamType::BiDi => FRAME_TYPE_MAX_STREAMS_BIDI, + StreamType::UniDi => FRAME_TYPE_MAX_STREAMS_UNIDI, + }; + if builder.write_varint_frame(&[frame, max_streams]) { + stats.max_streams += 1; + tokens.push(RecoveryToken::Stream(StreamRecoveryToken::MaxStreams { + stream_type: self.subject, + max_streams, + })); + self.frame_sent(max_streams); + } + } + + /// Check if received item exceeds the allowed flow control limit. + pub fn check_allowed(&self, new_end: u64) -> bool { + new_end < self.max_allowed + } + + /// Retire given amount of additional data. + /// This function will send flow updates immediately. + pub fn add_retired(&mut self, count: u64) { + self.retired += count; + if count > 0 { + self.send_flowc_update(); + } + } +} + +pub struct RemoteStreamLimit { + streams_fc: ReceiverFlowControl<StreamType>, + next_stream: StreamId, +} + +impl RemoteStreamLimit { + pub fn new(stream_type: StreamType, max_streams: u64, role: Role) -> Self { + Self { + streams_fc: ReceiverFlowControl::new(stream_type, max_streams), + // // This is for a stream created by a peer, therefore we use role.remote(). + next_stream: StreamId::init(stream_type, role.remote()), + } + } + + pub fn is_allowed(&self, stream_id: StreamId) -> bool { + let stream_idx = stream_id.as_u64() >> 2; + self.streams_fc.check_allowed(stream_idx) + } + + pub fn is_new_stream(&self, stream_id: StreamId) -> Res<bool> { + if !self.is_allowed(stream_id) { + return Err(Error::StreamLimitError); + } + Ok(stream_id >= self.next_stream) + } + + pub fn take_stream_id(&mut self) -> StreamId { + let new_stream = self.next_stream; + self.next_stream.next(); + assert!(self.is_allowed(new_stream)); + new_stream + } +} + +impl Deref for RemoteStreamLimit { + type Target = ReceiverFlowControl<StreamType>; + fn deref(&self) -> &Self::Target { + &self.streams_fc + } +} + +impl DerefMut for RemoteStreamLimit { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.streams_fc + } +} + +pub struct RemoteStreamLimits { + bidirectional: RemoteStreamLimit, + unidirectional: RemoteStreamLimit, +} + +impl RemoteStreamLimits { + pub fn new(local_max_stream_bidi: u64, local_max_stream_uni: u64, role: Role) -> Self { + Self { + bidirectional: RemoteStreamLimit::new(StreamType::BiDi, local_max_stream_bidi, role), + unidirectional: RemoteStreamLimit::new(StreamType::UniDi, local_max_stream_uni, role), + } + } +} + +impl Index<StreamType> for RemoteStreamLimits { + type Output = RemoteStreamLimit; + + fn index(&self, idx: StreamType) -> &Self::Output { + match idx { + StreamType::BiDi => &self.bidirectional, + StreamType::UniDi => &self.unidirectional, + } + } +} + +impl IndexMut<StreamType> for RemoteStreamLimits { + fn index_mut(&mut self, idx: StreamType) -> &mut Self::Output { + match idx { + StreamType::BiDi => &mut self.bidirectional, + StreamType::UniDi => &mut self.unidirectional, + } + } +} + +pub struct LocalStreamLimits { + bidirectional: SenderFlowControl<StreamType>, + unidirectional: SenderFlowControl<StreamType>, + role_bit: u64, +} + +impl LocalStreamLimits { + pub fn new(role: Role) -> Self { + Self { + bidirectional: SenderFlowControl::new(StreamType::BiDi, 0), + unidirectional: SenderFlowControl::new(StreamType::UniDi, 0), + role_bit: StreamId::role_bit(role), + } + } + + pub fn take_stream_id(&mut self, stream_type: StreamType) -> Option<StreamId> { + let fc = match stream_type { + StreamType::BiDi => &mut self.bidirectional, + StreamType::UniDi => &mut self.unidirectional, + }; + if fc.available() > 0 { + let new_stream = fc.used(); + fc.consume(1); + let type_bit = match stream_type { + StreamType::BiDi => 0, + StreamType::UniDi => 2, + }; + Some(StreamId::from((new_stream << 2) + type_bit + self.role_bit)) + } else { + fc.blocked(); + None + } + } +} + +impl Index<StreamType> for LocalStreamLimits { + type Output = SenderFlowControl<StreamType>; + + fn index(&self, idx: StreamType) -> &Self::Output { + match idx { + StreamType::BiDi => &self.bidirectional, + StreamType::UniDi => &self.unidirectional, + } + } +} + +impl IndexMut<StreamType> for LocalStreamLimits { + fn index_mut(&mut self, idx: StreamType) -> &mut Self::Output { + match idx { + StreamType::BiDi => &mut self.bidirectional, + StreamType::UniDi => &mut self.unidirectional, + } + } +} + +#[cfg(test)] +mod test { + use neqo_common::{Encoder, Role}; + + use super::{LocalStreamLimits, ReceiverFlowControl, RemoteStreamLimits, SenderFlowControl}; + use crate::{ + packet::PacketBuilder, + stats::FrameStats, + stream_id::{StreamId, StreamType}, + Error, + }; + + #[test] + fn blocked_at_zero() { + let mut fc = SenderFlowControl::new((), 0); + fc.blocked(); + assert_eq!(fc.blocked_needed(), Some(0)); + } + + #[test] + fn blocked() { + let mut fc = SenderFlowControl::new((), 10); + fc.blocked(); + assert_eq!(fc.blocked_needed(), Some(10)); + } + + #[test] + fn update_consume() { + let mut fc = SenderFlowControl::new((), 10); + fc.consume(10); + assert_eq!(fc.available(), 0); + fc.update(5); // An update lower than the current limit does nothing. + assert_eq!(fc.available(), 0); + fc.update(15); + assert_eq!(fc.available(), 5); + fc.consume(3); + assert_eq!(fc.available(), 2); + } + + #[test] + fn update_clears_blocked() { + let mut fc = SenderFlowControl::new((), 10); + fc.blocked(); + assert_eq!(fc.blocked_needed(), Some(10)); + fc.update(5); // An update lower than the current limit does nothing. + assert_eq!(fc.blocked_needed(), Some(10)); + fc.update(11); + assert_eq!(fc.blocked_needed(), None); + } + + #[test] + fn lost_blocked_resent() { + let mut fc = SenderFlowControl::new((), 10); + fc.blocked(); + fc.blocked_sent(); + assert_eq!(fc.blocked_needed(), None); + fc.frame_lost(10); + assert_eq!(fc.blocked_needed(), Some(10)); + } + + #[test] + fn lost_after_increase() { + let mut fc = SenderFlowControl::new((), 10); + fc.blocked(); + fc.blocked_sent(); + assert_eq!(fc.blocked_needed(), None); + fc.update(11); + fc.frame_lost(10); + assert_eq!(fc.blocked_needed(), None); + } + + #[test] + fn lost_after_higher_blocked() { + let mut fc = SenderFlowControl::new((), 10); + fc.blocked(); + fc.blocked_sent(); + fc.update(11); + fc.blocked(); + assert_eq!(fc.blocked_needed(), Some(11)); + fc.blocked_sent(); + fc.frame_lost(10); + assert_eq!(fc.blocked_needed(), None); + } + + #[test] + fn do_no_need_max_allowed_frame_at_start() { + let fc = ReceiverFlowControl::new((), 0); + assert!(!fc.frame_needed()); + } + + #[test] + fn max_allowed_after_items_retired() { + let mut fc = ReceiverFlowControl::new((), 100); + fc.retire(49); + assert!(!fc.frame_needed()); + fc.retire(51); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 151); + } + + #[test] + fn need_max_allowed_frame_after_loss() { + let mut fc = ReceiverFlowControl::new((), 100); + fc.retire(100); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 200); + fc.frame_sent(200); + assert!(!fc.frame_needed()); + fc.frame_lost(200); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 200); + } + + #[test] + fn no_max_allowed_frame_after_old_loss() { + let mut fc = ReceiverFlowControl::new((), 100); + fc.retire(51); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 151); + fc.frame_sent(151); + assert!(!fc.frame_needed()); + fc.retire(102); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 202); + fc.frame_sent(202); + assert!(!fc.frame_needed()); + fc.frame_lost(151); + assert!(!fc.frame_needed()); + } + + #[test] + fn force_send_max_allowed() { + let mut fc = ReceiverFlowControl::new((), 100); + fc.retire(10); + assert!(!fc.frame_needed()); + } + + #[test] + fn multiple_retries_after_frame_pending_is_set() { + let mut fc = ReceiverFlowControl::new((), 100); + fc.retire(51); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 151); + fc.retire(61); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 161); + fc.retire(88); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 188); + fc.retire(90); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 190); + fc.frame_sent(190); + assert!(!fc.frame_needed()); + fc.retire(141); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 241); + fc.frame_sent(241); + assert!(!fc.frame_needed()); + } + + #[test] + fn new_retired_before_loss() { + let mut fc = ReceiverFlowControl::new((), 100); + fc.retire(51); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 151); + fc.frame_sent(151); + assert!(!fc.frame_needed()); + fc.retire(62); + assert!(!fc.frame_needed()); + fc.frame_lost(151); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 162); + } + + #[test] + fn changing_max_active() { + let mut fc = ReceiverFlowControl::new((), 100); + fc.set_max_active(50); + // There is no MAX_STREAM_DATA frame needed. + assert!(!fc.frame_needed()); + // We can still retire more than 50. + fc.retire(60); + // There is no MAX_STREAM_DATA fame needed yet. + assert!(!fc.frame_needed()); + fc.retire(76); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 126); + + // Increase max_active. + fc.set_max_active(60); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 136); + + // We can retire more than 60. + fc.retire(136); + assert!(fc.frame_needed()); + assert_eq!(fc.next_limit(), 196); + } + + fn remote_stream_limits(role: Role, bidi: u64, unidi: u64) { + let mut fc = RemoteStreamLimits::new(2, 1, role); + assert!(fc[StreamType::BiDi] + .is_new_stream(StreamId::from(bidi)) + .unwrap()); + assert!(fc[StreamType::BiDi] + .is_new_stream(StreamId::from(bidi + 4)) + .unwrap()); + assert!(fc[StreamType::UniDi] + .is_new_stream(StreamId::from(unidi)) + .unwrap()); + + // Exceed limits + assert_eq!( + fc[StreamType::BiDi].is_new_stream(StreamId::from(bidi + 8)), + Err(Error::StreamLimitError) + ); + assert_eq!( + fc[StreamType::UniDi].is_new_stream(StreamId::from(unidi + 4)), + Err(Error::StreamLimitError) + ); + + assert_eq!(fc[StreamType::BiDi].take_stream_id(), StreamId::from(bidi)); + assert_eq!( + fc[StreamType::BiDi].take_stream_id(), + StreamId::from(bidi + 4) + ); + assert_eq!( + fc[StreamType::UniDi].take_stream_id(), + StreamId::from(unidi) + ); + + fc[StreamType::BiDi].add_retired(1); + fc[StreamType::BiDi].send_flowc_update(); + // consume the frame + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut tokens = Vec::new(); + fc[StreamType::BiDi].write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + assert_eq!(tokens.len(), 1); + + // Now 9 can be a new StreamId. + assert!(fc[StreamType::BiDi] + .is_new_stream(StreamId::from(bidi + 8)) + .unwrap()); + assert_eq!( + fc[StreamType::BiDi].take_stream_id(), + StreamId::from(bidi + 8) + ); + // 13 still exceeds limits + assert_eq!( + fc[StreamType::BiDi].is_new_stream(StreamId::from(bidi + 12)), + Err(Error::StreamLimitError) + ); + + fc[StreamType::UniDi].add_retired(1); + fc[StreamType::UniDi].send_flowc_update(); + // consume the frame + fc[StreamType::UniDi].write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + assert_eq!(tokens.len(), 2); + + // Now 7 can be a new StreamId. + assert!(fc[StreamType::UniDi] + .is_new_stream(StreamId::from(unidi + 4)) + .unwrap()); + assert_eq!( + fc[StreamType::UniDi].take_stream_id(), + StreamId::from(unidi + 4) + ); + // 11 exceeds limits + assert_eq!( + fc[StreamType::UniDi].is_new_stream(StreamId::from(unidi + 8)), + Err(Error::StreamLimitError) + ); + } + + #[test] + fn remote_stream_limits_new_stream_client() { + remote_stream_limits(Role::Client, 1, 3); + } + + #[test] + fn remote_stream_limits_new_stream_server() { + remote_stream_limits(Role::Server, 0, 2); + } + + #[should_panic(expected = ".is_allowed")] + #[test] + fn remote_stream_limits_asserts_if_limit_exceeded() { + let mut fc = RemoteStreamLimits::new(2, 1, Role::Client); + assert_eq!(fc[StreamType::BiDi].take_stream_id(), StreamId::from(1)); + assert_eq!(fc[StreamType::BiDi].take_stream_id(), StreamId::from(5)); + _ = fc[StreamType::BiDi].take_stream_id(); + } + + fn local_stream_limits(role: Role, bidi: u64, unidi: u64) { + let mut fc = LocalStreamLimits::new(role); + + fc[StreamType::BiDi].update(2); + fc[StreamType::UniDi].update(1); + + // Add streams + assert_eq!( + fc.take_stream_id(StreamType::BiDi).unwrap(), + StreamId::from(bidi) + ); + assert_eq!( + fc.take_stream_id(StreamType::BiDi).unwrap(), + StreamId::from(bidi + 4) + ); + assert_eq!(fc.take_stream_id(StreamType::BiDi), None); + assert_eq!( + fc.take_stream_id(StreamType::UniDi).unwrap(), + StreamId::from(unidi) + ); + assert_eq!(fc.take_stream_id(StreamType::UniDi), None); + + // Increase limit + fc[StreamType::BiDi].update(3); + fc[StreamType::UniDi].update(2); + assert_eq!( + fc.take_stream_id(StreamType::BiDi).unwrap(), + StreamId::from(bidi + 8) + ); + assert_eq!(fc.take_stream_id(StreamType::BiDi), None); + assert_eq!( + fc.take_stream_id(StreamType::UniDi).unwrap(), + StreamId::from(unidi + 4) + ); + assert_eq!(fc.take_stream_id(StreamType::UniDi), None); + } + + #[test] + fn local_stream_limits_new_stream_client() { + local_stream_limits(Role::Client, 0, 2); + } + + #[test] + fn local_stream_limits_new_stream_server() { + local_stream_limits(Role::Server, 1, 3); + } +} diff --git a/third_party/rust/neqo-transport/src/frame.rs b/third_party/rust/neqo-transport/src/frame.rs new file mode 100644 index 0000000000..f3d567ac7c --- /dev/null +++ b/third_party/rust/neqo-transport/src/frame.rs @@ -0,0 +1,977 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Directly relating to QUIC frames. + +use std::{convert::TryFrom, ops::RangeInclusive}; + +use neqo_common::{qtrace, Decoder}; + +use crate::{ + cid::MAX_CONNECTION_ID_LEN, + packet::PacketType, + stream_id::{StreamId, StreamType}, + AppError, ConnectionError, Error, Res, TransportError, +}; + +#[allow(clippy::module_name_repetitions)] +pub type FrameType = u64; + +const FRAME_TYPE_PADDING: FrameType = 0x0; +pub const FRAME_TYPE_PING: FrameType = 0x1; +pub const FRAME_TYPE_ACK: FrameType = 0x2; +const FRAME_TYPE_ACK_ECN: FrameType = 0x3; +pub const FRAME_TYPE_RESET_STREAM: FrameType = 0x4; +pub const FRAME_TYPE_STOP_SENDING: FrameType = 0x5; +pub const FRAME_TYPE_CRYPTO: FrameType = 0x6; +pub const FRAME_TYPE_NEW_TOKEN: FrameType = 0x7; +const FRAME_TYPE_STREAM: FrameType = 0x8; +const FRAME_TYPE_STREAM_MAX: FrameType = 0xf; +pub const FRAME_TYPE_MAX_DATA: FrameType = 0x10; +pub const FRAME_TYPE_MAX_STREAM_DATA: FrameType = 0x11; +pub const FRAME_TYPE_MAX_STREAMS_BIDI: FrameType = 0x12; +pub const FRAME_TYPE_MAX_STREAMS_UNIDI: FrameType = 0x13; +pub const FRAME_TYPE_DATA_BLOCKED: FrameType = 0x14; +pub const FRAME_TYPE_STREAM_DATA_BLOCKED: FrameType = 0x15; +pub const FRAME_TYPE_STREAMS_BLOCKED_BIDI: FrameType = 0x16; +pub const FRAME_TYPE_STREAMS_BLOCKED_UNIDI: FrameType = 0x17; +pub const FRAME_TYPE_NEW_CONNECTION_ID: FrameType = 0x18; +pub const FRAME_TYPE_RETIRE_CONNECTION_ID: FrameType = 0x19; +pub const FRAME_TYPE_PATH_CHALLENGE: FrameType = 0x1a; +pub const FRAME_TYPE_PATH_RESPONSE: FrameType = 0x1b; +pub const FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT: FrameType = 0x1c; +pub const FRAME_TYPE_CONNECTION_CLOSE_APPLICATION: FrameType = 0x1d; +pub const FRAME_TYPE_HANDSHAKE_DONE: FrameType = 0x1e; +// draft-ietf-quic-ack-delay +pub const FRAME_TYPE_ACK_FREQUENCY: FrameType = 0xaf; +// draft-ietf-quic-datagram +pub const FRAME_TYPE_DATAGRAM: FrameType = 0x30; +pub const FRAME_TYPE_DATAGRAM_WITH_LEN: FrameType = 0x31; +const DATAGRAM_FRAME_BIT_LEN: u64 = 0x01; + +const STREAM_FRAME_BIT_FIN: u64 = 0x01; +const STREAM_FRAME_BIT_LEN: u64 = 0x02; +const STREAM_FRAME_BIT_OFF: u64 = 0x04; + +#[derive(PartialEq, Eq, Debug, PartialOrd, Ord, Clone, Copy)] +pub enum CloseError { + Transport(TransportError), + Application(AppError), +} + +impl CloseError { + fn frame_type_bit(self) -> u64 { + match self { + Self::Transport(_) => 0, + Self::Application(_) => 1, + } + } + + fn from_type_bit(bit: u64, code: u64) -> Self { + if (bit & 0x01) == 0 { + Self::Transport(code) + } else { + Self::Application(code) + } + } + + pub fn code(&self) -> u64 { + match self { + Self::Transport(c) | Self::Application(c) => *c, + } + } +} + +impl From<ConnectionError> for CloseError { + fn from(err: ConnectionError) -> Self { + match err { + ConnectionError::Transport(c) => Self::Transport(c.code()), + ConnectionError::Application(c) => Self::Application(c), + } + } +} + +#[derive(PartialEq, Eq, Debug, Default, Clone)] +pub struct AckRange { + pub(crate) gap: u64, + pub(crate) range: u64, +} + +#[derive(PartialEq, Eq, Debug, Clone)] +pub enum Frame<'a> { + Padding, + Ping, + Ack { + largest_acknowledged: u64, + ack_delay: u64, + first_ack_range: u64, + ack_ranges: Vec<AckRange>, + }, + ResetStream { + stream_id: StreamId, + application_error_code: AppError, + final_size: u64, + }, + StopSending { + stream_id: StreamId, + application_error_code: AppError, + }, + Crypto { + offset: u64, + data: &'a [u8], + }, + NewToken { + token: &'a [u8], + }, + Stream { + stream_id: StreamId, + offset: u64, + data: &'a [u8], + fin: bool, + fill: bool, + }, + MaxData { + maximum_data: u64, + }, + MaxStreamData { + stream_id: StreamId, + maximum_stream_data: u64, + }, + MaxStreams { + stream_type: StreamType, + maximum_streams: u64, + }, + DataBlocked { + data_limit: u64, + }, + StreamDataBlocked { + stream_id: StreamId, + stream_data_limit: u64, + }, + StreamsBlocked { + stream_type: StreamType, + stream_limit: u64, + }, + NewConnectionId { + sequence_number: u64, + retire_prior: u64, + connection_id: &'a [u8], + stateless_reset_token: &'a [u8; 16], + }, + RetireConnectionId { + sequence_number: u64, + }, + PathChallenge { + data: [u8; 8], + }, + PathResponse { + data: [u8; 8], + }, + ConnectionClose { + error_code: CloseError, + frame_type: u64, + // Not a reference as we use this to hold the value. + // This is not used in optimized builds anyway. + reason_phrase: Vec<u8>, + }, + HandshakeDone, + AckFrequency { + /// The current ACK frequency sequence number. + seqno: u64, + /// The number of contiguous packets that can be received without + /// acknowledging immediately. + tolerance: u64, + /// The time to delay after receiving the first packet that is + /// not immediately acknowledged. + delay: u64, + /// Ignore reordering when deciding to immediately acknowledge. + ignore_order: bool, + }, + Datagram { + data: &'a [u8], + fill: bool, + }, +} + +impl<'a> Frame<'a> { + fn get_stream_type_bit(stream_type: StreamType) -> u64 { + match stream_type { + StreamType::BiDi => 0, + StreamType::UniDi => 1, + } + } + + fn stream_type_from_bit(bit: u64) -> StreamType { + if (bit & 0x01) == 0 { + StreamType::BiDi + } else { + StreamType::UniDi + } + } + + pub fn get_type(&self) -> FrameType { + match self { + Self::Padding => FRAME_TYPE_PADDING, + Self::Ping => FRAME_TYPE_PING, + Self::Ack { .. } => FRAME_TYPE_ACK, // We don't do ACK ECN. + Self::ResetStream { .. } => FRAME_TYPE_RESET_STREAM, + Self::StopSending { .. } => FRAME_TYPE_STOP_SENDING, + Self::Crypto { .. } => FRAME_TYPE_CRYPTO, + Self::NewToken { .. } => FRAME_TYPE_NEW_TOKEN, + Self::Stream { + fin, offset, fill, .. + } => Self::stream_type(*fin, *offset > 0, *fill), + Self::MaxData { .. } => FRAME_TYPE_MAX_DATA, + Self::MaxStreamData { .. } => FRAME_TYPE_MAX_STREAM_DATA, + Self::MaxStreams { stream_type, .. } => { + FRAME_TYPE_MAX_STREAMS_BIDI + Self::get_stream_type_bit(*stream_type) + } + Self::DataBlocked { .. } => FRAME_TYPE_DATA_BLOCKED, + Self::StreamDataBlocked { .. } => FRAME_TYPE_STREAM_DATA_BLOCKED, + Self::StreamsBlocked { stream_type, .. } => { + FRAME_TYPE_STREAMS_BLOCKED_BIDI + Self::get_stream_type_bit(*stream_type) + } + Self::NewConnectionId { .. } => FRAME_TYPE_NEW_CONNECTION_ID, + Self::RetireConnectionId { .. } => FRAME_TYPE_RETIRE_CONNECTION_ID, + Self::PathChallenge { .. } => FRAME_TYPE_PATH_CHALLENGE, + Self::PathResponse { .. } => FRAME_TYPE_PATH_RESPONSE, + Self::ConnectionClose { error_code, .. } => { + FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT + error_code.frame_type_bit() + } + Self::HandshakeDone => FRAME_TYPE_HANDSHAKE_DONE, + Self::AckFrequency { .. } => FRAME_TYPE_ACK_FREQUENCY, + Self::Datagram { fill, .. } => { + if *fill { + FRAME_TYPE_DATAGRAM + } else { + FRAME_TYPE_DATAGRAM_WITH_LEN + } + } + } + } + + pub fn is_stream(&self) -> bool { + matches!( + self, + Self::ResetStream { .. } + | Self::StopSending { .. } + | Self::Stream { .. } + | Self::MaxData { .. } + | Self::MaxStreamData { .. } + | Self::MaxStreams { .. } + | Self::DataBlocked { .. } + | Self::StreamDataBlocked { .. } + | Self::StreamsBlocked { .. } + ) + } + + pub fn stream_type(fin: bool, nonzero_offset: bool, fill: bool) -> u64 { + let mut t = FRAME_TYPE_STREAM; + if fin { + t |= STREAM_FRAME_BIT_FIN; + } + if nonzero_offset { + t |= STREAM_FRAME_BIT_OFF; + } + if !fill { + t |= STREAM_FRAME_BIT_LEN; + } + t + } + + /// If the frame causes a recipient to generate an ACK within its + /// advertised maximum acknowledgement delay. + pub fn ack_eliciting(&self) -> bool { + !matches!( + self, + Self::Ack { .. } | Self::Padding | Self::ConnectionClose { .. } + ) + } + + /// If the frame can be sent in a path probe + /// without initiating migration to that path. + pub fn path_probing(&self) -> bool { + matches!( + self, + Self::Padding + | Self::NewConnectionId { .. } + | Self::PathChallenge { .. } + | Self::PathResponse { .. } + ) + } + + /// Converts AckRanges as encoded in a ACK frame (see -transport + /// 19.3.1) into ranges of acked packets (end, start), inclusive of + /// start and end values. + pub fn decode_ack_frame( + largest_acked: u64, + first_ack_range: u64, + ack_ranges: &[AckRange], + ) -> Res<Vec<RangeInclusive<u64>>> { + let mut acked_ranges = Vec::with_capacity(ack_ranges.len() + 1); + + if largest_acked < first_ack_range { + return Err(Error::FrameEncodingError); + } + acked_ranges.push((largest_acked - first_ack_range)..=largest_acked); + if !ack_ranges.is_empty() && largest_acked < first_ack_range + 1 { + return Err(Error::FrameEncodingError); + } + let mut cur = if ack_ranges.is_empty() { + 0 + } else { + largest_acked - first_ack_range - 1 + }; + for r in ack_ranges { + if cur < r.gap + 1 { + return Err(Error::FrameEncodingError); + } + cur = cur - r.gap - 1; + + if cur < r.range { + return Err(Error::FrameEncodingError); + } + acked_ranges.push((cur - r.range)..=cur); + + if cur > r.range + 1 { + cur -= r.range + 1; + } else { + cur -= r.range; + } + } + + Ok(acked_ranges) + } + + pub fn dump(&self) -> Option<String> { + match self { + Self::Crypto { offset, data } => Some(format!( + "Crypto {{ offset: {}, len: {} }}", + offset, + data.len() + )), + Self::Stream { + stream_id, + offset, + fill, + data, + fin, + } => Some(format!( + "Stream {{ stream_id: {}, offset: {}, len: {}{}, fin: {} }}", + stream_id.as_u64(), + offset, + if *fill { ">>" } else { "" }, + data.len(), + fin, + )), + Self::Padding => None, + Self::Datagram { data, .. } => Some(format!("Datagram {{ len: {} }}", data.len())), + _ => Some(format!("{self:?}")), + } + } + + pub fn is_allowed(&self, pt: PacketType) -> bool { + match self { + Self::Padding | Self::Ping => true, + Self::Crypto { .. } + | Self::Ack { .. } + | Self::ConnectionClose { + error_code: CloseError::Transport(_), + .. + } => pt != PacketType::ZeroRtt, + Self::NewToken { .. } | Self::ConnectionClose { .. } => pt == PacketType::Short, + _ => pt == PacketType::ZeroRtt || pt == PacketType::Short, + } + } + + pub fn decode(dec: &mut Decoder<'a>) -> Res<Self> { + /// Maximum ACK Range Count in ACK Frame + /// + /// Given a max UDP datagram size of 64k bytes and a minimum ACK Range size of 2 + /// bytes (2 QUIC varints), a single datagram can at most contain 32k ACK + /// Ranges. + /// + /// Note that the maximum (jumbogram) Ethernet MTU of 9216 or on the + /// Internet the regular Ethernet MTU of 1518 are more realistically to + /// be the limiting factor. Though for simplicity the higher limit is chosen. + const MAX_ACK_RANGE_COUNT: u64 = 32 * 1024; + + fn d<T>(v: Option<T>) -> Res<T> { + v.ok_or(Error::NoMoreData) + } + fn dv(dec: &mut Decoder) -> Res<u64> { + d(dec.decode_varint()) + } + + // TODO(ekr@rtfm.com): check for minimal encoding + let t = d(dec.decode_varint())?; + match t { + FRAME_TYPE_PADDING => Ok(Self::Padding), + FRAME_TYPE_PING => Ok(Self::Ping), + FRAME_TYPE_RESET_STREAM => Ok(Self::ResetStream { + stream_id: StreamId::from(dv(dec)?), + application_error_code: d(dec.decode_varint())?, + final_size: match dec.decode_varint() { + Some(v) => v, + _ => return Err(Error::NoMoreData), + }, + }), + FRAME_TYPE_ACK | FRAME_TYPE_ACK_ECN => { + let la = dv(dec)?; + let ad = dv(dec)?; + let nr = dv(dec).and_then(|nr| { + if nr < MAX_ACK_RANGE_COUNT { + Ok(nr) + } else { + Err(Error::TooMuchData) + } + })?; + let fa = dv(dec)?; + let mut arr: Vec<AckRange> = Vec::with_capacity(nr as usize); + for _ in 0..nr { + let ar = AckRange { + gap: dv(dec)?, + range: dv(dec)?, + }; + arr.push(ar); + } + + // Now check for the values for ACK_ECN. + if t == FRAME_TYPE_ACK_ECN { + dv(dec)?; + dv(dec)?; + dv(dec)?; + } + + Ok(Self::Ack { + largest_acknowledged: la, + ack_delay: ad, + first_ack_range: fa, + ack_ranges: arr, + }) + } + FRAME_TYPE_STOP_SENDING => Ok(Self::StopSending { + stream_id: StreamId::from(dv(dec)?), + application_error_code: d(dec.decode_varint())?, + }), + FRAME_TYPE_CRYPTO => { + let offset = dv(dec)?; + let data = d(dec.decode_vvec())?; + if offset + u64::try_from(data.len()).unwrap() > ((1 << 62) - 1) { + return Err(Error::FrameEncodingError); + } + Ok(Self::Crypto { offset, data }) + } + FRAME_TYPE_NEW_TOKEN => { + let token = d(dec.decode_vvec())?; + if token.is_empty() { + return Err(Error::FrameEncodingError); + } + Ok(Self::NewToken { token }) + } + FRAME_TYPE_STREAM..=FRAME_TYPE_STREAM_MAX => { + let s = dv(dec)?; + let o = if t & STREAM_FRAME_BIT_OFF == 0 { + 0 + } else { + dv(dec)? + }; + let fill = (t & STREAM_FRAME_BIT_LEN) == 0; + let data = if fill { + qtrace!("STREAM frame, extends to the end of the packet"); + dec.decode_remainder() + } else { + qtrace!("STREAM frame, with length"); + d(dec.decode_vvec())? + }; + if o + u64::try_from(data.len()).unwrap() > ((1 << 62) - 1) { + return Err(Error::FrameEncodingError); + } + Ok(Self::Stream { + fin: (t & STREAM_FRAME_BIT_FIN) != 0, + stream_id: StreamId::from(s), + offset: o, + data, + fill, + }) + } + FRAME_TYPE_MAX_DATA => Ok(Self::MaxData { + maximum_data: dv(dec)?, + }), + FRAME_TYPE_MAX_STREAM_DATA => Ok(Self::MaxStreamData { + stream_id: StreamId::from(dv(dec)?), + maximum_stream_data: dv(dec)?, + }), + FRAME_TYPE_MAX_STREAMS_BIDI | FRAME_TYPE_MAX_STREAMS_UNIDI => { + let m = dv(dec)?; + if m > (1 << 60) { + return Err(Error::StreamLimitError); + } + Ok(Self::MaxStreams { + stream_type: Self::stream_type_from_bit(t), + maximum_streams: m, + }) + } + FRAME_TYPE_DATA_BLOCKED => Ok(Self::DataBlocked { + data_limit: dv(dec)?, + }), + FRAME_TYPE_STREAM_DATA_BLOCKED => Ok(Self::StreamDataBlocked { + stream_id: dv(dec)?.into(), + stream_data_limit: dv(dec)?, + }), + FRAME_TYPE_STREAMS_BLOCKED_BIDI | FRAME_TYPE_STREAMS_BLOCKED_UNIDI => { + Ok(Self::StreamsBlocked { + stream_type: Self::stream_type_from_bit(t), + stream_limit: dv(dec)?, + }) + } + FRAME_TYPE_NEW_CONNECTION_ID => { + let sequence_number = dv(dec)?; + let retire_prior = dv(dec)?; + let connection_id = d(dec.decode_vec(1))?; + if connection_id.len() > MAX_CONNECTION_ID_LEN { + return Err(Error::DecodingFrame); + } + let srt = d(dec.decode(16))?; + let stateless_reset_token = <&[_; 16]>::try_from(srt).unwrap(); + + Ok(Self::NewConnectionId { + sequence_number, + retire_prior, + connection_id, + stateless_reset_token, + }) + } + FRAME_TYPE_RETIRE_CONNECTION_ID => Ok(Self::RetireConnectionId { + sequence_number: dv(dec)?, + }), + FRAME_TYPE_PATH_CHALLENGE => { + let data = d(dec.decode(8))?; + let mut datav: [u8; 8] = [0; 8]; + datav.copy_from_slice(data); + Ok(Self::PathChallenge { data: datav }) + } + FRAME_TYPE_PATH_RESPONSE => { + let data = d(dec.decode(8))?; + let mut datav: [u8; 8] = [0; 8]; + datav.copy_from_slice(data); + Ok(Self::PathResponse { data: datav }) + } + FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT | FRAME_TYPE_CONNECTION_CLOSE_APPLICATION => { + let error_code = CloseError::from_type_bit(t, d(dec.decode_varint())?); + let frame_type = if t == FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT { + dv(dec)? + } else { + 0 + }; + // We can tolerate this copy for now. + let reason_phrase = d(dec.decode_vvec())?.to_vec(); + Ok(Self::ConnectionClose { + error_code, + frame_type, + reason_phrase, + }) + } + FRAME_TYPE_HANDSHAKE_DONE => Ok(Self::HandshakeDone), + FRAME_TYPE_ACK_FREQUENCY => { + let seqno = dv(dec)?; + let tolerance = dv(dec)?; + if tolerance == 0 { + return Err(Error::FrameEncodingError); + } + let delay = dv(dec)?; + let ignore_order = match d(dec.decode_uint(1))? { + 0 => false, + 1 => true, + _ => return Err(Error::FrameEncodingError), + }; + Ok(Self::AckFrequency { + seqno, + tolerance, + delay, + ignore_order, + }) + } + FRAME_TYPE_DATAGRAM | FRAME_TYPE_DATAGRAM_WITH_LEN => { + let fill = (t & DATAGRAM_FRAME_BIT_LEN) == 0; + let data = if fill { + qtrace!("DATAGRAM frame, extends to the end of the packet"); + dec.decode_remainder() + } else { + qtrace!("DATAGRAM frame, with length"); + d(dec.decode_vvec())? + }; + Ok(Self::Datagram { data, fill }) + } + _ => Err(Error::UnknownFrameType), + } + } +} + +#[cfg(test)] +mod tests { + use neqo_common::{Decoder, Encoder}; + + use super::*; + + fn just_dec(f: &Frame, s: &str) { + let encoded = Encoder::from_hex(s); + let decoded = Frame::decode(&mut encoded.as_decoder()).unwrap(); + assert_eq!(*f, decoded); + } + + #[test] + fn padding() { + let f = Frame::Padding; + just_dec(&f, "00"); + } + + #[test] + fn ping() { + let f = Frame::Ping; + just_dec(&f, "01"); + } + + #[test] + fn ack() { + let ar = vec![AckRange { gap: 1, range: 2 }, AckRange { gap: 3, range: 4 }]; + + let f = Frame::Ack { + largest_acknowledged: 0x1234, + ack_delay: 0x1235, + first_ack_range: 0x1236, + ack_ranges: ar, + }; + + just_dec(&f, "025234523502523601020304"); + + // Try to parse ACK_ECN without ECN values + let enc = Encoder::from_hex("035234523502523601020304"); + let mut dec = enc.as_decoder(); + assert_eq!(Frame::decode(&mut dec).unwrap_err(), Error::NoMoreData); + + // Try to parse ACK_ECN without ECN values + let enc = Encoder::from_hex("035234523502523601020304010203"); + let mut dec = enc.as_decoder(); + assert_eq!(Frame::decode(&mut dec).unwrap(), f); + } + + #[test] + fn reset_stream() { + let f = Frame::ResetStream { + stream_id: StreamId::from(0x1234), + application_error_code: 0x77, + final_size: 0x3456, + }; + + just_dec(&f, "04523440777456"); + } + + #[test] + fn stop_sending() { + let f = Frame::StopSending { + stream_id: StreamId::from(63), + application_error_code: 0x77, + }; + + just_dec(&f, "053F4077"); + } + + #[test] + fn crypto() { + let f = Frame::Crypto { + offset: 1, + data: &[1, 2, 3], + }; + + just_dec(&f, "060103010203"); + } + + #[test] + fn new_token() { + let f = Frame::NewToken { + token: &[0x12, 0x34, 0x56], + }; + + just_dec(&f, "0703123456"); + } + + #[test] + fn empty_new_token() { + let mut dec = Decoder::from(&[0x07, 0x00][..]); + assert_eq!( + Frame::decode(&mut dec).unwrap_err(), + Error::FrameEncodingError + ); + } + + #[test] + fn stream() { + // First, just set the length bit. + let f = Frame::Stream { + fin: false, + stream_id: StreamId::from(5), + offset: 0, + data: &[1, 2, 3], + fill: false, + }; + + just_dec(&f, "0a0503010203"); + + // Now with offset != 0 and FIN + let f = Frame::Stream { + fin: true, + stream_id: StreamId::from(5), + offset: 1, + data: &[1, 2, 3], + fill: false, + }; + just_dec(&f, "0f050103010203"); + + // Now to fill the packet. + let f = Frame::Stream { + fin: true, + stream_id: StreamId::from(5), + offset: 0, + data: &[1, 2, 3], + fill: true, + }; + just_dec(&f, "0905010203"); + } + + #[test] + fn max_data() { + let f = Frame::MaxData { + maximum_data: 0x1234, + }; + + just_dec(&f, "105234"); + } + + #[test] + fn max_stream_data() { + let f = Frame::MaxStreamData { + stream_id: StreamId::from(5), + maximum_stream_data: 0x1234, + }; + + just_dec(&f, "11055234"); + } + + #[test] + fn max_streams() { + let mut f = Frame::MaxStreams { + stream_type: StreamType::BiDi, + maximum_streams: 0x1234, + }; + + just_dec(&f, "125234"); + + f = Frame::MaxStreams { + stream_type: StreamType::UniDi, + maximum_streams: 0x1234, + }; + + just_dec(&f, "135234"); + } + + #[test] + fn data_blocked() { + let f = Frame::DataBlocked { data_limit: 0x1234 }; + + just_dec(&f, "145234"); + } + + #[test] + fn stream_data_blocked() { + let f = Frame::StreamDataBlocked { + stream_id: StreamId::from(5), + stream_data_limit: 0x1234, + }; + + just_dec(&f, "15055234"); + } + + #[test] + fn streams_blocked() { + let mut f = Frame::StreamsBlocked { + stream_type: StreamType::BiDi, + stream_limit: 0x1234, + }; + + just_dec(&f, "165234"); + + f = Frame::StreamsBlocked { + stream_type: StreamType::UniDi, + stream_limit: 0x1234, + }; + + just_dec(&f, "175234"); + } + + #[test] + fn new_connection_id() { + let f = Frame::NewConnectionId { + sequence_number: 0x1234, + retire_prior: 0, + connection_id: &[0x01, 0x02], + stateless_reset_token: &[9; 16], + }; + + just_dec(&f, "1852340002010209090909090909090909090909090909"); + } + + #[test] + fn too_large_new_connection_id() { + let mut enc = Encoder::from_hex("18523400"); // up to the CID + enc.encode_vvec(&[0x0c; MAX_CONNECTION_ID_LEN + 10]); + enc.encode(&[0x11; 16][..]); + assert_eq!( + Frame::decode(&mut enc.as_decoder()).unwrap_err(), + Error::DecodingFrame + ); + } + + #[test] + fn retire_connection_id() { + let f = Frame::RetireConnectionId { + sequence_number: 0x1234, + }; + + just_dec(&f, "195234"); + } + + #[test] + fn path_challenge() { + let f = Frame::PathChallenge { data: [9; 8] }; + + just_dec(&f, "1a0909090909090909"); + } + + #[test] + fn path_response() { + let f = Frame::PathResponse { data: [9; 8] }; + + just_dec(&f, "1b0909090909090909"); + } + + #[test] + fn connection_close_transport() { + let f = Frame::ConnectionClose { + error_code: CloseError::Transport(0x5678), + frame_type: 0x1234, + reason_phrase: vec![0x01, 0x02, 0x03], + }; + + just_dec(&f, "1c80005678523403010203"); + } + + #[test] + fn connection_close_application() { + let f = Frame::ConnectionClose { + error_code: CloseError::Application(0x5678), + frame_type: 0, + reason_phrase: vec![0x01, 0x02, 0x03], + }; + + just_dec(&f, "1d8000567803010203"); + } + + #[test] + fn test_compare() { + let f1 = Frame::Padding; + let f2 = Frame::Padding; + let f3 = Frame::Crypto { + offset: 0, + data: &[1, 2, 3], + }; + let f4 = Frame::Crypto { + offset: 0, + data: &[1, 2, 3], + }; + let f5 = Frame::Crypto { + offset: 1, + data: &[1, 2, 3], + }; + let f6 = Frame::Crypto { + offset: 0, + data: &[1, 2, 4], + }; + + assert_eq!(f1, f2); + assert_ne!(f1, f3); + assert_eq!(f3, f4); + assert_ne!(f3, f5); + assert_ne!(f3, f6); + } + + #[test] + fn decode_ack_frame() { + let res = Frame::decode_ack_frame(7, 2, &[AckRange { gap: 0, range: 3 }]); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), vec![5..=7, 0..=3]); + } + + #[test] + fn ack_frequency() { + let f = Frame::AckFrequency { + seqno: 10, + tolerance: 5, + delay: 2000, + ignore_order: true, + }; + just_dec(&f, "40af0a0547d001"); + } + + #[test] + fn ack_frequency_ignore_error_error() { + let enc = Encoder::from_hex("40af0a0547d003"); // ignore_order of 3 + assert_eq!( + Frame::decode(&mut enc.as_decoder()).unwrap_err(), + Error::FrameEncodingError + ); + } + + /// Hopefully this test is eventually redundant. + #[test] + fn ack_frequency_zero_packets() { + let enc = Encoder::from_hex("40af0a000101"); // packets of 0 + assert_eq!( + Frame::decode(&mut enc.as_decoder()).unwrap_err(), + Error::FrameEncodingError + ); + } + + #[test] + fn datagram() { + // Without the length bit. + let f = Frame::Datagram { + data: &[1, 2, 3], + fill: true, + }; + + just_dec(&f, "4030010203"); + + // With the length bit. + let f = Frame::Datagram { + data: &[1, 2, 3], + fill: false, + }; + just_dec(&f, "403103010203"); + } + + #[test] + fn frame_decode_enforces_bound_on_ack_range() { + let mut e = Encoder::new(); + + e.encode_varint(FRAME_TYPE_ACK); + e.encode_varint(0u64); // largest acknowledged + e.encode_varint(0u64); // ACK delay + e.encode_varint(u32::MAX); // ACK range count = huge, but maybe available for allocation + + assert_eq!(Err(Error::TooMuchData), Frame::decode(&mut e.as_decoder())); + } +} diff --git a/third_party/rust/neqo-transport/src/lib.rs b/third_party/rust/neqo-transport/src/lib.rs new file mode 100644 index 0000000000..ecf7ee2f73 --- /dev/null +++ b/third_party/rust/neqo-transport/src/lib.rs @@ -0,0 +1,226 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![cfg_attr(feature = "deny-warnings", deny(warnings))] +#![warn(clippy::use_self)] + +use neqo_common::qinfo; +use neqo_crypto::Error as CryptoError; + +mod ackrate; +mod addr_valid; +mod cc; +mod cid; +mod connection; +mod crypto; +mod events; +mod fc; +mod frame; +mod pace; +mod packet; +mod path; +mod qlog; +mod quic_datagrams; +mod recovery; +#[cfg(feature = "bench")] +pub mod recv_stream; +#[cfg(not(feature = "bench"))] +mod recv_stream; +mod rtt; +mod send_stream; +mod sender; +pub mod server; +mod stats; +pub mod stream_id; +pub mod streams; +pub mod tparams; +mod tracking; +pub mod version; + +pub use self::{ + cc::CongestionControlAlgorithm, + cid::{ + ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef, + EmptyConnectionIdGenerator, RandomConnectionIdGenerator, + }, + connection::{ + params::{ConnectionParameters, ACK_RATIO_SCALE}, + Connection, Output, State, ZeroRttState, + }, + events::{ConnectionEvent, ConnectionEvents}, + frame::CloseError, + quic_datagrams::DatagramTracking, + recv_stream::{RecvStreamStats, RECV_BUFFER_SIZE}, + send_stream::{SendStreamStats, SEND_BUFFER_SIZE}, + stats::Stats, + stream_id::{StreamId, StreamType}, + version::Version, +}; + +pub type TransportError = u64; +const ERROR_APPLICATION_CLOSE: TransportError = 12; +const ERROR_CRYPTO_BUFFER_EXCEEDED: TransportError = 13; +const ERROR_AEAD_LIMIT_REACHED: TransportError = 15; + +#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] +pub enum Error { + NoError, + // Each time tihe error is return a different parameter is supply. + // This will be use to distinguish each occurance of this error. + InternalError, + ConnectionRefused, + FlowControlError, + StreamLimitError, + StreamStateError, + FinalSizeError, + FrameEncodingError, + TransportParameterError, + ProtocolViolation, + InvalidToken, + ApplicationError, + CryptoBufferExceeded, + CryptoError(CryptoError), + QlogError, + CryptoAlert(u8), + EchRetry(Vec<u8>), + + // All internal errors from here. Please keep these sorted. + AckedUnsentPacket, + ConnectionIdLimitExceeded, + ConnectionIdsExhausted, + ConnectionState, + DecodingFrame, + DecryptError, + DisabledVersion, + HandshakeFailed, + IdleTimeout, + IntegerOverflow, + InvalidInput, + InvalidMigration, + InvalidPacket, + InvalidResumptionToken, + InvalidRetry, + InvalidStreamId, + KeysDiscarded(crypto::CryptoSpace), + /// Packet protection keys are exhausted. + /// Also used when too many key updates have happened. + KeysExhausted, + /// Packet protection keys aren't available yet for the identified space. + KeysPending(crypto::CryptoSpace), + /// An attempt to update keys can be blocked if + /// a packet sent with the current keys hasn't been acknowledged. + KeyUpdateBlocked, + NoAvailablePath, + NoMoreData, + NotConnected, + PacketNumberOverlap, + PeerApplicationError(AppError), + PeerError(TransportError), + StatelessReset, + TooMuchData, + UnexpectedMessage, + UnknownConnectionId, + UnknownFrameType, + VersionNegotiation, + WrongRole, + NotAvailable, +} + +impl Error { + pub fn code(&self) -> TransportError { + match self { + Self::NoError + | Self::IdleTimeout + | Self::PeerError(_) + | Self::PeerApplicationError(_) => 0, + Self::ConnectionRefused => 2, + Self::FlowControlError => 3, + Self::StreamLimitError => 4, + Self::StreamStateError => 5, + Self::FinalSizeError => 6, + Self::FrameEncodingError => 7, + Self::TransportParameterError => 8, + Self::ProtocolViolation => 10, + Self::InvalidToken => 11, + Self::KeysExhausted => ERROR_AEAD_LIMIT_REACHED, + Self::ApplicationError => ERROR_APPLICATION_CLOSE, + Self::NoAvailablePath => 16, + Self::CryptoBufferExceeded => ERROR_CRYPTO_BUFFER_EXCEEDED, + Self::CryptoAlert(a) => 0x100 + u64::from(*a), + // As we have a special error code for ECH fallbacks, we lose the alert. + // Send the server "ech_required" directly. + Self::EchRetry(_) => 0x100 + 121, + Self::VersionNegotiation => 0x53f8, + // All the rest are internal errors. + _ => 1, + } + } +} + +impl From<CryptoError> for Error { + fn from(err: CryptoError) -> Self { + qinfo!("Crypto operation failed {:?}", err); + match err { + CryptoError::EchRetry(config) => Self::EchRetry(config), + _ => Self::CryptoError(err), + } + } +} + +impl From<::qlog::Error> for Error { + fn from(_err: ::qlog::Error) -> Self { + Self::QlogError + } +} + +impl From<std::num::TryFromIntError> for Error { + fn from(_: std::num::TryFromIntError) -> Self { + Self::IntegerOverflow + } +} + +impl ::std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::CryptoError(e) => Some(e), + _ => None, + } + } +} + +impl ::std::fmt::Display for Error { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Transport error: {self:?}") + } +} + +pub type AppError = u64; + +#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] +pub enum ConnectionError { + Transport(Error), + Application(AppError), +} + +impl ConnectionError { + pub fn app_code(&self) -> Option<AppError> { + match self { + Self::Application(e) => Some(*e), + Self::Transport(_) => None, + } + } +} + +impl From<CloseError> for ConnectionError { + fn from(err: CloseError) -> Self { + match err { + CloseError::Transport(c) => Self::Transport(Error::PeerError(c)), + CloseError::Application(c) => Self::Application(c), + } + } +} + +pub type Res<T> = std::result::Result<T, Error>; diff --git a/third_party/rust/neqo-transport/src/pace.rs b/third_party/rust/neqo-transport/src/pace.rs new file mode 100644 index 0000000000..e5214c1bc8 --- /dev/null +++ b/third_party/rust/neqo-transport/src/pace.rs @@ -0,0 +1,165 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Pacer +#![deny(clippy::pedantic)] + +use std::{ + cmp::min, + convert::TryFrom, + fmt::{Debug, Display}, + time::{Duration, Instant}, +}; + +use neqo_common::qtrace; + +/// This value determines how much faster the pacer operates than the +/// congestion window. +/// +/// A value of 1 would cause all packets to be spaced over the entire RTT, +/// which is a little slow and might act as an additional restriction in +/// the case the congestion controller increases the congestion window. +/// This value spaces packets over half the congestion window, which matches +/// our current congestion controller, which double the window every RTT. +const PACER_SPEEDUP: usize = 2; + +/// A pacer that uses a leaky bucket. +pub struct Pacer { + /// Whether pacing is enabled. + enabled: bool, + /// The last update time. + t: Instant, + /// The maximum capacity, or burst size, in bytes. + m: usize, + /// The current capacity, in bytes. + c: usize, + /// The packet size or minimum capacity for sending, in bytes. + p: usize, +} + +impl Pacer { + /// Create a new `Pacer`. This takes the current time, the maximum burst size, + /// and the packet size. + /// + /// The value of `m` is the maximum capacity in bytes. `m` primes the pacer + /// with credit and determines the burst size. `m` must not exceed + /// the initial congestion window, but it should probably be lower. + /// + /// The value of `p` is the packet size in bytes, which determines the minimum + /// credit needed before a packet is sent. This should be a substantial + /// fraction of the maximum packet size, if not the packet size. + pub fn new(enabled: bool, now: Instant, m: usize, p: usize) -> Self { + assert!(m >= p, "maximum capacity has to be at least one packet"); + Self { + enabled, + t: now, + m, + c: m, + p, + } + } + + /// Determine when the next packet will be available based on the provided RTT + /// and congestion window. This doesn't update state. + /// This returns a time, which could be in the past (this object doesn't know what + /// the current time is). + pub fn next(&self, rtt: Duration, cwnd: usize) -> Instant { + if self.c >= self.p { + qtrace!([self], "next {}/{:?} no wait = {:?}", cwnd, rtt, self.t); + self.t + } else { + // This is the inverse of the function in `spend`: + // self.t + rtt * (self.p - self.c) / (PACER_SPEEDUP * cwnd) + let r = rtt.as_nanos(); + let d = r.saturating_mul(u128::try_from(self.p - self.c).unwrap()); + let add = d / u128::try_from(cwnd * PACER_SPEEDUP).unwrap(); + let w = u64::try_from(add).map(Duration::from_nanos).unwrap_or(rtt); + let nxt = self.t + w; + qtrace!([self], "next {}/{:?} wait {:?} = {:?}", cwnd, rtt, w, nxt); + nxt + } + } + + /// Spend credit. This cannot fail; users of this API are expected to call + /// `next()` to determine when to spend. This takes the current time (`now`), + /// an estimate of the round trip time (`rtt`), the estimated congestion + /// window (`cwnd`), and the number of bytes that were sent (`count`). + pub fn spend(&mut self, now: Instant, rtt: Duration, cwnd: usize, count: usize) { + if !self.enabled { + self.t = now; + return; + } + + qtrace!([self], "spend {} over {}, {:?}", count, cwnd, rtt); + // Increase the capacity by: + // `(now - self.t) * PACER_SPEEDUP * cwnd / rtt` + // That is, the elapsed fraction of the RTT times rate that data is added. + let incr = now + .saturating_duration_since(self.t) + .as_nanos() + .saturating_mul(u128::try_from(cwnd * PACER_SPEEDUP).unwrap()) + .checked_div(rtt.as_nanos()) + .and_then(|i| usize::try_from(i).ok()) + .unwrap_or(self.m); + + // Add the capacity up to a limit of `self.m`, then subtract `count`. + self.c = min(self.m, (self.c + incr).saturating_sub(count)); + self.t = now; + } +} + +impl Display for Pacer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Pacer {}/{}", self.c, self.p) + } +} + +impl Debug for Pacer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Pacer@{:?} {}/{}..{}", self.t, self.c, self.p, self.m) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use test_fixture::now; + + use super::Pacer; + + const RTT: Duration = Duration::from_millis(1000); + const PACKET: usize = 1000; + const CWND: usize = PACKET * 10; + + #[test] + fn even() { + let n = now(); + let mut p = Pacer::new(true, n, PACKET, PACKET); + assert_eq!(p.next(RTT, CWND), n); + p.spend(n, RTT, CWND, PACKET); + assert_eq!(p.next(RTT, CWND), n + (RTT / 20)); + } + + #[test] + fn backwards_in_time() { + let n = now(); + let mut p = Pacer::new(true, n + RTT, PACKET, PACKET); + assert_eq!(p.next(RTT, CWND), n + RTT); + // Now spend some credit in the past using a time machine. + p.spend(n, RTT, CWND, PACKET); + assert_eq!(p.next(RTT, CWND), n + (RTT / 20)); + } + + #[test] + fn pacing_disabled() { + let n = now(); + let mut p = Pacer::new(false, n, PACKET, PACKET); + assert_eq!(p.next(RTT, CWND), n); + p.spend(n, RTT, CWND, PACKET); + assert_eq!(p.next(RTT, CWND), n); + } +} diff --git a/third_party/rust/neqo-transport/src/packet/mod.rs b/third_party/rust/neqo-transport/src/packet/mod.rs new file mode 100644 index 0000000000..ccfd212d5f --- /dev/null +++ b/third_party/rust/neqo-transport/src/packet/mod.rs @@ -0,0 +1,1457 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Encoding and decoding packets off the wire. +use std::{ + cmp::min, + convert::TryFrom, + fmt, + iter::ExactSizeIterator, + ops::{Deref, DerefMut, Range}, + time::Instant, +}; + +use neqo_common::{hex, hex_with_len, qtrace, qwarn, Decoder, Encoder}; +use neqo_crypto::random; + +use crate::{ + cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdRef, MAX_CONNECTION_ID_LEN}, + crypto::{CryptoDxState, CryptoSpace, CryptoStates}, + version::{Version, WireVersion}, + Error, Res, +}; + +pub const PACKET_BIT_LONG: u8 = 0x80; +const PACKET_BIT_SHORT: u8 = 0x00; +const PACKET_BIT_FIXED_QUIC: u8 = 0x40; +const PACKET_BIT_SPIN: u8 = 0x20; +const PACKET_BIT_KEY_PHASE: u8 = 0x04; + +const PACKET_HP_MASK_LONG: u8 = 0x0f; +const PACKET_HP_MASK_SHORT: u8 = 0x1f; + +const SAMPLE_SIZE: usize = 16; +const SAMPLE_OFFSET: usize = 4; +const MAX_PACKET_NUMBER_LEN: usize = 4; + +mod retry; + +pub type PacketNumber = u64; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PacketType { + VersionNegotiation, + Initial, + Handshake, + ZeroRtt, + Retry, + Short, + OtherVersion, +} + +impl PacketType { + #[must_use] + fn from_byte(t: u8, v: Version) -> Self { + // Version2 adds one to the type, modulo 4 + match t.wrapping_sub(u8::from(v == Version::Version2)) & 3 { + 0 => Self::Initial, + 1 => Self::ZeroRtt, + 2 => Self::Handshake, + 3 => Self::Retry, + _ => panic!("packet type out of range"), + } + } + + #[must_use] + fn to_byte(self, v: Version) -> u8 { + let t = match self { + Self::Initial => 0, + Self::ZeroRtt => 1, + Self::Handshake => 2, + Self::Retry => 3, + _ => panic!("not a long header packet type"), + }; + // Version2 adds one to the type, modulo 4 + (t + u8::from(v == Version::Version2)) & 3 + } +} + +impl From<PacketType> for CryptoSpace { + fn from(v: PacketType) -> Self { + match v { + PacketType::Initial => Self::Initial, + PacketType::ZeroRtt => Self::ZeroRtt, + PacketType::Handshake => Self::Handshake, + PacketType::Short => Self::ApplicationData, + _ => panic!("shouldn't be here"), + } + } +} + +impl From<CryptoSpace> for PacketType { + fn from(cs: CryptoSpace) -> Self { + match cs { + CryptoSpace::Initial => Self::Initial, + CryptoSpace::ZeroRtt => Self::ZeroRtt, + CryptoSpace::Handshake => Self::Handshake, + CryptoSpace::ApplicationData => Self::Short, + } + } +} + +struct PacketBuilderOffsets { + /// The bits of the first octet that need masking. + first_byte_mask: u8, + /// The offset of the length field. + len: usize, + /// The location of the packet number field. + pn: Range<usize>, +} + +/// A packet builder that can be used to produce short packets and long packets. +/// This does not produce Retry or Version Negotiation. +pub struct PacketBuilder { + encoder: Encoder, + pn: PacketNumber, + header: Range<usize>, + offsets: PacketBuilderOffsets, + limit: usize, + /// Whether to pad the packet before construction. + padding: bool, +} + +impl PacketBuilder { + /// The minimum useful frame size. If space is less than this, we will claim to be full. + pub const MINIMUM_FRAME_SIZE: usize = 2; + + fn infer_limit(encoder: &Encoder) -> usize { + if encoder.capacity() > 64 { + encoder.capacity() + } else { + 2048 + } + } + + /// Start building a short header packet. + /// + /// This doesn't fail if there isn't enough space; instead it returns a builder that + /// has no available space left. This allows the caller to extract the encoder + /// and any packets that might have been added before as adding a packet header is + /// only likely to fail if there are other packets already written. + /// + /// If, after calling this method, `remaining()` returns 0, then call `abort()` to get + /// the encoder back. + #[allow(clippy::reversed_empty_ranges)] + pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self { + let mut limit = Self::infer_limit(&encoder); + let header_start = encoder.len(); + // Check that there is enough space for the header. + // 5 = 1 (first byte) + 4 (packet number) + if limit > encoder.len() && 5 + dcid.as_ref().len() < limit - encoder.len() { + encoder + .encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2)); + encoder.encode(dcid.as_ref()); + } else { + limit = 0; + } + Self { + encoder, + pn: u64::max_value(), + header: header_start..header_start, + offsets: PacketBuilderOffsets { + first_byte_mask: PACKET_HP_MASK_SHORT, + pn: 0..0, + len: 0, + }, + limit, + padding: false, + } + } + + /// Start building a long header packet. + /// For an Initial packet you will need to call initial_token(), + /// even if the token is empty. + /// + /// See `short()` for more on how to handle this in cases where there is no space. + #[allow(clippy::reversed_empty_ranges)] // For initializing an empty range. + pub fn long( + mut encoder: Encoder, + pt: PacketType, + version: Version, + dcid: impl AsRef<[u8]>, + scid: impl AsRef<[u8]>, + ) -> Self { + let mut limit = Self::infer_limit(&encoder); + let header_start = encoder.len(); + // Check that there is enough space for the header. + // 11 = 1 (first byte) + 4 (version) + 2 (dcid+scid length) + 4 (packet number) + if limit > encoder.len() + && 11 + dcid.as_ref().len() + scid.as_ref().len() < limit - encoder.len() + { + encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.to_byte(version) << 4); + encoder.encode_uint(4, version.wire_version()); + encoder.encode_vec(1, dcid.as_ref()); + encoder.encode_vec(1, scid.as_ref()); + } else { + limit = 0; + } + + Self { + encoder, + pn: u64::max_value(), + header: header_start..header_start, + offsets: PacketBuilderOffsets { + first_byte_mask: PACKET_HP_MASK_LONG, + pn: 0..0, + len: 0, + }, + limit, + padding: false, + } + } + + fn is_long(&self) -> bool { + self.as_ref()[self.header.start] & 0x80 == PACKET_BIT_LONG + } + + /// This stores a value that can be used as a limit. This does not cause + /// this limit to be enforced until encryption occurs. Prior to that, it + /// is only used voluntarily by users of the builder, through `remaining()`. + pub fn set_limit(&mut self, limit: usize) { + self.limit = limit; + } + + /// Get the current limit. + #[must_use] + pub fn limit(&mut self) -> usize { + self.limit + } + + /// How many bytes remain against the size limit for the builder. + #[must_use] + pub fn remaining(&self) -> usize { + self.limit.saturating_sub(self.encoder.len()) + } + + /// Returns true if the packet has no more space for frames. + #[must_use] + pub fn is_full(&self) -> bool { + // No useful frame is smaller than 2 bytes long. + self.limit < self.encoder.len() + Self::MINIMUM_FRAME_SIZE + } + + /// Adjust the limit to ensure that no more data is added. + pub fn mark_full(&mut self) { + self.limit = self.encoder.len(); + } + + /// Mark the packet as needing padding (or not). + pub fn enable_padding(&mut self, needs_padding: bool) { + self.padding = needs_padding; + } + + /// Maybe pad with "PADDING" frames. + /// Only does so if padding was needed and this is a short packet. + /// Returns true if padding was added. + pub fn pad(&mut self) -> bool { + if self.padding && !self.is_long() { + self.encoder.pad_to(self.limit, 0); + true + } else { + false + } + } + + /// Add unpredictable values for unprotected parts of the packet. + pub fn scramble(&mut self, quic_bit: bool) { + debug_assert!(self.len() > self.header.start); + let mask = if quic_bit { PACKET_BIT_FIXED_QUIC } else { 0 } + | if self.is_long() { 0 } else { PACKET_BIT_SPIN }; + let first = self.header.start; + self.encoder.as_mut()[first] ^= random(1)[0] & mask; + } + + /// For an Initial packet, encode the token. + /// If you fail to do this, then you will not get a valid packet. + pub fn initial_token(&mut self, token: &[u8]) { + if Encoder::vvec_len(token.len()) < self.remaining() { + self.encoder.encode_vvec(token); + } else { + self.limit = 0; + } + } + + /// Add a packet number of the given size. + /// For a long header packet, this also inserts a dummy length. + /// The length is filled in after calling `build`. + /// Does nothing if there isn't 4 bytes available other than render this builder + /// unusable; if `remaining()` returns 0 at any point, call `abort()`. + pub fn pn(&mut self, pn: PacketNumber, pn_len: usize) { + if self.remaining() < 4 { + self.limit = 0; + return; + } + + // Reserve space for a length in long headers. + if self.is_long() { + self.offsets.len = self.encoder.len(); + self.encoder.encode(&[0; 2]); + } + + // This allows the input to be >4, which is absurd, but we can eat that. + let pn_len = min(MAX_PACKET_NUMBER_LEN, pn_len); + debug_assert_ne!(pn_len, 0); + // Encode the packet number and save its offset. + let pn_offset = self.encoder.len(); + self.encoder.encode_uint(pn_len, pn); + self.offsets.pn = pn_offset..self.encoder.len(); + + // Now encode the packet number length and save the header length. + self.encoder.as_mut()[self.header.start] |= u8::try_from(pn_len - 1).unwrap(); + self.header.end = self.encoder.len(); + self.pn = pn; + } + + fn write_len(&mut self, expansion: usize) { + let len = self.encoder.len() - (self.offsets.len + 2) + expansion; + self.encoder.as_mut()[self.offsets.len] = 0x40 | ((len >> 8) & 0x3f) as u8; + self.encoder.as_mut()[self.offsets.len + 1] = (len & 0xff) as u8; + } + + fn pad_for_crypto(&mut self, crypto: &mut CryptoDxState) { + // Make sure that there is enough data in the packet. + // The length of the packet number plus the payload length needs to + // be at least 4 (MAX_PACKET_NUMBER_LEN) plus any amount by which + // the header protection sample exceeds the AEAD expansion. + let crypto_pad = crypto.extra_padding(); + self.encoder.pad_to( + self.offsets.pn.start + MAX_PACKET_NUMBER_LEN + crypto_pad, + 0, + ); + } + + /// A lot of frames here are just a collection of varints. + /// This helper functions writes a frame like that safely, returning `true` if + /// a frame was written. + pub fn write_varint_frame(&mut self, values: &[u64]) -> bool { + let write = self.remaining() + >= values + .iter() + .map(|&v| Encoder::varint_len(v)) + .sum::<usize>(); + if write { + for v in values { + self.encode_varint(*v); + } + debug_assert!(self.len() <= self.limit()); + }; + write + } + + /// Build the packet and return the encoder. + pub fn build(mut self, crypto: &mut CryptoDxState) -> Res<Encoder> { + if self.len() > self.limit { + qwarn!("Packet contents are more than the limit"); + debug_assert!(false); + return Err(Error::InternalError); + } + + self.pad_for_crypto(crypto); + if self.offsets.len > 0 { + self.write_len(crypto.expansion()); + } + + let hdr = &self.encoder.as_ref()[self.header.clone()]; + let body = &self.encoder.as_ref()[self.header.end..]; + qtrace!( + "Packet build pn={} hdr={} body={}", + self.pn, + hex(hdr), + hex(body) + ); + let ciphertext = crypto.encrypt(self.pn, hdr, body)?; + + // Calculate the mask. + let offset = SAMPLE_OFFSET - self.offsets.pn.len(); + assert!(offset + SAMPLE_SIZE <= ciphertext.len()); + let sample = &ciphertext[offset..offset + SAMPLE_SIZE]; + let mask = crypto.compute_mask(sample)?; + + // Apply the mask. + self.encoder.as_mut()[self.header.start] ^= mask[0] & self.offsets.first_byte_mask; + for (i, j) in (1..=self.offsets.pn.len()).zip(self.offsets.pn) { + self.encoder.as_mut()[j] ^= mask[i]; + } + + // Finally, cut off the plaintext and add back the ciphertext. + self.encoder.truncate(self.header.end); + self.encoder.encode(&ciphertext); + qtrace!("Packet built {}", hex(&self.encoder)); + Ok(self.encoder) + } + + /// Abort writing of this packet and return the encoder. + #[must_use] + pub fn abort(mut self) -> Encoder { + self.encoder.truncate(self.header.start); + self.encoder + } + + /// Work out if nothing was added after the header. + #[must_use] + pub fn packet_empty(&self) -> bool { + self.encoder.len() == self.header.end + } + + /// Make a retry packet. + /// As this is a simple packet, this is just an associated function. + /// As Retry is odd (it has to be constructed with leading bytes), + /// this returns a [`Vec<u8>`] rather than building on an encoder. + pub fn retry( + version: Version, + dcid: &[u8], + scid: &[u8], + token: &[u8], + odcid: &[u8], + ) -> Res<Vec<u8>> { + let mut encoder = Encoder::default(); + encoder.encode_vec(1, odcid); + let start = encoder.len(); + encoder.encode_byte( + PACKET_BIT_LONG + | PACKET_BIT_FIXED_QUIC + | (PacketType::Retry.to_byte(version) << 4) + | (random(1)[0] & 0xf), + ); + encoder.encode_uint(4, version.wire_version()); + encoder.encode_vec(1, dcid); + encoder.encode_vec(1, scid); + debug_assert_ne!(token.len(), 0); + encoder.encode(token); + let tag = retry::use_aead(version, |aead| { + let mut buf = vec![0; aead.expansion()]; + Ok(aead.encrypt(0, encoder.as_ref(), &[], &mut buf)?.to_vec()) + })?; + encoder.encode(&tag); + let mut complete: Vec<u8> = encoder.into(); + Ok(complete.split_off(start)) + } + + /// Make a Version Negotiation packet. + pub fn version_negotiation( + dcid: &[u8], + scid: &[u8], + client_version: u32, + versions: &[Version], + ) -> Vec<u8> { + let mut encoder = Encoder::default(); + let mut grease = random(4); + // This will not include the "QUIC bit" sometimes. Intentionally. + encoder.encode_byte(PACKET_BIT_LONG | (grease[3] & 0x7f)); + encoder.encode(&[0; 4]); // Zero version == VN. + encoder.encode_vec(1, dcid); + encoder.encode_vec(1, scid); + + for v in versions { + encoder.encode_uint(4, v.wire_version()); + } + // Add a greased version, using the randomness already generated. + for g in &mut grease[..3] { + *g = *g & 0xf0 | 0x0a; + } + + // Ensure our greased version does not collide with the client version + // by making the last byte differ from the client initial. + grease[3] = (client_version.wrapping_add(0x10) & 0xf0) as u8 | 0x0a; + encoder.encode(&grease[..4]); + + Vec::from(encoder) + } +} + +impl Deref for PacketBuilder { + type Target = Encoder; + + fn deref(&self) -> &Self::Target { + &self.encoder + } +} + +impl DerefMut for PacketBuilder { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.encoder + } +} + +impl From<PacketBuilder> for Encoder { + fn from(v: PacketBuilder) -> Self { + v.encoder + } +} + +/// PublicPacket holds information from packets that is public only. This allows for +/// processing of packets prior to decryption. +pub struct PublicPacket<'a> { + /// The packet type. + packet_type: PacketType, + /// The recovered destination connection ID. + dcid: ConnectionIdRef<'a>, + /// The source connection ID, if this is a long header packet. + scid: Option<ConnectionIdRef<'a>>, + /// Any token that is included in the packet (Retry always has a token; Initial sometimes + /// does). This is empty when there is no token. + token: &'a [u8], + /// The size of the header, not including the packet number. + header_len: usize, + /// Protocol version, if present in header. + version: Option<WireVersion>, + /// A reference to the entire packet, including the header. + data: &'a [u8], +} + +impl<'a> PublicPacket<'a> { + fn opt<T>(v: Option<T>) -> Res<T> { + if let Some(v) = v { + Ok(v) + } else { + Err(Error::NoMoreData) + } + } + + /// Decode the type-specific portions of a long header. + /// This includes reading the length and the remainder of the packet. + /// Returns a tuple of any token and the length of the header. + fn decode_long( + decoder: &mut Decoder<'a>, + packet_type: PacketType, + version: Version, + ) -> Res<(&'a [u8], usize)> { + if packet_type == PacketType::Retry { + let header_len = decoder.offset(); + let expansion = retry::expansion(version); + let token = Self::opt(decoder.decode(decoder.remaining() - expansion))?; + if token.is_empty() { + return Err(Error::InvalidPacket); + } + Self::opt(decoder.decode(expansion))?; + return Ok((token, header_len)); + } + let token = if packet_type == PacketType::Initial { + Self::opt(decoder.decode_vvec())? + } else { + &[] + }; + let len = Self::opt(decoder.decode_varint())?; + let header_len = decoder.offset(); + let _body = Self::opt(decoder.decode(usize::try_from(len)?))?; + Ok((token, header_len)) + } + + /// Decode the common parts of a packet. This provides minimal parsing and validation. + /// Returns a tuple of a `PublicPacket` and a slice with any remainder from the datagram. + pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> { + let mut decoder = Decoder::new(data); + let first = Self::opt(decoder.decode_byte())?; + + if first & 0x80 == PACKET_BIT_SHORT { + // Conveniently, this also guarantees that there is enough space + // for a connection ID of any size. + if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { + return Err(Error::InvalidPacket); + } + let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?; + if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { + return Err(Error::InvalidPacket); + } + let header_len = decoder.offset(); + return Ok(( + Self { + packet_type: PacketType::Short, + dcid, + scid: None, + token: &[], + header_len, + version: None, + data, + }, + &[], + )); + } + + // Generic long header. + let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?).unwrap(); + let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); + let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); + + // Version negotiation. + if version == 0 { + return Ok(( + Self { + packet_type: PacketType::VersionNegotiation, + dcid, + scid: Some(scid), + token: &[], + header_len: decoder.offset(), + version: None, + data, + }, + &[], + )); + } + + // Check that this is a long header from a supported version. + let Ok(version) = Version::try_from(version) else { + return Ok(( + Self { + packet_type: PacketType::OtherVersion, + dcid, + scid: Some(scid), + token: &[], + header_len: decoder.offset(), + version: Some(version), + data, + }, + &[], + )); + }; + + if dcid.len() > MAX_CONNECTION_ID_LEN || scid.len() > MAX_CONNECTION_ID_LEN { + return Err(Error::InvalidPacket); + } + let packet_type = PacketType::from_byte((first >> 4) & 3, version); + + // The type-specific code includes a token. This consumes the remainder of the packet. + let (token, header_len) = Self::decode_long(&mut decoder, packet_type, version)?; + let end = data.len() - decoder.remaining(); + let (data, remainder) = data.split_at(end); + Ok(( + Self { + packet_type, + dcid, + scid: Some(scid), + token, + header_len, + version: Some(version.wire_version()), + data, + }, + remainder, + )) + } + + /// Validate the given packet as though it were a retry. + pub fn is_valid_retry(&self, odcid: &ConnectionId) -> bool { + if self.packet_type != PacketType::Retry { + return false; + } + let version = self.version().unwrap(); + let expansion = retry::expansion(version); + if self.data.len() <= expansion { + return false; + } + let (header, tag) = self.data.split_at(self.data.len() - expansion); + let mut encoder = Encoder::with_capacity(self.data.len()); + encoder.encode_vec(1, odcid); + encoder.encode(header); + retry::use_aead(version, |aead| { + let mut buf = vec![0; expansion]; + Ok(aead.decrypt(0, encoder.as_ref(), tag, &mut buf)?.is_empty()) + }) + .unwrap_or(false) + } + + pub fn is_valid_initial(&self) -> bool { + // Packet has to be an initial, with a DCID of 8 bytes, or a token. + // Note: the Server class validates the token and checks the length. + self.packet_type == PacketType::Initial + && (self.dcid().len() >= 8 || !self.token.is_empty()) + } + + pub fn packet_type(&self) -> PacketType { + self.packet_type + } + + pub fn dcid(&self) -> ConnectionIdRef<'a> { + self.dcid + } + + pub fn scid(&self) -> ConnectionIdRef<'a> { + self.scid + .expect("should only be called for long header packets") + } + + pub fn token(&self) -> &'a [u8] { + self.token + } + + pub fn version(&self) -> Option<Version> { + self.version.and_then(|v| Version::try_from(v).ok()) + } + + pub fn wire_version(&self) -> WireVersion { + debug_assert!(self.version.is_some()); + self.version.unwrap_or(0) + } + + pub fn len(&self) -> usize { + self.data.len() + } + + fn decode_pn(expected: PacketNumber, pn: u64, w: usize) -> PacketNumber { + let window = 1_u64 << (w * 8); + let candidate = (expected & !(window - 1)) | pn; + if candidate + (window / 2) <= expected { + candidate + window + } else if candidate > expected + (window / 2) { + match candidate.checked_sub(window) { + Some(pn_sub) => pn_sub, + None => candidate, + } + } else { + candidate + } + } + + /// Decrypt the header of the packet. + fn decrypt_header( + &self, + crypto: &mut CryptoDxState, + ) -> Res<(bool, PacketNumber, Vec<u8>, &'a [u8])> { + assert_ne!(self.packet_type, PacketType::Retry); + assert_ne!(self.packet_type, PacketType::VersionNegotiation); + + qtrace!( + "unmask hdr={}", + hex(&self.data[..self.header_len + SAMPLE_OFFSET]) + ); + + let sample_offset = self.header_len + SAMPLE_OFFSET; + let mask = if let Some(sample) = self.data.get(sample_offset..(sample_offset + SAMPLE_SIZE)) + { + crypto.compute_mask(sample) + } else { + Err(Error::NoMoreData) + }?; + + // Un-mask the leading byte. + let bits = if self.packet_type == PacketType::Short { + PACKET_HP_MASK_SHORT + } else { + PACKET_HP_MASK_LONG + }; + let first_byte = self.data[0] ^ (mask[0] & bits); + + // Make a copy of the header to work on. + let mut hdrbytes = self.data[..self.header_len + 4].to_vec(); + hdrbytes[0] = first_byte; + + // Unmask the PN. + let mut pn_encoded: u64 = 0; + for i in 0..MAX_PACKET_NUMBER_LEN { + hdrbytes[self.header_len + i] ^= mask[1 + i]; + pn_encoded <<= 8; + pn_encoded += u64::from(hdrbytes[self.header_len + i]); + } + + // Now decode the packet number length and apply it, hopefully in constant time. + let pn_len = usize::from((first_byte & 0x3) + 1); + hdrbytes.truncate(self.header_len + pn_len); + pn_encoded >>= 8 * (MAX_PACKET_NUMBER_LEN - pn_len); + + qtrace!("unmasked hdr={}", hex(&hdrbytes)); + + let key_phase = self.packet_type == PacketType::Short + && (first_byte & PACKET_BIT_KEY_PHASE) == PACKET_BIT_KEY_PHASE; + let pn = Self::decode_pn(crypto.next_pn(), pn_encoded, pn_len); + Ok(( + key_phase, + pn, + hdrbytes, + &self.data[self.header_len + pn_len..], + )) + } + + pub fn decrypt(&self, crypto: &mut CryptoStates, release_at: Instant) -> Res<DecryptedPacket> { + let cspace: CryptoSpace = self.packet_type.into(); + // When we don't have a version, the crypto code doesn't need a version + // for lookup, so use the default, but fix it up if decryption succeeds. + let version = self.version().unwrap_or_default(); + // This has to work in two stages because we need to remove header protection + // before picking the keys to use. + if let Some(rx) = crypto.rx_hp(version, cspace) { + // Note that this will dump early, which creates a side-channel. + // This is OK in this case because we the only reason this can + // fail is if the cryptographic module is bad or the packet is + // too small (which is public information). + let (key_phase, pn, header, body) = self.decrypt_header(rx)?; + qtrace!([rx], "decoded header: {:?}", header); + let rx = crypto.rx(version, cspace, key_phase).unwrap(); + let version = rx.version(); // Version fixup; see above. + let d = rx.decrypt(pn, &header, body)?; + // If this is the first packet ever successfully decrypted + // using `rx`, make sure to initiate a key update. + if rx.needs_update() { + crypto.key_update_received(release_at)?; + } + crypto.check_pn_overlap()?; + Ok(DecryptedPacket { + version, + pt: self.packet_type, + pn, + data: d, + }) + } else if crypto.rx_pending(cspace) { + Err(Error::KeysPending(cspace)) + } else { + qtrace!("keys for {:?} already discarded", cspace); + Err(Error::KeysDiscarded(cspace)) + } + } + + pub fn supported_versions(&self) -> Res<Vec<WireVersion>> { + assert_eq!(self.packet_type, PacketType::VersionNegotiation); + let mut decoder = Decoder::new(&self.data[self.header_len..]); + let mut res = Vec::new(); + while decoder.remaining() > 0 { + let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?)?; + res.push(version); + } + Ok(res) + } +} + +impl fmt::Debug for PublicPacket<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{:?}: {} {}", + self.packet_type(), + hex_with_len(&self.data[..self.header_len]), + hex_with_len(&self.data[self.header_len..]) + ) + } +} + +pub struct DecryptedPacket { + version: Version, + pt: PacketType, + pn: PacketNumber, + data: Vec<u8>, +} + +impl DecryptedPacket { + pub fn version(&self) -> Version { + self.version + } + + pub fn packet_type(&self) -> PacketType { + self.pt + } + + pub fn pn(&self) -> PacketNumber { + self.pn + } +} + +impl Deref for DecryptedPacket { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.data[..] + } +} + +#[cfg(all(test, not(feature = "fuzzing")))] +mod tests { + use neqo_common::Encoder; + use test_fixture::{fixture_init, now}; + + use super::*; + use crate::{ + crypto::{CryptoDxState, CryptoStates}, + EmptyConnectionIdGenerator, RandomConnectionIdGenerator, Version, + }; + + const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; + const SERVER_CID: &[u8] = &[0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5]; + + /// This is a connection ID manager, which is only used for decoding short header packets. + fn cid_mgr() -> RandomConnectionIdGenerator { + RandomConnectionIdGenerator::new(SERVER_CID.len()) + } + + const SAMPLE_INITIAL_PAYLOAD: &[u8] = &[ + 0x02, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x40, 0x5a, 0x02, 0x00, 0x00, 0x56, 0x03, 0x03, + 0xee, 0xfc, 0xe7, 0xf7, 0xb3, 0x7b, 0xa1, 0xd1, 0x63, 0x2e, 0x96, 0x67, 0x78, 0x25, 0xdd, + 0xf7, 0x39, 0x88, 0xcf, 0xc7, 0x98, 0x25, 0xdf, 0x56, 0x6d, 0xc5, 0x43, 0x0b, 0x9a, 0x04, + 0x5a, 0x12, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00, + 0x20, 0x9d, 0x3c, 0x94, 0x0d, 0x89, 0x69, 0x0b, 0x84, 0xd0, 0x8a, 0x60, 0x99, 0x3c, 0x14, + 0x4e, 0xca, 0x68, 0x4d, 0x10, 0x81, 0x28, 0x7c, 0x83, 0x4d, 0x53, 0x11, 0xbc, 0xf3, 0x2b, + 0xb9, 0xda, 0x1a, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04, + ]; + const SAMPLE_INITIAL: &[u8] = &[ + 0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x00, 0x40, 0x75, 0xc0, 0xd9, 0x5a, 0x48, 0x2c, 0xd0, 0x99, 0x1c, 0xd2, 0x5b, 0x0a, 0xac, + 0x40, 0x6a, 0x58, 0x16, 0xb6, 0x39, 0x41, 0x00, 0xf3, 0x7a, 0x1c, 0x69, 0x79, 0x75, 0x54, + 0x78, 0x0b, 0xb3, 0x8c, 0xc5, 0xa9, 0x9f, 0x5e, 0xde, 0x4c, 0xf7, 0x3c, 0x3e, 0xc2, 0x49, + 0x3a, 0x18, 0x39, 0xb3, 0xdb, 0xcb, 0xa3, 0xf6, 0xea, 0x46, 0xc5, 0xb7, 0x68, 0x4d, 0xf3, + 0x54, 0x8e, 0x7d, 0xde, 0xb9, 0xc3, 0xbf, 0x9c, 0x73, 0xcc, 0x3f, 0x3b, 0xde, 0xd7, 0x4b, + 0x56, 0x2b, 0xfb, 0x19, 0xfb, 0x84, 0x02, 0x2f, 0x8e, 0xf4, 0xcd, 0xd9, 0x37, 0x95, 0xd7, + 0x7d, 0x06, 0xed, 0xbb, 0x7a, 0xaf, 0x2f, 0x58, 0x89, 0x18, 0x50, 0xab, 0xbd, 0xca, 0x3d, + 0x20, 0x39, 0x8c, 0x27, 0x64, 0x56, 0xcb, 0xc4, 0x21, 0x58, 0x40, 0x7d, 0xd0, 0x74, 0xee, + ]; + + #[test] + fn sample_server_initial() { + fixture_init(); + let mut prot = CryptoDxState::test_default(); + + // The spec uses PN=1, but our crypto refuses to skip packet numbers. + // So burn an encryption: + let burn = prot.encrypt(0, &[], &[]).expect("burn OK"); + assert_eq!(burn.len(), prot.expansion()); + + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Initial, + Version::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(SERVER_CID), + ); + builder.initial_token(&[]); + builder.pn(1, 2); + builder.encode(SAMPLE_INITIAL_PAYLOAD); + let packet = builder.build(&mut prot).expect("build"); + assert_eq!(packet.as_ref(), SAMPLE_INITIAL); + } + + #[test] + fn decrypt_initial() { + const EXTRA: &[u8] = &[0xce; 33]; + + fixture_init(); + let mut padded = SAMPLE_INITIAL.to_vec(); + padded.extend_from_slice(EXTRA); + let (packet, remainder) = PublicPacket::decode(&padded, &cid_mgr()).unwrap(); + assert_eq!(packet.packet_type(), PacketType::Initial); + assert_eq!(&packet.dcid()[..], &[] as &[u8]); + assert_eq!(&packet.scid()[..], SERVER_CID); + assert!(packet.token().is_empty()); + assert_eq!(remainder, EXTRA); + + let decrypted = packet + .decrypt(&mut CryptoStates::test_default(), now()) + .unwrap(); + assert_eq!(decrypted.pn(), 1); + } + + #[test] + fn disallow_long_dcid() { + let mut enc = Encoder::new(); + enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC); + enc.encode_uint(4, Version::default().wire_version()); + enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 1]); + enc.encode_vec(1, &[]); + enc.encode(&[0xff; 40]); // junk + + assert!(PublicPacket::decode(enc.as_ref(), &cid_mgr()).is_err()); + } + + #[test] + fn disallow_long_scid() { + let mut enc = Encoder::new(); + enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC); + enc.encode_uint(4, Version::default().wire_version()); + enc.encode_vec(1, &[]); + enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 2]); + enc.encode(&[0xff; 40]); // junk + + assert!(PublicPacket::decode(enc.as_ref(), &cid_mgr()).is_err()); + } + + const SAMPLE_SHORT: &[u8] = &[ + 0x40, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0xf4, 0xa8, 0x30, 0x39, 0xc4, 0x7d, + 0x99, 0xe3, 0x94, 0x1c, 0x9b, 0xb9, 0x7a, 0x30, 0x1d, 0xd5, 0x8f, 0xf3, 0xdd, 0xa9, + ]; + const SAMPLE_SHORT_PAYLOAD: &[u8] = &[0; 3]; + + #[test] + fn build_short() { + fixture_init(); + let mut builder = + PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID)); + builder.pn(0, 1); + builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling. + let packet = builder + .build(&mut CryptoDxState::test_default()) + .expect("build"); + assert_eq!(packet.as_ref(), SAMPLE_SHORT); + } + + #[test] + fn scramble_short() { + fixture_init(); + let mut firsts = Vec::new(); + for _ in 0..64 { + let mut builder = + PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID)); + builder.scramble(true); + builder.pn(0, 1); + firsts.push(builder.as_ref()[0]); + } + let is_set = |bit| move |v| v & bit == bit; + // There should be at least one value with the QUIC bit set: + assert!(firsts.iter().any(is_set(PACKET_BIT_FIXED_QUIC))); + // ... but not all: + assert!(!firsts.iter().all(is_set(PACKET_BIT_FIXED_QUIC))); + // There should be at least one value with the spin bit set: + assert!(firsts.iter().any(is_set(PACKET_BIT_SPIN))); + // ... but not all: + assert!(!firsts.iter().all(is_set(PACKET_BIT_SPIN))); + } + + #[test] + fn decode_short() { + fixture_init(); + let (packet, remainder) = PublicPacket::decode(SAMPLE_SHORT, &cid_mgr()).unwrap(); + assert_eq!(packet.packet_type(), PacketType::Short); + assert!(remainder.is_empty()); + let decrypted = packet + .decrypt(&mut CryptoStates::test_default(), now()) + .unwrap(); + assert_eq!(&decrypted[..], SAMPLE_SHORT_PAYLOAD); + } + + /// By telling the decoder that the connection ID is shorter than it really is, we get a + /// decryption error. + #[test] + fn decode_short_bad_cid() { + fixture_init(); + let (packet, remainder) = PublicPacket::decode( + SAMPLE_SHORT, + &RandomConnectionIdGenerator::new(SERVER_CID.len() - 1), + ) + .unwrap(); + assert_eq!(packet.packet_type(), PacketType::Short); + assert!(remainder.is_empty()); + assert!(packet + .decrypt(&mut CryptoStates::test_default(), now()) + .is_err()); + } + + /// Saying that the connection ID is longer causes the initial decode to fail. + #[test] + fn decode_short_long_cid() { + assert!(PublicPacket::decode( + SAMPLE_SHORT, + &RandomConnectionIdGenerator::new(SERVER_CID.len() + 1) + ) + .is_err()); + } + + #[test] + fn build_two() { + fixture_init(); + let mut prot = CryptoDxState::test_default(); + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Handshake, + Version::default(), + &ConnectionId::from(SERVER_CID), + &ConnectionId::from(CLIENT_CID), + ); + builder.pn(0, 1); + builder.encode(&[0; 3]); + let encoder = builder.build(&mut prot).expect("build"); + assert_eq!(encoder.len(), 45); + let first = encoder.clone(); + + let mut builder = PacketBuilder::short(encoder, false, &ConnectionId::from(SERVER_CID)); + builder.pn(1, 3); + builder.encode(&[0]); // Minimal size (packet number is big enough). + let encoder = builder.build(&mut prot).expect("build"); + assert_eq!( + first.as_ref(), + &encoder.as_ref()[..first.len()], + "the first packet should be a prefix" + ); + assert_eq!(encoder.len(), 45 + 29); + } + + #[test] + fn build_long() { + const EXPECTED: &[u8] = &[ + 0xe4, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x40, 0x14, 0xfb, 0xa9, 0x32, 0x3a, 0xf8, + 0xbb, 0x18, 0x63, 0xc6, 0xbd, 0x78, 0x0e, 0xba, 0x0c, 0x98, 0x65, 0x58, 0xc9, 0x62, + 0x31, + ]; + + fixture_init(); + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Handshake, + Version::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(&[][..]), + ); + builder.pn(0, 1); + builder.encode(&[1, 2, 3]); + let packet = builder.build(&mut CryptoDxState::test_default()).unwrap(); + assert_eq!(packet.as_ref(), EXPECTED); + } + + #[test] + fn scramble_long() { + fixture_init(); + let mut found_unset = false; + let mut found_set = false; + for _ in 1..64 { + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Handshake, + Version::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(&[][..]), + ); + builder.pn(0, 1); + builder.scramble(true); + if (builder.as_ref()[0] & PACKET_BIT_FIXED_QUIC) == 0 { + found_unset = true; + } else { + found_set = true; + } + } + assert!(found_unset); + assert!(found_set); + } + + #[test] + fn build_abort() { + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Initial, + Version::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(SERVER_CID), + ); + assert_ne!(builder.remaining(), 0); + builder.initial_token(&[]); + assert_ne!(builder.remaining(), 0); + builder.pn(1, 2); + assert_ne!(builder.remaining(), 0); + let encoder = builder.abort(); + assert!(encoder.is_empty()); + } + + #[test] + fn build_insufficient_space() { + fixture_init(); + + let mut builder = PacketBuilder::short( + Encoder::with_capacity(100), + true, + &ConnectionId::from(SERVER_CID), + ); + builder.pn(0, 1); + // Pad, but not up to the full capacity. Leave enough space for the + // AEAD expansion and some extra, but not for an entire long header. + builder.set_limit(75); + builder.enable_padding(true); + assert!(builder.pad()); + let encoder = builder.build(&mut CryptoDxState::test_default()).unwrap(); + let encoder_copy = encoder.clone(); + + let builder = PacketBuilder::long( + encoder, + PacketType::Initial, + Version::default(), + &ConnectionId::from(SERVER_CID), + &ConnectionId::from(SERVER_CID), + ); + assert_eq!(builder.remaining(), 0); + assert_eq!(builder.abort(), encoder_copy); + } + + const SAMPLE_RETRY_V2: &[u8] = &[ + 0xcf, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xc8, 0x64, 0x6c, 0xe8, 0xbf, 0xe3, 0x39, 0x52, 0xd9, 0x55, + 0x54, 0x36, 0x65, 0xdc, 0xc7, 0xb6, + ]; + + const SAMPLE_RETRY_V1: &[u8] = &[ + 0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x04, 0xa2, 0x65, 0xba, 0x2e, 0xff, 0x4d, 0x82, 0x90, 0x58, + 0xfb, 0x3f, 0x0f, 0x24, 0x96, 0xba, + ]; + + const SAMPLE_RETRY_29: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xd1, 0x69, 0x26, 0xd8, 0x1f, 0x6f, 0x9c, 0xa2, 0x95, 0x3a, + 0x8a, 0xa4, 0x57, 0x5e, 0x1e, 0x49, + ]; + + const SAMPLE_RETRY_30: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x1e, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2d, 0x3e, 0x04, 0x5d, 0x6d, 0x39, 0x20, 0x67, 0x89, 0x94, + 0x37, 0x10, 0x8c, 0xe0, 0x0a, 0x61, + ]; + + const SAMPLE_RETRY_31: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x1f, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xc7, 0x0c, 0xe5, 0xde, 0x43, 0x0b, 0x4b, 0xdb, 0x7d, 0xf1, + 0xa3, 0x83, 0x3a, 0x75, 0xf9, 0x86, + ]; + + const SAMPLE_RETRY_32: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x20, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x59, 0x75, 0x65, 0x19, 0xdd, 0x6c, 0xc8, 0x5b, 0xd9, 0x0e, + 0x33, 0xa9, 0x34, 0xd2, 0xff, 0x85, + ]; + + const RETRY_TOKEN: &[u8] = b"token"; + + fn build_retry_single(version: Version, sample_retry: &[u8]) { + fixture_init(); + let retry = + PacketBuilder::retry(version, &[], SERVER_CID, RETRY_TOKEN, CLIENT_CID).unwrap(); + + let (packet, remainder) = PublicPacket::decode(&retry, &cid_mgr()).unwrap(); + assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); + assert!(remainder.is_empty()); + + // The builder adds randomness, which makes expectations hard. + // So only do a full check when that randomness matches up. + if retry[0] == sample_retry[0] { + assert_eq!(&retry, &sample_retry); + } else { + // Otherwise, just check that the header is OK. + assert_eq!( + retry[0] & 0xf0, + 0xc0 | (PacketType::Retry.to_byte(version) << 4) + ); + let header_range = 1..retry.len() - 16; + assert_eq!(&retry[header_range.clone()], &sample_retry[header_range]); + } + } + + #[test] + fn build_retry_v2() { + build_retry_single(Version::Version2, SAMPLE_RETRY_V2); + } + + #[test] + fn build_retry_v1() { + build_retry_single(Version::Version1, SAMPLE_RETRY_V1); + } + + #[test] + fn build_retry_29() { + build_retry_single(Version::Draft29, SAMPLE_RETRY_29); + } + + #[test] + fn build_retry_30() { + build_retry_single(Version::Draft30, SAMPLE_RETRY_30); + } + + #[test] + fn build_retry_31() { + build_retry_single(Version::Draft31, SAMPLE_RETRY_31); + } + + #[test] + fn build_retry_32() { + build_retry_single(Version::Draft32, SAMPLE_RETRY_32); + } + + #[test] + fn build_retry_multiple() { + // Run the build_retry test a few times. + // Odds are approximately 1 in 8 that the full comparison doesn't happen + // for a given version. + for _ in 0..32 { + build_retry_v2(); + build_retry_v1(); + build_retry_29(); + build_retry_30(); + build_retry_31(); + build_retry_32(); + } + } + + fn decode_retry(version: Version, sample_retry: &[u8]) { + fixture_init(); + let (packet, remainder) = + PublicPacket::decode(sample_retry, &RandomConnectionIdGenerator::new(5)).unwrap(); + assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); + assert_eq!(Some(version), packet.version()); + assert!(packet.dcid().is_empty()); + assert_eq!(&packet.scid()[..], SERVER_CID); + assert_eq!(packet.token(), RETRY_TOKEN); + assert!(remainder.is_empty()); + } + + #[test] + fn decode_retry_v2() { + decode_retry(Version::Version2, SAMPLE_RETRY_V2); + } + + #[test] + fn decode_retry_v1() { + decode_retry(Version::Version1, SAMPLE_RETRY_V1); + } + + #[test] + fn decode_retry_29() { + decode_retry(Version::Draft29, SAMPLE_RETRY_29); + } + + #[test] + fn decode_retry_30() { + decode_retry(Version::Draft30, SAMPLE_RETRY_30); + } + + #[test] + fn decode_retry_31() { + decode_retry(Version::Draft31, SAMPLE_RETRY_31); + } + + #[test] + fn decode_retry_32() { + decode_retry(Version::Draft32, SAMPLE_RETRY_32); + } + + /// Check some packets that are clearly not valid Retry packets. + #[test] + fn invalid_retry() { + fixture_init(); + let cid_mgr = RandomConnectionIdGenerator::new(5); + let odcid = ConnectionId::from(CLIENT_CID); + + assert!(PublicPacket::decode(&[], &cid_mgr).is_err()); + + let (packet, remainder) = PublicPacket::decode(SAMPLE_RETRY_V1, &cid_mgr).unwrap(); + assert!(remainder.is_empty()); + assert!(packet.is_valid_retry(&odcid)); + + let mut damaged_retry = SAMPLE_RETRY_V1.to_vec(); + let last = damaged_retry.len() - 1; + damaged_retry[last] ^= 66; + let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap(); + assert!(remainder.is_empty()); + assert!(!packet.is_valid_retry(&odcid)); + + damaged_retry.truncate(last); + let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap(); + assert!(remainder.is_empty()); + assert!(!packet.is_valid_retry(&odcid)); + + // An invalid token should be rejected sooner. + damaged_retry.truncate(last - 4); + assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err()); + + damaged_retry.truncate(last - 1); + assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err()); + } + + const SAMPLE_VN: &[u8] = &[ + 0x80, 0x00, 0x00, 0x00, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x08, + 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x00, 0x00, + 0x01, 0xff, 0x00, 0x00, 0x20, 0xff, 0x00, 0x00, 0x1f, 0xff, 0x00, 0x00, 0x1e, 0xff, 0x00, + 0x00, 0x1d, 0x0a, 0x0a, 0x0a, 0x0a, + ]; + + #[test] + fn build_vn() { + fixture_init(); + let mut vn = + PacketBuilder::version_negotiation(SERVER_CID, CLIENT_CID, 0x0a0a0a0a, &Version::all()); + // Erase randomness from greasing... + assert_eq!(vn.len(), SAMPLE_VN.len()); + vn[0] &= 0x80; + for v in vn.iter_mut().skip(SAMPLE_VN.len() - 4) { + *v &= 0x0f; + } + assert_eq!(&vn, &SAMPLE_VN); + } + + #[test] + fn vn_do_not_repeat_client_grease() { + fixture_init(); + let vn = + PacketBuilder::version_negotiation(SERVER_CID, CLIENT_CID, 0x0a0a0a0a, &Version::all()); + assert_ne!(&vn[SAMPLE_VN.len() - 4..], &[0x0a, 0x0a, 0x0a, 0x0a]); + } + + #[test] + fn parse_vn() { + let (packet, remainder) = + PublicPacket::decode(SAMPLE_VN, &EmptyConnectionIdGenerator::default()).unwrap(); + assert!(remainder.is_empty()); + assert_eq!(&packet.dcid[..], SERVER_CID); + assert!(packet.scid.is_some()); + assert_eq!(&packet.scid.unwrap()[..], CLIENT_CID); + } + + /// A Version Negotiation packet can have a long connection ID. + #[test] + fn parse_vn_big_cid() { + const BIG_DCID: &[u8] = &[0x44; MAX_CONNECTION_ID_LEN + 1]; + const BIG_SCID: &[u8] = &[0xee; 255]; + + let mut enc = Encoder::from(&[0xff, 0x00, 0x00, 0x00, 0x00][..]); + enc.encode_vec(1, BIG_DCID); + enc.encode_vec(1, BIG_SCID); + enc.encode_uint(4, 0x1a2a_3a4a_u64); + enc.encode_uint(4, Version::default().wire_version()); + enc.encode_uint(4, 0x5a6a_7a8a_u64); + + let (packet, remainder) = + PublicPacket::decode(enc.as_ref(), &EmptyConnectionIdGenerator::default()).unwrap(); + assert!(remainder.is_empty()); + assert_eq!(&packet.dcid[..], BIG_DCID); + assert!(packet.scid.is_some()); + assert_eq!(&packet.scid.unwrap()[..], BIG_SCID); + } + + #[test] + fn decode_pn() { + // When the expected value is low, the value doesn't go negative. + assert_eq!(PublicPacket::decode_pn(0, 0, 1), 0); + assert_eq!(PublicPacket::decode_pn(0, 0xff, 1), 0xff); + assert_eq!(PublicPacket::decode_pn(10, 0, 1), 0); + assert_eq!(PublicPacket::decode_pn(0x7f, 0, 1), 0); + assert_eq!(PublicPacket::decode_pn(0x80, 0, 1), 0x100); + assert_eq!(PublicPacket::decode_pn(0x80, 2, 1), 2); + assert_eq!(PublicPacket::decode_pn(0x80, 0xff, 1), 0xff); + assert_eq!(PublicPacket::decode_pn(0x7ff, 0xfe, 1), 0x7fe); + + // This is invalid by spec, as we are expected to check for overflow around 2^62-1, + // but we don't need to worry about overflow + // and hitting this is basically impossible in practice. + assert_eq!( + PublicPacket::decode_pn(0x3fff_ffff_ffff_ffff, 2, 4), + 0x4000_0000_0000_0002 + ); + } + + #[test] + fn chacha20_sample() { + const PACKET: &[u8] = &[ + 0x4c, 0xfe, 0x41, 0x89, 0x65, 0x5e, 0x5c, 0xd5, 0x5c, 0x41, 0xf6, 0x90, 0x80, 0x57, + 0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb, + ]; + fixture_init(); + let (packet, slice) = + PublicPacket::decode(PACKET, &EmptyConnectionIdGenerator::default()).unwrap(); + assert!(slice.is_empty()); + let decrypted = packet + .decrypt(&mut CryptoStates::test_chacha(), now()) + .unwrap(); + assert_eq!(decrypted.packet_type(), PacketType::Short); + assert_eq!(decrypted.pn(), 654_360_564); + assert_eq!(&decrypted[..], &[0x01]); + } +} diff --git a/third_party/rust/neqo-transport/src/packet/retry.rs b/third_party/rust/neqo-transport/src/packet/retry.rs new file mode 100644 index 0000000000..004e9de6e7 --- /dev/null +++ b/third_party/rust/neqo-transport/src/packet/retry.rs @@ -0,0 +1,59 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![deny(clippy::pedantic)] + +use std::cell::RefCell; + +use neqo_common::qerror; +use neqo_crypto::{hkdf, Aead, TLS_AES_128_GCM_SHA256, TLS_VERSION_1_3}; + +use crate::{version::Version, Error, Res}; + +/// The AEAD used for Retry is fixed, so use thread local storage. +fn make_aead(version: Version) -> Aead { + #[cfg(debug_assertions)] + ::neqo_crypto::assert_initialized(); + + let secret = hkdf::import_key(TLS_VERSION_1_3, version.retry_secret()).unwrap(); + Aead::new( + false, + TLS_VERSION_1_3, + TLS_AES_128_GCM_SHA256, + &secret, + version.label_prefix(), + ) + .unwrap() +} +thread_local!(static RETRY_AEAD_29: RefCell<Aead> = RefCell::new(make_aead(Version::Draft29))); +thread_local!(static RETRY_AEAD_V1: RefCell<Aead> = RefCell::new(make_aead(Version::Version1))); +thread_local!(static RETRY_AEAD_V2: RefCell<Aead> = RefCell::new(make_aead(Version::Version2))); + +/// Run a function with the appropriate Retry AEAD. +pub fn use_aead<F, T>(version: Version, f: F) -> Res<T> +where + F: FnOnce(&Aead) -> Res<T>, +{ + match version { + Version::Version2 => &RETRY_AEAD_V2, + Version::Version1 => &RETRY_AEAD_V1, + Version::Draft29 | Version::Draft30 | Version::Draft31 | Version::Draft32 => &RETRY_AEAD_29, + } + .try_with(|aead| f(&aead.borrow())) + .map_err(|e| { + qerror!("Unable to access Retry AEAD: {:?}", e); + Error::InternalError + })? +} + +/// Determine how large the expansion is for a given key. +pub fn expansion(version: Version) -> usize { + if let Ok(ex) = use_aead(version, |aead| Ok(aead.expansion())) { + ex + } else { + panic!("Unable to access Retry AEAD") + } +} diff --git a/third_party/rust/neqo-transport/src/path.rs b/third_party/rust/neqo-transport/src/path.rs new file mode 100644 index 0000000000..d6920c8d94 --- /dev/null +++ b/third_party/rust/neqo-transport/src/path.rs @@ -0,0 +1,1032 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![deny(clippy::pedantic)] +#![allow(clippy::module_name_repetitions)] + +use std::{ + cell::RefCell, + convert::TryFrom, + fmt::{self, Display}, + mem, + net::{IpAddr, SocketAddr}, + rc::Rc, + time::{Duration, Instant}, +}; + +use neqo_common::{hex, qdebug, qinfo, qlog::NeqoQlog, qtrace, Datagram, Encoder, IpTos}; +use neqo_crypto::random; + +use crate::{ + ackrate::{AckRate, PeerAckDelay}, + cc::CongestionControlAlgorithm, + cid::{ConnectionId, ConnectionIdRef, ConnectionIdStore, RemoteConnectionIdEntry}, + frame::{FRAME_TYPE_PATH_CHALLENGE, FRAME_TYPE_PATH_RESPONSE, FRAME_TYPE_RETIRE_CONNECTION_ID}, + packet::PacketBuilder, + recovery::RecoveryToken, + rtt::RttEstimate, + sender::PacketSender, + stats::FrameStats, + tracking::{PacketNumberSpace, SentPacket}, + Stats, +}; + +/// This is the MTU that we assume when using IPv6. +/// We use this size for Initial packets, so we don't need to worry about probing for support. +/// If the path doesn't support this MTU, we will assume that it doesn't support QUIC. +/// +/// This is a multiple of 16 greater than the largest possible short header (1 + 20 + 4). +pub const PATH_MTU_V6: usize = 1337; +/// The path MTU for IPv4 can be 20 bytes larger than for v6. +pub const PATH_MTU_V4: usize = PATH_MTU_V6 + 20; +/// The number of times that a path will be probed before it is considered failed. +const MAX_PATH_PROBES: usize = 3; +/// The maximum number of paths that `Paths` will track. +const MAX_PATHS: usize = 15; + +pub type PathRef = Rc<RefCell<Path>>; + +/// A collection for network paths. +/// This holds a collection of paths that have been used for sending or +/// receiving, plus an additional "temporary" path that is held only while +/// processing a packet. +/// This structure limits its storage and will forget about paths if it +/// is exposed to too many paths. +#[derive(Debug, Default)] +pub struct Paths { + /// All of the paths. All of these paths will be permanent. + #[allow(unknown_lints)] // available with Rust v1.75 + #[allow(clippy::struct_field_names)] + paths: Vec<PathRef>, + /// This is the primary path. This will only be `None` initially, so + /// care needs to be taken regarding that only during the handshake. + /// This path will also be in `paths`. + primary: Option<PathRef>, + + /// The path that we would prefer to migrate to. + migration_target: Option<PathRef>, + + /// Connection IDs that need to be retired. + to_retire: Vec<u64>, + + /// QLog handler. + qlog: NeqoQlog, +} + +impl Paths { + /// Find the path for the given addresses. + /// This might be a temporary path. + pub fn find_path( + &self, + local: SocketAddr, + remote: SocketAddr, + cc: CongestionControlAlgorithm, + pacing: bool, + now: Instant, + ) -> PathRef { + self.paths + .iter() + .find_map(|p| { + if p.borrow().received_on(local, remote, false) { + Some(Rc::clone(p)) + } else { + None + } + }) + .unwrap_or_else(|| { + let mut p = Path::temporary(local, remote, cc, pacing, self.qlog.clone(), now); + if let Some(primary) = self.primary.as_ref() { + p.prime_rtt(primary.borrow().rtt()); + } + Rc::new(RefCell::new(p)) + }) + } + + /// Find the path, but allow for rebinding. That matches the pair of addresses + /// to paths that match the remote address only based on IP addres, not port. + /// We use this when the other side migrates to skip address validation and + /// creating a new path. + pub fn find_path_with_rebinding( + &self, + local: SocketAddr, + remote: SocketAddr, + cc: CongestionControlAlgorithm, + pacing: bool, + now: Instant, + ) -> PathRef { + self.paths + .iter() + .find_map(|p| { + if p.borrow().received_on(local, remote, false) { + Some(Rc::clone(p)) + } else { + None + } + }) + .or_else(|| { + self.paths.iter().find_map(|p| { + if p.borrow().received_on(local, remote, true) { + Some(Rc::clone(p)) + } else { + None + } + }) + }) + .unwrap_or_else(|| { + Rc::new(RefCell::new(Path::temporary( + local, + remote, + cc, + pacing, + self.qlog.clone(), + now, + ))) + }) + } + + /// Get a reference to the primary path. This will assert if there is no primary + /// path, which happens at a server prior to receiving a valid Initial packet + /// from a client. So be careful using this method. + pub fn primary(&self) -> PathRef { + self.primary_fallible().unwrap() + } + + /// Get a reference to the primary path. Use this prior to handshake completion. + pub fn primary_fallible(&self) -> Option<PathRef> { + self.primary.as_ref().map(Rc::clone) + } + + /// Returns true if the path is not permanent. + pub fn is_temporary(&self, path: &PathRef) -> bool { + // Ask the path first, which is simpler. + path.borrow().is_temporary() || !self.paths.iter().any(|p| Rc::ptr_eq(p, path)) + } + + fn retire(to_retire: &mut Vec<u64>, retired: &PathRef) { + let seqno = retired + .borrow() + .remote_cid + .as_ref() + .unwrap() + .sequence_number(); + to_retire.push(seqno); + } + + /// Adopt a temporary path as permanent. + /// The first path that is made permanent is made primary. + pub fn make_permanent( + &mut self, + path: &PathRef, + local_cid: Option<ConnectionId>, + remote_cid: RemoteConnectionIdEntry, + ) { + debug_assert!(self.is_temporary(path)); + + // Make sure not to track too many paths. + // This protects index 0, which contains the primary path. + if self.paths.len() >= MAX_PATHS { + debug_assert_eq!(self.paths.len(), MAX_PATHS); + let removed = self.paths.remove(1); + Self::retire(&mut self.to_retire, &removed); + if self + .migration_target + .as_ref() + .map_or(false, |target| Rc::ptr_eq(target, &removed)) + { + qinfo!( + [path.borrow()], + "The migration target path had to be removed" + ); + self.migration_target = None; + } + debug_assert_eq!(Rc::strong_count(&removed), 1); + } + + qdebug!([path.borrow()], "Make permanent"); + path.borrow_mut().make_permanent(local_cid, remote_cid); + self.paths.push(Rc::clone(path)); + if self.primary.is_none() { + assert!(self.select_primary(path).is_none()); + } + } + + /// Select a path as the primary. Returns the old primary path. + /// Using the old path is only necessary if this change in path is a reaction + /// to a migration from a peer, in which case the old path needs to be probed. + #[must_use] + fn select_primary(&mut self, path: &PathRef) -> Option<PathRef> { + qinfo!([path.borrow()], "set as primary path"); + let old_path = self.primary.replace(Rc::clone(path)).map(|old| { + old.borrow_mut().set_primary(false); + old + }); + + // Swap the primary path into slot 0, so that it is protected from eviction. + let idx = self + .paths + .iter() + .enumerate() + .find_map(|(i, p)| if Rc::ptr_eq(p, path) { Some(i) } else { None }) + .expect("migration target should be permanent"); + self.paths.swap(0, idx); + + path.borrow_mut().set_primary(true); + old_path + } + + /// Migrate to the identified path. If `force` is true, the path + /// is forcibly marked as valid and the path is used immediately. + /// Otherwise, migration will occur after probing succeeds. + /// The path is always probed and will be abandoned if probing fails. + /// Returns `true` if the path was migrated. + pub fn migrate(&mut self, path: &PathRef, force: bool, now: Instant) -> bool { + debug_assert!(!self.is_temporary(path)); + if force || path.borrow().is_valid() { + path.borrow_mut().set_valid(now); + mem::drop(self.select_primary(path)); + self.migration_target = None; + } else { + self.migration_target = Some(Rc::clone(path)); + } + path.borrow_mut().probe(); + self.migration_target.is_none() + } + + /// Process elapsed time for active paths. + /// Returns an true if there are viable paths remaining after tidying up. + /// + /// TODO(mt) - the paths should own the RTT estimator, so they can find the PTO + /// for themselves. + pub fn process_timeout(&mut self, now: Instant, pto: Duration) -> bool { + let to_retire = &mut self.to_retire; + let mut primary_failed = false; + self.paths.retain(|p| { + if p.borrow_mut().process_timeout(now, pto) { + true + } else { + qdebug!([p.borrow()], "Retiring path"); + if p.borrow().is_primary() { + primary_failed = true; + } + Self::retire(to_retire, p); + false + } + }); + + if primary_failed { + self.primary = None; + // Find a valid path to fall back to. + if let Some(fallback) = self + .paths + .iter() + .rev() // More recent paths are toward the end. + .find(|p| p.borrow().is_valid()) + { + // Need a clone as `fallback` is borrowed from `self`. + let path = Rc::clone(fallback); + qinfo!([path.borrow()], "Failing over after primary path failed"); + mem::drop(self.select_primary(&path)); + true + } else { + false + } + } else { + true + } + } + + /// Get when the next call to `process_timeout()` should be scheduled. + pub fn next_timeout(&self, pto: Duration) -> Option<Instant> { + self.paths + .iter() + .filter_map(|p| p.borrow().next_timeout(pto)) + .min() + } + + /// Set the identified path to be primary. + /// This panics if `make_permanent` hasn't been called. + pub fn handle_migration(&mut self, path: &PathRef, remote: SocketAddr, now: Instant) { + qtrace!([self.primary().borrow()], "handle_migration"); + // The update here needs to match the checks in `Path::received_on`. + // Here, we update the remote port number to match the source port on the + // datagram that was received. This ensures that we send subsequent + // packets back to the right place. + path.borrow_mut().update_port(remote.port()); + + if path.borrow().is_primary() { + // Update when the path was last regarded as valid. + path.borrow_mut().update(now); + return; + } + + if let Some(old_path) = self.select_primary(path) { + // Need to probe the old path if the peer migrates. + old_path.borrow_mut().probe(); + // TODO(mt) - suppress probing if the path was valid within 3PTO. + } + } + + /// Select a path to send on. This will select the first path that has + /// probes to send, then fall back to the primary path. + pub fn select_path(&self) -> Option<PathRef> { + self.paths + .iter() + .find_map(|p| { + if p.borrow().has_probe() { + Some(Rc::clone(p)) + } else { + None + } + }) + .or_else(|| self.primary.as_ref().map(Rc::clone)) + } + + /// A `PATH_RESPONSE` was received. + /// Returns `true` if migration occurred. + #[must_use] + pub fn path_response(&mut self, response: [u8; 8], now: Instant) -> bool { + // TODO(mt) consider recording an RTT measurement here as we don't train + // RTT for non-primary paths. + for p in &self.paths { + if p.borrow_mut().path_response(response, now) { + // The response was accepted. If this path is one we intend + // to migrate to, then migrate. + if self + .migration_target + .as_ref() + .map_or(false, |target| Rc::ptr_eq(target, p)) + { + let primary = self.migration_target.take(); + mem::drop(self.select_primary(&primary.unwrap())); + return true; + } + break; + } + } + false + } + + /// Retire all of the connection IDs prior to the indicated sequence number. + /// Keep active paths if possible by pulling new connection IDs from the provided store. + /// One slightly non-obvious consequence of this is that if migration is being attempted + /// and the new path cannot obtain a new connection ID, the migration attempt will fail. + pub fn retire_cids(&mut self, retire_prior: u64, store: &mut ConnectionIdStore<[u8; 16]>) { + let to_retire = &mut self.to_retire; + let migration_target = &mut self.migration_target; + + // First, tell the store to release any connection IDs that are too old. + let mut retired = store.retire_prior_to(retire_prior); + to_retire.append(&mut retired); + + self.paths.retain(|p| { + let current = p.borrow().remote_cid.as_ref().unwrap().sequence_number(); + if current < retire_prior { + to_retire.push(current); + let new_cid = store.next(); + let has_replacement = new_cid.is_some(); + // There must be a connection ID available for the primary path as we + // keep that path at the first index. + debug_assert!(!p.borrow().is_primary() || has_replacement); + p.borrow_mut().remote_cid = new_cid; + if !has_replacement + && migration_target + .as_ref() + .map_or(false, |target| Rc::ptr_eq(target, p)) + { + qinfo!( + [p.borrow()], + "NEW_CONNECTION_ID with Retire Prior To forced migration to fail" + ); + *migration_target = None; + } + has_replacement + } else { + true + } + }); + } + + /// Write out any `RETIRE_CONNECTION_ID` frames that are outstanding. + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + while let Some(seqno) = self.to_retire.pop() { + if builder.remaining() < 1 + Encoder::varint_len(seqno) { + self.to_retire.push(seqno); + break; + } + builder.encode_varint(FRAME_TYPE_RETIRE_CONNECTION_ID); + builder.encode_varint(seqno); + tokens.push(RecoveryToken::RetireConnectionId(seqno)); + stats.retire_connection_id += 1; + } + + // Write out any ACK_FREQUENCY frames. + self.primary() + .borrow_mut() + .write_cc_frames(builder, tokens, stats); + } + + pub fn lost_retire_cid(&mut self, lost: u64) { + self.to_retire.push(lost); + } + + pub fn acked_retire_cid(&mut self, acked: u64) { + self.to_retire.retain(|&seqno| seqno != acked); + } + + pub fn lost_ack_frequency(&mut self, lost: &AckRate) { + self.primary().borrow_mut().lost_ack_frequency(lost); + } + + pub fn acked_ack_frequency(&mut self, acked: &AckRate) { + self.primary().borrow_mut().acked_ack_frequency(acked); + } + + /// Get an estimate of the RTT on the primary path. + #[cfg(test)] + pub fn rtt(&self) -> Duration { + // Rather than have this fail when there is no active path, + // make a new RTT esimate and interrogate that. + // That is more expensive, but it should be rare and breaking encapsulation + // is worse, especially as this is only used in tests. + self.primary_fallible() + .map_or(RttEstimate::default().estimate(), |p| { + p.borrow().rtt().estimate() + }) + } + + pub fn set_qlog(&mut self, qlog: NeqoQlog) { + for p in &mut self.paths { + p.borrow_mut().set_qlog(qlog.clone()); + } + self.qlog = qlog; + } +} + +/// The state of a path with respect to address validation. +#[derive(Debug)] +enum ProbeState { + /// The path was last valid at the indicated time. + Valid, + /// The path was previously valid, but a new probe is needed. + ProbeNeeded { probe_count: usize }, + /// The path hasn't been validated, but a probe has been sent. + Probing { + /// The number of probes that have been sent. + probe_count: usize, + /// The probe that was last sent. + data: [u8; 8], + /// Whether the probe was sent in a datagram padded to the path MTU. + mtu: bool, + /// When the probe was sent. + sent: Instant, + }, + /// Validation failed the last time it was attempted. + Failed, +} + +impl ProbeState { + /// Determine whether the current state requires probing. + fn probe_needed(&self) -> bool { + matches!(self, Self::ProbeNeeded { .. }) + } +} + +/// A network path. +/// +/// Paths are used a little bit strangely by connections: +/// they need to encapsulate all the state for a path (which +/// is normal), but that information is not propagated to the +/// `Paths` instance that holds them. This is because the packet +/// processing where changes occur can't hold a reference to the +/// `Paths` instance that owns the `Path`. Any changes to the +/// path are communicated to `Paths` afterwards. +#[derive(Debug)] +pub struct Path { + /// A local socket address. + local: SocketAddr, + /// A remote socket address. + remote: SocketAddr, + /// The connection IDs that we use when sending on this path. + /// This is only needed during the handshake. + local_cid: Option<ConnectionId>, + /// The current connection ID that we are using and its details. + remote_cid: Option<RemoteConnectionIdEntry>, + + /// Whether this is the primary path. + primary: bool, + /// Whether the current path is considered valid. + state: ProbeState, + /// For a path that is not validated, this is `None`. For a validated + /// path, the time that the path was last valid. + validated: Option<Instant>, + /// A path challenge was received and PATH_RESPONSE has not been sent. + challenge: Option<[u8; 8]>, + + /// The round trip time estimate for this path. + rtt: RttEstimate, + /// A packet sender for the path, which includes congestion control and a pacer. + sender: PacketSender, + /// The DSCP/ECN marking to use for outgoing packets on this path. + tos: IpTos, + /// The IP TTL to use for outgoing packets on this path. + ttl: u8, + + /// The number of bytes received on this path. + /// Note that this value might saturate on a long-lived connection, + /// but we only use it before the path is validated. + received_bytes: usize, + /// The number of bytes sent on this path. + sent_bytes: usize, + + /// For logging of events. + qlog: NeqoQlog, +} + +impl Path { + /// Create a path from addresses and a remote connection ID. + /// This is used for migration and for new datagrams. + pub fn temporary( + local: SocketAddr, + remote: SocketAddr, + cc: CongestionControlAlgorithm, + pacing: bool, + qlog: NeqoQlog, + now: Instant, + ) -> Self { + let mut sender = PacketSender::new(cc, pacing, Self::mtu_by_addr(remote.ip()), now); + sender.set_qlog(qlog.clone()); + Self { + local, + remote, + local_cid: None, + remote_cid: None, + primary: false, + state: ProbeState::ProbeNeeded { probe_count: 0 }, + validated: None, + challenge: None, + rtt: RttEstimate::default(), + sender, + tos: IpTos::default(), // TODO: Default to Ect0 when ECN is supported. + ttl: 64, // This is the default TTL on many OSes. + received_bytes: 0, + sent_bytes: 0, + qlog, + } + } + + /// Whether this path is the primary or current path for the connection. + pub fn is_primary(&self) -> bool { + self.primary + } + + /// Whether this path is a temporary one. + pub fn is_temporary(&self) -> bool { + self.remote_cid.is_none() + } + + /// By adding a remote connection ID, we make the path permanent + /// and one that we will later send packets on. + /// If `local_cid` is `None`, the existing value will be kept. + pub(crate) fn make_permanent( + &mut self, + local_cid: Option<ConnectionId>, + remote_cid: RemoteConnectionIdEntry, + ) { + if self.local_cid.is_none() { + self.local_cid = local_cid; + } + self.remote_cid.replace(remote_cid); + } + + /// Determine if this path was the one that the provided datagram was received on. + /// This uses the full local socket address, but ignores the port number on the peer + /// if `flexible` is true, allowing for NAT rebinding that retains the same IP. + fn received_on(&self, local: SocketAddr, remote: SocketAddr, flexible: bool) -> bool { + self.local == local + && self.remote.ip() == remote.ip() + && (flexible || self.remote.port() == remote.port()) + } + + /// Update the remote port number. Any flexibility we allow in `received_on` + /// need to be adjusted at this point. + fn update_port(&mut self, port: u16) { + self.remote.set_port(port); + } + + /// Set whether this path is primary. + pub(crate) fn set_primary(&mut self, primary: bool) { + qtrace!([self], "Make primary {}", primary); + debug_assert!(self.remote_cid.is_some()); + self.primary = primary; + if !primary { + self.sender.discard_in_flight(); + } + } + + /// Set the current path as valid. This updates the time that the path was + /// last validated and cancels any path validation. + pub fn set_valid(&mut self, now: Instant) { + qdebug!([self], "Path validated {:?}", now); + self.state = ProbeState::Valid; + self.validated = Some(now); + } + + /// Update the last use of this path, if it is valid. + /// This will keep the path active slightly longer. + pub fn update(&mut self, now: Instant) { + if self.validated.is_some() { + self.validated = Some(now); + } + } + + fn mtu_by_addr(addr: IpAddr) -> usize { + match addr { + IpAddr::V4(_) => PATH_MTU_V4, + IpAddr::V6(_) => PATH_MTU_V6, + } + } + + /// Get the path MTU. This is currently fixed based on IP version. + pub fn mtu(&self) -> usize { + Self::mtu_by_addr(self.remote.ip()) + } + + /// Get the first local connection ID. + /// Only do this for the primary path during the handshake. + pub fn local_cid(&self) -> &ConnectionId { + self.local_cid.as_ref().unwrap() + } + + /// Set the remote connection ID based on the peer's choice. + /// This is only valid during the handshake. + pub fn set_remote_cid(&mut self, cid: ConnectionIdRef) { + self.remote_cid + .as_mut() + .unwrap() + .update_cid(ConnectionId::from(cid)); + } + + /// Access the remote connection ID. + pub fn remote_cid(&self) -> &ConnectionId { + self.remote_cid.as_ref().unwrap().connection_id() + } + + /// Set the stateless reset token for the connection ID that is currently in use. + /// Panics if the sequence number is non-zero as this is only necessary during + /// the handshake; all other connection IDs are initialized with a token. + pub fn set_reset_token(&mut self, token: [u8; 16]) { + self.remote_cid + .as_mut() + .unwrap() + .set_stateless_reset_token(token); + } + + /// Determine if the provided token is a stateless reset token. + pub fn is_stateless_reset(&self, token: &[u8; 16]) -> bool { + self.remote_cid + .as_ref() + .map_or(false, |rcid| rcid.is_stateless_reset(token)) + } + + /// Make a datagram. + pub fn datagram<V: Into<Vec<u8>>>(&self, payload: V) -> Datagram { + Datagram::new(self.local, self.remote, self.tos, Some(self.ttl), payload) + } + + /// Get local address as `SocketAddr` + pub fn local_address(&self) -> SocketAddr { + self.local + } + + /// Get remote address as `SocketAddr` + pub fn remote_address(&self) -> SocketAddr { + self.remote + } + + /// Whether the path has been validated. + pub fn is_valid(&self) -> bool { + self.validated.is_some() + } + + /// Handle a `PATH_RESPONSE` frame. Returns true if the response was accepted. + pub fn path_response(&mut self, response: [u8; 8], now: Instant) -> bool { + if let ProbeState::Probing { data, mtu, .. } = &mut self.state { + if response == *data { + let need_full_probe = !*mtu; + self.set_valid(now); + if need_full_probe { + qdebug!([self], "Sub-MTU probe successful, reset probe count"); + self.probe(); + } + true + } else { + false + } + } else { + false + } + } + + /// The path has been challenged. This generates a response. + /// This only generates a single response at a time. + pub fn challenged(&mut self, challenge: [u8; 8]) { + self.challenge = Some(challenge.to_owned()); + } + + /// At the next opportunity, send a probe. + /// If the probe count has been exhausted already, marks the path as failed. + fn probe(&mut self) { + let probe_count = match &self.state { + ProbeState::Probing { probe_count, .. } => *probe_count + 1, + ProbeState::ProbeNeeded { probe_count, .. } => *probe_count, + _ => 0, + }; + self.state = if probe_count >= MAX_PATH_PROBES { + qinfo!([self], "Probing failed"); + ProbeState::Failed + } else { + qdebug!([self], "Initiating probe"); + ProbeState::ProbeNeeded { probe_count } + }; + } + + /// Returns true if this path have any probing frames to send. + pub fn has_probe(&self) -> bool { + self.challenge.is_some() || self.state.probe_needed() + } + + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + stats: &mut FrameStats, + mtu: bool, // Whether the packet we're writing into will be a full MTU. + now: Instant, + ) -> bool { + if builder.remaining() < 9 { + return false; + } + + // Send PATH_RESPONSE. + let resp_sent = if let Some(challenge) = self.challenge.take() { + qtrace!([self], "Responding to path challenge {}", hex(challenge)); + builder.encode_varint(FRAME_TYPE_PATH_RESPONSE); + builder.encode(&challenge[..]); + + // These frames are not retransmitted in the usual fashion. + // There is no token, therefore we need to count `all` specially. + stats.path_response += 1; + stats.all += 1; + + if builder.remaining() < 9 { + return true; + } + true + } else { + false + }; + + // Send PATH_CHALLENGE. + if let ProbeState::ProbeNeeded { probe_count } = self.state { + qtrace!([self], "Initiating path challenge {}", probe_count); + let data = <[u8; 8]>::try_from(&random(8)[..]).unwrap(); + builder.encode_varint(FRAME_TYPE_PATH_CHALLENGE); + builder.encode(&data); + + // As above, no recovery token. + stats.path_challenge += 1; + stats.all += 1; + + self.state = ProbeState::Probing { + probe_count, + data, + mtu, + sent: now, + }; + true + } else { + resp_sent + } + } + + /// Write `ACK_FREQUENCY` frames. + pub fn write_cc_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + self.rtt.write_frames(builder, tokens, stats); + } + + pub fn lost_ack_frequency(&mut self, lost: &AckRate) { + self.rtt.frame_lost(lost); + } + + pub fn acked_ack_frequency(&mut self, acked: &AckRate) { + self.rtt.frame_acked(acked); + } + + /// Process a timer for this path. + /// This returns true if the path is viable and can be kept alive. + pub fn process_timeout(&mut self, now: Instant, pto: Duration) -> bool { + if let ProbeState::Probing { sent, .. } = &self.state { + if now >= *sent + pto { + self.probe(); + } + } + if let ProbeState::Failed = self.state { + // Retire failed paths immediately. + false + } else if self.primary { + // Keep valid primary paths otherwise. + true + } else if let ProbeState::Valid = self.state { + // Retire validated, non-primary paths. + // Allow more than `MAX_PATH_PROBES` times the PTO so that an old + // path remains around until after a previous path fails. + let count = u32::try_from(MAX_PATH_PROBES + 1).unwrap(); + self.validated.unwrap() + (pto * count) > now + } else { + // Keep paths that are being actively probed. + true + } + } + + /// Return the next time that this path needs servicing. + /// This only considers retransmissions of probes, not cleanup of the path. + /// If there is no other activity, then there is no real need to schedule a + /// timer to cleanup old paths. + pub fn next_timeout(&self, pto: Duration) -> Option<Instant> { + if let ProbeState::Probing { sent, .. } = &self.state { + Some(*sent + pto) + } else { + None + } + } + + /// Get the RTT estimator for this path. + pub fn rtt(&self) -> &RttEstimate { + &self.rtt + } + + /// Mutably borrow the RTT estimator for this path. + pub fn rtt_mut(&mut self) -> &mut RttEstimate { + &mut self.rtt + } + + /// Read-only access to the owned sender. + pub fn sender(&self) -> &PacketSender { + &self.sender + } + + /// Pass on RTT configuration: the maximum acknowledgment delay of the peer, + /// and maybe the minimum delay. + pub fn set_ack_delay( + &mut self, + max_ack_delay: Duration, + min_ack_delay: Option<Duration>, + ack_ratio: u8, + ) { + let ack_delay = min_ack_delay.map_or_else( + || PeerAckDelay::fixed(max_ack_delay), + |m| { + PeerAckDelay::flexible( + max_ack_delay, + m, + ack_ratio, + self.sender.cwnd(), + self.mtu(), + self.rtt.estimate(), + ) + }, + ); + self.rtt.set_ack_delay(ack_delay); + } + + /// Initialize the RTT for the path based on an existing estimate. + pub fn prime_rtt(&mut self, rtt: &RttEstimate) { + self.rtt.prime_rtt(rtt); + } + + /// Record received bytes for the path. + pub fn add_received(&mut self, count: usize) { + self.received_bytes = self.received_bytes.saturating_add(count); + } + + /// Record sent bytes for the path. + pub fn add_sent(&mut self, count: usize) { + self.sent_bytes = self.sent_bytes.saturating_add(count); + } + + /// Record a packet as having been sent on this path. + pub fn packet_sent(&mut self, sent: &mut SentPacket) { + if !self.is_primary() { + sent.clear_primary_path(); + } + self.sender.on_packet_sent(sent, self.rtt.estimate()); + } + + /// Discard a packet that previously might have been in-flight. + pub fn discard_packet(&mut self, sent: &SentPacket, now: Instant, stats: &mut Stats) { + if self.rtt.first_sample_time().is_none() { + // When discarding a packet there might not be a good RTT estimate. + // But discards only occur after receiving something, so that means + // that there is some RTT information, which is better than nothing. + // Two cases: 1. at the client when handling a Retry and + // 2. at the server when disposing the Initial packet number space. + qinfo!( + [self], + "discarding a packet without an RTT estimate; guessing RTT={:?}", + now - sent.time_sent + ); + stats.rtt_init_guess = true; + self.rtt.update( + &mut self.qlog, + now - sent.time_sent, + Duration::new(0, 0), + false, + now, + ); + } + + self.sender.discard(sent); + } + + /// Record packets as acknowledged with the sender. + pub fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], now: Instant) { + debug_assert!(self.is_primary()); + self.sender.on_packets_acked(acked_pkts, &self.rtt, now); + } + + /// Record packets as lost with the sender. + pub fn on_packets_lost( + &mut self, + prev_largest_acked_sent: Option<Instant>, + space: PacketNumberSpace, + lost_packets: &[SentPacket], + ) { + debug_assert!(self.is_primary()); + let cwnd_reduced = self.sender.on_packets_lost( + self.rtt.first_sample_time(), + prev_largest_acked_sent, + self.rtt.pto(space), // Important: the base PTO, not adjusted. + lost_packets, + ); + if cwnd_reduced { + self.rtt.update_ack_delay(self.sender.cwnd(), self.mtu()); + } + } + + /// Get the number of bytes that can be written to this path. + pub fn amplification_limit(&self) -> usize { + if matches!(self.state, ProbeState::Failed) { + 0 + } else if self.is_valid() { + usize::MAX + } else { + self.received_bytes + .checked_mul(3) + .map_or(usize::MAX, |limit| { + let budget = if limit == 0 { + // If we have received absolutely nothing thus far, then this endpoint + // is the one initiating communication on this path. Allow enough space for + // probing. + self.mtu() * 5 + } else { + limit + }; + budget.saturating_sub(self.sent_bytes) + }) + } + } + + /// Update the `NeqoQLog` instance. + pub fn set_qlog(&mut self, qlog: NeqoQlog) { + self.sender.set_qlog(qlog); + } +} + +impl Display for Path { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.is_primary() { + write!(f, "pri-")?; // primary + } + if !self.is_valid() { + write!(f, "unv-")?; // unvalidated + } + write!(f, "path")?; + if let Some(entry) = self.remote_cid.as_ref() { + write!(f, ":{}", entry.connection_id())?; + } + write!(f, " {}->{}", self.local, self.remote)?; + Ok(()) + } +} diff --git a/third_party/rust/neqo-transport/src/qlog.rs b/third_party/rust/neqo-transport/src/qlog.rs new file mode 100644 index 0000000000..434395fd23 --- /dev/null +++ b/third_party/rust/neqo-transport/src/qlog.rs @@ -0,0 +1,563 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Functions that handle capturing QLOG traces. + +use std::{ + convert::TryFrom, + ops::{Deref, RangeInclusive}, + string::String, + time::Duration, +}; + +use neqo_common::{hex, qinfo, qlog::NeqoQlog, Decoder}; +use qlog::events::{ + connectivity::{ConnectionStarted, ConnectionState, ConnectionStateUpdated}, + quic::{ + AckedRanges, ErrorSpace, MetricsUpdated, PacketDropped, PacketHeader, PacketLost, + PacketReceived, PacketSent, QuicFrame, StreamType, VersionInformation, + }, + EventData, RawInfo, +}; +use smallvec::SmallVec; + +use crate::{ + connection::State, + frame::{CloseError, Frame}, + packet::{DecryptedPacket, PacketNumber, PacketType, PublicPacket}, + path::PathRef, + stream_id::StreamType as NeqoStreamType, + tparams::{self, TransportParametersHandler}, + tracking::SentPacket, + version::{Version, VersionConfig, WireVersion}, +}; + +pub fn connection_tparams_set(qlog: &mut NeqoQlog, tph: &TransportParametersHandler) { + qlog.add_event_data(|| { + let remote = tph.remote(); + let ev_data = EventData::TransportParametersSet( + qlog::events::quic::TransportParametersSet { + owner: None, + resumption_allowed: None, + early_data_enabled: None, + tls_cipher: None, + aead_tag_length: None, + original_destination_connection_id: remote + .get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID) + .map(hex), + initial_source_connection_id: None, + retry_source_connection_id: None, + stateless_reset_token: remote.get_bytes(tparams::STATELESS_RESET_TOKEN).map(hex), + disable_active_migration: if remote.get_empty(tparams::DISABLE_MIGRATION) { + Some(true) + } else { + None + }, + max_idle_timeout: Some(remote.get_integer(tparams::IDLE_TIMEOUT)), + max_udp_payload_size: Some(remote.get_integer(tparams::MAX_UDP_PAYLOAD_SIZE) as u32), + ack_delay_exponent: Some(remote.get_integer(tparams::ACK_DELAY_EXPONENT) as u16), + max_ack_delay: Some(remote.get_integer(tparams::MAX_ACK_DELAY) as u16), + active_connection_id_limit: Some(remote.get_integer(tparams::ACTIVE_CONNECTION_ID_LIMIT) as u32), + initial_max_data: Some(remote.get_integer(tparams::INITIAL_MAX_DATA)), + initial_max_stream_data_bidi_local: Some(remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL)), + initial_max_stream_data_bidi_remote: Some(remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE)), + initial_max_stream_data_uni: Some(remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI)), + initial_max_streams_bidi: Some(remote.get_integer(tparams::INITIAL_MAX_STREAMS_BIDI)), + initial_max_streams_uni: Some(remote.get_integer(tparams::INITIAL_MAX_STREAMS_UNI)), + preferred_address: remote.get_preferred_address().and_then(|(paddr, cid)| { + Some(qlog::events::quic::PreferredAddress { + ip_v4: paddr.ipv4()?.ip().to_string(), + ip_v6: paddr.ipv6()?.ip().to_string(), + port_v4: paddr.ipv4()?.port(), + port_v6: paddr.ipv6()?.port(), + connection_id: cid.connection_id().to_string(), + stateless_reset_token: hex(cid.reset_token()), + }) + }), + }); + + Some(ev_data) + }); +} + +pub fn server_connection_started(qlog: &mut NeqoQlog, path: &PathRef) { + connection_started(qlog, path); +} + +pub fn client_connection_started(qlog: &mut NeqoQlog, path: &PathRef) { + connection_started(qlog, path); +} + +fn connection_started(qlog: &mut NeqoQlog, path: &PathRef) { + qlog.add_event_data(|| { + let p = path.deref().borrow(); + let ev_data = EventData::ConnectionStarted(ConnectionStarted { + ip_version: if p.local_address().ip().is_ipv4() { + Some("ipv4".into()) + } else { + Some("ipv6".into()) + }, + src_ip: format!("{}", p.local_address().ip()), + dst_ip: format!("{}", p.remote_address().ip()), + protocol: Some("QUIC".into()), + src_port: p.local_address().port().into(), + dst_port: p.remote_address().port().into(), + src_cid: Some(format!("{}", p.local_cid())), + dst_cid: Some(format!("{}", p.remote_cid())), + }); + + Some(ev_data) + }); +} + +pub fn connection_state_updated(qlog: &mut NeqoQlog, new: &State) { + qlog.add_event_data(|| { + let ev_data = EventData::ConnectionStateUpdated(ConnectionStateUpdated { + old: None, + new: match new { + State::Init | State::WaitInitial => ConnectionState::Attempted, + State::WaitVersion | State::Handshaking => ConnectionState::HandshakeStarted, + State::Connected => ConnectionState::HandshakeCompleted, + State::Confirmed => ConnectionState::HandshakeConfirmed, + State::Closing { .. } => ConnectionState::Closing, + State::Draining { .. } => ConnectionState::Draining, + State::Closed { .. } => ConnectionState::Closed, + }, + }); + + Some(ev_data) + }); +} + +pub fn client_version_information_initiated(qlog: &mut NeqoQlog, version_config: &VersionConfig) { + qlog.add_event_data(|| { + Some(EventData::VersionInformation(VersionInformation { + client_versions: Some( + version_config + .all() + .iter() + .map(|v| format!("{:02x}", v.wire_version())) + .collect(), + ), + server_versions: None, + chosen_version: Some(format!("{:02x}", version_config.initial().wire_version())), + })) + }); +} + +pub fn client_version_information_negotiated( + qlog: &mut NeqoQlog, + client: &[Version], + server: &[WireVersion], + chosen: Version, +) { + qlog.add_event_data(|| { + Some(EventData::VersionInformation(VersionInformation { + client_versions: Some( + client + .iter() + .map(|v| format!("{:02x}", v.wire_version())) + .collect(), + ), + server_versions: Some(server.iter().map(|v| format!("{v:02x}")).collect()), + chosen_version: Some(format!("{:02x}", chosen.wire_version())), + })) + }); +} + +pub fn server_version_information_failed( + qlog: &mut NeqoQlog, + server: &[Version], + client: WireVersion, +) { + qlog.add_event_data(|| { + Some(EventData::VersionInformation(VersionInformation { + client_versions: Some(vec![format!("{client:02x}")]), + server_versions: Some( + server + .iter() + .map(|v| format!("{:02x}", v.wire_version())) + .collect(), + ), + chosen_version: None, + })) + }); +} + +pub fn packet_sent( + qlog: &mut NeqoQlog, + pt: PacketType, + pn: PacketNumber, + plen: usize, + body: &[u8], +) { + qlog.add_event_with_stream(|stream| { + let mut d = Decoder::from(body); + let header = PacketHeader::with_type(to_qlog_pkt_type(pt), Some(pn), None, None, None); + let raw = RawInfo { + length: Some(plen as u64), + payload_length: None, + data: None, + }; + + let mut frames = SmallVec::new(); + while d.remaining() > 0 { + if let Ok(f) = Frame::decode(&mut d) { + frames.push(frame_to_qlogframe(&f)) + } else { + qinfo!("qlog: invalid frame"); + break; + } + } + + let ev_data = EventData::PacketSent(PacketSent { + header, + frames: Some(frames), + is_coalesced: None, + retry_token: None, + stateless_reset_token: None, + supported_versions: None, + raw: Some(raw), + datagram_id: None, + send_at_time: None, + trigger: None, + }); + + stream.add_event_data_now(ev_data) + }); +} + +pub fn packet_dropped(qlog: &mut NeqoQlog, public_packet: &PublicPacket) { + qlog.add_event_data(|| { + let header = PacketHeader::with_type( + to_qlog_pkt_type(public_packet.packet_type()), + None, + None, + None, + None, + ); + let raw = RawInfo { + length: Some(public_packet.len() as u64), + payload_length: None, + data: None, + }; + + let ev_data = EventData::PacketDropped(PacketDropped { + header: Some(header), + raw: Some(raw), + datagram_id: None, + details: None, + trigger: None, + }); + + Some(ev_data) + }); +} + +pub fn packets_lost(qlog: &mut NeqoQlog, pkts: &[SentPacket]) { + qlog.add_event_with_stream(|stream| { + for pkt in pkts { + let header = + PacketHeader::with_type(to_qlog_pkt_type(pkt.pt), Some(pkt.pn), None, None, None); + + let ev_data = EventData::PacketLost(PacketLost { + header: Some(header), + trigger: None, + frames: None, + }); + + stream.add_event_data_now(ev_data)?; + } + Ok(()) + }); +} + +pub fn packet_received( + qlog: &mut NeqoQlog, + public_packet: &PublicPacket, + payload: &DecryptedPacket, +) { + qlog.add_event_with_stream(|stream| { + let mut d = Decoder::from(&payload[..]); + + let header = PacketHeader::with_type( + to_qlog_pkt_type(public_packet.packet_type()), + Some(payload.pn()), + None, + None, + None, + ); + let raw = RawInfo { + length: Some(public_packet.len() as u64), + payload_length: None, + data: None, + }; + + let mut frames = Vec::new(); + + while d.remaining() > 0 { + if let Ok(f) = Frame::decode(&mut d) { + frames.push(frame_to_qlogframe(&f)) + } else { + qinfo!("qlog: invalid frame"); + break; + } + } + + let ev_data = EventData::PacketReceived(PacketReceived { + header, + frames: Some(frames), + is_coalesced: None, + retry_token: None, + stateless_reset_token: None, + supported_versions: None, + raw: Some(raw), + datagram_id: None, + trigger: None, + }); + + stream.add_event_data_now(ev_data) + }); +} + +#[allow(dead_code)] +pub enum QlogMetric { + MinRtt(Duration), + SmoothedRtt(Duration), + LatestRtt(Duration), + RttVariance(u64), + MaxAckDelay(u64), + PtoCount(usize), + CongestionWindow(usize), + BytesInFlight(usize), + SsThresh(usize), + PacketsInFlight(u64), + InRecovery(bool), + PacingRate(u64), +} + +pub fn metrics_updated(qlog: &mut NeqoQlog, updated_metrics: &[QlogMetric]) { + debug_assert!(!updated_metrics.is_empty()); + + qlog.add_event_data(|| { + let mut min_rtt: Option<f32> = None; + let mut smoothed_rtt: Option<f32> = None; + let mut latest_rtt: Option<f32> = None; + let mut rtt_variance: Option<f32> = None; + let mut pto_count: Option<u16> = None; + let mut congestion_window: Option<u64> = None; + let mut bytes_in_flight: Option<u64> = None; + let mut ssthresh: Option<u64> = None; + let mut packets_in_flight: Option<u64> = None; + let mut pacing_rate: Option<u64> = None; + + for metric in updated_metrics { + match metric { + QlogMetric::MinRtt(v) => min_rtt = Some(v.as_secs_f32() * 1000.0), + QlogMetric::SmoothedRtt(v) => smoothed_rtt = Some(v.as_secs_f32() * 1000.0), + QlogMetric::LatestRtt(v) => latest_rtt = Some(v.as_secs_f32() * 1000.0), + QlogMetric::RttVariance(v) => rtt_variance = Some(*v as f32), + QlogMetric::PtoCount(v) => pto_count = Some(u16::try_from(*v).unwrap()), + QlogMetric::CongestionWindow(v) => { + congestion_window = Some(u64::try_from(*v).unwrap()); + } + QlogMetric::BytesInFlight(v) => bytes_in_flight = Some(u64::try_from(*v).unwrap()), + QlogMetric::SsThresh(v) => ssthresh = Some(u64::try_from(*v).unwrap()), + QlogMetric::PacketsInFlight(v) => packets_in_flight = Some(*v), + QlogMetric::PacingRate(v) => pacing_rate = Some(*v), + _ => (), + } + } + + let ev_data = EventData::MetricsUpdated(MetricsUpdated { + min_rtt, + smoothed_rtt, + latest_rtt, + rtt_variance, + pto_count, + congestion_window, + bytes_in_flight, + ssthresh, + packets_in_flight, + pacing_rate, + }); + + Some(ev_data) + }); +} + +// Helper functions + +fn frame_to_qlogframe(frame: &Frame) -> QuicFrame { + match frame { + Frame::Padding => QuicFrame::Padding, + Frame::Ping => QuicFrame::Ping, + Frame::Ack { + largest_acknowledged, + ack_delay, + first_ack_range, + ack_ranges, + } => { + let ranges = + Frame::decode_ack_frame(*largest_acknowledged, *first_ack_range, ack_ranges).ok(); + + let acked_ranges = ranges.map(|all| { + AckedRanges::Double( + all.into_iter() + .map(RangeInclusive::into_inner) + .collect::<Vec<_>>(), + ) + }); + + QuicFrame::Ack { + ack_delay: Some(*ack_delay as f32 / 1000.0), + acked_ranges, + ect1: None, + ect0: None, + ce: None, + } + } + Frame::ResetStream { + stream_id, + application_error_code, + final_size, + } => QuicFrame::ResetStream { + stream_id: stream_id.as_u64(), + error_code: *application_error_code, + final_size: *final_size, + }, + Frame::StopSending { + stream_id, + application_error_code, + } => QuicFrame::StopSending { + stream_id: stream_id.as_u64(), + error_code: *application_error_code, + }, + Frame::Crypto { offset, data } => QuicFrame::Crypto { + offset: *offset, + length: data.len() as u64, + }, + Frame::NewToken { token } => QuicFrame::NewToken { + token: qlog::Token { + ty: Some(qlog::TokenType::Retry), + details: None, + raw: Some(RawInfo { + data: Some(hex(token)), + length: Some(token.len() as u64), + payload_length: None, + }), + }, + }, + Frame::Stream { + fin, + stream_id, + offset, + data, + .. + } => QuicFrame::Stream { + stream_id: stream_id.as_u64(), + offset: *offset, + length: data.len() as u64, + fin: Some(*fin), + raw: None, + }, + Frame::MaxData { maximum_data } => QuicFrame::MaxData { + maximum: *maximum_data, + }, + Frame::MaxStreamData { + stream_id, + maximum_stream_data, + } => QuicFrame::MaxStreamData { + stream_id: stream_id.as_u64(), + maximum: *maximum_stream_data, + }, + Frame::MaxStreams { + stream_type, + maximum_streams, + } => QuicFrame::MaxStreams { + stream_type: match stream_type { + NeqoStreamType::BiDi => StreamType::Bidirectional, + NeqoStreamType::UniDi => StreamType::Unidirectional, + }, + maximum: *maximum_streams, + }, + Frame::DataBlocked { data_limit } => QuicFrame::DataBlocked { limit: *data_limit }, + Frame::StreamDataBlocked { + stream_id, + stream_data_limit, + } => QuicFrame::StreamDataBlocked { + stream_id: stream_id.as_u64(), + limit: *stream_data_limit, + }, + Frame::StreamsBlocked { + stream_type, + stream_limit, + } => QuicFrame::StreamsBlocked { + stream_type: match stream_type { + NeqoStreamType::BiDi => StreamType::Bidirectional, + NeqoStreamType::UniDi => StreamType::Unidirectional, + }, + limit: *stream_limit, + }, + Frame::NewConnectionId { + sequence_number, + retire_prior, + connection_id, + stateless_reset_token, + } => QuicFrame::NewConnectionId { + sequence_number: *sequence_number as u32, + retire_prior_to: *retire_prior as u32, + connection_id_length: Some(connection_id.len() as u8), + connection_id: hex(connection_id), + stateless_reset_token: Some(hex(stateless_reset_token)), + }, + Frame::RetireConnectionId { sequence_number } => QuicFrame::RetireConnectionId { + sequence_number: *sequence_number as u32, + }, + Frame::PathChallenge { data } => QuicFrame::PathChallenge { + data: Some(hex(data)), + }, + Frame::PathResponse { data } => QuicFrame::PathResponse { + data: Some(hex(data)), + }, + Frame::ConnectionClose { + error_code, + frame_type, + reason_phrase, + } => QuicFrame::ConnectionClose { + error_space: match error_code { + CloseError::Transport(_) => Some(ErrorSpace::TransportError), + CloseError::Application(_) => Some(ErrorSpace::ApplicationError), + }, + error_code: Some(error_code.code()), + error_code_value: Some(0), + reason: Some(String::from_utf8_lossy(reason_phrase).to_string()), + trigger_frame_type: Some(*frame_type), + }, + Frame::HandshakeDone => QuicFrame::HandshakeDone, + Frame::AckFrequency { .. } => QuicFrame::Unknown { + frame_type_value: None, + raw_frame_type: frame.get_type(), + raw: None, + }, + Frame::Datagram { data, .. } => QuicFrame::Datagram { + length: data.len() as u64, + raw: None, + }, + } +} + +fn to_qlog_pkt_type(ptype: PacketType) -> qlog::events::quic::PacketType { + match ptype { + PacketType::Initial => qlog::events::quic::PacketType::Initial, + PacketType::Handshake => qlog::events::quic::PacketType::Handshake, + PacketType::ZeroRtt => qlog::events::quic::PacketType::ZeroRtt, + PacketType::Short => qlog::events::quic::PacketType::OneRtt, + PacketType::Retry => qlog::events::quic::PacketType::Retry, + PacketType::VersionNegotiation => qlog::events::quic::PacketType::VersionNegotiation, + PacketType::OtherVersion => qlog::events::quic::PacketType::Unknown, + } +} diff --git a/third_party/rust/neqo-transport/src/quic_datagrams.rs b/third_party/rust/neqo-transport/src/quic_datagrams.rs new file mode 100644 index 0000000000..07f3594768 --- /dev/null +++ b/third_party/rust/neqo-transport/src/quic_datagrams.rs @@ -0,0 +1,185 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// https://datatracker.ietf.org/doc/html/draft-ietf-quic-datagram + +use std::{cmp::min, collections::VecDeque, convert::TryFrom}; + +use neqo_common::Encoder; + +use crate::{ + events::OutgoingDatagramOutcome, + frame::{FRAME_TYPE_DATAGRAM, FRAME_TYPE_DATAGRAM_WITH_LEN}, + packet::PacketBuilder, + recovery::RecoveryToken, + ConnectionEvents, Error, Res, Stats, +}; + +pub const MAX_QUIC_DATAGRAM: u64 = 65535; + +#[derive(Debug, Clone, Copy)] +pub enum DatagramTracking { + None, + Id(u64), +} + +impl From<Option<u64>> for DatagramTracking { + fn from(v: Option<u64>) -> Self { + match v { + Some(id) => Self::Id(id), + None => Self::None, + } + } +} + +impl From<DatagramTracking> for Option<u64> { + fn from(v: DatagramTracking) -> Self { + match v { + DatagramTracking::Id(id) => Some(id), + DatagramTracking::None => None, + } + } +} + +struct QuicDatagram { + data: Vec<u8>, + tracking: DatagramTracking, +} + +impl QuicDatagram { + fn tracking(&self) -> &DatagramTracking { + &self.tracking + } +} + +impl AsRef<[u8]> for QuicDatagram { + #[must_use] + fn as_ref(&self) -> &[u8] { + &self.data[..] + } +} + +pub struct QuicDatagrams { + /// The max size of a datagram that would be acceptable. + local_datagram_size: u64, + /// The max size of a datagram that would be acceptable by the peer. + remote_datagram_size: u64, + max_queued_outgoing_datagrams: usize, + /// The max number of datagrams that will be queued in connection events. + /// If the number is exceeded, the oldest datagram will be dropped. + max_queued_incoming_datagrams: usize, + /// Datagram queued for sending. + datagrams: VecDeque<QuicDatagram>, + conn_events: ConnectionEvents, +} + +impl QuicDatagrams { + pub fn new( + local_datagram_size: u64, + max_queued_outgoing_datagrams: usize, + max_queued_incoming_datagrams: usize, + conn_events: ConnectionEvents, + ) -> Self { + Self { + local_datagram_size, + remote_datagram_size: 0, + max_queued_outgoing_datagrams, + max_queued_incoming_datagrams, + datagrams: VecDeque::with_capacity(max_queued_outgoing_datagrams), + conn_events, + } + } + + pub fn remote_datagram_size(&self) -> u64 { + self.remote_datagram_size + } + + pub fn set_remote_datagram_size(&mut self, v: u64) { + self.remote_datagram_size = min(v, MAX_QUIC_DATAGRAM); + } + + /// This function tries to write a datagram frame into a packet. + /// If the frame does not fit into the packet, the datagram will + /// be dropped and a DatagramLost event will be posted. + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut Stats, + ) { + while let Some(dgram) = self.datagrams.pop_front() { + let len = dgram.as_ref().len(); + if builder.remaining() > len { + // We need 1 more than `len` for the Frame type. + let length_len = Encoder::varint_len(u64::try_from(len).unwrap()); + // Include a length if there is space for another frame after this one. + if builder.remaining() >= 1 + length_len + len + PacketBuilder::MINIMUM_FRAME_SIZE { + builder.encode_varint(FRAME_TYPE_DATAGRAM_WITH_LEN); + builder.encode_vvec(dgram.as_ref()); + } else { + builder.encode_varint(FRAME_TYPE_DATAGRAM); + builder.encode(dgram.as_ref()); + builder.mark_full(); + } + debug_assert!(builder.len() <= builder.limit()); + stats.frame_tx.datagram += 1; + tokens.push(RecoveryToken::Datagram(*dgram.tracking())); + } else if tokens.is_empty() { + // If the packet is empty, except packet headers, and the + // datagram cannot fit, drop it. + // Also continue trying to write the next QuicDatagram. + self.conn_events + .datagram_outcome(dgram.tracking(), OutgoingDatagramOutcome::DroppedTooBig); + stats.datagram_tx.dropped_too_big += 1; + } else { + self.datagrams.push_front(dgram); + // Try later on an empty packet. + return; + } + } + } + + /// Returns true if there was an unsent datagram that has been dismissed. + /// + /// # Error + /// + /// The function returns `TooMuchData` if the supply buffer is bigger than + /// the allowed remote datagram size. The funcion does not check if the + /// datagram can fit into a packet (i.e. MTU limit). This is checked during + /// creation of an actual packet and the datagram will be dropped if it does + /// not fit into the packet. + pub fn add_datagram( + &mut self, + buf: &[u8], + tracking: DatagramTracking, + stats: &mut Stats, + ) -> Res<()> { + if u64::try_from(buf.len()).unwrap() > self.remote_datagram_size { + return Err(Error::TooMuchData); + } + if self.datagrams.len() == self.max_queued_outgoing_datagrams { + self.conn_events.datagram_outcome( + self.datagrams.pop_front().unwrap().tracking(), + OutgoingDatagramOutcome::DroppedQueueFull, + ); + stats.datagram_tx.dropped_queue_full += 1; + } + self.datagrams.push_back(QuicDatagram { + data: buf.to_vec(), + tracking, + }); + Ok(()) + } + + pub fn handle_datagram(&self, data: &[u8], stats: &mut Stats) -> Res<()> { + if self.local_datagram_size < u64::try_from(data.len()).unwrap() { + return Err(Error::ProtocolViolation); + } + self.conn_events + .add_datagram(self.max_queued_incoming_datagrams, data, stats); + Ok(()) + } +} diff --git a/third_party/rust/neqo-transport/src/recovery.rs b/third_party/rust/neqo-transport/src/recovery.rs new file mode 100644 index 0000000000..d90989b486 --- /dev/null +++ b/third_party/rust/neqo-transport/src/recovery.rs @@ -0,0 +1,1610 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Tracking of sent packets and detecting their loss. + +#![deny(clippy::pedantic)] + +use std::{ + cmp::{max, min}, + collections::BTreeMap, + convert::TryFrom, + mem, + ops::RangeInclusive, + time::{Duration, Instant}, +}; + +use neqo_common::{qdebug, qinfo, qlog::NeqoQlog, qtrace, qwarn}; +use smallvec::{smallvec, SmallVec}; + +use crate::{ + ackrate::AckRate, + cid::ConnectionIdEntry, + crypto::CryptoRecoveryToken, + packet::PacketNumber, + path::{Path, PathRef}, + qlog::{self, QlogMetric}, + quic_datagrams::DatagramTracking, + rtt::RttEstimate, + send_stream::SendStreamRecoveryToken, + stats::{Stats, StatsCell}, + stream_id::{StreamId, StreamType}, + tracking::{AckToken, PacketNumberSpace, PacketNumberSpaceSet, SentPacket}, +}; + +pub(crate) const PACKET_THRESHOLD: u64 = 3; +/// `ACK_ONLY_SIZE_LIMIT` is the minimum size of the congestion window. +/// If the congestion window is this small, we will only send ACK frames. +pub(crate) const ACK_ONLY_SIZE_LIMIT: usize = 256; +/// The maximum number of packets we send on a PTO. +/// And the maximum number to declare lost when the PTO timer is hit. +pub const MAX_PTO_PACKET_COUNT: usize = 2; +/// The preferred limit on the number of packets that are tracked. +/// If we exceed this number, we start sending `PING` frames sooner to +/// force the peer to acknowledge some of them. +pub(crate) const MAX_OUTSTANDING_UNACK: usize = 200; +/// Disable PING until this many packets are outstanding. +pub(crate) const MIN_OUTSTANDING_UNACK: usize = 16; +/// The scale we use for the fast PTO feature. +pub const FAST_PTO_SCALE: u8 = 100; + +#[derive(Debug, Clone)] +#[allow(clippy::module_name_repetitions)] +pub enum StreamRecoveryToken { + Stream(SendStreamRecoveryToken), + ResetStream { + stream_id: StreamId, + }, + StopSending { + stream_id: StreamId, + }, + + MaxData(u64), + DataBlocked(u64), + + MaxStreamData { + stream_id: StreamId, + max_data: u64, + }, + StreamDataBlocked { + stream_id: StreamId, + limit: u64, + }, + + MaxStreams { + stream_type: StreamType, + max_streams: u64, + }, + StreamsBlocked { + stream_type: StreamType, + limit: u64, + }, +} + +#[derive(Debug, Clone)] +#[allow(clippy::module_name_repetitions)] +pub enum RecoveryToken { + Stream(StreamRecoveryToken), + Ack(AckToken), + Crypto(CryptoRecoveryToken), + HandshakeDone, + KeepAlive, // Special PING. + NewToken(usize), + NewConnectionId(ConnectionIdEntry<[u8; 16]>), + RetireConnectionId(u64), + AckFrequency(AckRate), + Datagram(DatagramTracking), +} + +/// `SendProfile` tells a sender how to send packets. +#[derive(Debug)] +pub struct SendProfile { + /// The limit on the size of the packet. + limit: usize, + /// Whether this is a PTO, and what space the PTO is for. + pto: Option<PacketNumberSpace>, + /// What spaces should be probed. + probe: PacketNumberSpaceSet, + /// Whether pacing is active. + paced: bool, +} + +impl SendProfile { + pub fn new_limited(limit: usize) -> Self { + // When the limit is too low, we only send ACK frames. + // Set the limit to `ACK_ONLY_SIZE_LIMIT - 1` to ensure that + // ACK-only packets are still limited in size. + Self { + limit: max(ACK_ONLY_SIZE_LIMIT - 1, limit), + pto: None, + probe: PacketNumberSpaceSet::default(), + paced: false, + } + } + + pub fn new_paced() -> Self { + // When pacing, we still allow ACK frames to be sent. + Self { + limit: ACK_ONLY_SIZE_LIMIT - 1, + pto: None, + probe: PacketNumberSpaceSet::default(), + paced: true, + } + } + + pub fn new_pto(pn_space: PacketNumberSpace, mtu: usize, probe: PacketNumberSpaceSet) -> Self { + debug_assert!(mtu > ACK_ONLY_SIZE_LIMIT); + debug_assert!(probe[pn_space]); + Self { + limit: mtu, + pto: Some(pn_space), + probe, + paced: false, + } + } + + /// Whether probing this space is helpful. This isn't necessarily the space + /// that caused the timer to pop, but it is helpful to send a PING in a space + /// that has the PTO timer armed. + pub fn should_probe(&self, space: PacketNumberSpace) -> bool { + self.probe[space] + } + + /// Determine whether an ACK-only packet should be sent for the given packet + /// number space. + /// Send only ACKs either: when the space available is too small, or when a PTO + /// exists for a later packet number space (which should get the most space). + pub fn ack_only(&self, space: PacketNumberSpace) -> bool { + self.limit < ACK_ONLY_SIZE_LIMIT || self.pto.map_or(false, |sp| space < sp) + } + + pub fn paced(&self) -> bool { + self.paced + } + + pub fn limit(&self) -> usize { + self.limit + } +} + +#[derive(Debug)] +pub(crate) struct LossRecoverySpace { + space: PacketNumberSpace, + largest_acked: Option<PacketNumber>, + largest_acked_sent_time: Option<Instant>, + /// The time used to calculate the PTO timer for this space. + /// This is the time that the last ACK-eliciting packet in this space + /// was sent. This might be the time that a probe was sent. + last_ack_eliciting: Option<Instant>, + /// The number of outstanding packets in this space that are in flight. + /// This might be less than the number of ACK-eliciting packets, + /// because PTO packets don't count. + in_flight_outstanding: usize, + sent_packets: BTreeMap<u64, SentPacket>, + /// The time that the first out-of-order packet was sent. + /// This is `None` if there were no out-of-order packets detected. + /// When set to `Some(T)`, time-based loss detection should be enabled. + first_ooo_time: Option<Instant>, +} + +impl LossRecoverySpace { + pub fn new(space: PacketNumberSpace) -> Self { + Self { + space, + largest_acked: None, + largest_acked_sent_time: None, + last_ack_eliciting: None, + in_flight_outstanding: 0, + sent_packets: BTreeMap::default(), + first_ooo_time: None, + } + } + + #[must_use] + pub fn space(&self) -> PacketNumberSpace { + self.space + } + + /// Find the time we sent the first packet that is lower than the + /// largest acknowledged and that isn't yet declared lost. + /// Use the value we prepared earlier in `detect_lost_packets`. + #[must_use] + pub fn loss_recovery_timer_start(&self) -> Option<Instant> { + self.first_ooo_time + } + + pub fn in_flight_outstanding(&self) -> bool { + self.in_flight_outstanding > 0 + } + + pub fn pto_packets(&mut self, count: usize) -> impl Iterator<Item = &SentPacket> { + self.sent_packets + .iter_mut() + .filter_map(|(pn, sent)| { + if sent.pto() { + qtrace!("PTO: marking packet {} lost ", pn); + Some(&*sent) + } else { + None + } + }) + .take(count) + } + + pub fn pto_base_time(&self) -> Option<Instant> { + if self.in_flight_outstanding() { + debug_assert!(self.last_ack_eliciting.is_some()); + self.last_ack_eliciting + } else if self.space == PacketNumberSpace::ApplicationData { + None + } else { + // Nasty special case to prevent handshake deadlocks. + // A client needs to keep the PTO timer armed to prevent a stall + // of the handshake. Technically, this has to stop once we receive + // an ACK of Handshake or 1-RTT, or when we receive HANDSHAKE_DONE, + // but a few extra probes won't hurt. + // It only means that we fail anti-amplification tests. + // A server shouldn't arm its PTO timer this way. The server sends + // ack-eliciting, in-flight packets immediately so this only + // happens when the server has nothing outstanding. If we had + // client authentication, this might cause some extra probes, + // but they would be harmless anyway. + self.last_ack_eliciting + } + } + + pub fn on_packet_sent(&mut self, sent_packet: SentPacket) { + if sent_packet.ack_eliciting() { + self.last_ack_eliciting = Some(sent_packet.time_sent); + self.in_flight_outstanding += 1; + } else if self.space != PacketNumberSpace::ApplicationData + && self.last_ack_eliciting.is_none() + { + // For Initial and Handshake spaces, make sure that we have a PTO baseline + // always. See `LossRecoverySpace::pto_base_time()` for details. + self.last_ack_eliciting = Some(sent_packet.time_sent); + } + self.sent_packets.insert(sent_packet.pn, sent_packet); + } + + /// If we are only sending ACK frames, send a PING frame after 2 PTOs so that + /// the peer sends an ACK frame. If we have received lots of packets and no ACK, + /// send a PING frame after 1 PTO. Note that this can't be within a PTO, or + /// we would risk setting up a feedback loop; having this many packets + /// outstanding can be normal and we don't want to PING too often. + pub fn should_probe(&self, pto: Duration, now: Instant) -> bool { + let n_pto = if self.sent_packets.len() >= MAX_OUTSTANDING_UNACK { + 1 + } else if self.sent_packets.len() >= MIN_OUTSTANDING_UNACK { + 2 + } else { + return false; + }; + self.last_ack_eliciting + .map_or(false, |t| now > t + (pto * n_pto)) + } + + fn remove_packet(&mut self, p: &SentPacket) { + if p.ack_eliciting() { + debug_assert!(self.in_flight_outstanding > 0); + self.in_flight_outstanding -= 1; + if self.in_flight_outstanding == 0 { + qtrace!("remove_packet outstanding == 0 for space {}", self.space); + } + } + } + + /// Remove all acknowledged packets. + /// Returns all the acknowledged packets, with the largest packet number first. + /// ...and a boolean indicating if any of those packets were ack-eliciting. + /// This operates more efficiently because it assumes that the input is sorted + /// in the order that an ACK frame is (from the top). + fn remove_acked<R>(&mut self, acked_ranges: R, stats: &mut Stats) -> (Vec<SentPacket>, bool) + where + R: IntoIterator<Item = RangeInclusive<u64>>, + R::IntoIter: ExactSizeIterator, + { + let acked_ranges = acked_ranges.into_iter(); + let mut keep = Vec::with_capacity(acked_ranges.len()); + + let mut acked = Vec::new(); + let mut eliciting = false; + for range in acked_ranges { + let first_keep = *range.end() + 1; + if let Some((&first, _)) = self.sent_packets.range(range).next() { + let mut tail = self.sent_packets.split_off(&first); + if let Some((&next, _)) = tail.range(first_keep..).next() { + keep.push(tail.split_off(&next)); + } + for (_, p) in tail.into_iter().rev() { + self.remove_packet(&p); + eliciting |= p.ack_eliciting(); + if p.lost() { + stats.late_ack += 1; + } + if p.pto_fired() { + stats.pto_ack += 1; + } + acked.push(p); + } + } + } + + for mut k in keep.into_iter().rev() { + self.sent_packets.append(&mut k); + } + + (acked, eliciting) + } + + /// Remove all tracked packets from the space. + /// This is called by a client when 0-RTT packets are dropped, when a Retry is received + /// and when keys are dropped. + fn remove_ignored(&mut self) -> impl Iterator<Item = SentPacket> { + self.in_flight_outstanding = 0; + mem::take(&mut self.sent_packets).into_values() + } + + /// Remove the primary path marking on any packets this is tracking. + fn migrate(&mut self) { + for pkt in self.sent_packets.values_mut() { + pkt.clear_primary_path(); + } + } + + /// Remove old packets that we've been tracking in case they get acknowledged. + /// We try to keep these around until a probe is sent for them, so it is + /// important that `cd` is set to at least the current PTO time; otherwise we + /// might remove all in-flight packets and stop sending probes. + #[allow(clippy::option_if_let_else)] // Hard enough to read as-is. + fn remove_old_lost(&mut self, now: Instant, cd: Duration) { + let mut it = self.sent_packets.iter(); + // If the first item is not expired, do nothing. + if it.next().map_or(false, |(_, p)| p.expired(now, cd)) { + // Find the index of the first unexpired packet. + let to_remove = if let Some(first_keep) = + it.find_map(|(i, p)| if p.expired(now, cd) { None } else { Some(*i) }) + { + // Some packets haven't expired, so keep those. + let keep = self.sent_packets.split_off(&first_keep); + mem::replace(&mut self.sent_packets, keep) + } else { + // All packets are expired. + mem::take(&mut self.sent_packets) + }; + for (_, p) in to_remove { + self.remove_packet(&p); + } + } + } + + /// Detect lost packets. + /// `loss_delay` is the time we will wait before declaring something lost. + /// `cleanup_delay` is the time we will wait before cleaning up a lost packet. + pub fn detect_lost_packets( + &mut self, + now: Instant, + loss_delay: Duration, + cleanup_delay: Duration, + lost_packets: &mut Vec<SentPacket>, + ) { + // Housekeeping. + self.remove_old_lost(now, cleanup_delay); + + qtrace!( + "detect lost {}: now={:?} delay={:?}", + self.space, + now, + loss_delay, + ); + self.first_ooo_time = None; + + let largest_acked = self.largest_acked; + + // Lost for retrans/CC purposes + let mut lost_pns = SmallVec::<[_; 8]>::new(); + + for (pn, packet) in self + .sent_packets + .iter_mut() + // BTreeMap iterates in order of ascending PN + .take_while(|(&k, _)| k < largest_acked.unwrap_or(PacketNumber::MAX)) + { + // Packets sent before now - loss_delay are deemed lost. + if packet.time_sent + loss_delay <= now { + qtrace!( + "lost={}, time sent {:?} is before lost_delay {:?}", + pn, + packet.time_sent, + loss_delay + ); + } else if largest_acked >= Some(*pn + PACKET_THRESHOLD) { + qtrace!( + "lost={}, is >= {} from largest acked {:?}", + pn, + PACKET_THRESHOLD, + largest_acked + ); + } else { + if largest_acked.is_some() { + self.first_ooo_time = Some(packet.time_sent); + } + // No more packets can be declared lost after this one. + break; + }; + + if packet.declare_lost(now) { + lost_pns.push(*pn); + } + } + + lost_packets.extend(lost_pns.iter().map(|pn| self.sent_packets[pn].clone())); + } +} + +#[derive(Debug)] +pub(crate) struct LossRecoverySpaces { + /// When we have all of the loss recovery spaces, this will use a separate + /// allocation, but this is reduced once the handshake is done. + spaces: SmallVec<[LossRecoverySpace; 1]>, +} + +impl LossRecoverySpaces { + fn idx(space: PacketNumberSpace) -> usize { + match space { + PacketNumberSpace::ApplicationData => 0, + PacketNumberSpace::Handshake => 1, + PacketNumberSpace::Initial => 2, + } + } + + /// Drop a packet number space and return all the packets that were + /// outstanding, so that those can be marked as lost. + /// + /// # Panics + /// + /// If the space has already been removed. + pub fn drop_space(&mut self, space: PacketNumberSpace) -> impl IntoIterator<Item = SentPacket> { + let sp = match space { + PacketNumberSpace::Initial => self.spaces.pop(), + PacketNumberSpace::Handshake => { + let sp = self.spaces.pop(); + self.spaces.shrink_to_fit(); + sp + } + PacketNumberSpace::ApplicationData => panic!("discarding application space"), + }; + let mut sp = sp.unwrap(); + assert_eq!(sp.space(), space, "dropping spaces out of order"); + sp.remove_ignored() + } + + pub fn get(&self, space: PacketNumberSpace) -> Option<&LossRecoverySpace> { + self.spaces.get(Self::idx(space)) + } + + pub fn get_mut(&mut self, space: PacketNumberSpace) -> Option<&mut LossRecoverySpace> { + self.spaces.get_mut(Self::idx(space)) + } + + fn iter(&self) -> impl Iterator<Item = &LossRecoverySpace> { + self.spaces.iter() + } + + fn iter_mut(&mut self) -> impl Iterator<Item = &mut LossRecoverySpace> { + self.spaces.iter_mut() + } +} + +impl Default for LossRecoverySpaces { + fn default() -> Self { + Self { + spaces: smallvec![ + LossRecoverySpace::new(PacketNumberSpace::ApplicationData), + LossRecoverySpace::new(PacketNumberSpace::Handshake), + LossRecoverySpace::new(PacketNumberSpace::Initial), + ], + } + } +} + +#[derive(Debug)] +struct PtoState { + /// The packet number space that caused the PTO to fire. + space: PacketNumberSpace, + /// The number of probes that we have sent. + count: usize, + packets: usize, + /// The complete set of packet number spaces that can have probes sent. + probe: PacketNumberSpaceSet, +} + +impl PtoState { + /// The number of packets we send on a PTO. + /// And the number to declare lost when the PTO timer is hit. + fn pto_packet_count(space: PacketNumberSpace, rx_count: usize) -> usize { + if space == PacketNumberSpace::Initial && rx_count == 0 { + // For the Initial space, we only send one packet on PTO if we have not received any + // packets from the peer yet. This avoids sending useless PING-only packets + // when the Client Initial is deemed lost. + 1 + } else { + MAX_PTO_PACKET_COUNT + } + } + + pub fn new(space: PacketNumberSpace, probe: PacketNumberSpaceSet, rx_count: usize) -> Self { + debug_assert!(probe[space]); + Self { + space, + count: 1, + packets: Self::pto_packet_count(space, rx_count), + probe, + } + } + + pub fn pto(&mut self, space: PacketNumberSpace, probe: PacketNumberSpaceSet, rx_count: usize) { + debug_assert!(probe[space]); + self.space = space; + self.count += 1; + self.packets = Self::pto_packet_count(space, rx_count); + self.probe = probe; + } + + pub fn count(&self) -> usize { + self.count + } + + pub fn count_pto(&self, stats: &mut Stats) { + stats.add_pto_count(self.count); + } + + /// Generate a sending profile, indicating what space it should be from. + /// This takes a packet from the supply if one remains, or returns `None`. + pub fn send_profile(&mut self, mtu: usize) -> Option<SendProfile> { + if self.packets > 0 { + // This is a PTO, so ignore the limit. + self.packets -= 1; + Some(SendProfile::new_pto(self.space, mtu, self.probe)) + } else { + None + } + } +} + +#[derive(Debug)] +pub(crate) struct LossRecovery { + /// When the handshake was confirmed, if it has been. + confirmed_time: Option<Instant>, + pto_state: Option<PtoState>, + spaces: LossRecoverySpaces, + qlog: NeqoQlog, + stats: StatsCell, + /// The factor by which the PTO period is reduced. + /// This enables faster probing at a cost in additional lost packets. + fast_pto: u8, +} + +impl LossRecovery { + pub fn new(stats: StatsCell, fast_pto: u8) -> Self { + Self { + confirmed_time: None, + pto_state: None, + spaces: LossRecoverySpaces::default(), + qlog: NeqoQlog::default(), + stats, + fast_pto, + } + } + + pub fn largest_acknowledged_pn(&self, pn_space: PacketNumberSpace) -> Option<PacketNumber> { + self.spaces.get(pn_space).and_then(|sp| sp.largest_acked) + } + + pub fn set_qlog(&mut self, qlog: NeqoQlog) { + self.qlog = qlog; + } + + pub fn drop_0rtt(&mut self, primary_path: &PathRef, now: Instant) -> Vec<SentPacket> { + // The largest acknowledged or loss_time should still be unset. + // The client should not have received any ACK frames when it drops 0-RTT. + assert!(self + .spaces + .get(PacketNumberSpace::ApplicationData) + .unwrap() + .largest_acked + .is_none()); + let mut dropped = self + .spaces + .get_mut(PacketNumberSpace::ApplicationData) + .unwrap() + .remove_ignored() + .collect::<Vec<_>>(); + let mut path = primary_path.borrow_mut(); + for p in &mut dropped { + path.discard_packet(p, now, &mut self.stats.borrow_mut()); + } + dropped + } + + pub fn on_packet_sent(&mut self, path: &PathRef, mut sent_packet: SentPacket) { + let pn_space = PacketNumberSpace::from(sent_packet.pt); + qdebug!([self], "packet {}-{} sent", pn_space, sent_packet.pn); + if let Some(space) = self.spaces.get_mut(pn_space) { + path.borrow_mut().packet_sent(&mut sent_packet); + space.on_packet_sent(sent_packet); + } else { + qwarn!( + [self], + "ignoring {}-{} from dropped space", + pn_space, + sent_packet.pn + ); + } + } + + pub fn should_probe(&self, pto: Duration, now: Instant) -> bool { + self.spaces + .get(PacketNumberSpace::ApplicationData) + .unwrap() + .should_probe(pto, now) + } + + /// Record an RTT sample. + fn rtt_sample( + &mut self, + rtt: &mut RttEstimate, + send_time: Instant, + now: Instant, + ack_delay: Duration, + ) { + let confirmed = self.confirmed_time.map_or(false, |t| t < send_time); + if let Some(sample) = now.checked_duration_since(send_time) { + rtt.update(&mut self.qlog, sample, ack_delay, confirmed, now); + } + } + + /// Returns (acked packets, lost packets) + pub fn on_ack_received<R>( + &mut self, + primary_path: &PathRef, + pn_space: PacketNumberSpace, + largest_acked: u64, + acked_ranges: R, + ack_delay: Duration, + now: Instant, + ) -> (Vec<SentPacket>, Vec<SentPacket>) + where + R: IntoIterator<Item = RangeInclusive<u64>>, + R::IntoIter: ExactSizeIterator, + { + qdebug!( + [self], + "ACK for {} - largest_acked={}.", + pn_space, + largest_acked + ); + + let Some(space) = self.spaces.get_mut(pn_space) else { + qinfo!("ACK on discarded space"); + return (Vec::new(), Vec::new()); + }; + + let (acked_packets, any_ack_eliciting) = + space.remove_acked(acked_ranges, &mut self.stats.borrow_mut()); + if acked_packets.is_empty() { + // No new information. + return (Vec::new(), Vec::new()); + } + + // Track largest PN acked per space + let prev_largest_acked = space.largest_acked_sent_time; + if Some(largest_acked) > space.largest_acked { + space.largest_acked = Some(largest_acked); + + // If the largest acknowledged is newly acked and any newly acked + // packet was ack-eliciting, update the RTT. (-recovery 5.1) + let largest_acked_pkt = acked_packets.first().expect("must be there"); + space.largest_acked_sent_time = Some(largest_acked_pkt.time_sent); + if any_ack_eliciting && largest_acked_pkt.on_primary_path() { + self.rtt_sample( + primary_path.borrow_mut().rtt_mut(), + largest_acked_pkt.time_sent, + now, + ack_delay, + ); + } + } + + // Perform loss detection. + // PTO is used to remove lost packets from in-flight accounting. + // We need to ensure that we have sent any PTO probes before they are removed + // as we rely on the count of in-flight packets to determine whether to send + // another probe. Removing them too soon would result in not sending on PTO. + let loss_delay = primary_path.borrow().rtt().loss_delay(); + let cleanup_delay = self.pto_period(primary_path.borrow().rtt(), pn_space); + let mut lost = Vec::new(); + self.spaces.get_mut(pn_space).unwrap().detect_lost_packets( + now, + loss_delay, + cleanup_delay, + &mut lost, + ); + self.stats.borrow_mut().lost += lost.len(); + + // Tell the congestion controller about any lost packets. + // The PTO for congestion control is the raw number, without exponential + // backoff, so that we can determine persistent congestion. + primary_path + .borrow_mut() + .on_packets_lost(prev_largest_acked, pn_space, &lost); + + // This must happen after on_packets_lost. If in recovery, this could + // take us out, and then lost packets will start a new recovery period + // when it shouldn't. + primary_path + .borrow_mut() + .on_packets_acked(&acked_packets, now); + + self.pto_state = None; + + (acked_packets, lost) + } + + /// When receiving a retry, get all the sent packets so that they can be flushed. + /// We also need to pretend that they never happened for the purposes of congestion control. + pub fn retry(&mut self, primary_path: &PathRef, now: Instant) -> Vec<SentPacket> { + self.pto_state = None; + let mut dropped = self + .spaces + .iter_mut() + .flat_map(LossRecoverySpace::remove_ignored) + .collect::<Vec<_>>(); + let mut path = primary_path.borrow_mut(); + for p in &mut dropped { + path.discard_packet(p, now, &mut self.stats.borrow_mut()); + } + dropped + } + + fn confirmed(&mut self, rtt: &RttEstimate, now: Instant) { + debug_assert!(self.confirmed_time.is_none()); + self.confirmed_time = Some(now); + // Up until now, the ApplicationData space has been ignored for PTO. + // So maybe fire a PTO. + if let Some(pto) = self.pto_time(rtt, PacketNumberSpace::ApplicationData) { + if pto < now { + let probes = PacketNumberSpaceSet::from(&[PacketNumberSpace::ApplicationData]); + self.fire_pto(PacketNumberSpace::ApplicationData, probes); + } + } + } + + /// This function is called when the connection migrates. + /// It marks all packets that are outstanding as having being sent on a non-primary path. + /// This way failure to deliver on the old path doesn't count against the congestion + /// control state on the new path and the RTT measurements don't apply either. + pub fn migrate(&mut self) { + for space in self.spaces.iter_mut() { + space.migrate(); + } + } + + /// Discard state for a given packet number space. + pub fn discard(&mut self, primary_path: &PathRef, space: PacketNumberSpace, now: Instant) { + qdebug!([self], "Reset loss recovery state for {}", space); + let mut path = primary_path.borrow_mut(); + for p in self.spaces.drop_space(space) { + path.discard_packet(&p, now, &mut self.stats.borrow_mut()); + } + + // We just made progress, so discard PTO count. + // The spec says that clients should not do this until confirming that + // the server has completed address validation, but ignore that. + self.pto_state = None; + + if space == PacketNumberSpace::Handshake { + self.confirmed(path.rtt(), now); + } + } + + /// Calculate when the next timeout is likely to be. This is the earlier of the loss timer + /// and the PTO timer; either or both might be disabled, so this can return `None`. + pub fn next_timeout(&mut self, rtt: &RttEstimate) -> Option<Instant> { + let loss_time = self.earliest_loss_time(rtt); + let pto_time = self.earliest_pto(rtt); + qtrace!( + [self], + "next_timeout loss={:?} pto={:?}", + loss_time, + pto_time + ); + match (loss_time, pto_time) { + (Some(loss_time), Some(pto_time)) => Some(min(loss_time, pto_time)), + (Some(loss_time), None) => Some(loss_time), + (None, Some(pto_time)) => Some(pto_time), + (None, None) => None, + } + } + + /// Find when the earliest sent packet should be considered lost. + fn earliest_loss_time(&self, rtt: &RttEstimate) -> Option<Instant> { + self.spaces + .iter() + .filter_map(LossRecoverySpace::loss_recovery_timer_start) + .min() + .map(|val| val + rtt.loss_delay()) + } + + /// Simple wrapper for the PTO calculation that avoids borrow check rules. + fn pto_period_inner( + rtt: &RttEstimate, + pto_state: Option<&PtoState>, + pn_space: PacketNumberSpace, + fast_pto: u8, + ) -> Duration { + // This is a complicated (but safe) way of calculating: + // base_pto * F * 2^pto_count + // where F = fast_pto / FAST_PTO_SCALE (== 1 by default) + let pto_count = pto_state.map_or(0, |p| u32::try_from(p.count).unwrap_or(0)); + rtt.pto(pn_space) + .checked_mul(u32::from(fast_pto) << min(pto_count, u32::BITS - u8::BITS)) + .map_or(Duration::from_secs(3600), |p| p / u32::from(FAST_PTO_SCALE)) + } + + /// Get the current PTO period for the given packet number space. + /// Unlike calling `RttEstimate::pto` directly, this includes exponential backoff. + fn pto_period(&self, rtt: &RttEstimate, pn_space: PacketNumberSpace) -> Duration { + Self::pto_period_inner(rtt, self.pto_state.as_ref(), pn_space, self.fast_pto) + } + + // Calculate PTO time for the given space. + fn pto_time(&self, rtt: &RttEstimate, pn_space: PacketNumberSpace) -> Option<Instant> { + if self.confirmed_time.is_none() && pn_space == PacketNumberSpace::ApplicationData { + None + } else { + self.spaces.get(pn_space).and_then(|space| { + space + .pto_base_time() + .map(|t| t + self.pto_period(rtt, pn_space)) + }) + } + } + + /// Find the earliest PTO time for all active packet number spaces. + /// Ignore Application if either Initial or Handshake have an active PTO. + fn earliest_pto(&self, rtt: &RttEstimate) -> Option<Instant> { + if self.confirmed_time.is_some() { + self.pto_time(rtt, PacketNumberSpace::ApplicationData) + } else { + self.pto_time(rtt, PacketNumberSpace::Initial) + .iter() + .chain(self.pto_time(rtt, PacketNumberSpace::Handshake).iter()) + .min() + .copied() + } + } + + fn fire_pto(&mut self, pn_space: PacketNumberSpace, allow_probes: PacketNumberSpaceSet) { + let rx_count = self.stats.borrow().packets_rx; + if let Some(st) = &mut self.pto_state { + st.pto(pn_space, allow_probes, rx_count); + } else { + self.pto_state = Some(PtoState::new(pn_space, allow_probes, rx_count)); + } + + self.pto_state + .as_mut() + .unwrap() + .count_pto(&mut self.stats.borrow_mut()); + + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::PtoCount( + self.pto_state.as_ref().unwrap().count(), + )], + ); + } + + /// This checks whether the PTO timer has fired and fires it if needed. + /// When it has, mark a few packets as "lost" for the purposes of having frames + /// regenerated in subsequent packets. The packets aren't truly lost, so + /// we have to clone the `SentPacket` instance. + fn maybe_fire_pto(&mut self, rtt: &RttEstimate, now: Instant, lost: &mut Vec<SentPacket>) { + let mut pto_space = None; + // The spaces in which we will allow probing. + let mut allow_probes = PacketNumberSpaceSet::default(); + for pn_space in PacketNumberSpace::iter() { + if let Some(t) = self.pto_time(rtt, *pn_space) { + allow_probes[*pn_space] = true; + if t <= now { + qdebug!([self], "PTO timer fired for {}", pn_space); + let space = self.spaces.get_mut(*pn_space).unwrap(); + lost.extend( + space + .pto_packets(PtoState::pto_packet_count( + *pn_space, + self.stats.borrow().packets_rx, + )) + .cloned(), + ); + + pto_space = pto_space.or(Some(*pn_space)); + } + } + } + + // This has to happen outside the loop. Increasing the PTO count here causes the + // pto_time to increase which might cause PTO for later packet number spaces to not fire. + if let Some(pn_space) = pto_space { + qtrace!([self], "PTO {}, probing {:?}", pn_space, allow_probes); + self.fire_pto(pn_space, allow_probes); + } + } + + pub fn timeout(&mut self, primary_path: &PathRef, now: Instant) -> Vec<SentPacket> { + qtrace!([self], "timeout {:?}", now); + + let loss_delay = primary_path.borrow().rtt().loss_delay(); + + let mut lost_packets = Vec::new(); + for space in self.spaces.iter_mut() { + let first = lost_packets.len(); // The first packet lost in this space. + let pto = Self::pto_period_inner( + primary_path.borrow().rtt(), + self.pto_state.as_ref(), + space.space(), + self.fast_pto, + ); + space.detect_lost_packets(now, loss_delay, pto, &mut lost_packets); + + primary_path.borrow_mut().on_packets_lost( + space.largest_acked_sent_time, + space.space(), + &lost_packets[first..], + ); + } + self.stats.borrow_mut().lost += lost_packets.len(); + + self.maybe_fire_pto(primary_path.borrow().rtt(), now, &mut lost_packets); + lost_packets + } + + /// Check how packets should be sent, based on whether there is a PTO, + /// what the current congestion window is, and what the pacer says. + #[allow(clippy::option_if_let_else)] + pub fn send_profile(&mut self, path: &Path, now: Instant) -> SendProfile { + qdebug!([self], "get send profile {:?}", now); + let sender = path.sender(); + let mtu = path.mtu(); + if let Some(profile) = self + .pto_state + .as_mut() + .and_then(|pto| pto.send_profile(mtu)) + { + profile + } else { + let limit = min(sender.cwnd_avail(), path.amplification_limit()); + if limit > mtu { + // More than an MTU available; we might need to pace. + if sender + .next_paced(path.rtt().estimate()) + .map_or(false, |t| t > now) + { + SendProfile::new_paced() + } else { + SendProfile::new_limited(mtu) + } + } else if sender.recovery_packet() { + // After entering recovery, allow a packet to be sent immediately. + // This uses the PTO machinery, probing in all spaces. This will + // result in a PING being sent in every active space. + SendProfile::new_pto(PacketNumberSpace::Initial, mtu, PacketNumberSpaceSet::all()) + } else { + SendProfile::new_limited(limit) + } + } + } +} + +impl ::std::fmt::Display for LossRecovery { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "LossRecovery") + } +} + +#[cfg(test)] +mod tests { + use std::{ + cell::RefCell, + convert::TryInto, + ops::{Deref, DerefMut, RangeInclusive}, + rc::Rc, + time::{Duration, Instant}, + }; + + use neqo_common::qlog::NeqoQlog; + use test_fixture::{addr, now}; + + use super::{ + LossRecovery, LossRecoverySpace, PacketNumberSpace, SendProfile, SentPacket, FAST_PTO_SCALE, + }; + use crate::{ + cc::CongestionControlAlgorithm, + cid::{ConnectionId, ConnectionIdEntry}, + packet::PacketType, + path::{Path, PathRef}, + rtt::RttEstimate, + stats::{Stats, StatsCell}, + }; + + // Shorthand for a time in milliseconds. + const fn ms(t: u64) -> Duration { + Duration::from_millis(t) + } + + const ON_SENT_SIZE: usize = 100; + /// An initial RTT for using with `setup_lr`. + const TEST_RTT: Duration = ms(80); + const TEST_RTTVAR: Duration = ms(40); + + struct Fixture { + lr: LossRecovery, + path: PathRef, + } + + // This shadows functions on the base object so that the path and RTT estimator + // is used consistently in the tests. It also simplifies the function signatures. + impl Fixture { + pub fn on_ack_received( + &mut self, + pn_space: PacketNumberSpace, + largest_acked: u64, + acked_ranges: Vec<RangeInclusive<u64>>, + ack_delay: Duration, + now: Instant, + ) -> (Vec<SentPacket>, Vec<SentPacket>) { + self.lr.on_ack_received( + &self.path, + pn_space, + largest_acked, + acked_ranges, + ack_delay, + now, + ) + } + + pub fn on_packet_sent(&mut self, sent_packet: SentPacket) { + self.lr.on_packet_sent(&self.path, sent_packet); + } + + pub fn timeout(&mut self, now: Instant) -> Vec<SentPacket> { + self.lr.timeout(&self.path, now) + } + + pub fn next_timeout(&mut self) -> Option<Instant> { + self.lr.next_timeout(self.path.borrow().rtt()) + } + + pub fn discard(&mut self, space: PacketNumberSpace, now: Instant) { + self.lr.discard(&self.path, space, now); + } + + pub fn pto_time(&self, space: PacketNumberSpace) -> Option<Instant> { + self.lr.pto_time(self.path.borrow().rtt(), space) + } + + pub fn send_profile(&mut self, now: Instant) -> SendProfile { + self.lr.send_profile(&self.path.borrow(), now) + } + } + + impl Default for Fixture { + fn default() -> Self { + const CC: CongestionControlAlgorithm = CongestionControlAlgorithm::NewReno; + let mut path = Path::temporary(addr(), addr(), CC, true, NeqoQlog::default(), now()); + path.make_permanent( + None, + ConnectionIdEntry::new(0, ConnectionId::from(&[1, 2, 3]), [0; 16]), + ); + path.set_primary(true); + Self { + lr: LossRecovery::new(StatsCell::default(), FAST_PTO_SCALE), + path: Rc::new(RefCell::new(path)), + } + } + } + + // Most uses of the fixture only care about the loss recovery piece, + // but the internal functions need the other bits. + impl Deref for Fixture { + type Target = LossRecovery; + #[must_use] + fn deref(&self) -> &Self::Target { + &self.lr + } + } + + impl DerefMut for Fixture { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.lr + } + } + + fn assert_rtts( + lr: &Fixture, + latest_rtt: Duration, + smoothed_rtt: Duration, + rttvar: Duration, + min_rtt: Duration, + ) { + let p = lr.path.borrow(); + let rtt = p.rtt(); + println!( + "rtts: {:?} {:?} {:?} {:?}", + rtt.latest(), + rtt.estimate(), + rtt.rttvar(), + rtt.minimum(), + ); + assert_eq!(rtt.latest(), latest_rtt, "latest RTT"); + assert_eq!(rtt.estimate(), smoothed_rtt, "smoothed RTT"); + assert_eq!(rtt.rttvar(), rttvar, "RTT variance"); + assert_eq!(rtt.minimum(), min_rtt, "min RTT"); + } + + fn assert_sent_times( + lr: &Fixture, + initial: Option<Instant>, + handshake: Option<Instant>, + app_data: Option<Instant>, + ) { + let est = |sp| { + lr.spaces + .get(sp) + .and_then(LossRecoverySpace::loss_recovery_timer_start) + }; + println!( + "loss times: {:?} {:?} {:?}", + est(PacketNumberSpace::Initial), + est(PacketNumberSpace::Handshake), + est(PacketNumberSpace::ApplicationData), + ); + assert_eq!( + est(PacketNumberSpace::Initial), + initial, + "Initial earliest sent time" + ); + assert_eq!( + est(PacketNumberSpace::Handshake), + handshake, + "Handshake earliest sent time" + ); + assert_eq!( + est(PacketNumberSpace::ApplicationData), + app_data, + "AppData earliest sent time" + ); + } + + fn assert_no_sent_times(lr: &Fixture) { + assert_sent_times(lr, None, None, None); + } + + // In most of the tests below, packets are sent at a fixed cadence, with PACING between each. + const PACING: Duration = ms(7); + fn pn_time(pn: u64) -> Instant { + now() + (PACING * pn.try_into().unwrap()) + } + + fn pace(lr: &mut Fixture, count: u64) { + for pn in 0..count { + lr.on_packet_sent(SentPacket::new( + PacketType::Short, + pn, + pn_time(pn), + true, + Vec::new(), + ON_SENT_SIZE, + )); + } + } + + const ACK_DELAY: Duration = ms(24); + /// Acknowledge PN with the identified delay. + fn ack(lr: &mut Fixture, pn: u64, delay: Duration) { + lr.on_ack_received( + PacketNumberSpace::ApplicationData, + pn, + vec![pn..=pn], + ACK_DELAY, + pn_time(pn) + delay, + ); + } + + fn add_sent(lrs: &mut LossRecoverySpace, packet_numbers: &[u64]) { + for &pn in packet_numbers { + lrs.on_packet_sent(SentPacket::new( + PacketType::Short, + pn, + pn_time(pn), + true, + Vec::new(), + ON_SENT_SIZE, + )); + } + } + + fn match_acked(acked: &[SentPacket], expected: &[u64]) { + assert!(acked.iter().map(|p| &p.pn).eq(expected)); + } + + #[test] + fn remove_acked() { + let mut lrs = LossRecoverySpace::new(PacketNumberSpace::ApplicationData); + let mut stats = Stats::default(); + add_sent(&mut lrs, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let (acked, _) = lrs.remove_acked(vec![], &mut stats); + assert!(acked.is_empty()); + let (acked, _) = lrs.remove_acked(vec![7..=8, 2..=4], &mut stats); + match_acked(&acked, &[8, 7, 4, 3, 2]); + let (acked, _) = lrs.remove_acked(vec![8..=11], &mut stats); + match_acked(&acked, &[10, 9]); + let (acked, _) = lrs.remove_acked(vec![0..=2], &mut stats); + match_acked(&acked, &[1]); + let (acked, _) = lrs.remove_acked(vec![5..=6], &mut stats); + match_acked(&acked, &[6, 5]); + } + + #[test] + fn initial_rtt() { + let mut lr = Fixture::default(); + pace(&mut lr, 1); + let rtt = ms(100); + ack(&mut lr, 0, rtt); + assert_rtts(&lr, rtt, rtt, rtt / 2, rtt); + assert_no_sent_times(&lr); + } + + /// Send `n` packets (using PACING), then acknowledge the first. + fn setup_lr(n: u64) -> Fixture { + let mut lr = Fixture::default(); + pace(&mut lr, n); + ack(&mut lr, 0, TEST_RTT); + assert_rtts(&lr, TEST_RTT, TEST_RTT, TEST_RTTVAR, TEST_RTT); + assert_no_sent_times(&lr); + lr + } + + // The ack delay is removed from any RTT estimate. + #[test] + fn ack_delay_adjusted() { + let mut lr = setup_lr(2); + ack(&mut lr, 1, TEST_RTT + ACK_DELAY); + // RTT stays the same, but the RTTVAR is adjusted downwards. + assert_rtts(&lr, TEST_RTT, TEST_RTT, TEST_RTTVAR * 3 / 4, TEST_RTT); + assert_no_sent_times(&lr); + } + + // The ack delay is ignored when it would cause a sample to be less than min_rtt. + #[test] + fn ack_delay_ignored() { + let mut lr = setup_lr(2); + let extra = ms(8); + assert!(extra < ACK_DELAY); + ack(&mut lr, 1, TEST_RTT + extra); + let expected_rtt = TEST_RTT + (extra / 8); + let expected_rttvar = (TEST_RTTVAR * 3 + extra) / 4; + assert_rtts( + &lr, + TEST_RTT + extra, + expected_rtt, + expected_rttvar, + TEST_RTT, + ); + assert_no_sent_times(&lr); + } + + // A lower observed RTT is used as min_rtt (and ack delay is ignored). + #[test] + fn reduce_min_rtt() { + let mut lr = setup_lr(2); + let delta = ms(4); + let reduced_rtt = TEST_RTT - delta; + ack(&mut lr, 1, reduced_rtt); + let expected_rtt = TEST_RTT - (delta / 8); + let expected_rttvar = (TEST_RTTVAR * 3 + delta) / 4; + assert_rtts(&lr, reduced_rtt, expected_rtt, expected_rttvar, reduced_rtt); + assert_no_sent_times(&lr); + } + + // Acknowledging something again has no effect. + #[test] + fn no_new_acks() { + let mut lr = setup_lr(1); + let check = |lr: &Fixture| { + assert_rtts(lr, TEST_RTT, TEST_RTT, TEST_RTTVAR, TEST_RTT); + assert_no_sent_times(lr); + }; + check(&lr); + + ack(&mut lr, 0, ms(1339)); // much delayed ACK + check(&lr); + + ack(&mut lr, 0, ms(3)); // time travel! + check(&lr); + } + + // Test time loss detection as part of handling a regular ACK. + #[test] + fn time_loss_detection_gap() { + let mut lr = Fixture::default(); + // Create a single packet gap, and have pn 0 time out. + // This can't use the default pacing, which is too tight. + // So send two packets with 1/4 RTT between them. Acknowledge pn 1 after 1 RTT. + // pn 0 should then be marked lost because it is then outstanding for 5RTT/4 + // the loss time for packets is 9RTT/8. + lr.on_packet_sent(SentPacket::new( + PacketType::Short, + 0, + pn_time(0), + true, + Vec::new(), + ON_SENT_SIZE, + )); + lr.on_packet_sent(SentPacket::new( + PacketType::Short, + 1, + pn_time(0) + TEST_RTT / 4, + true, + Vec::new(), + ON_SENT_SIZE, + )); + let (_, lost) = lr.on_ack_received( + PacketNumberSpace::ApplicationData, + 1, + vec![1..=1], + ACK_DELAY, + pn_time(0) + (TEST_RTT * 5 / 4), + ); + assert_eq!(lost.len(), 1); + assert_no_sent_times(&lr); + } + + // Test time loss detection as part of an explicit timeout. + #[test] + fn time_loss_detection_timeout() { + let mut lr = setup_lr(3); + + // We want to declare PN 2 as acknowledged before we declare PN 1 as lost. + // For this to work, we need PACING above to be less than 1/8 of an RTT. + let pn1_sent_time = pn_time(1); + let pn1_loss_time = pn1_sent_time + (TEST_RTT * 9 / 8); + let pn2_ack_time = pn_time(2) + TEST_RTT; + assert!(pn1_loss_time > pn2_ack_time); + + let (_, lost) = lr.on_ack_received( + PacketNumberSpace::ApplicationData, + 2, + vec![2..=2], + ACK_DELAY, + pn2_ack_time, + ); + assert!(lost.is_empty()); + // Run the timeout function here to force time-based loss recovery to be enabled. + let lost = lr.timeout(pn2_ack_time); + assert!(lost.is_empty()); + assert_sent_times(&lr, None, None, Some(pn1_sent_time)); + + // After time elapses, pn 1 is marked lost. + let callback_time = lr.next_timeout(); + assert_eq!(callback_time, Some(pn1_loss_time)); + let packets = lr.timeout(pn1_loss_time); + assert_eq!(packets.len(), 1); + // Checking for expiration with zero delay lets us check the loss time. + assert!(packets[0].expired(pn1_loss_time, Duration::new(0, 0))); + assert_no_sent_times(&lr); + } + + #[test] + fn big_gap_loss() { + let mut lr = setup_lr(5); // This sends packets 0-4 and acknowledges pn 0. + + // Acknowledge just 2-4, which will cause pn 1 to be marked as lost. + assert_eq!(super::PACKET_THRESHOLD, 3); + let (_, lost) = lr.on_ack_received( + PacketNumberSpace::ApplicationData, + 4, + vec![2..=4], + ACK_DELAY, + pn_time(4), + ); + assert_eq!(lost.len(), 1); + } + + #[test] + #[should_panic(expected = "discarding application space")] + fn drop_app() { + let mut lr = Fixture::default(); + lr.discard(PacketNumberSpace::ApplicationData, now()); + } + + #[test] + #[should_panic(expected = "dropping spaces out of order")] + fn drop_out_of_order() { + let mut lr = Fixture::default(); + lr.discard(PacketNumberSpace::Handshake, now()); + } + + #[test] + fn ack_after_drop() { + let mut lr = Fixture::default(); + lr.discard(PacketNumberSpace::Initial, now()); + let (acked, lost) = lr.on_ack_received( + PacketNumberSpace::Initial, + 0, + vec![], + Duration::from_millis(0), + pn_time(0), + ); + assert!(acked.is_empty()); + assert!(lost.is_empty()); + } + + #[test] + fn drop_spaces() { + let mut lr = Fixture::default(); + lr.on_packet_sent(SentPacket::new( + PacketType::Initial, + 0, + pn_time(0), + true, + Vec::new(), + ON_SENT_SIZE, + )); + lr.on_packet_sent(SentPacket::new( + PacketType::Handshake, + 0, + pn_time(1), + true, + Vec::new(), + ON_SENT_SIZE, + )); + lr.on_packet_sent(SentPacket::new( + PacketType::Short, + 0, + pn_time(2), + true, + Vec::new(), + ON_SENT_SIZE, + )); + + // Now put all spaces on the LR timer so we can see them. + for sp in &[ + PacketType::Initial, + PacketType::Handshake, + PacketType::Short, + ] { + let sent_pkt = SentPacket::new(*sp, 1, pn_time(3), true, Vec::new(), ON_SENT_SIZE); + let pn_space = PacketNumberSpace::from(sent_pkt.pt); + lr.on_packet_sent(sent_pkt); + lr.on_ack_received(pn_space, 1, vec![1..=1], Duration::from_secs(0), pn_time(3)); + let mut lost = Vec::new(); + lr.spaces.get_mut(pn_space).unwrap().detect_lost_packets( + pn_time(3), + TEST_RTT, + TEST_RTT * 3, // unused + &mut lost, + ); + assert!(lost.is_empty()); + } + + lr.discard(PacketNumberSpace::Initial, pn_time(3)); + assert_sent_times(&lr, None, Some(pn_time(1)), Some(pn_time(2))); + + lr.discard(PacketNumberSpace::Handshake, pn_time(3)); + assert_sent_times(&lr, None, None, Some(pn_time(2))); + + // There are cases where we send a packet that is not subsequently tracked. + // So check that this works. + lr.on_packet_sent(SentPacket::new( + PacketType::Initial, + 0, + pn_time(3), + true, + Vec::new(), + ON_SENT_SIZE, + )); + assert_sent_times(&lr, None, None, Some(pn_time(2))); + } + + #[test] + fn rearm_pto_after_confirmed() { + let mut lr = Fixture::default(); + lr.on_packet_sent(SentPacket::new( + PacketType::Initial, + 0, + now(), + true, + Vec::new(), + ON_SENT_SIZE, + )); + // Set the RTT to the initial value so that discarding doesn't + // alter the estimate. + let rtt = lr.path.borrow().rtt().estimate(); + lr.on_ack_received( + PacketNumberSpace::Initial, + 0, + vec![0..=0], + Duration::new(0, 0), + now() + rtt, + ); + + lr.on_packet_sent(SentPacket::new( + PacketType::Handshake, + 0, + now(), + true, + Vec::new(), + ON_SENT_SIZE, + )); + lr.on_packet_sent(SentPacket::new( + PacketType::Short, + 0, + now(), + true, + Vec::new(), + ON_SENT_SIZE, + )); + + assert_eq!(lr.pto_time(PacketNumberSpace::ApplicationData), None); + lr.discard(PacketNumberSpace::Initial, pn_time(1)); + assert_eq!(lr.pto_time(PacketNumberSpace::ApplicationData), None); + + // Expiring state after the PTO on the ApplicationData space has + // expired should result in setting a PTO state. + let default_pto = RttEstimate::default().pto(PacketNumberSpace::ApplicationData); + let expected_pto = pn_time(2) + default_pto; + lr.discard(PacketNumberSpace::Handshake, expected_pto); + let profile = lr.send_profile(now()); + assert!(profile.pto.is_some()); + assert!(!profile.should_probe(PacketNumberSpace::Initial)); + assert!(!profile.should_probe(PacketNumberSpace::Handshake)); + assert!(profile.should_probe(PacketNumberSpace::ApplicationData)); + } + + #[test] + fn no_pto_if_amplification_limited() { + let mut lr = Fixture::default(); + // Eat up the amplification limit by telling the path that we've sent a giant packet. + { + const SPARE: usize = 10; + let mut path = lr.path.borrow_mut(); + let limit = path.amplification_limit(); + path.add_sent(limit - SPARE); + assert_eq!(path.amplification_limit(), SPARE); + } + + lr.on_packet_sent(SentPacket::new( + PacketType::Initial, + 1, + now(), + true, + Vec::new(), + ON_SENT_SIZE, + )); + + let handshake_pto = RttEstimate::default().pto(PacketNumberSpace::Handshake); + let expected_pto = now() + handshake_pto; + assert_eq!(lr.pto_time(PacketNumberSpace::Initial), Some(expected_pto)); + let profile = lr.send_profile(now()); + assert!(profile.ack_only(PacketNumberSpace::Initial)); + assert!(profile.pto.is_none()); + assert!(!profile.should_probe(PacketNumberSpace::Initial)); + assert!(!profile.should_probe(PacketNumberSpace::Handshake)); + assert!(!profile.should_probe(PacketNumberSpace::ApplicationData)); + } +} diff --git a/third_party/rust/neqo-transport/src/recv_stream.rs b/third_party/rust/neqo-transport/src/recv_stream.rs new file mode 100644 index 0000000000..06ca59685d --- /dev/null +++ b/third_party/rust/neqo-transport/src/recv_stream.rs @@ -0,0 +1,2149 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Building a stream of ordered bytes to give the application from a series of +// incoming STREAM frames. + +use std::{ + cell::RefCell, + cmp::max, + collections::BTreeMap, + convert::TryFrom, + mem, + rc::{Rc, Weak}, +}; + +use neqo_common::{qtrace, Role}; +use smallvec::SmallVec; + +use crate::{ + events::ConnectionEvents, + fc::ReceiverFlowControl, + frame::FRAME_TYPE_STOP_SENDING, + packet::PacketBuilder, + recovery::{RecoveryToken, StreamRecoveryToken}, + send_stream::SendStreams, + stats::FrameStats, + stream_id::StreamId, + AppError, Error, Res, +}; + +const RX_STREAM_DATA_WINDOW: u64 = 0x10_0000; // 1MiB + +// Export as usize for consistency with SEND_BUFFER_SIZE +pub const RECV_BUFFER_SIZE: usize = RX_STREAM_DATA_WINDOW as usize; + +#[derive(Debug, Default)] +pub(crate) struct RecvStreams { + streams: BTreeMap<StreamId, RecvStream>, + keep_alive: Weak<()>, +} + +impl RecvStreams { + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + for stream in self.streams.values_mut() { + stream.write_frame(builder, tokens, stats); + if builder.is_full() { + return; + } + } + } + + pub fn insert(&mut self, id: StreamId, stream: RecvStream) { + self.streams.insert(id, stream); + } + + pub fn get_mut(&mut self, id: StreamId) -> Res<&mut RecvStream> { + self.streams.get_mut(&id).ok_or(Error::InvalidStreamId) + } + + pub fn keep_alive(&mut self, id: StreamId, k: bool) -> Res<()> { + let self_ka = &mut self.keep_alive; + let s = self.streams.get_mut(&id).ok_or(Error::InvalidStreamId)?; + s.keep_alive = if k { + Some(self_ka.upgrade().unwrap_or_else(|| { + let r = Rc::new(()); + *self_ka = Rc::downgrade(&r); + r + })) + } else { + None + }; + Ok(()) + } + + pub fn need_keep_alive(&mut self) -> bool { + self.keep_alive.strong_count() > 0 + } + + pub fn clear(&mut self) { + self.streams.clear(); + } + + pub fn clear_terminal(&mut self, send_streams: &SendStreams, role: Role) -> (u64, u64) { + let recv_to_remove = self + .streams + .iter() + .filter_map(|(id, stream)| { + // Remove all streams for which the receiving is done (or aborted). + // But only if they are unidirectional, or we have finished sending. + if stream.is_terminal() && (id.is_uni() || !send_streams.exists(*id)) { + Some(*id) + } else { + None + } + }) + .collect::<Vec<_>>(); + + let mut removed_bidi = 0; + let mut removed_uni = 0; + for id in &recv_to_remove { + self.streams.remove(id); + if id.is_remote_initiated(role) { + if id.is_bidi() { + removed_bidi += 1; + } else { + removed_uni += 1; + } + } + } + + (removed_bidi, removed_uni) + } +} + +/// Holds data not yet read by application. Orders and dedupes data ranges +/// from incoming STREAM frames. +#[derive(Debug, Default)] +pub struct RxStreamOrderer { + data_ranges: BTreeMap<u64, Vec<u8>>, // (start_offset, data) + retired: u64, // Number of bytes the application has read + received: u64, // The number of bytes has stored in `data_ranges` +} + +impl RxStreamOrderer { + pub fn new() -> Self { + Self::default() + } + + /// Process an incoming stream frame off the wire. This may result in data + /// being available to upper layers if frame is not out of order (ooo) or + /// if the frame fills a gap. + pub fn inbound_frame(&mut self, mut new_start: u64, mut new_data: &[u8]) { + qtrace!("Inbound data offset={} len={}", new_start, new_data.len()); + + // Get entry before where new entry would go, so we can see if we already + // have the new bytes. + // Avoid copies and duplicated data. + let new_end = new_start + u64::try_from(new_data.len()).unwrap(); + + if new_end <= self.retired { + // Range already read by application, this frame is very late and unneeded. + return; + } + + if new_start < self.retired { + new_data = &new_data[usize::try_from(self.retired - new_start).unwrap()..]; + new_start = self.retired; + } + + if new_data.is_empty() { + // No data to insert + return; + } + + let extend = if let Some((&prev_start, prev_vec)) = + self.data_ranges.range_mut(..=new_start).next_back() + { + let prev_end = prev_start + u64::try_from(prev_vec.len()).unwrap(); + if new_end > prev_end { + // PPPPPP -> PPPPPP + // NNNNNN NN + // NNNNNNNN NN + // Add a range containing only new data + // (In-order frames will take this path, with no overlap) + let overlap = prev_end.saturating_sub(new_start); + qtrace!( + "New frame {}-{} received, overlap: {}", + new_start, + new_end, + overlap + ); + new_start += overlap; + new_data = &new_data[usize::try_from(overlap).unwrap()..]; + // If it is small enough, extend the previous buffer. + // This can't always extend, because otherwise the buffer could end up + // growing indefinitely without being released. + prev_vec.len() < 4096 && prev_end == new_start + } else { + // PPPPPP -> PPPPPP + // NNNN + // NNNN + // Do nothing + qtrace!( + "Dropping frame with already-received range {}-{}", + new_start, + new_end + ); + return; + } + } else { + qtrace!("New frame {}-{} received", new_start, new_end); + false + }; + + let mut to_add = new_data; + if self + .data_ranges + .last_entry() + .map_or(false, |e| *e.key() >= new_start) + { + // Is this at the end (common case)? If so, nothing to do in this block + // Common case: + // PPPPPP -> PPPPPP + // NNNNNNN NNNNNNN + // or + // PPPPPP -> PPPPPP + // NNNNNNN NNNNNNN + // + // Not the common case, handle possible overlap with next entries + // PPPPPP AAA -> PPPPPP + // NNNNNNN NNNNNNN + // or + // PPPPPP AAAA -> PPPPPP AAAA + // NNNNNNN NNNNN + // or (this is where to_remove is used) + // PPPPPP AA -> PPPPPP + // NNNNNNN NNNNNNN + + let mut to_remove = SmallVec::<[_; 8]>::new(); + + for (&next_start, next_data) in self.data_ranges.range_mut(new_start..) { + let next_end = next_start + u64::try_from(next_data.len()).unwrap(); + let overlap = new_end.saturating_sub(next_start); + if overlap == 0 { + // Fills in the hole, exactly (probably common) + break; + } else if next_end >= new_end { + qtrace!( + "New frame {}-{} overlaps with next frame by {}, truncating", + new_start, + new_end, + overlap + ); + let truncate_to = new_data.len() - usize::try_from(overlap).unwrap(); + to_add = &new_data[..truncate_to]; + break; + } + qtrace!( + "New frame {}-{} spans entire next frame {}-{}, replacing", + new_start, + new_end, + next_start, + next_end + ); + to_remove.push(next_start); + // Continue, since we may have more overlaps + } + + for start in to_remove { + self.data_ranges.remove(&start); + } + } + + if !to_add.is_empty() { + self.received += u64::try_from(to_add.len()).unwrap(); + if extend { + let (_, buf) = self + .data_ranges + .range_mut(..=new_start) + .next_back() + .unwrap(); + buf.extend_from_slice(to_add); + } else { + self.data_ranges.insert(new_start, to_add.to_vec()); + } + } + } + + /// Are any bytes readable? + pub fn data_ready(&self) -> bool { + self.data_ranges + .keys() + .next() + .map_or(false, |&start| start <= self.retired) + } + + /// How many bytes are readable? + fn bytes_ready(&self) -> usize { + let mut prev_end = self.retired; + self.data_ranges + .iter() + .map(|(start_offset, data)| { + // All ranges don't overlap but we could have partially + // retired some of the first entry's data. + let data_len = data.len() as u64 - self.retired.saturating_sub(*start_offset); + (start_offset, data_len) + }) + .take_while(|(start_offset, data_len)| { + if **start_offset <= prev_end { + prev_end += data_len; + true + } else { + false + } + }) + .map(|(_, data_len)| data_len as usize) + .sum() + } + + /// Bytes read by the application. + pub fn retired(&self) -> u64 { + self.retired + } + + pub fn received(&self) -> u64 { + self.received + } + + /// Data bytes buffered. Could be more than bytes_readable if there are + /// ranges missing. + fn buffered(&self) -> u64 { + self.data_ranges + .iter() + .map(|(&start, data)| data.len() as u64 - (self.retired.saturating_sub(start))) + .sum() + } + + /// Copy received data (if any) into the buffer. Returns bytes copied. + fn read(&mut self, buf: &mut [u8]) -> usize { + qtrace!("Reading {} bytes, {} available", buf.len(), self.buffered()); + let mut copied = 0; + + for (&range_start, range_data) in &mut self.data_ranges { + let mut keep = false; + if self.retired >= range_start { + // Frame data has new contiguous bytes. + let copy_offset = + usize::try_from(max(range_start, self.retired) - range_start).unwrap(); + assert!(range_data.len() >= copy_offset); + let available = range_data.len() - copy_offset; + let space = buf.len() - copied; + let copy_bytes = if available > space { + keep = true; + space + } else { + available + }; + + if copy_bytes > 0 { + let copy_slc = &range_data[copy_offset..copy_offset + copy_bytes]; + buf[copied..copied + copy_bytes].copy_from_slice(copy_slc); + copied += copy_bytes; + self.retired += u64::try_from(copy_bytes).unwrap(); + } + } else { + // The data in the buffer isn't contiguous. + keep = true; + } + if keep { + let mut keep = self.data_ranges.split_off(&range_start); + mem::swap(&mut self.data_ranges, &mut keep); + return copied; + } + } + + self.data_ranges.clear(); + copied + } + + /// Extend the given Vector with any available data. + pub fn read_to_end(&mut self, buf: &mut Vec<u8>) -> usize { + let orig_len = buf.len(); + buf.resize(orig_len + self.bytes_ready(), 0); + self.read(&mut buf[orig_len..]) + } +} + +/// QUIC receiving states, based on -transport 3.2. +#[derive(Debug)] +#[allow(dead_code)] +// Because a dead_code warning is easier than clippy::unused_self, see https://github.com/rust-lang/rust/issues/68408 +enum RecvStreamState { + Recv { + fc: ReceiverFlowControl<StreamId>, + session_fc: Rc<RefCell<ReceiverFlowControl<()>>>, + recv_buf: RxStreamOrderer, + }, + SizeKnown { + fc: ReceiverFlowControl<StreamId>, + session_fc: Rc<RefCell<ReceiverFlowControl<()>>>, + recv_buf: RxStreamOrderer, + }, + DataRecvd { + fc: ReceiverFlowControl<StreamId>, + session_fc: Rc<RefCell<ReceiverFlowControl<()>>>, + recv_buf: RxStreamOrderer, + }, + DataRead { + final_received: u64, + final_read: u64, + }, + AbortReading { + fc: ReceiverFlowControl<StreamId>, + session_fc: Rc<RefCell<ReceiverFlowControl<()>>>, + final_size_reached: bool, + frame_needed: bool, + err: AppError, + final_received: u64, + final_read: u64, + }, + WaitForReset { + fc: ReceiverFlowControl<StreamId>, + session_fc: Rc<RefCell<ReceiverFlowControl<()>>>, + final_received: u64, + final_read: u64, + }, + ResetRecvd { + final_received: u64, + final_read: u64, + }, + // Defined by spec but we don't use it: ResetRead +} + +impl RecvStreamState { + fn new( + max_bytes: u64, + stream_id: StreamId, + session_fc: Rc<RefCell<ReceiverFlowControl<()>>>, + ) -> Self { + Self::Recv { + fc: ReceiverFlowControl::new(stream_id, max_bytes), + recv_buf: RxStreamOrderer::new(), + session_fc, + } + } + + fn name(&self) -> &str { + match self { + Self::Recv { .. } => "Recv", + Self::SizeKnown { .. } => "SizeKnown", + Self::DataRecvd { .. } => "DataRecvd", + Self::DataRead { .. } => "DataRead", + Self::AbortReading { .. } => "AbortReading", + Self::WaitForReset { .. } => "WaitForReset", + Self::ResetRecvd { .. } => "ResetRecvd", + } + } + + fn recv_buf(&self) -> Option<&RxStreamOrderer> { + match self { + Self::Recv { recv_buf, .. } + | Self::SizeKnown { recv_buf, .. } + | Self::DataRecvd { recv_buf, .. } => Some(recv_buf), + Self::DataRead { .. } + | Self::AbortReading { .. } + | Self::WaitForReset { .. } + | Self::ResetRecvd { .. } => None, + } + } + + fn flow_control_consume_data(&mut self, consumed: u64, fin: bool) -> Res<()> { + let (fc, session_fc, final_size_reached, retire_data) = match self { + Self::Recv { fc, session_fc, .. } => (fc, session_fc, false, false), + Self::WaitForReset { fc, session_fc, .. } => (fc, session_fc, false, true), + Self::SizeKnown { fc, session_fc, .. } | Self::DataRecvd { fc, session_fc, .. } => { + (fc, session_fc, true, false) + } + Self::AbortReading { + fc, + session_fc, + final_size_reached, + .. + } => { + let old_final_size_reached = *final_size_reached; + *final_size_reached |= fin; + (fc, session_fc, old_final_size_reached, true) + } + Self::DataRead { .. } | Self::ResetRecvd { .. } => { + return Ok(()); + } + }; + + // Check final size: + let final_size_ok = match (fin, final_size_reached) { + (true, true) => consumed == fc.consumed(), + (false, true) => consumed <= fc.consumed(), + (true, false) => consumed >= fc.consumed(), + (false, false) => true, + }; + + if !final_size_ok { + return Err(Error::FinalSizeError); + } + + let new_bytes_consumed = fc.set_consumed(consumed)?; + session_fc.borrow_mut().consume(new_bytes_consumed)?; + if retire_data { + // Let's also retire this data since the stream has been aborted + RecvStream::flow_control_retire_data(fc.consumed() - fc.retired(), fc, session_fc); + } + Ok(()) + } +} + +// See https://www.w3.org/TR/webtransport/#receive-stream-stats +#[derive(Debug, Clone, Copy)] +pub struct RecvStreamStats { + // An indicator of progress on how many of the server application’s bytes + // intended for this stream have been received so far. + // Only sequential bytes up to, but not including, the first missing byte, + // are counted. This number can only increase. + pub bytes_received: u64, + // The total number of bytes the application has successfully read from this + // stream. This number can only increase, and is always less than or equal + // to bytes_received. + pub bytes_read: u64, +} + +impl RecvStreamStats { + #[must_use] + pub fn new(bytes_received: u64, bytes_read: u64) -> Self { + Self { + bytes_received, + bytes_read, + } + } + + #[must_use] + pub fn bytes_received(&self) -> u64 { + self.bytes_received + } + + #[must_use] + pub fn bytes_read(&self) -> u64 { + self.bytes_read + } +} + +/// Implement a QUIC receive stream. +#[derive(Debug)] +pub struct RecvStream { + stream_id: StreamId, + state: RecvStreamState, + conn_events: ConnectionEvents, + keep_alive: Option<Rc<()>>, +} + +impl RecvStream { + pub fn new( + stream_id: StreamId, + max_stream_data: u64, + session_fc: Rc<RefCell<ReceiverFlowControl<()>>>, + conn_events: ConnectionEvents, + ) -> Self { + Self { + stream_id, + state: RecvStreamState::new(max_stream_data, stream_id, session_fc), + conn_events, + keep_alive: None, + } + } + + fn set_state(&mut self, new_state: RecvStreamState) { + debug_assert_ne!( + mem::discriminant(&self.state), + mem::discriminant(&new_state) + ); + qtrace!( + "RecvStream {} state {} -> {}", + self.stream_id.as_u64(), + self.state.name(), + new_state.name() + ); + + match new_state { + // Receiving all data, or receiving or requesting RESET_STREAM + // is cause to stop keep-alives. + RecvStreamState::DataRecvd { .. } + | RecvStreamState::AbortReading { .. } + | RecvStreamState::ResetRecvd { .. } => { + self.keep_alive = None; + } + // Once all the data is read, generate an event. + RecvStreamState::DataRead { .. } => { + self.conn_events.recv_stream_complete(self.stream_id); + } + _ => {} + } + + self.state = new_state; + } + + pub fn stats(&self) -> RecvStreamStats { + match &self.state { + RecvStreamState::Recv { recv_buf, .. } + | RecvStreamState::SizeKnown { recv_buf, .. } + | RecvStreamState::DataRecvd { recv_buf, .. } => { + let received = recv_buf.received(); + let read = recv_buf.retired(); + RecvStreamStats::new(received, read) + } + RecvStreamState::AbortReading { + final_received, + final_read, + .. + } + | RecvStreamState::WaitForReset { + final_received, + final_read, + .. + } + | RecvStreamState::DataRead { + final_received, + final_read, + } + | RecvStreamState::ResetRecvd { + final_received, + final_read, + } => { + let received = *final_received; + let read = *final_read; + RecvStreamStats::new(received, read) + } + } + } + + pub fn inbound_stream_frame(&mut self, fin: bool, offset: u64, data: &[u8]) -> Res<()> { + // We should post a DataReadable event only once when we change from no-data-ready to + // data-ready. Therefore remember the state before processing a new frame. + let already_data_ready = self.data_ready(); + let new_end = offset + u64::try_from(data.len()).unwrap(); + + self.state.flow_control_consume_data(new_end, fin)?; + + match &mut self.state { + RecvStreamState::Recv { + recv_buf, + fc, + session_fc, + } => { + recv_buf.inbound_frame(offset, data); + if fin { + let all_recv = + fc.consumed() == recv_buf.retired() + recv_buf.bytes_ready() as u64; + let buf = mem::replace(recv_buf, RxStreamOrderer::new()); + let fc_copy = mem::take(fc); + let session_fc_copy = mem::take(session_fc); + if all_recv { + self.set_state(RecvStreamState::DataRecvd { + fc: fc_copy, + session_fc: session_fc_copy, + recv_buf: buf, + }); + } else { + self.set_state(RecvStreamState::SizeKnown { + fc: fc_copy, + session_fc: session_fc_copy, + recv_buf: buf, + }); + } + } + } + RecvStreamState::SizeKnown { + recv_buf, + fc, + session_fc, + } => { + recv_buf.inbound_frame(offset, data); + if fc.consumed() == recv_buf.retired() + recv_buf.bytes_ready() as u64 { + let buf = mem::replace(recv_buf, RxStreamOrderer::new()); + let fc_copy = mem::take(fc); + let session_fc_copy = mem::take(session_fc); + self.set_state(RecvStreamState::DataRecvd { + fc: fc_copy, + session_fc: session_fc_copy, + recv_buf: buf, + }); + } + } + RecvStreamState::DataRecvd { .. } + | RecvStreamState::DataRead { .. } + | RecvStreamState::AbortReading { .. } + | RecvStreamState::WaitForReset { .. } + | RecvStreamState::ResetRecvd { .. } => { + qtrace!("data received when we are in state {}", self.state.name()); + } + } + + if !already_data_ready && (self.data_ready() || self.needs_to_inform_app_about_fin()) { + self.conn_events.recv_stream_readable(self.stream_id); + } + + Ok(()) + } + + pub fn reset(&mut self, application_error_code: AppError, final_size: u64) -> Res<()> { + self.state.flow_control_consume_data(final_size, true)?; + match &mut self.state { + RecvStreamState::Recv { + fc, + session_fc, + recv_buf, + } + | RecvStreamState::SizeKnown { + fc, + session_fc, + recv_buf, + } => { + // make flow control consumes new data that not really exist. + Self::flow_control_retire_data(final_size - fc.retired(), fc, session_fc); + self.conn_events + .recv_stream_reset(self.stream_id, application_error_code); + let received = recv_buf.received(); + let read = recv_buf.retired(); + self.set_state(RecvStreamState::ResetRecvd { + final_received: received, + final_read: read, + }); + } + RecvStreamState::AbortReading { + fc, + session_fc, + final_received, + final_read, + .. + } + | RecvStreamState::WaitForReset { + fc, + session_fc, + final_received, + final_read, + } => { + // make flow control consumes new data that not really exist. + Self::flow_control_retire_data(final_size - fc.retired(), fc, session_fc); + self.conn_events + .recv_stream_reset(self.stream_id, application_error_code); + let received = *final_received; + let read = *final_read; + self.set_state(RecvStreamState::ResetRecvd { + final_received: received, + final_read: read, + }); + } + _ => { + // Ignore reset if in DataRecvd, DataRead, or ResetRecvd + } + } + Ok(()) + } + + /// If we should tell the sender they have more credit, return an offset + fn flow_control_retire_data( + new_read: u64, + fc: &mut ReceiverFlowControl<StreamId>, + session_fc: &mut Rc<RefCell<ReceiverFlowControl<()>>>, + ) { + if new_read > 0 { + fc.add_retired(new_read); + session_fc.borrow_mut().add_retired(new_read); + } + } + + /// Send a flow control update. + /// This is used when a peer declares that they are blocked. + /// This sends `MAX_STREAM_DATA` if there is any increase possible. + pub fn send_flowc_update(&mut self) { + if let RecvStreamState::Recv { fc, .. } = &mut self.state { + fc.send_flowc_update(); + } + } + + pub fn set_stream_max_data(&mut self, max_data: u64) { + if let RecvStreamState::Recv { fc, .. } = &mut self.state { + fc.set_max_active(max_data); + } + } + + pub fn is_terminal(&self) -> bool { + matches!( + self.state, + RecvStreamState::ResetRecvd { .. } | RecvStreamState::DataRead { .. } + ) + } + + // App got all data but did not get the fin signal. + fn needs_to_inform_app_about_fin(&self) -> bool { + matches!(self.state, RecvStreamState::DataRecvd { .. }) + } + + fn data_ready(&self) -> bool { + self.state + .recv_buf() + .map_or(false, RxStreamOrderer::data_ready) + } + + /// # Errors + /// + /// `NoMoreData` if data and fin bit were previously read by the application. + pub fn read(&mut self, buf: &mut [u8]) -> Res<(usize, bool)> { + let data_recvd_state = matches!(self.state, RecvStreamState::DataRecvd { .. }); + match &mut self.state { + RecvStreamState::Recv { + recv_buf, + fc, + session_fc, + } + | RecvStreamState::SizeKnown { + recv_buf, + fc, + session_fc, + .. + } + | RecvStreamState::DataRecvd { + recv_buf, + fc, + session_fc, + } => { + let bytes_read = recv_buf.read(buf); + Self::flow_control_retire_data(u64::try_from(bytes_read).unwrap(), fc, session_fc); + let fin_read = if data_recvd_state { + if recv_buf.buffered() == 0 { + let received = recv_buf.received(); + let read = recv_buf.retired(); + self.set_state(RecvStreamState::DataRead { + final_received: received, + final_read: read, + }); + true + } else { + false + } + } else { + false + }; + Ok((bytes_read, fin_read)) + } + RecvStreamState::DataRead { .. } + | RecvStreamState::AbortReading { .. } + | RecvStreamState::WaitForReset { .. } + | RecvStreamState::ResetRecvd { .. } => Err(Error::NoMoreData), + } + } + + pub fn stop_sending(&mut self, err: AppError) { + qtrace!("stop_sending called when in state {}", self.state.name()); + match &mut self.state { + RecvStreamState::Recv { + fc, + session_fc, + recv_buf, + } + | RecvStreamState::SizeKnown { + fc, + session_fc, + recv_buf, + } => { + // Retire data + Self::flow_control_retire_data(fc.consumed() - fc.retired(), fc, session_fc); + let fc_copy = mem::take(fc); + let session_fc_copy = mem::take(session_fc); + let received = recv_buf.received(); + let read = recv_buf.retired(); + self.set_state(RecvStreamState::AbortReading { + fc: fc_copy, + session_fc: session_fc_copy, + final_size_reached: matches!(self.state, RecvStreamState::SizeKnown { .. }), + frame_needed: true, + err, + final_received: received, + final_read: read, + }); + } + RecvStreamState::DataRecvd { + fc, + session_fc, + recv_buf, + } => { + Self::flow_control_retire_data(fc.consumed() - fc.retired(), fc, session_fc); + let received = recv_buf.received(); + let read = recv_buf.retired(); + self.set_state(RecvStreamState::DataRead { + final_received: received, + final_read: read, + }); + } + RecvStreamState::DataRead { .. } + | RecvStreamState::AbortReading { .. } + | RecvStreamState::WaitForReset { .. } + | RecvStreamState::ResetRecvd { .. } => { + // Already in terminal state + } + } + } + + /// Maybe write a `MAX_STREAM_DATA` frame. + pub fn write_frame( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + match &mut self.state { + // Maybe send MAX_STREAM_DATA + RecvStreamState::Recv { fc, .. } => fc.write_frames(builder, tokens, stats), + // Maybe send STOP_SENDING + RecvStreamState::AbortReading { + frame_needed, err, .. + } => { + if *frame_needed + && builder.write_varint_frame(&[ + FRAME_TYPE_STOP_SENDING, + self.stream_id.as_u64(), + *err, + ]) + { + tokens.push(RecoveryToken::Stream(StreamRecoveryToken::StopSending { + stream_id: self.stream_id, + })); + stats.stop_sending += 1; + *frame_needed = false; + } + } + _ => {} + } + } + + pub fn max_stream_data_lost(&mut self, maximum_data: u64) { + if let RecvStreamState::Recv { fc, .. } = &mut self.state { + fc.frame_lost(maximum_data); + } + } + + pub fn stop_sending_lost(&mut self) { + if let RecvStreamState::AbortReading { frame_needed, .. } = &mut self.state { + *frame_needed = true; + } + } + + pub fn stop_sending_acked(&mut self) { + if let RecvStreamState::AbortReading { + fc, + session_fc, + final_size_reached, + final_received, + final_read, + .. + } = &mut self.state + { + let received = *final_received; + let read = *final_read; + if *final_size_reached { + // We already know the final_size of the stream therefore we + // do not need to wait for RESET. + self.set_state(RecvStreamState::ResetRecvd { + final_received: received, + final_read: read, + }); + } else { + let fc_copy = mem::take(fc); + let session_fc_copy = mem::take(session_fc); + self.set_state(RecvStreamState::WaitForReset { + fc: fc_copy, + session_fc: session_fc_copy, + final_received: received, + final_read: read, + }); + } + } + } + + #[cfg(test)] + pub fn has_frames_to_write(&self) -> bool { + if let RecvStreamState::Recv { fc, .. } = &self.state { + fc.frame_needed() + } else { + false + } + } + + #[cfg(test)] + pub fn fc(&self) -> Option<&ReceiverFlowControl<StreamId>> { + match &self.state { + RecvStreamState::Recv { fc, .. } + | RecvStreamState::SizeKnown { fc, .. } + | RecvStreamState::DataRecvd { fc, .. } + | RecvStreamState::AbortReading { fc, .. } + | RecvStreamState::WaitForReset { fc, .. } => Some(fc), + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use std::ops::Range; + + use neqo_common::Encoder; + + use super::*; + + const SESSION_WINDOW: usize = 1024; + + fn recv_ranges(ranges: &[Range<u64>], available: usize) { + const ZEROES: &[u8] = &[0; 100]; + qtrace!("recv_ranges {:?}", ranges); + + let mut s = RxStreamOrderer::default(); + for r in ranges { + let data = &ZEROES[..usize::try_from(r.end - r.start).unwrap()]; + s.inbound_frame(r.start, data); + } + + let mut buf = [0xff; 100]; + let mut total_recvd = 0; + loop { + let recvd = s.read(&mut buf[..]); + qtrace!("recv_ranges read {}", recvd); + total_recvd += recvd; + if recvd == 0 { + assert_eq!(total_recvd, available); + break; + } + } + } + + #[test] + #[allow(unknown_lints, clippy::single_range_in_vec_init)] // Because that lint makes no sense here. + fn recv_noncontiguous() { + // Non-contiguous with the start, no data available. + recv_ranges(&[10..20], 0); + } + + /// Overlaps with the start of a 10..20 range of bytes. + #[test] + fn recv_overlap_start() { + // Overlap the start, with a larger new value. + // More overlap than not. + recv_ranges(&[10..20, 4..18, 0..4], 20); + // Overlap the start, with a larger new value. + // Less overlap than not. + recv_ranges(&[10..20, 2..15, 0..2], 20); + // Overlap the start, with a smaller new value. + // More overlap than not. + recv_ranges(&[10..20, 8..14, 0..8], 20); + // Overlap the start, with a smaller new value. + // Less overlap than not. + recv_ranges(&[10..20, 6..13, 0..6], 20); + + // Again with some of the first range split in two. + recv_ranges(&[10..11, 11..20, 4..18, 0..4], 20); + recv_ranges(&[10..11, 11..20, 2..15, 0..2], 20); + recv_ranges(&[10..11, 11..20, 8..14, 0..8], 20); + recv_ranges(&[10..11, 11..20, 6..13, 0..6], 20); + + // Again with a gap in the first range. + recv_ranges(&[10..11, 12..20, 4..18, 0..4], 20); + recv_ranges(&[10..11, 12..20, 2..15, 0..2], 20); + recv_ranges(&[10..11, 12..20, 8..14, 0..8], 20); + recv_ranges(&[10..11, 12..20, 6..13, 0..6], 20); + } + + /// Overlaps with the end of a 10..20 range of bytes. + #[test] + fn recv_overlap_end() { + // Overlap the end, with a larger new value. + // More overlap than not. + recv_ranges(&[10..20, 12..25, 0..10], 25); + // Overlap the end, with a larger new value. + // Less overlap than not. + recv_ranges(&[10..20, 17..33, 0..10], 33); + // Overlap the end, with a smaller new value. + // More overlap than not. + recv_ranges(&[10..20, 15..21, 0..10], 21); + // Overlap the end, with a smaller new value. + // Less overlap than not. + recv_ranges(&[10..20, 17..25, 0..10], 25); + + // Again with some of the first range split in two. + recv_ranges(&[10..19, 19..20, 12..25, 0..10], 25); + recv_ranges(&[10..19, 19..20, 17..33, 0..10], 33); + recv_ranges(&[10..19, 19..20, 15..21, 0..10], 21); + recv_ranges(&[10..19, 19..20, 17..25, 0..10], 25); + + // Again with a gap in the first range. + recv_ranges(&[10..18, 19..20, 12..25, 0..10], 25); + recv_ranges(&[10..18, 19..20, 17..33, 0..10], 33); + recv_ranges(&[10..18, 19..20, 15..21, 0..10], 21); + recv_ranges(&[10..18, 19..20, 17..25, 0..10], 25); + } + + /// Complete overlaps with the start of a 10..20 range of bytes. + #[test] + fn recv_overlap_complete() { + // Complete overlap, more at the end. + recv_ranges(&[10..20, 9..23, 0..9], 23); + // Complete overlap, more at the start. + recv_ranges(&[10..20, 3..23, 0..3], 23); + // Complete overlap, to end. + recv_ranges(&[10..20, 5..20, 0..5], 20); + // Complete overlap, from start. + recv_ranges(&[10..20, 10..27, 0..10], 27); + // Complete overlap, from 0 and more. + recv_ranges(&[10..20, 0..23], 23); + + // Again with the first range split in two. + recv_ranges(&[10..14, 14..20, 9..23, 0..9], 23); + recv_ranges(&[10..14, 14..20, 3..23, 0..3], 23); + recv_ranges(&[10..14, 14..20, 5..20, 0..5], 20); + recv_ranges(&[10..14, 14..20, 10..27, 0..10], 27); + recv_ranges(&[10..14, 14..20, 0..23], 23); + + // Again with the a gap in the first range. + recv_ranges(&[10..13, 14..20, 9..23, 0..9], 23); + recv_ranges(&[10..13, 14..20, 3..23, 0..3], 23); + recv_ranges(&[10..13, 14..20, 5..20, 0..5], 20); + recv_ranges(&[10..13, 14..20, 10..27, 0..10], 27); + recv_ranges(&[10..13, 14..20, 0..23], 23); + } + + /// An overlap with no new bytes. + #[test] + fn recv_overlap_duplicate() { + recv_ranges(&[10..20, 11..12, 0..10], 20); + recv_ranges(&[10..20, 10..15, 0..10], 20); + recv_ranges(&[10..20, 14..20, 0..10], 20); + // Now with the first range split. + recv_ranges(&[10..14, 14..20, 10..15, 0..10], 20); + recv_ranges(&[10..15, 16..20, 21..25, 10..25, 0..10], 25); + } + + /// Reading exactly one chunk works, when the next chunk starts immediately. + #[test] + fn stop_reading_at_chunk() { + const CHUNK_SIZE: usize = 10; + const EXTRA_SIZE: usize = 3; + let mut s = RxStreamOrderer::new(); + + // Add three chunks. + s.inbound_frame(0, &[0; CHUNK_SIZE]); + let offset = u64::try_from(CHUNK_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + let offset = u64::try_from(CHUNK_SIZE + EXTRA_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + + // Read, providing only enough space for the first. + let mut buf = [0; 100]; + let count = s.read(&mut buf[..CHUNK_SIZE]); + assert_eq!(count, CHUNK_SIZE); + let count = s.read(&mut buf[..]); + assert_eq!(count, EXTRA_SIZE * 2); + } + + #[test] + fn recv_overlap_while_reading() { + let mut s = RxStreamOrderer::new(); + + // Add a chunk + s.inbound_frame(0, &[0; 150]); + assert_eq!(s.data_ranges.get(&0).unwrap().len(), 150); + // Read, providing only enough space for the first 100. + let mut buf = [0; 100]; + let count = s.read(&mut buf[..]); + assert_eq!(count, 100); + assert_eq!(s.retired, 100); + + // Add a second frame that overlaps. + // This shouldn't truncate the first frame, as we're already + // Reading from it. + s.inbound_frame(120, &[0; 60]); + assert_eq!(s.data_ranges.get(&0).unwrap().len(), 180); + // Read second part of first frame and all of the second frame + let count = s.read(&mut buf[..]); + assert_eq!(count, 80); + } + + /// Reading exactly one chunk works, when there is a gap. + #[test] + fn stop_reading_at_gap() { + const CHUNK_SIZE: usize = 10; + const EXTRA_SIZE: usize = 3; + let mut s = RxStreamOrderer::new(); + + // Add three chunks. + s.inbound_frame(0, &[0; CHUNK_SIZE]); + let offset = u64::try_from(CHUNK_SIZE + EXTRA_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + + // Read, providing only enough space for the first chunk. + let mut buf = [0; 100]; + let count = s.read(&mut buf[..CHUNK_SIZE]); + assert_eq!(count, CHUNK_SIZE); + + // Now fill the gap and ensure that everything can be read. + let offset = u64::try_from(CHUNK_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + let count = s.read(&mut buf[..]); + assert_eq!(count, EXTRA_SIZE * 2); + } + + /// Reading exactly one chunk works, when there is a gap. + #[test] + fn stop_reading_in_chunk() { + const CHUNK_SIZE: usize = 10; + const EXTRA_SIZE: usize = 3; + let mut s = RxStreamOrderer::new(); + + // Add two chunks. + s.inbound_frame(0, &[0; CHUNK_SIZE]); + let offset = u64::try_from(CHUNK_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + + // Read, providing only enough space for some of the first chunk. + let mut buf = [0; 100]; + let count = s.read(&mut buf[..CHUNK_SIZE - EXTRA_SIZE]); + assert_eq!(count, CHUNK_SIZE - EXTRA_SIZE); + + let count = s.read(&mut buf[..]); + assert_eq!(count, EXTRA_SIZE * 2); + } + + /// Read one byte at a time. + #[test] + fn read_byte_at_a_time() { + const CHUNK_SIZE: usize = 10; + const EXTRA_SIZE: usize = 3; + let mut s = RxStreamOrderer::new(); + + // Add two chunks. + s.inbound_frame(0, &[0; CHUNK_SIZE]); + let offset = u64::try_from(CHUNK_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + + let mut buf = [0; 1]; + for _ in 0..CHUNK_SIZE + EXTRA_SIZE { + let count = s.read(&mut buf[..]); + assert_eq!(count, 1); + } + assert_eq!(0, s.read(&mut buf[..])); + } + + fn check_stats(stream: &RecvStream, expected_received: u64, expected_read: u64) { + let stream_stats = stream.stats(); + assert_eq!(expected_received, stream_stats.bytes_received()); + assert_eq!(expected_read, stream_stats.bytes_read()); + } + + #[test] + fn stream_rx() { + let conn_events = ConnectionEvents::default(); + + let mut s = RecvStream::new( + StreamId::from(567), + 1024, + Rc::new(RefCell::new(ReceiverFlowControl::new((), 1024 * 1024))), + conn_events, + ); + + // test receiving a contig frame and reading it works + s.inbound_stream_frame(false, 0, &[1; 10]).unwrap(); + assert!(s.data_ready()); + check_stats(&s, 10, 0); + + let mut buf = vec![0u8; 100]; + assert_eq!(s.read(&mut buf).unwrap(), (10, false)); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 0); + + check_stats(&s, 10, 10); + + // test receiving a noncontig frame + s.inbound_stream_frame(false, 12, &[2; 12]).unwrap(); + assert!(!s.data_ready()); + assert_eq!(s.read(&mut buf).unwrap(), (0, false)); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 12); + + check_stats(&s, 22, 10); + + // another frame that overlaps the first + s.inbound_stream_frame(false, 14, &[3; 8]).unwrap(); + assert!(!s.data_ready()); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 12); + + check_stats(&s, 22, 10); + + // fill in the gap, but with a FIN + s.inbound_stream_frame(true, 10, &[4; 6]).unwrap_err(); + assert!(!s.data_ready()); + assert_eq!(s.read(&mut buf).unwrap(), (0, false)); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 12); + + check_stats(&s, 22, 10); + + // fill in the gap + s.inbound_stream_frame(false, 10, &[5; 10]).unwrap(); + assert!(s.data_ready()); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 14); + + check_stats(&s, 24, 10); + + // a legit FIN + s.inbound_stream_frame(true, 24, &[6; 18]).unwrap(); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 32); + assert!(s.data_ready()); + assert_eq!(s.read(&mut buf).unwrap(), (32, true)); + + check_stats(&s, 42, 42); + + // Stream now no longer readable (is in DataRead state) + s.read(&mut buf).unwrap_err(); + } + + fn check_chunks(s: &mut RxStreamOrderer, expected: &[(u64, usize)]) { + assert_eq!(s.data_ranges.len(), expected.len()); + for ((start, buf), (expected_start, expected_len)) in s.data_ranges.iter().zip(expected) { + assert_eq!((*start, buf.len()), (*expected_start, *expected_len)); + } + } + + // Test deduplication when the new data is at the end. + #[test] + fn stream_rx_dedupe_tail() { + let mut s = RxStreamOrderer::new(); + + s.inbound_frame(0, &[1; 6]); + check_chunks(&mut s, &[(0, 6)]); + + // New data that overlaps entirely (starting from the head), is ignored. + s.inbound_frame(0, &[2; 3]); + check_chunks(&mut s, &[(0, 6)]); + + // New data that overlaps at the tail has any new data appended. + s.inbound_frame(2, &[3; 6]); + check_chunks(&mut s, &[(0, 8)]); + + // New data that overlaps entirely (up to the tail), is ignored. + s.inbound_frame(4, &[4; 4]); + check_chunks(&mut s, &[(0, 8)]); + + // New data that overlaps, starting from the beginning is appended too. + s.inbound_frame(0, &[5; 10]); + check_chunks(&mut s, &[(0, 10)]); + + // New data that is entirely subsumed is ignored. + s.inbound_frame(2, &[6; 2]); + check_chunks(&mut s, &[(0, 10)]); + + let mut buf = [0; 16]; + assert_eq!(s.read(&mut buf[..]), 10); + assert_eq!(buf[..10], [1, 1, 1, 1, 1, 1, 3, 3, 5, 5]); + } + + /// When chunks are added before existing data, they aren't merged. + #[test] + fn stream_rx_dedupe_head() { + let mut s = RxStreamOrderer::new(); + + s.inbound_frame(1, &[6; 6]); + check_chunks(&mut s, &[(1, 6)]); + + // Insertion before an existing chunk causes truncation of the new chunk. + s.inbound_frame(0, &[7; 6]); + check_chunks(&mut s, &[(0, 1), (1, 6)]); + + // Perfect overlap with existing slices has no effect. + s.inbound_frame(0, &[8; 7]); + check_chunks(&mut s, &[(0, 1), (1, 6)]); + + let mut buf = [0; 16]; + assert_eq!(s.read(&mut buf[..]), 7); + assert_eq!(buf[..7], [7, 6, 6, 6, 6, 6, 6]); + } + + #[test] + fn stream_rx_dedupe_new_tail() { + let mut s = RxStreamOrderer::new(); + + s.inbound_frame(1, &[6; 6]); + check_chunks(&mut s, &[(1, 6)]); + + // Insertion before an existing chunk causes truncation of the new chunk. + s.inbound_frame(0, &[7; 6]); + check_chunks(&mut s, &[(0, 1), (1, 6)]); + + // New data at the end causes the tail to be added to the first chunk, + // replacing later chunks entirely. + s.inbound_frame(0, &[9; 8]); + check_chunks(&mut s, &[(0, 8)]); + + let mut buf = [0; 16]; + assert_eq!(s.read(&mut buf[..]), 8); + assert_eq!(buf[..8], [7, 9, 9, 9, 9, 9, 9, 9]); + } + + #[test] + fn stream_rx_dedupe_replace() { + let mut s = RxStreamOrderer::new(); + + s.inbound_frame(2, &[6; 6]); + check_chunks(&mut s, &[(2, 6)]); + + // Insertion before an existing chunk causes truncation of the new chunk. + s.inbound_frame(1, &[7; 6]); + check_chunks(&mut s, &[(1, 1), (2, 6)]); + + // New data at the start and end replaces all the slices. + s.inbound_frame(0, &[9; 10]); + check_chunks(&mut s, &[(0, 10)]); + + let mut buf = [0; 16]; + assert_eq!(s.read(&mut buf[..]), 10); + assert_eq!(buf[..10], [9; 10]); + } + + #[test] + fn trim_retired() { + let mut s = RxStreamOrderer::new(); + + let mut buf = [0; 18]; + s.inbound_frame(0, &[1; 10]); + + // Partially read slices are retained. + assert_eq!(s.read(&mut buf[..6]), 6); + check_chunks(&mut s, &[(0, 10)]); + + // Partially read slices are kept and so are added to. + s.inbound_frame(3, &buf[..10]); + check_chunks(&mut s, &[(0, 13)]); + + // Wholly read pieces are dropped. + assert_eq!(s.read(&mut buf[..]), 7); + assert!(s.data_ranges.is_empty()); + + // New data that overlaps with retired data is trimmed. + s.inbound_frame(0, &buf[..]); + check_chunks(&mut s, &[(13, 5)]); + } + + #[test] + fn stream_flowc_update() { + let mut s = create_stream(1024 * RX_STREAM_DATA_WINDOW); + let mut buf = vec![0u8; RECV_BUFFER_SIZE + 100]; // Make it overlarge + + assert!(!s.has_frames_to_write()); + s.inbound_stream_frame(false, 0, &[0; RECV_BUFFER_SIZE]) + .unwrap(); + assert!(!s.has_frames_to_write()); + assert_eq!(s.read(&mut buf).unwrap(), (RECV_BUFFER_SIZE, false)); + assert!(!s.data_ready()); + + // flow msg generated! + assert!(s.has_frames_to_write()); + + // consume it + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut token = Vec::new(); + s.write_frame(&mut builder, &mut token, &mut FrameStats::default()); + + // it should be gone + assert!(!s.has_frames_to_write()); + } + + fn create_stream(session_fc: u64) -> RecvStream { + let conn_events = ConnectionEvents::default(); + RecvStream::new( + StreamId::from(67), + RX_STREAM_DATA_WINDOW, + Rc::new(RefCell::new(ReceiverFlowControl::new((), session_fc))), + conn_events, + ) + } + + #[test] + fn stream_max_stream_data() { + let mut s = create_stream(1024 * RX_STREAM_DATA_WINDOW); + assert!(!s.has_frames_to_write()); + s.inbound_stream_frame(false, 0, &[0; RECV_BUFFER_SIZE]) + .unwrap(); + s.inbound_stream_frame(false, RX_STREAM_DATA_WINDOW, &[1; 1]) + .unwrap_err(); + } + + #[test] + fn stream_orderer_bytes_ready() { + let mut rx_ord = RxStreamOrderer::new(); + + rx_ord.inbound_frame(0, &[1; 6]); + assert_eq!(rx_ord.bytes_ready(), 6); + assert_eq!(rx_ord.buffered(), 6); + assert_eq!(rx_ord.retired(), 0); + + // read some so there's an offset into the first frame + let mut buf = [0u8; 10]; + rx_ord.read(&mut buf[..2]); + assert_eq!(rx_ord.bytes_ready(), 4); + assert_eq!(rx_ord.buffered(), 4); + assert_eq!(rx_ord.retired(), 2); + + // an overlapping frame + rx_ord.inbound_frame(5, &[2; 6]); + assert_eq!(rx_ord.bytes_ready(), 9); + assert_eq!(rx_ord.buffered(), 9); + assert_eq!(rx_ord.retired(), 2); + + // a noncontig frame + rx_ord.inbound_frame(20, &[3; 6]); + assert_eq!(rx_ord.bytes_ready(), 9); + assert_eq!(rx_ord.buffered(), 15); + assert_eq!(rx_ord.retired(), 2); + + // an old frame + rx_ord.inbound_frame(0, &[4; 2]); + assert_eq!(rx_ord.bytes_ready(), 9); + assert_eq!(rx_ord.buffered(), 15); + assert_eq!(rx_ord.retired(), 2); + } + + #[test] + fn no_stream_flowc_event_after_exiting_recv() { + let mut s = create_stream(1024 * RX_STREAM_DATA_WINDOW); + s.inbound_stream_frame(false, 0, &[0; RECV_BUFFER_SIZE]) + .unwrap(); + let mut buf = [0; RECV_BUFFER_SIZE]; + s.read(&mut buf).unwrap(); + assert!(s.has_frames_to_write()); + s.inbound_stream_frame(true, RX_STREAM_DATA_WINDOW, &[]) + .unwrap(); + assert!(!s.has_frames_to_write()); + } + + fn create_stream_with_fc( + session_fc: Rc<RefCell<ReceiverFlowControl<()>>>, + fc_limit: u64, + ) -> RecvStream { + RecvStream::new( + StreamId::from(567), + fc_limit, + session_fc, + ConnectionEvents::default(), + ) + } + + fn create_stream_session_flow_control() -> (RecvStream, Rc<RefCell<ReceiverFlowControl<()>>>) { + assert!(RX_STREAM_DATA_WINDOW > u64::try_from(SESSION_WINDOW).unwrap()); + let session_fc = Rc::new(RefCell::new(ReceiverFlowControl::new( + (), + u64::try_from(SESSION_WINDOW).unwrap(), + ))); + ( + create_stream_with_fc(Rc::clone(&session_fc), RX_STREAM_DATA_WINDOW), + session_fc, + ) + } + + #[test] + fn session_flow_control() { + let (mut s, session_fc) = create_stream_session_flow_control(); + + s.inbound_stream_frame(false, 0, &[0; SESSION_WINDOW]) + .unwrap(); + assert!(!session_fc.borrow().frame_needed()); + // The buffer is big enough to hold SESSION_WINDOW, this will make sure that we always + // read everything from he stream. + let mut buf = [0; 2 * SESSION_WINDOW]; + s.read(&mut buf).unwrap(); + assert!(session_fc.borrow().frame_needed()); + // consume it + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut token = Vec::new(); + session_fc + .borrow_mut() + .write_frames(&mut builder, &mut token, &mut FrameStats::default()); + + // Switch to SizeKnown state + s.inbound_stream_frame(true, 2 * u64::try_from(SESSION_WINDOW).unwrap() - 1, &[0]) + .unwrap(); + assert!(!session_fc.borrow().frame_needed()); + // Receive new data that can be read. + s.inbound_stream_frame( + false, + u64::try_from(SESSION_WINDOW).unwrap(), + &[0; SESSION_WINDOW / 2 + 1], + ) + .unwrap(); + assert!(!session_fc.borrow().frame_needed()); + s.read(&mut buf).unwrap(); + assert!(session_fc.borrow().frame_needed()); + // consume it + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut token = Vec::new(); + session_fc + .borrow_mut() + .write_frames(&mut builder, &mut token, &mut FrameStats::default()); + + // Test DataRecvd state + let session_fc = Rc::new(RefCell::new(ReceiverFlowControl::new( + (), + u64::try_from(SESSION_WINDOW).unwrap(), + ))); + let mut s = RecvStream::new( + StreamId::from(567), + RX_STREAM_DATA_WINDOW, + Rc::clone(&session_fc), + ConnectionEvents::default(), + ); + + s.inbound_stream_frame(true, 0, &[0; SESSION_WINDOW]) + .unwrap(); + assert!(!session_fc.borrow().frame_needed()); + s.read(&mut buf).unwrap(); + assert!(session_fc.borrow().frame_needed()); + } + + #[test] + fn session_flow_control_reset() { + let (mut s, session_fc) = create_stream_session_flow_control(); + + s.inbound_stream_frame(false, 0, &[0; SESSION_WINDOW / 2]) + .unwrap(); + assert!(!session_fc.borrow().frame_needed()); + + s.reset( + Error::NoError.code(), + u64::try_from(SESSION_WINDOW).unwrap(), + ) + .unwrap(); + assert!(session_fc.borrow().frame_needed()); + } + + fn check_fc<T: std::fmt::Debug>(fc: &ReceiverFlowControl<T>, consumed: u64, retired: u64) { + assert_eq!(fc.consumed(), consumed); + assert_eq!(fc.retired(), retired); + } + + /// Test consuming the flow control in RecvStreamState::Recv + #[test] + fn fc_state_recv_1() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + + check_fc(&fc.borrow(), 0, 0); + check_fc(s.fc().unwrap(), 0, 0); + + s.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap(); + + check_fc(&fc.borrow(), SW / 4, 0); + check_fc(s.fc().unwrap(), SW / 4, 0); + } + + /// Test consuming the flow control in RecvStreamState::Recv + /// with multiple streams + #[test] + fn fc_state_recv_2() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + let mut s1 = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + let mut s2 = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + + check_fc(&fc.borrow(), 0, 0); + check_fc(s1.fc().unwrap(), 0, 0); + check_fc(s2.fc().unwrap(), 0, 0); + + s1.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap(); + + check_fc(&fc.borrow(), SW / 4, 0); + check_fc(s1.fc().unwrap(), SW / 4, 0); + check_fc(s2.fc().unwrap(), 0, 0); + + s2.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap(); + + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s1.fc().unwrap(), SW / 4, 0); + check_fc(s2.fc().unwrap(), SW / 4, 0); + } + + /// Test retiring the flow control in RecvStreamState::Recv + /// with multiple streams + #[test] + fn fc_state_recv_3() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + let mut s1 = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + let mut s2 = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + + check_fc(&fc.borrow(), 0, 0); + check_fc(s1.fc().unwrap(), 0, 0); + check_fc(s2.fc().unwrap(), 0, 0); + + s1.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap(); + s2.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s1.fc().unwrap(), SW / 4, 0); + check_fc(s2.fc().unwrap(), SW / 4, 0); + + // Read data + let mut buf = [1; SW_US]; + assert_eq!(s1.read(&mut buf).unwrap(), (SW_US / 4, false)); + check_fc(&fc.borrow(), SW / 2, SW / 4); + check_fc(s1.fc().unwrap(), SW / 4, SW / 4); + check_fc(s2.fc().unwrap(), SW / 4, 0); + + assert_eq!(s2.read(&mut buf).unwrap(), (SW_US / 4, false)); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s1.fc().unwrap(), SW / 4, SW / 4); + check_fc(s2.fc().unwrap(), SW / 4, SW / 4); + + // Read when there is no more date to be read will not change fc. + assert_eq!(s1.read(&mut buf).unwrap(), (0, false)); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s1.fc().unwrap(), SW / 4, SW / 4); + check_fc(s2.fc().unwrap(), SW / 4, SW / 4); + + // Receiving more data on a stream. + s1.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4]) + .unwrap(); + check_fc(&fc.borrow(), SW * 3 / 4, SW / 2); + check_fc(s1.fc().unwrap(), SW / 2, SW / 4); + check_fc(s2.fc().unwrap(), SW / 4, SW / 4); + + // Read data + assert_eq!(s1.read(&mut buf).unwrap(), (SW_US / 4, false)); + check_fc(&fc.borrow(), SW * 3 / 4, SW * 3 / 4); + check_fc(s1.fc().unwrap(), SW / 2, SW / 2); + check_fc(s2.fc().unwrap(), SW / 4, SW / 4); + } + + /// Test consuming the flow control in RecvStreamState::Recv - duplicate data + #[test] + fn fc_state_recv_4() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + + check_fc(&fc.borrow(), 0, 0); + check_fc(s.fc().unwrap(), 0, 0); + + s.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap(); + + check_fc(&fc.borrow(), SW / 4, 0); + check_fc(s.fc().unwrap(), SW / 4, 0); + + // Receiving duplicate frames (already consumed data) will not cause an error or + // change fc. + s.inbound_stream_frame(false, 0, &[0; SW_US / 8]).unwrap(); + check_fc(&fc.borrow(), SW / 4, 0); + check_fc(s.fc().unwrap(), SW / 4, 0); + } + + /// Test consuming the flow control in RecvStreamState::Recv - filling a gap in the + /// data stream. + #[test] + fn fc_state_recv_5() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + + // Receive out of order data. + s.inbound_stream_frame(false, SW / 8, &[0; SW_US / 8]) + .unwrap(); + check_fc(&fc.borrow(), SW / 4, 0); + check_fc(s.fc().unwrap(), SW / 4, 0); + + // Filling in the gap will not change fc. + s.inbound_stream_frame(false, 0, &[0; SW_US / 8]).unwrap(); + check_fc(&fc.borrow(), SW / 4, 0); + check_fc(s.fc().unwrap(), SW / 4, 0); + } + + /// Test consuming the flow control in RecvStreamState::Recv - receiving frame past + /// the flow control will cause an error. + #[test] + fn fc_state_recv_6() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + + // Receiving frame past the flow control will cause an error. + assert_eq!( + s.inbound_stream_frame(false, 0, &[0; SW_US * 3 / 4 + 1]), + Err(Error::FlowControlError) + ); + } + + /// Test that the flow controls will send updates. + #[test] + fn fc_state_recv_7() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + let mut s = create_stream_with_fc(Rc::clone(&fc), SW / 2); + + check_fc(&fc.borrow(), 0, 0); + check_fc(s.fc().unwrap(), 0, 0); + + s.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap(); + let mut buf = [1; SW_US]; + assert_eq!(s.read(&mut buf).unwrap(), (SW_US / 4, false)); + check_fc(&fc.borrow(), SW / 4, SW / 4); + check_fc(s.fc().unwrap(), SW / 4, SW / 4); + + // Still no fc update needed. + assert!(!fc.borrow().frame_needed()); + assert!(!s.fc().unwrap().frame_needed()); + + // Receive one more byte that will cause a fc update after it is read. + s.inbound_stream_frame(false, SW / 4, &[0]).unwrap(); + check_fc(&fc.borrow(), SW / 4 + 1, SW / 4); + check_fc(s.fc().unwrap(), SW / 4 + 1, SW / 4); + // Only consuming data does not cause a fc update to be sent. + assert!(!fc.borrow().frame_needed()); + assert!(!s.fc().unwrap().frame_needed()); + + assert_eq!(s.read(&mut buf).unwrap(), (1, false)); + check_fc(&fc.borrow(), SW / 4 + 1, SW / 4 + 1); + check_fc(s.fc().unwrap(), SW / 4 + 1, SW / 4 + 1); + // Data are retired and the sttream fc will send an update. + assert!(!fc.borrow().frame_needed()); + assert!(s.fc().unwrap().frame_needed()); + + // Receive more data to increase fc further. + s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4]) + .unwrap(); + assert_eq!(s.read(&mut buf).unwrap(), (SW_US / 4 - 1, false)); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + assert!(!fc.borrow().frame_needed()); + assert!(s.fc().unwrap().frame_needed()); + + // Write the fc update frame + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut token = Vec::new(); + let mut stats = FrameStats::default(); + fc.borrow_mut() + .write_frames(&mut builder, &mut token, &mut stats); + assert_eq!(stats.max_data, 0); + s.write_frame(&mut builder, &mut token, &mut stats); + assert_eq!(stats.max_stream_data, 1); + + // Receive 1 byte that will case a session fc update after it is read. + s.inbound_stream_frame(false, SW / 2, &[0]).unwrap(); + assert_eq!(s.read(&mut buf).unwrap(), (1, false)); + check_fc(&fc.borrow(), SW / 2 + 1, SW / 2 + 1); + check_fc(s.fc().unwrap(), SW / 2 + 1, SW / 2 + 1); + assert!(fc.borrow().frame_needed()); + assert!(!s.fc().unwrap().frame_needed()); + fc.borrow_mut() + .write_frames(&mut builder, &mut token, &mut stats); + assert_eq!(stats.max_data, 1); + s.write_frame(&mut builder, &mut token, &mut stats); + assert_eq!(stats.max_stream_data, 1); + } + + /// Test flow control in RecvStreamState::SizeKnown + #[test] + fn fc_state_size_known() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + + let mut s = create_stream_with_fc(Rc::clone(&fc), SW); + + check_fc(&fc.borrow(), 0, 0); + check_fc(s.fc().unwrap(), 0, 0); + + s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4]) + .unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + // Receiving duplicate frames (already consumed data) will not cause an error or + // change fc. + s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4]) + .unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + // The stream can still receive duplicate data without a fin bit. + s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4]) + .unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + // Receiving frame past the final size of a stream will return an error. + assert_eq!( + s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4 + 1]), + Err(Error::FinalSizeError) + ); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + // Add new data to the gap will not change fc. + s.inbound_stream_frame(false, SW / 8, &[0; SW_US / 8]) + .unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + // Fill the gap + s.inbound_stream_frame(false, 0, &[0; SW_US / 8]).unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + // Read all data + let mut buf = [1; SW_US]; + assert_eq!(s.read(&mut buf).unwrap(), (SW_US / 2, true)); + // the stream does not have fc any more. We can only check the session fc. + check_fc(&fc.borrow(), SW / 2, SW / 2); + assert!(s.fc().is_none()); + } + + /// Test flow control in RecvStreamState::DataRecvd + #[test] + fn fc_state_data_recv() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + + let mut s = create_stream_with_fc(Rc::clone(&fc), SW); + + check_fc(&fc.borrow(), 0, 0); + check_fc(s.fc().unwrap(), 0, 0); + + s.inbound_stream_frame(true, 0, &[0; SW_US / 2]).unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + // Receiving duplicate frames (already consumed data) will not cause an error or + // change fc. + s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4]) + .unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + // The stream can still receive duplicate data without a fin bit. + s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4]) + .unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + // Receiving frame past the final size of a stream will return an error. + assert_eq!( + s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4 + 1]), + Err(Error::FinalSizeError) + ); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + // Read all data + let mut buf = [1; SW_US]; + assert_eq!(s.read(&mut buf).unwrap(), (SW_US / 2, true)); + // the stream does not have fc any more. We can only check the session fc. + check_fc(&fc.borrow(), SW / 2, SW / 2); + assert!(s.fc().is_none()); + } + + /// Test flow control in RecvStreamState::DataRead + #[test] + fn fc_state_data_read() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + + let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + check_fc(&fc.borrow(), 0, 0); + check_fc(s.fc().unwrap(), 0, 0); + + s.inbound_stream_frame(true, 0, &[0; SW_US / 2]).unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + let mut buf = [1; SW_US]; + assert_eq!(s.read(&mut buf).unwrap(), (SW_US / 2, true)); + // the stream does not have fc any more. We can only check the session fc. + check_fc(&fc.borrow(), SW / 2, SW / 2); + assert!(s.fc().is_none()); + + // Receiving duplicate frames (already consumed data) will not cause an error or + // change fc. + s.inbound_stream_frame(true, 0, &[0; SW_US / 2]).unwrap(); + // the stream does not have fc any more. We can only check the session fc. + check_fc(&fc.borrow(), SW / 2, SW / 2); + assert!(s.fc().is_none()); + + // Receiving frame past the final size of a stream or the stream's fc limit + // will NOT return an error. + s.inbound_stream_frame(true, 0, &[0; SW_US / 2 + 1]) + .unwrap(); + s.inbound_stream_frame(true, 0, &[0; SW_US * 3 / 4 + 1]) + .unwrap(); + check_fc(&fc.borrow(), SW / 2, SW / 2); + assert!(s.fc().is_none()); + } + + /// Test flow control in RecvStreamState::AbortReading and final size is known + #[test] + fn fc_state_abort_reading_1() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + + let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + check_fc(&fc.borrow(), 0, 0); + check_fc(s.fc().unwrap(), 0, 0); + + s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4]) + .unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + s.stop_sending(Error::NoError.code()); + // All data will de retired + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + + // Receiving duplicate frames (already consumed data) will not cause an error or + // change fc. + s.inbound_stream_frame(true, 0, &[0; SW_US / 2]).unwrap(); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + + // The stream can still receive duplicate data without a fin bit. + s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4]) + .unwrap(); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + + // Receiving frame past the final size of a stream will return an error. + assert_eq!( + s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4 + 1]), + Err(Error::FinalSizeError) + ); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + } + + /// Test flow control in RecvStreamState::AbortReading and final size is unknown + #[test] + fn fc_state_abort_reading_2() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + + let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + check_fc(&fc.borrow(), 0, 0); + check_fc(s.fc().unwrap(), 0, 0); + + s.inbound_stream_frame(false, 0, &[0; SW_US / 2]).unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + s.stop_sending(Error::NoError.code()); + // All data will de retired + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + + // Receiving duplicate frames (already consumed data) will not cause an error or + // change fc. + s.inbound_stream_frame(false, 0, &[0; SW_US / 2]).unwrap(); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + + // Receiving data past the flow control limit will cause an error. + assert_eq!( + s.inbound_stream_frame(false, 0, &[0; SW_US * 3 / 4 + 1]), + Err(Error::FlowControlError) + ); + + // The stream can still receive duplicate data without a fin bit. + s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4]) + .unwrap(); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + + // Receiving more data will case the data to be retired. + // The stream can still receive duplicate data without a fin bit. + s.inbound_stream_frame(false, SW / 2, &[0; 10]).unwrap(); + check_fc(&fc.borrow(), SW / 2 + 10, SW / 2 + 10); + check_fc(s.fc().unwrap(), SW / 2 + 10, SW / 2 + 10); + + // We can still receive the final size. + s.inbound_stream_frame(true, SW / 2, &[0; 20]).unwrap(); + check_fc(&fc.borrow(), SW / 2 + 20, SW / 2 + 20); + check_fc(s.fc().unwrap(), SW / 2 + 20, SW / 2 + 20); + + // Receiving frame past the final size of a stream will return an error. + assert_eq!( + s.inbound_stream_frame(true, SW / 2, &[0; 21]), + Err(Error::FinalSizeError) + ); + check_fc(&fc.borrow(), SW / 2 + 20, SW / 2 + 20); + check_fc(s.fc().unwrap(), SW / 2 + 20, SW / 2 + 20); + } + + /// Test flow control in RecvStreamState::WaitForReset + #[test] + fn fc_state_wait_for_reset() { + const SW: u64 = 1024; + const SW_US: usize = 1024; + let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW))); + + let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4); + check_fc(&fc.borrow(), 0, 0); + check_fc(s.fc().unwrap(), 0, 0); + + s.inbound_stream_frame(false, 0, &[0; SW_US / 2]).unwrap(); + check_fc(&fc.borrow(), SW / 2, 0); + check_fc(s.fc().unwrap(), SW / 2, 0); + + s.stop_sending(Error::NoError.code()); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + + s.stop_sending_acked(); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + + // Receiving duplicate frames (already consumed data) will not cause an error or + // change fc. + s.inbound_stream_frame(false, 0, &[0; SW_US / 2]).unwrap(); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + + // Receiving data past the flow control limit will cause an error. + assert_eq!( + s.inbound_stream_frame(false, 0, &[0; SW_US * 3 / 4 + 1]), + Err(Error::FlowControlError) + ); + + // The stream can still receive duplicate data without a fin bit. + s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4]) + .unwrap(); + check_fc(&fc.borrow(), SW / 2, SW / 2); + check_fc(s.fc().unwrap(), SW / 2, SW / 2); + + // Receiving more data will case the data to be retired. + // The stream can still receive duplicate data without a fin bit. + s.inbound_stream_frame(false, SW / 2, &[0; 10]).unwrap(); + check_fc(&fc.borrow(), SW / 2 + 10, SW / 2 + 10); + check_fc(s.fc().unwrap(), SW / 2 + 10, SW / 2 + 10); + } +} diff --git a/third_party/rust/neqo-transport/src/rtt.rs b/third_party/rust/neqo-transport/src/rtt.rs new file mode 100644 index 0000000000..4b05198bc9 --- /dev/null +++ b/third_party/rust/neqo-transport/src/rtt.rs @@ -0,0 +1,211 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Tracking of sent packets and detecting their loss. + +#![deny(clippy::pedantic)] + +use std::{ + cmp::{max, min}, + time::{Duration, Instant}, +}; + +use neqo_common::{qlog::NeqoQlog, qtrace}; + +use crate::{ + ackrate::{AckRate, PeerAckDelay}, + packet::PacketBuilder, + qlog::{self, QlogMetric}, + recovery::RecoveryToken, + stats::FrameStats, + tracking::PacketNumberSpace, +}; + +/// The smallest time that the system timer (via `sleep()`, `nanosleep()`, +/// `select()`, or similar) can reliably deliver; see `neqo_common::hrtime`. +pub const GRANULARITY: Duration = Duration::from_millis(1); +// Defined in -recovery 6.2 as 333ms but using lower value. +pub(crate) const INITIAL_RTT: Duration = Duration::from_millis(100); + +#[derive(Debug)] +#[allow(clippy::module_name_repetitions)] +pub struct RttEstimate { + first_sample_time: Option<Instant>, + latest_rtt: Duration, + smoothed_rtt: Duration, + rttvar: Duration, + min_rtt: Duration, + ack_delay: PeerAckDelay, +} + +impl RttEstimate { + fn init(&mut self, rtt: Duration) { + // Only allow this when there are no samples. + debug_assert!(self.first_sample_time.is_none()); + self.latest_rtt = rtt; + self.min_rtt = rtt; + self.smoothed_rtt = rtt; + self.rttvar = rtt / 2; + } + + #[cfg(test)] + pub const fn from_duration(rtt: Duration) -> Self { + Self { + first_sample_time: None, + latest_rtt: rtt, + smoothed_rtt: rtt, + rttvar: Duration::from_millis(0), + min_rtt: rtt, + ack_delay: PeerAckDelay::Fixed(Duration::from_millis(25)), + } + } + + pub fn set_initial(&mut self, rtt: Duration) { + qtrace!("initial RTT={:?}", rtt); + if rtt >= GRANULARITY { + // Ignore if the value is too small. + self.init(rtt); + } + } + + /// For a new path, prime the RTT based on the state of another path. + pub fn prime_rtt(&mut self, other: &Self) { + self.set_initial(other.smoothed_rtt + other.rttvar); + self.ack_delay = other.ack_delay.clone(); + } + + pub fn set_ack_delay(&mut self, ack_delay: PeerAckDelay) { + self.ack_delay = ack_delay; + } + + pub fn update_ack_delay(&mut self, cwnd: usize, mtu: usize) { + self.ack_delay.update(cwnd, mtu, self.smoothed_rtt); + } + + pub fn update( + &mut self, + qlog: &mut NeqoQlog, + mut rtt_sample: Duration, + ack_delay: Duration, + confirmed: bool, + now: Instant, + ) { + // Limit ack delay by max_ack_delay if confirmed. + let mad = self.ack_delay.max(); + let ack_delay = if confirmed && ack_delay > mad { + mad + } else { + ack_delay + }; + + // min_rtt ignores ack delay. + self.min_rtt = min(self.min_rtt, rtt_sample); + // Adjust for ack delay unless it goes below `min_rtt`. + if rtt_sample - self.min_rtt >= ack_delay { + rtt_sample -= ack_delay; + } + + if self.first_sample_time.is_none() { + self.init(rtt_sample); + self.first_sample_time = Some(now); + } else { + // Calculate EWMA RTT (based on {{?RFC6298}}). + let rttvar_sample = if self.smoothed_rtt > rtt_sample { + self.smoothed_rtt - rtt_sample + } else { + rtt_sample - self.smoothed_rtt + }; + + self.latest_rtt = rtt_sample; + self.rttvar = (self.rttvar * 3 + rttvar_sample) / 4; + self.smoothed_rtt = (self.smoothed_rtt * 7 + rtt_sample) / 8; + } + qtrace!( + "RTT latest={:?} -> estimate={:?}~{:?}", + self.latest_rtt, + self.smoothed_rtt, + self.rttvar + ); + qlog::metrics_updated( + qlog, + &[ + QlogMetric::LatestRtt(self.latest_rtt), + QlogMetric::MinRtt(self.min_rtt), + QlogMetric::SmoothedRtt(self.smoothed_rtt), + ], + ); + } + + /// Get the estimated value. + pub fn estimate(&self) -> Duration { + self.smoothed_rtt + } + + pub fn pto(&self, pn_space: PacketNumberSpace) -> Duration { + let mut t = self.estimate() + max(4 * self.rttvar, GRANULARITY); + if pn_space == PacketNumberSpace::ApplicationData { + t += self.ack_delay.max(); + } + t + } + + /// Calculate the loss delay based on the current estimate and the last + /// RTT measurement received. + pub fn loss_delay(&self) -> Duration { + // kTimeThreshold = 9/8 + // loss_delay = kTimeThreshold * max(latest_rtt, smoothed_rtt) + // loss_delay = max(loss_delay, kGranularity) + let rtt = max(self.latest_rtt, self.smoothed_rtt); + max(rtt * 9 / 8, GRANULARITY) + } + + pub fn first_sample_time(&self) -> Option<Instant> { + self.first_sample_time + } + + #[cfg(test)] + pub fn latest(&self) -> Duration { + self.latest_rtt + } + + pub fn rttvar(&self) -> Duration { + self.rttvar + } + + pub fn minimum(&self) -> Duration { + self.min_rtt + } + + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + self.ack_delay.write_frames(builder, tokens, stats); + } + + pub fn frame_lost(&mut self, lost: &AckRate) { + self.ack_delay.frame_lost(lost); + } + + pub fn frame_acked(&mut self, acked: &AckRate) { + self.ack_delay.frame_acked(acked); + } +} + +impl Default for RttEstimate { + fn default() -> Self { + Self { + first_sample_time: None, + latest_rtt: INITIAL_RTT, + smoothed_rtt: INITIAL_RTT, + rttvar: INITIAL_RTT / 2, + min_rtt: INITIAL_RTT, + ack_delay: PeerAckDelay::default(), + } + } +} diff --git a/third_party/rust/neqo-transport/src/send_stream.rs b/third_party/rust/neqo-transport/src/send_stream.rs new file mode 100644 index 0000000000..5feb785ac6 --- /dev/null +++ b/third_party/rust/neqo-transport/src/send_stream.rs @@ -0,0 +1,2636 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Buffering data to send until it is acked. + +use std::{ + cell::RefCell, + cmp::{max, min, Ordering}, + collections::{BTreeMap, VecDeque}, + convert::TryFrom, + hash::{Hash, Hasher}, + mem, + ops::Add, + rc::Rc, +}; + +use indexmap::IndexMap; +use neqo_common::{qdebug, qerror, qinfo, qtrace, Encoder, Role}; +use smallvec::SmallVec; + +use crate::{ + events::ConnectionEvents, + fc::SenderFlowControl, + frame::{Frame, FRAME_TYPE_RESET_STREAM}, + packet::PacketBuilder, + recovery::{RecoveryToken, StreamRecoveryToken}, + stats::FrameStats, + stream_id::StreamId, + streams::SendOrder, + tparams::{self, TransportParameters}, + AppError, Error, Res, +}; + +pub const SEND_BUFFER_SIZE: usize = 0x10_0000; // 1 MiB + +/// The priority that is assigned to sending data for the stream. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TransmissionPriority { + /// This stream is more important than the functioning of the connection. + /// Don't use this priority unless the stream really is that important. + /// A stream at this priority can starve out other connection functions, + /// including flow control, which could be very bad. + Critical, + /// The stream is very important. Stream data will be written ahead of + /// some of the less critical connection functions, like path validation, + /// connection ID management, and session tickets. + Important, + /// High priority streams are important, but not enough to disrupt + /// connection operation. They go ahead of session tickets though. + High, + /// The default priority. + Normal, + /// Low priority streams get sent last. + Low, +} + +impl Default for TransmissionPriority { + fn default() -> Self { + Self::Normal + } +} + +impl PartialOrd for TransmissionPriority { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some(self.cmp(other)) + } +} + +impl Ord for TransmissionPriority { + fn cmp(&self, other: &Self) -> Ordering { + if self == other { + return Ordering::Equal; + } + match (self, other) { + (Self::Critical, _) => Ordering::Greater, + (_, Self::Critical) => Ordering::Less, + (Self::Important, _) => Ordering::Greater, + (_, Self::Important) => Ordering::Less, + (Self::High, _) => Ordering::Greater, + (_, Self::High) => Ordering::Less, + (Self::Normal, _) => Ordering::Greater, + (_, Self::Normal) => Ordering::Less, + _ => unreachable!(), + } + } +} + +impl Add<RetransmissionPriority> for TransmissionPriority { + type Output = Self; + fn add(self, rhs: RetransmissionPriority) -> Self::Output { + match rhs { + RetransmissionPriority::Fixed(fixed) => fixed, + RetransmissionPriority::Same => self, + RetransmissionPriority::Higher => match self { + Self::Critical => Self::Critical, + Self::Important | Self::High => Self::Important, + Self::Normal => Self::High, + Self::Low => Self::Normal, + }, + RetransmissionPriority::MuchHigher => match self { + Self::Critical | Self::Important => Self::Critical, + Self::High | Self::Normal => Self::Important, + Self::Low => Self::High, + }, + } + } +} + +/// If data is lost, this determines the priority that applies to retransmissions +/// of that data. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RetransmissionPriority { + /// Prioritize retransmission at a fixed priority. + /// With this, it is possible to prioritize retransmissions lower than transmissions. + /// Doing that can create a deadlock with flow control which might cause the connection + /// to stall unless new data stops arriving fast enough that retransmissions can complete. + Fixed(TransmissionPriority), + /// Don't increase priority for retransmission. This is probably not a good idea + /// as it could mean starving flow control. + Same, + /// Increase the priority of retransmissions (the default). + /// Retransmissions of `Critical` or `Important` aren't elevated at all. + Higher, + /// Increase the priority of retransmissions a lot. + /// This is useful for streams that are particularly exposed to head-of-line blocking. + MuchHigher, +} + +impl Default for RetransmissionPriority { + fn default() -> Self { + Self::Higher + } +} + +#[derive(Debug, PartialEq, Clone, Copy)] +enum RangeState { + Sent, + Acked, +} + +/// Track ranges in the stream as sent or acked. Acked implies sent. Not in a +/// range implies needing-to-be-sent, either initially or as a retransmission. +#[derive(Debug, Default, PartialEq)] +struct RangeTracker { + // offset, (len, RangeState). Use u64 for len because ranges can exceed 32bits. + used: BTreeMap<u64, (u64, RangeState)>, +} + +impl RangeTracker { + fn highest_offset(&self) -> u64 { + self.used + .range(..) + .next_back() + .map_or(0, |(k, (v, _))| *k + *v) + } + + fn acked_from_zero(&self) -> u64 { + self.used + .get(&0) + .filter(|(_, state)| *state == RangeState::Acked) + .map_or(0, |(v, _)| *v) + } + + /// Find the first unmarked range. If all are contiguous, this will return + /// (highest_offset(), None). + fn first_unmarked_range(&self) -> (u64, Option<u64>) { + let mut prev_end = 0; + + for (cur_off, (cur_len, _)) in &self.used { + if prev_end == *cur_off { + prev_end = cur_off + cur_len; + } else { + return (prev_end, Some(cur_off - prev_end)); + } + } + (prev_end, None) + } + + /// Turn one range into a list of subranges that align with existing + /// ranges. + /// Check impermissible overlaps in subregions: Sent cannot overwrite Acked. + // + // e.g. given N is new and ABC are existing: + // NNNNNNNNNNNNNNNN + // AAAAA BBBCCCCC ...then we want 5 chunks: + // 1122222333444555 + // + // but also if we have this: + // NNNNNNNNNNNNNNNN + // AAAAAAAAAA BBBB ...then break existing A and B ranges up: + // + // 1111111122222233 + // aaAAAAAAAA BBbb + // + // Doing all this work up front should make handling each chunk much + // easier. + fn chunk_range_on_edges( + &mut self, + new_off: u64, + new_len: u64, + new_state: RangeState, + ) -> Vec<(u64, u64, RangeState)> { + let mut tmp_off = new_off; + let mut tmp_len = new_len; + let mut v = Vec::new(); + + // cut previous overlapping range if needed + let prev = self.used.range_mut(..tmp_off).next_back(); + if let Some((prev_off, (prev_len, prev_state))) = prev { + let prev_state = *prev_state; + let overlap = (*prev_off + *prev_len).saturating_sub(new_off); + *prev_len -= overlap; + if overlap > 0 { + self.used.insert(new_off, (overlap, prev_state)); + } + } + + let mut last_existing_remaining = None; + for (off, (len, state)) in self.used.range(tmp_off..tmp_off + tmp_len) { + // Create chunk for "overhang" before an existing range + if tmp_off < *off { + let sub_len = off - tmp_off; + v.push((tmp_off, sub_len, new_state)); + tmp_off += sub_len; + tmp_len -= sub_len; + } + + // Create chunk to match existing range + let sub_len = min(*len, tmp_len); + let remaining_len = len - sub_len; + if new_state == RangeState::Sent && *state == RangeState::Acked { + qinfo!( + "Attempted to downgrade overlapping range Acked range {}-{} with Sent {}-{}", + off, + len, + new_off, + new_len + ); + } else { + v.push((tmp_off, sub_len, new_state)); + } + tmp_off += sub_len; + tmp_len -= sub_len; + + if remaining_len > 0 { + last_existing_remaining = Some((*off, sub_len, remaining_len, *state)); + } + } + + // Maybe break last existing range in two so that a final chunk will + // have the same length as an existing range entry + if let Some((off, sub_len, remaining_len, state)) = last_existing_remaining { + *self.used.get_mut(&off).expect("must be there") = (sub_len, state); + self.used.insert(off + sub_len, (remaining_len, state)); + } + + // Create final chunk if anything remains of the new range + if tmp_len > 0 { + v.push((tmp_off, tmp_len, new_state)); + } + + v + } + + /// Merge contiguous Acked ranges into the first entry (0). This range may + /// be dropped from the send buffer. + fn coalesce_acked_from_zero(&mut self) { + let acked_range_from_zero = self + .used + .get_mut(&0) + .filter(|(_, state)| *state == RangeState::Acked) + .map(|(len, _)| *len); + + if let Some(len_from_zero) = acked_range_from_zero { + let mut new_len_from_zero = len_from_zero; + + // See if there's another Acked range entry contiguous to this one + while let Some((next_len, _)) = self + .used + .get(&new_len_from_zero) + .filter(|(_, state)| *state == RangeState::Acked) + { + let to_remove = new_len_from_zero; + new_len_from_zero += *next_len; + self.used.remove(&to_remove); + } + + if len_from_zero != new_len_from_zero { + self.used.get_mut(&0).expect("must be there").0 = new_len_from_zero; + } + } + } + + fn mark_range(&mut self, off: u64, len: usize, state: RangeState) { + if len == 0 { + qinfo!("mark 0-length range at {}", off); + return; + } + + let subranges = self.chunk_range_on_edges(off, len as u64, state); + + for (sub_off, sub_len, sub_state) in subranges { + self.used.insert(sub_off, (sub_len, sub_state)); + } + + self.coalesce_acked_from_zero(); + } + + fn unmark_range(&mut self, off: u64, len: usize) { + if len == 0 { + qdebug!("unmark 0-length range at {}", off); + return; + } + + let len = u64::try_from(len).unwrap(); + let end_off = off + len; + + let mut to_remove = SmallVec::<[_; 8]>::new(); + let mut to_add = None; + + // Walk backwards through possibly affected existing ranges + for (cur_off, (cur_len, cur_state)) in self.used.range_mut(..off + len).rev() { + // Maybe fixup range preceding the removed range + if *cur_off < off { + // Check for overlap + if *cur_off + *cur_len > off { + if *cur_state == RangeState::Acked { + qdebug!( + "Attempted to unmark Acked range {}-{} with unmark_range {}-{}", + cur_off, + cur_len, + off, + off + len + ); + } else { + *cur_len = off - cur_off; + } + } + break; + } + + if *cur_state == RangeState::Acked { + qdebug!( + "Attempted to unmark Acked range {}-{} with unmark_range {}-{}", + cur_off, + cur_len, + off, + off + len + ); + continue; + } + + // Add a new range for old subrange extending beyond + // to-be-unmarked range + let cur_end_off = cur_off + *cur_len; + if cur_end_off > end_off { + let new_cur_off = off + len; + let new_cur_len = cur_end_off - end_off; + assert_eq!(to_add, None); + to_add = Some((new_cur_off, new_cur_len, *cur_state)); + } + + to_remove.push(*cur_off); + } + + for remove_off in to_remove { + self.used.remove(&remove_off); + } + + if let Some((new_cur_off, new_cur_len, cur_state)) = to_add { + self.used.insert(new_cur_off, (new_cur_len, cur_state)); + } + } + + /// Unmark all sent ranges. + pub fn unmark_sent(&mut self) { + self.unmark_range(0, usize::try_from(self.highest_offset()).unwrap()); + } +} + +/// Buffer to contain queued bytes and track their state. +#[derive(Debug, Default, PartialEq)] +pub struct TxBuffer { + retired: u64, // contig acked bytes, no longer in buffer + send_buf: VecDeque<u8>, // buffer of not-acked bytes + ranges: RangeTracker, // ranges in buffer that have been sent or acked +} + +impl TxBuffer { + pub fn new() -> Self { + Self::default() + } + + /// Attempt to add some or all of the passed-in buffer to the TxBuffer. + pub fn send(&mut self, buf: &[u8]) -> usize { + let can_buffer = min(SEND_BUFFER_SIZE - self.buffered(), buf.len()); + if can_buffer > 0 { + self.send_buf.extend(&buf[..can_buffer]); + assert!(self.send_buf.len() <= SEND_BUFFER_SIZE); + } + can_buffer + } + + pub fn next_bytes(&self) -> Option<(u64, &[u8])> { + let (start, maybe_len) = self.ranges.first_unmarked_range(); + + if start == self.retired + u64::try_from(self.buffered()).unwrap() { + return None; + } + + // Convert from ranges-relative-to-zero to + // ranges-relative-to-buffer-start + let buff_off = usize::try_from(start - self.retired).unwrap(); + + // Deque returns two slices. Create a subslice from whichever + // one contains the first unmarked data. + let slc = if buff_off < self.send_buf.as_slices().0.len() { + &self.send_buf.as_slices().0[buff_off..] + } else { + &self.send_buf.as_slices().1[buff_off - self.send_buf.as_slices().0.len()..] + }; + + let len = if let Some(range_len) = maybe_len { + // Truncate if range crosses deque slices + min(usize::try_from(range_len).unwrap(), slc.len()) + } else { + slc.len() + }; + + debug_assert!(len > 0); + debug_assert!(len <= slc.len()); + + Some((start, &slc[..len])) + } + + pub fn mark_as_sent(&mut self, offset: u64, len: usize) { + self.ranges.mark_range(offset, len, RangeState::Sent); + } + + pub fn mark_as_acked(&mut self, offset: u64, len: usize) { + self.ranges.mark_range(offset, len, RangeState::Acked); + + // We can drop contig acked range from the buffer + let new_retirable = self.ranges.acked_from_zero() - self.retired; + debug_assert!(new_retirable <= self.buffered() as u64); + let keep_len = + self.buffered() - usize::try_from(new_retirable).expect("should fit in usize"); + + // Truncate front + self.send_buf.rotate_left(self.buffered() - keep_len); + self.send_buf.truncate(keep_len); + + self.retired += new_retirable; + } + + pub fn mark_as_lost(&mut self, offset: u64, len: usize) { + self.ranges.unmark_range(offset, len); + } + + /// Forget about anything that was marked as sent. + pub fn unmark_sent(&mut self) { + self.ranges.unmark_sent(); + } + + pub fn retired(&self) -> u64 { + self.retired + } + + fn buffered(&self) -> usize { + self.send_buf.len() + } + + fn avail(&self) -> usize { + SEND_BUFFER_SIZE - self.buffered() + } + + fn used(&self) -> u64 { + self.retired + u64::try_from(self.buffered()).unwrap() + } +} + +/// QUIC sending stream states, based on -transport 3.1. +#[derive(Debug)] +pub(crate) enum SendStreamState { + Ready { + fc: SenderFlowControl<StreamId>, + conn_fc: Rc<RefCell<SenderFlowControl<()>>>, + }, + Send { + fc: SenderFlowControl<StreamId>, + conn_fc: Rc<RefCell<SenderFlowControl<()>>>, + send_buf: TxBuffer, + }, + // Note: `DataSent` is entered when the stream is closed, not when all data has been + // sent for the first time. + DataSent { + send_buf: TxBuffer, + fin_sent: bool, + fin_acked: bool, + }, + DataRecvd { + retired: u64, + written: u64, + }, + ResetSent { + err: AppError, + final_size: u64, + priority: Option<TransmissionPriority>, + final_retired: u64, + final_written: u64, + }, + ResetRecvd { + final_retired: u64, + final_written: u64, + }, +} + +impl SendStreamState { + fn tx_buf_mut(&mut self) -> Option<&mut TxBuffer> { + match self { + Self::Send { send_buf, .. } | Self::DataSent { send_buf, .. } => Some(send_buf), + Self::Ready { .. } + | Self::DataRecvd { .. } + | Self::ResetSent { .. } + | Self::ResetRecvd { .. } => None, + } + } + + fn tx_avail(&self) -> usize { + match self { + // In Ready, TxBuffer not yet allocated but size is known + Self::Ready { .. } => SEND_BUFFER_SIZE, + Self::Send { send_buf, .. } | Self::DataSent { send_buf, .. } => send_buf.avail(), + Self::DataRecvd { .. } | Self::ResetSent { .. } | Self::ResetRecvd { .. } => 0, + } + } + + fn name(&self) -> &str { + match self { + Self::Ready { .. } => "Ready", + Self::Send { .. } => "Send", + Self::DataSent { .. } => "DataSent", + Self::DataRecvd { .. } => "DataRecvd", + Self::ResetSent { .. } => "ResetSent", + Self::ResetRecvd { .. } => "ResetRecvd", + } + } + + fn transition(&mut self, new_state: Self) { + qtrace!("SendStream state {} -> {}", self.name(), new_state.name()); + *self = new_state; + } +} + +// See https://www.w3.org/TR/webtransport/#send-stream-stats. +#[derive(Debug, Clone, Copy)] +pub struct SendStreamStats { + // The total number of bytes the consumer has successfully written to + // this stream. This number can only increase. + pub bytes_written: u64, + // An indicator of progress on how many of the consumer bytes written to + // this stream has been sent at least once. This number can only increase, + // and is always less than or equal to bytes_written. + pub bytes_sent: u64, + // An indicator of progress on how many of the consumer bytes written to + // this stream have been sent and acknowledged as received by the server + // using QUIC’s ACK mechanism. Only sequential bytes up to, + // but not including, the first non-acknowledged byte, are counted. + // This number can only increase and is always less than or equal to + // bytes_sent. + pub bytes_acked: u64, +} + +impl SendStreamStats { + #[must_use] + pub fn new(bytes_written: u64, bytes_sent: u64, bytes_acked: u64) -> Self { + Self { + bytes_written, + bytes_sent, + bytes_acked, + } + } + + #[must_use] + pub fn bytes_written(&self) -> u64 { + self.bytes_written + } + + #[must_use] + pub fn bytes_sent(&self) -> u64 { + self.bytes_sent + } + + #[must_use] + pub fn bytes_acked(&self) -> u64 { + self.bytes_acked + } +} + +/// Implement a QUIC send stream. +#[derive(Debug)] +pub struct SendStream { + stream_id: StreamId, + state: SendStreamState, + conn_events: ConnectionEvents, + priority: TransmissionPriority, + retransmission_priority: RetransmissionPriority, + retransmission_offset: u64, + sendorder: Option<SendOrder>, + bytes_sent: u64, + fair: bool, +} + +impl Hash for SendStream { + fn hash<H: Hasher>(&self, state: &mut H) { + self.stream_id.hash(state); + } +} + +impl PartialEq for SendStream { + fn eq(&self, other: &Self) -> bool { + self.stream_id == other.stream_id + } +} +impl Eq for SendStream {} + +impl SendStream { + pub fn new( + stream_id: StreamId, + max_stream_data: u64, + conn_fc: Rc<RefCell<SenderFlowControl<()>>>, + conn_events: ConnectionEvents, + ) -> Self { + let ss = Self { + stream_id, + state: SendStreamState::Ready { + fc: SenderFlowControl::new(stream_id, max_stream_data), + conn_fc, + }, + conn_events, + priority: TransmissionPriority::default(), + retransmission_priority: RetransmissionPriority::default(), + retransmission_offset: 0, + sendorder: None, + bytes_sent: 0, + fair: false, + }; + if ss.avail() > 0 { + ss.conn_events.send_stream_writable(stream_id); + } + ss + } + + pub fn write_frames( + &mut self, + priority: TransmissionPriority, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + qtrace!("write STREAM frames at priority {:?}", priority); + if !self.write_reset_frame(priority, builder, tokens, stats) { + self.write_blocked_frame(priority, builder, tokens, stats); + self.write_stream_frame(priority, builder, tokens, stats); + } + } + + // return false if the builder is full and the caller should stop iterating + pub fn write_frames_with_early_return( + &mut self, + priority: TransmissionPriority, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) -> bool { + if !self.write_reset_frame(priority, builder, tokens, stats) { + self.write_blocked_frame(priority, builder, tokens, stats); + if builder.is_full() { + return false; + } + self.write_stream_frame(priority, builder, tokens, stats); + if builder.is_full() { + return false; + } + } + true + } + + pub fn set_fairness(&mut self, make_fair: bool) { + self.fair = make_fair; + } + + pub fn is_fair(&self) -> bool { + self.fair + } + + pub fn set_priority( + &mut self, + transmission: TransmissionPriority, + retransmission: RetransmissionPriority, + ) { + self.priority = transmission; + self.retransmission_priority = retransmission; + } + + pub fn sendorder(&self) -> Option<SendOrder> { + self.sendorder + } + + pub fn set_sendorder(&mut self, sendorder: Option<SendOrder>) { + self.sendorder = sendorder; + } + + /// If all data has been buffered or written, how much was sent. + pub fn final_size(&self) -> Option<u64> { + match &self.state { + SendStreamState::DataSent { send_buf, .. } => Some(send_buf.used()), + SendStreamState::ResetSent { final_size, .. } => Some(*final_size), + _ => None, + } + } + + pub fn stats(&self) -> SendStreamStats { + SendStreamStats::new(self.bytes_written(), self.bytes_sent, self.bytes_acked()) + } + + pub fn bytes_written(&self) -> u64 { + match &self.state { + SendStreamState::Send { send_buf, .. } | SendStreamState::DataSent { send_buf, .. } => { + send_buf.retired() + u64::try_from(send_buf.buffered()).unwrap() + } + SendStreamState::DataRecvd { + retired, written, .. + } => *retired + *written, + SendStreamState::ResetSent { + final_retired, + final_written, + .. + } + | SendStreamState::ResetRecvd { + final_retired, + final_written, + .. + } => *final_retired + *final_written, + SendStreamState::Ready { .. } => 0, + } + } + + pub fn bytes_acked(&self) -> u64 { + match &self.state { + SendStreamState::Send { send_buf, .. } | SendStreamState::DataSent { send_buf, .. } => { + send_buf.retired() + } + SendStreamState::DataRecvd { retired, .. } => *retired, + SendStreamState::ResetSent { final_retired, .. } + | SendStreamState::ResetRecvd { final_retired, .. } => *final_retired, + SendStreamState::Ready { .. } => 0, + } + } + + /// Return the next range to be sent, if any. + /// If this is a retransmission, cut off what is sent at the retransmission + /// offset. + fn next_bytes(&mut self, retransmission_only: bool) -> Option<(u64, &[u8])> { + match self.state { + SendStreamState::Send { ref send_buf, .. } => { + send_buf.next_bytes().and_then(|(offset, slice)| { + if retransmission_only { + qtrace!( + [self], + "next_bytes apply retransmission limit at {}", + self.retransmission_offset + ); + if self.retransmission_offset > offset { + let len = min( + usize::try_from(self.retransmission_offset - offset).unwrap(), + slice.len(), + ); + Some((offset, &slice[..len])) + } else { + None + } + } else { + Some((offset, slice)) + } + }) + } + SendStreamState::DataSent { + ref send_buf, + fin_sent, + .. + } => { + let bytes = send_buf.next_bytes(); + if bytes.is_some() { + bytes + } else if fin_sent { + None + } else { + // Send empty stream frame with fin set + Some((send_buf.used(), &[])) + } + } + SendStreamState::Ready { .. } + | SendStreamState::DataRecvd { .. } + | SendStreamState::ResetSent { .. } + | SendStreamState::ResetRecvd { .. } => None, + } + } + + /// Calculate how many bytes (length) can fit into available space and whether + /// the remainder of the space can be filled (or if a length field is needed). + fn length_and_fill(data_len: usize, space: usize) -> (usize, bool) { + if data_len >= space { + // More data than space allows, or an exact fit => fast path. + qtrace!("SendStream::length_and_fill fill {}", space); + return (space, true); + } + + // Estimate size of the length field based on the available space, + // less 1, which is the worst case. + let length = min(space.saturating_sub(1), data_len); + let length_len = Encoder::varint_len(u64::try_from(length).unwrap()); + debug_assert!(length_len <= space); // We don't depend on this being true, but it is true. + + // From here we can always fit `data_len`, but we might as well fill + // if there is no space for the length field plus another frame. + let fill = data_len + length_len + PacketBuilder::MINIMUM_FRAME_SIZE > space; + qtrace!("SendStream::length_and_fill {} fill {}", data_len, fill); + (data_len, fill) + } + + /// Maybe write a `STREAM` frame. + pub fn write_stream_frame( + &mut self, + priority: TransmissionPriority, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + let retransmission = if priority == self.priority { + false + } else if priority == self.priority + self.retransmission_priority { + true + } else { + return; + }; + + let id = self.stream_id; + let final_size = self.final_size(); + if let Some((offset, data)) = self.next_bytes(retransmission) { + let overhead = 1 // Frame type + + Encoder::varint_len(id.as_u64()) + + if offset > 0 { + Encoder::varint_len(offset) + } else { + 0 + }; + if overhead > builder.remaining() { + qtrace!([self], "write_frame no space for header"); + return; + } + + let (length, fill) = Self::length_and_fill(data.len(), builder.remaining() - overhead); + let fin = final_size.map_or(false, |fs| fs == offset + u64::try_from(length).unwrap()); + if length == 0 && !fin { + qtrace!([self], "write_frame no data, no fin"); + return; + } + + // Write the stream out. + builder.encode_varint(Frame::stream_type(fin, offset > 0, fill)); + builder.encode_varint(id.as_u64()); + if offset > 0 { + builder.encode_varint(offset); + } + if fill { + builder.encode(&data[..length]); + builder.mark_full(); + } else { + builder.encode_vvec(&data[..length]); + } + debug_assert!(builder.len() <= builder.limit()); + + self.mark_as_sent(offset, length, fin); + tokens.push(RecoveryToken::Stream(StreamRecoveryToken::Stream( + SendStreamRecoveryToken { + id, + offset, + length, + fin, + }, + ))); + stats.stream += 1; + } + } + + pub fn reset_acked(&mut self) { + match self.state { + SendStreamState::Ready { .. } + | SendStreamState::Send { .. } + | SendStreamState::DataSent { .. } + | SendStreamState::DataRecvd { .. } => { + qtrace!([self], "Reset acked while in {} state?", self.state.name()); + } + SendStreamState::ResetSent { + final_retired, + final_written, + .. + } => self.state.transition(SendStreamState::ResetRecvd { + final_retired, + final_written, + }), + SendStreamState::ResetRecvd { .. } => qtrace!([self], "already in ResetRecvd state"), + }; + } + + pub fn reset_lost(&mut self) { + match self.state { + SendStreamState::ResetSent { + ref mut priority, .. + } => { + *priority = Some(self.priority + self.retransmission_priority); + } + SendStreamState::ResetRecvd { .. } => (), + _ => unreachable!(), + } + } + + /// Maybe write a `RESET_STREAM` frame. + pub fn write_reset_frame( + &mut self, + p: TransmissionPriority, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) -> bool { + if let SendStreamState::ResetSent { + final_size, + err, + ref mut priority, + .. + } = self.state + { + if *priority != Some(p) { + return false; + } + if builder.write_varint_frame(&[ + FRAME_TYPE_RESET_STREAM, + self.stream_id.as_u64(), + err, + final_size, + ]) { + tokens.push(RecoveryToken::Stream(StreamRecoveryToken::ResetStream { + stream_id: self.stream_id, + })); + stats.reset_stream += 1; + *priority = None; + true + } else { + false + } + } else { + false + } + } + + pub fn blocked_lost(&mut self, limit: u64) { + if let SendStreamState::Ready { fc, .. } | SendStreamState::Send { fc, .. } = + &mut self.state + { + fc.frame_lost(limit); + } else { + qtrace!([self], "Ignoring lost STREAM_DATA_BLOCKED({})", limit); + } + } + + /// Maybe write a `STREAM_DATA_BLOCKED` frame. + pub fn write_blocked_frame( + &mut self, + priority: TransmissionPriority, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + // Send STREAM_DATA_BLOCKED at normal priority always. + if priority == self.priority { + if let SendStreamState::Ready { fc, .. } | SendStreamState::Send { fc, .. } = + &mut self.state + { + fc.write_frames(builder, tokens, stats); + } + } + } + + pub fn mark_as_sent(&mut self, offset: u64, len: usize, fin: bool) { + self.bytes_sent = max(self.bytes_sent, offset + u64::try_from(len).unwrap()); + + if let Some(buf) = self.state.tx_buf_mut() { + buf.mark_as_sent(offset, len); + self.send_blocked_if_space_needed(0); + }; + + if fin { + if let SendStreamState::DataSent { fin_sent, .. } = &mut self.state { + *fin_sent = true; + } + } + } + + pub fn mark_as_acked(&mut self, offset: u64, len: usize, fin: bool) { + match self.state { + SendStreamState::Send { + ref mut send_buf, .. + } => { + send_buf.mark_as_acked(offset, len); + if self.avail() > 0 { + self.conn_events.send_stream_writable(self.stream_id); + } + } + SendStreamState::DataSent { + ref mut send_buf, + ref mut fin_acked, + .. + } => { + send_buf.mark_as_acked(offset, len); + if fin { + *fin_acked = true; + } + if *fin_acked && send_buf.buffered() == 0 { + self.conn_events.send_stream_complete(self.stream_id); + let retired = send_buf.retired(); + let buffered = u64::try_from(send_buf.buffered()).unwrap(); + self.state.transition(SendStreamState::DataRecvd { + retired, + written: buffered, + }); + } + } + _ => qtrace!( + [self], + "mark_as_acked called from state {}", + self.state.name() + ), + } + } + + pub fn mark_as_lost(&mut self, offset: u64, len: usize, fin: bool) { + self.retransmission_offset = max( + self.retransmission_offset, + offset + u64::try_from(len).unwrap(), + ); + qtrace!( + [self], + "mark_as_lost retransmission offset={}", + self.retransmission_offset + ); + if let Some(buf) = self.state.tx_buf_mut() { + buf.mark_as_lost(offset, len); + } + + if fin { + if let SendStreamState::DataSent { + fin_sent, + fin_acked, + .. + } = &mut self.state + { + *fin_sent = *fin_acked; + } + } + } + + /// Bytes sendable on stream. Constrained by stream credit available, + /// connection credit available, and space in the tx buffer. + pub fn avail(&self) -> usize { + if let SendStreamState::Ready { fc, conn_fc } | SendStreamState::Send { fc, conn_fc, .. } = + &self.state + { + min( + min(fc.available(), conn_fc.borrow().available()), + self.state.tx_avail(), + ) + } else { + 0 + } + } + + pub fn set_max_stream_data(&mut self, limit: u64) { + if let SendStreamState::Ready { fc, .. } | SendStreamState::Send { fc, .. } = + &mut self.state + { + let stream_was_blocked = fc.available() == 0; + fc.update(limit); + if stream_was_blocked && self.avail() > 0 { + self.conn_events.send_stream_writable(self.stream_id); + } + } + } + + pub fn is_terminal(&self) -> bool { + matches!( + self.state, + SendStreamState::DataRecvd { .. } | SendStreamState::ResetRecvd { .. } + ) + } + + pub fn send(&mut self, buf: &[u8]) -> Res<usize> { + self.send_internal(buf, false) + } + + pub fn send_atomic(&mut self, buf: &[u8]) -> Res<usize> { + self.send_internal(buf, true) + } + + fn send_blocked_if_space_needed(&mut self, needed_space: usize) { + if let SendStreamState::Ready { fc, conn_fc } | SendStreamState::Send { fc, conn_fc, .. } = + &mut self.state + { + if fc.available() <= needed_space { + fc.blocked(); + } + + if conn_fc.borrow().available() <= needed_space { + conn_fc.borrow_mut().blocked(); + } + } + } + + fn send_internal(&mut self, buf: &[u8], atomic: bool) -> Res<usize> { + if buf.is_empty() { + qerror!([self], "zero-length send on stream"); + return Err(Error::InvalidInput); + } + + if let SendStreamState::Ready { fc, conn_fc } = &mut self.state { + let owned_fc = mem::replace(fc, SenderFlowControl::new(self.stream_id, 0)); + let owned_conn_fc = Rc::clone(conn_fc); + self.state.transition(SendStreamState::Send { + fc: owned_fc, + conn_fc: owned_conn_fc, + send_buf: TxBuffer::new(), + }); + } + + if !matches!(self.state, SendStreamState::Send { .. }) { + return Err(Error::FinalSizeError); + } + + let buf = if buf.is_empty() || (self.avail() == 0) { + return Ok(0); + } else if self.avail() < buf.len() { + if atomic { + self.send_blocked_if_space_needed(buf.len()); + return Ok(0); + } else { + &buf[..self.avail()] + } + } else { + buf + }; + + match &mut self.state { + SendStreamState::Ready { .. } => unreachable!(), + SendStreamState::Send { + fc, + conn_fc, + send_buf, + } => { + let sent = send_buf.send(buf); + fc.consume(sent); + conn_fc.borrow_mut().consume(sent); + Ok(sent) + } + _ => Err(Error::FinalSizeError), + } + } + + pub fn close(&mut self) { + match &mut self.state { + SendStreamState::Ready { .. } => { + self.state.transition(SendStreamState::DataSent { + send_buf: TxBuffer::new(), + fin_sent: false, + fin_acked: false, + }); + } + SendStreamState::Send { send_buf, .. } => { + let owned_buf = mem::replace(send_buf, TxBuffer::new()); + self.state.transition(SendStreamState::DataSent { + send_buf: owned_buf, + fin_sent: false, + fin_acked: false, + }); + } + SendStreamState::DataSent { .. } => qtrace!([self], "already in DataSent state"), + SendStreamState::DataRecvd { .. } => qtrace!([self], "already in DataRecvd state"), + SendStreamState::ResetSent { .. } => qtrace!([self], "already in ResetSent state"), + SendStreamState::ResetRecvd { .. } => qtrace!([self], "already in ResetRecvd state"), + } + } + + pub fn reset(&mut self, err: AppError) { + match &self.state { + SendStreamState::Ready { fc, .. } => { + let final_size = fc.used(); + self.state.transition(SendStreamState::ResetSent { + err, + final_size, + priority: Some(self.priority), + final_retired: 0, + final_written: 0, + }); + } + SendStreamState::Send { fc, send_buf, .. } => { + let final_size = fc.used(); + let final_retired = send_buf.retired(); + let buffered = u64::try_from(send_buf.buffered()).unwrap(); + self.state.transition(SendStreamState::ResetSent { + err, + final_size, + priority: Some(self.priority), + final_retired, + final_written: buffered, + }); + } + SendStreamState::DataSent { send_buf, .. } => { + let final_size = send_buf.used(); + let final_retired = send_buf.retired(); + let buffered = u64::try_from(send_buf.buffered()).unwrap(); + self.state.transition(SendStreamState::ResetSent { + err, + final_size, + priority: Some(self.priority), + final_retired, + final_written: buffered, + }); + } + SendStreamState::DataRecvd { .. } => qtrace!([self], "already in DataRecvd state"), + SendStreamState::ResetSent { .. } => qtrace!([self], "already in ResetSent state"), + SendStreamState::ResetRecvd { .. } => qtrace!([self], "already in ResetRecvd state"), + }; + } + + #[cfg(test)] + pub(crate) fn state(&mut self) -> &mut SendStreamState { + &mut self.state + } +} + +impl ::std::fmt::Display for SendStream { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "SendStream {}", self.stream_id) + } +} + +#[derive(Debug, Default)] +pub struct OrderGroup { + // This vector is sorted by StreamId + vec: Vec<StreamId>, + + // Since we need to remember where we were, we'll store the iterator next + // position in the object. This means there can only be a single iterator active + // at a time! + next: usize, + // This is used when an iterator is created to set the start/stop point for the + // iteration. The iterator must iterate from this entry to the end, and then + // wrap and iterate from 0 until before the initial value of next. + // This value may need to be updated after insertion and removal; in theory we should + // track the target entry across modifications, but in practice it should be good + // enough to simply leave it alone unless it points past the end of the + // Vec, and re-initialize to 0 in that case. +} + +pub struct OrderGroupIter<'a> { + group: &'a mut OrderGroup, + // We store the next position in the OrderGroup. + // Otherwise we'd need an explicit "done iterating" call to be made, or implement Drop to + // copy the value back. + // This is where next was when we iterated for the first time; when we get back to that we + // stop. + started_at: Option<usize>, +} + +impl OrderGroup { + pub fn iter(&mut self) -> OrderGroupIter { + // Ids may have been deleted since we last iterated + if self.next >= self.vec.len() { + self.next = 0; + } + OrderGroupIter { + started_at: None, + group: self, + } + } + + pub fn stream_ids(&self) -> &Vec<StreamId> { + &self.vec + } + + pub fn clear(&mut self) { + self.vec.clear(); + } + + pub fn push(&mut self, stream_id: StreamId) { + self.vec.push(stream_id); + } + + #[cfg(test)] + pub fn truncate(&mut self, position: usize) { + self.vec.truncate(position); + } + + fn update_next(&mut self) -> usize { + let next = self.next; + self.next = (self.next + 1) % self.vec.len(); + next + } + + pub fn insert(&mut self, stream_id: StreamId) { + match self.vec.binary_search(&stream_id) { + Ok(_) => { + // element already in vector @ `pos` + panic!("Duplicate stream_id {}", stream_id) + } + Err(pos) => self.vec.insert(pos, stream_id), + } + } + + pub fn remove(&mut self, stream_id: StreamId) { + match self.vec.binary_search(&stream_id) { + Ok(pos) => { + self.vec.remove(pos); + } + Err(_) => { + // element already in vector @ `pos` + panic!("Missing stream_id {}", stream_id) + } + } + } +} + +impl<'a> Iterator for OrderGroupIter<'a> { + type Item = StreamId; + fn next(&mut self) -> Option<Self::Item> { + // Stop when we would return the started_at element on the next + // call. Note that this must take into account wrapping. + if self.started_at == Some(self.group.next) || self.group.vec.is_empty() { + return None; + } + self.started_at = self.started_at.or(Some(self.group.next)); + let orig = self.group.update_next(); + Some(self.group.vec[orig]) + } +} + +#[derive(Debug, Default)] +pub(crate) struct SendStreams { + map: IndexMap<StreamId, SendStream>, + + // What we really want is a Priority Queue that we can do arbitrary + // removes from (so we can reprioritize). BinaryHeap doesn't work, + // because there's no remove(). BTreeMap doesn't work, since you can't + // duplicate keys. PriorityQueue does have what we need, except for an + // ordered iterator that doesn't consume the queue. So we roll our own. + + // Added complication: We want to have Fairness for streams of the same + // 'group' (for WebTransport), but for H3 (and other non-WT streams) we + // tend to get better pageload performance by prioritizing by creation order. + // + // Two options are to walk the 'map' first, ignoring WebTransport + // streams, then process the unordered and ordered WebTransport + // streams. The second is to have a sorted Vec for unfair streams (and + // use a normal iterator for that), and then chain the iterators for + // the unordered and ordered WebTranport streams. The first works very + // well for H3, and for WebTransport nodes are visited twice on every + // processing loop. The second adds insertion and removal costs, but + // avoids a CPU penalty for WebTransport streams. For now we'll do #1. + // + // So we use a sorted Vec<> for the regular streams (that's usually all of + // them), and then a BTreeMap of an entry for each SendOrder value, and + // for each of those entries a Vec of the stream_ids at that + // sendorder. In most cases (such as stream-per-frame), there will be + // a single stream at a given sendorder. + + // These both store stream_ids, which need to be looked up in 'map'. + // This avoids the complexity of trying to hold references to the + // Streams which are owned by the IndexMap. + sendordered: BTreeMap<SendOrder, OrderGroup>, + regular: OrderGroup, // streams with no SendOrder set, sorted in stream_id order +} + +impl SendStreams { + pub fn get(&self, id: StreamId) -> Res<&SendStream> { + self.map.get(&id).ok_or(Error::InvalidStreamId) + } + + pub fn get_mut(&mut self, id: StreamId) -> Res<&mut SendStream> { + self.map.get_mut(&id).ok_or(Error::InvalidStreamId) + } + + pub fn exists(&self, id: StreamId) -> bool { + self.map.contains_key(&id) + } + + pub fn insert(&mut self, id: StreamId, stream: SendStream) { + self.map.insert(id, stream); + } + + fn group_mut(&mut self, sendorder: Option<SendOrder>) -> &mut OrderGroup { + if let Some(order) = sendorder { + self.sendordered.entry(order).or_default() + } else { + &mut self.regular + } + } + + pub fn set_sendorder(&mut self, stream_id: StreamId, sendorder: Option<SendOrder>) -> Res<()> { + self.set_fairness(stream_id, true)?; + if let Some(stream) = self.map.get_mut(&stream_id) { + // don't grab stream here; causes borrow errors + let old_sendorder = stream.sendorder(); + if old_sendorder != sendorder { + // we have to remove it from the list it was in, and reinsert it with the new + // sendorder key + let mut group = self.group_mut(old_sendorder); + group.remove(stream_id); + self.get_mut(stream_id).unwrap().set_sendorder(sendorder); + group = self.group_mut(sendorder); + group.insert(stream_id); + qtrace!( + "ordering of stream_ids: {:?}", + self.sendordered.values().collect::<Vec::<_>>() + ); + } + Ok(()) + } else { + Err(Error::InvalidStreamId) + } + } + + pub fn set_fairness(&mut self, stream_id: StreamId, make_fair: bool) -> Res<()> { + let stream: &mut SendStream = self.map.get_mut(&stream_id).ok_or(Error::InvalidStreamId)?; + let was_fair = stream.fair; + stream.set_fairness(make_fair); + if !was_fair && make_fair { + // Move to the regular OrderGroup. + + // We know sendorder can't have been set, since + // set_sendorder() will call this routine if it's not + // already set as fair. + + // This normally is only called when a new stream is created. If + // so, because of how we allocate StreamIds, it should always have + // the largest value. This means we can just append it to the + // regular vector. However, if we were ever to change this + // invariant, things would break subtly. + + // To be safe we can try to insert at the end and if not + // fall back to binary-search insertion + if matches!(self.regular.stream_ids().last(), Some(last) if stream_id > *last) { + self.regular.push(stream_id); + } else { + self.regular.insert(stream_id); + } + } else if was_fair && !make_fair { + // remove from the OrderGroup + let group = if let Some(sendorder) = stream.sendorder { + self.sendordered.get_mut(&sendorder).unwrap() + } else { + &mut self.regular + }; + group.remove(stream_id); + } + Ok(()) + } + + pub fn acked(&mut self, token: &SendStreamRecoveryToken) { + if let Some(ss) = self.map.get_mut(&token.id) { + ss.mark_as_acked(token.offset, token.length, token.fin); + } + } + + pub fn reset_acked(&mut self, id: StreamId) { + if let Some(ss) = self.map.get_mut(&id) { + ss.reset_acked(); + } + } + + pub fn lost(&mut self, token: &SendStreamRecoveryToken) { + if let Some(ss) = self.map.get_mut(&token.id) { + ss.mark_as_lost(token.offset, token.length, token.fin); + } + } + + pub fn reset_lost(&mut self, stream_id: StreamId) { + if let Some(ss) = self.map.get_mut(&stream_id) { + ss.reset_lost(); + } + } + + pub fn blocked_lost(&mut self, stream_id: StreamId, limit: u64) { + if let Some(ss) = self.map.get_mut(&stream_id) { + ss.blocked_lost(limit); + } + } + + pub fn clear(&mut self) { + self.map.clear(); + self.sendordered.clear(); + self.regular.clear(); + } + + pub fn remove_terminal(&mut self) { + let map: &mut IndexMap<StreamId, SendStream> = &mut self.map; + let regular: &mut OrderGroup = &mut self.regular; + let sendordered: &mut BTreeMap<SendOrder, OrderGroup> = &mut self.sendordered; + + // Take refs to all the items we need to modify instead of &mut + // self to keep the compiler happy (if we use self.map.retain it + // gets upset due to borrows) + map.retain(|stream_id, stream| { + if stream.is_terminal() { + if stream.is_fair() { + match stream.sendorder() { + None => regular.remove(*stream_id), + Some(sendorder) => { + sendordered.get_mut(&sendorder).unwrap().remove(*stream_id); + } + }; + } + // if unfair, we're done + return false; + } + true + }); + } + + pub(crate) fn write_frames( + &mut self, + priority: TransmissionPriority, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + qtrace!("write STREAM frames at priority {:?}", priority); + // WebTransport data (which is Normal) may have a SendOrder + // priority attached. The spec states (6.3 write-chunk 6.1): + + // First, we send any streams without Fairness defined, with + // ordering defined by StreamId. (Http3 streams used for + // e.g. pageload benefit from being processed in order of creation + // so the far side can start acting on a datum/request sooner. All + // WebTransport streams MUST have fairness set.) Then we send + // streams with fairness set (including all WebTransport streams) + // as follows: + + // If stream.[[SendOrder]] is null then this sending MUST NOT + // starve except for flow control reasons or error. If + // stream.[[SendOrder]] is not null then this sending MUST starve + // until all bytes queued for sending on WebTransportSendStreams + // with a non-null and higher [[SendOrder]], that are neither + // errored nor blocked by flow control, have been sent. + + // So data without SendOrder goes first. Then the highest priority + // SendOrdered streams. + // + // Fairness is implemented by a round-robining or "statefully + // iterating" within a single sendorder/unordered vector. We do + // this by recording where we stopped in the previous pass, and + // starting there the next pass. If we store an index into the + // vec, this means we can't use a chained iterator, since we want + // to retain our place-in-the-vector. If we rotate the vector, + // that would let us use the chained iterator, but would require + // more expensive searches for insertion and removal (since the + // sorted order would be lost). + + // Iterate the map, but only those without fairness, then iterate + // OrderGroups, then iterate each group + qdebug!("processing streams... unfair:"); + for stream in self.map.values_mut() { + if !stream.is_fair() { + qdebug!(" {}", stream); + if !stream.write_frames_with_early_return(priority, builder, tokens, stats) { + break; + } + } + } + qdebug!("fair streams:"); + let stream_ids = self.regular.iter().chain( + self.sendordered + .values_mut() + .rev() + .flat_map(|group| group.iter()), + ); + for stream_id in stream_ids { + let stream = self.map.get_mut(&stream_id).unwrap(); + if let Some(order) = stream.sendorder() { + qdebug!(" {} ({})", stream_id, order) + } else { + qdebug!(" None") + } + if !stream.write_frames_with_early_return(priority, builder, tokens, stats) { + break; + } + } + } + + pub fn update_initial_limit(&mut self, remote: &TransportParameters) { + for (id, ss) in self.map.iter_mut() { + let limit = if id.is_bidi() { + assert!(!id.is_remote_initiated(Role::Client)); + remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE) + } else { + remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI) + }; + ss.set_max_stream_data(limit); + } + } +} + +impl<'a> IntoIterator for &'a mut SendStreams { + type Item = (&'a StreamId, &'a mut SendStream); + type IntoIter = indexmap::map::IterMut<'a, StreamId, SendStream>; + + fn into_iter(self) -> indexmap::map::IterMut<'a, StreamId, SendStream> { + self.map.iter_mut() + } +} + +#[derive(Debug, Clone)] +pub struct SendStreamRecoveryToken { + pub(crate) id: StreamId, + offset: u64, + length: usize, + fin: bool, +} + +#[cfg(test)] +mod tests { + use neqo_common::{event::Provider, hex_with_len, qtrace}; + + use super::*; + use crate::events::ConnectionEvent; + + fn connection_fc(limit: u64) -> Rc<RefCell<SenderFlowControl<()>>> { + Rc::new(RefCell::new(SenderFlowControl::new((), limit))) + } + + #[test] + fn test_mark_range() { + let mut rt = RangeTracker::default(); + + // ranges can go from nothing->Sent if queued for retrans and then + // acks arrive + rt.mark_range(5, 5, RangeState::Acked); + assert_eq!(rt.highest_offset(), 10); + assert_eq!(rt.acked_from_zero(), 0); + rt.mark_range(10, 4, RangeState::Acked); + assert_eq!(rt.highest_offset(), 14); + assert_eq!(rt.acked_from_zero(), 0); + + rt.mark_range(0, 5, RangeState::Sent); + assert_eq!(rt.highest_offset(), 14); + assert_eq!(rt.acked_from_zero(), 0); + rt.mark_range(0, 5, RangeState::Acked); + assert_eq!(rt.highest_offset(), 14); + assert_eq!(rt.acked_from_zero(), 14); + + rt.mark_range(12, 20, RangeState::Acked); + assert_eq!(rt.highest_offset(), 32); + assert_eq!(rt.acked_from_zero(), 32); + + // ack the lot + rt.mark_range(0, 400, RangeState::Acked); + assert_eq!(rt.highest_offset(), 400); + assert_eq!(rt.acked_from_zero(), 400); + + // acked trumps sent + rt.mark_range(0, 200, RangeState::Sent); + assert_eq!(rt.highest_offset(), 400); + assert_eq!(rt.acked_from_zero(), 400); + } + + #[test] + fn unmark_sent_start() { + let mut rt = RangeTracker::default(); + + rt.mark_range(0, 5, RangeState::Sent); + assert_eq!(rt.highest_offset(), 5); + assert_eq!(rt.acked_from_zero(), 0); + + rt.unmark_sent(); + assert_eq!(rt.highest_offset(), 0); + assert_eq!(rt.acked_from_zero(), 0); + assert_eq!(rt.first_unmarked_range(), (0, None)); + } + + #[test] + fn unmark_sent_middle() { + let mut rt = RangeTracker::default(); + + rt.mark_range(0, 5, RangeState::Acked); + assert_eq!(rt.highest_offset(), 5); + assert_eq!(rt.acked_from_zero(), 5); + rt.mark_range(5, 5, RangeState::Sent); + assert_eq!(rt.highest_offset(), 10); + assert_eq!(rt.acked_from_zero(), 5); + rt.mark_range(10, 5, RangeState::Acked); + assert_eq!(rt.highest_offset(), 15); + assert_eq!(rt.acked_from_zero(), 5); + assert_eq!(rt.first_unmarked_range(), (15, None)); + + rt.unmark_sent(); + assert_eq!(rt.highest_offset(), 15); + assert_eq!(rt.acked_from_zero(), 5); + assert_eq!(rt.first_unmarked_range(), (5, Some(5))); + } + + #[test] + fn unmark_sent_end() { + let mut rt = RangeTracker::default(); + + rt.mark_range(0, 5, RangeState::Acked); + assert_eq!(rt.highest_offset(), 5); + assert_eq!(rt.acked_from_zero(), 5); + rt.mark_range(5, 5, RangeState::Sent); + assert_eq!(rt.highest_offset(), 10); + assert_eq!(rt.acked_from_zero(), 5); + assert_eq!(rt.first_unmarked_range(), (10, None)); + + rt.unmark_sent(); + assert_eq!(rt.highest_offset(), 5); + assert_eq!(rt.acked_from_zero(), 5); + assert_eq!(rt.first_unmarked_range(), (5, None)); + } + + #[test] + fn truncate_front() { + let mut v = VecDeque::new(); + v.push_back(5); + v.push_back(6); + v.push_back(7); + v.push_front(4usize); + + v.rotate_left(1); + v.truncate(3); + assert_eq!(*v.front().unwrap(), 5); + assert_eq!(*v.back().unwrap(), 7); + } + + #[test] + fn test_unmark_range() { + let mut rt = RangeTracker::default(); + + rt.mark_range(5, 5, RangeState::Acked); + rt.mark_range(10, 5, RangeState::Sent); + + // Should unmark sent but not acked range + rt.unmark_range(7, 6); + + let res = rt.first_unmarked_range(); + assert_eq!(res, (0, Some(5))); + assert_eq!( + rt.used.iter().next().unwrap(), + (&5, &(5, RangeState::Acked)) + ); + assert_eq!( + rt.used.iter().nth(1).unwrap(), + (&13, &(2, RangeState::Sent)) + ); + assert!(rt.used.iter().nth(2).is_none()); + rt.mark_range(0, 5, RangeState::Sent); + + let res = rt.first_unmarked_range(); + assert_eq!(res, (10, Some(3))); + rt.mark_range(10, 3, RangeState::Sent); + + let res = rt.first_unmarked_range(); + assert_eq!(res, (15, None)); + } + + #[test] + #[allow(clippy::cognitive_complexity)] + fn tx_buffer_next_bytes_1() { + let mut txb = TxBuffer::new(); + + assert_eq!(txb.avail(), SEND_BUFFER_SIZE); + + // Fill the buffer + assert_eq!(txb.send(&[1; SEND_BUFFER_SIZE * 2]), SEND_BUFFER_SIZE); + assert!(matches!(txb.next_bytes(), + Some((0, x)) if x.len()==SEND_BUFFER_SIZE + && x.iter().all(|ch| *ch == 1))); + + // Mark almost all as sent. Get what's left + let one_byte_from_end = SEND_BUFFER_SIZE as u64 - 1; + txb.mark_as_sent(0, one_byte_from_end as usize); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 1 + && start == one_byte_from_end + && x.iter().all(|ch| *ch == 1))); + + // Mark all as sent. Get nothing + txb.mark_as_sent(0, SEND_BUFFER_SIZE); + assert!(txb.next_bytes().is_none()); + + // Mark as lost. Get it again + txb.mark_as_lost(one_byte_from_end, 1); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 1 + && start == one_byte_from_end + && x.iter().all(|ch| *ch == 1))); + + // Mark a larger range lost, including beyond what's in the buffer even. + // Get a little more + let five_bytes_from_end = SEND_BUFFER_SIZE as u64 - 5; + txb.mark_as_lost(five_bytes_from_end, 100); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 5 + && start == five_bytes_from_end + && x.iter().all(|ch| *ch == 1))); + + // Contig acked range at start means it can be removed from buffer + // Impl of vecdeque should now result in a split buffer when more data + // is sent + txb.mark_as_acked(0, five_bytes_from_end as usize); + assert_eq!(txb.send(&[2; 30]), 30); + // Just get 5 even though there is more + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 5 + && start == five_bytes_from_end + && x.iter().all(|ch| *ch == 1))); + assert_eq!(txb.retired, five_bytes_from_end); + assert_eq!(txb.buffered(), 35); + + // Marking that bit as sent should let the last contig bit be returned + // when called again + txb.mark_as_sent(five_bytes_from_end, 5); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 30 + && start == SEND_BUFFER_SIZE as u64 + && x.iter().all(|ch| *ch == 2))); + } + + #[test] + fn tx_buffer_next_bytes_2() { + let mut txb = TxBuffer::new(); + + assert_eq!(txb.avail(), SEND_BUFFER_SIZE); + + // Fill the buffer + assert_eq!(txb.send(&[1; SEND_BUFFER_SIZE * 2]), SEND_BUFFER_SIZE); + assert!(matches!(txb.next_bytes(), + Some((0, x)) if x.len()==SEND_BUFFER_SIZE + && x.iter().all(|ch| *ch == 1))); + + // As above + let forty_bytes_from_end = SEND_BUFFER_SIZE as u64 - 40; + + txb.mark_as_acked(0, forty_bytes_from_end as usize); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 40 + && start == forty_bytes_from_end + )); + + // Valid new data placed in split locations + assert_eq!(txb.send(&[2; 100]), 100); + + // Mark a little more as sent + txb.mark_as_sent(forty_bytes_from_end, 10); + let thirty_bytes_from_end = forty_bytes_from_end + 10; + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 30 + && start == thirty_bytes_from_end + && x.iter().all(|ch| *ch == 1))); + + // Mark a range 'A' in second slice as sent. Should still return the same + let range_a_start = SEND_BUFFER_SIZE as u64 + 30; + let range_a_end = range_a_start + 10; + txb.mark_as_sent(range_a_start, 10); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 30 + && start == thirty_bytes_from_end + && x.iter().all(|ch| *ch == 1))); + + // Ack entire first slice and into second slice + let ten_bytes_past_end = SEND_BUFFER_SIZE as u64 + 10; + txb.mark_as_acked(0, ten_bytes_past_end as usize); + + // Get up to marked range A + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 20 + && start == ten_bytes_past_end + && x.iter().all(|ch| *ch == 2))); + + txb.mark_as_sent(ten_bytes_past_end, 20); + + // Get bit after earlier marked range A + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 60 + && start == range_a_end + && x.iter().all(|ch| *ch == 2))); + + // No more bytes. + txb.mark_as_sent(range_a_end, 60); + assert!(txb.next_bytes().is_none()); + } + + #[test] + fn test_stream_tx() { + let conn_fc = connection_fc(4096); + let conn_events = ConnectionEvents::default(); + + let mut s = SendStream::new(4.into(), 1024, Rc::clone(&conn_fc), conn_events); + + let res = s.send(&[4; 100]).unwrap(); + assert_eq!(res, 100); + s.mark_as_sent(0, 50, false); + if let SendStreamState::Send { fc, .. } = s.state() { + assert_eq!(fc.used(), 100); + } else { + panic!("unexpected stream state"); + } + + // Should hit stream flow control limit before filling up send buffer + let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap(); + assert_eq!(res, 1024 - 100); + + // should do nothing, max stream data already 1024 + s.set_max_stream_data(1024); + let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap(); + assert_eq!(res, 0); + + // should now hit the conn flow control (4096) + s.set_max_stream_data(1_048_576); + let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap(); + assert_eq!(res, 3072); + + // should now hit the tx buffer size + conn_fc.borrow_mut().update(SEND_BUFFER_SIZE as u64); + let res = s.send(&[4; SEND_BUFFER_SIZE + 100]).unwrap(); + assert_eq!(res, SEND_BUFFER_SIZE - 4096); + + // TODO(agrover@mozilla.com): test ooo acks somehow + s.mark_as_acked(0, 40, false); + } + + #[test] + fn test_tx_buffer_acks() { + let mut tx = TxBuffer::new(); + assert_eq!(tx.send(&[4; 100]), 100); + let res = tx.next_bytes().unwrap(); + assert_eq!(res.0, 0); + assert_eq!(res.1.len(), 100); + tx.mark_as_sent(0, 100); + let res = tx.next_bytes(); + assert_eq!(res, None); + + tx.mark_as_acked(0, 100); + let res = tx.next_bytes(); + assert_eq!(res, None); + } + + #[test] + fn send_stream_writable_event_gen() { + let conn_fc = connection_fc(2); + let mut conn_events = ConnectionEvents::default(); + + let mut s = SendStream::new(4.into(), 0, Rc::clone(&conn_fc), conn_events.clone()); + + // Stream is initially blocked (conn:2, stream:0) + // and will not accept data. + assert_eq!(s.send(b"hi").unwrap(), 0); + + // increasing to (conn:2, stream:2) will allow 2 bytes, and also + // generate a SendStreamWritable event. + s.set_max_stream_data(2); + let evts = conn_events.events().collect::<Vec<_>>(); + assert_eq!(evts.len(), 1); + assert!(matches!( + evts[0], + ConnectionEvent::SendStreamWritable { .. } + )); + assert_eq!(s.send(b"hello").unwrap(), 2); + + // increasing to (conn:2, stream:4) will not generate an event or allow + // sending anything. + s.set_max_stream_data(4); + assert_eq!(conn_events.events().count(), 0); + assert_eq!(s.send(b"hello").unwrap(), 0); + + // Increasing conn max (conn:4, stream:4) will unblock but not emit + // event b/c that happens in Connection::emit_frame() (tested in + // connection.rs) + assert!(conn_fc.borrow_mut().update(4)); + assert_eq!(conn_events.events().count(), 0); + assert_eq!(s.avail(), 2); + assert_eq!(s.send(b"hello").unwrap(), 2); + + // No event because still blocked by conn + s.set_max_stream_data(1_000_000_000); + assert_eq!(conn_events.events().count(), 0); + + // No event because happens in emit_frame() + conn_fc.borrow_mut().update(1_000_000_000); + assert_eq!(conn_events.events().count(), 0); + + // Unblocking both by a large amount will cause avail() to be limited by + // tx buffer size. + assert_eq!(s.avail(), SEND_BUFFER_SIZE - 4); + + assert_eq!( + s.send(&[b'a'; SEND_BUFFER_SIZE]).unwrap(), + SEND_BUFFER_SIZE - 4 + ); + + // No event because still blocked by tx buffer full + s.set_max_stream_data(2_000_000_000); + assert_eq!(conn_events.events().count(), 0); + assert_eq!(s.send(b"hello").unwrap(), 0); + } + + #[test] + fn send_stream_writable_event_new_stream() { + let conn_fc = connection_fc(2); + let mut conn_events = ConnectionEvents::default(); + + let _s = SendStream::new(4.into(), 100, conn_fc, conn_events.clone()); + + // Creating a new stream with conn and stream credits should result in + // an event. + let evts = conn_events.events().collect::<Vec<_>>(); + assert_eq!(evts.len(), 1); + assert!(matches!( + evts[0], + ConnectionEvent::SendStreamWritable { .. } + )); + } + + fn as_stream_token(t: &RecoveryToken) -> &SendStreamRecoveryToken { + if let RecoveryToken::Stream(StreamRecoveryToken::Stream(rt)) = &t { + rt + } else { + panic!(); + } + } + + #[test] + // Verify lost frames handle fin properly + fn send_stream_get_frame_data() { + let conn_fc = connection_fc(100); + let conn_events = ConnectionEvents::default(); + + let mut s = SendStream::new(0.into(), 100, conn_fc, conn_events); + s.send(&[0; 10]).unwrap(); + s.close(); + + let mut ss = SendStreams::default(); + ss.insert(StreamId::from(0), s); + + let mut tokens = Vec::new(); + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + + // Write a small frame: no fin. + let written = builder.len(); + builder.set_limit(written + 6); + ss.write_frames( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ); + assert_eq!(builder.len(), written + 6); + assert_eq!(tokens.len(), 1); + let f1_token = tokens.remove(0); + assert!(!as_stream_token(&f1_token).fin); + + // Write the rest: fin. + let written = builder.len(); + builder.set_limit(written + 200); + ss.write_frames( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ); + assert_eq!(builder.len(), written + 10); + assert_eq!(tokens.len(), 1); + let f2_token = tokens.remove(0); + assert!(as_stream_token(&f2_token).fin); + + // Should be no more data to frame. + let written = builder.len(); + ss.write_frames( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ); + assert_eq!(builder.len(), written); + assert!(tokens.is_empty()); + + // Mark frame 1 as lost + ss.lost(as_stream_token(&f1_token)); + + // Next frame should not set fin even though stream has fin but frame + // does not include end of stream + let written = builder.len(); + ss.write_frames( + TransmissionPriority::default() + RetransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ); + assert_eq!(builder.len(), written + 7); // Needs a length this time. + assert_eq!(tokens.len(), 1); + let f4_token = tokens.remove(0); + assert!(!as_stream_token(&f4_token).fin); + + // Mark frame 2 as lost + ss.lost(as_stream_token(&f2_token)); + + // Next frame should set fin because it includes end of stream + let written = builder.len(); + ss.write_frames( + TransmissionPriority::default() + RetransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ); + assert_eq!(builder.len(), written + 10); + assert_eq!(tokens.len(), 1); + let f5_token = tokens.remove(0); + assert!(as_stream_token(&f5_token).fin); + } + + #[test] + #[allow(clippy::cognitive_complexity)] + // Verify lost frames handle fin properly with zero length fin + fn send_stream_get_frame_zerolength_fin() { + let conn_fc = connection_fc(100); + let conn_events = ConnectionEvents::default(); + + let mut s = SendStream::new(0.into(), 100, conn_fc, conn_events); + s.send(&[0; 10]).unwrap(); + + let mut ss = SendStreams::default(); + ss.insert(StreamId::from(0), s); + + let mut tokens = Vec::new(); + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + ss.write_frames( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ); + let f1_token = tokens.remove(0); + assert_eq!(as_stream_token(&f1_token).offset, 0); + assert_eq!(as_stream_token(&f1_token).length, 10); + assert!(!as_stream_token(&f1_token).fin); + + // Should be no more data to frame + ss.write_frames( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ); + assert!(tokens.is_empty()); + + ss.get_mut(StreamId::from(0)).unwrap().close(); + + ss.write_frames( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ); + let f2_token = tokens.remove(0); + assert_eq!(as_stream_token(&f2_token).offset, 10); + assert_eq!(as_stream_token(&f2_token).length, 0); + assert!(as_stream_token(&f2_token).fin); + + // Mark frame 2 as lost + ss.lost(as_stream_token(&f2_token)); + + // Next frame should set fin + ss.write_frames( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ); + let f3_token = tokens.remove(0); + assert_eq!(as_stream_token(&f3_token).offset, 10); + assert_eq!(as_stream_token(&f3_token).length, 0); + assert!(as_stream_token(&f3_token).fin); + + // Mark frame 1 as lost + ss.lost(as_stream_token(&f1_token)); + + // Next frame should set fin and include all data + ss.write_frames( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut FrameStats::default(), + ); + let f4_token = tokens.remove(0); + assert_eq!(as_stream_token(&f4_token).offset, 0); + assert_eq!(as_stream_token(&f4_token).length, 10); + assert!(as_stream_token(&f4_token).fin); + } + + #[test] + fn data_blocked() { + let conn_fc = connection_fc(5); + let conn_events = ConnectionEvents::default(); + + let stream_id = StreamId::from(4); + let mut s = SendStream::new(stream_id, 2, Rc::clone(&conn_fc), conn_events); + + // Only two bytes can be sent due to the stream limit. + assert_eq!(s.send(b"abc").unwrap(), 2); + assert_eq!(s.next_bytes(false), Some((0, &b"ab"[..]))); + + // This doesn't report blocking yet. + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut tokens = Vec::new(); + let mut stats = FrameStats::default(); + s.write_blocked_frame( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut stats, + ); + assert_eq!(stats.stream_data_blocked, 0); + + // Blocking is reported after sending the last available credit. + s.mark_as_sent(0, 2, false); + s.write_blocked_frame( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut stats, + ); + assert_eq!(stats.stream_data_blocked, 1); + + // Now increase the stream limit and test the connection limit. + s.set_max_stream_data(10); + + assert_eq!(s.send(b"abcd").unwrap(), 3); + assert_eq!(s.next_bytes(false), Some((2, &b"abc"[..]))); + // DATA_BLOCKED is not sent yet. + conn_fc + .borrow_mut() + .write_frames(&mut builder, &mut tokens, &mut stats); + assert_eq!(stats.data_blocked, 0); + + // DATA_BLOCKED is queued once bytes using all credit are sent. + s.mark_as_sent(2, 3, false); + conn_fc + .borrow_mut() + .write_frames(&mut builder, &mut tokens, &mut stats); + assert_eq!(stats.data_blocked, 1); + } + + #[test] + fn data_blocked_atomic() { + let conn_fc = connection_fc(5); + let conn_events = ConnectionEvents::default(); + + let stream_id = StreamId::from(4); + let mut s = SendStream::new(stream_id, 2, Rc::clone(&conn_fc), conn_events); + + // Stream is initially blocked (conn:5, stream:2) + // and will not accept atomic write of 3 bytes. + assert_eq!(s.send_atomic(b"abc").unwrap(), 0); + + // Assert that STREAM_DATA_BLOCKED is sent. + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut tokens = Vec::new(); + let mut stats = FrameStats::default(); + s.write_blocked_frame( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut stats, + ); + assert_eq!(stats.stream_data_blocked, 1); + + // Assert that a non-atomic write works. + assert_eq!(s.send(b"abc").unwrap(), 2); + assert_eq!(s.next_bytes(false), Some((0, &b"ab"[..]))); + s.mark_as_sent(0, 2, false); + + // Set limits to (conn:5, stream:10). + s.set_max_stream_data(10); + + // An atomic write of 4 bytes exceeds the remaining limit of 3. + assert_eq!(s.send_atomic(b"abcd").unwrap(), 0); + + // Assert that DATA_BLOCKED is sent. + conn_fc + .borrow_mut() + .write_frames(&mut builder, &mut tokens, &mut stats); + assert_eq!(stats.data_blocked, 1); + + // Check that a non-atomic write works. + assert_eq!(s.send(b"abcd").unwrap(), 3); + assert_eq!(s.next_bytes(false), Some((2, &b"abc"[..]))); + s.mark_as_sent(2, 3, false); + + // Increase limits to (conn:15, stream:15). + s.set_max_stream_data(15); + conn_fc.borrow_mut().update(15); + + // Check that atomic writing right up to the limit works. + assert_eq!(s.send_atomic(b"abcdefghij").unwrap(), 10); + } + + #[test] + fn ack_fin_first() { + const MESSAGE: &[u8] = b"hello"; + let len_u64 = u64::try_from(MESSAGE.len()).unwrap(); + + let conn_fc = connection_fc(len_u64); + let conn_events = ConnectionEvents::default(); + + let mut s = SendStream::new(StreamId::new(100), 0, conn_fc, conn_events); + s.set_max_stream_data(len_u64); + + // Send all the data, then the fin. + _ = s.send(MESSAGE).unwrap(); + s.mark_as_sent(0, MESSAGE.len(), false); + s.close(); + s.mark_as_sent(len_u64, 0, true); + + // Ack the fin, then the data. + s.mark_as_acked(len_u64, 0, true); + s.mark_as_acked(0, MESSAGE.len(), false); + assert!(s.is_terminal()); + } + + #[test] + fn ack_then_lose_fin() { + const MESSAGE: &[u8] = b"hello"; + let len_u64 = u64::try_from(MESSAGE.len()).unwrap(); + + let conn_fc = connection_fc(len_u64); + let conn_events = ConnectionEvents::default(); + + let id = StreamId::new(100); + let mut s = SendStream::new(id, 0, conn_fc, conn_events); + s.set_max_stream_data(len_u64); + + // Send all the data, then the fin. + _ = s.send(MESSAGE).unwrap(); + s.mark_as_sent(0, MESSAGE.len(), false); + s.close(); + s.mark_as_sent(len_u64, 0, true); + + // Ack the fin, then mark it lost. + s.mark_as_acked(len_u64, 0, true); + s.mark_as_lost(len_u64, 0, true); + + // No frame should be sent here. + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut tokens = Vec::new(); + let mut stats = FrameStats::default(); + s.write_stream_frame( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut stats, + ); + assert_eq!(stats.stream, 0); + } + + /// Create a `SendStream` and force it into a state where it believes that + /// `offset` bytes have already been sent and acknowledged. + fn stream_with_sent(stream: u64, offset: usize) -> SendStream { + const MAX_VARINT: u64 = (1 << 62) - 1; + + let conn_fc = connection_fc(MAX_VARINT); + let mut s = SendStream::new( + StreamId::from(stream), + MAX_VARINT, + conn_fc, + ConnectionEvents::default(), + ); + + let mut send_buf = TxBuffer::new(); + send_buf.retired = u64::try_from(offset).unwrap(); + send_buf.ranges.mark_range(0, offset, RangeState::Acked); + let mut fc = SenderFlowControl::new(StreamId::from(stream), MAX_VARINT); + fc.consume(offset); + let conn_fc = Rc::new(RefCell::new(SenderFlowControl::new((), MAX_VARINT))); + s.state = SendStreamState::Send { + fc, + conn_fc, + send_buf, + }; + s + } + + fn frame_sent_sid(stream: u64, offset: usize, len: usize, fin: bool, space: usize) -> bool { + const BUF: &[u8] = &[0x42; 128]; + + qtrace!( + "frame_sent stream={} offset={} len={} fin={}, space={}", + stream, + offset, + len, + fin, + space + ); + + let mut s = stream_with_sent(stream, offset); + + // Now write out the proscribed data and maybe close. + if len > 0 { + s.send(&BUF[..len]).unwrap(); + } + if fin { + s.close(); + } + + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let header_len = builder.len(); + builder.set_limit(header_len + space); + + let mut tokens = Vec::new(); + let mut stats = FrameStats::default(); + s.write_stream_frame( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut stats, + ); + qtrace!( + "STREAM frame: {}", + hex_with_len(&builder.as_ref()[header_len..]) + ); + stats.stream > 0 + } + + fn frame_sent(offset: usize, len: usize, fin: bool, space: usize) -> bool { + frame_sent_sid(0, offset, len, fin, space) + } + + #[test] + fn stream_frame_empty() { + // Stream frames with empty data and no fin never work. + assert!(!frame_sent(10, 0, false, 2)); + assert!(!frame_sent(10, 0, false, 3)); + assert!(!frame_sent(10, 0, false, 4)); + assert!(!frame_sent(10, 0, false, 5)); + assert!(!frame_sent(10, 0, false, 100)); + + // Empty data with fin is only a problem if there is no space. + assert!(!frame_sent(0, 0, true, 1)); + assert!(frame_sent(0, 0, true, 2)); + assert!(!frame_sent(10, 0, true, 2)); + assert!(frame_sent(10, 0, true, 3)); + assert!(frame_sent(10, 0, true, 4)); + assert!(frame_sent(10, 0, true, 5)); + assert!(frame_sent(10, 0, true, 100)); + } + + #[test] + fn stream_frame_minimum() { + // Add minimum data + assert!(!frame_sent(10, 1, false, 3)); + assert!(!frame_sent(10, 1, true, 3)); + assert!(frame_sent(10, 1, false, 4)); + assert!(frame_sent(10, 1, true, 4)); + assert!(frame_sent(10, 1, false, 5)); + assert!(frame_sent(10, 1, true, 5)); + assert!(frame_sent(10, 1, false, 100)); + assert!(frame_sent(10, 1, true, 100)); + } + + #[test] + fn stream_frame_more() { + // Try more data + assert!(!frame_sent(10, 100, false, 3)); + assert!(!frame_sent(10, 100, true, 3)); + assert!(frame_sent(10, 100, false, 4)); + assert!(frame_sent(10, 100, true, 4)); + assert!(frame_sent(10, 100, false, 5)); + assert!(frame_sent(10, 100, true, 5)); + assert!(frame_sent(10, 100, false, 100)); + assert!(frame_sent(10, 100, true, 100)); + + assert!(frame_sent(10, 100, false, 1000)); + assert!(frame_sent(10, 100, true, 1000)); + } + + #[test] + fn stream_frame_big_id() { + // A value that encodes to the largest varint. + const BIG: u64 = 1 << 30; + const BIGSZ: usize = 1 << 30; + + assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 16)); + assert!(!frame_sent_sid(BIG, BIGSZ, 0, true, 16)); + assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 17)); + assert!(frame_sent_sid(BIG, BIGSZ, 0, true, 17)); + assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 18)); + assert!(frame_sent_sid(BIG, BIGSZ, 0, true, 18)); + + assert!(!frame_sent_sid(BIG, BIGSZ, 1, false, 17)); + assert!(!frame_sent_sid(BIG, BIGSZ, 1, true, 17)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 18)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 18)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 19)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 19)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 100)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 100)); + } + + fn stream_frame_at_boundary(data: &[u8]) { + fn send_with_extra_capacity(data: &[u8], extra: usize, expect_full: bool) -> Vec<u8> { + qtrace!("send_with_extra_capacity {} + {}", data.len(), extra); + let mut s = stream_with_sent(0, 0); + s.send(data).unwrap(); + s.close(); + + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let header_len = builder.len(); + // Add 2 for the frame type and stream ID, then add the extra. + builder.set_limit(header_len + data.len() + 2 + extra); + let mut tokens = Vec::new(); + let mut stats = FrameStats::default(); + s.write_stream_frame( + TransmissionPriority::default(), + &mut builder, + &mut tokens, + &mut stats, + ); + assert_eq!(stats.stream, 1); + assert_eq!(builder.is_full(), expect_full); + Vec::from(Encoder::from(builder)).split_off(header_len) + } + + // The minimum amount of extra space for getting another frame in. + let mut enc = Encoder::new(); + enc.encode_varint(u64::try_from(data.len()).unwrap()); + let len_buf = Vec::from(enc); + let minimum_extra = len_buf.len() + PacketBuilder::MINIMUM_FRAME_SIZE; + + // For anything short of the minimum extra, the frame should fill the packet. + for i in 0..minimum_extra { + let frame = send_with_extra_capacity(data, i, true); + let (header, body) = frame.split_at(2); + assert_eq!(header, &[0b1001, 0]); + assert_eq!(body, data); + } + + // Once there is space for another packet AND a length field, + // then a length will be added. + let frame = send_with_extra_capacity(data, minimum_extra, false); + let (header, rest) = frame.split_at(2); + assert_eq!(header, &[0b1011, 0]); + let (len, body) = rest.split_at(len_buf.len()); + assert_eq!(len, &len_buf); + assert_eq!(body, data); + } + + /// 16383/16384 is an odd boundary in STREAM frame construction. + /// That is the boundary where a length goes from 2 bytes to 4 bytes. + /// Test that we correctly add a length field to the frame; and test + /// that if we don't, then we don't allow other frames to be added. + #[test] + fn stream_frame_16384() { + stream_frame_at_boundary(&[4; 16383]); + stream_frame_at_boundary(&[4; 16384]); + } + + /// 63/64 is the other odd boundary. + #[test] + fn stream_frame_64() { + stream_frame_at_boundary(&[2; 63]); + stream_frame_at_boundary(&[2; 64]); + } + + fn check_stats( + stream: &SendStream, + expected_written: u64, + expected_sent: u64, + expected_acked: u64, + ) { + let stream_stats = stream.stats(); + assert_eq!(stream_stats.bytes_written(), expected_written); + assert_eq!(stream_stats.bytes_sent(), expected_sent); + assert_eq!(stream_stats.bytes_acked(), expected_acked); + } + + #[test] + fn send_stream_stats() { + const MESSAGE: &[u8] = b"hello"; + let len_u64 = u64::try_from(MESSAGE.len()).unwrap(); + + let conn_fc = connection_fc(len_u64); + let conn_events = ConnectionEvents::default(); + + let id = StreamId::new(100); + let mut s = SendStream::new(id, 0, conn_fc, conn_events); + s.set_max_stream_data(len_u64); + + // Initial stats should be all 0. + check_stats(&s, 0, 0, 0); + // Adter sending the data, bytes_written should be increased. + _ = s.send(MESSAGE).unwrap(); + check_stats(&s, len_u64, 0, 0); + + // Adter calling mark_as_sent, bytes_sent should be increased. + s.mark_as_sent(0, MESSAGE.len(), false); + check_stats(&s, len_u64, len_u64, 0); + + s.close(); + s.mark_as_sent(len_u64, 0, true); + + // In the end, check bytes_acked. + s.mark_as_acked(0, MESSAGE.len(), false); + check_stats(&s, len_u64, len_u64, len_u64); + + s.mark_as_acked(len_u64, 0, true); + assert!(s.is_terminal()); + } +} diff --git a/third_party/rust/neqo-transport/src/sender.rs b/third_party/rust/neqo-transport/src/sender.rs new file mode 100644 index 0000000000..9a00dfc7a7 --- /dev/null +++ b/third_party/rust/neqo-transport/src/sender.rs @@ -0,0 +1,130 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] +#![allow(clippy::module_name_repetitions)] + +use std::{ + fmt::{self, Debug, Display}, + time::{Duration, Instant}, +}; + +use neqo_common::qlog::NeqoQlog; + +use crate::{ + cc::{ClassicCongestionControl, CongestionControl, CongestionControlAlgorithm, Cubic, NewReno}, + pace::Pacer, + rtt::RttEstimate, + tracking::SentPacket, +}; + +/// The number of packets we allow to burst from the pacer. +pub const PACING_BURST_SIZE: usize = 2; + +#[derive(Debug)] +pub struct PacketSender { + cc: Box<dyn CongestionControl>, + pacer: Pacer, +} + +impl Display for PacketSender { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} {}", self.cc, self.pacer) + } +} + +impl PacketSender { + #[must_use] + pub fn new( + alg: CongestionControlAlgorithm, + pacing_enabled: bool, + mtu: usize, + now: Instant, + ) -> Self { + Self { + cc: match alg { + CongestionControlAlgorithm::NewReno => { + Box::new(ClassicCongestionControl::new(NewReno::default())) + } + CongestionControlAlgorithm::Cubic => { + Box::new(ClassicCongestionControl::new(Cubic::default())) + } + }, + pacer: Pacer::new(pacing_enabled, now, mtu * PACING_BURST_SIZE, mtu), + } + } + + pub fn set_qlog(&mut self, qlog: NeqoQlog) { + self.cc.set_qlog(qlog); + } + + #[must_use] + pub fn cwnd(&self) -> usize { + self.cc.cwnd() + } + + #[must_use] + pub fn cwnd_avail(&self) -> usize { + self.cc.cwnd_avail() + } + + pub fn on_packets_acked( + &mut self, + acked_pkts: &[SentPacket], + rtt_est: &RttEstimate, + now: Instant, + ) { + self.cc.on_packets_acked(acked_pkts, rtt_est, now); + } + + /// Called when packets are lost. Returns true if the congestion window was reduced. + pub fn on_packets_lost( + &mut self, + first_rtt_sample_time: Option<Instant>, + prev_largest_acked_sent: Option<Instant>, + pto: Duration, + lost_packets: &[SentPacket], + ) -> bool { + self.cc.on_packets_lost( + first_rtt_sample_time, + prev_largest_acked_sent, + pto, + lost_packets, + ) + } + + pub fn discard(&mut self, pkt: &SentPacket) { + self.cc.discard(pkt); + } + + /// When we migrate, the congestion controller for the previously active path drops + /// all bytes in flight. + pub fn discard_in_flight(&mut self) { + self.cc.discard_in_flight(); + } + + pub fn on_packet_sent(&mut self, pkt: &SentPacket, rtt: Duration) { + self.pacer + .spend(pkt.time_sent, rtt, self.cc.cwnd(), pkt.size); + self.cc.on_packet_sent(pkt); + } + + #[must_use] + pub fn next_paced(&self, rtt: Duration) -> Option<Instant> { + // Only pace if there are bytes in flight. + if self.cc.bytes_in_flight() > 0 { + Some(self.pacer.next(rtt, self.cc.cwnd())) + } else { + None + } + } + + #[must_use] + pub fn recovery_packet(&self) -> bool { + self.cc.recovery_packet() + } +} diff --git a/third_party/rust/neqo-transport/src/server.rs b/third_party/rust/neqo-transport/src/server.rs new file mode 100644 index 0000000000..12a7d2f9e0 --- /dev/null +++ b/third_party/rust/neqo-transport/src/server.rs @@ -0,0 +1,782 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// This file implements a server that can handle multiple connections. + +use std::{ + cell::RefCell, + collections::{HashMap, HashSet, VecDeque}, + fs::OpenOptions, + mem, + net::SocketAddr, + ops::{Deref, DerefMut}, + path::PathBuf, + rc::{Rc, Weak}, + time::{Duration, Instant}, +}; + +use neqo_common::{ + self as common, event::Provider, hex, qdebug, qerror, qinfo, qlog::NeqoQlog, qtrace, qwarn, + timer::Timer, Datagram, Decoder, Role, +}; +use neqo_crypto::{ + encode_ech_config, AntiReplay, Cipher, PrivateKey, PublicKey, ZeroRttCheckResult, + ZeroRttChecker, +}; +use qlog::streamer::QlogStreamer; + +pub use crate::addr_valid::ValidateAddress; +use crate::{ + addr_valid::{AddressValidation, AddressValidationResult}, + cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef}, + connection::{Connection, Output, State}, + packet::{PacketBuilder, PacketType, PublicPacket}, + ConnectionParameters, Res, Version, +}; + +pub enum InitialResult { + Accept, + Drop, + Retry(Vec<u8>), +} + +/// MIN_INITIAL_PACKET_SIZE is the smallest packet that can be used to establish +/// a new connection across all QUIC versions this server supports. +const MIN_INITIAL_PACKET_SIZE: usize = 1200; +/// The size of timer buckets. This is higher than the actual timer granularity +/// as this depends on there being some distribution of events. +const TIMER_GRANULARITY: Duration = Duration::from_millis(4); +/// The number of buckets in the timer. As mentioned in the definition of `Timer`, +/// the granularity and capacity need to multiply to be larger than the largest +/// delay that might be used. That's the idle timeout (currently 30s). +const TIMER_CAPACITY: usize = 16384; + +type StateRef = Rc<RefCell<ServerConnectionState>>; +type ConnectionTableRef = Rc<RefCell<HashMap<ConnectionId, StateRef>>>; + +#[derive(Debug)] +pub struct ServerConnectionState { + c: Connection, + active_attempt: Option<AttemptKey>, + last_timer: Instant, +} + +impl Deref for ServerConnectionState { + type Target = Connection; + fn deref(&self) -> &Self::Target { + &self.c + } +} + +impl DerefMut for ServerConnectionState { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.c + } +} + +/// A `AttemptKey` is used to disambiguate connection attempts. +/// Multiple connection attempts with the same key won't produce multiple connections. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +struct AttemptKey { + // Using the remote address is sufficient for disambiguation, + // until we support multiple local socket addresses. + remote_address: SocketAddr, + odcid: ConnectionId, +} + +/// A `ServerZeroRttChecker` is a simple wrapper around a single checker. +/// It uses `RefCell` so that the wrapped checker can be shared between +/// multiple connections created by the server. +#[derive(Clone, Debug)] +struct ServerZeroRttChecker { + checker: Rc<RefCell<Box<dyn ZeroRttChecker>>>, +} + +impl ServerZeroRttChecker { + pub fn new(checker: Box<dyn ZeroRttChecker>) -> Self { + Self { + checker: Rc::new(RefCell::new(checker)), + } + } +} + +impl ZeroRttChecker for ServerZeroRttChecker { + fn check(&self, token: &[u8]) -> ZeroRttCheckResult { + self.checker.borrow().check(token) + } +} + +/// `InitialDetails` holds important information for processing `Initial` packets. +struct InitialDetails { + src_cid: ConnectionId, + dst_cid: ConnectionId, + token: Vec<u8>, + version: Version, +} + +impl InitialDetails { + fn new(packet: &PublicPacket) -> Self { + Self { + src_cid: ConnectionId::from(packet.scid()), + dst_cid: ConnectionId::from(packet.dcid()), + token: packet.token().to_vec(), + version: packet.version().unwrap(), + } + } +} + +struct EchConfig { + config: u8, + public_name: String, + sk: PrivateKey, + pk: PublicKey, + encoded: Vec<u8>, +} + +impl EchConfig { + fn new(config: u8, public_name: &str, sk: &PrivateKey, pk: &PublicKey) -> Res<Self> { + let encoded = encode_ech_config(config, public_name, pk)?; + Ok(Self { + config, + public_name: String::from(public_name), + sk: sk.clone(), + pk: pk.clone(), + encoded, + }) + } +} + +pub struct Server { + /// The names of certificates. + certs: Vec<String>, + /// The ALPN values that the server supports. + protocols: Vec<String>, + /// The cipher suites that the server supports. + ciphers: Vec<Cipher>, + /// Anti-replay configuration for 0-RTT. + anti_replay: AntiReplay, + /// A function for determining if 0-RTT can be accepted. + zero_rtt_checker: ServerZeroRttChecker, + /// A connection ID generator. + cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>, + /// Connection parameters. + conn_params: ConnectionParameters, + /// Active connection attempts, keyed by `AttemptKey`. Initial packets with + /// the same key are routed to the connection that was first accepted. + /// This is cleared out when the connection is closed or established. + active_attempts: HashMap<AttemptKey, StateRef>, + /// All connections, keyed by ConnectionId. + connections: ConnectionTableRef, + /// The connections that have new events. + active: HashSet<ActiveConnectionRef>, + /// The set of connections that need immediate processing. + waiting: VecDeque<StateRef>, + /// Outstanding timers for connections. + timers: Timer<StateRef>, + /// Address validation logic, which determines whether we send a Retry. + address_validation: Rc<RefCell<AddressValidation>>, + /// Directory to create qlog traces in + qlog_dir: Option<PathBuf>, + /// Encrypted client hello (ECH) configuration. + ech_config: Option<EchConfig>, +} + +impl Server { + /// Construct a new server. + /// * `now` is the time that the server is instantiated. + /// * `certs` is a list of the certificates that should be configured. + /// * `protocols` is the preference list of ALPN values. + /// * `anti_replay` is an anti-replay context. + /// * `zero_rtt_checker` determines whether 0-RTT should be accepted. This will be passed the + /// value of the `extra` argument that was passed to `Connection::send_ticket` to see if it is + /// OK. + /// * `cid_generator` is responsible for generating connection IDs and parsing them; connection + /// IDs produced by the manager cannot be zero-length. + pub fn new( + now: Instant, + certs: &[impl AsRef<str>], + protocols: &[impl AsRef<str>], + anti_replay: AntiReplay, + zero_rtt_checker: Box<dyn ZeroRttChecker>, + cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>, + conn_params: ConnectionParameters, + ) -> Res<Self> { + let validation = AddressValidation::new(now, ValidateAddress::Never)?; + Ok(Self { + certs: certs.iter().map(|x| String::from(x.as_ref())).collect(), + protocols: protocols.iter().map(|x| String::from(x.as_ref())).collect(), + ciphers: Vec::new(), + anti_replay, + zero_rtt_checker: ServerZeroRttChecker::new(zero_rtt_checker), + cid_generator, + conn_params, + active_attempts: HashMap::default(), + connections: Rc::default(), + active: HashSet::default(), + waiting: VecDeque::default(), + timers: Timer::new(now, TIMER_GRANULARITY, TIMER_CAPACITY), + address_validation: Rc::new(RefCell::new(validation)), + qlog_dir: None, + ech_config: None, + }) + } + + /// Set or clear directory to create logs of connection events in QLOG format. + pub fn set_qlog_dir(&mut self, dir: Option<PathBuf>) { + self.qlog_dir = dir; + } + + /// Set the policy for address validation. + pub fn set_validation(&mut self, v: ValidateAddress) { + self.address_validation.borrow_mut().set_validation(v); + } + + /// Set the cipher suites that should be used. Set an empty value to use + /// default values. + pub fn set_ciphers(&mut self, ciphers: impl AsRef<[Cipher]>) { + self.ciphers = Vec::from(ciphers.as_ref()); + } + + pub fn enable_ech( + &mut self, + config: u8, + public_name: &str, + sk: &PrivateKey, + pk: &PublicKey, + ) -> Res<()> { + self.ech_config = Some(EchConfig::new(config, public_name, sk, pk)?); + Ok(()) + } + + pub fn ech_config(&self) -> &[u8] { + self.ech_config.as_ref().map_or(&[], |cfg| &cfg.encoded) + } + + fn remove_timer(&mut self, c: &StateRef) { + let last = c.borrow().last_timer; + self.timers.remove(last, |t| Rc::ptr_eq(t, c)); + } + + fn process_connection( + &mut self, + c: StateRef, + dgram: Option<&Datagram>, + now: Instant, + ) -> Option<Datagram> { + qtrace!([self], "Process connection {:?}", c); + let out = c.borrow_mut().process(dgram, now); + match out { + Output::Datagram(_) => { + qtrace!([self], "Sending packet, added to waiting connections"); + self.waiting.push_back(Rc::clone(&c)); + } + Output::Callback(delay) => { + let next = now + delay; + if next != c.borrow().last_timer { + qtrace!([self], "Change timer to {:?}", next); + self.remove_timer(&c); + c.borrow_mut().last_timer = next; + self.timers.add(next, Rc::clone(&c)); + } + } + Output::None => { + self.remove_timer(&c); + } + } + if c.borrow().has_events() { + qtrace!([self], "Connection active: {:?}", c); + self.active.insert(ActiveConnectionRef { c: Rc::clone(&c) }); + } + + if *c.borrow().state() > State::Handshaking { + // Remove any active connection attempt now that this is no longer handshaking. + if let Some(k) = c.borrow_mut().active_attempt.take() { + self.active_attempts.remove(&k); + } + } + + if matches!(c.borrow().state(), State::Closed(_)) { + c.borrow_mut().set_qlog(NeqoQlog::disabled()); + self.connections + .borrow_mut() + .retain(|_, v| !Rc::ptr_eq(v, &c)); + } + out.dgram() + } + + fn connection(&self, cid: ConnectionIdRef) -> Option<StateRef> { + self.connections.borrow().get(&cid[..]).map(Rc::clone) + } + + fn handle_initial( + &mut self, + initial: InitialDetails, + dgram: &Datagram, + now: Instant, + ) -> Option<Datagram> { + qdebug!([self], "Handle initial"); + let res = self + .address_validation + .borrow() + .validate(&initial.token, dgram.source(), now); + match res { + AddressValidationResult::Invalid => None, + AddressValidationResult::Pass => self.connection_attempt(initial, dgram, None, now), + AddressValidationResult::ValidRetry(orig_dcid) => { + self.connection_attempt(initial, dgram, Some(orig_dcid), now) + } + AddressValidationResult::Validate => { + qinfo!([self], "Send retry for {:?}", initial.dst_cid); + + let res = self.address_validation.borrow().generate_retry_token( + &initial.dst_cid, + dgram.source(), + now, + ); + let Ok(token) = res else { + qerror!([self], "unable to generate token, dropping packet"); + return None; + }; + if let Some(new_dcid) = self.cid_generator.borrow_mut().generate_cid() { + let packet = PacketBuilder::retry( + initial.version, + &initial.src_cid, + &new_dcid, + &token, + &initial.dst_cid, + ); + if let Ok(p) = packet { + let retry = Datagram::new( + dgram.destination(), + dgram.source(), + dgram.tos(), + dgram.ttl(), + p, + ); + Some(retry) + } else { + qerror!([self], "unable to encode retry, dropping packet"); + None + } + } else { + qerror!([self], "no connection ID for retry, dropping packet"); + None + } + } + } + } + + fn connection_attempt( + &mut self, + initial: InitialDetails, + dgram: &Datagram, + orig_dcid: Option<ConnectionId>, + now: Instant, + ) -> Option<Datagram> { + let attempt_key = AttemptKey { + remote_address: dgram.source(), + odcid: orig_dcid.as_ref().unwrap_or(&initial.dst_cid).clone(), + }; + if let Some(c) = self.active_attempts.get(&attempt_key) { + qdebug!( + [self], + "Handle Initial for existing connection attempt {:?}", + attempt_key + ); + let c = Rc::clone(c); + self.process_connection(c, Some(dgram), now) + } else { + self.accept_connection(attempt_key, initial, dgram, orig_dcid, now) + } + } + + fn create_qlog_trace(&self, odcid: ConnectionIdRef<'_>) -> NeqoQlog { + if let Some(qlog_dir) = &self.qlog_dir { + let mut qlog_path = qlog_dir.to_path_buf(); + + qlog_path.push(format!("{}.qlog", odcid)); + + // The original DCID is chosen by the client. Using create_new() + // prevents attackers from overwriting existing logs. + match OpenOptions::new() + .write(true) + .create_new(true) + .open(&qlog_path) + { + Ok(f) => { + qinfo!("Qlog output to {}", qlog_path.display()); + + let streamer = QlogStreamer::new( + qlog::QLOG_VERSION.to_string(), + Some("Neqo server qlog".to_string()), + Some("Neqo server qlog".to_string()), + None, + std::time::Instant::now(), + common::qlog::new_trace(Role::Server), + qlog::events::EventImportance::Base, + Box::new(f), + ); + let n_qlog = NeqoQlog::enabled(streamer, qlog_path); + match n_qlog { + Ok(nql) => nql, + Err(e) => { + // Keep going but w/o qlogging + qerror!("NeqoQlog error: {}", e); + NeqoQlog::disabled() + } + } + } + Err(e) => { + qerror!( + "Could not open file {} for qlog output: {}", + qlog_path.display(), + e + ); + NeqoQlog::disabled() + } + } + } else { + NeqoQlog::disabled() + } + } + + fn setup_connection( + &mut self, + c: &mut Connection, + attempt_key: &AttemptKey, + initial: InitialDetails, + orig_dcid: Option<ConnectionId>, + ) { + let zcheck = self.zero_rtt_checker.clone(); + if c.server_enable_0rtt(&self.anti_replay, zcheck).is_err() { + qwarn!([self], "Unable to enable 0-RTT"); + } + if let Some(odcid) = orig_dcid { + // There was a retry, so set the connection IDs for. + c.set_retry_cids(odcid, initial.src_cid, initial.dst_cid); + } + c.set_validation(Rc::clone(&self.address_validation)); + c.set_qlog(self.create_qlog_trace(attempt_key.odcid.as_cid_ref())); + if let Some(cfg) = &self.ech_config { + if c.server_enable_ech(cfg.config, &cfg.public_name, &cfg.sk, &cfg.pk) + .is_err() + { + qwarn!([self], "Unable to enable ECH"); + } + } + } + + fn accept_connection( + &mut self, + attempt_key: AttemptKey, + initial: InitialDetails, + dgram: &Datagram, + orig_dcid: Option<ConnectionId>, + now: Instant, + ) -> Option<Datagram> { + qinfo!([self], "Accept connection {:?}", attempt_key); + // The internal connection ID manager that we use is not used directly. + // Instead, wrap it so that we can save connection IDs. + + let cid_mgr = Rc::new(RefCell::new(ServerConnectionIdGenerator { + c: Weak::new(), + cid_generator: Rc::clone(&self.cid_generator), + connections: Rc::clone(&self.connections), + saved_cids: Vec::new(), + })); + + let mut params = self.conn_params.clone(); + params.get_versions_mut().set_initial(initial.version); + let sconn = Connection::new_server( + &self.certs, + &self.protocols, + Rc::clone(&cid_mgr) as _, + params, + ); + + match sconn { + Ok(mut c) => { + self.setup_connection(&mut c, &attempt_key, initial, orig_dcid); + let c = Rc::new(RefCell::new(ServerConnectionState { + c, + last_timer: now, + active_attempt: Some(attempt_key.clone()), + })); + cid_mgr.borrow_mut().set_connection(Rc::clone(&c)); + let previous_attempt = self.active_attempts.insert(attempt_key, Rc::clone(&c)); + debug_assert!(previous_attempt.is_none()); + self.process_connection(c, Some(dgram), now) + } + Err(e) => { + qwarn!([self], "Unable to create connection"); + if e == crate::Error::VersionNegotiation { + crate::qlog::server_version_information_failed( + &mut self.create_qlog_trace(attempt_key.odcid.as_cid_ref()), + self.conn_params.get_versions().all(), + initial.version.wire_version(), + ) + } + None + } + } + } + + /// Handle 0-RTT packets that were sent with the client's choice of connection ID. + /// Most 0-RTT will arrive this way. A client can usually send 1-RTT after it + /// receives a connection ID from the server. + fn handle_0rtt( + &mut self, + dgram: &Datagram, + dcid: ConnectionId, + now: Instant, + ) -> Option<Datagram> { + let attempt_key = AttemptKey { + remote_address: dgram.source(), + odcid: dcid, + }; + if let Some(c) = self.active_attempts.get(&attempt_key) { + qdebug!( + [self], + "Handle 0-RTT for existing connection attempt {:?}", + attempt_key + ); + let c = Rc::clone(c); + self.process_connection(c, Some(dgram), now) + } else { + qdebug!([self], "Dropping 0-RTT for unknown connection"); + None + } + } + + fn process_input(&mut self, dgram: &Datagram, now: Instant) -> Option<Datagram> { + qtrace!("Process datagram: {}", hex(&dgram[..])); + + // This is only looking at the first packet header in the datagram. + // All packets in the datagram are routed to the same connection. + let res = PublicPacket::decode(&dgram[..], self.cid_generator.borrow().as_decoder()); + let Ok((packet, _remainder)) = res else { + qtrace!([self], "Discarding {:?}", dgram); + return None; + }; + + // Finding an existing connection. Should be the most common case. + if let Some(c) = self.connection(packet.dcid()) { + return self.process_connection(c, Some(dgram), now); + } + + if packet.packet_type() == PacketType::Short { + // TODO send a stateless reset here. + qtrace!([self], "Short header packet for an unknown connection"); + return None; + } + + if packet.packet_type() == PacketType::OtherVersion + || (packet.packet_type() == PacketType::Initial + && !self + .conn_params + .get_versions() + .all() + .contains(&packet.version().unwrap())) + { + if dgram.len() < MIN_INITIAL_PACKET_SIZE { + qdebug!([self], "Unsupported version: too short"); + return None; + } + + qdebug!([self], "Unsupported version: {:x}", packet.wire_version()); + let vn = PacketBuilder::version_negotiation( + &packet.scid()[..], + &packet.dcid()[..], + packet.wire_version(), + self.conn_params.get_versions().all(), + ); + + crate::qlog::server_version_information_failed( + &mut self.create_qlog_trace(packet.dcid()), + self.conn_params.get_versions().all(), + packet.wire_version(), + ); + + return Some(Datagram::new( + dgram.destination(), + dgram.source(), + dgram.tos(), + dgram.ttl(), + vn, + )); + } + + match packet.packet_type() { + PacketType::Initial => { + if dgram.len() < MIN_INITIAL_PACKET_SIZE { + qdebug!([self], "Drop initial: too short"); + return None; + } + // Copy values from `packet` because they are currently still borrowing from + // `dgram`. + let initial = InitialDetails::new(&packet); + self.handle_initial(initial, dgram, now) + } + PacketType::ZeroRtt => { + let dcid = ConnectionId::from(packet.dcid()); + self.handle_0rtt(dgram, dcid, now) + } + PacketType::OtherVersion => unreachable!(), + _ => { + qtrace!([self], "Not an initial packet"); + None + } + } + } + + /// Iterate through the pending connections looking for any that might want + /// to send a datagram. Stop at the first one that does. + fn process_next_output(&mut self, now: Instant) -> Option<Datagram> { + qtrace!([self], "No packet to send, look at waiting connections"); + while let Some(c) = self.waiting.pop_front() { + if let Some(d) = self.process_connection(c, None, now) { + return Some(d); + } + } + qtrace!([self], "No packet to send still, run timers"); + while let Some(c) = self.timers.take_next(now) { + if let Some(d) = self.process_connection(c, None, now) { + return Some(d); + } + } + None + } + + fn next_time(&mut self, now: Instant) -> Option<Duration> { + if self.waiting.is_empty() { + self.timers.next_time().map(|x| x - now) + } else { + Some(Duration::new(0, 0)) + } + } + + pub fn process(&mut self, dgram: Option<&Datagram>, now: Instant) -> Output { + dgram + .and_then(|d| self.process_input(d, now)) + .or_else(|| self.process_next_output(now)) + .map(|d| { + qtrace!([self], "Send packet: {:?}", d); + Output::Datagram(d) + }) + .or_else(|| { + self.next_time(now).map(|delay| { + qtrace!([self], "Wait: {:?}", delay); + Output::Callback(delay) + }) + }) + .unwrap_or_else(|| { + qtrace!([self], "Go dormant"); + Output::None + }) + } + + /// This lists the connections that have received new events + /// as a result of calling `process()`. + pub fn active_connections(&mut self) -> Vec<ActiveConnectionRef> { + mem::take(&mut self.active).into_iter().collect() + } + + pub fn add_to_waiting(&mut self, c: ActiveConnectionRef) { + self.waiting.push_back(c.connection()); + } +} + +#[derive(Clone, Debug)] +pub struct ActiveConnectionRef { + c: StateRef, +} + +impl ActiveConnectionRef { + pub fn borrow(&self) -> impl Deref<Target = Connection> + '_ { + std::cell::Ref::map(self.c.borrow(), |c| &c.c) + } + + pub fn borrow_mut(&mut self) -> impl DerefMut<Target = Connection> + '_ { + std::cell::RefMut::map(self.c.borrow_mut(), |c| &mut c.c) + } + + pub fn connection(&self) -> StateRef { + Rc::clone(&self.c) + } +} + +impl std::hash::Hash for ActiveConnectionRef { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + let ptr: *const _ = self.c.as_ref(); + ptr.hash(state); + } +} + +impl PartialEq for ActiveConnectionRef { + fn eq(&self, other: &Self) -> bool { + Rc::ptr_eq(&self.c, &other.c) + } +} + +impl Eq for ActiveConnectionRef {} + +struct ServerConnectionIdGenerator { + c: Weak<RefCell<ServerConnectionState>>, + connections: ConnectionTableRef, + cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>, + saved_cids: Vec<ConnectionId>, +} + +impl ServerConnectionIdGenerator { + pub fn set_connection(&mut self, c: StateRef) { + let saved = std::mem::replace(&mut self.saved_cids, Vec::with_capacity(0)); + for cid in saved { + qtrace!("ServerConnectionIdGenerator inserting saved cid {}", cid); + self.insert_cid(cid, Rc::clone(&c)); + } + self.c = Rc::downgrade(&c); + } + + fn insert_cid(&mut self, cid: ConnectionId, rc: StateRef) { + debug_assert!(!cid.is_empty()); + self.connections.borrow_mut().insert(cid, rc); + } +} + +impl ConnectionIdDecoder for ServerConnectionIdGenerator { + fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> { + self.cid_generator.borrow_mut().decode_cid(dec) + } +} + +impl ConnectionIdGenerator for ServerConnectionIdGenerator { + fn generate_cid(&mut self) -> Option<ConnectionId> { + let maybe_cid = self.cid_generator.borrow_mut().generate_cid(); + if let Some(cid) = maybe_cid { + if let Some(rc) = self.c.upgrade() { + self.insert_cid(cid.clone(), rc); + } else { + // This function can be called before the connection is set. + // So save any connection IDs until that hookup happens. + qtrace!("ServerConnectionIdGenerator saving cid {}", cid); + self.saved_cids.push(cid.clone()); + } + Some(cid) + } else { + None + } + } + + fn as_decoder(&self) -> &dyn ConnectionIdDecoder { + self + } +} + +impl ::std::fmt::Display for Server { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Server") + } +} diff --git a/third_party/rust/neqo-transport/src/stats.rs b/third_party/rust/neqo-transport/src/stats.rs new file mode 100644 index 0000000000..d6c7a911f9 --- /dev/null +++ b/third_party/rust/neqo-transport/src/stats.rs @@ -0,0 +1,235 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Tracking of some useful statistics. +#![deny(clippy::pedantic)] + +use std::{ + cell::RefCell, + fmt::{self, Debug}, + ops::Deref, + rc::Rc, + time::Duration, +}; + +use neqo_common::qinfo; + +use crate::packet::PacketNumber; + +pub(crate) const MAX_PTO_COUNTS: usize = 16; + +#[derive(Default, Clone)] +#[cfg_attr(test, derive(PartialEq, Eq))] +#[allow(clippy::module_name_repetitions)] +pub struct FrameStats { + pub all: usize, + pub ack: usize, + pub largest_acknowledged: PacketNumber, + + pub crypto: usize, + pub stream: usize, + pub reset_stream: usize, + pub stop_sending: usize, + + pub ping: usize, + pub padding: usize, + + pub max_streams: usize, + pub streams_blocked: usize, + pub max_data: usize, + pub data_blocked: usize, + pub max_stream_data: usize, + pub stream_data_blocked: usize, + + pub new_connection_id: usize, + pub retire_connection_id: usize, + + pub path_challenge: usize, + pub path_response: usize, + + pub connection_close: usize, + pub handshake_done: usize, + pub new_token: usize, + + pub ack_frequency: usize, + pub datagram: usize, +} + +impl Debug for FrameStats { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + " crypto {} done {} token {} close {}", + self.crypto, self.handshake_done, self.new_token, self.connection_close, + )?; + writeln!( + f, + " ack {} (max {}) ping {} padding {}", + self.ack, self.largest_acknowledged, self.ping, self.padding + )?; + writeln!( + f, + " stream {} reset {} stop {}", + self.stream, self.reset_stream, self.stop_sending, + )?; + writeln!( + f, + " max: stream {} data {} stream_data {}", + self.max_streams, self.max_data, self.max_stream_data, + )?; + writeln!( + f, + " blocked: stream {} data {} stream_data {}", + self.streams_blocked, self.data_blocked, self.stream_data_blocked, + )?; + writeln!(f, " datagram {}", self.datagram)?; + writeln!( + f, + " ncid {} rcid {} pchallenge {} presponse {}", + self.new_connection_id, + self.retire_connection_id, + self.path_challenge, + self.path_response, + )?; + writeln!(f, " ack_frequency {}", self.ack_frequency) + } +} + +/// Datagram stats +#[derive(Default, Clone)] +#[allow(clippy::module_name_repetitions)] +pub struct DatagramStats { + /// The number of datagrams declared lost. + pub lost: usize, + /// The number of datagrams dropped due to being too large. + pub dropped_too_big: usize, + /// The number of datagrams dropped due to reaching the limit of the + /// outgoing queue. + pub dropped_queue_full: usize, +} + +/// Connection statistics +#[derive(Default, Clone)] +#[allow(clippy::module_name_repetitions)] +pub struct Stats { + info: String, + + /// Total packets received, including all the bad ones. + pub packets_rx: usize, + /// Duplicate packets received. + pub dups_rx: usize, + /// Dropped packets or dropped garbage. + pub dropped_rx: usize, + /// The number of packet that were saved for later processing. + pub saved_datagrams: usize, + + /// Total packets sent. + pub packets_tx: usize, + /// Total number of packets that are declared lost. + pub lost: usize, + /// Late acknowledgments, for packets that were declared lost already. + pub late_ack: usize, + /// Acknowledgments for packets that contained data that was marked + /// for retransmission when the PTO timer popped. + pub pto_ack: usize, + + /// Whether the connection was resumed successfully. + pub resumed: bool, + + /// The current, estimated round-trip time on the primary path. + pub rtt: Duration, + /// The current, estimated round-trip time variation on the primary path. + pub rttvar: Duration, + /// Whether the first RTT sample was guessed from a discarded packet. + pub rtt_init_guess: bool, + + /// Count PTOs. Single PTOs, 2 PTOs in a row, 3 PTOs in row, etc. are counted + /// separately. + pub pto_counts: [usize; MAX_PTO_COUNTS], + + /// Count frames received. + pub frame_rx: FrameStats, + /// Count frames sent. + pub frame_tx: FrameStats, + + /// The number of incoming datagrams dropped due to reaching the limit + /// of the incoming queue. + pub incoming_datagram_dropped: usize, + + pub datagram_tx: DatagramStats, +} + +impl Stats { + pub fn init(&mut self, info: String) { + self.info = info; + } + + pub fn pkt_dropped(&mut self, reason: impl AsRef<str>) { + self.dropped_rx += 1; + qinfo!( + [self.info], + "Dropped received packet: {}; Total: {}", + reason.as_ref(), + self.dropped_rx + ); + } + + /// # Panics + /// + /// When preconditions are violated. + pub fn add_pto_count(&mut self, count: usize) { + debug_assert!(count > 0); + if count >= MAX_PTO_COUNTS { + // We can't move this count any further, so stop. + return; + } + self.pto_counts[count - 1] += 1; + if count > 1 { + debug_assert!(self.pto_counts[count - 2] > 0); + self.pto_counts[count - 2] -= 1; + } + } +} + +impl Debug for Stats { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "stats for {}", self.info)?; + writeln!( + f, + " rx: {} drop {} dup {} saved {}", + self.packets_rx, self.dropped_rx, self.dups_rx, self.saved_datagrams + )?; + writeln!( + f, + " tx: {} lost {} lateack {} ptoack {}", + self.packets_tx, self.lost, self.late_ack, self.pto_ack + )?; + writeln!(f, " resumed: {} ", self.resumed)?; + writeln!(f, " frames rx:")?; + self.frame_rx.fmt(f)?; + writeln!(f, " frames tx:")?; + self.frame_tx.fmt(f) + } +} + +#[derive(Default, Clone)] +#[allow(clippy::module_name_repetitions)] +pub struct StatsCell { + stats: Rc<RefCell<Stats>>, +} + +impl Deref for StatsCell { + type Target = RefCell<Stats>; + fn deref(&self) -> &Self::Target { + &self.stats + } +} + +impl Debug for StatsCell { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.stats.borrow().fmt(f) + } +} diff --git a/third_party/rust/neqo-transport/src/stream_id.rs b/third_party/rust/neqo-transport/src/stream_id.rs new file mode 100644 index 0000000000..f3b07b86a8 --- /dev/null +++ b/third_party/rust/neqo-transport/src/stream_id.rs @@ -0,0 +1,177 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Stream ID and stream index handling. + +use neqo_common::Role; + +#[derive(PartialEq, Debug, Copy, Clone, PartialOrd, Eq, Ord, Hash)] + +/// The type of stream, either Bi-Directional or Uni-Directional. +pub enum StreamType { + BiDi, + UniDi, +} + +#[derive(Debug, Eq, PartialEq, Clone, Copy, Ord, PartialOrd, Hash)] +pub struct StreamId(u64); + +impl StreamId { + pub const fn new(id: u64) -> Self { + Self(id) + } + + pub fn init(stream_type: StreamType, role: Role) -> Self { + let type_val = match stream_type { + StreamType::BiDi => 0, + StreamType::UniDi => 2, + }; + Self(type_val + Self::role_bit(role)) + } + + pub fn as_u64(self) -> u64 { + self.0 + } + + pub fn is_bidi(self) -> bool { + self.as_u64() & 0x02 == 0 + } + + pub fn is_uni(self) -> bool { + !self.is_bidi() + } + + pub fn stream_type(self) -> StreamType { + if self.is_bidi() { + StreamType::BiDi + } else { + StreamType::UniDi + } + } + + pub fn is_client_initiated(self) -> bool { + self.as_u64() & 0x01 == 0 + } + + pub fn is_server_initiated(self) -> bool { + !self.is_client_initiated() + } + + pub fn role(self) -> Role { + if self.is_client_initiated() { + Role::Client + } else { + Role::Server + } + } + + pub fn is_self_initiated(self, my_role: Role) -> bool { + match my_role { + Role::Client if self.is_client_initiated() => true, + Role::Server if self.is_server_initiated() => true, + _ => false, + } + } + + pub fn is_remote_initiated(self, my_role: Role) -> bool { + !self.is_self_initiated(my_role) + } + + pub fn is_send_only(self, my_role: Role) -> bool { + self.is_uni() && self.is_self_initiated(my_role) + } + + pub fn is_recv_only(self, my_role: Role) -> bool { + self.is_uni() && self.is_remote_initiated(my_role) + } + + pub fn next(&mut self) { + self.0 += 4; + } + + /// This returns a bit that is shared by all streams created by this role. + pub fn role_bit(role: Role) -> u64 { + match role { + Role::Server => 1, + Role::Client => 0, + } + } +} + +impl From<u64> for StreamId { + fn from(val: u64) -> Self { + Self::new(val) + } +} + +impl From<&u64> for StreamId { + fn from(val: &u64) -> Self { + Self::new(*val) + } +} + +impl PartialEq<u64> for StreamId { + fn eq(&self, other: &u64) -> bool { + self.as_u64() == *other + } +} + +impl AsRef<u64> for StreamId { + fn as_ref(&self) -> &u64 { + &self.0 + } +} + +impl ::std::fmt::Display for StreamId { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}", self.as_u64()) + } +} + +#[cfg(test)] +mod test { + use neqo_common::Role; + + use super::StreamId; + + #[test] + fn bidi_stream_properties() { + let id1 = StreamId::from(16); + assert!(id1.is_bidi()); + assert!(!id1.is_uni()); + assert!(id1.is_client_initiated()); + assert!(!id1.is_server_initiated()); + assert_eq!(id1.role(), Role::Client); + assert!(id1.is_self_initiated(Role::Client)); + assert!(!id1.is_self_initiated(Role::Server)); + assert!(!id1.is_remote_initiated(Role::Client)); + assert!(id1.is_remote_initiated(Role::Server)); + assert!(!id1.is_send_only(Role::Server)); + assert!(!id1.is_send_only(Role::Client)); + assert!(!id1.is_recv_only(Role::Server)); + assert!(!id1.is_recv_only(Role::Client)); + assert_eq!(id1.as_u64(), 16); + } + + #[test] + fn uni_stream_properties() { + let id2 = StreamId::from(35); + assert!(!id2.is_bidi()); + assert!(id2.is_uni()); + assert!(!id2.is_client_initiated()); + assert!(id2.is_server_initiated()); + assert_eq!(id2.role(), Role::Server); + assert!(!id2.is_self_initiated(Role::Client)); + assert!(id2.is_self_initiated(Role::Server)); + assert!(id2.is_remote_initiated(Role::Client)); + assert!(!id2.is_remote_initiated(Role::Server)); + assert!(id2.is_send_only(Role::Server)); + assert!(!id2.is_send_only(Role::Client)); + assert!(!id2.is_recv_only(Role::Server)); + assert!(id2.is_recv_only(Role::Client)); + assert_eq!(id2.as_u64(), 35); + } +} diff --git a/third_party/rust/neqo-transport/src/streams.rs b/third_party/rust/neqo-transport/src/streams.rs new file mode 100644 index 0000000000..7cbb29ce02 --- /dev/null +++ b/third_party/rust/neqo-transport/src/streams.rs @@ -0,0 +1,547 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Stream management for a connection. +use std::{cell::RefCell, cmp::Ordering, rc::Rc}; + +use neqo_common::{qtrace, qwarn, Role}; + +use crate::{ + fc::{LocalStreamLimits, ReceiverFlowControl, RemoteStreamLimits, SenderFlowControl}, + frame::Frame, + packet::PacketBuilder, + recovery::{RecoveryToken, StreamRecoveryToken}, + recv_stream::{RecvStream, RecvStreams}, + send_stream::{SendStream, SendStreams, TransmissionPriority}, + stats::FrameStats, + stream_id::{StreamId, StreamType}, + tparams::{self, TransportParametersHandler}, + ConnectionEvents, Error, Res, +}; + +pub type SendOrder = i64; + +#[derive(Copy, Clone)] +pub struct StreamOrder { + pub sendorder: Option<SendOrder>, +} + +// We want highest to lowest, with None being higher than any value +impl Ord for StreamOrder { + fn cmp(&self, other: &Self) -> Ordering { + if self.sendorder.is_some() && other.sendorder.is_some() { + // We want reverse order (high to low) when both values are specified. + other.sendorder.cmp(&self.sendorder) + } else { + self.sendorder.cmp(&other.sendorder) + } + } +} + +impl PartialOrd for StreamOrder { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some(self.cmp(other)) + } +} + +impl PartialEq for StreamOrder { + fn eq(&self, other: &Self) -> bool { + self.sendorder == other.sendorder + } +} + +impl Eq for StreamOrder {} + +pub struct Streams { + role: Role, + tps: Rc<RefCell<TransportParametersHandler>>, + events: ConnectionEvents, + sender_fc: Rc<RefCell<SenderFlowControl<()>>>, + receiver_fc: Rc<RefCell<ReceiverFlowControl<()>>>, + remote_stream_limits: RemoteStreamLimits, + local_stream_limits: LocalStreamLimits, + pub(crate) send: SendStreams, + pub(crate) recv: RecvStreams, +} + +impl Streams { + pub fn new( + tps: Rc<RefCell<TransportParametersHandler>>, + role: Role, + events: ConnectionEvents, + ) -> Self { + let limit_bidi = tps + .borrow() + .local + .get_integer(tparams::INITIAL_MAX_STREAMS_BIDI); + let limit_uni = tps + .borrow() + .local + .get_integer(tparams::INITIAL_MAX_STREAMS_UNI); + let max_data = tps.borrow().local.get_integer(tparams::INITIAL_MAX_DATA); + Self { + role, + tps, + events, + sender_fc: Rc::new(RefCell::new(SenderFlowControl::new((), 0))), + receiver_fc: Rc::new(RefCell::new(ReceiverFlowControl::new((), max_data))), + remote_stream_limits: RemoteStreamLimits::new(limit_bidi, limit_uni, role), + local_stream_limits: LocalStreamLimits::new(role), + send: SendStreams::default(), + recv: RecvStreams::default(), + } + } + + pub fn is_stream_id_allowed(&self, stream_id: StreamId) -> bool { + self.remote_stream_limits[stream_id.stream_type()].is_allowed(stream_id) + } + + pub fn zero_rtt_rejected(&mut self) { + self.clear_streams(); + debug_assert_eq!( + self.remote_stream_limits[StreamType::BiDi].max_active(), + self.tps + .borrow() + .local + .get_integer(tparams::INITIAL_MAX_STREAMS_BIDI) + ); + debug_assert_eq!( + self.remote_stream_limits[StreamType::UniDi].max_active(), + self.tps + .borrow() + .local + .get_integer(tparams::INITIAL_MAX_STREAMS_UNI) + ); + self.local_stream_limits = LocalStreamLimits::new(self.role); + } + + pub fn input_frame(&mut self, frame: Frame, stats: &mut FrameStats) -> Res<()> { + match frame { + Frame::ResetStream { + stream_id, + application_error_code, + final_size, + } => { + stats.reset_stream += 1; + if let (_, Some(rs)) = self.obtain_stream(stream_id)? { + rs.reset(application_error_code, final_size)?; + } + } + Frame::StopSending { + stream_id, + application_error_code, + } => { + stats.stop_sending += 1; + self.events + .send_stream_stop_sending(stream_id, application_error_code); + if let (Some(ss), _) = self.obtain_stream(stream_id)? { + ss.reset(application_error_code); + } + } + Frame::Stream { + fin, + stream_id, + offset, + data, + .. + } => { + stats.stream += 1; + if let (_, Some(rs)) = self.obtain_stream(stream_id)? { + rs.inbound_stream_frame(fin, offset, data)?; + } + } + Frame::MaxData { maximum_data } => { + stats.max_data += 1; + self.handle_max_data(maximum_data); + } + Frame::MaxStreamData { + stream_id, + maximum_stream_data, + } => { + qtrace!( + "Stream {} Received MaxStreamData {}", + stream_id, + maximum_stream_data + ); + stats.max_stream_data += 1; + if let (Some(ss), _) = self.obtain_stream(stream_id)? { + ss.set_max_stream_data(maximum_stream_data); + } + } + Frame::MaxStreams { + stream_type, + maximum_streams, + } => { + stats.max_streams += 1; + self.handle_max_streams(stream_type, maximum_streams); + } + Frame::DataBlocked { data_limit } => { + // Should never happen since we set data limit to max + qwarn!("Received DataBlocked with data limit {}", data_limit); + stats.data_blocked += 1; + self.handle_data_blocked(); + } + Frame::StreamDataBlocked { stream_id, .. } => { + qtrace!("Received StreamDataBlocked"); + stats.stream_data_blocked += 1; + // Terminate connection with STREAM_STATE_ERROR if send-only + // stream (-transport 19.13) + if stream_id.is_send_only(self.role) { + return Err(Error::StreamStateError); + } + + if let (_, Some(rs)) = self.obtain_stream(stream_id)? { + rs.send_flowc_update(); + } + } + Frame::StreamsBlocked { .. } => { + stats.streams_blocked += 1; + // We send an update evry time we retire a stream. There is no need to + // trigger flow updates here. + } + _ => unreachable!("This is not a stream Frame"), + } + Ok(()) + } + + fn write_maintenance_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + // Send `DATA_BLOCKED` as necessary. + self.sender_fc + .borrow_mut() + .write_frames(builder, tokens, stats); + if builder.is_full() { + return; + } + + // Send `MAX_DATA` as necessary. + self.receiver_fc + .borrow_mut() + .write_frames(builder, tokens, stats); + if builder.is_full() { + return; + } + + self.recv.write_frames(builder, tokens, stats); + + self.remote_stream_limits[StreamType::BiDi].write_frames(builder, tokens, stats); + if builder.is_full() { + return; + } + self.remote_stream_limits[StreamType::UniDi].write_frames(builder, tokens, stats); + if builder.is_full() { + return; + } + + self.local_stream_limits[StreamType::BiDi].write_frames(builder, tokens, stats); + if builder.is_full() { + return; + } + + self.local_stream_limits[StreamType::UniDi].write_frames(builder, tokens, stats); + } + + pub fn write_frames( + &mut self, + priority: TransmissionPriority, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + if priority == TransmissionPriority::Important { + self.write_maintenance_frames(builder, tokens, stats); + if builder.is_full() { + return; + } + } + + self.send.write_frames(priority, builder, tokens, stats); + } + + pub fn lost(&mut self, token: &StreamRecoveryToken) { + match token { + StreamRecoveryToken::Stream(st) => self.send.lost(st), + StreamRecoveryToken::ResetStream { stream_id } => self.send.reset_lost(*stream_id), + StreamRecoveryToken::StreamDataBlocked { stream_id, limit } => { + self.send.blocked_lost(*stream_id, *limit); + } + StreamRecoveryToken::MaxStreamData { + stream_id, + max_data, + } => { + if let Ok((_, Some(rs))) = self.obtain_stream(*stream_id) { + rs.max_stream_data_lost(*max_data); + } + } + StreamRecoveryToken::StopSending { stream_id } => { + if let Ok((_, Some(rs))) = self.obtain_stream(*stream_id) { + rs.stop_sending_lost(); + } + } + StreamRecoveryToken::StreamsBlocked { stream_type, limit } => { + self.local_stream_limits[*stream_type].frame_lost(*limit); + } + StreamRecoveryToken::MaxStreams { + stream_type, + max_streams, + } => { + self.remote_stream_limits[*stream_type].frame_lost(*max_streams); + } + StreamRecoveryToken::DataBlocked(limit) => { + self.sender_fc.borrow_mut().frame_lost(*limit); + } + StreamRecoveryToken::MaxData(maximum_data) => { + self.receiver_fc.borrow_mut().frame_lost(*maximum_data); + } + } + } + + pub fn acked(&mut self, token: &StreamRecoveryToken) { + match token { + StreamRecoveryToken::Stream(st) => self.send.acked(st), + StreamRecoveryToken::ResetStream { stream_id } => self.send.reset_acked(*stream_id), + StreamRecoveryToken::StopSending { stream_id } => { + if let Ok((_, Some(rs))) = self.obtain_stream(*stream_id) { + rs.stop_sending_acked(); + } + } + // We only worry when these are lost + StreamRecoveryToken::DataBlocked(_) + | StreamRecoveryToken::StreamDataBlocked { .. } + | StreamRecoveryToken::MaxStreamData { .. } + | StreamRecoveryToken::StreamsBlocked { .. } + | StreamRecoveryToken::MaxStreams { .. } + | StreamRecoveryToken::MaxData(_) => (), + } + } + + pub fn clear_streams(&mut self) { + self.send.clear(); + self.recv.clear(); + } + + pub fn cleanup_closed_streams(&mut self) { + // filter the list, removing closed streams + self.send.remove_terminal(); + + let send = &self.send; + let (removed_bidi, removed_uni) = self.recv.clear_terminal(send, self.role); + + // Send max_streams updates if we removed remote-initiated recv streams. + // The updates will be send if any steams has been removed. + self.remote_stream_limits[StreamType::BiDi].add_retired(removed_bidi); + self.remote_stream_limits[StreamType::UniDi].add_retired(removed_uni); + } + + fn ensure_created_if_remote(&mut self, stream_id: StreamId) -> Res<()> { + if !stream_id.is_remote_initiated(self.role) + || !self.remote_stream_limits[stream_id.stream_type()].is_new_stream(stream_id)? + { + // If it is not a remote stream and stream already exist. + return Ok(()); + } + + let tp = match stream_id.stream_type() { + // From the local perspective, this is a remote- originated BiDi stream. From + // the remote perspective, this is a local-originated BiDi stream. Therefore, + // look at the local transport parameters for the + // INITIAL_MAX_STREAM_DATA_BIDI_REMOTE value to decide how much this endpoint + // will allow its peer to send. + StreamType::BiDi => tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, + StreamType::UniDi => tparams::INITIAL_MAX_STREAM_DATA_UNI, + }; + let recv_initial_max_stream_data = self.tps.borrow().local.get_integer(tp); + + while self.remote_stream_limits[stream_id.stream_type()].is_new_stream(stream_id)? { + let next_stream_id = + self.remote_stream_limits[stream_id.stream_type()].take_stream_id(); + self.events.new_stream(next_stream_id); + + self.recv.insert( + next_stream_id, + RecvStream::new( + next_stream_id, + recv_initial_max_stream_data, + Rc::clone(&self.receiver_fc), + self.events.clone(), + ), + ); + + if next_stream_id.is_bidi() { + // From the local perspective, this is a remote- originated BiDi stream. + // From the remote perspective, this is a local-originated BiDi stream. + // Therefore, look at the remote's transport parameters for the + // INITIAL_MAX_STREAM_DATA_BIDI_LOCAL value to decide how much this endpoint + // is allowed to send its peer. + let send_initial_max_stream_data = self + .tps + .borrow() + .remote() + .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL); + self.send.insert( + next_stream_id, + SendStream::new( + next_stream_id, + send_initial_max_stream_data, + Rc::clone(&self.sender_fc), + self.events.clone(), + ), + ); + } + } + Ok(()) + } + + /// Get or make a stream, and implicitly open additional streams as + /// indicated by its stream id. + pub fn obtain_stream( + &mut self, + stream_id: StreamId, + ) -> Res<(Option<&mut SendStream>, Option<&mut RecvStream>)> { + self.ensure_created_if_remote(stream_id)?; + Ok(( + self.send.get_mut(stream_id).ok(), + self.recv.get_mut(stream_id).ok(), + )) + } + + pub fn set_sendorder(&mut self, stream_id: StreamId, sendorder: Option<SendOrder>) -> Res<()> { + self.send.set_sendorder(stream_id, sendorder) + } + + pub fn set_fairness(&mut self, stream_id: StreamId, fairness: bool) -> Res<()> { + self.send.set_fairness(stream_id, fairness) + } + + pub fn stream_create(&mut self, st: StreamType) -> Res<StreamId> { + match self.local_stream_limits.take_stream_id(st) { + None => Err(Error::StreamLimitError), + Some(new_id) => { + let send_limit_tp = match st { + StreamType::UniDi => tparams::INITIAL_MAX_STREAM_DATA_UNI, + StreamType::BiDi => tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, + }; + let send_limit = self.tps.borrow().remote().get_integer(send_limit_tp); + let stream = SendStream::new( + new_id, + send_limit, + Rc::clone(&self.sender_fc), + self.events.clone(), + ); + self.send.insert(new_id, stream); + + if st == StreamType::BiDi { + // From the local perspective, this is a local- originated BiDi stream. From the + // remote perspective, this is a remote-originated BiDi stream. Therefore, look + // at the local transport parameters for the + // INITIAL_MAX_STREAM_DATA_BIDI_LOCAL value to decide how + // much this endpoint will allow its peer to send. + let recv_initial_max_stream_data = self + .tps + .borrow() + .local + .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL); + + self.recv.insert( + new_id, + RecvStream::new( + new_id, + recv_initial_max_stream_data, + Rc::clone(&self.receiver_fc), + self.events.clone(), + ), + ); + } + Ok(new_id) + } + } + } + + pub fn handle_max_data(&mut self, maximum_data: u64) { + let conn_was_blocked = self.sender_fc.borrow().available() == 0; + let conn_credit_increased = self.sender_fc.borrow_mut().update(maximum_data); + + if conn_was_blocked && conn_credit_increased { + for (id, ss) in &mut self.send { + if ss.avail() > 0 { + // These may not actually all be writable if one + // uses up all the conn credit. Not our fault. + self.events.send_stream_writable(*id); + } + } + } + } + + pub fn handle_data_blocked(&mut self) { + self.receiver_fc.borrow_mut().send_flowc_update(); + } + + pub fn set_initial_limits(&mut self) { + _ = self.local_stream_limits[StreamType::BiDi].update( + self.tps + .borrow() + .remote() + .get_integer(tparams::INITIAL_MAX_STREAMS_BIDI), + ); + _ = self.local_stream_limits[StreamType::UniDi].update( + self.tps + .borrow() + .remote() + .get_integer(tparams::INITIAL_MAX_STREAMS_UNI), + ); + + // As a client, there are two sets of initial limits for sending stream data. + // If the second limit is higher and streams have been created, then + // ensure that streams are not blocked on the lower limit. + if self.role == Role::Client { + self.send.update_initial_limit(self.tps.borrow().remote()); + } + + self.sender_fc.borrow_mut().update( + self.tps + .borrow() + .remote() + .get_integer(tparams::INITIAL_MAX_DATA), + ); + + if self.local_stream_limits[StreamType::BiDi].available() > 0 { + self.events.send_stream_creatable(StreamType::BiDi); + } + if self.local_stream_limits[StreamType::UniDi].available() > 0 { + self.events.send_stream_creatable(StreamType::UniDi); + } + } + + pub fn handle_max_streams(&mut self, stream_type: StreamType, maximum_streams: u64) { + if self.local_stream_limits[stream_type].update(maximum_streams) { + self.events.send_stream_creatable(stream_type); + } + } + + pub fn get_send_stream_mut(&mut self, stream_id: StreamId) -> Res<&mut SendStream> { + self.send.get_mut(stream_id) + } + + pub fn get_send_stream(&self, stream_id: StreamId) -> Res<&SendStream> { + self.send.get(stream_id) + } + + pub fn get_recv_stream_mut(&mut self, stream_id: StreamId) -> Res<&mut RecvStream> { + self.recv.get_mut(stream_id) + } + + pub fn keep_alive(&mut self, stream_id: StreamId, keep: bool) -> Res<()> { + self.recv.keep_alive(stream_id, keep) + } + + pub fn need_keep_alive(&mut self) -> bool { + self.recv.need_keep_alive() + } +} diff --git a/third_party/rust/neqo-transport/src/tparams.rs b/third_party/rust/neqo-transport/src/tparams.rs new file mode 100644 index 0000000000..1297829094 --- /dev/null +++ b/third_party/rust/neqo-transport/src/tparams.rs @@ -0,0 +1,1130 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Transport parameters. See -transport section 7.3. + +use std::{ + cell::RefCell, + collections::HashMap, + convert::TryFrom, + net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}, + rc::Rc, +}; + +use neqo_common::{hex, qdebug, qinfo, qtrace, Decoder, Encoder, Role}; +use neqo_crypto::{ + constants::{TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS}, + ext::{ExtensionHandler, ExtensionHandlerResult, ExtensionWriterResult}, + random, HandshakeMessage, ZeroRttCheckResult, ZeroRttChecker, +}; + +use crate::{ + cid::{ConnectionId, ConnectionIdEntry, CONNECTION_ID_SEQNO_PREFERRED, MAX_CONNECTION_ID_LEN}, + version::{Version, VersionConfig, WireVersion}, + Error, Res, +}; + +pub type TransportParameterId = u64; +macro_rules! tpids { + { $($n:ident = $v:expr),+ $(,)? } => { + $(pub const $n: TransportParameterId = $v as TransportParameterId;)+ + + /// A complete list of internal transport parameters. + #[cfg(not(test))] + pub(crate) const INTERNAL_TRANSPORT_PARAMETERS: &[TransportParameterId] = &[ $($n),+ ]; + }; + } +tpids! { + ORIGINAL_DESTINATION_CONNECTION_ID = 0x00, + IDLE_TIMEOUT = 0x01, + STATELESS_RESET_TOKEN = 0x02, + MAX_UDP_PAYLOAD_SIZE = 0x03, + INITIAL_MAX_DATA = 0x04, + INITIAL_MAX_STREAM_DATA_BIDI_LOCAL = 0x05, + INITIAL_MAX_STREAM_DATA_BIDI_REMOTE = 0x06, + INITIAL_MAX_STREAM_DATA_UNI = 0x07, + INITIAL_MAX_STREAMS_BIDI = 0x08, + INITIAL_MAX_STREAMS_UNI = 0x09, + ACK_DELAY_EXPONENT = 0x0a, + MAX_ACK_DELAY = 0x0b, + DISABLE_MIGRATION = 0x0c, + PREFERRED_ADDRESS = 0x0d, + ACTIVE_CONNECTION_ID_LIMIT = 0x0e, + INITIAL_SOURCE_CONNECTION_ID = 0x0f, + RETRY_SOURCE_CONNECTION_ID = 0x10, + VERSION_INFORMATION = 0x11, + GREASE_QUIC_BIT = 0x2ab2, + MIN_ACK_DELAY = 0xff02_de1a, + MAX_DATAGRAM_FRAME_SIZE = 0x0020, +} + +#[derive(Clone, Debug)] +pub struct PreferredAddress { + v4: Option<SocketAddrV4>, + v6: Option<SocketAddrV6>, +} + +impl PreferredAddress { + /// Make a new preferred address configuration. + /// + /// # Panics + /// + /// If neither address is provided, or if either address is of the wrong type. + #[must_use] + pub fn new(v4: Option<SocketAddrV4>, v6: Option<SocketAddrV6>) -> Self { + assert!(v4.is_some() || v6.is_some()); + if let Some(a) = v4 { + assert!(!a.ip().is_unspecified()); + assert_ne!(a.port(), 0); + } + if let Some(a) = v6 { + assert!(!a.ip().is_unspecified()); + assert_ne!(a.port(), 0); + } + Self { v4, v6 } + } + + /// A generic version of `new()` for testing. + #[must_use] + #[cfg(test)] + pub fn new_any(v4: Option<std::net::SocketAddr>, v6: Option<std::net::SocketAddr>) -> Self { + use std::net::SocketAddr; + + let v4 = v4.map(|v4| { + let SocketAddr::V4(v4) = v4 else { + panic!("not v4"); + }; + v4 + }); + let v6 = v6.map(|v6| { + let SocketAddr::V6(v6) = v6 else { + panic!("not v6"); + }; + v6 + }); + Self::new(v4, v6) + } + + #[must_use] + pub fn ipv4(&self) -> Option<SocketAddrV4> { + self.v4 + } + #[must_use] + pub fn ipv6(&self) -> Option<SocketAddrV6> { + self.v6 + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum TransportParameter { + Bytes(Vec<u8>), + Integer(u64), + Empty, + PreferredAddress { + v4: Option<SocketAddrV4>, + v6: Option<SocketAddrV6>, + cid: ConnectionId, + srt: [u8; 16], + }, + Versions { + current: WireVersion, + other: Vec<WireVersion>, + }, +} + +impl TransportParameter { + fn encode(&self, enc: &mut Encoder, tp: TransportParameterId) { + qdebug!("TP encoded; type 0x{:02x} val {:?}", tp, self); + enc.encode_varint(tp); + match self { + Self::Bytes(a) => { + enc.encode_vvec(a); + } + Self::Integer(a) => { + enc.encode_vvec_with(|enc_inner| { + enc_inner.encode_varint(*a); + }); + } + Self::Empty => { + enc.encode_varint(0_u64); + } + Self::PreferredAddress { v4, v6, cid, srt } => { + enc.encode_vvec_with(|enc_inner| { + if let Some(v4) = v4 { + enc_inner.encode(&v4.ip().octets()[..]); + enc_inner.encode_uint(2, v4.port()); + } else { + enc_inner.encode(&[0; 6]); + } + if let Some(v6) = v6 { + enc_inner.encode(&v6.ip().octets()[..]); + enc_inner.encode_uint(2, v6.port()); + } else { + enc_inner.encode(&[0; 18]); + } + enc_inner.encode_vec(1, &cid[..]); + enc_inner.encode(&srt[..]); + }); + } + Self::Versions { current, other } => { + enc.encode_vvec_with(|enc_inner| { + enc_inner.encode_uint(4, *current); + for v in other { + enc_inner.encode_uint(4, *v); + } + }); + } + }; + } + + fn decode_preferred_address(d: &mut Decoder) -> Res<Self> { + // IPv4 address (maybe) + let v4ip = + Ipv4Addr::from(<[u8; 4]>::try_from(d.decode(4).ok_or(Error::NoMoreData)?).unwrap()); + let v4port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?).unwrap(); + // Can't have non-zero IP and zero port, or vice versa. + if v4ip.is_unspecified() ^ (v4port == 0) { + return Err(Error::TransportParameterError); + } + let v4 = if v4port == 0 { + None + } else { + Some(SocketAddrV4::new(v4ip, v4port)) + }; + + // IPv6 address (mostly the same as v4) + let v6ip = + Ipv6Addr::from(<[u8; 16]>::try_from(d.decode(16).ok_or(Error::NoMoreData)?).unwrap()); + let v6port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?).unwrap(); + if v6ip.is_unspecified() ^ (v6port == 0) { + return Err(Error::TransportParameterError); + } + let v6 = if v6port == 0 { + None + } else { + Some(SocketAddrV6::new(v6ip, v6port, 0, 0)) + }; + // Need either v4 or v6 to be present. + if v4.is_none() && v6.is_none() { + return Err(Error::TransportParameterError); + } + + // Connection ID (non-zero length) + let cid = ConnectionId::from(d.decode_vec(1).ok_or(Error::NoMoreData)?); + if cid.len() == 0 || cid.len() > MAX_CONNECTION_ID_LEN { + return Err(Error::TransportParameterError); + } + + // Stateless reset token + let srtbuf = d.decode(16).ok_or(Error::NoMoreData)?; + let srt = <[u8; 16]>::try_from(srtbuf).unwrap(); + + Ok(Self::PreferredAddress { v4, v6, cid, srt }) + } + + fn decode_versions(dec: &mut Decoder) -> Res<Self> { + fn dv(dec: &mut Decoder) -> Res<WireVersion> { + let v = dec.decode_uint(4).ok_or(Error::NoMoreData)?; + if v == 0 { + Err(Error::TransportParameterError) + } else { + Ok(v as WireVersion) + } + } + + let current = dv(dec)?; + // This rounding down is OK because `decode` checks for left over data. + let count = dec.remaining() / 4; + let mut other = Vec::with_capacity(count); + for _ in 0..count { + other.push(dv(dec)?); + } + Ok(Self::Versions { current, other }) + } + + fn decode(dec: &mut Decoder) -> Res<Option<(TransportParameterId, Self)>> { + let tp = dec.decode_varint().ok_or(Error::NoMoreData)?; + let content = dec.decode_vvec().ok_or(Error::NoMoreData)?; + qtrace!("TP {:x} length {:x}", tp, content.len()); + let mut d = Decoder::from(content); + let value = match tp { + ORIGINAL_DESTINATION_CONNECTION_ID + | INITIAL_SOURCE_CONNECTION_ID + | RETRY_SOURCE_CONNECTION_ID => Self::Bytes(d.decode_remainder().to_vec()), + STATELESS_RESET_TOKEN => { + if d.remaining() != 16 { + return Err(Error::TransportParameterError); + } + Self::Bytes(d.decode_remainder().to_vec()) + } + IDLE_TIMEOUT + | INITIAL_MAX_DATA + | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL + | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE + | INITIAL_MAX_STREAM_DATA_UNI + | MAX_ACK_DELAY + | MAX_DATAGRAM_FRAME_SIZE => match d.decode_varint() { + Some(v) => Self::Integer(v), + None => return Err(Error::TransportParameterError), + }, + + INITIAL_MAX_STREAMS_BIDI | INITIAL_MAX_STREAMS_UNI => match d.decode_varint() { + Some(v) if v <= (1 << 60) => Self::Integer(v), + _ => return Err(Error::StreamLimitError), + }, + + MAX_UDP_PAYLOAD_SIZE => match d.decode_varint() { + Some(v) if v >= 1200 => Self::Integer(v), + _ => return Err(Error::TransportParameterError), + }, + + ACK_DELAY_EXPONENT => match d.decode_varint() { + Some(v) if v <= 20 => Self::Integer(v), + _ => return Err(Error::TransportParameterError), + }, + ACTIVE_CONNECTION_ID_LIMIT => match d.decode_varint() { + Some(v) if v >= 2 => Self::Integer(v), + _ => return Err(Error::TransportParameterError), + }, + + DISABLE_MIGRATION | GREASE_QUIC_BIT => Self::Empty, + + PREFERRED_ADDRESS => Self::decode_preferred_address(&mut d)?, + + MIN_ACK_DELAY => match d.decode_varint() { + Some(v) if v < (1 << 24) => Self::Integer(v), + _ => return Err(Error::TransportParameterError), + }, + + VERSION_INFORMATION => Self::decode_versions(&mut d)?, + + // Skip. + _ => return Ok(None), + }; + if d.remaining() > 0 { + return Err(Error::TooMuchData); + } + qdebug!("TP decoded; type 0x{:02x} val {:?}", tp, value); + Ok(Some((tp, value))) + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct TransportParameters { + params: HashMap<TransportParameterId, TransportParameter>, +} + +impl TransportParameters { + /// Set a value. + pub fn set(&mut self, k: TransportParameterId, v: TransportParameter) { + self.params.insert(k, v); + } + + /// Clear a key. + pub fn remove(&mut self, k: TransportParameterId) { + self.params.remove(&k); + } + + /// Decode is a static function that parses transport parameters + /// using the provided decoder. + pub(crate) fn decode(d: &mut Decoder) -> Res<Self> { + let mut tps = Self::default(); + qtrace!("Parsed fixed TP header"); + + while d.remaining() > 0 { + match TransportParameter::decode(d) { + Ok(Some((tipe, tp))) => { + tps.set(tipe, tp); + } + Ok(None) => {} + Err(e) => return Err(e), + } + } + Ok(tps) + } + + pub(crate) fn encode(&self, enc: &mut Encoder) { + for (tipe, tp) in &self.params { + tp.encode(enc, *tipe); + } + } + + // Get an integer type or a default. + pub fn get_integer(&self, tp: TransportParameterId) -> u64 { + let default = match tp { + IDLE_TIMEOUT + | INITIAL_MAX_DATA + | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL + | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE + | INITIAL_MAX_STREAM_DATA_UNI + | INITIAL_MAX_STREAMS_BIDI + | INITIAL_MAX_STREAMS_UNI + | MIN_ACK_DELAY + | MAX_DATAGRAM_FRAME_SIZE => 0, + MAX_UDP_PAYLOAD_SIZE => 65527, + ACK_DELAY_EXPONENT => 3, + MAX_ACK_DELAY => 25, + ACTIVE_CONNECTION_ID_LIMIT => 2, + _ => panic!("Transport parameter not known or not an Integer"), + }; + match self.params.get(&tp) { + None => default, + Some(TransportParameter::Integer(x)) => *x, + _ => panic!("Internal error"), + } + } + + // Set an integer type or a default. + pub fn set_integer(&mut self, tp: TransportParameterId, value: u64) { + match tp { + IDLE_TIMEOUT + | INITIAL_MAX_DATA + | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL + | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE + | INITIAL_MAX_STREAM_DATA_UNI + | INITIAL_MAX_STREAMS_BIDI + | INITIAL_MAX_STREAMS_UNI + | MAX_UDP_PAYLOAD_SIZE + | ACK_DELAY_EXPONENT + | MAX_ACK_DELAY + | ACTIVE_CONNECTION_ID_LIMIT + | MIN_ACK_DELAY + | MAX_DATAGRAM_FRAME_SIZE => { + self.set(tp, TransportParameter::Integer(value)); + } + _ => panic!("Transport parameter not known"), + } + } + + pub fn get_bytes(&self, tp: TransportParameterId) -> Option<&[u8]> { + match tp { + ORIGINAL_DESTINATION_CONNECTION_ID + | INITIAL_SOURCE_CONNECTION_ID + | RETRY_SOURCE_CONNECTION_ID + | STATELESS_RESET_TOKEN => {} + _ => panic!("Transport parameter not known or not type bytes"), + } + + match self.params.get(&tp) { + None => None, + Some(TransportParameter::Bytes(x)) => Some(x), + _ => panic!("Internal error"), + } + } + + pub fn set_bytes(&mut self, tp: TransportParameterId, value: Vec<u8>) { + match tp { + ORIGINAL_DESTINATION_CONNECTION_ID + | INITIAL_SOURCE_CONNECTION_ID + | RETRY_SOURCE_CONNECTION_ID + | STATELESS_RESET_TOKEN => { + self.set(tp, TransportParameter::Bytes(value)); + } + _ => panic!("Transport parameter not known or not type bytes"), + } + } + + pub fn set_empty(&mut self, tp: TransportParameterId) { + match tp { + DISABLE_MIGRATION | GREASE_QUIC_BIT => { + self.set(tp, TransportParameter::Empty); + } + _ => panic!("Transport parameter not known or not type empty"), + } + } + + /// Set version information. + pub fn set_versions(&mut self, role: Role, versions: &VersionConfig) { + let rbuf = random(4); + let mut other = Vec::with_capacity(versions.all().len() + 1); + let mut dec = Decoder::new(&rbuf); + let grease = (dec.decode_uint(4).unwrap() as u32) & 0xf0f0_f0f0 | 0x0a0a_0a0a; + other.push(grease); + for &v in versions.all() { + if role == Role::Client && !versions.initial().is_compatible(v) { + continue; + } + other.push(v.wire_version()); + } + let current = versions.initial().wire_version(); + self.set( + VERSION_INFORMATION, + TransportParameter::Versions { current, other }, + ); + } + + fn compatible_upgrade(&mut self, v: Version) { + if let Some(TransportParameter::Versions { + ref mut current, .. + }) = self.params.get_mut(&VERSION_INFORMATION) + { + *current = v.wire_version(); + } else { + unreachable!("Compatible upgrade without transport parameters set!"); + } + } + + pub fn get_empty(&self, tipe: TransportParameterId) -> bool { + match self.params.get(&tipe) { + None => false, + Some(TransportParameter::Empty) => true, + _ => panic!("Internal error"), + } + } + + /// Return true if the remembered transport parameters are OK for 0-RTT. + /// Generally this means that any value that is currently in effect is greater than + /// or equal to the promised value. + pub(crate) fn ok_for_0rtt(&self, remembered: &Self) -> bool { + for (k, v_rem) in &remembered.params { + // Skip checks for these, which don't affect 0-RTT. + if matches!( + *k, + ORIGINAL_DESTINATION_CONNECTION_ID + | INITIAL_SOURCE_CONNECTION_ID + | RETRY_SOURCE_CONNECTION_ID + | STATELESS_RESET_TOKEN + | IDLE_TIMEOUT + | ACK_DELAY_EXPONENT + | MAX_ACK_DELAY + | ACTIVE_CONNECTION_ID_LIMIT + | PREFERRED_ADDRESS + ) { + continue; + } + let ok = if let Some(v_self) = self.params.get(k) { + match (v_self, v_rem) { + (TransportParameter::Integer(i_self), TransportParameter::Integer(i_rem)) => { + if *k == MIN_ACK_DELAY { + // MIN_ACK_DELAY is backwards: + // it can only be reduced safely. + *i_self <= *i_rem + } else { + *i_self >= *i_rem + } + } + (TransportParameter::Empty, TransportParameter::Empty) => true, + ( + TransportParameter::Versions { + current: v_self, .. + }, + TransportParameter::Versions { current: v_rem, .. }, + ) => v_self == v_rem, + _ => false, + } + } else { + false + }; + if !ok { + return false; + } + } + true + } + + /// Get the preferred address in a usable form. + #[must_use] + pub fn get_preferred_address(&self) -> Option<(PreferredAddress, ConnectionIdEntry<[u8; 16]>)> { + if let Some(TransportParameter::PreferredAddress { v4, v6, cid, srt }) = + self.params.get(&PREFERRED_ADDRESS) + { + Some(( + PreferredAddress::new(*v4, *v6), + ConnectionIdEntry::new(CONNECTION_ID_SEQNO_PREFERRED, cid.clone(), *srt), + )) + } else { + None + } + } + + /// Get the version negotiation values for validation. + #[must_use] + pub fn get_versions(&self) -> Option<(WireVersion, &[WireVersion])> { + if let Some(TransportParameter::Versions { current, other }) = + self.params.get(&VERSION_INFORMATION) + { + Some((*current, other)) + } else { + None + } + } + + #[must_use] + pub fn has_value(&self, tp: TransportParameterId) -> bool { + self.params.contains_key(&tp) + } +} + +#[derive(Debug)] +pub struct TransportParametersHandler { + role: Role, + versions: VersionConfig, + pub(crate) local: TransportParameters, + pub(crate) remote: Option<TransportParameters>, + pub(crate) remote_0rtt: Option<TransportParameters>, +} + +impl TransportParametersHandler { + pub fn new(role: Role, versions: VersionConfig) -> Self { + let mut local = TransportParameters::default(); + local.set_versions(role, &versions); + Self { + role, + versions, + local, + remote: None, + remote_0rtt: None, + } + } + + /// When resuming, the version is set based on the ticket. + /// That needs to be done to override the default choice from configuration. + pub fn set_version(&mut self, version: Version) { + debug_assert_eq!(self.role, Role::Client); + self.versions.set_initial(version); + self.local.set_versions(self.role, &self.versions); + } + + pub fn remote(&self) -> &TransportParameters { + match (self.remote.as_ref(), self.remote_0rtt.as_ref()) { + (Some(tp), _) | (_, Some(tp)) => tp, + _ => panic!("no transport parameters from peer"), + } + } + + /// Get the version as set (or as determined by a compatible upgrade). + pub fn version(&self) -> Version { + self.versions.initial() + } + + fn compatible_upgrade(&mut self, remote_tp: &TransportParameters) -> Res<()> { + if let Some((current, other)) = remote_tp.get_versions() { + qtrace!( + "Peer versions: {:x} {:x?}; config {:?}", + current, + other, + self.versions, + ); + + if self.role == Role::Client { + let chosen = Version::try_from(current)?; + if self.versions.compatible().any(|&v| v == chosen) { + Ok(()) + } else { + qinfo!( + "Chosen version {:x} is not compatible with initial version {:x}", + current, + self.versions.initial().wire_version(), + ); + Err(Error::TransportParameterError) + } + } else { + if current != self.versions.initial().wire_version() { + qinfo!( + "Current version {:x} != own version {:x}", + current, + self.versions.initial().wire_version(), + ); + return Err(Error::TransportParameterError); + } + + if let Some(preferred) = self.versions.preferred_compatible(other) { + if preferred != self.versions.initial() { + qinfo!( + "Compatible upgrade {:?} ==> {:?}", + self.versions.initial(), + preferred + ); + self.versions.set_initial(preferred); + self.local.compatible_upgrade(preferred); + } + Ok(()) + } else { + qinfo!("Unable to find any compatible version"); + Err(Error::TransportParameterError) + } + } + } else { + Ok(()) + } + } +} + +impl ExtensionHandler for TransportParametersHandler { + fn write(&mut self, msg: HandshakeMessage, d: &mut [u8]) -> ExtensionWriterResult { + if !matches!(msg, TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS) { + return ExtensionWriterResult::Skip; + } + + qdebug!("Writing transport parameters, msg={:?}", msg); + + // TODO(ekr@rtfm.com): Modify to avoid a copy. + let mut enc = Encoder::default(); + self.local.encode(&mut enc); + assert!(enc.len() <= d.len()); + d[..enc.len()].copy_from_slice(enc.as_ref()); + ExtensionWriterResult::Write(enc.len()) + } + + fn handle(&mut self, msg: HandshakeMessage, d: &[u8]) -> ExtensionHandlerResult { + qtrace!( + "Handling transport parameters, msg={:?} value={}", + msg, + hex(d), + ); + + if !matches!(msg, TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS) { + return ExtensionHandlerResult::Alert(110); // unsupported_extension + } + + let mut dec = Decoder::from(d); + match TransportParameters::decode(&mut dec) { + Ok(tp) => { + if self.compatible_upgrade(&tp).is_ok() { + self.remote = Some(tp); + ExtensionHandlerResult::Ok + } else { + ExtensionHandlerResult::Alert(47) + } + } + _ => ExtensionHandlerResult::Alert(47), // illegal_parameter + } + } +} + +#[derive(Debug)] +pub(crate) struct TpZeroRttChecker<T> { + handler: Rc<RefCell<TransportParametersHandler>>, + app_checker: T, +} + +impl<T> TpZeroRttChecker<T> +where + T: ZeroRttChecker + 'static, +{ + pub fn wrap( + handler: Rc<RefCell<TransportParametersHandler>>, + app_checker: T, + ) -> Box<dyn ZeroRttChecker> { + Box::new(Self { + handler, + app_checker, + }) + } +} + +impl<T> ZeroRttChecker for TpZeroRttChecker<T> +where + T: ZeroRttChecker, +{ + fn check(&self, token: &[u8]) -> ZeroRttCheckResult { + // Reject 0-RTT if there is no token. + if token.is_empty() { + qdebug!("0-RTT: no token, no 0-RTT"); + return ZeroRttCheckResult::Reject; + } + let mut dec = Decoder::from(token); + let Some(tpslice) = dec.decode_vvec() else { + qinfo!("0-RTT: token code error"); + return ZeroRttCheckResult::Fail; + }; + let mut dec_tp = Decoder::from(tpslice); + let Ok(remembered) = TransportParameters::decode(&mut dec_tp) else { + qinfo!("0-RTT: transport parameter decode error"); + return ZeroRttCheckResult::Fail; + }; + if self.handler.borrow().local.ok_for_0rtt(&remembered) { + qinfo!("0-RTT: transport parameters OK, passing to application checker"); + self.app_checker.check(dec.decode_remainder()) + } else { + qinfo!("0-RTT: transport parameters bad, rejecting"); + ZeroRttCheckResult::Reject + } + } +} + +#[cfg(test)] +#[allow(unused_variables)] +mod tests { + use super::*; + + #[test] + fn basic_tps() { + const RESET_TOKEN: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8]; + let mut tps = TransportParameters::default(); + tps.set( + STATELESS_RESET_TOKEN, + TransportParameter::Bytes(RESET_TOKEN.to_vec()), + ); + tps.params + .insert(INITIAL_MAX_STREAMS_BIDI, TransportParameter::Integer(10)); + + let mut enc = Encoder::default(); + tps.encode(&mut enc); + + let tps2 = TransportParameters::decode(&mut enc.as_decoder()).expect("Couldn't decode"); + assert_eq!(tps, tps2); + + println!("TPS = {tps:?}"); + assert_eq!(tps2.get_integer(IDLE_TIMEOUT), 0); // Default + assert_eq!(tps2.get_integer(MAX_ACK_DELAY), 25); // Default + assert_eq!(tps2.get_integer(ACTIVE_CONNECTION_ID_LIMIT), 2); // Default + assert_eq!(tps2.get_integer(INITIAL_MAX_STREAMS_BIDI), 10); // Sent + assert_eq!(tps2.get_bytes(STATELESS_RESET_TOKEN), Some(RESET_TOKEN)); + assert_eq!(tps2.get_bytes(ORIGINAL_DESTINATION_CONNECTION_ID), None); + assert_eq!(tps2.get_bytes(INITIAL_SOURCE_CONNECTION_ID), None); + assert_eq!(tps2.get_bytes(RETRY_SOURCE_CONNECTION_ID), None); + assert!(!tps2.has_value(ORIGINAL_DESTINATION_CONNECTION_ID)); + assert!(!tps2.has_value(INITIAL_SOURCE_CONNECTION_ID)); + assert!(!tps2.has_value(RETRY_SOURCE_CONNECTION_ID)); + assert!(tps2.has_value(STATELESS_RESET_TOKEN)); + + let mut enc = Encoder::default(); + tps.encode(&mut enc); + + let tps2 = TransportParameters::decode(&mut enc.as_decoder()).expect("Couldn't decode"); + } + + fn make_spa() -> TransportParameter { + TransportParameter::PreferredAddress { + v4: Some(SocketAddrV4::new(Ipv4Addr::from(0xc000_0201), 443)), + v6: Some(SocketAddrV6::new( + Ipv6Addr::from(0xfe80_0000_0000_0000_0000_0000_0000_0001), + 443, + 0, + 0, + )), + cid: ConnectionId::from(&[1, 2, 3, 4, 5]), + srt: [3; 16], + } + } + + #[test] + fn preferred_address_encode_decode() { + const ENCODED: &[u8] = &[ + 0x0d, 0x2e, 0xc0, 0x00, 0x02, 0x01, 0x01, 0xbb, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0xbb, 0x05, 0x01, + 0x02, 0x03, 0x04, 0x05, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, + 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, + ]; + let spa = make_spa(); + let mut enc = Encoder::new(); + spa.encode(&mut enc, PREFERRED_ADDRESS); + assert_eq!(enc.as_ref(), ENCODED); + + let mut dec = enc.as_decoder(); + let (id, decoded) = TransportParameter::decode(&mut dec).unwrap().unwrap(); + assert_eq!(id, PREFERRED_ADDRESS); + assert_eq!(decoded, spa); + } + + fn mutate_spa<F>(wrecker: F) -> TransportParameter + where + F: FnOnce(&mut Option<SocketAddrV4>, &mut Option<SocketAddrV6>, &mut ConnectionId), + { + let mut spa = make_spa(); + if let TransportParameter::PreferredAddress { + ref mut v4, + ref mut v6, + ref mut cid, + .. + } = &mut spa + { + wrecker(v4, v6, cid); + } else { + unreachable!(); + } + spa + } + + /// This takes a `TransportParameter::PreferredAddress` that has been mutilated. + /// It then encodes it, working from the knowledge that the `encode` function + /// doesn't care about validity, and decodes it. The result should be failure. + fn assert_invalid_spa(spa: TransportParameter) { + let mut enc = Encoder::new(); + spa.encode(&mut enc, PREFERRED_ADDRESS); + assert_eq!( + TransportParameter::decode(&mut enc.as_decoder()).unwrap_err(), + Error::TransportParameterError + ); + } + + /// This is for those rare mutations that are acceptable. + fn assert_valid_spa(spa: TransportParameter) { + let mut enc = Encoder::new(); + spa.encode(&mut enc, PREFERRED_ADDRESS); + let mut dec = enc.as_decoder(); + let (id, decoded) = TransportParameter::decode(&mut dec).unwrap().unwrap(); + assert_eq!(id, PREFERRED_ADDRESS); + assert_eq!(decoded, spa); + } + + #[test] + fn preferred_address_zero_address() { + // Either port being zero is bad. + assert_invalid_spa(mutate_spa(|v4, _, _| { + v4.as_mut().unwrap().set_port(0); + })); + assert_invalid_spa(mutate_spa(|_, v6, _| { + v6.as_mut().unwrap().set_port(0); + })); + // Either IP being zero is bad. + assert_invalid_spa(mutate_spa(|v4, _, _| { + v4.as_mut().unwrap().set_ip(Ipv4Addr::from(0)); + })); + assert_invalid_spa(mutate_spa(|_, v6, _| { + v6.as_mut().unwrap().set_ip(Ipv6Addr::from(0)); + })); + // Either address being absent is OK. + assert_valid_spa(mutate_spa(|v4, _, _| { + *v4 = None; + })); + assert_valid_spa(mutate_spa(|_, v6, _| { + *v6 = None; + })); + // Both addresses being absent is bad. + assert_invalid_spa(mutate_spa(|v4, v6, _| { + *v4 = None; + *v6 = None; + })); + } + + #[test] + fn preferred_address_bad_cid() { + assert_invalid_spa(mutate_spa(|_, _, cid| { + *cid = ConnectionId::from(&[]); + })); + assert_invalid_spa(mutate_spa(|_, _, cid| { + *cid = ConnectionId::from(&[0x0c; 21]); + })); + } + + #[test] + fn preferred_address_truncated() { + let spa = make_spa(); + let mut enc = Encoder::new(); + spa.encode(&mut enc, PREFERRED_ADDRESS); + let mut dec = Decoder::from(&enc.as_ref()[..enc.len() - 1]); + assert_eq!( + TransportParameter::decode(&mut dec).unwrap_err(), + Error::NoMoreData + ); + } + + #[test] + #[should_panic(expected = "v4.is_some() || v6.is_some()")] + fn preferred_address_neither() { + _ = PreferredAddress::new(None, None); + } + + #[test] + #[should_panic(expected = ".is_unspecified")] + fn preferred_address_v4_unspecified() { + _ = PreferredAddress::new(Some(SocketAddrV4::new(Ipv4Addr::from(0), 443)), None); + } + + #[test] + #[should_panic(expected = "left != right")] + fn preferred_address_v4_zero_port() { + _ = PreferredAddress::new( + Some(SocketAddrV4::new(Ipv4Addr::from(0xc000_0201), 0)), + None, + ); + } + + #[test] + #[should_panic(expected = ".is_unspecified")] + fn preferred_address_v6_unspecified() { + _ = PreferredAddress::new(None, Some(SocketAddrV6::new(Ipv6Addr::from(0), 443, 0, 0))); + } + + #[test] + #[should_panic(expected = "left != right")] + fn preferred_address_v6_zero_port() { + _ = PreferredAddress::new(None, Some(SocketAddrV6::new(Ipv6Addr::from(1), 0, 0, 0))); + } + + #[test] + fn compatible_0rtt_ignored_values() { + let mut tps_a = TransportParameters::default(); + tps_a.set( + STATELESS_RESET_TOKEN, + TransportParameter::Bytes(vec![1, 2, 3]), + ); + tps_a.set(IDLE_TIMEOUT, TransportParameter::Integer(10)); + tps_a.set(MAX_ACK_DELAY, TransportParameter::Integer(22)); + tps_a.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(33)); + + let mut tps_b = TransportParameters::default(); + assert!(tps_a.ok_for_0rtt(&tps_b)); + assert!(tps_b.ok_for_0rtt(&tps_a)); + + tps_b.set( + STATELESS_RESET_TOKEN, + TransportParameter::Bytes(vec![8, 9, 10]), + ); + tps_b.set(IDLE_TIMEOUT, TransportParameter::Integer(100)); + tps_b.set(MAX_ACK_DELAY, TransportParameter::Integer(2)); + tps_b.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(44)); + assert!(tps_a.ok_for_0rtt(&tps_b)); + assert!(tps_b.ok_for_0rtt(&tps_a)); + } + + #[test] + fn compatible_0rtt_integers() { + let mut tps_a = TransportParameters::default(); + const INTEGER_KEYS: &[TransportParameterId] = &[ + INITIAL_MAX_DATA, + INITIAL_MAX_STREAM_DATA_BIDI_LOCAL, + INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, + INITIAL_MAX_STREAM_DATA_UNI, + INITIAL_MAX_STREAMS_BIDI, + INITIAL_MAX_STREAMS_UNI, + MAX_UDP_PAYLOAD_SIZE, + MIN_ACK_DELAY, + MAX_DATAGRAM_FRAME_SIZE, + ]; + for i in INTEGER_KEYS { + tps_a.set(*i, TransportParameter::Integer(12)); + } + + let tps_b = tps_a.clone(); + assert!(tps_a.ok_for_0rtt(&tps_b)); + assert!(tps_b.ok_for_0rtt(&tps_a)); + + // For each integer key, choose a new value that will be accepted. + for i in INTEGER_KEYS { + let mut tps_b = tps_a.clone(); + // Set a safe new value; reducing MIN_ACK_DELAY instead. + let safe_value = if *i == MIN_ACK_DELAY { 11 } else { 13 }; + tps_b.set(*i, TransportParameter::Integer(safe_value)); + // If the new value is not safe relative to the remembered value, + // then we can't attempt 0-RTT with these parameters. + assert!(!tps_a.ok_for_0rtt(&tps_b)); + // The opposite situation is fine. + assert!(tps_b.ok_for_0rtt(&tps_a)); + } + + // Drop integer values and check that that is OK. + for i in INTEGER_KEYS { + let mut tps_b = tps_a.clone(); + tps_b.remove(*i); + // A value that is missing from what is rememebered is OK. + assert!(tps_a.ok_for_0rtt(&tps_b)); + // A value that is rememebered, but not current is not OK. + assert!(!tps_b.ok_for_0rtt(&tps_a)); + } + } + + /// `ACTIVE_CONNECTION_ID_LIMIT` can't be less than 2. + #[test] + fn active_connection_id_limit_min_2() { + let mut tps = TransportParameters::default(); + + // Intentionally set an invalid value for the ACTIVE_CONNECTION_ID_LIMIT transport + // parameter. + tps.params + .insert(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(1)); + + let mut enc = Encoder::default(); + tps.encode(&mut enc); + + // When decoding a set of transport parameters with an invalid ACTIVE_CONNECTION_ID_LIMIT + // the result should be an error. + let invalid_decode_result = TransportParameters::decode(&mut enc.as_decoder()); + assert!(invalid_decode_result.is_err()); + } + + #[test] + fn versions_encode_decode() { + const ENCODED: &[u8] = &[ + 0x11, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x1a, 0x2a, 0x3a, 0x4a, 0x5a, 0x6a, 0x7a, 0x8a, + ]; + let vn = TransportParameter::Versions { + current: Version::Version1.wire_version(), + other: vec![0x1a2a_3a4a, 0x5a6a_7a8a], + }; + + let mut enc = Encoder::new(); + vn.encode(&mut enc, VERSION_INFORMATION); + assert_eq!(enc.as_ref(), ENCODED); + + let mut dec = enc.as_decoder(); + let (id, decoded) = TransportParameter::decode(&mut dec).unwrap().unwrap(); + assert_eq!(id, VERSION_INFORMATION); + assert_eq!(decoded, vn); + } + + #[test] + fn versions_truncated() { + const TRUNCATED: &[u8] = &[ + 0x80, 0xff, 0x73, 0xdb, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x1a, 0x2a, 0x3a, 0x4a, 0x5a, + 0x6a, 0x7a, + ]; + let mut dec = Decoder::from(&TRUNCATED); + assert_eq!( + TransportParameter::decode(&mut dec).unwrap_err(), + Error::NoMoreData + ); + } + + #[test] + fn versions_zero() { + const ZERO1: &[u8] = &[0x11, 0x04, 0x00, 0x00, 0x00, 0x00]; + const ZERO2: &[u8] = &[0x11, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00]; + + let mut dec = Decoder::from(&ZERO1); + assert_eq!( + TransportParameter::decode(&mut dec).unwrap_err(), + Error::TransportParameterError + ); + let mut dec = Decoder::from(&ZERO2); + assert_eq!( + TransportParameter::decode(&mut dec).unwrap_err(), + Error::TransportParameterError + ); + } + + #[test] + fn versions_equal_0rtt() { + let mut current = TransportParameters::default(); + current.set( + VERSION_INFORMATION, + TransportParameter::Versions { + current: Version::Version1.wire_version(), + other: vec![0x1a2a_3a4a], + }, + ); + + let mut remembered = TransportParameters::default(); + // It's OK to not remember having versions. + assert!(current.ok_for_0rtt(&remembered)); + // But it is bad in the opposite direction. + assert!(!remembered.ok_for_0rtt(¤t)); + + // If the version matches, it's OK to use 0-RTT. + remembered.set( + VERSION_INFORMATION, + TransportParameter::Versions { + current: Version::Version1.wire_version(), + other: vec![0x5a6a_7a8a, 0x9aaa_baca], + }, + ); + assert!(current.ok_for_0rtt(&remembered)); + assert!(remembered.ok_for_0rtt(¤t)); + + // An apparent "upgrade" is still cause to reject 0-RTT. + remembered.set( + VERSION_INFORMATION, + TransportParameter::Versions { + current: Version::Version1.wire_version() + 1, + other: vec![], + }, + ); + assert!(!current.ok_for_0rtt(&remembered)); + assert!(!remembered.ok_for_0rtt(¤t)); + } +} diff --git a/third_party/rust/neqo-transport/src/tracking.rs b/third_party/rust/neqo-transport/src/tracking.rs new file mode 100644 index 0000000000..64d00257d3 --- /dev/null +++ b/third_party/rust/neqo-transport/src/tracking.rs @@ -0,0 +1,1228 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Tracking of received packets and generating acks thereof. + +#![deny(clippy::pedantic)] + +use std::{ + cmp::min, + collections::VecDeque, + convert::TryFrom, + ops::{Index, IndexMut}, + time::{Duration, Instant}, +}; + +use neqo_common::{qdebug, qinfo, qtrace, qwarn}; +use neqo_crypto::{Epoch, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL}; +use smallvec::{smallvec, SmallVec}; + +use crate::{ + packet::{PacketBuilder, PacketNumber, PacketType}, + recovery::RecoveryToken, + stats::FrameStats, +}; + +// TODO(mt) look at enabling EnumMap for this: https://stackoverflow.com/a/44905797/1375574 +#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq)] +pub enum PacketNumberSpace { + Initial, + Handshake, + ApplicationData, +} + +#[allow(clippy::use_self)] // https://github.com/rust-lang/rust-clippy/issues/3410 +impl PacketNumberSpace { + pub fn iter() -> impl Iterator<Item = &'static PacketNumberSpace> { + const SPACES: &[PacketNumberSpace] = &[ + PacketNumberSpace::Initial, + PacketNumberSpace::Handshake, + PacketNumberSpace::ApplicationData, + ]; + SPACES.iter() + } +} + +impl From<Epoch> for PacketNumberSpace { + fn from(epoch: Epoch) -> Self { + match epoch { + TLS_EPOCH_INITIAL => Self::Initial, + TLS_EPOCH_HANDSHAKE => Self::Handshake, + _ => Self::ApplicationData, + } + } +} + +impl From<PacketType> for PacketNumberSpace { + fn from(pt: PacketType) -> Self { + match pt { + PacketType::Initial => Self::Initial, + PacketType::Handshake => Self::Handshake, + PacketType::ZeroRtt | PacketType::Short => Self::ApplicationData, + _ => panic!("Attempted to get space from wrong packet type"), + } + } +} + +#[derive(Clone, Copy, Default)] +pub struct PacketNumberSpaceSet { + initial: bool, + handshake: bool, + application_data: bool, +} + +impl PacketNumberSpaceSet { + pub fn all() -> Self { + Self { + initial: true, + handshake: true, + application_data: true, + } + } +} + +impl Index<PacketNumberSpace> for PacketNumberSpaceSet { + type Output = bool; + + fn index(&self, space: PacketNumberSpace) -> &Self::Output { + match space { + PacketNumberSpace::Initial => &self.initial, + PacketNumberSpace::Handshake => &self.handshake, + PacketNumberSpace::ApplicationData => &self.application_data, + } + } +} + +impl IndexMut<PacketNumberSpace> for PacketNumberSpaceSet { + fn index_mut(&mut self, space: PacketNumberSpace) -> &mut Self::Output { + match space { + PacketNumberSpace::Initial => &mut self.initial, + PacketNumberSpace::Handshake => &mut self.handshake, + PacketNumberSpace::ApplicationData => &mut self.application_data, + } + } +} + +impl<T: AsRef<[PacketNumberSpace]>> From<T> for PacketNumberSpaceSet { + fn from(spaces: T) -> Self { + let mut v = Self::default(); + for sp in spaces.as_ref() { + v[*sp] = true; + } + v + } +} + +impl std::fmt::Debug for PacketNumberSpaceSet { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let mut first = true; + f.write_str("(")?; + for sp in PacketNumberSpace::iter() { + if self[*sp] { + if !first { + f.write_str("+")?; + first = false; + } + std::fmt::Display::fmt(sp, f)?; + } + } + f.write_str(")") + } +} + +#[derive(Debug, Clone)] +pub struct SentPacket { + pub pt: PacketType, + pub pn: PacketNumber, + ack_eliciting: bool, + pub time_sent: Instant, + primary_path: bool, + pub tokens: Vec<RecoveryToken>, + + time_declared_lost: Option<Instant>, + /// After a PTO, this is true when the packet has been released. + pto: bool, + + pub size: usize, +} + +impl SentPacket { + pub fn new( + pt: PacketType, + pn: PacketNumber, + time_sent: Instant, + ack_eliciting: bool, + tokens: Vec<RecoveryToken>, + size: usize, + ) -> Self { + Self { + pt, + pn, + time_sent, + ack_eliciting, + primary_path: true, + tokens, + time_declared_lost: None, + pto: false, + size, + } + } + + /// Returns `true` if the packet will elicit an ACK. + pub fn ack_eliciting(&self) -> bool { + self.ack_eliciting + } + + /// Returns `true` if the packet was sent on the primary path. + pub fn on_primary_path(&self) -> bool { + self.primary_path + } + + /// Clears the flag that had this packet on the primary path. + /// Used when migrating to clear out state. + pub fn clear_primary_path(&mut self) { + self.primary_path = false; + } + + /// Whether the packet has been declared lost. + pub fn lost(&self) -> bool { + self.time_declared_lost.is_some() + } + + /// Whether accounting for the loss or acknowledgement in the + /// congestion controller is pending. + /// Returns `true` if the packet counts as being "in flight", + /// and has not previously been declared lost. + /// Note that this should count packets that contain only ACK and PADDING, + /// but we don't send PADDING, so we don't track that. + pub fn cc_outstanding(&self) -> bool { + self.ack_eliciting() && self.on_primary_path() && !self.lost() + } + + /// Whether the packet should be tracked as in-flight. + pub fn cc_in_flight(&self) -> bool { + self.ack_eliciting() && self.on_primary_path() + } + + /// Declare the packet as lost. Returns `true` if this is the first time. + pub fn declare_lost(&mut self, now: Instant) -> bool { + if self.lost() { + false + } else { + self.time_declared_lost = Some(now); + true + } + } + + /// Ask whether this tracked packet has been declared lost for long enough + /// that it can be expired and no longer tracked. + pub fn expired(&self, now: Instant, expiration_period: Duration) -> bool { + self.time_declared_lost + .map_or(false, |loss_time| (loss_time + expiration_period) <= now) + } + + /// Whether the packet contents were cleared out after a PTO. + pub fn pto_fired(&self) -> bool { + self.pto + } + + /// On PTO, we need to get the recovery tokens so that we can ensure that + /// the frames we sent can be sent again in the PTO packet(s). Do that just once. + pub fn pto(&mut self) -> bool { + if self.pto || self.lost() { + false + } else { + self.pto = true; + true + } + } +} + +impl std::fmt::Display for PacketNumberSpace { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str(match self { + Self::Initial => "in", + Self::Handshake => "hs", + Self::ApplicationData => "ap", + }) + } +} + +/// `InsertionResult` tracks whether something was inserted for `PacketRange::add()`. +pub enum InsertionResult { + Largest, + Smallest, + NotInserted, +} + +#[derive(Clone, Debug, Default)] +pub struct PacketRange { + largest: PacketNumber, + smallest: PacketNumber, + ack_needed: bool, +} + +impl PacketRange { + /// Make a single packet range. + pub fn new(pn: PacketNumber) -> Self { + Self { + largest: pn, + smallest: pn, + ack_needed: true, + } + } + + /// Get the number of acknowleged packets in the range. + pub fn len(&self) -> u64 { + self.largest - self.smallest + 1 + } + + /// Returns whether this needs to be sent. + pub fn ack_needed(&self) -> bool { + self.ack_needed + } + + /// Return whether the given number is in the range. + pub fn contains(&self, pn: PacketNumber) -> bool { + (pn >= self.smallest) && (pn <= self.largest) + } + + /// Maybe add a packet number to the range. Returns true if it was added + /// at the small end (which indicates that this might need merging with a + /// preceding range). + pub fn add(&mut self, pn: PacketNumber) -> InsertionResult { + assert!(!self.contains(pn)); + // Only insert if this is adjacent the current range. + if (self.largest + 1) == pn { + qtrace!([self], "Adding largest {}", pn); + self.largest += 1; + self.ack_needed = true; + InsertionResult::Largest + } else if self.smallest == (pn + 1) { + qtrace!([self], "Adding smallest {}", pn); + self.smallest -= 1; + self.ack_needed = true; + InsertionResult::Smallest + } else { + InsertionResult::NotInserted + } + } + + /// Maybe merge a higher-numbered range into this. + fn merge_larger(&mut self, other: &Self) { + qinfo!([self], "Merging {}", other); + // This only works if they are immediately adjacent. + assert_eq!(self.largest + 1, other.smallest); + + self.largest = other.largest; + self.ack_needed = self.ack_needed || other.ack_needed; + } + + /// When a packet containing the range `other` is acknowledged, + /// clear the `ack_needed` attribute on this. + /// Requires that other is equal to this, or a larger range. + pub fn acknowledged(&mut self, other: &Self) { + if (other.smallest <= self.smallest) && (other.largest >= self.largest) { + self.ack_needed = false; + } + } +} + +impl ::std::fmt::Display for PacketRange { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}->{}", self.largest, self.smallest) + } +} + +/// The ACK delay we use. +pub const DEFAULT_ACK_DELAY: Duration = Duration::from_millis(20); // 20ms +/// The default number of in-order packets we will receive after +/// largest acknowledged without sending an immediate acknowledgment. +pub const DEFAULT_ACK_PACKET_TOLERANCE: PacketNumber = 1; +const MAX_TRACKED_RANGES: usize = 32; +const MAX_ACKS_PER_FRAME: usize = 32; + +/// A structure that tracks what was included in an ACK. +#[derive(Debug, Clone)] +pub struct AckToken { + space: PacketNumberSpace, + ranges: Vec<PacketRange>, +} + +/// A structure that tracks what packets have been received, +/// and what needs acknowledgement for a packet number space. +#[derive(Debug)] +pub struct RecvdPackets { + space: PacketNumberSpace, + ranges: VecDeque<PacketRange>, + /// The packet number of the lowest number packet that we are tracking. + min_tracked: PacketNumber, + /// The time we got the largest acknowledged. + largest_pn_time: Option<Instant>, + /// The time that we should be sending an ACK. + ack_time: Option<Instant>, + /// The time we last sent an ACK. + last_ack_time: Option<Instant>, + /// The current ACK frequency sequence number. + ack_frequency_seqno: u64, + /// The time to delay after receiving the first packet that is + /// not immediately acknowledged. + ack_delay: Duration, + /// The number of ack-eliciting packets that have been received, but + /// not acknowledged. + unacknowledged_count: PacketNumber, + /// The number of contiguous packets that can be received without + /// acknowledging immediately. + unacknowledged_tolerance: PacketNumber, + /// Whether we are ignoring packets that arrive out of order + /// for the purposes of generating immediate acknowledgment. + ignore_order: bool, +} + +impl RecvdPackets { + /// Make a new `RecvdPackets` for the indicated packet number space. + pub fn new(space: PacketNumberSpace) -> Self { + Self { + space, + ranges: VecDeque::new(), + min_tracked: 0, + largest_pn_time: None, + ack_time: None, + last_ack_time: None, + ack_frequency_seqno: 0, + ack_delay: DEFAULT_ACK_DELAY, + unacknowledged_count: 0, + unacknowledged_tolerance: DEFAULT_ACK_PACKET_TOLERANCE, + ignore_order: false, + } + } + + /// Get the time at which the next ACK should be sent. + pub fn ack_time(&self) -> Option<Instant> { + self.ack_time + } + + /// Update acknowledgment delay parameters. + pub fn ack_freq( + &mut self, + seqno: u64, + tolerance: PacketNumber, + delay: Duration, + ignore_order: bool, + ) { + // Yes, this means that we will overwrite values if a sequence number is + // reused, but that is better than using an `Option<PacketNumber>` + // when it will always be `Some`. + if seqno >= self.ack_frequency_seqno { + self.ack_frequency_seqno = seqno; + self.unacknowledged_tolerance = tolerance; + self.ack_delay = delay; + self.ignore_order = ignore_order; + } + } + + /// Returns true if an ACK frame should be sent now. + fn ack_now(&self, now: Instant, rtt: Duration) -> bool { + // If ack_time is Some, then we have something to acknowledge. + // In that case, either ack because `now >= ack_time`, or + // because it is more than an RTT since the last time we sent an ack. + self.ack_time.map_or(false, |next| { + next <= now || self.last_ack_time.map_or(false, |last| last + rtt <= now) + }) + } + + // A simple addition of a packet number to the tracked set. + // This doesn't do a binary search on the assumption that + // new packets will generally be added to the start of the list. + fn add(&mut self, pn: PacketNumber) { + for i in 0..self.ranges.len() { + match self.ranges[i].add(pn) { + InsertionResult::Largest => return, + InsertionResult::Smallest => { + // If this was the smallest, it might have filled a gap. + let nxt = i + 1; + if (nxt < self.ranges.len()) && (pn - 1 == self.ranges[nxt].largest) { + let larger = self.ranges.remove(i).unwrap(); + self.ranges[i].merge_larger(&larger); + } + return; + } + InsertionResult::NotInserted => { + if self.ranges[i].largest < pn { + self.ranges.insert(i, PacketRange::new(pn)); + return; + } + } + } + } + self.ranges.push_back(PacketRange::new(pn)); + } + + fn trim_ranges(&mut self) { + // Limit the number of ranges that are tracked to MAX_TRACKED_RANGES. + if self.ranges.len() > MAX_TRACKED_RANGES { + let oldest = self.ranges.pop_back().unwrap(); + if oldest.ack_needed { + qwarn!([self], "Dropping unacknowledged ACK range: {}", oldest); + // TODO(mt) Record some statistics about this so we can tune MAX_TRACKED_RANGES. + } else { + qdebug!([self], "Drop ACK range: {}", oldest); + } + self.min_tracked = oldest.largest + 1; + } + } + + /// Add the packet to the tracked set. + /// Return true if the packet was the largest received so far. + pub fn set_received(&mut self, now: Instant, pn: PacketNumber, ack_eliciting: bool) -> bool { + let next_in_order_pn = self.ranges.front().map_or(0, |r| r.largest + 1); + qdebug!([self], "received {}, next: {}", pn, next_in_order_pn); + + self.add(pn); + self.trim_ranges(); + + // The new addition was the largest, so update the time we use for calculating ACK delay. + let largest = if pn >= next_in_order_pn { + self.largest_pn_time = Some(now); + true + } else { + false + }; + + if ack_eliciting { + self.unacknowledged_count += 1; + + let immediate_ack = self.space != PacketNumberSpace::ApplicationData + || (pn != next_in_order_pn && !self.ignore_order) + || self.unacknowledged_count > self.unacknowledged_tolerance; + + let ack_time = if immediate_ack { + now + } else { + // Note that `ack_delay` can change and that won't take effect if + // we are waiting on the previous delay timer. + // If ACK delay increases, we might send an ACK a bit early; + // if ACK delay decreases, we might send an ACK a bit later. + // We could use min() here, but change is rare and the size + // of the change is very small. + self.ack_time.unwrap_or_else(|| now + self.ack_delay) + }; + qdebug!([self], "Set ACK timer to {:?}", ack_time); + self.ack_time = Some(ack_time); + } + largest + } + + /// If we just received a PING frame, we should immediately acknowledge. + pub fn immediate_ack(&mut self, now: Instant) { + self.ack_time = Some(now); + qdebug!([self], "immediate_ack at {:?}", now); + } + + /// Check if the packet is a duplicate. + pub fn is_duplicate(&self, pn: PacketNumber) -> bool { + if pn < self.min_tracked { + return true; + } + self.ranges + .iter() + .take_while(|r| pn <= r.largest) + .any(|r| r.contains(pn)) + } + + /// Mark the given range as having been acknowledged. + pub fn acknowledged(&mut self, acked: &[PacketRange]) { + let mut range_iter = self.ranges.iter_mut(); + let mut cur = range_iter.next().expect("should have at least one range"); + for ack in acked { + while cur.smallest > ack.largest { + cur = match range_iter.next() { + Some(c) => c, + None => return, + }; + } + cur.acknowledged(ack); + } + } + + /// Generate an ACK frame for this packet number space. + /// + /// Unlike other frame generators this doesn't modify the underlying instance + /// to track what has been sent. This only clears the delayed ACK timer. + /// + /// When sending ACKs, we want to always send the most recent ranges, + /// even if they have been sent in other packets. + /// + /// We don't send ranges that have been acknowledged, but they still need + /// to be tracked so that duplicates can be detected. + fn write_frame( + &mut self, + now: Instant, + rtt: Duration, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + // The worst possible ACK frame, assuming only one range. + // Note that this assumes one byte for the type and count of extra ranges. + const LONGEST_ACK_HEADER: usize = 1 + 8 + 8 + 1 + 8; + + // Check that we aren't delaying ACKs. + if !self.ack_now(now, rtt) { + return; + } + + // Drop extra ACK ranges to fit the available space. Do this based on + // a worst-case estimate of frame size for simplicity. + // + // When congestion limited, ACK-only packets are 255 bytes at most + // (`recovery::ACK_ONLY_SIZE_LIMIT - 1`). This results in limiting the + // ranges to 13 here. + let max_ranges = if let Some(avail) = builder.remaining().checked_sub(LONGEST_ACK_HEADER) { + // Apply a hard maximum to keep plenty of space for other stuff. + min(1 + (avail / 16), MAX_ACKS_PER_FRAME) + } else { + return; + }; + + let ranges = self + .ranges + .iter() + .filter(|r| r.ack_needed()) + .take(max_ranges) + .cloned() + .collect::<Vec<_>>(); + + builder.encode_varint(crate::frame::FRAME_TYPE_ACK); + let mut iter = ranges.iter(); + let Some(first) = iter.next() else { return }; + builder.encode_varint(first.largest); + stats.largest_acknowledged = first.largest; + stats.ack += 1; + + let elapsed = now.duration_since(self.largest_pn_time.unwrap()); + // We use the default exponent, so delay is in multiples of 8 microseconds. + let ack_delay = u64::try_from(elapsed.as_micros() / 8).unwrap_or(u64::MAX); + let ack_delay = min((1 << 62) - 1, ack_delay); + builder.encode_varint(ack_delay); + builder.encode_varint(u64::try_from(ranges.len() - 1).unwrap()); // extra ranges + builder.encode_varint(first.len() - 1); // first range + + let mut last = first.smallest; + for r in iter { + // the difference must be at least 2 because 0-length gaps, + // (difference 1) are illegal. + builder.encode_varint(last - r.largest - 2); // Gap + builder.encode_varint(r.len() - 1); // Range + last = r.smallest; + } + + // We've sent an ACK, reset the timer. + self.ack_time = None; + self.last_ack_time = Some(now); + self.unacknowledged_count = 0; + + tokens.push(RecoveryToken::Ack(AckToken { + space: self.space, + ranges, + })); + } +} + +impl ::std::fmt::Display for RecvdPackets { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Recvd-{}", self.space) + } +} + +#[derive(Debug)] +pub struct AckTracker { + /// This stores information about received packets in *reverse* order + /// by spaces. Why reverse? Because we ultimately only want to keep + /// `ApplicationData` and this allows us to drop other spaces easily. + spaces: SmallVec<[RecvdPackets; 1]>, +} + +impl AckTracker { + pub fn drop_space(&mut self, space: PacketNumberSpace) { + let sp = match space { + PacketNumberSpace::Initial => self.spaces.pop(), + PacketNumberSpace::Handshake => { + let sp = self.spaces.pop(); + self.spaces.shrink_to_fit(); + sp + } + PacketNumberSpace::ApplicationData => panic!("discarding application space"), + }; + assert_eq!(sp.unwrap().space, space, "dropping spaces out of order"); + } + + pub fn get_mut(&mut self, space: PacketNumberSpace) -> Option<&mut RecvdPackets> { + self.spaces.get_mut(match space { + PacketNumberSpace::ApplicationData => 0, + PacketNumberSpace::Handshake => 1, + PacketNumberSpace::Initial => 2, + }) + } + + pub fn ack_freq( + &mut self, + seqno: u64, + tolerance: PacketNumber, + delay: Duration, + ignore_order: bool, + ) { + // Only ApplicationData ever delays ACK. + self.get_mut(PacketNumberSpace::ApplicationData) + .unwrap() + .ack_freq(seqno, tolerance, delay, ignore_order); + } + + // Force an ACK to be generated immediately (a PING was received). + pub fn immediate_ack(&mut self, now: Instant) { + self.get_mut(PacketNumberSpace::ApplicationData) + .unwrap() + .immediate_ack(now); + } + + /// Determine the earliest time that an ACK might be needed. + pub fn ack_time(&self, now: Instant) -> Option<Instant> { + for recvd in &self.spaces { + qtrace!("ack_time for {} = {:?}", recvd.space, recvd.ack_time()); + } + + if self.spaces.len() == 1 { + self.spaces[0].ack_time() + } else { + // Ignore any time that is in the past relative to `now`. + // That is something of a hack, but there are cases where we can't send ACK + // frames for all spaces, which can mean that one space is stuck in the past. + // That isn't a problem because we guarantee that earlier spaces will always + // be able to send ACK frames. + self.spaces + .iter() + .filter_map(|recvd| recvd.ack_time().filter(|t| *t > now)) + .min() + } + } + + pub fn acked(&mut self, token: &AckToken) { + if let Some(space) = self.get_mut(token.space) { + space.acknowledged(&token.ranges); + } + } + + pub(crate) fn write_frame( + &mut self, + pn_space: PacketNumberSpace, + now: Instant, + rtt: Duration, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + if let Some(space) = self.get_mut(pn_space) { + space.write_frame(now, rtt, builder, tokens, stats); + } + } +} + +impl Default for AckTracker { + fn default() -> Self { + Self { + spaces: smallvec![ + RecvdPackets::new(PacketNumberSpace::ApplicationData), + RecvdPackets::new(PacketNumberSpace::Handshake), + RecvdPackets::new(PacketNumberSpace::Initial), + ], + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use lazy_static::lazy_static; + use neqo_common::Encoder; + + use super::{ + AckTracker, Duration, Instant, PacketNumberSpace, PacketNumberSpaceSet, RecoveryToken, + RecvdPackets, MAX_TRACKED_RANGES, + }; + use crate::{ + frame::Frame, + packet::{PacketBuilder, PacketNumber}, + stats::FrameStats, + }; + + const RTT: Duration = Duration::from_millis(100); + lazy_static! { + static ref NOW: Instant = Instant::now(); + } + + fn test_ack_range(pns: &[PacketNumber], nranges: usize) { + let mut rp = RecvdPackets::new(PacketNumberSpace::Initial); // Any space will do. + let mut packets = HashSet::new(); + + for pn in pns { + rp.set_received(*NOW, *pn, true); + packets.insert(*pn); + } + + assert_eq!(rp.ranges.len(), nranges); + + // Check that all these packets will be detected as duplicates. + for pn in pns { + assert!(rp.is_duplicate(*pn)); + } + + // Check that the ranges decrease monotonically and don't overlap. + let mut iter = rp.ranges.iter(); + let mut last = iter.next().expect("should have at least one"); + for n in iter { + assert!(n.largest + 1 < last.smallest); + last = n; + } + + // Check that the ranges include the right values. + let mut in_ranges = HashSet::new(); + for range in &rp.ranges { + for included in range.smallest..=range.largest { + in_ranges.insert(included); + } + } + assert_eq!(packets, in_ranges); + } + + #[test] + fn pn0() { + test_ack_range(&[0], 1); + } + + #[test] + fn pn1() { + test_ack_range(&[1], 1); + } + + #[test] + fn two_ranges() { + test_ack_range(&[0, 1, 2, 5, 6, 7], 2); + } + + #[test] + fn fill_in_range() { + test_ack_range(&[0, 1, 2, 5, 6, 7, 3, 4], 1); + } + + #[test] + fn too_many_ranges() { + let mut rp = RecvdPackets::new(PacketNumberSpace::Initial); // Any space will do. + + // This will add one too many disjoint ranges. + for i in 0..=MAX_TRACKED_RANGES { + rp.set_received(*NOW, (i * 2) as u64, true); + } + + assert_eq!(rp.ranges.len(), MAX_TRACKED_RANGES); + assert_eq!(rp.ranges.back().unwrap().largest, 2); + + // Even though the range was dropped, we still consider it a duplicate. + assert!(rp.is_duplicate(0)); + assert!(!rp.is_duplicate(1)); + assert!(rp.is_duplicate(2)); + } + + #[test] + fn ack_delay() { + const COUNT: PacketNumber = 9; + const DELAY: Duration = Duration::from_millis(7); + // Only application data packets are delayed. + let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData); + assert!(rp.ack_time().is_none()); + assert!(!rp.ack_now(*NOW, RTT)); + + rp.ack_freq(0, COUNT, DELAY, false); + + // Some packets won't cause an ACK to be needed. + for i in 0..COUNT { + rp.set_received(*NOW, i, true); + assert_eq!(Some(*NOW + DELAY), rp.ack_time()); + assert!(!rp.ack_now(*NOW, RTT)); + assert!(rp.ack_now(*NOW + DELAY, RTT)); + } + + // Exceeding COUNT will move the ACK time to now. + rp.set_received(*NOW, COUNT, true); + assert_eq!(Some(*NOW), rp.ack_time()); + assert!(rp.ack_now(*NOW, RTT)); + } + + #[test] + fn no_ack_delay() { + for space in &[PacketNumberSpace::Initial, PacketNumberSpace::Handshake] { + let mut rp = RecvdPackets::new(*space); + assert!(rp.ack_time().is_none()); + assert!(!rp.ack_now(*NOW, RTT)); + + // Any packet in these spaces is acknowledged straight away. + rp.set_received(*NOW, 0, true); + assert_eq!(Some(*NOW), rp.ack_time()); + assert!(rp.ack_now(*NOW, RTT)); + } + } + + #[test] + fn ooo_no_ack_delay_new() { + let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData); + assert!(rp.ack_time().is_none()); + assert!(!rp.ack_now(*NOW, RTT)); + + // Anything other than packet 0 is acknowledged immediately. + rp.set_received(*NOW, 1, true); + assert_eq!(Some(*NOW), rp.ack_time()); + assert!(rp.ack_now(*NOW, RTT)); + } + + fn write_frame_at(rp: &mut RecvdPackets, now: Instant) { + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut stats = FrameStats::default(); + let mut tokens = Vec::new(); + rp.write_frame(now, RTT, &mut builder, &mut tokens, &mut stats); + assert!(!tokens.is_empty()); + assert_eq!(stats.ack, 1); + } + + fn write_frame(rp: &mut RecvdPackets) { + write_frame_at(rp, *NOW); + } + + #[test] + fn ooo_no_ack_delay_fill() { + let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData); + rp.set_received(*NOW, 1, true); + write_frame(&mut rp); + + // Filling in behind the largest acknowledged causes immediate ACK. + rp.set_received(*NOW, 0, true); + write_frame(&mut rp); + + // Receiving the next packet won't elicit an ACK. + rp.set_received(*NOW, 2, true); + assert!(!rp.ack_now(*NOW, RTT)); + } + + #[test] + fn immediate_ack_after_rtt() { + let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData); + rp.set_received(*NOW, 1, true); + write_frame(&mut rp); + + // Filling in behind the largest acknowledged causes immediate ACK. + rp.set_received(*NOW, 0, true); + write_frame(&mut rp); + + // A new packet ordinarily doesn't result in an ACK, but this time it does. + rp.set_received(*NOW + RTT, 2, true); + write_frame_at(&mut rp, *NOW + RTT); + } + + #[test] + fn ooo_no_ack_delay_threshold_new() { + let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData); + + // Set tolerance to 2 and then it takes three packets. + rp.ack_freq(0, 2, Duration::from_millis(10), true); + + rp.set_received(*NOW, 1, true); + assert_ne!(Some(*NOW), rp.ack_time()); + rp.set_received(*NOW, 2, true); + assert_ne!(Some(*NOW), rp.ack_time()); + rp.set_received(*NOW, 3, true); + assert_eq!(Some(*NOW), rp.ack_time()); + } + + #[test] + fn ooo_no_ack_delay_threshold_gap() { + let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData); + rp.set_received(*NOW, 1, true); + write_frame(&mut rp); + + // Set tolerance to 2 and then it takes three packets. + rp.ack_freq(0, 2, Duration::from_millis(10), true); + + rp.set_received(*NOW, 3, true); + assert_ne!(Some(*NOW), rp.ack_time()); + rp.set_received(*NOW, 4, true); + assert_ne!(Some(*NOW), rp.ack_time()); + rp.set_received(*NOW, 5, true); + assert_eq!(Some(*NOW), rp.ack_time()); + } + + /// Test that an in-order packet that is not ack-eliciting doesn't + /// increase the number of packets needed to cause an ACK. + #[test] + fn non_ack_eliciting_skip() { + let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData); + rp.ack_freq(0, 1, Duration::from_millis(10), true); + + // This should be ignored. + rp.set_received(*NOW, 0, false); + assert_ne!(Some(*NOW), rp.ack_time()); + // Skip 1 (it has no effect). + rp.set_received(*NOW, 2, true); + assert_ne!(Some(*NOW), rp.ack_time()); + rp.set_received(*NOW, 3, true); + assert_eq!(Some(*NOW), rp.ack_time()); + } + + /// If a packet that is not ack-eliciting is reordered, that's fine too. + #[test] + fn non_ack_eliciting_reorder() { + let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData); + rp.ack_freq(0, 1, Duration::from_millis(10), false); + + // These are out of order, but they are not ack-eliciting. + rp.set_received(*NOW, 1, false); + assert_ne!(Some(*NOW), rp.ack_time()); + rp.set_received(*NOW, 0, false); + assert_ne!(Some(*NOW), rp.ack_time()); + + // These are in order. + rp.set_received(*NOW, 2, true); + assert_ne!(Some(*NOW), rp.ack_time()); + rp.set_received(*NOW, 3, true); + assert_eq!(Some(*NOW), rp.ack_time()); + } + + #[test] + fn aggregate_ack_time() { + const DELAY: Duration = Duration::from_millis(17); + let mut tracker = AckTracker::default(); + tracker.ack_freq(0, 1, DELAY, false); + // This packet won't trigger an ACK. + tracker + .get_mut(PacketNumberSpace::Handshake) + .unwrap() + .set_received(*NOW, 0, false); + assert_eq!(None, tracker.ack_time(*NOW)); + + // This should be delayed. + tracker + .get_mut(PacketNumberSpace::ApplicationData) + .unwrap() + .set_received(*NOW, 0, true); + assert_eq!(Some(*NOW + DELAY), tracker.ack_time(*NOW)); + + // This should move the time forward. + let later = *NOW + (DELAY / 2); + tracker + .get_mut(PacketNumberSpace::Initial) + .unwrap() + .set_received(later, 0, true); + assert_eq!(Some(later), tracker.ack_time(*NOW)); + } + + #[test] + #[should_panic(expected = "discarding application space")] + fn drop_app() { + let mut tracker = AckTracker::default(); + tracker.drop_space(PacketNumberSpace::ApplicationData); + } + + #[test] + #[should_panic(expected = "dropping spaces out of order")] + fn drop_out_of_order() { + let mut tracker = AckTracker::default(); + tracker.drop_space(PacketNumberSpace::Handshake); + } + + #[test] + fn drop_spaces() { + let mut tracker = AckTracker::default(); + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + tracker + .get_mut(PacketNumberSpace::Initial) + .unwrap() + .set_received(*NOW, 0, true); + // The reference time for `ack_time` has to be in the past or we filter out the timer. + assert!(tracker + .ack_time(NOW.checked_sub(Duration::from_millis(1)).unwrap()) + .is_some()); + + let mut tokens = Vec::new(); + let mut stats = FrameStats::default(); + tracker.write_frame( + PacketNumberSpace::Initial, + *NOW, + RTT, + &mut builder, + &mut tokens, + &mut stats, + ); + assert_eq!(stats.ack, 1); + + // Mark another packet as received so we have cause to send another ACK in that space. + tracker + .get_mut(PacketNumberSpace::Initial) + .unwrap() + .set_received(*NOW, 1, true); + assert!(tracker + .ack_time(NOW.checked_sub(Duration::from_millis(1)).unwrap()) + .is_some()); + + // Now drop that space. + tracker.drop_space(PacketNumberSpace::Initial); + + assert!(tracker.get_mut(PacketNumberSpace::Initial).is_none()); + assert!(tracker + .ack_time(NOW.checked_sub(Duration::from_millis(1)).unwrap()) + .is_none()); + tracker.write_frame( + PacketNumberSpace::Initial, + *NOW, + RTT, + &mut builder, + &mut tokens, + &mut stats, + ); + assert_eq!(stats.ack, 1); + if let RecoveryToken::Ack(tok) = &tokens[0] { + tracker.acked(tok); // Should be a noop. + } else { + panic!("not an ACK token"); + } + } + + #[test] + fn no_room_for_ack() { + let mut tracker = AckTracker::default(); + tracker + .get_mut(PacketNumberSpace::Initial) + .unwrap() + .set_received(*NOW, 0, true); + assert!(tracker + .ack_time(NOW.checked_sub(Duration::from_millis(1)).unwrap()) + .is_some()); + + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + builder.set_limit(10); + + let mut stats = FrameStats::default(); + tracker.write_frame( + PacketNumberSpace::Initial, + *NOW, + RTT, + &mut builder, + &mut Vec::new(), + &mut stats, + ); + assert_eq!(stats.ack, 0); + assert_eq!(builder.len(), 1); // Only the short packet header has been added. + } + + #[test] + fn no_room_for_extra_range() { + let mut tracker = AckTracker::default(); + tracker + .get_mut(PacketNumberSpace::Initial) + .unwrap() + .set_received(*NOW, 0, true); + tracker + .get_mut(PacketNumberSpace::Initial) + .unwrap() + .set_received(*NOW, 2, true); + assert!(tracker + .ack_time(NOW.checked_sub(Duration::from_millis(1)).unwrap()) + .is_some()); + + let mut builder = PacketBuilder::short(Encoder::new(), false, []); + builder.set_limit(32); + + let mut stats = FrameStats::default(); + tracker.write_frame( + PacketNumberSpace::Initial, + *NOW, + RTT, + &mut builder, + &mut Vec::new(), + &mut stats, + ); + assert_eq!(stats.ack, 1); + + let mut dec = builder.as_decoder(); + _ = dec.decode_byte().unwrap(); // Skip the short header. + let frame = Frame::decode(&mut dec).unwrap(); + if let Frame::Ack { ack_ranges, .. } = frame { + assert_eq!(ack_ranges.len(), 0); + } else { + panic!("not an ACK!"); + } + } + + #[test] + fn ack_time_elapsed() { + let mut tracker = AckTracker::default(); + + // While we have multiple PN spaces, we ignore ACK timers from the past. + // Send out of order to cause the delayed ack timer to be set to `*NOW`. + tracker + .get_mut(PacketNumberSpace::ApplicationData) + .unwrap() + .set_received(*NOW, 3, true); + assert!(tracker.ack_time(*NOW + Duration::from_millis(1)).is_none()); + + // When we are reduced to one space, that filter is off. + tracker.drop_space(PacketNumberSpace::Initial); + tracker.drop_space(PacketNumberSpace::Handshake); + assert_eq!( + tracker.ack_time(*NOW + Duration::from_millis(1)), + Some(*NOW) + ); + } + + #[test] + fn pnspaceset_default() { + let set = PacketNumberSpaceSet::default(); + assert!(!set[PacketNumberSpace::Initial]); + assert!(!set[PacketNumberSpace::Handshake]); + assert!(!set[PacketNumberSpace::ApplicationData]); + } + + #[test] + fn pnspaceset_from() { + let set = PacketNumberSpaceSet::from(&[PacketNumberSpace::Initial]); + assert!(set[PacketNumberSpace::Initial]); + assert!(!set[PacketNumberSpace::Handshake]); + assert!(!set[PacketNumberSpace::ApplicationData]); + + let set = + PacketNumberSpaceSet::from(&[PacketNumberSpace::Handshake, PacketNumberSpace::Initial]); + assert!(set[PacketNumberSpace::Initial]); + assert!(set[PacketNumberSpace::Handshake]); + assert!(!set[PacketNumberSpace::ApplicationData]); + + let set = PacketNumberSpaceSet::from(&[ + PacketNumberSpace::ApplicationData, + PacketNumberSpace::ApplicationData, + ]); + assert!(!set[PacketNumberSpace::Initial]); + assert!(!set[PacketNumberSpace::Handshake]); + assert!(set[PacketNumberSpace::ApplicationData]); + } + + #[test] + fn pnspaceset_copy() { + let set = PacketNumberSpaceSet::from(&[ + PacketNumberSpace::Handshake, + PacketNumberSpace::ApplicationData, + ]); + let copy = set; + assert!(!copy[PacketNumberSpace::Initial]); + assert!(copy[PacketNumberSpace::Handshake]); + assert!(copy[PacketNumberSpace::ApplicationData]); + } +} diff --git a/third_party/rust/neqo-transport/src/version.rs b/third_party/rust/neqo-transport/src/version.rs new file mode 100644 index 0000000000..13db0bf024 --- /dev/null +++ b/third_party/rust/neqo-transport/src/version.rs @@ -0,0 +1,235 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::convert::TryFrom; + +use neqo_common::qdebug; + +use crate::{Error, Res}; + +pub type WireVersion = u32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Version { + Version2, + Version1, + Draft29, + Draft30, + Draft31, + Draft32, +} + +impl Version { + pub const fn wire_version(self) -> WireVersion { + match self { + Self::Version2 => 0x6b33_43cf, + Self::Version1 => 1, + Self::Draft29 => 0xff00_0000 + 29, + Self::Draft30 => 0xff00_0000 + 30, + Self::Draft31 => 0xff00_0000 + 31, + Self::Draft32 => 0xff00_0000 + 32, + } + } + + pub(crate) fn initial_salt(self) -> &'static [u8] { + const INITIAL_SALT_V2: &[u8] = &[ + 0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, + 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9, + ]; + const INITIAL_SALT_V1: &[u8] = &[ + 0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, + 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a, + ]; + const INITIAL_SALT_29_32: &[u8] = &[ + 0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, + 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99, + ]; + match self { + Self::Version2 => INITIAL_SALT_V2, + Self::Version1 => INITIAL_SALT_V1, + Self::Draft29 | Self::Draft30 | Self::Draft31 | Self::Draft32 => INITIAL_SALT_29_32, + } + } + + pub(crate) fn label_prefix(self) -> &'static str { + match self { + Self::Version2 => "quicv2 ", + Self::Version1 | Self::Draft29 | Self::Draft30 | Self::Draft31 | Self::Draft32 => { + "quic " + } + } + } + + pub(crate) fn retry_secret(self) -> &'static [u8] { + const RETRY_SECRET_V2: &[u8] = &[ + 0xc4, 0xdd, 0x24, 0x84, 0xd6, 0x81, 0xae, 0xfa, 0x4f, 0xf4, 0xd6, 0x9c, 0x2c, 0x20, + 0x29, 0x99, 0x84, 0xa7, 0x65, 0xa5, 0xd3, 0xc3, 0x19, 0x82, 0xf3, 0x8f, 0xc7, 0x41, + 0x62, 0x15, 0x5e, 0x9f, + ]; + const RETRY_SECRET_V1: &[u8] = &[ + 0xd9, 0xc9, 0x94, 0x3e, 0x61, 0x01, 0xfd, 0x20, 0x00, 0x21, 0x50, 0x6b, 0xcc, 0x02, + 0x81, 0x4c, 0x73, 0x03, 0x0f, 0x25, 0xc7, 0x9d, 0x71, 0xce, 0x87, 0x6e, 0xca, 0x87, + 0x6e, 0x6f, 0xca, 0x8e, + ]; + const RETRY_SECRET_29: &[u8] = &[ + 0x8b, 0x0d, 0x37, 0xeb, 0x85, 0x35, 0x02, 0x2e, 0xbc, 0x8d, 0x76, 0xa2, 0x07, 0xd8, + 0x0d, 0xf2, 0x26, 0x46, 0xec, 0x06, 0xdc, 0x80, 0x96, 0x42, 0xc3, 0x0a, 0x8b, 0xaa, + 0x2b, 0xaa, 0xff, 0x4c, + ]; + match self { + Self::Version2 => RETRY_SECRET_V2, + Self::Version1 => RETRY_SECRET_V1, + Self::Draft29 | Self::Draft30 | Self::Draft31 | Self::Draft32 => RETRY_SECRET_29, + } + } + + pub(crate) fn is_draft(self) -> bool { + matches!( + self, + Self::Draft29 | Self::Draft30 | Self::Draft31 | Self::Draft32, + ) + } + + /// Determine if `self` can be upgraded to `other` compatibly. + pub fn is_compatible(self, other: Self) -> bool { + self == other + || matches!( + (self, other), + (Self::Version1, Self::Version2) | (Self::Version2, Self::Version1) + ) + } + + pub fn all() -> Vec<Self> { + vec![ + Self::Version2, + Self::Version1, + Self::Draft32, + Self::Draft31, + Self::Draft30, + Self::Draft29, + ] + } + + pub fn compatible<'a>( + self, + all: impl IntoIterator<Item = &'a Self>, + ) -> impl Iterator<Item = &'a Self> { + all.into_iter().filter(move |&v| self.is_compatible(*v)) + } +} + +impl Default for Version { + fn default() -> Self { + Self::Version1 + } +} + +impl TryFrom<WireVersion> for Version { + type Error = Error; + + fn try_from(wire: WireVersion) -> Res<Self> { + if wire == 1 { + Ok(Self::Version1) + } else if wire == 0x6b33_43cf { + Ok(Self::Version2) + } else if wire == 0xff00_0000 + 29 { + Ok(Self::Draft29) + } else if wire == 0xff00_0000 + 30 { + Ok(Self::Draft30) + } else if wire == 0xff00_0000 + 31 { + Ok(Self::Draft31) + } else if wire == 0xff00_0000 + 32 { + Ok(Self::Draft32) + } else { + Err(Error::VersionNegotiation) + } + } +} + +#[derive(Debug, Clone)] +pub struct VersionConfig { + /// The version that a client uses to establish a connection. + /// + /// For a client, this is the version that is sent out in an Initial packet. + /// A client that resumes will set this to the version from the original + /// connection. + /// A client that handles a Version Negotiation packet will be initialized with + /// a version chosen from the packet, but it will then have this value overridden + /// to match the original configuration so that the version negotiation can be + /// authenticated. + /// + /// For a server `Connection`, this is the only type of Initial packet that + /// can be accepted; the correct value is set by `Server`, see below. + /// + /// For a `Server`, this value is not used; if an Initial packet is received + /// in a supported version (as listed in `versions`), new instances of + /// `Connection` will be created with this value set to match what was received. + /// + /// An invariant here is that this version is always listed in `all`. + initial: Version, + /// The set of versions that are enabled, in preference order. For a server, + /// only the relative order of compatible versions matters. + all: Vec<Version>, +} + +impl VersionConfig { + pub fn new(initial: Version, all: Vec<Version>) -> Self { + assert!(all.contains(&initial)); + Self { initial, all } + } + + pub fn initial(&self) -> Version { + self.initial + } + + pub fn all(&self) -> &[Version] { + &self.all + } + + /// Overwrite the initial value; used by the `Server` when handling new connections + /// and by the client on resumption. + pub(crate) fn set_initial(&mut self, initial: Version) { + qdebug!( + "Overwrite initial version {:?} ==> {:?}", + self.initial, + initial + ); + assert!(self.all.contains(&initial)); + self.initial = initial; + } + + pub fn compatible(&self) -> impl Iterator<Item = &Version> { + self.initial.compatible(&self.all) + } + + fn find_preferred<'a>( + preferences: impl IntoIterator<Item = &'a Version>, + vn: &[WireVersion], + ) -> Option<Version> { + for v in preferences { + if vn.contains(&v.wire_version()) { + return Some(*v); + } + } + None + } + + /// Determine the preferred version based on a version negotiation packet. + pub(crate) fn preferred(&self, vn: &[WireVersion]) -> Option<Version> { + Self::find_preferred(&self.all, vn) + } + + /// Determine the preferred version based on a set of compatible versions. + pub(crate) fn preferred_compatible(&self, vn: &[WireVersion]) -> Option<Version> { + Self::find_preferred(self.compatible(), vn) + } +} + +impl Default for VersionConfig { + fn default() -> Self { + Self::new(Version::default(), Version::all()) + } +} |