diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
commit | 2aa4a82499d4becd2284cdb482213d541b8804dd (patch) | |
tree | b80bf8bf13c3766139fbacc530efd0dd9d54394c /third_party/rust/neqo-transport/src | |
parent | Initial commit. (diff) | |
download | firefox-upstream.tar.xz firefox-upstream.zip |
Adding upstream version 86.0.1.upstream/86.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/neqo-transport/src')
41 files changed, 21177 insertions, 0 deletions
diff --git a/third_party/rust/neqo-transport/src/addr_valid.rs b/third_party/rust/neqo-transport/src/addr_valid.rs new file mode 100644 index 0000000000..a8fcd76ab9 --- /dev/null +++ b/third_party/rust/neqo-transport/src/addr_valid.rs @@ -0,0 +1,509 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// This file implements functions necessary for address validation. + +use neqo_common::{qinfo, qtrace, Decoder, Encoder, Role}; +use neqo_crypto::{ + constants::{TLS_AES_128_GCM_SHA256, TLS_VERSION_1_3}, + selfencrypt::SelfEncrypt, +}; + +use crate::cid::ConnectionId; +use crate::packet::PacketBuilder; +use crate::recovery::RecoveryToken; +use crate::stats::FrameStats; +use crate::Res; + +use smallvec::SmallVec; +use std::convert::TryFrom; +use std::net::{IpAddr, SocketAddr}; +use std::time::{Duration, Instant}; + +/// A prefix we add to Retry tokens to distinguish them from NEW_TOKEN tokens. +const TOKEN_IDENTIFIER_RETRY: &[u8] = &[0x52, 0x65, 0x74, 0x72, 0x79]; +/// A prefix on NEW_TOKEN tokens, that is maximally Hamming distant from NEW_TOKEN. +/// Together, these need to have a low probability of collision, even if there is +/// corruption of individual bits in transit. +const TOKEN_IDENTIFIER_NEW_TOKEN: &[u8] = &[0xad, 0x9a, 0x8b, 0x8d, 0x86]; + +/// The maximum number of tokens we'll save from NEW_TOKEN frames. +/// This should be the same as the value of MAX_TICKETS in neqo-crypto. +const MAX_NEW_TOKEN: usize = 4; +/// The number of tokens we'll track for the purposes of looking for duplicates. +/// This is based on how many might be received over a period where could be +/// retransmissions. It should be at least `MAX_NEW_TOKEN`. +const MAX_SAVED_TOKENS: usize = 8; + +/// `ValidateAddress` determines what sort of address validation is performed. +/// In short, this determines when a Retry packet is sent. +#[derive(Debug, PartialEq, Eq)] +pub enum ValidateAddress { + /// Require address validation never. + Never, + /// Require address validation unless a NEW_TOKEN token is provided. + NoToken, + /// Require address validation even if a NEW_TOKEN token is provided. + Always, +} + +pub enum AddressValidationResult { + Pass, + ValidRetry(ConnectionId), + Validate, + Invalid, +} + +pub struct AddressValidation { + /// What sort of validation is performed. + validation: ValidateAddress, + /// A self-encryption object used for protecting Retry tokens. + self_encrypt: SelfEncrypt, + /// When this object was created. + start_time: Instant, +} + +impl AddressValidation { + pub fn new(now: Instant, validation: ValidateAddress) -> Res<Self> { + Ok(Self { + validation, + self_encrypt: SelfEncrypt::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256)?, + start_time: now, + }) + } + + fn encode_aad(peer_address: SocketAddr, retry: bool) -> Encoder { + // Let's be "clever" by putting the peer's address in the AAD. + // We don't need to encode these into the token as they should be + // available when we need to check the token. + let mut aad = Encoder::default(); + if retry { + aad.encode(TOKEN_IDENTIFIER_RETRY); + } else { + aad.encode(TOKEN_IDENTIFIER_NEW_TOKEN); + } + match peer_address.ip() { + IpAddr::V4(a) => { + aad.encode_byte(4); + aad.encode(&a.octets()); + } + IpAddr::V6(a) => { + aad.encode_byte(6); + aad.encode(&a.octets()); + } + } + if retry { + aad.encode_uint(2, peer_address.port()); + } + aad + } + + pub fn generate_token( + &self, + dcid: Option<&ConnectionId>, + peer_address: SocketAddr, + now: Instant, + ) -> Res<Vec<u8>> { + const EXPIRATION_RETRY: Duration = Duration::from_secs(5); + const EXPIRATION_NEW_TOKEN: Duration = Duration::from_secs(60 * 60 * 24); + + // TODO(mt) rotate keys on a fixed schedule. + let retry = dcid.is_some(); + let mut data = Encoder::default(); + let end = now + + if retry { + EXPIRATION_RETRY + } else { + EXPIRATION_NEW_TOKEN + }; + let end_millis = u32::try_from(end.duration_since(self.start_time).as_millis())?; + data.encode_uint(4, end_millis); + if let Some(dcid) = dcid { + data.encode(dcid); + } + + // Include the token identifier ("Retry"/~) in the AAD, then keep it for plaintext. + let mut buf = Self::encode_aad(peer_address, retry); + let encrypted = self.self_encrypt.seal(&buf, &data)?; + buf.truncate(TOKEN_IDENTIFIER_RETRY.len()); + buf.encode(&encrypted); + Ok(buf.into()) + } + + /// This generates a token for use with Retry. + pub fn generate_retry_token( + &self, + dcid: &ConnectionId, + peer_address: SocketAddr, + now: Instant, + ) -> Res<Vec<u8>> { + self.generate_token(Some(dcid), peer_address, now) + } + + /// This generates a token for use with NEW_TOKEN. + pub fn generate_new_token(&self, peer_address: SocketAddr, now: Instant) -> Res<Vec<u8>> { + self.generate_token(None, peer_address, now) + } + + pub fn set_validation(&mut self, validation: ValidateAddress) { + qtrace!("AddressValidation {:p}: set to {:?}", self, validation); + self.validation = validation; + } + + /// Decrypts `token` and returns the connection ID it contains. + /// Returns a tuple with a boolean indicating whether this thinks + /// that the token was a Retry token, and a connection ID, that is + /// None if the token wasn't successfully decrypted. + fn decrypt_token( + &self, + token: &[u8], + peer_address: SocketAddr, + retry: bool, + now: Instant, + ) -> Option<ConnectionId> { + let peer_addr = Self::encode_aad(peer_address, retry); + let data = if let Ok(d) = self.self_encrypt.open(&peer_addr, token) { + d + } else { + return None; + }; + let mut dec = Decoder::new(&data); + match dec.decode_uint(4) { + Some(d) => { + let end = self.start_time + Duration::from_millis(d); + if end < now { + qtrace!("Expired token: {:?} vs. {:?}", end, now); + return None; + } + } + _ => return None, + } + Some(ConnectionId::from(dec.decode_remainder())) + } + + /// Calculate the Hamming difference between our identifier and the target. + /// Less than one difference per byte indicates that it is likely not a Retry. + /// This generous interpretation allows for a lot of damage in transit. + /// Note that if this check fails, then the token will be treated like it came + /// from NEW_TOKEN instead. If there truly is corruption of packets that causes + /// validation failure, it will be a failure that we try to recover from. + fn is_likely_retry(token: &[u8]) -> bool { + let mut difference = 0; + for i in 0..TOKEN_IDENTIFIER_RETRY.len() { + difference += (token[i] ^ TOKEN_IDENTIFIER_RETRY[i]).count_ones(); + } + usize::try_from(difference).unwrap() < TOKEN_IDENTIFIER_RETRY.len() + } + + pub fn validate( + &self, + token: &[u8], + peer_address: SocketAddr, + now: Instant, + ) -> AddressValidationResult { + qtrace!( + "AddressValidation {:p}: validate {:?}", + self, + self.validation + ); + + if token.is_empty() { + if self.validation == ValidateAddress::Never { + qinfo!("AddressValidation: no token; accepting"); + return AddressValidationResult::Pass; + } else { + qinfo!("AddressValidation: no token; validating"); + return AddressValidationResult::Validate; + } + } + if token.len() <= TOKEN_IDENTIFIER_RETRY.len() { + // Treat bad tokens strictly. + qinfo!("AddressValidation: too short token"); + return AddressValidationResult::Invalid; + } + let retry = Self::is_likely_retry(token); + let enc = &token[TOKEN_IDENTIFIER_RETRY.len()..]; + // Note that this allows the token identifier part to be corrupted. + // That's OK here as we don't depend on that being authenticated. + if let Some(cid) = self.decrypt_token(enc, peer_address, retry, now) { + if retry { + // This is from Retry, so we should have an ODCID >= 8. + if cid.len() >= 8 { + qinfo!("AddressValidation: valid Retry token for {}", cid); + AddressValidationResult::ValidRetry(cid) + } else { + panic!("AddressValidation: Retry token with small CID {}", cid); + } + } else if cid.is_empty() { + // An empty connection ID means NEW_TOKEN. + if self.validation == ValidateAddress::Always { + qinfo!("AddressValidation: valid NEW_TOKEN token; validating again"); + AddressValidationResult::Validate + } else { + qinfo!("AddressValidation: valid NEW_TOKEN token; accepting"); + AddressValidationResult::Pass + } + } else { + panic!("AddressValidation: NEW_TOKEN token with CID {}", cid); + } + } else { + // From here on, we have a token that we couldn't decrypt. + // We've either lost the keys or we've received junk. + if retry { + // If this looked like a Retry, treat it as being bad. + qinfo!("AddressValidation: invalid Retry token; rejecting"); + AddressValidationResult::Invalid + } else if self.validation == ValidateAddress::Never { + // We don't require validation, so OK. + qinfo!("AddressValidation: invalid NEW_TOKEN token; accepting"); + AddressValidationResult::Pass + } else { + // This might be an invalid NEW_TOKEN token, or a valid one + // for which we have since lost the keys. Check again. + qinfo!("AddressValidation: invalid NEW_TOKEN token; validating again"); + AddressValidationResult::Validate + } + } + } +} + +// Note: these lint override can be removed in later versions where the lints +// either don't trip a false positive or don't apply. rustc 1.46 is fine. +#[allow(dead_code, clippy::large_enum_variant)] +pub enum NewTokenState { + Client { + /// Tokens that haven't been taken yet. + pending: SmallVec<[Vec<u8>; MAX_NEW_TOKEN]>, + /// Tokens that have been taken, saved so that we can discard duplicates. + old: SmallVec<[Vec<u8>; MAX_SAVED_TOKENS]>, + }, + Server(NewTokenSender), +} + +impl NewTokenState { + pub fn new(role: Role) -> Self { + match role { + Role::Client => Self::Client { + pending: SmallVec::<[_; MAX_NEW_TOKEN]>::new(), + old: SmallVec::<[_; MAX_SAVED_TOKENS]>::new(), + }, + Role::Server => Self::Server(NewTokenSender::default()), + } + } + + /// Is there a token available? + pub fn has_token(&self) -> bool { + match self { + Self::Client { ref pending, .. } => !pending.is_empty(), + Self::Server(..) => false, + } + } + + /// If this is a client, take a token if there is one. + /// If this is a server, panic. + pub fn take_token(&mut self) -> Option<&[u8]> { + if let Self::Client { + ref mut pending, + ref mut old, + } = self + { + if let Some(t) = pending.pop() { + if old.len() >= MAX_SAVED_TOKENS { + old.remove(0); + } + old.push(t); + Some(&old[old.len() - 1]) + } else { + None + } + } else { + unreachable!(); + } + } + + /// If this is a client, save a token. + /// If this is a server, panic. + pub fn save_token(&mut self, token: Vec<u8>) { + if let Self::Client { + ref mut pending, + ref old, + } = self + { + for t in old.iter().rev().chain(pending.iter().rev()) { + if t == &token { + qinfo!("NewTokenState discarding duplicate NEW_TOKEN"); + return; + } + } + + if pending.len() >= MAX_NEW_TOKEN { + pending.remove(0); + } + pending.push(token); + } else { + unreachable!(); + } + } + + /// If this is a server, maybe send a frame. + /// If this is a client, do nothing. + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + if let Self::Server(ref mut sender) = self { + sender.write_frames(builder, tokens, stats); + } + } + + /// If this a server, buffer a NEW_TOKEN for sending. + /// If this is a client, panic. + pub fn send_new_token(&mut self, token: Vec<u8>) { + if let Self::Server(ref mut sender) = self { + sender.send_new_token(token); + } else { + unreachable!(); + } + } + + /// If this a server, process a lost signal for a NEW_TOKEN frame. + /// If this is a client, panic. + pub fn lost(&mut self, seqno: usize) { + if let Self::Server(ref mut sender) = self { + sender.lost(seqno); + } else { + unreachable!(); + } + } + + /// If this a server, process remove the acknowledged NEW_TOKEN frame. + /// If this is a client, panic. + pub fn acked(&mut self, seqno: usize) { + if let Self::Server(ref mut sender) = self { + sender.acked(seqno); + } else { + unreachable!(); + } + } +} + +struct NewTokenFrameStatus { + seqno: usize, + token: Vec<u8>, + needs_sending: bool, +} + +impl NewTokenFrameStatus { + fn len(&self) -> usize { + 1 + Encoder::vvec_len(self.token.len()) + } +} + +#[derive(Default)] +pub struct NewTokenSender { + /// The unacknowledged NEW_TOKEN frames we are yet to send. + tokens: Vec<NewTokenFrameStatus>, + /// A sequence number that is used to track individual tokens + /// by reference (so that recovery tokens can be simple). + next_seqno: usize, +} + +impl NewTokenSender { + /// Add a token to be sent. + pub fn send_new_token(&mut self, token: Vec<u8>) { + self.tokens.push(NewTokenFrameStatus { + seqno: self.next_seqno, + token, + needs_sending: true, + }); + self.next_seqno += 1; + } + + pub fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + for t in self.tokens.iter_mut() { + if t.needs_sending && t.len() <= builder.remaining() { + t.needs_sending = false; + + builder.encode_varint(crate::frame::FRAME_TYPE_NEW_TOKEN); + builder.encode_vvec(&t.token); + + tokens.push(RecoveryToken::NewToken(t.seqno)); + stats.new_token += 1; + } + } + } + + pub fn lost(&mut self, seqno: usize) { + for t in self.tokens.iter_mut() { + if t.seqno == seqno { + t.needs_sending = true; + break; + } + } + } + + pub fn acked(&mut self, seqno: usize) { + self.tokens.retain(|i| i.seqno != seqno); + } +} + +#[cfg(test)] +mod tests { + use super::NewTokenState; + use neqo_common::Role; + + const ONE: &[u8] = &[1, 2, 3]; + const TWO: &[u8] = &[4, 5]; + + #[test] + fn duplicate_saved() { + let mut tokens = NewTokenState::new(Role::Client); + tokens.save_token(ONE.to_vec()); + tokens.save_token(TWO.to_vec()); + tokens.save_token(ONE.to_vec()); + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably TWO + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably ONE + assert!(!tokens.has_token()); + assert!(tokens.take_token().is_none()); + } + + #[test] + fn duplicate_after_take() { + let mut tokens = NewTokenState::new(Role::Client); + tokens.save_token(ONE.to_vec()); + tokens.save_token(TWO.to_vec()); + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably TWO + tokens.save_token(ONE.to_vec()); + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably ONE + assert!(!tokens.has_token()); + assert!(tokens.take_token().is_none()); + } + + #[test] + fn duplicate_after_empty() { + let mut tokens = NewTokenState::new(Role::Client); + tokens.save_token(ONE.to_vec()); + tokens.save_token(TWO.to_vec()); + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably TWO + assert!(tokens.has_token()); + assert!(tokens.take_token().is_some()); // probably ONE + tokens.save_token(ONE.to_vec()); + assert!(!tokens.has_token()); + assert!(tokens.take_token().is_none()); + } +} diff --git a/third_party/rust/neqo-transport/src/cc/classic_cc.rs b/third_party/rust/neqo-transport/src/cc/classic_cc.rs new file mode 100644 index 0000000000..b41969c680 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/classic_cc.rs @@ -0,0 +1,955 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] + +use std::cmp::{max, min}; +use std::fmt::{self, Debug, Display}; +use std::time::{Duration, Instant}; + +use super::CongestionControl; + +use crate::cc::MAX_DATAGRAM_SIZE; +use crate::qlog::{self, QlogMetric}; +use crate::sender::PACING_BURST_SIZE; +use crate::tracking::SentPacket; +use neqo_common::{const_max, const_min, qdebug, qinfo, qlog::NeqoQlog, qtrace}; + +pub const CWND_INITIAL_PKTS: usize = 10; +const CWND_INITIAL: usize = const_min( + CWND_INITIAL_PKTS * MAX_DATAGRAM_SIZE, + const_max(2 * MAX_DATAGRAM_SIZE, 14720), +); +pub const CWND_MIN: usize = MAX_DATAGRAM_SIZE * 2; +const PERSISTENT_CONG_THRESH: u32 = 3; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum State { + /// In either slow start or congestion avoidance, not recovery. + SlowStart, + /// In congestion avoidance. + CongestionAvoidance, + /// In a recovery period, but no packets have been sent yet. This is a + /// transient state because we want to exempt the first packet sent after + /// entering recovery from the congestion window. + RecoveryStart, + /// In a recovery period, with the first packet sent at this time. + Recovery, + /// Start of persistent congestion, which is transient, like `RecoveryStart`. + PersistentCongestion, +} + +impl State { + pub fn in_recovery(self) -> bool { + matches!(self, Self::RecoveryStart | Self::Recovery) + } + + pub fn in_slow_start(self) -> bool { + self == Self::SlowStart + } + + /// These states are transient, we tell qlog on entry, but not on exit. + pub fn transient(self) -> bool { + matches!(self, Self::RecoveryStart | Self::PersistentCongestion) + } + + /// Update a transient state to the true state. + pub fn update(&mut self) { + *self = match self { + Self::PersistentCongestion => Self::SlowStart, + Self::RecoveryStart => Self::Recovery, + _ => unreachable!(), + }; + } + + pub fn to_qlog(self) -> &'static str { + match self { + Self::SlowStart | Self::PersistentCongestion => "slow_start", + Self::CongestionAvoidance => "congestion_avoidance", + Self::Recovery | Self::RecoveryStart => "recovery", + } + } +} + +pub trait WindowAdjustment: Display + Debug { + fn on_packets_acked(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize); + fn on_congestion_event(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize); +} + +#[derive(Debug)] +pub struct ClassicCongestionControl<T> { + cc_algorithm: T, + state: State, + congestion_window: usize, // = kInitialWindow + bytes_in_flight: usize, + acked_bytes: usize, + ssthresh: usize, + recovery_start: Option<Instant>, + + qlog: NeqoQlog, +} + +impl<T: WindowAdjustment> Display for ClassicCongestionControl<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} CongCtrl {}/{} ssthresh {}", + self.cc_algorithm, self.bytes_in_flight, self.congestion_window, self.ssthresh, + )?; + Ok(()) + } +} + +impl<T: WindowAdjustment> CongestionControl for ClassicCongestionControl<T> { + fn set_qlog(&mut self, qlog: NeqoQlog) { + self.qlog = qlog; + } + + #[must_use] + fn cwnd(&self) -> usize { + self.congestion_window + } + + #[must_use] + fn bytes_in_flight(&self) -> usize { + self.bytes_in_flight + } + + #[must_use] + fn cwnd_avail(&self) -> usize { + // BIF can be higher than cwnd due to PTO packets, which are sent even + // if avail is 0, but still count towards BIF. + self.congestion_window.saturating_sub(self.bytes_in_flight) + } + + // Multi-packet version of OnPacketAckedCC + fn on_packets_acked(&mut self, acked_pkts: &[SentPacket]) { + // Check whether we are app limited before acked packets are removed + // from bytes_in_flight. + let is_app_limited = self.app_limited(); + qtrace!( + [self], + "app limited={}, bytes_in_flight:{}, cwnd: {}, state: {:?} pacing_burst_size: {}", + is_app_limited, + self.bytes_in_flight, + self.congestion_window, + self.state, + MAX_DATAGRAM_SIZE * PACING_BURST_SIZE, + ); + + let mut acked_bytes = 0; + for pkt in acked_pkts.iter().filter(|pkt| pkt.cc_outstanding()) { + assert!(self.bytes_in_flight >= pkt.size); + self.bytes_in_flight -= pkt.size; + + if !self.after_recovery_start(pkt) { + // Do not increase congestion window for packets sent before + // recovery last started. + continue; + } + + if self.state.in_recovery() { + self.set_state(State::CongestionAvoidance); + qlog::metrics_updated(&mut self.qlog, &[QlogMetric::InRecovery(false)]); + } + + acked_bytes += pkt.size; + } + + if !is_app_limited { + self.acked_bytes += acked_bytes; + } + + qtrace!([self], "ACK received, acked_bytes = {}", self.acked_bytes); + + // Slow start, up to the slow start threshold. + if self.congestion_window < self.ssthresh { + let increase = min(self.ssthresh - self.congestion_window, self.acked_bytes); + self.congestion_window += increase; + self.acked_bytes -= increase; + qinfo!([self], "slow start += {}", increase); + if self.congestion_window == self.ssthresh { + // This doesn't look like it is necessary, but it can happen + // after persistent congestion. + self.set_state(State::CongestionAvoidance); + } + } + // Congestion avoidance, above the slow start threshold. + if self.congestion_window >= self.ssthresh { + let (cwnd, acked_bytes) = self + .cc_algorithm + .on_packets_acked(self.congestion_window, self.acked_bytes); + self.congestion_window = cwnd; + self.acked_bytes = acked_bytes; + } + qlog::metrics_updated( + &mut self.qlog, + &[ + QlogMetric::CongestionWindow(self.congestion_window), + QlogMetric::BytesInFlight(self.bytes_in_flight), + ], + ); + } + + /// Update congestion controller state based on lost packets. + fn on_packets_lost( + &mut self, + first_rtt_sample_time: Option<Instant>, + prev_largest_acked_sent: Option<Instant>, + pto: Duration, + lost_packets: &[SentPacket], + ) { + if lost_packets.is_empty() { + return; + } + + for pkt in lost_packets.iter().filter(|pkt| pkt.ack_eliciting()) { + assert!(self.bytes_in_flight >= pkt.size); + self.bytes_in_flight -= pkt.size; + } + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::BytesInFlight(self.bytes_in_flight)], + ); + + qdebug!([self], "Pkts lost {}", lost_packets.len()); + + self.on_congestion_event(lost_packets.last().unwrap()); + self.detect_persistent_congestion( + first_rtt_sample_time, + prev_largest_acked_sent, + pto, + lost_packets, + ); + } + + fn discard(&mut self, pkt: &SentPacket) { + if pkt.cc_outstanding() { + assert!(self.bytes_in_flight >= pkt.size); + self.bytes_in_flight -= pkt.size; + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::BytesInFlight(self.bytes_in_flight)], + ); + qtrace!([self], "Ignore pkt with size {}", pkt.size); + } + } + + fn on_packet_sent(&mut self, pkt: &SentPacket) { + // Record the recovery time and exit any transient state. + if self.state.transient() { + self.recovery_start = Some(pkt.time_sent); + self.state.update(); + } + + if !pkt.ack_eliciting() { + return; + } + + self.bytes_in_flight += pkt.size; + qdebug!( + [self], + "Pkt Sent len {}, bif {}, cwnd {}", + pkt.size, + self.bytes_in_flight, + self.congestion_window + ); + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::BytesInFlight(self.bytes_in_flight)], + ); + } + + /// Whether a packet can be sent immediately as a result of entering recovery. + #[must_use] + fn recovery_packet(&self) -> bool { + self.state == State::RecoveryStart + } +} + +impl<T: WindowAdjustment> ClassicCongestionControl<T> { + pub fn new(cc_algorithm: T) -> Self { + Self { + cc_algorithm, + state: State::SlowStart, + congestion_window: CWND_INITIAL, + bytes_in_flight: 0, + acked_bytes: 0, + ssthresh: usize::MAX, + recovery_start: None, + qlog: NeqoQlog::disabled(), + } + } + + #[cfg(test)] + #[must_use] + pub fn ssthresh(&self) -> usize { + self.ssthresh + } + + fn set_state(&mut self, state: State) { + if self.state != state { + qdebug!([self], "state -> {:?}", state); + let old_state = self.state; + self.qlog.add_event(|| { + // No need to tell qlog about exit from transient states. + if old_state.transient() { + None + } else { + Some(::qlog::event::Event::congestion_state_updated( + Some(old_state.to_qlog().to_owned()), + state.to_qlog().to_owned(), + )) + } + }); + self.state = state; + } + } + + fn detect_persistent_congestion( + &mut self, + first_rtt_sample_time: Option<Instant>, + prev_largest_acked_sent: Option<Instant>, + pto: Duration, + lost_packets: &[SentPacket], + ) { + if first_rtt_sample_time.is_none() { + return; + } + + let pc_period = pto * PERSISTENT_CONG_THRESH; + + let mut last_pn = 1 << 62; // Impossibly large, but not enough to overflow. + let mut start = None; + + // Look for the first lost packet after the previous largest acknowledged. + // Ignore packets that weren't ack-eliciting for the start of this range. + // Also, make sure to ignore any packets sent before we got an RTT estimate + // as we might not have sent PTO packets soon enough after those. + let cutoff = max(first_rtt_sample_time, prev_largest_acked_sent); + for p in lost_packets + .iter() + .skip_while(|p| Some(p.time_sent) < cutoff) + { + if p.pn != last_pn + 1 { + // Not a contiguous range of lost packets, start over. + start = None; + } + last_pn = p.pn; + if !p.ack_eliciting() { + // Not interesting, keep looking. + continue; + } + if let Some(t) = start { + if p.time_sent.duration_since(t) > pc_period { + qinfo!([self], "persistent congestion"); + self.congestion_window = CWND_MIN; + self.acked_bytes = 0; + self.set_state(State::PersistentCongestion); + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::CongestionWindow(self.congestion_window)], + ); + return; + } + } else { + start = Some(p.time_sent); + } + } + } + + #[must_use] + fn after_recovery_start(&mut self, packet: &SentPacket) -> bool { + // At the start of the first recovery period, if the state is + // transient, all packets will have been sent before recovery. + self.recovery_start + .map_or(!self.state.transient(), |t| packet.time_sent >= t) + } + + /// Handle a congestion event. + fn on_congestion_event(&mut self, last_packet: &SentPacket) { + // Start a new congestion event if lost packet was sent after the start + // of the previous congestion recovery period. + if self.after_recovery_start(last_packet) { + let (cwnd, acked_bytes) = self + .cc_algorithm + .on_congestion_event(self.congestion_window, self.acked_bytes); + self.congestion_window = max(cwnd, CWND_MIN); + self.acked_bytes = acked_bytes; + self.ssthresh = self.congestion_window; + qinfo!( + [self], + "Cong event -> recovery; cwnd {}, ssthresh {}", + self.congestion_window, + self.ssthresh + ); + qlog::metrics_updated( + &mut self.qlog, + &[ + QlogMetric::CongestionWindow(self.congestion_window), + QlogMetric::SsThresh(self.ssthresh), + QlogMetric::InRecovery(true), + ], + ); + self.set_state(State::RecoveryStart); + } + } + + #[allow(clippy::unused_self)] + fn app_limited(&self) -> bool { + if self.bytes_in_flight >= self.congestion_window { + false + } else if self.state.in_slow_start() { + // Allow for potential doubling of the congestion window during slow start. + // That is, the application might not have been able to send enough to respond + // to increases to the congestion window. + self.bytes_in_flight < self.congestion_window / 2 + } else { + // We're not limited if the in-flight data is within a single burst of the + // congestion window. + (self.bytes_in_flight + MAX_DATAGRAM_SIZE * PACING_BURST_SIZE) < self.congestion_window + } + } +} + +#[cfg(test)] +mod tests { + use super::{ClassicCongestionControl, CWND_INITIAL, CWND_MIN, PERSISTENT_CONG_THRESH}; + use crate::cc::new_reno::NewReno; + use crate::cc::{CongestionControl, CWND_INITIAL_PKTS, MAX_DATAGRAM_SIZE}; + use crate::packet::{PacketNumber, PacketType}; + use crate::tracking::SentPacket; + use std::convert::TryFrom; + use std::time::{Duration, Instant}; + use test_fixture::now; + + const PTO: Duration = Duration::from_millis(100); + const RTT: Duration = Duration::from_millis(98); + const ZERO: Duration = Duration::from_secs(0); + const EPSILON: Duration = Duration::from_nanos(1); + const GAP: Duration = Duration::from_secs(1); + /// The largest time between packets without causing persistent congestion. + const SUB_PC: Duration = Duration::from_millis(100 * PERSISTENT_CONG_THRESH as u64); + /// The minimum time between packets to cause persistent congestion. + /// Uses an odd expression because `Duration` arithmetic isn't `const`. + const PC: Duration = Duration::from_nanos(100_000_000 * (PERSISTENT_CONG_THRESH as u64) + 1); + + fn cwnd_is_default(cc: &ClassicCongestionControl<NewReno>) { + assert_eq!(cc.cwnd(), CWND_INITIAL); + assert_eq!(cc.ssthresh(), usize::MAX); + } + + fn cwnd_is_halved(cc: &ClassicCongestionControl<NewReno>) { + assert_eq!(cc.cwnd(), CWND_INITIAL / 2); + assert_eq!(cc.ssthresh(), CWND_INITIAL / 2); + } + + #[test] + fn issue_876() { + let mut cc = ClassicCongestionControl::new(NewReno::default()); + let time_now = now(); + let time_before = time_now - Duration::from_millis(100); + let time_after = time_now + Duration::from_millis(150); + + let sent_packets = &[ + SentPacket::new( + PacketType::Short, + 1, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE - 1, // size + ), + SentPacket::new( + PacketType::Short, + 2, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE - 2, // size + ), + SentPacket::new( + PacketType::Short, + 3, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 4, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 5, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 6, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 7, // pn + time_after, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE - 3, // size + ), + ]; + + // Send some more packets so that the cc is not app-limited. + for p in &sent_packets[..6] { + cc.on_packet_sent(p); + } + assert_eq!(cc.acked_bytes, 0); + cwnd_is_default(&cc); + assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 3); + + cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[0..1]); + + // We are now in recovery + assert!(cc.recovery_packet()); + assert_eq!(cc.acked_bytes, 0); + cwnd_is_halved(&cc); + assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2); + + // Send a packet after recovery starts + cc.on_packet_sent(&sent_packets[6]); + assert!(!cc.recovery_packet()); + cwnd_is_halved(&cc); + assert_eq!(cc.acked_bytes, 0); + assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 5); + + // and ack it. cwnd increases slightly + cc.on_packets_acked(&sent_packets[6..]); + assert_eq!(cc.acked_bytes, sent_packets[6].size); + cwnd_is_halved(&cc); + assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2); + + // Packet from before is lost. Should not hurt cwnd. + cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[1..2]); + assert!(!cc.recovery_packet()); + assert_eq!(cc.acked_bytes, sent_packets[6].size); + cwnd_is_halved(&cc); + assert_eq!(cc.bytes_in_flight(), 4 * MAX_DATAGRAM_SIZE); + } + + fn lost(pn: PacketNumber, ack_eliciting: bool, t: Duration) -> SentPacket { + SentPacket::new( + PacketType::Short, + pn, + now() + t, + ack_eliciting, + Vec::new(), + 100, + ) + } + + fn persistent_congestion(lost_packets: &[SentPacket]) -> bool { + let mut cc = ClassicCongestionControl::new(NewReno::default()); + for p in lost_packets { + cc.on_packet_sent(p); + } + + cc.on_packets_lost(Some(now()), None, PTO, lost_packets); + if cc.cwnd() == CWND_INITIAL / 2 { + false + } else if cc.cwnd() == CWND_MIN { + true + } else { + panic!("unexpected cwnd"); + } + } + + /// A span of exactly the PC threshold only reduces the window on loss. + #[test] + fn persistent_congestion_none() { + assert!(!persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, true, SUB_PC), + ])); + } + + /// A span of just more than the PC threshold causes persistent congestion. + #[test] + fn persistent_congestion_simple() { + assert!(persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, true, PC), + ])); + } + + /// Both packets need to be ack-eliciting. + #[test] + fn persistent_congestion_non_ack_eliciting() { + assert!(!persistent_congestion(&[ + lost(1, false, ZERO), + lost(2, true, PC), + ])); + assert!(!persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, false, PC), + ])); + } + + /// Packets in the middle, of any type, are OK. + #[test] + fn persistent_congestion_middle() { + assert!(persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, false, RTT), + lost(3, true, PC), + ])); + assert!(persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, true, RTT), + lost(3, true, PC), + ])); + } + + /// Leading non-ack-eliciting packets are skipped. + #[test] + fn persistent_congestion_leading_non_ack_eliciting() { + assert!(!persistent_congestion(&[ + lost(1, false, ZERO), + lost(2, true, RTT), + lost(3, true, PC), + ])); + assert!(persistent_congestion(&[ + lost(1, false, ZERO), + lost(2, true, RTT), + lost(3, true, RTT + PC), + ])); + } + + /// Trailing non-ack-eliciting packets aren't relevant. + #[test] + fn persistent_congestion_trailing_non_ack_eliciting() { + assert!(persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, true, PC), + lost(3, false, PC + EPSILON), + ])); + assert!(!persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, true, SUB_PC), + lost(3, false, PC), + ])); + } + + /// Gaps in the middle, of any type, restart the count. + #[test] + fn persistent_congestion_gap_reset() { + assert!(!persistent_congestion(&[ + lost(1, true, ZERO), + lost(3, true, PC), + ])); + assert!(!persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, true, RTT), + lost(4, true, GAP), + lost(5, true, GAP + PTO * PERSISTENT_CONG_THRESH), + ])); + } + + /// A span either side of a gap will cause persistent congestion. + #[test] + fn persistent_congestion_gap_or() { + assert!(persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, true, PC), + lost(4, true, GAP), + lost(5, true, GAP + PTO), + ])); + assert!(persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, true, GAP), + lost(5, true, GAP + PC), + ])); + } + + /// A gap only restarts after an ack-eliciting packet. + #[test] + fn persistent_congestion_gap_non_ack_eliciting() { + assert!(!persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, false, GAP), + lost(5, true, GAP + PC), + ])); + assert!(!persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, false, GAP), + lost(5, true, GAP + RTT), + lost(6, true, GAP + RTT + SUB_PC), + ])); + assert!(persistent_congestion(&[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, false, GAP), + lost(5, true, GAP + RTT), + lost(6, true, GAP + RTT + PC), + ])); + } + + /// Get a time, in multiples of `PTO`, relative to `now()`. + fn by_pto(t: u32) -> Instant { + now() + (PTO * t) + } + + /// Make packets that will be made lost. + /// `times` is the time of sending, in multiples of `PTO`, relative to `now()`. + fn make_lost(times: &[u32]) -> Vec<SentPacket> { + times + .iter() + .enumerate() + .map(|(i, &t)| { + SentPacket::new( + PacketType::Short, + u64::try_from(i).unwrap(), + by_pto(t), + true, + Vec::new(), + 1000, + ) + }) + .collect::<Vec<_>>() + } + + /// Call `detect_persistent_congestion` using times relative to now and the fixed PTO time. + /// `last_ack` and `rtt_time` are times in multiples of `PTO`, relative to `now()`, + /// for the time of the largest acknowledged and the first RTT sample, respectively. + fn persistent_congestion_by_pto(last_ack: u32, rtt_time: u32, lost: &[SentPacket]) -> bool { + let mut cc = ClassicCongestionControl::new(NewReno::default()); + assert_eq!(cc.cwnd(), CWND_INITIAL); + + let last_ack = Some(by_pto(last_ack)); + let rtt_time = Some(by_pto(rtt_time)); + + // Persistent congestion is never declared if the RTT time is `None`. + cc.detect_persistent_congestion(None, None, PTO, lost); + assert_eq!(cc.cwnd(), CWND_INITIAL); + cc.detect_persistent_congestion(None, last_ack, PTO, lost); + assert_eq!(cc.cwnd(), CWND_INITIAL); + + cc.detect_persistent_congestion(rtt_time, last_ack, PTO, lost); + cc.cwnd() == CWND_MIN + } + + /// No persistent congestion can be had if there are no lost packets. + #[test] + fn persistent_congestion_no_lost() { + let lost = make_lost(&[]); + assert!(!persistent_congestion_by_pto(0, 0, &lost)); + } + + /// No persistent congestion can be had if there is only one lost packet. + #[test] + fn persistent_congestion_one_lost() { + let lost = make_lost(&[1]); + assert!(!persistent_congestion_by_pto(0, 0, &lost)); + } + + /// Persistent congestion can't happen based on old packets. + #[test] + fn persistent_congestion_past() { + // Packets sent prior to either the last acknowledged or the first RTT + // sample are not considered. So 0 is ignored. + let lost = make_lost(&[0, PERSISTENT_CONG_THRESH + 1, PERSISTENT_CONG_THRESH + 2]); + assert!(!persistent_congestion_by_pto(1, 1, &lost)); + assert!(!persistent_congestion_by_pto(0, 1, &lost)); + assert!(!persistent_congestion_by_pto(1, 0, &lost)); + } + + /// Persistent congestion doesn't start unless the packet is ack-eliciting. + #[test] + fn persistent_congestion_ack_eliciting() { + let mut lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); + lost[0] = SentPacket::new( + lost[0].pt, + lost[0].pn, + lost[0].time_sent, + false, + Vec::new(), + lost[0].size, + ); + assert!(!persistent_congestion_by_pto(0, 0, &lost)); + } + + /// Detect persistent congestion. Note that the first lost packet needs to have a time + /// greater than the previously acknowledged packet AND the first RTT sample. And the + /// difference in times needs to be greater than the persistent congestion threshold. + #[test] + fn persistent_congestion_min() { + let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); + assert!(persistent_congestion_by_pto(0, 0, &lost)); + } + + /// Make sure that not having a previous largest acknowledged also results + /// in detecting persistent congestion. (This is not expected to happen, but + /// the code permits it). + #[test] + fn persistent_congestion_no_prev_ack() { + let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); + let mut cc = ClassicCongestionControl::new(NewReno::default()); + cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost); + assert_eq!(cc.cwnd(), CWND_MIN); + } + + /// The code asserts on ordering errors. + #[test] + #[should_panic] + fn persistent_congestion_unsorted() { + let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]); + assert!(!persistent_congestion_by_pto(0, 0, &lost)); + } + + #[test] + fn app_limited_slow_start() { + const LESS_THAN_CWND_PKTS: usize = 4; + let mut cc = ClassicCongestionControl::new(NewReno::default()); + + for i in 0..CWND_INITIAL_PKTS { + let sent = SentPacket::new( + PacketType::Short, + u64::try_from(i).unwrap(), // pn + now(), // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packet_sent(&sent); + } + assert_eq!(cc.bytes_in_flight(), CWND_INITIAL); + + for i in 0..LESS_THAN_CWND_PKTS { + let acked = SentPacket::new( + PacketType::Short, + u64::try_from(i).unwrap(), // pn + now(), // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packets_acked(&[acked]); + + assert_eq!( + cc.bytes_in_flight(), + (CWND_INITIAL_PKTS - i - 1) * MAX_DATAGRAM_SIZE + ); + assert_eq!(cc.cwnd(), (CWND_INITIAL_PKTS + i + 1) * MAX_DATAGRAM_SIZE); + } + + // Now we are app limited + for i in 4..CWND_INITIAL_PKTS { + let p = [SentPacket::new( + PacketType::Short, + u64::try_from(i).unwrap(), // pn + now(), // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + )]; + cc.on_packets_acked(&p); + + assert_eq!( + cc.bytes_in_flight(), + (CWND_INITIAL_PKTS - i - 1) * MAX_DATAGRAM_SIZE + ); + assert_eq!(cc.cwnd(), (CWND_INITIAL_PKTS + 4) * MAX_DATAGRAM_SIZE); + } + } + + #[test] + fn app_limited_congestion_avoidance() { + const CWND_PKTS_CA: usize = CWND_INITIAL_PKTS / 2; + + let mut cc = ClassicCongestionControl::new(NewReno::default()); + + // Change state to congestion avoidance by introducing loss. + + let p_lost = SentPacket::new( + PacketType::Short, + 1, // pn + now(), // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packet_sent(&p_lost); + cwnd_is_default(&cc); + cc.on_packets_lost(Some(now()), None, PTO, &[p_lost]); + cwnd_is_halved(&cc); + let p_not_lost = SentPacket::new( + PacketType::Short, + 1, // pn + now(), // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packet_sent(&p_not_lost); + cc.on_packets_acked(&[p_not_lost]); + cwnd_is_halved(&cc); + // cc is app limited therefore cwnd in not increased. + assert_eq!(cc.acked_bytes, 0); + + // Now we are in the congestion avoidance state. + let mut pkts = Vec::new(); + for i in 0..CWND_PKTS_CA { + let p = SentPacket::new( + PacketType::Short, + u64::try_from(i + 3).unwrap(), // pn + now(), // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packet_sent(&p); + pkts.push(p); + } + assert_eq!(cc.bytes_in_flight(), CWND_INITIAL / 2); + + for i in 0..CWND_PKTS_CA - 2 { + cc.on_packets_acked(&pkts[i..=i]); + + assert_eq!( + cc.bytes_in_flight(), + (CWND_PKTS_CA - i - 1) * MAX_DATAGRAM_SIZE + ); + assert_eq!(cc.cwnd(), CWND_PKTS_CA * MAX_DATAGRAM_SIZE); + assert_eq!(cc.acked_bytes, MAX_DATAGRAM_SIZE * (i + 1)); + } + + // Now we are app limited + for i in CWND_PKTS_CA - 2..CWND_PKTS_CA { + cc.on_packets_acked(&pkts[i..=i]); + + assert_eq!( + cc.bytes_in_flight(), + (CWND_PKTS_CA - i - 1) * MAX_DATAGRAM_SIZE + ); + assert_eq!(cc.cwnd(), CWND_PKTS_CA * MAX_DATAGRAM_SIZE); + assert_eq!(cc.acked_bytes, MAX_DATAGRAM_SIZE * 3); + } + } +} diff --git a/third_party/rust/neqo-transport/src/cc/mod.rs b/third_party/rust/neqo-transport/src/cc/mod.rs new file mode 100644 index 0000000000..996054ad08 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/mod.rs @@ -0,0 +1,55 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] + +use crate::path::PATH_MTU_V6; +use crate::tracking::SentPacket; +use neqo_common::qlog::NeqoQlog; + +use std::fmt::{Debug, Display}; +use std::time::{Duration, Instant}; + +mod classic_cc; +mod new_reno; + +pub use classic_cc::ClassicCongestionControl; +pub use classic_cc::{CWND_INITIAL_PKTS, CWND_MIN}; +pub use new_reno::NewReno; + +pub const MAX_DATAGRAM_SIZE: usize = PATH_MTU_V6; + +pub trait CongestionControl: Display + Debug { + fn set_qlog(&mut self, qlog: NeqoQlog); + + fn cwnd(&self) -> usize; + + fn bytes_in_flight(&self) -> usize; + + fn cwnd_avail(&self) -> usize; + + fn on_packets_acked(&mut self, acked_pkts: &[SentPacket]); + + fn on_packets_lost( + &mut self, + first_rtt_sample_time: Option<Instant>, + prev_largest_acked_sent: Option<Instant>, + pto: Duration, + lost_packets: &[SentPacket], + ); + + fn recovery_packet(&self) -> bool; + + fn discard(&mut self, pkt: &SentPacket); + + fn on_packet_sent(&mut self, pkt: &SentPacket); +} + +#[derive(Copy, Clone)] +pub enum CongestionControlAlgorithm { + NewReno, +} diff --git a/third_party/rust/neqo-transport/src/cc/new_reno.rs b/third_party/rust/neqo-transport/src/cc/new_reno.rs new file mode 100644 index 0000000000..a398887d61 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/new_reno.rs @@ -0,0 +1,44 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] + +use std::fmt::{self, Display}; + +use crate::cc::{classic_cc::WindowAdjustment, MAX_DATAGRAM_SIZE}; +use neqo_common::qinfo; + +#[derive(Debug)] +pub struct NewReno {} + +impl Default for NewReno { + fn default() -> Self { + Self {} + } +} + +impl Display for NewReno { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "NewReno")?; + Ok(()) + } +} + +impl WindowAdjustment for NewReno { + fn on_packets_acked(&mut self, mut curr_cwnd: usize, mut acked_bytes: usize) -> (usize, usize) { + if acked_bytes >= curr_cwnd { + acked_bytes -= curr_cwnd; + curr_cwnd += MAX_DATAGRAM_SIZE; + qinfo!([self], "congestion avoidance += {}", MAX_DATAGRAM_SIZE); + } + (curr_cwnd, acked_bytes) + } + + fn on_congestion_event(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize) { + (curr_cwnd / 2, acked_bytes / 2) + } +} diff --git a/third_party/rust/neqo-transport/src/cid.rs b/third_party/rust/neqo-transport/src/cid.rs new file mode 100644 index 0000000000..ef2b938c28 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cid.rs @@ -0,0 +1,157 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Encoding and decoding packets off the wire. + +use neqo_common::{hex, hex_with_len, Decoder}; +use neqo_crypto::random; + +use std::borrow::Borrow; +use std::cmp::max; +use std::convert::AsRef; + +pub const MAX_CONNECTION_ID_LEN: usize = 20; + +#[derive(Clone, Default, Eq, Hash, PartialEq)] +pub struct ConnectionId { + pub(crate) cid: Vec<u8>, +} + +impl ConnectionId { + pub fn generate(len: usize) -> Self { + assert!(matches!(len, 0..=MAX_CONNECTION_ID_LEN)); + Self { cid: random(len) } + } + + // Apply a wee bit of greasing here in picking a length between 8 and 20 bytes long. + pub fn generate_initial() -> Self { + let v = random(1); + // Bias selection toward picking 8 (>50% of the time). + let len: usize = max(8, 5 + (v[0] & (v[0] >> 4))).into(); + Self::generate(len) + } + + pub fn as_cid_ref(&self) -> ConnectionIdRef { + ConnectionIdRef::from(&self.cid[..]) + } +} + +impl AsRef<[u8]> for ConnectionId { + fn as_ref(&self) -> &[u8] { + self.borrow() + } +} + +impl Borrow<[u8]> for ConnectionId { + fn borrow(&self) -> &[u8] { + &self.cid + } +} + +impl From<&[u8]> for ConnectionId { + fn from(buf: &[u8]) -> Self { + Self { + cid: Vec::from(buf), + } + } +} + +impl<'a> From<&ConnectionIdRef<'a>> for ConnectionId { + fn from(cidref: &ConnectionIdRef<'a>) -> Self { + Self { + cid: Vec::from(cidref.cid), + } + } +} + +impl std::ops::Deref for ConnectionId { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.cid + } +} + +impl ::std::fmt::Debug for ConnectionId { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "CID {}", hex_with_len(&self.cid)) + } +} + +impl ::std::fmt::Display for ConnectionId { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}", hex(&self.cid)) + } +} + +impl<'a> PartialEq<ConnectionIdRef<'a>> for ConnectionId { + fn eq(&self, other: &ConnectionIdRef<'a>) -> bool { + &self.cid[..] == other.cid + } +} + +#[derive(Hash, Eq, PartialEq)] +pub struct ConnectionIdRef<'a> { + cid: &'a [u8], +} + +impl<'a> ::std::fmt::Debug for ConnectionIdRef<'a> { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "CID {}", hex_with_len(&self.cid)) + } +} + +impl<'a> ::std::fmt::Display for ConnectionIdRef<'a> { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}", hex(&self.cid)) + } +} + +impl<'a> From<&'a [u8]> for ConnectionIdRef<'a> { + fn from(cid: &'a [u8]) -> Self { + Self { cid } + } +} + +impl<'a> std::ops::Deref for ConnectionIdRef<'a> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.cid + } +} + +impl<'a> PartialEq<ConnectionId> for ConnectionIdRef<'a> { + fn eq(&self, other: &ConnectionId) -> bool { + self.cid == &other.cid[..] + } +} + +pub trait ConnectionIdDecoder { + fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>>; +} + +pub trait ConnectionIdManager: ConnectionIdDecoder { + fn generate_cid(&mut self) -> ConnectionId; + fn as_decoder(&self) -> &dyn ConnectionIdDecoder; +} + +#[cfg(test)] +mod tests { + use super::*; + use test_fixture::fixture_init; + + #[test] + fn generate_initial_cid() { + fixture_init(); + for _ in 0..100 { + let cid = ConnectionId::generate_initial(); + if !matches!(cid.len(), 8..=MAX_CONNECTION_ID_LEN) { + panic!("connection ID {:?}", cid); + } + } + } +} diff --git a/third_party/rust/neqo-transport/src/connection/idle.rs b/third_party/rust/neqo-transport/src/connection/idle.rs new file mode 100644 index 0000000000..9cf3be20a1 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/idle.rs @@ -0,0 +1,90 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::cmp::{max, min}; +use std::time::{Duration, Instant}; + +pub const LOCAL_IDLE_TIMEOUT: Duration = Duration::from_secs(30); + +#[derive(Debug, Clone)] +/// There's a little bit of different behavior for resetting idle timeout. See +/// -transport 10.2 ("Idle Timeout"). +enum IdleTimeoutState { + Init, + PacketReceived(Instant), + AckElicitingPacketSent(Instant), +} + +#[derive(Debug, Clone)] +/// There's a little bit of different behavior for resetting idle timeout. See +/// -transport 10.2 ("Idle Timeout"). +pub struct IdleTimeout { + timeout: Duration, + state: IdleTimeoutState, +} + +#[cfg(test)] +impl IdleTimeout { + pub fn new(timeout: Duration) -> Self { + Self { + timeout, + state: IdleTimeoutState::Init, + } + } +} + +impl Default for IdleTimeout { + fn default() -> Self { + Self { + timeout: LOCAL_IDLE_TIMEOUT, + state: IdleTimeoutState::Init, + } + } +} + +impl IdleTimeout { + pub fn set_peer_timeout(&mut self, peer_timeout: Duration) { + self.timeout = min(self.timeout, peer_timeout); + } + + pub fn expiry(&self, now: Instant, pto: Duration) -> Instant { + let start = match self.state { + IdleTimeoutState::Init => now, + IdleTimeoutState::PacketReceived(t) | IdleTimeoutState::AckElicitingPacketSent(t) => t, + }; + start + max(self.timeout, pto * 3) + } + + pub fn on_packet_sent(&mut self, now: Instant) { + // Only reset idle timeout if we've received a packet since the last + // time we reset the timeout here. + match self.state { + IdleTimeoutState::AckElicitingPacketSent(_) => {} + IdleTimeoutState::Init | IdleTimeoutState::PacketReceived(_) => { + self.state = IdleTimeoutState::AckElicitingPacketSent(now); + } + } + } + + pub fn on_packet_received(&mut self, now: Instant) { + // Only update if this doesn't rewind the idle timeout. + // We sometimes process packets after caching them, which uses + // the time the packet was received. That could be in the past. + let update = match self.state { + IdleTimeoutState::Init => true, + IdleTimeoutState::AckElicitingPacketSent(t) | IdleTimeoutState::PacketReceived(t) => { + t <= now + } + }; + if update { + self.state = IdleTimeoutState::PacketReceived(now); + } + } + + pub fn expired(&self, now: Instant, pto: Duration) -> bool { + now >= self.expiry(now, pto) + } +} diff --git a/third_party/rust/neqo-transport/src/connection/mod.rs b/third_party/rust/neqo-transport/src/connection/mod.rs new file mode 100644 index 0000000000..ba6e628809 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/mod.rs @@ -0,0 +1,2768 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// The class implementing a QUIC connection. + +use std::cell::RefCell; +use std::cmp::{max, min}; +use std::collections::HashMap; +use std::convert::TryFrom; +use std::fmt::{self, Debug}; +use std::mem; +use std::net::SocketAddr; +use std::rc::{Rc, Weak}; +use std::time::{Duration, Instant}; + +use smallvec::SmallVec; + +use neqo_common::{ + event::Provider as EventProvider, hex, hex_snip_middle, qdebug, qerror, qinfo, qlog::NeqoQlog, + qtrace, qwarn, Datagram, Decoder, Encoder, Role, +}; +use neqo_crypto::agent::CertificateInfo; +use neqo_crypto::{ + Agent, AntiReplay, AuthenticationStatus, Cipher, Client, HandshakeState, ResumptionToken, + SecretAgentInfo, Server, ZeroRttChecker, +}; + +use crate::addr_valid::{AddressValidation, NewTokenState}; +use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdManager, ConnectionIdRef}; +use crate::crypto::{Crypto, CryptoDxState, CryptoSpace}; +use crate::dump::*; +use crate::events::{ConnectionEvent, ConnectionEvents}; +use crate::flow_mgr::FlowMgr; +use crate::frame::{ + AckRange, CloseError, Frame, FrameType, StreamType, FRAME_TYPE_CONNECTION_CLOSE_APPLICATION, + FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT, +}; +use crate::packet::{ + DecryptedPacket, PacketBuilder, PacketNumber, PacketType, PublicPacket, QuicVersion, +}; +use crate::path::Path; +use crate::qlog; +use crate::recovery::{LossRecovery, RecoveryToken, SendProfile, GRANULARITY}; +use crate::recv_stream::{RecvStream, RecvStreams, RECV_BUFFER_SIZE}; +use crate::send_stream::{SendStream, SendStreams}; +use crate::stats::{Stats, StatsCell}; +use crate::stream_id::{StreamId, StreamIndex, StreamIndexes}; +use crate::tparams::{ + self, TransportParameter, TransportParameterId, TransportParameters, TransportParametersHandler, +}; +use crate::tracking::{AckTracker, PNSpace, SentPacket}; +use crate::ConnectionParameters; +use crate::{AppError, ConnectionError, Error, Res}; + +mod idle; +pub mod params; +mod saved; +mod state; + +use idle::IdleTimeout; +pub use idle::LOCAL_IDLE_TIMEOUT; +use saved::SavedDatagrams; +pub use state::State; +use state::StateSignaling; + +#[derive(Debug, Default)] +struct Packet(Vec<u8>); + +pub const LOCAL_STREAM_LIMIT_BIDI: u64 = 16; +pub const LOCAL_STREAM_LIMIT_UNI: u64 = 16; + +/// The number of Initial packets that the client will send in response +/// to receiving an undecryptable packet during the early part of the +/// handshake. This is a hack, but a useful one. +const EXTRA_INITIALS: usize = 4; +const LOCAL_MAX_DATA: u64 = 0x3FFF_FFFF_FFFF_FFFF; // 2^62-1 + +#[derive(Debug, PartialEq, Eq)] +pub enum ZeroRttState { + Init, + Sending, + AcceptedClient, + AcceptedServer, + Rejected, +} + +#[derive(Clone, Debug, PartialEq)] +/// Type returned from process() and `process_output()`. Users are required to +/// call these repeatedly until `Callback` or `None` is returned. +pub enum Output { + /// Connection requires no action. + None, + /// Connection requires the datagram be sent. + Datagram(Datagram), + /// Connection requires `process_input()` be called when the `Duration` + /// elapses. + Callback(Duration), +} + +impl Output { + /// Convert into an `Option<Datagram>`. + #[must_use] + pub fn dgram(self) -> Option<Datagram> { + match self { + Self::Datagram(dg) => Some(dg), + _ => None, + } + } + + /// Get a reference to the Datagram, if any. + pub fn as_dgram_ref(&self) -> Option<&Datagram> { + match self { + Self::Datagram(dg) => Some(dg), + _ => None, + } + } + + /// Ask how long the caller should wait before calling back. + #[must_use] + pub fn callback(&self) -> Duration { + match self { + Self::Callback(t) => *t, + _ => Duration::new(0, 0), + } + } +} + +/// Used by inner functions like Connection::output. +enum SendOption { + /// Yes, please send this datagram. + Yes(Datagram), + /// Don't send. If this was blocked on the pacer (the arg is true). + No(bool), +} + +impl Default for SendOption { + fn default() -> Self { + Self::No(false) + } +} + +/// Used by `Connection::preprocess` to determine what to do +/// with an packet before attempting to remove protection. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PreprocessResult { + /// End processing and return successfully. + End, + /// Stop processing this datagram and move on to the next. + Next, + /// Continue and process this packet. + Continue, +} + +/// Alias the common form for ConnectionIdManager. +type CidMgr = Rc<RefCell<dyn ConnectionIdManager>>; + +/// An FixedConnectionIdManager produces random connection IDs of a fixed length. +pub struct FixedConnectionIdManager { + len: usize, +} +impl FixedConnectionIdManager { + pub fn new(len: usize) -> Self { + Self { len } + } +} +impl ConnectionIdDecoder for FixedConnectionIdManager { + fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> { + dec.decode(self.len).map(ConnectionIdRef::from) + } +} +impl ConnectionIdManager for FixedConnectionIdManager { + fn generate_cid(&mut self) -> ConnectionId { + ConnectionId::generate(self.len) + } + fn as_decoder(&self) -> &dyn ConnectionIdDecoder { + self + } +} + +/// `AddressValidationInfo` holds information relevant to either +/// responding to address validation (`NewToken`, `Retry`) or generating +/// tokens for address validation (`Server`). +enum AddressValidationInfo { + None, + // We are a client and have information from `NEW_TOKEN`. + NewToken(Vec<u8>), + // We are a client and have received a `Retry` packet. + Retry { + token: Vec<u8>, + retry_source_cid: ConnectionId, + }, + // We are a server and can generate tokens. + Server(Weak<RefCell<AddressValidation>>), +} + +impl AddressValidationInfo { + pub fn token(&self) -> &[u8] { + match self { + Self::NewToken(token) | Self::Retry { token, .. } => &token, + _ => &[], + } + } + + pub fn generate_new_token( + &mut self, + peer_address: SocketAddr, + now: Instant, + ) -> Option<Vec<u8>> { + match self { + Self::Server(ref w) => { + if let Some(validation) = w.upgrade() { + validation + .borrow() + .generate_new_token(peer_address, now) + .ok() + } else { + None + } + } + Self::None => None, + _ => unreachable!("called a server function on a client"), + } + } +} + +/// A QUIC Connection +/// +/// First, create a new connection using `new_client()` or `new_server()`. +/// +/// For the life of the connection, handle activity in the following manner: +/// 1. Perform operations using the `stream_*()` methods. +/// 1. Call `process_input()` when a datagram is received or the timer +/// expires. Obtain information on connection state changes by checking +/// `events()`. +/// 1. Having completed handling current activity, repeatedly call +/// `process_output()` for packets to send, until it returns `Output::Callback` +/// or `Output::None`. +/// +/// After the connection is closed (either by calling `close()` or by the +/// remote) continue processing until `state()` returns `Closed`. +pub struct Connection { + role: Role, + state: State, + tps: Rc<RefCell<TransportParametersHandler>>, + /// What we are doing with 0-RTT. + zero_rtt_state: ZeroRttState, + /// This object will generate connection IDs for the connection. + cid_manager: CidMgr, + /// Network paths. Right now, this tracks at most one path, so it uses `Option`. + path: Option<Path>, + /// The connection IDs that we will accept. + /// This includes any we advertise in NEW_CONNECTION_ID that haven't been bound to a path yet. + /// During the handshake at the server, it also includes the randomized DCID pick by the client. + valid_cids: Vec<ConnectionId>, + address_validation: AddressValidationInfo, + + /// The source connection ID that this endpoint uses for the handshake. + /// Since we need to communicate this to our peer in tparams, setting this + /// value is part of constructing the struct. + local_initial_source_cid: ConnectionId, + /// The source connection ID from the first packet from the other end. + /// This is checked against the peer's transport parameters. + remote_initial_source_cid: Option<ConnectionId>, + /// The destination connection ID from the first packet from the client. + /// This is checked by the client against the server's transport parameters. + original_destination_cid: Option<ConnectionId>, + + /// We sometimes save a datagram against the possibility that keys will later + /// become available. This avoids reporting packets as dropped during the handshake + /// when they are either just reordered or we haven't been able to install keys yet. + /// In particular, this occurs when asynchronous certificate validation happens. + saved_datagrams: SavedDatagrams, + + pub(crate) crypto: Crypto, + pub(crate) acks: AckTracker, + idle_timeout: IdleTimeout, + pub(crate) indexes: StreamIndexes, + connection_ids: HashMap<u64, (ConnectionId, [u8; 16])>, // (sequence number, (connection id, reset token)) + pub(crate) send_streams: SendStreams, + pub(crate) recv_streams: RecvStreams, + pub(crate) flow_mgr: Rc<RefCell<FlowMgr>>, + state_signaling: StateSignaling, + loss_recovery: LossRecovery, + events: ConnectionEvents, + new_token: NewTokenState, + stats: StatsCell, + qlog: NeqoQlog, + /// A session ticket was received without NEW_TOKEN, + /// this is when that turns into an event without NEW_TOKEN. + release_resumption_token_timer: Option<Instant>, + quic_version: QuicVersion, +} + +impl Debug for Connection { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{:?} Connection: {:?} {:?}", + self.role, self.state, self.path + ) + } +} + +impl Connection { + /// Create a new QUIC connection with Client role. + pub fn new_client( + server_name: &str, + protocols: &[impl AsRef<str>], + cid_manager: CidMgr, + local_addr: SocketAddr, + remote_addr: SocketAddr, + conn_params: &ConnectionParameters, + ) -> Res<Self> { + let dcid = ConnectionId::generate_initial(); + let mut c = Self::new( + Role::Client, + Client::new(server_name)?.into(), + cid_manager, + protocols, + None, + conn_params, + )?; + c.crypto + .states + .init(conn_params.get_quic_version(), Role::Client, &dcid); + c.original_destination_cid = Some(dcid); + c.initialize_path(local_addr, remote_addr); + Ok(c) + } + + /// Create a new QUIC connection with Server role. + pub fn new_server( + certs: &[impl AsRef<str>], + protocols: &[impl AsRef<str>], + cid_manager: CidMgr, + conn_params: &ConnectionParameters, + ) -> Res<Self> { + Self::new( + Role::Server, + Server::new(certs)?.into(), + cid_manager, + protocols, + None, + conn_params, + ) + } + + pub fn server_enable_0rtt( + &mut self, + anti_replay: &AntiReplay, + zero_rtt_checker: impl ZeroRttChecker + 'static, + ) -> Res<()> { + self.crypto + .server_enable_0rtt(self.tps.clone(), anti_replay, zero_rtt_checker) + } + + fn set_tp_defaults(tps: &mut TransportParameters) { + tps.set_integer( + tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL, + u64::try_from(RECV_BUFFER_SIZE).unwrap(), + ); + tps.set_integer( + tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, + u64::try_from(RECV_BUFFER_SIZE).unwrap(), + ); + tps.set_integer( + tparams::INITIAL_MAX_STREAM_DATA_UNI, + u64::try_from(RECV_BUFFER_SIZE).unwrap(), + ); + tps.set_integer(tparams::INITIAL_MAX_STREAMS_BIDI, LOCAL_STREAM_LIMIT_BIDI); + tps.set_integer(tparams::INITIAL_MAX_STREAMS_UNI, LOCAL_STREAM_LIMIT_UNI); + tps.set_integer(tparams::INITIAL_MAX_DATA, LOCAL_MAX_DATA); + tps.set_integer( + tparams::IDLE_TIMEOUT, + u64::try_from(LOCAL_IDLE_TIMEOUT.as_millis()).unwrap(), + ); + tps.set_empty(tparams::DISABLE_MIGRATION); + tps.set_empty(tparams::GREASE_QUIC_BIT); + } + + fn new( + role: Role, + agent: Agent, + cid_manager: CidMgr, + protocols: &[impl AsRef<str>], + path: Option<Path>, + conn_params: &ConnectionParameters, + ) -> Res<Self> { + let tphandler = Rc::new(RefCell::new(TransportParametersHandler::default())); + Self::set_tp_defaults(&mut tphandler.borrow_mut().local); + tphandler.borrow_mut().local.set_integer( + tparams::INITIAL_MAX_STREAMS_BIDI, + conn_params.get_max_streams(StreamType::BiDi), + ); + tphandler.borrow_mut().local.set_integer( + tparams::INITIAL_MAX_STREAMS_UNI, + conn_params.get_max_streams(StreamType::UniDi), + ); + let local_initial_source_cid = cid_manager.borrow_mut().generate_cid(); + tphandler.borrow_mut().local.set_bytes( + tparams::INITIAL_SOURCE_CONNECTION_ID, + local_initial_source_cid.to_vec(), + ); + + let crypto = Crypto::new(agent, protocols, tphandler.clone())?; + + let stats = StatsCell::default(); + let c = Self { + role, + state: State::Init, + cid_manager, + path, + valid_cids: Vec::new(), + tps: tphandler, + zero_rtt_state: ZeroRttState::Init, + address_validation: AddressValidationInfo::None, + local_initial_source_cid, + remote_initial_source_cid: None, + original_destination_cid: None, + saved_datagrams: SavedDatagrams::default(), + crypto, + acks: AckTracker::default(), + idle_timeout: IdleTimeout::default(), + indexes: StreamIndexes::new(), + connection_ids: HashMap::new(), + send_streams: SendStreams::default(), + recv_streams: RecvStreams::default(), + flow_mgr: Rc::new(RefCell::new(FlowMgr::default())), + state_signaling: StateSignaling::Idle, + loss_recovery: LossRecovery::new(conn_params.get_cc_algorithm(), stats.clone()), + events: ConnectionEvents::default(), + new_token: NewTokenState::new(role), + stats, + qlog: NeqoQlog::disabled(), + release_resumption_token_timer: None, + quic_version: conn_params.get_quic_version(), + }; + c.stats.borrow_mut().init(format!("{}", c)); + Ok(c) + } + + /// Get the local path. + pub fn path(&self) -> Option<&Path> { + self.path.as_ref() + } + + /// Set or clear the qlog for this connection. + pub fn set_qlog(&mut self, qlog: NeqoQlog) { + self.loss_recovery.set_qlog(qlog.clone()); + self.qlog = qlog; + } + + /// Get the qlog (if any) for this connection. + pub fn qlog_mut(&mut self) -> &mut NeqoQlog { + &mut self.qlog + } + + /// Get the original destination connection id for this connection. This + /// will always be present for Role::Client but not if Role::Server is in + /// State::Init. + pub fn odcid(&self) -> Option<&ConnectionId> { + self.original_destination_cid.as_ref() + } + + /// Set a local transport parameter, possibly overriding a default value. + pub fn set_local_tparam(&self, tp: TransportParameterId, value: TransportParameter) -> Res<()> { + if *self.state() == State::Init { + self.tps.borrow_mut().local.set(tp, value); + Ok(()) + } else { + qerror!("Current state: {:?}", self.state()); + qerror!("Cannot set local tparam when not in an initial connection state."); + Err(Error::ConnectionState) + } + } + + /// `odcid` is their original choice for our CID, which we get from the Retry token. + /// `remote_cid` is the value from the Source Connection ID field of + /// an incoming packet: what the peer wants us to use now. + /// `retry_cid` is what we asked them to use when we sent the Retry. + pub(crate) fn set_retry_cids( + &mut self, + odcid: ConnectionId, + remote_cid: ConnectionId, + retry_cid: ConnectionId, + ) { + debug_assert_eq!(self.role, Role::Server); + qtrace!( + [self], + "Retry CIDs: odcid={} remote={} retry={}", + odcid, + remote_cid, + retry_cid + ); + // We advertise "our" choices in transport parameters. + let local_tps = &mut self.tps.borrow_mut().local; + local_tps.set_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID, odcid.to_vec()); + local_tps.set_bytes(tparams::RETRY_SOURCE_CONNECTION_ID, retry_cid.to_vec()); + + // ...and save their choices for later validation. + self.remote_initial_source_cid = Some(remote_cid); + } + + fn retry_sent(&self) -> bool { + self.tps + .borrow() + .local + .get_bytes(tparams::RETRY_SOURCE_CONNECTION_ID) + .is_some() + } + + /// Set ALPN preferences. Strings that appear earlier in the list are given + /// higher preference. + pub fn set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()> { + self.crypto.tls.set_alpn(protocols)?; + Ok(()) + } + + /// Enable a set of ciphers. + pub fn set_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()> { + if self.state != State::Init { + qerror!([self], "Cannot enable ciphers in state {:?}", self.state); + return Err(Error::ConnectionState); + } + self.crypto.tls.set_ciphers(ciphers)?; + Ok(()) + } + + fn make_resumption_token(&mut self) -> ResumptionToken { + debug_assert_eq!(self.role, Role::Client); + debug_assert!(self.crypto.has_resumption_token()); + self.crypto + .create_resumption_token( + self.new_token.take_token(), + self.tps + .borrow() + .remote + .as_ref() + .expect("should have transport parameters"), + u64::try_from(self.loss_recovery.rtt().as_millis()).unwrap_or(0), + ) + .unwrap() + } + + fn create_resumption_token(&mut self, now: Instant) { + if self.role == Role::Server || self.state < State::Connected { + return; + } + + qtrace!( + [self], + "Maybe create resumption token: {} {}", + self.crypto.has_resumption_token(), + self.new_token.has_token() + ); + + while self.crypto.has_resumption_token() && self.new_token.has_token() { + let token = self.make_resumption_token(); + self.events.client_resumption_token(token); + } + + // If we have a resumption ticket check or set a timer. + if self.crypto.has_resumption_token() { + let arm = if let Some(expiration_time) = self.release_resumption_token_timer { + if expiration_time <= now { + let token = self.make_resumption_token(); + self.events.client_resumption_token(token); + self.release_resumption_token_timer = None; + + // This means that we release one session ticket every 3 PTOs + // if no NEW_TOKEN frame is received. + self.crypto.has_resumption_token() + } else { + false + } + } else { + true + }; + + if arm { + self.release_resumption_token_timer = + Some(now + 3 * self.loss_recovery.pto_raw(PNSpace::ApplicationData)); + } + } + } + + /// Get a resumption token. The correct way to obtain a resumption token is + /// waiting for the `ConnectionEvent::ResumptionToken` event. However, some + /// servers don't send `NEW_TOKEN` frames and so that event might be slow in + /// arriving. This is especially a problem for short-lived connections, where + /// the connection is closed before any events are released. This retrieves + /// the token, without waiting for the `NEW_TOKEN` frame to arrive. + /// + /// # Panics + /// If this is called on a server. + pub fn take_resumption_token(&mut self, now: Instant) -> Option<ResumptionToken> { + assert_eq!(self.role, Role::Client); + + if self.crypto.has_resumption_token() { + let token = self.make_resumption_token(); + if self.crypto.has_resumption_token() { + self.release_resumption_token_timer = + Some(now + 3 * self.loss_recovery.pto_raw(PNSpace::ApplicationData)); + } + Some(token) + } else { + None + } + } + + /// Enable resumption, using a token previously provided. + /// This can only be called once and only on the client. + /// After calling the function, it should be possible to attempt 0-RTT + /// if the token supports that. + pub fn enable_resumption(&mut self, now: Instant, token: impl AsRef<[u8]>) -> Res<()> { + if self.state != State::Init { + qerror!([self], "set token in state {:?}", self.state); + return Err(Error::ConnectionState); + } + if self.role == Role::Server { + return Err(Error::ConnectionState); + } + + qinfo!( + [self], + "resumption token {}", + hex_snip_middle(token.as_ref()) + ); + let mut dec = Decoder::from(token.as_ref()); + + let smoothed_rtt = + Duration::from_millis(dec.decode_varint().ok_or(Error::InvalidResumptionToken)?); + qtrace!([self], " RTT {:?}", smoothed_rtt); + + let tp_slice = dec.decode_vvec().ok_or(Error::InvalidResumptionToken)?; + qtrace!([self], " transport parameters {}", hex(&tp_slice)); + let mut dec_tp = Decoder::from(tp_slice); + let tp = + TransportParameters::decode(&mut dec_tp).map_err(|_| Error::InvalidResumptionToken)?; + + let init_token = dec.decode_vvec().ok_or(Error::InvalidResumptionToken)?; + qtrace!([self], " Initial token {}", hex(&init_token)); + + let tok = dec.decode_remainder(); + qtrace!([self], " TLS token {}", hex(&tok)); + match self.crypto.tls { + Agent::Client(ref mut c) => { + let res = c.enable_resumption(&tok); + if let Err(e) = res { + self.absorb_error::<Error>(now, Err(Error::from(e))); + return Ok(()); + } + } + Agent::Server(_) => return Err(Error::WrongRole), + } + + self.tps.borrow_mut().remote_0rtt = Some(tp); + if !init_token.is_empty() { + self.address_validation = AddressValidationInfo::NewToken(init_token.to_vec()); + } + if smoothed_rtt > GRANULARITY { + self.loss_recovery.set_initial_rtt(smoothed_rtt); + } + self.set_initial_limits(); + // Start up TLS, which has the effect of setting up all the necessary + // state for 0-RTT. This only stages the CRYPTO frames. + let res = self.client_start(now); + self.absorb_error(now, res); + Ok(()) + } + + pub(crate) fn set_validation(&mut self, validation: Rc<RefCell<AddressValidation>>) { + qtrace!([self], "Enabling NEW_TOKEN"); + assert_eq!(self.role, Role::Server); + self.address_validation = AddressValidationInfo::Server(Rc::downgrade(&validation)); + } + + /// Send a TLS session ticket AND a NEW_TOKEN frame (if possible). + pub fn send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<()> { + if self.role == Role::Client { + return Err(Error::WrongRole); + } + + let tps = &self.tps; + if let Agent::Server(ref mut s) = self.crypto.tls { + let mut enc = Encoder::default(); + enc.encode_vvec_with(|mut enc_inner| { + tps.borrow().local.encode(&mut enc_inner); + }); + enc.encode(extra); + let records = s.send_ticket(now, &enc)?; + qinfo!([self], "send session ticket {}", hex(&enc)); + self.crypto.buffer_records(records)?; + } else { + unreachable!(); + } + + // If we are able, also send a NEW_TOKEN frame. + // This should be recording all remote addresses that are valid, + // but there are just 0 or 1 in the current implementation. + if let Some(p) = self.path.as_ref() { + if let Some(token) = self + .address_validation + .generate_new_token(p.remote_address(), now) + { + self.new_token.send_new_token(token); + } + } + + Ok(()) + } + + pub fn tls_info(&self) -> Option<&SecretAgentInfo> { + self.crypto.tls.info() + } + + /// Get the peer's certificate chain and other info. + pub fn peer_certificate(&self) -> Option<CertificateInfo> { + self.crypto.tls.peer_certificate() + } + + /// Call by application when the peer cert has been verified. + /// + /// This panics if there is no active peer. It's OK to call this + /// when authentication isn't needed, that will likely only cause + /// the connection to fail. However, if no packets have been + /// exchanged, it's not OK. + pub fn authenticated(&mut self, status: AuthenticationStatus, now: Instant) { + qinfo!([self], "Authenticated {:?}", status); + self.crypto.tls.authenticated(status); + let res = self.handshake(now, PNSpace::Handshake, None); + self.absorb_error(now, res); + self.process_saved(now); + } + + /// Get the role of the connection. + pub fn role(&self) -> Role { + self.role + } + + /// Get the state of the connection. + pub fn state(&self) -> &State { + &self.state + } + + /// Get the 0-RTT state of the connection. + pub fn zero_rtt_state(&self) -> &ZeroRttState { + &self.zero_rtt_state + } + + /// Get a snapshot of collected statistics. + pub fn stats(&self) -> Stats { + self.stats.borrow().clone() + } + + // This function wraps a call to another function and sets the connection state + // properly if that call fails. + fn capture_error<T>(&mut self, now: Instant, frame_type: FrameType, res: Res<T>) -> Res<T> { + if let Err(v) = &res { + #[cfg(debug_assertions)] + let msg = format!("{:?}", v); + #[cfg(not(debug_assertions))] + let msg = ""; + let error = ConnectionError::Transport(v.clone()); + match &self.state { + State::Closing { error: err, .. } + | State::Draining { error: err, .. } + | State::Closed(err) => { + qwarn!([self], "Closing again after error {:?}", err); + } + State::Init => { + // We have not even sent anything just close the connection without sending any error. + // This may happen when client_start fails. + self.set_state(State::Closed(error)); + } + State::WaitInitial => { + // We don't have any state yet, so don't bother with + // the closing state, just send one CONNECTION_CLOSE. + self.state_signaling.close(error.clone(), frame_type, msg); + self.set_state(State::Closed(error)); + } + _ => { + self.state_signaling.close(error.clone(), frame_type, msg); + if matches!(v, Error::KeysExhausted) { + self.set_state(State::Closed(error)); + } else { + self.set_state(State::Closing { + error, + timeout: self.get_closing_period_time(now), + }); + } + } + } + } + res + } + + /// For use with process_input(). Errors there can be ignored, but this + /// needs to ensure that the state is updated. + fn absorb_error<T>(&mut self, now: Instant, res: Res<T>) -> Option<T> { + self.capture_error(now, 0, res).ok() + } + + fn process_timer(&mut self, now: Instant) { + if let State::Closing { error, timeout } | State::Draining { error, timeout } = &self.state + { + if *timeout <= now { + // Close timeout expired, move to Closed + let st = State::Closed(error.clone()); + self.set_state(st); + qinfo!("Closing timer expired"); + return; + } + } + if let State::Closed(_) = self.state { + qdebug!("Timer fired while closed"); + return; + } + + let pto = self.loss_recovery.pto_raw(PNSpace::ApplicationData); + if self.idle_timeout.expired(now, pto) { + qinfo!([self], "idle timeout expired"); + self.set_state(State::Closed(ConnectionError::Transport( + Error::IdleTimeout, + ))); + return; + } + + self.cleanup_streams(); + + let res = self.crypto.states.check_key_update(now); + self.absorb_error(now, res); + + let lost = self.loss_recovery.timeout(now); + self.handle_lost_packets(&lost); + qlog::packets_lost(&mut self.qlog, &lost); + + if self.release_resumption_token_timer.is_some() { + self.create_resumption_token(now); + } + } + + /// Process new input datagrams on the connection. + pub fn process_input(&mut self, d: Datagram, now: Instant) { + let res = self.input(d, now); + self.absorb_error(now, res); + self.process_saved(now); + self.cleanup_streams(); + } + + /// Get the time that we next need to be called back, relative to `now`. + fn next_delay(&mut self, now: Instant, paced: bool) -> Duration { + qtrace!([self], "Get callback delay {:?}", now); + + // Only one timer matters when closing... + if let State::Closing { timeout, .. } | State::Draining { timeout, .. } = self.state { + return timeout.duration_since(now); + } + + let mut delays = SmallVec::<[_; 6]>::new(); + if let Some(ack_time) = self.acks.ack_time(now) { + qtrace!([self], "Delayed ACK timer {:?}", ack_time); + delays.push(ack_time); + } + + let pto = self.loss_recovery.pto_raw(PNSpace::ApplicationData); + let idle_time = self.idle_timeout.expiry(now, pto); + qtrace!([self], "Idle timer {:?}", idle_time); + delays.push(idle_time); + + if let Some(lr_time) = self.loss_recovery.next_timeout() { + qtrace!([self], "Loss recovery timer {:?}", lr_time); + delays.push(lr_time); + } + + if let Some(key_update_time) = self.crypto.states.update_time() { + qtrace!([self], "Key update timer {:?}", key_update_time); + delays.push(key_update_time); + } + + if paced { + if let Some(pace_time) = self.loss_recovery.next_paced() { + qtrace!([self], "Pacing timer {:?}", pace_time); + delays.push(pace_time); + } + } + + // `release_resumption_token_timer` is not considered here, because + // it is not important enough to force the application to set a + // timeout for it It is expected thatt other activities will + // drive it. + + let earliest = delays.into_iter().min().unwrap(); + // TODO(agrover, mt) - need to analyze and fix #47 + // rather than just clamping to zero here. + qdebug!( + [self], + "delay duration {:?}", + max(now, earliest).duration_since(now) + ); + debug_assert!(earliest > now); + max(now, earliest).duration_since(now) + } + + /// Get output packets, as a result of receiving packets, or actions taken + /// by the application. + /// Returns datagrams to send, and how long to wait before calling again + /// even if no incoming packets. + #[must_use = "Output of the process_output function must be handled"] + pub fn process_output(&mut self, now: Instant) -> Output { + qtrace!([self], "process_output {:?} {:?}", self.state, now); + + if self.state == State::Init { + if self.role == Role::Client { + let res = self.client_start(now); + self.absorb_error(now, res); + } + } else { + self.process_timer(now); + } + + match self.output(now) { + SendOption::Yes(dgram) => Output::Datagram(dgram), + SendOption::No(paced) => match self.state { + State::Init | State::Closed(_) => Output::None, + State::Closing { timeout, .. } | State::Draining { timeout, .. } => { + Output::Callback(timeout.duration_since(now)) + } + _ => Output::Callback(self.next_delay(now, paced)), + }, + } + } + + /// Process input and generate output. + #[must_use = "Output of the process function must be handled"] + pub fn process(&mut self, dgram: Option<Datagram>, now: Instant) -> Output { + if let Some(d) = dgram { + let res = self.input(d, now); + self.absorb_error(now, res); + self.process_saved(now); + } + self.process_output(now) + } + + fn is_valid_cid(&self, cid: &ConnectionIdRef) -> bool { + self.valid_cids.iter().any(|c| c == cid) || self.path.iter().any(|p| p.valid_local_cid(cid)) + } + + fn handle_retry(&mut self, packet: &PublicPacket) -> Res<()> { + qinfo!([self], "received Retry"); + if matches!(self.address_validation, AddressValidationInfo::Retry { .. }) { + self.stats.borrow_mut().pkt_dropped("Extra Retry"); + return Ok(()); + } + if packet.token().is_empty() { + self.stats.borrow_mut().pkt_dropped("Retry without a token"); + return Ok(()); + } + if !packet.is_valid_retry(&self.original_destination_cid.as_ref().unwrap()) { + self.stats + .borrow_mut() + .pkt_dropped("Retry with bad integrity tag"); + return Ok(()); + } + if let Some(p) = &mut self.path { + // At this point, we shouldn't have a remote connection ID for the path. + p.set_remote_cid(packet.scid()); + } else { + qinfo!([self], "No path, but we received a Retry"); + return Err(Error::InternalError); + }; + + let retry_scid = ConnectionId::from(packet.scid()); + qinfo!( + [self], + "Valid Retry received, token={} scid={}", + hex(packet.token()), + retry_scid + ); + + let lost_packets = self.loss_recovery.retry(); + self.handle_lost_packets(&lost_packets); + + self.crypto + .states + .init(self.quic_version, self.role, &retry_scid); + self.address_validation = AddressValidationInfo::Retry { + token: packet.token().to_vec(), + retry_source_cid: retry_scid, + }; + Ok(()) + } + + fn discard_keys(&mut self, space: PNSpace, now: Instant) { + if self.crypto.discard(space) { + qinfo!([self], "Drop packet number space {}", space); + self.loss_recovery.discard(space, now); + self.acks.drop_space(space); + } + } + + fn token_equal(a: &[u8; 16], b: &[u8; 16]) -> bool { + // rustc might decide to optimize this and make this non-constant-time + // with respect to `t`, but it doesn't appear to currently. + let mut c = 0; + for (&a, &b) in a.iter().zip(b) { + c |= a ^ b; + } + c == 0 + } + + fn is_stateless_reset(&self, d: &Datagram) -> bool { + if d.len() < 16 { + return false; + } + let token = <&[u8; 16]>::try_from(&d[d.len() - 16..]).unwrap(); + // TODO(mt) only check the path that matches the datagram. + self.path + .as_ref() + .map(|p| p.reset_token()) + .flatten() + .map_or(false, |t| Self::token_equal(t, token)) + } + + fn check_stateless_reset<'a, 'b>( + &'a mut self, + d: &'b Datagram, + first: bool, + now: Instant, + ) -> Res<()> { + if first && self.is_stateless_reset(d) { + // Failing to process a packet in a datagram might + // indicate that there is a stateless reset present. + qdebug!([self], "Stateless reset: {}", hex(&d[d.len() - 16..])); + self.state_signaling.reset(); + self.set_state(State::Draining { + error: ConnectionError::Transport(Error::StatelessReset), + timeout: self.get_closing_period_time(now), + }); + Err(Error::StatelessReset) + } else { + Ok(()) + } + } + + /// Process any saved datagrams that might be available for processing. + fn process_saved(&mut self, now: Instant) { + while let Some(cspace) = self.saved_datagrams.available() { + qdebug!([self], "process saved for space {:?}", cspace); + debug_assert!(self.crypto.states.rx_hp(cspace).is_some()); + for saved in self.saved_datagrams.take_saved() { + qtrace!([self], "input saved @{:?}: {:?}", saved.t, saved.d); + let res = self.input(saved.d, saved.t); + self.absorb_error(now, res); + } + } + } + + /// In case a datagram arrives that we can only partially process, save any + /// part that we don't have keys for. + fn save_datagram(&mut self, cspace: CryptoSpace, d: Datagram, remaining: usize, now: Instant) { + let d = if remaining < d.len() { + Datagram::new(d.source(), d.destination(), &d[d.len() - remaining..]) + } else { + d + }; + self.saved_datagrams.save(cspace, d, now); + self.stats.borrow_mut().saved_datagrams += 1; + } + + /// Perform any processing that we might have to do on packets prior to + /// attempting to remove protection. + fn preprocess( + &mut self, + packet: &PublicPacket, + dcid: Option<&ConnectionId>, + now: Instant, + ) -> Res<PreprocessResult> { + if dcid.map_or(false, |d| d != packet.dcid()) { + self.stats + .borrow_mut() + .pkt_dropped("Coalesced packet has different DCID"); + return Ok(PreprocessResult::Next); + } + + match (packet.packet_type(), &self.state, &self.role) { + (PacketType::Initial, State::Init, Role::Server) => { + if !packet.is_valid_initial() { + self.stats.borrow_mut().pkt_dropped("Invalid Initial"); + return Ok(PreprocessResult::Next); + } + qinfo!( + [self], + "Received valid Initial packet with scid {:?} dcid {:?}", + packet.scid(), + packet.dcid() + ); + self.set_state(State::WaitInitial); + self.loss_recovery.start_pacer(now); + self.crypto + .states + .init(self.quic_version, self.role, &packet.dcid()); + + // We need to make sure that we set this transport parameter. + // This has to happen prior to processing the packet so that + // the TLS handshake has all it needs. + if !self.retry_sent() { + self.tps.borrow_mut().local.set_bytes( + tparams::ORIGINAL_DESTINATION_CONNECTION_ID, + packet.dcid().to_vec(), + ) + } + } + (PacketType::VersionNegotiation, State::WaitInitial, Role::Client) => { + match packet.supported_versions() { + Ok(versions) => { + if versions.is_empty() + || versions.contains(&self.quic_version.as_u32()) + || packet.dcid() != self.odcid().unwrap() + || matches!(self.address_validation, AddressValidationInfo::Retry { .. }) + { + // Ignore VersionNegotiation packets that contain the current version. + // Or don't have the right connection ID. + // Or are received after a Retry. + self.stats.borrow_mut().pkt_dropped("Invalid VN"); + return Ok(PreprocessResult::End); + } + + self.set_state(State::Closed(ConnectionError::Transport( + Error::VersionNegotiation, + ))); + return Err(Error::VersionNegotiation); + } + Err(_) => { + self.stats.borrow_mut().pkt_dropped("Invalid VN"); + return Ok(PreprocessResult::End); + } + } + } + (PacketType::Retry, State::WaitInitial, Role::Client) => { + self.handle_retry(packet)?; + return Ok(PreprocessResult::Next); + } + (PacketType::Handshake, State::WaitInitial, Role::Client) + | (PacketType::Short, State::WaitInitial, Role::Client) => { + // This packet can't be processed now, but it could be a sign + // that Initial packets were lost. + // Resend Initial CRYPTO frames immediately a few times just + // in case. As we don't have an RTT estimate yet, this helps + // when there is a short RTT and losses. + if dcid.is_none() + && self.is_valid_cid(packet.dcid()) + && self.stats.borrow().saved_datagrams <= EXTRA_INITIALS + { + self.crypto.resend_unacked(PNSpace::Initial); + } + } + (PacketType::VersionNegotiation, ..) + | (PacketType::Retry, ..) + | (PacketType::OtherVersion, ..) => { + self.stats + .borrow_mut() + .pkt_dropped(format!("{:?}", packet.packet_type())); + return Ok(PreprocessResult::Next); + } + _ => {} + } + + let res = match self.state { + State::Init => { + self.stats + .borrow_mut() + .pkt_dropped("Received while in Init state"); + PreprocessResult::Next + } + State::WaitInitial => PreprocessResult::Continue, + State::Handshaking | State::Connected | State::Confirmed => { + if !self.is_valid_cid(packet.dcid()) { + self.stats + .borrow_mut() + .pkt_dropped(format!("Invalid DCID {:?}", packet.dcid())); + PreprocessResult::Next + } else { + if self.role == Role::Server && packet.packet_type() == PacketType::Handshake { + // Server has received a Handshake packet -> discard Initial keys and states + self.discard_keys(PNSpace::Initial, now); + } + PreprocessResult::Continue + } + } + State::Closing { .. } => { + // Don't bother processing the packet. Instead ask to get a + // new close frame. + self.state_signaling.send_close(); + PreprocessResult::Next + } + State::Draining { .. } | State::Closed(..) => { + // Do nothing. + self.stats + .borrow_mut() + .pkt_dropped(format!("State {:?}", self.state)); + PreprocessResult::Next + } + }; + Ok(res) + } + + /// Take a datagram as input. This reports an error if the packet was bad. + fn input(&mut self, d: Datagram, now: Instant) -> Res<()> { + let mut slc = &d[..]; + let mut dcid = None; + + qtrace!([self], "input {}", hex(&**d)); + + // Handle each packet in the datagram. + while !slc.is_empty() { + self.stats.borrow_mut().packets_rx += 1; + let (packet, remainder) = + match PublicPacket::decode(slc, self.cid_manager.borrow().as_decoder()) { + Ok((packet, remainder)) => (packet, remainder), + Err(e) => { + qinfo!([self], "Garbage packet: {}", e); + qtrace!([self], "Garbage packet contents: {}", hex(slc)); + self.stats.borrow_mut().pkt_dropped("Garbage packet"); + break; + } + }; + match self.preprocess(&packet, dcid.as_ref(), now)? { + PreprocessResult::Continue => (), + PreprocessResult::Next => break, + PreprocessResult::End => return Ok(()), + } + + qtrace!([self], "Received unverified packet {:?}", packet); + + let pto = self.loss_recovery.pto_raw(PNSpace::ApplicationData); + match packet.decrypt(&mut self.crypto.states, now + pto) { + Ok(payload) => { + // TODO(ekr@rtfm.com): Have the server blow away the initial + // crypto state if this fails? Otherwise, we will get a panic + // on the assert for doesn't exist. + // OK, we have a valid packet. + self.idle_timeout.on_packet_received(now); + dump_packet( + self, + "-> RX", + payload.packet_type(), + payload.pn(), + &payload[..], + ); + qlog::packet_received(&mut self.qlog, &packet, &payload); + let res = self.process_packet(&payload, now); + if res.is_err() && self.path.is_none() { + // We need to make a path for sending an error message. + // But this connection is going to be closed. + self.remote_initial_source_cid = Some(ConnectionId::from(packet.scid())); + self.initialize_path(d.destination(), d.source()); + } + res?; + if self.state == State::WaitInitial { + self.start_handshake(&packet, &d)?; + } + self.process_migrations(&d)?; + } + Err(e) => { + match e { + Error::KeysPending(cspace) => { + // This packet can't be decrypted because we don't have the keys yet. + // Don't check this packet for a stateless reset, just return. + let remaining = slc.len(); + self.save_datagram(cspace, d, remaining, now); + return Ok(()); + } + Error::KeysExhausted => { + // Exhausting read keys is fatal. + return Err(e); + } + _ => (), + } + // Decryption failure, or not having keys is not fatal. + // If the state isn't available, or we can't decrypt the packet, drop + // the rest of the datagram on the floor, but don't generate an error. + self.check_stateless_reset(&d, dcid.is_none(), now)?; + self.stats.borrow_mut().pkt_dropped("Decryption failure"); + qlog::packet_dropped(&mut self.qlog, &packet); + } + } + slc = remainder; + dcid = Some(ConnectionId::from(packet.dcid())); + } + self.check_stateless_reset(&d, dcid.is_none(), now)?; + Ok(()) + } + + fn process_packet(&mut self, packet: &DecryptedPacket, now: Instant) -> Res<()> { + // TODO(ekr@rtfm.com): Have the server blow away the initial + // crypto state if this fails? Otherwise, we will get a panic + // on the assert for doesn't exist. + // OK, we have a valid packet. + + let space = PNSpace::from(packet.packet_type()); + if self.acks.get_mut(space).unwrap().is_duplicate(packet.pn()) { + qdebug!([self], "Duplicate packet from {} pn={}", space, packet.pn()); + self.stats.borrow_mut().dups_rx += 1; + return Ok(()); + } + + let mut ack_eliciting = false; + let mut d = Decoder::from(&packet[..]); + let mut consecutive_padding = 0; + while d.remaining() > 0 { + let mut f = Frame::decode(&mut d)?; + + // Skip padding + while f == Frame::Padding && d.remaining() > 0 { + consecutive_padding += 1; + f = Frame::decode(&mut d)?; + } + if consecutive_padding > 0 { + qdebug!( + [self], + "PADDING frame repeated {} times", + consecutive_padding + ); + consecutive_padding = 0; + } + + ack_eliciting |= f.ack_eliciting(); + let t = f.get_type(); + let res = self.input_frame(packet.packet_type(), f, now); + self.capture_error(now, t, res)?; + } + self.acks + .get_mut(space) + .unwrap() + .set_received(now, packet.pn(), ack_eliciting); + + Ok(()) + } + + fn initialize_path(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr) { + debug_assert!(self.path.is_none()); + self.path = Some(Path::new( + local_addr, + remote_addr, + self.local_initial_source_cid.clone(), + // Ideally we know what the peer wants us to use for the remote CID. + // But we will use our own guess if necessary. + self.remote_initial_source_cid + .as_ref() + .or_else(|| self.original_destination_cid.as_ref()) + .unwrap() + .clone(), + )); + } + + fn start_handshake(&mut self, packet: &PublicPacket, d: &Datagram) -> Res<()> { + qtrace!([self], "starting handshake"); + debug_assert_eq!(packet.packet_type(), PacketType::Initial); + self.remote_initial_source_cid = Some(ConnectionId::from(packet.scid())); + + if self.role == Role::Server { + // A server needs to accept the client's selected CID during the handshake. + self.valid_cids.push(ConnectionId::from(packet.dcid())); + self.original_destination_cid = Some(ConnectionId::from(packet.dcid())); + // Install a path. + self.initialize_path(d.destination(), d.source()); + + self.zero_rtt_state = match self.crypto.enable_0rtt(self.role) { + Ok(true) => { + qdebug!([self], "Accepted 0-RTT"); + ZeroRttState::AcceptedServer + } + _ => ZeroRttState::Rejected, + }; + } else { + qdebug!([self], "Changing to use Server CID={}", packet.scid()); + let p = self + .path + .iter_mut() + .find(|p| p.received_on(&d)) + .expect("should have a path for sending Initial"); + p.set_remote_cid(packet.scid()); + } + + self.set_state(State::Handshaking); + Ok(()) + } + + fn process_migrations(&self, d: &Datagram) -> Res<()> { + if self.path.iter().any(|p| p.received_on(&d)) { + Ok(()) + } else { + // Right now, we don't support any form of migration. + // So generate an error if a packet is received on a new path. + Err(Error::InvalidMigration) + } + } + + fn output(&mut self, now: Instant) -> SendOption { + qtrace!([self], "output {:?}", now); + if let Some(mut path) = self.path.take() { + let res = match &self.state { + State::Init + | State::WaitInitial + | State::Handshaking + | State::Connected + | State::Confirmed => self.output_path(&mut path, now), + State::Closing { .. } | State::Draining { .. } | State::Closed(_) => { + if let Some(frame) = self.state_signaling.close_frame() { + self.output_close(&path, &frame) + } else { + Ok(SendOption::default()) + } + } + }; + let out = self.absorb_error(now, res).unwrap_or_default(); + self.path = Some(path); + out + } else { + SendOption::default() + } + } + + fn build_packet_header( + path: &Path, + cspace: CryptoSpace, + encoder: Encoder, + tx: &CryptoDxState, + address_validation: &AddressValidationInfo, + quic_version: QuicVersion, + grease_quic_bit: bool, + ) -> (PacketType, PacketBuilder) { + let pt = PacketType::from(cspace); + let mut builder = if pt == PacketType::Short { + qdebug!("Building Short dcid {}", path.remote_cid()); + PacketBuilder::short(encoder, tx.key_phase(), path.remote_cid()) + } else { + qdebug!( + "Building {:?} dcid {} scid {}", + pt, + path.remote_cid(), + path.local_cid(), + ); + + PacketBuilder::long( + encoder, + pt, + quic_version, + path.remote_cid(), + path.local_cid(), + ) + }; + builder.scramble(grease_quic_bit); + if pt == PacketType::Initial { + builder.initial_token(address_validation.token()); + } + + (pt, builder) + } + + fn add_packet_number( + builder: &mut PacketBuilder, + tx: &CryptoDxState, + largest_acknowledged: Option<PacketNumber>, + ) -> PacketNumber { + // Get the packet number and work out how long it is. + let pn = tx.next_pn(); + let unacked_range = if let Some(la) = largest_acknowledged { + // Double the range from this to the last acknowledged in this space. + (pn - la) << 1 + } else { + pn + 1 + }; + // Count how many bytes in this range are non-zero. + let pn_len = mem::size_of::<PacketNumber>() + - usize::try_from(unacked_range.leading_zeros() / 8).unwrap(); + // pn_len can't be zero (unacked_range is > 0) + // TODO(mt) also use `4*path CWND/path MTU` to set a minimum length. + builder.pn(pn, pn_len); + pn + } + + fn can_grease_quic_bit(&self) -> bool { + let tph = self.tps.borrow(); + if let Some(r) = &tph.remote { + r.get_empty(tparams::GREASE_QUIC_BIT) + } else if let Some(r) = &tph.remote_0rtt { + r.get_empty(tparams::GREASE_QUIC_BIT) + } else { + false + } + } + + fn output_close(&mut self, path: &Path, frame: &Frame) -> Res<SendOption> { + let mut encoder = Encoder::with_capacity(path.mtu()); + let grease_quic_bit = self.can_grease_quic_bit(); + for space in PNSpace::iter() { + let (cspace, tx) = if let Some(crypto) = self.crypto.states.select_tx(*space) { + crypto + } else { + continue; + }; + + let (_, mut builder) = Self::build_packet_header( + path, + cspace, + encoder, + tx, + &AddressValidationInfo::None, + self.quic_version, + grease_quic_bit, + ); + let _ = Self::add_packet_number( + &mut builder, + tx, + self.loss_recovery.largest_acknowledged_pn(*space), + ); + + // ConnectionError::Application is only allowed at 1RTT. + let sanitized = if *space == PNSpace::ApplicationData { + &frame + } else { + frame.sanitize_close() + }; + if let Frame::ConnectionClose { + error_code, + frame_type, + reason_phrase, + } = sanitized + { + builder.encode_varint(sanitized.get_type()); + builder.encode_varint(error_code.code()); + if let CloseError::Transport(_) = error_code { + builder.encode_varint(*frame_type); + } + let reason_len = min(min(reason_phrase.len(), 256), builder.remaining() - 2); + builder.encode_vvec(&reason_phrase[..reason_len]); + } else { + unreachable!(); + } + + encoder = builder.build(tx)?; + } + + Ok(SendOption::Yes(path.datagram(encoder))) + } + + /// Write frames to the provided builder. Returns a list of tokens used for + /// tracking loss or acknowledgment, whether any frame was ACK eliciting, and + /// whether the packet was padded. + fn write_frames( + &mut self, + space: PNSpace, + profile: &SendProfile, + builder: &mut PacketBuilder, + mut pad: bool, + now: Instant, + ) -> (Vec<RecoveryToken>, bool, bool) { + let mut tokens = Vec::new(); + let stats = &mut self.stats.borrow_mut().frame_tx; + + let ack_token = self.acks.write_frame(space, now, builder, stats); + + if profile.ack_only(space) { + // If we are CC limited we can only send acks! + if let Some(t) = ack_token { + tokens.push(t); + } + return (tokens, false, false); + } + + if space == PNSpace::ApplicationData && self.role == Role::Server { + if let Some(t) = self.state_signaling.write_done(builder) { + tokens.push(t); + stats.handshake_done += 1; + } + } + + if let Some(t) = self.crypto.streams.write_frame(space, builder) { + tokens.push(t); + stats.crypto += 1; + } + + if space == PNSpace::ApplicationData { + self.flow_mgr + .borrow_mut() + .write_frames(builder, &mut tokens, stats); + + self.send_streams.write_frames(builder, &mut tokens, stats); + self.new_token.write_frames(builder, &mut tokens, stats); + } + + // Anything - other than ACK - that registered a token wants an acknowledgment. + let ack_eliciting = !tokens.is_empty() + || if profile.should_probe(space) { + // Nothing ack-eliciting and we need to probe; send PING. + debug_assert_ne!(builder.remaining(), 0); + builder.encode_varint(crate::frame::FRAME_TYPE_PING); + stats.ping += 1; + stats.all += 1; + true + } else { + false + }; + + // Add padding. Only pad 1-RTT packets so that we don't prevent coalescing. + // And avoid padding packets that otherwise only contain ACK because adding PADDING + // causes those packets to consume congestion window, which is not tracked (yet). + pad &= ack_eliciting && space == PNSpace::ApplicationData; + if pad { + builder.pad(); + stats.padding += 1; + stats.all += 1; + } + + if let Some(t) = ack_token { + tokens.push(t); + } + stats.all += tokens.len(); + (tokens, ack_eliciting, pad) + } + + /// Build a datagram, possibly from multiple packets (for different PN + /// spaces) and each containing 1+ frames. + fn output_path(&mut self, path: &mut Path, now: Instant) -> Res<SendOption> { + let mut initial_sent = None; + let mut needs_padding = false; + let grease_quic_bit = self.can_grease_quic_bit(); + + // Determine how we are sending packets (PTO, etc..). + let profile = self.loss_recovery.send_profile(now, path.mtu()); + qdebug!([self], "output_path send_profile {:?}", profile); + + // Frames for different epochs must go in different packets, but then these + // packets can go in a single datagram + let mut encoder = Encoder::with_capacity(profile.limit()); + for space in PNSpace::iter() { + // Ensure we have tx crypto state for this epoch, or skip it. + let (cspace, tx) = if let Some(crypto) = self.crypto.states.select_tx(*space) { + crypto + } else { + continue; + }; + + let header_start = encoder.len(); + let (pt, mut builder) = Self::build_packet_header( + path, + cspace, + encoder, + tx, + &self.address_validation, + self.quic_version, + grease_quic_bit, + ); + let pn = Self::add_packet_number( + &mut builder, + tx, + self.loss_recovery.largest_acknowledged_pn(*space), + ); + let payload_start = builder.len(); + + // Work out if we have space left. + let aead_expansion = tx.expansion(); + if builder.len() + aead_expansion > profile.limit() { + // No space for a packet of this type. + encoder = builder.abort(); + continue; + } + + // Add frames to the packet. + let limit = profile.limit() - aead_expansion; + builder.set_limit(limit); + let (tokens, ack_eliciting, padded) = + self.write_frames(*space, &profile, &mut builder, needs_padding, now); + if builder.packet_empty() { + // Nothing to include in this packet. + encoder = builder.abort(); + continue; + } + + dump_packet(self, "TX ->", pt, pn, &builder[payload_start..]); + qlog::packet_sent( + &mut self.qlog, + pt, + pn, + builder.len() - header_start + aead_expansion, + &builder[payload_start..], + ); + + self.stats.borrow_mut().packets_tx += 1; + encoder = builder.build(self.crypto.states.tx(cspace).unwrap())?; + debug_assert!(encoder.len() <= path.mtu()); + self.crypto.states.auto_update()?; + + if ack_eliciting { + self.idle_timeout.on_packet_sent(now); + } + let sent = SentPacket::new( + pt, + pn, + now, + ack_eliciting, + tokens, + encoder.len() - header_start, + ); + if padded { + needs_padding = false; + self.loss_recovery.on_packet_sent(sent); + } else if pt == PacketType::Initial && (self.role == Role::Client || ack_eliciting) { + // Packets containing Initial packets might need padding, and we want to + // track that padding along with the Initial packet. So defer tracking. + initial_sent = Some(sent); + needs_padding = true; + } else { + if pt == PacketType::Handshake && self.role == Role::Client { + needs_padding = false; + } + self.loss_recovery.on_packet_sent(sent); + } + + if *space == PNSpace::Handshake { + if self.role == Role::Client { + // Client can send Handshake packets -> discard Initial keys and states + self.discard_keys(PNSpace::Initial, now); + } else if self.state == State::Confirmed { + // We could discard handshake keys in set_state, but wait until after sending an ACK. + self.discard_keys(PNSpace::Handshake, now); + } + } + } + + if encoder.is_empty() { + Ok(SendOption::No(profile.paced())) + } else { + // Perform additional padding for Initial packets as necessary. + let mut packets: Vec<u8> = encoder.into(); + if let Some(mut initial) = initial_sent.take() { + if needs_padding { + qdebug!([self], "pad Initial to path MTU {}", path.mtu()); + initial.size += path.mtu() - packets.len(); + packets.resize(path.mtu(), 0); + } + self.loss_recovery.on_packet_sent(initial); + } + Ok(SendOption::Yes(path.datagram(packets))) + } + } + + pub fn initiate_key_update(&mut self) -> Res<()> { + if self.state == State::Confirmed { + let la = self + .loss_recovery + .largest_acknowledged_pn(PNSpace::ApplicationData); + qinfo!([self], "Initiating key update"); + self.crypto.states.initiate_key_update(la) + } else { + Err(Error::KeyUpdateBlocked) + } + } + + #[cfg(test)] + pub fn get_epochs(&self) -> (Option<usize>, Option<usize>) { + self.crypto.states.get_epochs() + } + + fn client_start(&mut self, now: Instant) -> Res<()> { + qinfo!([self], "client_start"); + debug_assert_eq!(self.role, Role::Client); + qlog::client_connection_started(&mut self.qlog, self.path.as_ref().unwrap()); + self.loss_recovery.start_pacer(now); + + self.handshake(now, PNSpace::Initial, None)?; + self.set_state(State::WaitInitial); + self.zero_rtt_state = if self.crypto.enable_0rtt(self.role)? { + qdebug!([self], "Enabled 0-RTT"); + ZeroRttState::Sending + } else { + ZeroRttState::Init + }; + Ok(()) + } + + fn get_closing_period_time(&self, now: Instant) -> Instant { + // Spec says close time should be at least PTO times 3. + now + (self.loss_recovery.pto_raw(PNSpace::ApplicationData) * 3) + } + + /// Close the connection. + pub fn close(&mut self, now: Instant, app_error: AppError, msg: impl AsRef<str>) { + let error = ConnectionError::Application(app_error); + let timeout = self.get_closing_period_time(now); + self.state_signaling.close(error.clone(), 0, msg); + self.set_state(State::Closing { error, timeout }); + } + + fn set_initial_limits(&mut self) { + let tps = self.tps.borrow(); + let remote = tps.remote(); + self.indexes.remote_max_stream_bidi = + StreamIndex::new(remote.get_integer(tparams::INITIAL_MAX_STREAMS_BIDI)); + self.indexes.remote_max_stream_uni = + StreamIndex::new(remote.get_integer(tparams::INITIAL_MAX_STREAMS_UNI)); + self.flow_mgr + .borrow_mut() + .conn_increase_max_credit(remote.get_integer(tparams::INITIAL_MAX_DATA)); + + let peer_timeout = remote.get_integer(tparams::IDLE_TIMEOUT); + if peer_timeout > 0 { + self.idle_timeout + .set_peer_timeout(Duration::from_millis(peer_timeout)); + } + } + + /// Process the final set of transport parameters. + fn process_tps(&mut self) -> Res<()> { + self.validate_cids()?; + { + let tps = self.tps.borrow(); + if let Some(token) = tps + .remote + .as_ref() + .unwrap() + .get_bytes(tparams::STATELESS_RESET_TOKEN) + { + let reset_token = <[u8; 16]>::try_from(token).unwrap().to_owned(); + self.path.as_mut().unwrap().set_reset_token(reset_token); + } + let mad = Duration::from_millis( + tps.remote + .as_ref() + .unwrap() + .get_integer(tparams::MAX_ACK_DELAY), + ); + self.loss_recovery.set_peer_max_ack_delay(mad); + } + self.set_initial_limits(); + qlog::connection_tparams_set(&mut self.qlog, &*self.tps.borrow()); + Ok(()) + } + + fn validate_cids(&mut self) -> Res<()> { + match self.quic_version { + QuicVersion::Draft27 => self.validate_cids_draft_27(), + _ => self.validate_cids_draft_28_plus(), + } + } + + fn validate_cids_draft_27(&mut self) -> Res<()> { + if let AddressValidationInfo::Retry { token, .. } = &self.address_validation { + debug_assert!(!token.is_empty()); + let tph = self.tps.borrow(); + let tp = tph + .remote + .as_ref() + .unwrap() + .get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID); + if self + .original_destination_cid + .as_ref() + .map(ConnectionId::as_cid_ref) + != tp.map(ConnectionIdRef::from) + { + return Err(Error::InvalidRetry); + } + } + Ok(()) + } + + fn validate_cids_draft_28_plus(&mut self) -> Res<()> { + let tph = self.tps.borrow(); + let remote_tps = tph.remote.as_ref().unwrap(); + + let tp = remote_tps.get_bytes(tparams::INITIAL_SOURCE_CONNECTION_ID); + if self + .remote_initial_source_cid + .as_ref() + .map(ConnectionId::as_cid_ref) + != tp.map(ConnectionIdRef::from) + { + qwarn!( + [self], + "ISCID test failed: self cid {:?} != tp cid {:?}", + self.remote_initial_source_cid, + tp.map(hex), + ); + return Err(Error::ProtocolViolation); + } + + if self.role == Role::Client { + let tp = remote_tps.get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID); + if self + .original_destination_cid + .as_ref() + .map(ConnectionId::as_cid_ref) + != tp.map(ConnectionIdRef::from) + { + qwarn!( + [self], + "ODCID test failed: self cid {:?} != tp cid {:?}", + self.original_destination_cid, + tp.map(hex), + ); + return Err(Error::ProtocolViolation); + } + + let tp = remote_tps.get_bytes(tparams::RETRY_SOURCE_CONNECTION_ID); + let expected = if let AddressValidationInfo::Retry { + retry_source_cid, .. + } = &self.address_validation + { + Some(retry_source_cid.as_cid_ref()) + } else { + None + }; + if expected != tp.map(ConnectionIdRef::from) { + qwarn!( + [self], + "RSCID test failed. self cid {:?} != tp cid {:?}", + expected, + tp.map(hex), + ); + return Err(Error::ProtocolViolation); + } + } + + Ok(()) + } + + fn handshake(&mut self, now: Instant, space: PNSpace, data: Option<&[u8]>) -> Res<()> { + qtrace!([self], "Handshake space={} data={:0x?}", space, data); + + let try_update = data.is_some(); + match self.crypto.handshake(now, space, data)? { + HandshakeState::Authenticated(_) | HandshakeState::InProgress => (), + HandshakeState::AuthenticationPending => self.events.authentication_needed(), + HandshakeState::Complete(_) => { + if !self.state.connected() { + self.set_connected(now)?; + } + } + _ => { + unreachable!("Crypto state should not be new or failed after successful handshake") + } + } + + // There is a chance that this could be called less often, but getting the + // conditions right is a little tricky, so call it on every CRYPTO frame. + if try_update { + // We have transport parameters, it's go time. + if self.tps.borrow().remote.is_some() { + self.set_initial_limits(); + } + if self.crypto.install_keys(self.role)? { + self.saved_datagrams.make_available(CryptoSpace::Handshake); + } + } + + Ok(()) + } + + fn handle_max_data(&mut self, maximum_data: u64) { + let conn_was_blocked = self.flow_mgr.borrow().conn_credit_avail() == 0; + let conn_credit_increased = self + .flow_mgr + .borrow_mut() + .conn_increase_max_credit(maximum_data); + + if conn_was_blocked && conn_credit_increased { + for (id, ss) in &mut self.send_streams { + if ss.avail() > 0 { + // These may not actually all be writable if one + // uses up all the conn credit. Not our fault. + self.events.send_stream_writable(*id) + } + } + } + } + + fn input_frame(&mut self, ptype: PacketType, frame: Frame, now: Instant) -> Res<()> { + if !frame.is_allowed(ptype) { + qinfo!("frame not allowed: {:?} {:?}", frame, ptype); + return Err(Error::ProtocolViolation); + } + self.stats.borrow_mut().frame_rx.all += 1; + let space = PNSpace::from(ptype); + match frame { + Frame::Padding => { + // Note: This counts contiguous padding as a single frame. + self.stats.borrow_mut().frame_rx.padding += 1; + } + Frame::Ping => { + // If we get a PING and there are outstanding CRYPTO frames, + // prepare to resend them. + self.stats.borrow_mut().frame_rx.ping += 1; + self.crypto.resend_unacked(space); + } + Frame::Ack { + largest_acknowledged, + ack_delay, + first_ack_range, + ack_ranges, + } => { + self.handle_ack( + space, + largest_acknowledged, + ack_delay, + first_ack_range, + ack_ranges, + now, + )?; + } + Frame::ResetStream { + stream_id, + application_error_code, + .. + } => { + // TODO(agrover@mozilla.com): use final_size for connection MaxData calc + self.stats.borrow_mut().frame_rx.reset_stream += 1; + if let (_, Some(rs)) = self.obtain_stream(stream_id)? { + rs.reset(application_error_code); + } + } + Frame::StopSending { + stream_id, + application_error_code, + } => { + self.stats.borrow_mut().frame_rx.stop_sending += 1; + self.events + .send_stream_stop_sending(stream_id, application_error_code); + if let (Some(ss), _) = self.obtain_stream(stream_id)? { + ss.reset(application_error_code); + } + } + Frame::Crypto { offset, data } => { + qtrace!( + [self], + "Crypto frame on space={} offset={}, data={:0x?}", + space, + offset, + &data + ); + self.stats.borrow_mut().frame_rx.crypto += 1; + self.crypto.streams.inbound_frame(space, offset, data); + if self.crypto.streams.data_ready(space) { + let mut buf = Vec::new(); + let read = self.crypto.streams.read_to_end(space, &mut buf); + qdebug!("Read {} bytes", read); + self.handshake(now, space, Some(&buf))?; + self.create_resumption_token(now); + } else { + // If we get a useless CRYPTO frame send outstanding CRYPTO frames again. + self.crypto.resend_unacked(space); + } + } + Frame::NewToken { token } => { + self.stats.borrow_mut().frame_rx.new_token += 1; + self.new_token.save_token(token.to_vec()); + self.create_resumption_token(now); + } + Frame::Stream { + fin, + stream_id, + offset, + data, + .. + } => { + self.stats.borrow_mut().frame_rx.stream += 1; + if let (_, Some(rs)) = self.obtain_stream(stream_id)? { + rs.inbound_stream_frame(fin, offset, data)?; + } + } + Frame::MaxData { maximum_data } => { + self.stats.borrow_mut().frame_rx.max_data += 1; + self.handle_max_data(maximum_data); + } + Frame::MaxStreamData { + stream_id, + maximum_stream_data, + } => { + self.stats.borrow_mut().frame_rx.max_stream_data += 1; + if let (Some(ss), _) = self.obtain_stream(stream_id)? { + ss.set_max_stream_data(maximum_stream_data); + } + } + Frame::MaxStreams { + stream_type, + maximum_streams, + } => { + self.stats.borrow_mut().frame_rx.max_streams += 1; + let remote_max = match stream_type { + StreamType::BiDi => &mut self.indexes.remote_max_stream_bidi, + StreamType::UniDi => &mut self.indexes.remote_max_stream_uni, + }; + + if maximum_streams > *remote_max { + *remote_max = maximum_streams; + self.events.send_stream_creatable(stream_type); + } + } + Frame::DataBlocked { data_limit } => { + // Should never happen since we set data limit to max + qwarn!( + [self], + "Received DataBlocked with data limit {}", + data_limit + ); + self.stats.borrow_mut().frame_rx.data_blocked += 1; + // But if it does, open it up all the way + self.flow_mgr.borrow_mut().max_data(LOCAL_MAX_DATA); + } + Frame::StreamDataBlocked { + stream_id, + stream_data_limit, + } => { + self.stats.borrow_mut().frame_rx.stream_data_blocked += 1; + // Terminate connection with STREAM_STATE_ERROR if send-only + // stream (-transport 19.13) + if stream_id.is_send_only(self.role()) { + return Err(Error::StreamStateError); + } + + if let (_, Some(rs)) = self.obtain_stream(stream_id)? { + if let Some(msd) = rs.max_stream_data() { + qinfo!( + [self], + "Got StreamDataBlocked(id {} MSD {}); curr MSD {}", + stream_id.as_u64(), + stream_data_limit, + msd + ); + if stream_data_limit != msd { + self.flow_mgr.borrow_mut().max_stream_data(stream_id, msd) + } + } + } + } + Frame::StreamsBlocked { stream_type, .. } => { + self.stats.borrow_mut().frame_rx.streams_blocked += 1; + let local_max = match stream_type { + StreamType::BiDi => &mut self.indexes.local_max_stream_bidi, + StreamType::UniDi => &mut self.indexes.local_max_stream_uni, + }; + + self.flow_mgr + .borrow_mut() + .max_streams(*local_max, stream_type) + } + Frame::NewConnectionId { + sequence_number, + connection_id, + stateless_reset_token, + .. + } => { + self.stats.borrow_mut().frame_rx.new_connection_id += 1; + let cid = ConnectionId::from(connection_id); + let srt = stateless_reset_token.to_owned(); + self.connection_ids.insert(sequence_number, (cid, srt)); + } + Frame::RetireConnectionId { sequence_number } => { + self.stats.borrow_mut().frame_rx.retire_connection_id += 1; + self.connection_ids.remove(&sequence_number); + } + Frame::PathChallenge { data } => { + self.stats.borrow_mut().frame_rx.path_challenge += 1; + self.flow_mgr.borrow_mut().path_response(data); + } + Frame::PathResponse { .. } => { + // Should never see this, we don't support migration atm and + // do not send path challenges + qwarn!([self], "Received Path Response"); + self.stats.borrow_mut().frame_rx.path_response += 1; + } + Frame::ConnectionClose { + error_code, + frame_type, + reason_phrase, + } => { + self.stats.borrow_mut().frame_rx.connection_close += 1; + let reason_phrase = String::from_utf8_lossy(&reason_phrase); + qinfo!( + [self], + "ConnectionClose received. Error code: {:?} frame type {:x} reason {}", + error_code, + frame_type, + reason_phrase + ); + let (detail, frame_type) = if let CloseError::Application(_) = error_code { + // Use a transport error here because we want to send + // NO_ERROR in this case. + ( + Error::PeerApplicationError(error_code.code()), + FRAME_TYPE_CONNECTION_CLOSE_APPLICATION, + ) + } else { + ( + Error::PeerError(error_code.code()), + FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT, + ) + }; + let error = ConnectionError::Transport(detail); + self.state_signaling.drain(error.clone(), frame_type, ""); + self.set_state(State::Draining { + error, + timeout: self.get_closing_period_time(now), + }); + } + Frame::HandshakeDone => { + self.stats.borrow_mut().frame_rx.handshake_done += 1; + if self.role == Role::Server || !self.state.connected() { + return Err(Error::ProtocolViolation); + } + self.set_state(State::Confirmed); + self.discard_keys(PNSpace::Handshake, now); + } + }; + + Ok(()) + } + + /// Given a set of `SentPacket` instances, ensure that the source of the packet + /// is told that they are lost. This gives the frame generation code a chance + /// to retransmit the frame as needed. + fn handle_lost_packets(&mut self, lost_packets: &[SentPacket]) { + for lost in lost_packets { + for token in &lost.tokens { + qdebug!([self], "Lost: {:?}", token); + match token { + RecoveryToken::Ack(_) => {} + RecoveryToken::Stream(st) => self.send_streams.lost(&st), + RecoveryToken::Crypto(ct) => self.crypto.lost(&ct), + RecoveryToken::Flow(ft) => self.flow_mgr.borrow_mut().lost( + &ft, + &mut self.send_streams, + &mut self.recv_streams, + &mut self.indexes, + ), + RecoveryToken::HandshakeDone => self.state_signaling.handshake_done(), + RecoveryToken::NewToken(seqno) => self.new_token.lost(*seqno), + } + } + } + } + + fn decode_ack_delay(&self, v: u64) -> Duration { + // If we have remote transport parameters, use them. + // Otherwise, ack delay should be zero (because it's the handshake). + if let Some(r) = self.tps.borrow().remote.as_ref() { + let exponent = u32::try_from(r.get_integer(tparams::ACK_DELAY_EXPONENT)).unwrap(); + Duration::from_micros(v.checked_shl(exponent).unwrap_or(u64::MAX)) + } else { + Duration::new(0, 0) + } + } + + fn handle_ack( + &mut self, + space: PNSpace, + largest_acknowledged: u64, + ack_delay: u64, + first_ack_range: u64, + ack_ranges: Vec<AckRange>, + now: Instant, + ) -> Res<()> { + qinfo!( + [self], + "Rx ACK space={}, largest_acked={}, first_ack_range={}, ranges={:?}", + space, + largest_acknowledged, + first_ack_range, + ack_ranges + ); + + let acked_ranges = + Frame::decode_ack_frame(largest_acknowledged, first_ack_range, &ack_ranges)?; + let (acked_packets, lost_packets) = self.loss_recovery.on_ack_received( + space, + largest_acknowledged, + acked_ranges, + self.decode_ack_delay(ack_delay), + now, + ); + for acked in acked_packets { + for token in &acked.tokens { + match token { + RecoveryToken::Ack(at) => self.acks.acked(at), + RecoveryToken::Stream(st) => self.send_streams.acked(st), + RecoveryToken::Crypto(ct) => self.crypto.acked(ct), + RecoveryToken::Flow(ft) => { + self.flow_mgr.borrow_mut().acked(ft, &mut self.send_streams) + } + RecoveryToken::HandshakeDone => (), + RecoveryToken::NewToken(seqno) => self.new_token.acked(*seqno), + } + } + } + self.handle_lost_packets(&lost_packets); + qlog::packets_lost(&mut self.qlog, &lost_packets); + let stats = &mut self.stats.borrow_mut().frame_rx; + stats.ack += 1; + stats.largest_acknowledged = max(stats.largest_acknowledged, largest_acknowledged); + Ok(()) + } + + /// When the server rejects 0-RTT we need to drop a bunch of stuff. + fn client_0rtt_rejected(&mut self) { + if !matches!(self.zero_rtt_state, ZeroRttState::Sending) { + return; + } + qdebug!([self], "0-RTT rejected"); + + // Tell 0-RTT packets that they were "lost". + let dropped = self.loss_recovery.drop_0rtt(); + self.handle_lost_packets(&dropped); + + self.send_streams.clear(); + self.recv_streams.clear(); + self.indexes = StreamIndexes::new(); + self.crypto.states.discard_0rtt_keys(); + self.events.client_0rtt_rejected(); + } + + fn set_connected(&mut self, now: Instant) -> Res<()> { + qinfo!([self], "TLS connection complete"); + if self.crypto.tls.info().map(SecretAgentInfo::alpn).is_none() { + qwarn!([self], "No ALPN. Closing connection."); + // 120 = no_application_protocol + return Err(Error::CryptoAlert(120)); + } + if self.role == Role::Server { + // Remove the randomized client CID from the list of acceptable CIDs. + debug_assert_eq!(1, self.valid_cids.len()); + self.valid_cids.clear(); + // Generate a qlog event that the server connection started. + qlog::server_connection_started(&mut self.qlog, self.path.as_ref().unwrap()); + } else { + self.zero_rtt_state = if self.crypto.tls.info().unwrap().early_data_accepted() { + ZeroRttState::AcceptedClient + } else { + self.client_0rtt_rejected(); + ZeroRttState::Rejected + }; + } + + // Setting application keys has to occur after 0-RTT rejection. + let pto = self.loss_recovery.pto_raw(PNSpace::ApplicationData); + self.crypto.install_application_keys(now + pto)?; + self.process_tps()?; + self.set_state(State::Connected); + self.create_resumption_token(now); + self.saved_datagrams + .make_available(CryptoSpace::ApplicationData); + self.stats.borrow_mut().resumed = self.crypto.tls.info().unwrap().resumed(); + if self.role == Role::Server { + self.state_signaling.handshake_done(); + self.set_state(State::Confirmed); + } + qinfo!([self], "Connection established"); + Ok(()) + } + + fn set_state(&mut self, state: State) { + if state > self.state { + qinfo!([self], "State change from {:?} -> {:?}", self.state, state); + self.state = state.clone(); + if self.state.closed() { + self.send_streams.clear(); + self.recv_streams.clear(); + } + self.events.connection_state_change(state); + qlog::connection_state_updated(&mut self.qlog, &self.state) + } else if mem::discriminant(&state) != mem::discriminant(&self.state) { + // Only tolerate a regression in state if the new state is closing + // and the connection is already closed. + debug_assert!(matches!(state, State::Closing { .. } | State::Draining { .. })); + debug_assert!(self.state.closed()); + } + } + + fn cleanup_streams(&mut self) { + self.send_streams.clear_terminal(); + let recv_to_remove = self + .recv_streams + .iter() + .filter_map(|(id, stream)| { + // Remove all streams for which the receiving is done (or aborted). + // But only if they are unidirectional, or we have finished sending. + if stream.is_terminal() && (id.is_uni() || !self.send_streams.exists(*id)) { + Some(*id) + } else { + None + } + }) + .collect::<Vec<_>>(); + + let mut removed_bidi = 0; + let mut removed_uni = 0; + for id in &recv_to_remove { + self.recv_streams.remove(&id); + if id.is_remote_initiated(self.role()) { + if id.is_bidi() { + removed_bidi += 1; + } else { + removed_uni += 1; + } + } + } + + // Send max_streams updates if we removed remote-initiated recv streams. + if removed_bidi > 0 { + self.indexes.local_max_stream_bidi += removed_bidi; + self.flow_mgr + .borrow_mut() + .max_streams(self.indexes.local_max_stream_bidi, StreamType::BiDi) + } + if removed_uni > 0 { + self.indexes.local_max_stream_uni += removed_uni; + self.flow_mgr + .borrow_mut() + .max_streams(self.indexes.local_max_stream_uni, StreamType::UniDi) + } + } + + /// Get or make a stream, and implicitly open additional streams as + /// indicated by its stream id. + fn obtain_stream( + &mut self, + stream_id: StreamId, + ) -> Res<(Option<&mut SendStream>, Option<&mut RecvStream>)> { + if !self.state.connected() + && !matches!( + (&self.state, &self.zero_rtt_state), + (State::Handshaking, ZeroRttState::AcceptedServer) + ) + { + return Err(Error::ConnectionState); + } + + // May require creating new stream(s) + if stream_id.is_remote_initiated(self.role()) { + let next_stream_idx = if stream_id.is_bidi() { + &mut self.indexes.local_next_stream_bidi + } else { + &mut self.indexes.local_next_stream_uni + }; + let stream_idx: StreamIndex = stream_id.into(); + + if stream_idx >= *next_stream_idx { + let recv_initial_max_stream_data = if stream_id.is_bidi() { + if stream_idx > self.indexes.local_max_stream_bidi { + qwarn!( + [self], + "remote bidi stream create blocked, next={:?} max={:?}", + stream_idx, + self.indexes.local_max_stream_bidi + ); + return Err(Error::StreamLimitError); + } + // From the local perspective, this is a remote- originated BiDi stream. From + // the remote perspective, this is a local-originated BiDi stream. Therefore, + // look at the local transport parameters for the + // INITIAL_MAX_STREAM_DATA_BIDI_REMOTE value to decide how much this endpoint + // will allow its peer to send. + self.tps + .borrow() + .local + .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE) + } else { + if stream_idx > self.indexes.local_max_stream_uni { + qwarn!( + [self], + "remote uni stream create blocked, next={:?} max={:?}", + stream_idx, + self.indexes.local_max_stream_uni + ); + return Err(Error::StreamLimitError); + } + self.tps + .borrow() + .local + .get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI) + }; + + loop { + let next_stream_id = + next_stream_idx.to_stream_id(stream_id.stream_type(), stream_id.role()); + self.events.new_stream(next_stream_id); + + self.recv_streams.insert( + next_stream_id, + RecvStream::new( + next_stream_id, + recv_initial_max_stream_data, + self.flow_mgr.clone(), + self.events.clone(), + ), + ); + + if next_stream_id.is_bidi() { + // From the local perspective, this is a remote- originated BiDi stream. + // From the remote perspective, this is a local-originated BiDi stream. + // Therefore, look at the remote's transport parameters for the + // INITIAL_MAX_STREAM_DATA_BIDI_LOCAL value to decide how much this endpoint + // is allowed to send its peer. + let send_initial_max_stream_data = self + .tps + .borrow() + .remote() + .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL); + self.send_streams.insert( + next_stream_id, + SendStream::new( + next_stream_id, + send_initial_max_stream_data, + self.flow_mgr.clone(), + self.events.clone(), + ), + ); + } + + *next_stream_idx += 1; + if *next_stream_idx > stream_idx { + break; + } + } + } + } + + Ok(( + self.send_streams.get_mut(stream_id).ok(), + self.recv_streams.get_mut(&stream_id), + )) + } + + /// Create a stream. + /// Returns new stream id + /// # Errors + /// `ConnectionState` if the connecton stat does not allow to create streams. + /// `StreamLimitError` if we are limiied by server's stream concurence. + pub fn stream_create(&mut self, st: StreamType) -> Res<u64> { + // Can't make streams while closing, otherwise rely on the stream limits. + match self.state { + State::Closing { .. } | State::Draining { .. } | State::Closed { .. } => { + return Err(Error::ConnectionState); + } + State::WaitInitial | State::Handshaking => { + if self.role == Role::Client && self.zero_rtt_state != ZeroRttState::Sending { + return Err(Error::ConnectionState); + } + } + // In all other states, trust that the stream limits are correct. + _ => (), + } + + Ok(match st { + StreamType::UniDi => { + if self.indexes.remote_next_stream_uni >= self.indexes.remote_max_stream_uni { + self.flow_mgr + .borrow_mut() + .streams_blocked(self.indexes.remote_max_stream_uni, StreamType::UniDi); + qwarn!( + [self], + "local uni stream create blocked, next={:?} max={:?}", + self.indexes.remote_next_stream_uni, + self.indexes.remote_max_stream_uni + ); + return Err(Error::StreamLimitError); + } + let new_id = self + .indexes + .remote_next_stream_uni + .to_stream_id(StreamType::UniDi, self.role); + self.indexes.remote_next_stream_uni += 1; + let initial_max_stream_data = self + .tps + .borrow() + .remote() + .get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI); + + self.send_streams.insert( + new_id, + SendStream::new( + new_id, + initial_max_stream_data, + self.flow_mgr.clone(), + self.events.clone(), + ), + ); + new_id.as_u64() + } + StreamType::BiDi => { + if self.indexes.remote_next_stream_bidi >= self.indexes.remote_max_stream_bidi { + self.flow_mgr + .borrow_mut() + .streams_blocked(self.indexes.remote_max_stream_bidi, StreamType::BiDi); + qwarn!( + [self], + "local bidi stream create blocked, next={:?} max={:?}", + self.indexes.remote_next_stream_bidi, + self.indexes.remote_max_stream_bidi + ); + return Err(Error::StreamLimitError); + } + let new_id = self + .indexes + .remote_next_stream_bidi + .to_stream_id(StreamType::BiDi, self.role); + self.indexes.remote_next_stream_bidi += 1; + // From the local perspective, this is a local- originated BiDi stream. From the + // remote perspective, this is a remote-originated BiDi stream. Therefore, look at + // the remote transport parameters for the INITIAL_MAX_STREAM_DATA_BIDI_REMOTE value + // to decide how much this endpoint is allowed to send its peer. + let send_initial_max_stream_data = self + .tps + .borrow() + .remote() + .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE); + + self.send_streams.insert( + new_id, + SendStream::new( + new_id, + send_initial_max_stream_data, + self.flow_mgr.clone(), + self.events.clone(), + ), + ); + // From the local perspective, this is a local- originated BiDi stream. From the + // remote perspective, this is a remote-originated BiDi stream. Therefore, look at + // the local transport parameters for the INITIAL_MAX_STREAM_DATA_BIDI_LOCAL value + // to decide how much this endpoint will allow its peer to send. + let recv_initial_max_stream_data = self + .tps + .borrow() + .local + .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL); + + self.recv_streams.insert( + new_id, + RecvStream::new( + new_id, + recv_initial_max_stream_data, + self.flow_mgr.clone(), + self.events.clone(), + ), + ); + new_id.as_u64() + } + }) + } + + /// Send data on a stream. + /// Returns how many bytes were successfully sent. Could be less + /// than total, based on receiver credit space available, etc. + /// # Errors + /// `InvalidStreamId` the stream does not exist, + /// `InvalidInput` if length of `data` is zero, + /// `FinalSizeError` if the stream has already been closed. + pub fn stream_send(&mut self, stream_id: u64, data: &[u8]) -> Res<usize> { + self.send_streams.get_mut(stream_id.into())?.send(data) + } + + /// Send all data or nothing on a stream. May cause DATA_BLOCKED or + /// STREAM_DATA_BLOCKED frames to be sent. + /// Returns true if data was successfully sent, otherwise false. + /// # Errors + /// `InvalidStreamId` the stream does not exist, + /// `InvalidInput` if length of `data` is zero, + /// `FinalSizeError` if the stream has already been closed. + pub fn stream_send_atomic(&mut self, stream_id: u64, data: &[u8]) -> Res<bool> { + let val = self + .send_streams + .get_mut(stream_id.into())? + .send_atomic(data); + if let Ok(val) = val { + debug_assert!( + val == 0 || val == data.len(), + "Unexpected value {} when trying to send {} bytes atomically", + val, + data.len() + ); + } + val.map(|v| v == data.len()) + } + + /// Bytes that stream_send() is guaranteed to accept for sending. + /// i.e. that will not be blocked by flow credits or send buffer max + /// capacity. + pub fn stream_avail_send_space(&self, stream_id: u64) -> Res<usize> { + Ok(self.send_streams.get(stream_id.into())?.avail()) + } + + /// Close the stream. Enqueued data will be sent. + pub fn stream_close_send(&mut self, stream_id: u64) -> Res<()> { + self.send_streams.get_mut(stream_id.into())?.close(); + Ok(()) + } + + /// Abandon transmission of in-flight and future stream data. + pub fn stream_reset_send(&mut self, stream_id: u64, err: AppError) -> Res<()> { + self.send_streams.get_mut(stream_id.into())?.reset(err); + Ok(()) + } + + /// Read buffered data from stream. bool says whether read bytes includes + /// the final data on stream. + /// # Errors + /// `InvalidStreamId` if the stream does not exist. + /// `NoMoreData` if data and fin bit were previously read by the application. + pub fn stream_recv(&mut self, stream_id: u64, data: &mut [u8]) -> Res<(usize, bool)> { + let stream = self + .recv_streams + .get_mut(&stream_id.into()) + .ok_or(Error::InvalidStreamId)?; + + let rb = stream.read(data)?; + Ok((rb.0 as usize, rb.1)) + } + + /// Application is no longer interested in this stream. + pub fn stream_stop_sending(&mut self, stream_id: u64, err: AppError) -> Res<()> { + let stream = self + .recv_streams + .get_mut(&stream_id.into()) + .ok_or(Error::InvalidStreamId)?; + + stream.stop_sending(err); + Ok(()) + } + + #[cfg(test)] + pub fn get_pto(&self) -> Duration { + self.loss_recovery.pto_raw(PNSpace::ApplicationData) + } +} + +impl EventProvider for Connection { + type Event = ConnectionEvent; + + /// Return true if there are outstanding events. + fn has_events(&self) -> bool { + self.events.has_events() + } + + /// Get events that indicate state changes on the connection. This method + /// correctly handles cases where handling one event can obsolete + /// previously-queued events, or cause new events to be generated. + fn next_event(&mut self) -> Option<Self::Event> { + self.events.next_event() + } +} + +impl ::std::fmt::Display for Connection { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{:?} ", self.role)?; + if let Some(cid) = self.odcid() { + std::fmt::Display::fmt(&cid, f) + } else { + write!(f, "...") + } + } +} + +#[cfg(test)] +mod tests; diff --git a/third_party/rust/neqo-transport/src/connection/params.rs b/third_party/rust/neqo-transport/src/connection/params.rs new file mode 100644 index 0000000000..c4404b54d9 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/params.rs @@ -0,0 +1,71 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::frame::StreamType; +use crate::{ + CongestionControlAlgorithm, QuicVersion, LOCAL_STREAM_LIMIT_BIDI, LOCAL_STREAM_LIMIT_UNI, +}; + +/// ConnectionParameters use for setting intitial value for QUIC parameters. +/// This collect like initial limits, protocol version and congestion control. +#[derive(Clone)] +pub struct ConnectionParameters { + quic_version: QuicVersion, + cc_algorithm: CongestionControlAlgorithm, + max_streams_bidi: u64, + max_streams_uni: u64, +} + +impl Default for ConnectionParameters { + fn default() -> Self { + Self { + quic_version: QuicVersion::default(), + cc_algorithm: CongestionControlAlgorithm::NewReno, + max_streams_bidi: LOCAL_STREAM_LIMIT_BIDI, + max_streams_uni: LOCAL_STREAM_LIMIT_UNI, + } + } +} + +impl ConnectionParameters { + pub fn get_quic_version(&self) -> QuicVersion { + self.quic_version + } + + pub fn quic_version(mut self, v: QuicVersion) -> Self { + self.quic_version = v; + self + } + + pub fn get_cc_algorithm(&self) -> CongestionControlAlgorithm { + self.cc_algorithm + } + + pub fn cc_algorithm(mut self, v: CongestionControlAlgorithm) -> Self { + self.cc_algorithm = v; + self + } + + pub fn get_max_streams(&self, stream_type: StreamType) -> u64 { + match stream_type { + StreamType::BiDi => self.max_streams_bidi, + StreamType::UniDi => self.max_streams_uni, + } + } + + pub fn max_streams(mut self, stream_type: StreamType, v: u64) -> Self { + assert!(v <= (1 << 60), "max_streams's parameter too big"); + match stream_type { + StreamType::BiDi => { + self.max_streams_bidi = v; + } + StreamType::UniDi => { + self.max_streams_uni = v; + } + } + self + } +} diff --git a/third_party/rust/neqo-transport/src/connection/saved.rs b/third_party/rust/neqo-transport/src/connection/saved.rs new file mode 100644 index 0000000000..b2da0c644a --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/saved.rs @@ -0,0 +1,72 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::mem; +use std::time::Instant; + +use crate::crypto::CryptoSpace; +use neqo_common::{qdebug, qinfo, Datagram}; + +/// The number of datagrams that are saved during the handshake when +/// keys to decrypt them are not yet available. +/// +/// This value exceeds what should be possible to send during the handshake. +/// Neither endpoint should have enough congestion window to send this +/// much before the handshake completes. +const MAX_SAVED_DATAGRAMS: usize = 32; + +pub struct SavedDatagram { + /// The datagram. + pub d: Datagram, + /// The time that the datagram was received. + pub t: Instant, +} + +#[derive(Default)] +pub struct SavedDatagrams { + handshake: Vec<SavedDatagram>, + application_data: Vec<SavedDatagram>, + available: Option<CryptoSpace>, +} + +impl SavedDatagrams { + fn store(&mut self, cspace: CryptoSpace) -> &mut Vec<SavedDatagram> { + match cspace { + CryptoSpace::Handshake => &mut self.handshake, + CryptoSpace::ApplicationData => &mut self.application_data, + _ => panic!("unexpected space"), + } + } + + pub fn save(&mut self, cspace: CryptoSpace, d: Datagram, t: Instant) { + let store = self.store(cspace); + + if store.len() < MAX_SAVED_DATAGRAMS { + qdebug!("saving datagram of {} bytes", d.len()); + store.push(SavedDatagram { d, t }); + } else { + qinfo!("not saving datagram of {} bytes", d.len()); + } + } + + pub fn make_available(&mut self, cspace: CryptoSpace) { + debug_assert_ne!(cspace, CryptoSpace::ZeroRtt); + debug_assert_ne!(cspace, CryptoSpace::Initial); + if !self.store(cspace).is_empty() { + self.available = Some(cspace); + } + } + + pub fn available(&self) -> Option<CryptoSpace> { + self.available + } + + pub fn take_saved(&mut self) -> Vec<SavedDatagram> { + self.available + .take() + .map_or_else(Vec::new, |cspace| mem::take(self.store(cspace))) + } +} diff --git a/third_party/rust/neqo-transport/src/connection/state.rs b/third_party/rust/neqo-transport/src/connection/state.rs new file mode 100644 index 0000000000..1b38c1650c --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/state.rs @@ -0,0 +1,207 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::cmp::Ordering; +use std::mem; +use std::time::Instant; + +use crate::frame::{Frame, FrameType}; +use crate::packet::PacketBuilder; +use crate::recovery::RecoveryToken; +use crate::{CloseError, ConnectionError}; + +#[derive(Clone, Debug, PartialEq, Eq)] +/// The state of the Connection. +pub enum State { + Init, + WaitInitial, + Handshaking, + Connected, + Confirmed, + Closing { + error: ConnectionError, + timeout: Instant, + }, + Draining { + error: ConnectionError, + timeout: Instant, + }, + Closed(ConnectionError), +} + +impl State { + #[must_use] + pub fn connected(&self) -> bool { + matches!(self, Self::Connected | Self::Confirmed) + } + + #[must_use] + pub fn closed(&self) -> bool { + matches!(self, Self::Closing { .. } | Self::Draining { .. } | Self::Closed(_)) + } +} + +// Implement `PartialOrd` so that we can enforce monotonic state progression. +impl PartialOrd for State { + #[allow(clippy::match_same_arms)] // Lint bug: rust-lang/rust-clippy#860 + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + if mem::discriminant(self) == mem::discriminant(other) { + return Some(Ordering::Equal); + } + Some(match (self, other) { + (Self::Init, _) => Ordering::Less, + (_, Self::Init) => Ordering::Greater, + (Self::WaitInitial, _) => Ordering::Less, + (_, Self::WaitInitial) => Ordering::Greater, + (Self::Handshaking, _) => Ordering::Less, + (_, Self::Handshaking) => Ordering::Greater, + (Self::Connected, _) => Ordering::Less, + (_, Self::Connected) => Ordering::Greater, + (Self::Confirmed, _) => Ordering::Less, + (_, Self::Confirmed) => Ordering::Greater, + (Self::Closing { .. }, _) => Ordering::Less, + (_, Self::Closing { .. }) => Ordering::Greater, + (Self::Draining { .. }, _) => Ordering::Less, + (_, Self::Draining { .. }) => Ordering::Greater, + (Self::Closed(_), _) => unreachable!(), + }) + } +} + +impl Ord for State { + fn cmp(&self, other: &Self) -> Ordering { + if mem::discriminant(self) == mem::discriminant(other) { + return Ordering::Equal; + } + match (self, other) { + (Self::Init, _) => Ordering::Less, + (_, Self::Init) => Ordering::Greater, + (Self::WaitInitial, _) => Ordering::Less, + (_, Self::WaitInitial) => Ordering::Greater, + (Self::Handshaking, _) => Ordering::Less, + (_, Self::Handshaking) => Ordering::Greater, + (Self::Connected, _) => Ordering::Less, + (_, Self::Connected) => Ordering::Greater, + (Self::Confirmed, _) => Ordering::Less, + (_, Self::Confirmed) => Ordering::Greater, + (Self::Closing { .. }, _) => Ordering::Less, + (_, Self::Closing { .. }) => Ordering::Greater, + (Self::Draining { .. }, _) => Ordering::Less, + (_, Self::Draining { .. }) => Ordering::Greater, + (Self::Closed(_), _) => unreachable!(), + } + } +} + +type ClosingFrame = Frame<'static>; + +/// `StateSignaling` manages whether we need to send HANDSHAKE_DONE and CONNECTION_CLOSE. +/// Valid state transitions are: +/// * Idle -> HandshakeDone: at the server when the handshake completes +/// * HandshakeDone -> Idle: when a HANDSHAKE_DONE frame is sent +/// * Idle/HandshakeDone -> Closing/Draining: when closing or draining +/// * Closing/Draining -> CloseSent: after sending CONNECTION_CLOSE +/// * CloseSent -> Closing: any time a new CONNECTION_CLOSE is needed +/// * -> Reset: from any state in case of a stateless reset +#[derive(Debug, Clone, PartialEq)] +pub enum StateSignaling { + Idle, + HandshakeDone, + /// These states save the frame that needs to be sent. + Closing(ClosingFrame), + Draining(ClosingFrame), + /// This state saves the frame that might need to be sent again. + /// If it is `None`, then we are draining and don't send. + CloseSent(Option<ClosingFrame>), + Reset, +} + +impl StateSignaling { + pub fn handshake_done(&mut self) { + if *self != Self::Idle { + debug_assert!(false, "StateSignaling must be in Idle state."); + return; + } + *self = Self::HandshakeDone + } + + pub fn write_done(&mut self, builder: &mut PacketBuilder) -> Option<RecoveryToken> { + if *self == Self::HandshakeDone && builder.remaining() >= 1 { + *self = Self::Idle; + builder.encode_varint(Frame::HandshakeDone.get_type()); + Some(RecoveryToken::HandshakeDone) + } else { + None + } + } + + fn make_close_frame( + error: ConnectionError, + frame_type: FrameType, + message: impl AsRef<str>, + ) -> ClosingFrame { + let reason_phrase = message.as_ref().as_bytes().to_owned(); + Frame::ConnectionClose { + error_code: CloseError::from(error), + frame_type, + reason_phrase, + } + } + + pub fn close( + &mut self, + error: ConnectionError, + frame_type: FrameType, + message: impl AsRef<str>, + ) { + if *self != Self::Reset { + *self = Self::Closing(Self::make_close_frame(error, frame_type, message)); + } + } + + pub fn drain( + &mut self, + error: ConnectionError, + frame_type: FrameType, + message: impl AsRef<str>, + ) { + if *self != Self::Reset { + *self = Self::Draining(Self::make_close_frame(error, frame_type, message)); + } + } + + /// If a close is pending, take a frame. + pub fn close_frame(&mut self) -> Option<ClosingFrame> { + match self { + Self::Closing(frame) => { + // When we are closing, we might need to send the close frame again. + let frame = mem::replace(frame, Frame::Padding); + *self = Self::CloseSent(Some(frame.clone())); + Some(frame) + } + Self::Draining(frame) => { + // When we are draining, just send once. + let frame = mem::replace(frame, Frame::Padding); + *self = Self::CloseSent(None); + Some(frame) + } + _ => None, + } + } + + /// If a close can be sent again, prepare to send it again. + pub fn send_close(&mut self) { + if let Self::CloseSent(Some(frame)) = self { + let frame = mem::replace(frame, Frame::Padding); + *self = Self::Closing(frame); + } + } + + /// We just got a stateless reset. Terminate. + pub fn reset(&mut self) { + *self = Self::Reset; + } +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/cc.rs b/third_party/rust/neqo-transport/src/connection/tests/cc.rs new file mode 100644 index 0000000000..f69234a357 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/cc.rs @@ -0,0 +1,526 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::super::{Connection, Output}; +use super::{ + assert_full_cwnd, connect_force_idle, connect_rtt_idle, cwnd_packets, default_client, + default_server, fill_cwnd, send_something, AT_LEAST_PTO, DEFAULT_RTT, POST_HANDSHAKE_CWND, +}; +use crate::cc::{CWND_MIN, MAX_DATAGRAM_SIZE}; +use crate::frame::StreamType; +use crate::packet::PacketNumber; +use crate::recovery::{ACK_ONLY_SIZE_LIMIT, PACKET_THRESHOLD}; +use crate::sender::PACING_BURST_SIZE; +use crate::stats::MAX_PTO_COUNTS; +use crate::tparams::{self, TransportParameter}; +use crate::tracking::MAX_UNACKED_PKTS; + +use neqo_common::{qdebug, qinfo, qtrace, Datagram}; +use std::convert::TryFrom; +use std::time::{Duration, Instant}; +use test_fixture::{self, now}; + +fn induce_persistent_congestion( + client: &mut Connection, + server: &mut Connection, + mut now: Instant, +) -> Instant { + // Note: wait some arbitrary time that should be longer than pto + // timer. This is rather brittle. + now += AT_LEAST_PTO; + + let mut pto_counts = [0; MAX_PTO_COUNTS]; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + qtrace!([client], "first PTO"); + let (c_tx_dgrams, next_now) = fill_cwnd(client, 0, now); + now = next_now; + assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets + + pto_counts[0] = 1; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + qtrace!([client], "second PTO"); + now += AT_LEAST_PTO * 2; + let (c_tx_dgrams, next_now) = fill_cwnd(client, 0, now); + now = next_now; + assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets + + pto_counts[0] = 0; + pto_counts[1] = 1; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + qtrace!([client], "third PTO"); + now += AT_LEAST_PTO * 4; + let (c_tx_dgrams, next_now) = fill_cwnd(client, 0, now); + now = next_now; + assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets + + pto_counts[1] = 0; + pto_counts[2] = 1; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + // Generate ACK + let s_tx_dgram = ack_bytes(server, 0, c_tx_dgrams, now); + + // An ACK for the third PTO causes persistent congestion. + for dgram in s_tx_dgram { + client.process_input(dgram, now); + } + + assert_eq!(client.loss_recovery.cwnd(), CWND_MIN); + now +} + +// Receive multiple packets and generate an ack-only packet. +fn ack_bytes<D>(dest: &mut Connection, stream: u64, in_dgrams: D, now: Instant) -> Vec<Datagram> +where + D: IntoIterator<Item = Datagram>, + D::IntoIter: ExactSizeIterator, +{ + let mut srv_buf = [0; 4_096]; + + let in_dgrams = in_dgrams.into_iter(); + qdebug!([dest], "ack_bytes {} datagrams", in_dgrams.len()); + for dgram in in_dgrams { + dest.process_input(dgram, now); + } + + loop { + let (bytes_read, _fin) = dest.stream_recv(stream, &mut srv_buf).unwrap(); + qtrace!([dest], "ack_bytes read {} bytes", bytes_read); + if bytes_read == 0 { + break; + } + } + + let mut tx_dgrams = Vec::new(); + while let Output::Datagram(dg) = dest.process_output(now) { + tx_dgrams.push(dg); + } + + assert!((tx_dgrams.len() == 1) || (tx_dgrams.len() == 2)); + tx_dgrams +} + +#[test] +/// Verify initial CWND is honored. +fn cc_slow_start() { + let mut client = default_client(); + let mut server = default_server(); + + server + .set_local_tparam( + tparams::INITIAL_MAX_DATA, + TransportParameter::Integer(65536), + ) + .unwrap(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Try to send a lot of data + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); + let (c_tx_dgrams, _) = fill_cwnd(&mut client, 2, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + assert!(client.loss_recovery.cwnd_avail() < ACK_ONLY_SIZE_LIMIT); +} + +#[test] +/// Verify that CC moves to cong avoidance when a packet is marked lost. +fn cc_slow_start_to_cong_avoidance_recovery_period() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Create stream 0 + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0); + + // Buffer up lot of data and generate packets + let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + // Predict the packet number of the last packet sent. + // We have already sent one packet in `connect_force_idle` (an ACK), + // so this will be equal to the number of packets in this flight. + let flight1_largest = PacketNumber::try_from(c_tx_dgrams.len()).unwrap(); + + // Server: Receive and generate ack + now += DEFAULT_RTT / 2; + let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now); + assert_eq!( + server.stats().frame_tx.largest_acknowledged, + flight1_largest + ); + + // Client: Process ack + now += DEFAULT_RTT / 2; + for dgram in s_tx_dgram { + client.process_input(dgram, now); + } + assert_eq!( + client.stats().frame_rx.largest_acknowledged, + flight1_largest + ); + + // Client: send more + let (mut c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND * 2); + let flight2_largest = flight1_largest + u64::try_from(c_tx_dgrams.len()).unwrap(); + + // Server: Receive and generate ack again, but drop first packet + now += DEFAULT_RTT / 2; + c_tx_dgrams.remove(0); + let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now); + assert_eq!( + server.stats().frame_tx.largest_acknowledged, + flight2_largest + ); + + // Client: Process ack + now += DEFAULT_RTT / 2; + for dgram in s_tx_dgram { + client.process_input(dgram, now); + } + assert_eq!( + client.stats().frame_rx.largest_acknowledged, + flight2_largest + ); +} + +#[test] +/// Verify that CC stays in recovery period when packet sent before start of +/// recovery period is acked. +fn cc_cong_avoidance_recovery_period_unchanged() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + // Create stream 0 + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0); + + // Buffer up lot of data and generate packets + let (mut c_tx_dgrams, now) = fill_cwnd(&mut client, 0, now()); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + + // Drop 0th packet. When acked, this should put client into CARP. + c_tx_dgrams.remove(0); + + let c_tx_dgrams2 = c_tx_dgrams.split_off(5); + + // Server: Receive and generate ack + let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now); + for dgram in s_tx_dgram { + client.process_input(dgram, now); + } + + let cwnd1 = client.loss_recovery.cwnd(); + + // Generate ACK for more received packets + let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams2, now); + + // ACK more packets but they were sent before end of recovery period + for dgram in s_tx_dgram { + client.process_input(dgram, now); + } + + // cwnd should not have changed since ACKed packets were sent before + // recovery period expired + let cwnd2 = client.loss_recovery.cwnd(); + assert_eq!(cwnd1, cwnd2); +} + +#[test] +/// Ensure that a single packet is sent after entering recovery, even +/// when that exceeds the available congestion window. +fn single_packet_on_recovery() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + // Drop a few packets, up to the reordering threshold. + for _ in 0..PACKET_THRESHOLD { + let _dropped = send_something(&mut client, now()); + } + let delivered = send_something(&mut client, now()); + + // Now fill the congestion window. + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0); + let _ = fill_cwnd(&mut client, 0, now()); + assert!(client.loss_recovery.cwnd_avail() < ACK_ONLY_SIZE_LIMIT); + + // Acknowledge just one packet and cause one packet to be declared lost. + // The length is the amount of credit the client should have. + let ack = server.process(Some(delivered), now()).dgram(); + assert!(ack.is_some()); + + // The client should see the loss and enter recovery. + // As there are many outstanding packets, there should be no available cwnd. + client.process_input(ack.unwrap(), now()); + assert_eq!(client.loss_recovery.cwnd_avail(), 0); + + // The client should send one packet, ignoring the cwnd. + let dgram = client.process_output(now()).dgram(); + assert!(dgram.is_some()); +} + +#[test] +/// Verify that CC moves out of recovery period when packet sent after start +/// of recovery period is acked. +fn cc_cong_avoidance_recovery_period_to_cong_avoidance() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Create stream 0 + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0); + + // Buffer up lot of data and generate packets + let (mut c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now); + + // Drop 0th packet. When acked, this should put client into CARP. + c_tx_dgrams.remove(0); + + // Server: Receive and generate ack + now += DEFAULT_RTT / 2; + let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now); + + // Client: Process ack + now += DEFAULT_RTT / 2; + for dgram in s_tx_dgram { + client.process_input(dgram, now); + } + + // Should be in CARP now. + now += DEFAULT_RTT / 2; + qinfo!( + "moving to congestion avoidance {}", + client.loss_recovery.cwnd() + ); + + // Now make sure that we increase congestion window according to the + // accurate byte counting version of congestion avoidance. + // Check over several increases to be sure. + let mut expected_cwnd = client.loss_recovery.cwnd(); + // Fill cwnd. + let (mut c_tx_dgrams, next_now) = fill_cwnd(&mut client, 0, now); + now = next_now; + for i in 0..5 { + qinfo!("iteration {}", i); + + let c_tx_size: usize = c_tx_dgrams.iter().map(|d| d.len()).sum(); + qinfo!( + "client sending {} bytes into cwnd of {}", + c_tx_size, + client.loss_recovery.cwnd() + ); + assert_eq!(c_tx_size, expected_cwnd); + + // As acks arrive we will continue filling cwnd and save all packets + // from this cycle will be stored in next_c_tx_dgrams. + let mut next_c_tx_dgrams: Vec<Datagram> = Vec::new(); + + // Until we process all the packets, the congestion window remains the same. + // Note that we need the client to process ACK frames in stages, so split the + // datagrams into two, ensuring that we allow for an ACK for each batch. + let most = c_tx_dgrams.len() - MAX_UNACKED_PKTS - 1; + let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams.drain(..most), now); + for dgram in s_tx_dgram { + assert_eq!(client.loss_recovery.cwnd(), expected_cwnd); + client.process_input(dgram, now); + // make sure to fill cwnd again. + let (mut new_pkts, next_now) = fill_cwnd(&mut client, 0, now); + now = next_now; + next_c_tx_dgrams.append(&mut new_pkts); + } + let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now); + for dgram in s_tx_dgram { + assert_eq!(client.loss_recovery.cwnd(), expected_cwnd); + client.process_input(dgram, now); + // make sure to fill cwnd again. + let (mut new_pkts, next_now) = fill_cwnd(&mut client, 0, now); + now = next_now; + next_c_tx_dgrams.append(&mut new_pkts); + } + expected_cwnd += MAX_DATAGRAM_SIZE; + assert_eq!(client.loss_recovery.cwnd(), expected_cwnd); + c_tx_dgrams = next_c_tx_dgrams; + } +} + +#[test] +/// Verify transition to persistent congestion state if conditions are met. +fn cc_slow_start_to_persistent_congestion_no_acks() { + let mut client = default_client(); + let mut server = default_server(); + let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); + + // Create stream 0 + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0); + + // Buffer up lot of data and generate packets + let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + + // Server: Receive and generate ack + now += DEFAULT_RTT / 2; + let _ = ack_bytes(&mut server, 0, c_tx_dgrams, now); + + // ACK lost. + induce_persistent_congestion(&mut client, &mut server, now); +} + +#[test] +/// Verify transition to persistent congestion state if conditions are met. +fn cc_slow_start_to_persistent_congestion_some_acks() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + // Create stream 0 + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0); + + // Buffer up lot of data and generate packets + let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now()); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + + // Server: Receive and generate ack + now += Duration::from_millis(100); + let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now); + + now += Duration::from_millis(100); + for dgram in s_tx_dgram { + client.process_input(dgram, now); + } + + // send bytes that will be lost + let (_, next_now) = fill_cwnd(&mut client, 0, now); + now = next_now + Duration::from_millis(100); + + induce_persistent_congestion(&mut client, &mut server, now); +} + +#[test] +/// Verify persistent congestion moves to slow start after recovery period +/// ends. +fn cc_persistent_congestion_to_slow_start() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + // Create stream 0 + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0); + + // Buffer up lot of data and generate packets + let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now()); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + + // Server: Receive and generate ack + now += Duration::from_millis(10); + let _ = ack_bytes(&mut server, 0, c_tx_dgrams, now); + + // ACK lost. + + now = induce_persistent_congestion(&mut client, &mut server, now); + + // New part of test starts here + + now += Duration::from_millis(10); + + // Send packets from after start of CARP + let (c_tx_dgrams, next_now) = fill_cwnd(&mut client, 0, now); + assert_eq!(c_tx_dgrams.len(), 2); + + // Server: Receive and generate ack + now = next_now + Duration::from_millis(100); + let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now); + + // No longer in CARP. (pkts acked from after start of CARP) + // Should be in slow start now. + for dgram in s_tx_dgram { + client.process_input(dgram, now); + } + + // ACKing 2 packets should let client send 4. + let (c_tx_dgrams, _) = fill_cwnd(&mut client, 0, now); + assert_eq!(c_tx_dgrams.len(), 4); +} + +#[test] +fn ack_are_not_cc() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + // Create a stream + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0); + + // Buffer up lot of data and generate packets, so that cc window is filled. + let (c_tx_dgrams, now) = fill_cwnd(&mut client, 0, now()); + assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND); + + // The server hasn't received any of these packets yet, the server + // won't ACK, but if it sends an ack-eliciting packet instead. + qdebug!([server], "Sending ack-eliciting"); + assert_eq!(server.stream_create(StreamType::BiDi).unwrap(), 1); + server.stream_send(1, b"dropped").unwrap(); + let dropped_packet = server.process(None, now).dgram(); + assert!(dropped_packet.is_some()); // Now drop this one. + + // Now the server sends a packet that will force an ACK, + // because the client will detect a gap. + server.stream_send(1, b"sent").unwrap(); + let ack_eliciting_packet = server.process(None, now).dgram(); + assert!(ack_eliciting_packet.is_some()); + + // The client can ack the server packet even if cc windows is full. + qdebug!([client], "Process ack-eliciting"); + let ack_pkt = client.process(ack_eliciting_packet, now).dgram(); + assert!(ack_pkt.is_some()); + qdebug!([server], "Handle ACK"); + let prev_ack_count = server.stats().frame_rx.ack; + server.process_input(ack_pkt.unwrap(), now); + assert_eq!(server.stats().frame_rx.ack, prev_ack_count + 1); +} + +#[test] +fn pace() { + const RTT: Duration = Duration::from_millis(1000); + const DATA: &[u8] = &[0xcc; 4_096]; + let mut client = default_client(); + let mut server = default_server(); + let mut now = connect_rtt_idle(&mut client, &mut server, RTT); + + // Now fill up the pipe and watch it trickle out. + let stream = client.stream_create(StreamType::BiDi).unwrap(); + loop { + let written = client.stream_send(stream, DATA).unwrap(); + if written < DATA.len() { + break; + } + } + let mut count = 0; + // We should get a burst at first. + for _ in 0..PACING_BURST_SIZE { + let dgram = client.process_output(now).dgram(); + assert!(dgram.is_some()); + count += 1; + } + let gap = client.process_output(now).callback(); + assert_ne!(gap, Duration::new(0, 0)); + // The last one will not be paced. + for _ in PACING_BURST_SIZE..cwnd_packets(POST_HANDSHAKE_CWND) - 1 { + assert_eq!(client.process_output(now).callback(), gap); + now += gap; + let dgram = client.process_output(now).dgram(); + assert!(dgram.is_some()); + count += 1; + } + let dgram = client.process_output(now).dgram(); + assert!(dgram.is_some()); + count += 1; + assert_eq!(count, cwnd_packets(POST_HANDSHAKE_CWND)); + let fin = client.process_output(now).callback(); + assert_ne!(fin, Duration::new(0, 0)); + assert_ne!(fin, gap); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/close.rs b/third_party/rust/neqo-transport/src/connection/tests/close.rs new file mode 100644 index 0000000000..715073523b --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/close.rs @@ -0,0 +1,206 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::super::{Connection, Output, State}; +use super::{connect, connect_force_idle, default_client, default_server, send_something}; +use crate::tparams::{self, TransportParameter}; +use crate::{AppError, ConnectionError, Error, ERROR_APPLICATION_CLOSE}; + +use neqo_common::Datagram; +use std::time::Duration; +use test_fixture::{self, loopback, now}; + +fn assert_draining(c: &Connection, expected: &Error) { + assert!(c.state().closed()); + if let State::Draining { + error: ConnectionError::Transport(error), + .. + } = c.state() + { + assert_eq!(error, expected); + } else { + panic!(); + } +} + +#[test] +fn connection_close() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let now = now(); + + client.close(now, 42, ""); + + let out = client.process(None, now); + + server.process_input(out.dgram().unwrap(), now); + assert_draining(&server, &Error::PeerApplicationError(42)); +} + +#[test] +fn connection_close_with_long_reason_string() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let now = now(); + // Create a long string and use it as the close reason. + let long_reason = String::from_utf8([0x61; 2048].to_vec()).unwrap(); + client.close(now, 42, long_reason); + + let out = client.process(None, now); + + server.process_input(out.dgram().unwrap(), now); + assert_draining(&server, &Error::PeerApplicationError(42)); +} + +// During the handshake, an application close should be sanitized. +#[test] +fn early_application_close() { + let mut client = default_client(); + let mut server = default_server(); + + // One flight each. + let dgram = client.process(None, now()).dgram(); + assert!(dgram.is_some()); + let dgram = server.process(dgram, now()).dgram(); + assert!(dgram.is_some()); + + server.close(now(), 77, String::from("")); + assert!(server.state().closed()); + let dgram = server.process(None, now()).dgram(); + assert!(dgram.is_some()); + + client.process_input(dgram.unwrap(), now()); + assert_draining(&client, &Error::PeerError(ERROR_APPLICATION_CLOSE)); +} + +#[test] +fn bad_tls_version() { + let mut client = default_client(); + // Do a bad, bad thing. + client + .crypto + .tls + .set_option(neqo_crypto::Opt::Tls13CompatMode, true) + .unwrap(); + let mut server = default_server(); + + let dgram = client.process(None, now()).dgram(); + assert!(dgram.is_some()); + let dgram = server.process(dgram, now()).dgram(); + assert_eq!( + *server.state(), + State::Closed(ConnectionError::Transport(Error::ProtocolViolation)) + ); + assert!(dgram.is_some()); + client.process_input(dgram.unwrap(), now()); + assert_draining(&client, &Error::PeerError(Error::ProtocolViolation.code())); +} + +/// Test the interaction between the loss recovery timer +/// and the closing timer. +#[test] +fn closing_timers_interation() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let mut now = now(); + + // We're going to induce time-based loss recovery so that timer is set. + let _p1 = send_something(&mut client, now); + let p2 = send_something(&mut client, now); + let ack = server.process(Some(p2), now).dgram(); + assert!(ack.is_some()); // This is an ACK. + + // After processing the ACK, we should be on the loss recovery timer. + let cb = client.process(ack, now).callback(); + assert_ne!(cb, Duration::from_secs(0)); + now += cb; + + // Rather than let the timer pop, close the connection. + client.close(now, 0, ""); + let client_close = client.process(None, now).dgram(); + assert!(client_close.is_some()); + // This should now report the end of the closing period, not a + // zero-duration wait driven by the (now defunct) loss recovery timer. + let client_close_timer = client.process(None, now).callback(); + assert_ne!(client_close_timer, Duration::from_secs(0)); +} + +#[test] +fn closing_and_draining() { + const APP_ERROR: AppError = 7; + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // Save a packet from the client for later. + let p1 = send_something(&mut client, now()); + + // Close the connection. + client.close(now(), APP_ERROR, ""); + let client_close = client.process(None, now()).dgram(); + assert!(client_close.is_some()); + let client_close_timer = client.process(None, now()).callback(); + assert_ne!(client_close_timer, Duration::from_secs(0)); + + // The client will spit out the same packet in response to anything it receives. + let p3 = send_something(&mut server, now()); + let client_close2 = client.process(Some(p3), now()).dgram(); + assert_eq!( + client_close.as_ref().unwrap().len(), + client_close2.as_ref().unwrap().len() + ); + + // After this time, the client should transition to closed. + let end = client.process(None, now() + client_close_timer); + assert_eq!(end, Output::None); + assert_eq!( + *client.state(), + State::Closed(ConnectionError::Application(APP_ERROR)) + ); + + // When the server receives the close, it too should generate CONNECTION_CLOSE. + let server_close = server.process(client_close, now()).dgram(); + assert!(server.state().closed()); + assert!(server_close.is_some()); + // .. but it ignores any further close packets. + let server_close_timer = server.process(client_close2, now()).callback(); + assert_ne!(server_close_timer, Duration::from_secs(0)); + // Even a legitimate packet without a close in it. + let server_close_timer2 = server.process(Some(p1), now()).callback(); + assert_eq!(server_close_timer, server_close_timer2); + + let end = server.process(None, now() + server_close_timer); + assert_eq!(end, Output::None); + assert_eq!( + *server.state(), + State::Closed(ConnectionError::Transport(Error::PeerApplicationError( + APP_ERROR + ))) + ); +} + +/// Test that a client can handle a stateless reset correctly. +#[test] +fn stateless_reset_client() { + let mut client = default_client(); + let mut server = default_server(); + server + .set_local_tparam( + tparams::STATELESS_RESET_TOKEN, + TransportParameter::Bytes(vec![77; 16]), + ) + .unwrap(); + connect_force_idle(&mut client, &mut server); + + client.process_input(Datagram::new(loopback(), loopback(), vec![77; 21]), now()); + assert_draining(&client, &Error::StatelessReset); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/handshake.rs b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs new file mode 100644 index 0000000000..c9769b3c3c --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs @@ -0,0 +1,697 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::super::{Connection, FixedConnectionIdManager, Output, State, LOCAL_IDLE_TIMEOUT}; +use super::{ + assert_error, connect_force_idle, connect_with_rtt, default_client, default_server, get_tokens, + handshake, maybe_authenticate, send_something, AT_LEAST_PTO, DEFAULT_RTT, DEFAULT_STREAM_DATA, +}; +use crate::connection::AddressValidation; +use crate::events::ConnectionEvent; +use crate::frame::StreamType; +use crate::path::PATH_MTU_V6; +use crate::server::ValidateAddress; +use crate::{ConnectionError, ConnectionParameters, Error}; + +use neqo_common::{event::Provider, qdebug, Datagram}; +use neqo_crypto::{constants::TLS_CHACHA20_POLY1305_SHA256, AuthenticationStatus}; +use std::cell::RefCell; +use std::rc::Rc; +use std::time::Duration; +use test_fixture::{self, assertions, fixture_init, loopback, now, split_datagram}; + +#[test] +fn full_handshake() { + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6); + + qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); + let mut server = default_server(); + let out = server.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6); + + qdebug!("---- client: cert verification"); + let out = client.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_some()); + + let out = server.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_none()); + + assert!(maybe_authenticate(&mut client)); + + qdebug!("---- client: SH..FIN -> FIN"); + let out = client.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(*client.state(), State::Connected); + + qdebug!("---- server: FIN -> ACKS"); + let out = server.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(*server.state(), State::Confirmed); + + qdebug!("---- client: ACKS -> 0"); + let out = client.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_none()); + assert_eq!(*client.state(), State::Confirmed); +} + +#[test] +fn handshake_failed_authentication() { + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + + qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); + let mut server = default_server(); + let out = server.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_some()); + + qdebug!("---- client: cert verification"); + let out = client.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_some()); + + let out = server.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_none()); + + let authentication_needed = |e| matches!(e, ConnectionEvent::AuthenticationNeeded); + assert!(client.events().any(authentication_needed)); + qdebug!("---- client: Alert(certificate_revoked)"); + client.authenticated(AuthenticationStatus::CertRevoked, now()); + + qdebug!("---- client: -> Alert(certificate_revoked)"); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + + qdebug!("---- server: Alert(certificate_revoked)"); + let out = server.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_some()); + assert_error(&client, &ConnectionError::Transport(Error::CryptoAlert(44))); + assert_error(&server, &ConnectionError::Transport(Error::PeerError(300))); +} + +#[test] +fn no_alpn() { + fixture_init(); + let mut client = Connection::new_client( + "example.com", + &["bad-alpn"], + Rc::new(RefCell::new(FixedConnectionIdManager::new(9))), + loopback(), + loopback(), + &ConnectionParameters::default(), + ) + .unwrap(); + let mut server = default_server(); + + handshake(&mut client, &mut server, now(), Duration::new(0, 0)); + // TODO (mt): errors are immediate, which means that we never send CONNECTION_CLOSE + // and the client never sees the server's rejection of its handshake. + //assert_error(&client, ConnectionError::Transport(Error::CryptoAlert(120))); + assert_error( + &server, + &ConnectionError::Transport(Error::CryptoAlert(120)), + ); +} + +#[test] +fn dup_server_flight1() { + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + + qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); + let mut server = default_server(); + let out_to_rep = server.process(out.dgram(), now()); + assert!(out_to_rep.as_dgram_ref().is_some()); + qdebug!("Output={:0x?}", out_to_rep.as_dgram_ref()); + + qdebug!("---- client: cert verification"); + let out = client.process(Some(out_to_rep.as_dgram_ref().unwrap().clone()), now()); + assert!(out.as_dgram_ref().is_some()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + + let out = server.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_none()); + + assert!(maybe_authenticate(&mut client)); + + qdebug!("---- client: SH..FIN -> FIN"); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + + assert_eq!(3, client.stats().packets_rx); + assert_eq!(0, client.stats().dups_rx); + assert_eq!(1, client.stats().dropped_rx); + + qdebug!("---- Dup, ignored"); + let out = client.process(out_to_rep.dgram(), now()); + assert!(out.as_dgram_ref().is_none()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + + // Four packets total received, 1 of them is a dup and one has been dropped because Initial keys + // are dropped. Add 2 counts of the padding that the server adds to Initial packets. + assert_eq!(6, client.stats().packets_rx); + assert_eq!(1, client.stats().dups_rx); + assert_eq!(3, client.stats().dropped_rx); +} + +// Test that we split crypto data if they cannot fit into one packet. +// To test this we will use a long server certificate. +#[test] +fn crypto_frame_split() { + let mut client = default_client(); + + let mut server = Connection::new_server( + test_fixture::LONG_CERT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(FixedConnectionIdManager::new(6))), + &ConnectionParameters::default(), + ) + .expect("create a server"); + + let client1 = client.process(None, now()); + assert!(client1.as_dgram_ref().is_some()); + + // The entire server flight doesn't fit in a single packet because the + // certificate is large, therefore the server will produce 2 packets. + let server1 = server.process(client1.dgram(), now()); + assert!(server1.as_dgram_ref().is_some()); + let server2 = server.process(None, now()); + assert!(server2.as_dgram_ref().is_some()); + + let client2 = client.process(server1.dgram(), now()); + // This is an ack. + assert!(client2.as_dgram_ref().is_some()); + // The client might have the certificate now, so we can't guarantee that + // this will work. + let auth1 = maybe_authenticate(&mut client); + assert_eq!(*client.state(), State::Handshaking); + + // let server process the ack for the first packet. + let server3 = server.process(client2.dgram(), now()); + assert!(server3.as_dgram_ref().is_none()); + + // Consume the second packet from the server. + let client3 = client.process(server2.dgram(), now()); + + // Check authentication. + let auth2 = maybe_authenticate(&mut client); + assert!(auth1 ^ auth2); + // Now client has all data to finish handshake. + assert_eq!(*client.state(), State::Connected); + + let client4 = client.process(server3.dgram(), now()); + // One of these will contain data depending on whether Authentication was completed + // after the first or second server packet. + assert!(client3.as_dgram_ref().is_some() ^ client4.as_dgram_ref().is_some()); + + let _ = server.process(client3.dgram(), now()); + let _ = server.process(client4.dgram(), now()); + + assert_eq!(*client.state(), State::Connected); + assert_eq!(*server.state(), State::Confirmed); +} + +/// Run a single ChaCha20-Poly1305 test and get a PTO. +#[test] +fn chacha20poly1305() { + let mut server = default_server(); + let mut client = Connection::new_client( + test_fixture::DEFAULT_SERVER_NAME, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(FixedConnectionIdManager::new(0))), + loopback(), + loopback(), + &ConnectionParameters::default(), + ) + .expect("create a default client"); + client.set_ciphers(&[TLS_CHACHA20_POLY1305_SHA256]).unwrap(); + connect_force_idle(&mut client, &mut server); +} + +/// Test that a server can send 0.5 RTT application data. +#[test] +fn send_05rtt() { + let mut client = default_client(); + let mut server = default_server(); + + let c1 = client.process(None, now()).dgram(); + assert!(c1.is_some()); + let s1 = server.process(c1, now()).dgram().unwrap(); + assert_eq!(s1.len(), PATH_MTU_V6); + + // The server should accept writes at this point. + let s2 = send_something(&mut server, now()); + + // Complete the handshake at the client. + client.process_input(s1, now()); + maybe_authenticate(&mut client); + assert_eq!(*client.state(), State::Connected); + + // The client should receive the 0.5-RTT data now. + client.process_input(s2, now()); + let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1]; + let stream_id = client + .events() + .find_map(|e| { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + Some(stream_id) + } else { + None + } + }) + .unwrap(); + let (l, ended) = client.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(&buf[..l], DEFAULT_STREAM_DATA); + assert!(ended); +} + +/// Test that a client buffers 0.5-RTT data when it arrives early. +#[test] +fn reorder_05rtt() { + let mut client = default_client(); + let mut server = default_server(); + + let c1 = client.process(None, now()).dgram(); + assert!(c1.is_some()); + let s1 = server.process(c1, now()).dgram().unwrap(); + + // The server should accept writes at this point. + let s2 = send_something(&mut server, now()); + + // We can't use the standard facility to complete the handshake, so + // drive it as aggressively as possible. + client.process_input(s2, now()); + assert_eq!(client.stats().saved_datagrams, 1); + + // After processing the first packet, the client should go back and + // process the 0.5-RTT packet data, which should make data available. + client.process_input(s1, now()); + // We can't use `maybe_authenticate` here as that consumes events. + client.authenticated(AuthenticationStatus::Ok, now()); + assert_eq!(*client.state(), State::Connected); + + let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1]; + let stream_id = client + .events() + .find_map(|e| { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + Some(stream_id) + } else { + None + } + }) + .unwrap(); + let (l, ended) = client.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(&buf[..l], DEFAULT_STREAM_DATA); + assert!(ended); +} + +#[test] +fn reorder_05rtt_with_0rtt() { + const RTT: Duration = Duration::from_millis(100); + + let mut client = default_client(); + let mut server = default_server(); + let validation = AddressValidation::new(now(), ValidateAddress::NoToken).unwrap(); + let validation = Rc::new(RefCell::new(validation)); + server.set_validation(Rc::clone(&validation)); + let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT); + + // Include RTT in sending the ticket or the ticket age reported by the + // client is wrong, which causes the server to reject 0-RTT. + now += RTT / 2; + server.send_ticket(now, &[]).unwrap(); + let ticket = server.process_output(now).dgram().unwrap(); + now += RTT / 2; + client.process_input(ticket, now); + + let token = get_tokens(&mut client).pop().unwrap(); + let mut client = default_client(); + client.enable_resumption(now, token).unwrap(); + let mut server = default_server(); + + // Send ClientHello and some 0-RTT. + let c1 = send_something(&mut client, now); + assertions::assert_coalesced_0rtt(&c1[..]); + // Drop the 0-RTT from the coalesced datagram, so that the server + // acknowledges the next 0-RTT packet. + let (c1, _) = split_datagram(&c1); + let c2 = send_something(&mut client, now); + + // Handle the first packet and send 0.5-RTT in response. Drop the response. + now += RTT / 2; + let _ = server.process(Some(c1), now).dgram().unwrap(); + // The gap in 0-RTT will result in this 0.5 RTT containing an ACK. + server.process_input(c2, now); + let s2 = send_something(&mut server, now); + + // Save the 0.5 RTT. + now += RTT / 2; + client.process_input(s2, now); + assert_eq!(client.stats().saved_datagrams, 1); + + // Now PTO at the client and cause the server to re-send handshake packets. + now += AT_LEAST_PTO; + let c3 = client.process(None, now).dgram(); + + now += RTT / 2; + let s3 = server.process(c3, now).dgram().unwrap(); + assertions::assert_no_1rtt(&s3[..]); + + // The client should be able to process the 0.5 RTT now. + // This should contain an ACK, so we are processing an ACK from the past. + now += RTT / 2; + client.process_input(s3, now); + maybe_authenticate(&mut client); + let c4 = client.process(None, now).dgram(); + assert_eq!(*client.state(), State::Connected); + assert_eq!(client.loss_recovery.rtt(), RTT); + + now += RTT / 2; + server.process_input(c4.unwrap(), now); + assert_eq!(*server.state(), State::Confirmed); + assert_eq!(server.loss_recovery.rtt(), RTT); +} + +/// Test that a server that coalesces 0.5 RTT with handshake packets +/// doesn't cause the client to drop application data. +#[test] +fn coalesce_05rtt() { + const RTT: Duration = Duration::from_millis(100); + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + + // The first exchange doesn't offer a chance for the server to send. + // So drop the server flight and wait for the PTO. + let c1 = client.process(None, now).dgram(); + assert!(c1.is_some()); + now += RTT / 2; + let s1 = server.process(c1, now).dgram(); + assert!(s1.is_some()); + + // Drop the server flight. Then send some data. + let stream_id = server.stream_create(StreamType::UniDi).unwrap(); + assert!(server.stream_send(stream_id, DEFAULT_STREAM_DATA).is_ok()); + assert!(server.stream_close_send(stream_id).is_ok()); + + // Now after a PTO the client can send another packet. + // The server should then send its entire flight again, + // including the application data, which it sends in a 1-RTT packet. + now += AT_LEAST_PTO; + let c2 = client.process(None, now).dgram(); + assert!(c2.is_some()); + now += RTT / 2; + let s2 = server.process(c2, now).dgram(); + // Even though there is a 1-RTT packet at the end of the datagram, the + // flight should be padded to full size. + assert_eq!(s2.as_ref().unwrap().len(), PATH_MTU_V6); + + // The client should process the datagram. It can't process the 1-RTT + // packet until authentication completes though. So it saves it. + now += RTT / 2; + assert_eq!(client.stats().dropped_rx, 0); + let _ = client.process(s2, now).dgram(); + // This packet will contain an ACK, but we can ignore it. + assert_eq!(client.stats().dropped_rx, 0); + assert_eq!(client.stats().packets_rx, 3); + assert_eq!(client.stats().saved_datagrams, 1); + + // After (successful) authentication, the packet is processed. + maybe_authenticate(&mut client); + let c3 = client.process(None, now).dgram(); + assert!(c3.is_some()); + assert_eq!(client.stats().dropped_rx, 0); // No Initial padding. + assert_eq!(client.stats().packets_rx, 4); + assert_eq!(client.stats().saved_datagrams, 1); + assert_eq!(client.stats().frame_rx.padding, 1); // Padding uses frames. + + // Allow the handshake to complete. + now += RTT / 2; + let s3 = server.process(c3, now).dgram(); + assert!(s3.is_some()); + assert_eq!(*server.state(), State::Confirmed); + now += RTT / 2; + let _ = client.process(s3, now).dgram(); + assert_eq!(*client.state(), State::Confirmed); + + assert_eq!(client.stats().dropped_rx, 0); // No dropped packets. +} + +#[test] +fn reorder_handshake() { + const RTT: Duration = Duration::from_millis(100); + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + + let c1 = client.process(None, now).dgram(); + assert!(c1.is_some()); + + now += RTT / 2; + let s1 = server.process(c1, now).dgram(); + assert!(s1.is_some()); + + // Drop the Initial packet from this. + let (_, s_hs) = split_datagram(&s1.unwrap()); + assert!(s_hs.is_some()); + + // Pass just the handshake packet in and the client can't handle it yet. + // It can only send another Initial packet. + now += RTT / 2; + let dgram = client.process(s_hs, now).dgram(); + assertions::assert_initial(&dgram.as_ref().unwrap(), false); + assert_eq!(client.stats().saved_datagrams, 1); + assert_eq!(client.stats().packets_rx, 1); + + // Get the server to try again. + // Though we currently allow the server to arm its PTO timer, use + // a second client Initial packet to cause it to send again. + now += AT_LEAST_PTO; + let c2 = client.process(None, now).dgram(); + now += RTT / 2; + let s2 = server.process(c2, now).dgram(); + assert!(s2.is_some()); + + let (s_init, s_hs) = split_datagram(&s2.unwrap()); + assert!(s_hs.is_some()); + + // Processing the Handshake packet first should save it. + now += RTT / 2; + client.process_input(s_hs.unwrap(), now); + assert_eq!(client.stats().saved_datagrams, 2); + assert_eq!(client.stats().packets_rx, 2); + + client.process_input(s_init, now); + // Each saved packet should now be "received" again. + assert_eq!(client.stats().packets_rx, 7); + maybe_authenticate(&mut client); + let c3 = client.process(None, now).dgram(); + assert!(c3.is_some()); + + // Note that though packets were saved and processed very late, + // they don't cause the RTT to change. + now += RTT / 2; + let s3 = server.process(c3, now).dgram(); + assert_eq!(*server.state(), State::Confirmed); + assert_eq!(server.loss_recovery.rtt(), RTT); + + now += RTT / 2; + client.process_input(s3.unwrap(), now); + assert_eq!(*client.state(), State::Confirmed); + assert_eq!(client.loss_recovery.rtt(), RTT); +} + +#[test] +fn reorder_1rtt() { + const RTT: Duration = Duration::from_millis(100); + const PACKETS: usize = 6; // Many, but not enough to overflow cwnd. + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + + let c1 = client.process(None, now).dgram(); + assert!(c1.is_some()); + + now += RTT / 2; + let s1 = server.process(c1, now).dgram(); + assert!(s1.is_some()); + + now += RTT / 2; + client.process_input(s1.unwrap(), now); + maybe_authenticate(&mut client); + let c2 = client.process(None, now).dgram(); + assert!(c2.is_some()); + + // Now get a bunch of packets from the client. + // Give them to the server before giving it `c2`. + for _ in 0..PACKETS { + let d = send_something(&mut client, now); + server.process_input(d, now + RTT / 2); + } + // The server has now received those packets, and saved them. + // The two extra received are Initial + the junk we use for padding. + assert_eq!(server.stats().packets_rx, PACKETS + 2); + assert_eq!(server.stats().saved_datagrams, PACKETS); + assert_eq!(server.stats().dropped_rx, 1); + + now += RTT / 2; + let s2 = server.process(c2, now).dgram(); + // The server has now received those packets, and saved them. + // The two additional are an Initial ACK and Handshake. + assert_eq!(server.stats().packets_rx, PACKETS * 2 + 4); + assert_eq!(server.stats().saved_datagrams, PACKETS); + assert_eq!(server.stats().dropped_rx, 1); + assert_eq!(*server.state(), State::Confirmed); + assert_eq!(server.loss_recovery.rtt(), RTT); + + now += RTT / 2; + client.process_input(s2.unwrap(), now); + assert_eq!(client.loss_recovery.rtt(), RTT); + + // All the stream data that was sent should now be available. + let streams = server + .events() + .filter_map(|e| { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + Some(stream_id) + } else { + None + } + }) + .collect::<Vec<_>>(); + assert_eq!(streams.len(), PACKETS); + for stream_id in streams { + let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1]; + let (recvd, fin) = server.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(recvd, DEFAULT_STREAM_DATA.len()); + assert!(fin); + } +} + +#[test] +fn corrupted_initial() { + let mut client = default_client(); + let mut server = default_server(); + let d = client.process(None, now()).dgram().unwrap(); + let mut corrupted = Vec::from(&d[..]); + // Find the last non-zero value and corrupt that. + let (idx, _) = corrupted + .iter() + .enumerate() + .rev() + .find(|(_, &v)| v != 0) + .unwrap(); + corrupted[idx] ^= 0x76; + let dgram = Datagram::new(d.source(), d.destination(), corrupted); + server.process_input(dgram, now()); + // The server should have received two packets, + // the first should be dropped, the second saved. + assert_eq!(server.stats().packets_rx, 2); + assert_eq!(server.stats().dropped_rx, 2); + assert_eq!(server.stats().saved_datagrams, 0); +} + +#[test] +// Absent path PTU discovery, max v6 packet size should be PATH_MTU_V6. +fn verify_pkt_honors_mtu() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let now = now(); + + let res = client.process(None, now); + assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT)); + + // Try to send a large stream and verify first packet is correctly sized + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); + assert_eq!(client.stream_send(2, &[0xbb; 2000]).unwrap(), 2000); + let pkt0 = client.process(None, now); + assert!(matches!(pkt0, Output::Datagram(_))); + assert_eq!(pkt0.as_dgram_ref().unwrap().len(), PATH_MTU_V6); +} + +#[test] +fn extra_initial_hs() { + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + + let c_init = client.process(None, now).dgram(); + assert!(c_init.is_some()); + now += DEFAULT_RTT / 2; + let s_init = server.process(c_init, now).dgram(); + assert!(s_init.is_some()); + now += DEFAULT_RTT / 2; + + // Drop the Initial packet, keep only the Handshake. + let (_, undecryptable) = split_datagram(&s_init.unwrap()); + assert!(undecryptable.is_some()); + + // Feed the same undecryptable packet into the client a few times. + // Do that EXTRA_INITIALS times and each time the client will emit + // another Initial packet. + for _ in 0..=super::super::EXTRA_INITIALS { + let c_init = client.process(undecryptable.clone(), now).dgram(); + assertions::assert_initial(&c_init.as_ref().unwrap(), false); + now += DEFAULT_RTT / 10; + } + + // After EXTRA_INITIALS, the client stops sending Initial packets. + let nothing = client.process(undecryptable, now).dgram(); + assert!(nothing.is_none()); + + // Until PTO, where another Initial can be used to complete the handshake. + now += AT_LEAST_PTO; + let c_init = client.process(None, now).dgram(); + assertions::assert_initial(c_init.as_ref().unwrap(), false); + now += DEFAULT_RTT / 2; + let s_init = server.process(c_init, now).dgram(); + now += DEFAULT_RTT / 2; + client.process_input(s_init.unwrap(), now); + maybe_authenticate(&mut client); + let c_fin = client.process_output(now).dgram(); + assert_eq!(*client.state(), State::Connected); + now += DEFAULT_RTT / 2; + server.process_input(c_fin.unwrap(), now); + assert_eq!(*server.state(), State::Confirmed); +} + +#[test] +fn extra_initial_invalid_cid() { + let mut client = default_client(); + let mut server = default_server(); + let mut now = now(); + + let c_init = client.process(None, now).dgram(); + assert!(c_init.is_some()); + now += DEFAULT_RTT / 2; + let s_init = server.process(c_init, now).dgram(); + assert!(s_init.is_some()); + now += DEFAULT_RTT / 2; + + // If the client receives a packet that contains the wrong connection + // ID, it won't send another Initial. + let (_, hs) = split_datagram(&s_init.unwrap()); + let hs = hs.unwrap(); + let mut copy = hs.to_vec(); + assert_ne!(copy[5], 0); // The DCID should be non-zero length. + copy[6] ^= 0xc4; + let dgram_copy = Datagram::new(hs.destination(), hs.source(), copy); + let nothing = client.process(Some(dgram_copy), now).dgram(); + assert!(nothing.is_none()); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/idle.rs b/third_party/rust/neqo-transport/src/connection/tests/idle.rs new file mode 100644 index 0000000000..962645eff3 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/idle.rs @@ -0,0 +1,274 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::super::{IdleTimeout, Output, State, LOCAL_IDLE_TIMEOUT}; +use super::{ + connect, connect_force_idle, connect_with_rtt, default_client, default_server, + maybe_authenticate, send_something, AT_LEAST_PTO, +}; +use crate::frame::StreamType; +use crate::packet::PacketBuilder; +use crate::tparams::{self, TransportParameter}; +use crate::tracking::PNSpace; + +use neqo_common::Encoder; +use std::time::Duration; +use test_fixture::{self, now, split_datagram}; + +#[test] +fn idle_timeout() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let now = now(); + + let res = client.process(None, now); + assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT)); + + // Still connected after 29 seconds. Idle timer not reset + let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT - Duration::from_secs(1)); + assert!(matches!(client.state(), State::Confirmed)); + + let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT); + + // Not connected after LOCAL_IDLE_TIMEOUT seconds. + assert!(matches!(client.state(), State::Closed(_))); +} + +#[test] +fn asymmetric_idle_timeout() { + const LOWER_TIMEOUT_MS: u64 = 1000; + const LOWER_TIMEOUT: Duration = Duration::from_millis(LOWER_TIMEOUT_MS); + // Sanity check the constant. + assert!(LOWER_TIMEOUT < LOCAL_IDLE_TIMEOUT); + + let mut client = default_client(); + let mut server = default_server(); + + // Overwrite the default at the server. + server + .tps + .borrow_mut() + .local + .set_integer(tparams::IDLE_TIMEOUT, LOWER_TIMEOUT_MS); + server.idle_timeout = IdleTimeout::new(LOWER_TIMEOUT); + + // Now connect and force idleness manually. + connect(&mut client, &mut server); + let p1 = send_something(&mut server, now()); + let p2 = send_something(&mut server, now()); + client.process_input(p2, now()); + let ack = client.process(Some(p1), now()).dgram(); + assert!(ack.is_some()); + // Now the server has its ACK and both should be idle. + assert_eq!(server.process(ack, now()), Output::Callback(LOWER_TIMEOUT)); + assert_eq!(client.process(None, now()), Output::Callback(LOWER_TIMEOUT)); +} + +#[test] +fn tiny_idle_timeout() { + const RTT: Duration = Duration::from_millis(500); + const LOWER_TIMEOUT_MS: u64 = 100; + const LOWER_TIMEOUT: Duration = Duration::from_millis(LOWER_TIMEOUT_MS); + // We won't respect a value that is lower than 3*PTO, sanity check. + assert!(LOWER_TIMEOUT < 3 * RTT); + + let mut client = default_client(); + let mut server = default_server(); + + // Overwrite the default at the server. + server + .set_local_tparam( + tparams::IDLE_TIMEOUT, + TransportParameter::Integer(LOWER_TIMEOUT_MS), + ) + .unwrap(); + server.idle_timeout = IdleTimeout::new(LOWER_TIMEOUT); + + // Now connect with an RTT and force idleness manually. + let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT); + let p1 = send_something(&mut server, now); + let p2 = send_something(&mut server, now); + now += RTT / 2; + client.process_input(p2, now); + let ack = client.process(Some(p1), now).dgram(); + assert!(ack.is_some()); + + // The client should be idle now, but with a different timer. + if let Output::Callback(t) = client.process(None, now) { + assert!(t > LOWER_TIMEOUT); + } else { + panic!("Client not idle"); + } + + // The server should go idle after the ACK, but again with a larger timeout. + now += RTT / 2; + if let Output::Callback(t) = client.process(ack, now) { + assert!(t > LOWER_TIMEOUT); + } else { + panic!("Client not idle"); + } +} + +#[test] +fn idle_send_packet1() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let now = now(); + + let res = client.process(None, now); + assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT)); + + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); + assert_eq!(client.stream_send(2, b"hello").unwrap(), 5); + + let out = client.process(None, now + Duration::from_secs(10)); + let out = server.process(out.dgram(), now + Duration::from_secs(10)); + + // Still connected after 39 seconds because idle timer reset by outgoing + // packet + let _ = client.process( + out.dgram(), + now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(9), + ); + assert!(matches!(client.state(), State::Confirmed)); + + // Not connected after 40 seconds. + let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(10)); + + assert!(matches!(client.state(), State::Closed(_))); +} + +#[test] +fn idle_send_packet2() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let now = now(); + + let res = client.process(None, now); + assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT)); + + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); + assert_eq!(client.stream_send(2, b"hello").unwrap(), 5); + + let _out = client.process(None, now + Duration::from_secs(10)); + + assert_eq!(client.stream_send(2, b"there").unwrap(), 5); + let _out = client.process(None, now + Duration::from_secs(20)); + + // Still connected after 39 seconds. + let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(9)); + assert!(matches!(client.state(), State::Confirmed)); + + // Not connected after 40 seconds because timer not reset by second + // outgoing packet + let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(10)); + assert!(matches!(client.state(), State::Closed(_))); +} + +#[test] +fn idle_recv_packet() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let now = now(); + + let res = client.process(None, now); + assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT)); + + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0); + assert_eq!(client.stream_send(0, b"hello").unwrap(), 5); + + // Respond with another packet + let out = client.process(None, now + Duration::from_secs(10)); + server.process_input(out.dgram().unwrap(), now + Duration::from_secs(10)); + assert_eq!(server.stream_send(0, b"world").unwrap(), 5); + let out = server.process_output(now + Duration::from_secs(10)); + assert_ne!(out.as_dgram_ref(), None); + + let _ = client.process(out.dgram(), now + Duration::from_secs(20)); + assert!(matches!(client.state(), State::Confirmed)); + + // Still connected after 49 seconds because idle timer reset by received + // packet + let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(19)); + assert!(matches!(client.state(), State::Confirmed)); + + // Not connected after 50 seconds. + let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(20)); + + assert!(matches!(client.state(), State::Closed(_))); +} + +/// Caching packets should not cause the connection to become idle. +/// This requires a few tricks to keep the connection from going +/// idle while preventing any progress on the handshake. +#[test] +fn idle_caching() { + let mut client = default_client(); + let mut server = default_server(); + let start = now(); + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + + // Perform the first round trip, but drop the Initial from the server. + // The client then caches the Handshake packet. + let dgram = client.process_output(start).dgram(); + let dgram = server.process(dgram, start).dgram(); + let (_, handshake) = split_datagram(&dgram.unwrap()); + client.process_input(handshake.unwrap(), start); + + // Perform an exchange and keep the connection alive. + // Only allow a packet containing a PING to pass. + let middle = start + AT_LEAST_PTO; + let _ = client.process_output(middle); + let dgram = client.process_output(middle).dgram(); + + // Get the server to send its first probe and throw that away. + let _ = server.process_output(middle).dgram(); + // Now let the server process the client PING. This causes the server + // to send CRYPTO frames again, so manually extract and discard those. + let ping_before_s = server.stats().frame_rx.ping; + server.process_input(dgram.unwrap(), middle); + assert_eq!(server.stats().frame_rx.ping, ping_before_s + 1); + let crypto = server + .crypto + .streams + .write_frame(PNSpace::Initial, &mut builder); + assert!(crypto.is_some()); + let crypto = server + .crypto + .streams + .write_frame(PNSpace::Initial, &mut builder); + assert!(crypto.is_none()); + let dgram = server.process_output(middle).dgram(); + + // Now only allow the Initial packet from the server through; + // it shouldn't contain a CRYPTO frame. + let (initial, _) = split_datagram(&dgram.unwrap()); + let ping_before_c = client.stats().frame_rx.ping; + let ack_before = client.stats().frame_rx.ack; + client.process_input(initial, middle); + assert_eq!(client.stats().frame_rx.ping, ping_before_c + 1); + assert_eq!(client.stats().frame_rx.ack, ack_before + 1); + + let end = start + LOCAL_IDLE_TIMEOUT + (AT_LEAST_PTO / 2); + // Now let the server Initial through, with the CRYPTO frame. + let dgram = server.process_output(end).dgram(); + let (initial, _) = split_datagram(&dgram.unwrap()); + let _ = client.process(Some(initial), end); + maybe_authenticate(&mut client); + let dgram = client.process_output(end).dgram(); + let dgram = server.process(dgram, end).dgram(); + client.process_input(dgram.unwrap(), end); + assert_eq!(*client.state(), State::Confirmed); + assert_eq!(*server.state(), State::Confirmed); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/keys.rs b/third_party/rust/neqo-transport/src/connection/tests/keys.rs new file mode 100644 index 0000000000..cc572e85ca --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/keys.rs @@ -0,0 +1,330 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::super::super::{ConnectionError, ERROR_AEAD_LIMIT_REACHED}; +use super::super::{Connection, Error, Output, State, StreamType, LOCAL_IDLE_TIMEOUT}; +use super::{ + connect, connect_force_idle, default_client, default_server, maybe_authenticate, + send_and_receive, send_something, AT_LEAST_PTO, +}; +use crate::crypto::{OVERWRITE_INVOCATIONS, UPDATE_WRITE_KEYS_AT}; +use crate::packet::PacketNumber; +use crate::path::PATH_MTU_V6; + +use neqo_common::{qdebug, Datagram}; +use test_fixture::{self, now}; + +fn check_discarded(peer: &mut Connection, pkt: Datagram, dropped: usize, dups: usize) { + // Make sure to flush any saved datagrams before doing this. + let _ = peer.process_output(now()); + + let before = peer.stats(); + let out = peer.process(Some(pkt), now()); + assert!(out.as_dgram_ref().is_none()); + let after = peer.stats(); + assert_eq!(dropped, after.dropped_rx - before.dropped_rx); + assert_eq!(dups, after.dups_rx - before.dups_rx); +} + +fn assert_update_blocked(c: &mut Connection) { + assert_eq!( + c.initiate_key_update().unwrap_err(), + Error::KeyUpdateBlocked + ) +} + +fn overwrite_invocations(n: PacketNumber) { + OVERWRITE_INVOCATIONS.with(|v| { + *v.borrow_mut() = Some(n); + }); +} + +#[test] +fn discarded_initial_keys() { + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let init_pkt_c = client.process(None, now()).dgram(); + assert!(init_pkt_c.is_some()); + assert_eq!(init_pkt_c.as_ref().unwrap().len(), PATH_MTU_V6); + + qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); + let mut server = default_server(); + let init_pkt_s = server.process(init_pkt_c.clone(), now()).dgram(); + assert!(init_pkt_s.is_some()); + + qdebug!("---- client: cert verification"); + let out = client.process(init_pkt_s.clone(), now()).dgram(); + assert!(out.is_some()); + + // The client has received handshake packet. It will remove the Initial keys. + // We will check this by processing init_pkt_s a second time. + // The initial packet should be dropped. The packet contains a Handshake packet as well, which + // will be marked as dup. And it will contain padding, which will be "dropped". + check_discarded(&mut client, init_pkt_s.unwrap(), 2, 1); + + assert!(maybe_authenticate(&mut client)); + + // The server has not removed the Initial keys yet, because it has not yet received a Handshake + // packet from the client. + // We will check this by processing init_pkt_c a second time. + // The dropped packet is padding. The Initial packet has been mark dup. + check_discarded(&mut server, init_pkt_c.clone().unwrap(), 1, 1); + + qdebug!("---- client: SH..FIN -> FIN"); + let out = client.process(None, now()).dgram(); + assert!(out.is_some()); + + // The server will process the first Handshake packet. + // After this the Initial keys will be dropped. + let out = server.process(out, now()).dgram(); + assert!(out.is_some()); + + // Check that the Initial keys are dropped at the server + // We will check this by processing init_pkt_c a third time. + // The Initial packet has been dropped and padding that follows it. + // There is no dups, everything has been dropped. + check_discarded(&mut server, init_pkt_c.unwrap(), 1, 0); +} + +#[test] +fn key_update_client() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + let mut now = now(); + + assert_eq!(client.get_epochs(), (Some(3), Some(3))); // (write, read) + assert_eq!(server.get_epochs(), (Some(3), Some(3))); + + assert!(client.initiate_key_update().is_ok()); + assert_update_blocked(&mut client); + + // Initiating an update should only increase the write epoch. + assert_eq!( + Output::Callback(LOCAL_IDLE_TIMEOUT), + client.process(None, now) + ); + assert_eq!(client.get_epochs(), (Some(4), Some(3))); + + // Send something to propagate the update. + assert!(send_and_receive(&mut client, &mut server, now).is_none()); + + // The server should now be waiting to discharge read keys. + assert_eq!(server.get_epochs(), (Some(4), Some(3))); + let res = server.process(None, now); + if let Output::Callback(t) = res { + assert!(t < LOCAL_IDLE_TIMEOUT); + } else { + panic!("server should now be waiting to clear keys"); + } + + // Without having had time to purge old keys, more updates are blocked. + // The spec would permits it at this point, but we are more conservative. + assert_update_blocked(&mut client); + // The server can't update until it receives an ACK for a packet. + assert_update_blocked(&mut server); + + // Waiting now for at least a PTO should cause the server to drop old keys. + // But at this point the client hasn't received a key update from the server. + // It will be stuck with old keys. + now += AT_LEAST_PTO; + let dgram = client.process(None, now).dgram(); + assert!(dgram.is_some()); // Drop this packet. + assert_eq!(client.get_epochs(), (Some(4), Some(3))); + let _ = server.process(None, now); + assert_eq!(server.get_epochs(), (Some(4), Some(4))); + + // Even though the server has updated, it hasn't received an ACK yet. + assert_update_blocked(&mut server); + + // Now get an ACK from the server. + // The previous PTO packet (see above) was dropped, so we should get an ACK here. + let dgram = send_and_receive(&mut client, &mut server, now); + assert!(dgram.is_some()); + let res = client.process(dgram, now); + // This is the first packet that the client has received from the server + // with new keys, so its read timer just started. + if let Output::Callback(t) = res { + assert!(t < LOCAL_IDLE_TIMEOUT); + } else { + panic!("client should now be waiting to clear keys"); + } + + assert_update_blocked(&mut client); + assert_eq!(client.get_epochs(), (Some(4), Some(3))); + // The server can't update until it gets something from the client. + assert_update_blocked(&mut server); + + now += AT_LEAST_PTO; + let _ = client.process(None, now); + assert_eq!(client.get_epochs(), (Some(4), Some(4))); +} + +#[test] +fn key_update_consecutive() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let now = now(); + + assert!(server.initiate_key_update().is_ok()); + assert_eq!(server.get_epochs(), (Some(4), Some(3))); + + // Server sends something. + // Send twice and drop the first to induce an ACK from the client. + let _ = send_something(&mut server, now); // Drop this. + + // Another packet from the server will cause the client to ACK and update keys. + let dgram = send_and_receive(&mut server, &mut client, now); + assert!(dgram.is_some()); + assert_eq!(client.get_epochs(), (Some(4), Some(3))); + + // Have the server process the ACK. + if let Output::Callback(_) = server.process(dgram, now) { + assert_eq!(server.get_epochs(), (Some(4), Some(3))); + // Now move the server temporarily into the future so that it + // rotates the keys. The client stays in the present. + let _ = server.process(None, now + AT_LEAST_PTO); + assert_eq!(server.get_epochs(), (Some(4), Some(4))); + } else { + panic!("server should have a timer set"); + } + + // Now update keys on the server again. + assert!(server.initiate_key_update().is_ok()); + assert_eq!(server.get_epochs(), (Some(5), Some(4))); + + let dgram = send_something(&mut server, now + AT_LEAST_PTO); + + // However, as the server didn't wait long enough to update again, the + // client hasn't rotated its keys, so the packet gets dropped. + check_discarded(&mut client, dgram, 1, 0); +} + +// Key updates can't be initiated too early. +#[test] +fn key_update_before_confirmed() { + let mut client = default_client(); + assert_update_blocked(&mut client); + let mut server = default_server(); + assert_update_blocked(&mut server); + + // Client Initial + let dgram = client.process(None, now()).dgram(); + assert!(dgram.is_some()); + assert_update_blocked(&mut client); + + // Server Initial + Handshake + let dgram = server.process(dgram, now()).dgram(); + assert!(dgram.is_some()); + assert_update_blocked(&mut server); + + // Client Handshake + client.process_input(dgram.unwrap(), now()); + assert_update_blocked(&mut client); + + assert!(maybe_authenticate(&mut client)); + assert_update_blocked(&mut client); + + let dgram = client.process(None, now()).dgram(); + assert!(dgram.is_some()); + assert_update_blocked(&mut client); + + // Server HANDSHAKE_DONE + let dgram = server.process(dgram, now()).dgram(); + assert!(dgram.is_some()); + assert!(server.initiate_key_update().is_ok()); + + // Client receives HANDSHAKE_DONE + let dgram = client.process(dgram, now()).dgram(); + assert!(dgram.is_none()); + assert!(client.initiate_key_update().is_ok()); +} + +#[test] +fn exhaust_write_keys() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + overwrite_invocations(0); + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert!(client.stream_send(stream_id, b"explode!").is_ok()); + let dgram = client.process_output(now()).dgram(); + assert!(dgram.is_none()); + assert!(matches!( + client.state(), + State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + )); +} + +#[test] +fn exhaust_read_keys() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let dgram = send_something(&mut client, now()); + + overwrite_invocations(0); + let dgram = server.process(Some(dgram), now()).dgram(); + assert!(matches!( + server.state(), + State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + )); + + client.process_input(dgram.unwrap(), now()); + assert!(matches!(client.state(), State::Draining { + error: ConnectionError::Transport(Error::PeerError(ERROR_AEAD_LIMIT_REACHED)), .. + })); +} + +#[test] +fn automatic_update_write_keys() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + overwrite_invocations(UPDATE_WRITE_KEYS_AT); + let _ = send_something(&mut client, now()); + assert_eq!(client.get_epochs(), (Some(4), Some(3))); +} + +#[test] +fn automatic_update_write_keys_later() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + overwrite_invocations(UPDATE_WRITE_KEYS_AT + 2); + // No update after the first. + let _ = send_something(&mut client, now()); + assert_eq!(client.get_epochs(), (Some(3), Some(3))); + // The second will update though. + let _ = send_something(&mut client, now()); + assert_eq!(client.get_epochs(), (Some(4), Some(3))); +} + +#[test] +fn automatic_update_write_keys_blocked() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + // An outstanding key update will block the automatic update. + client.initiate_key_update().unwrap(); + + overwrite_invocations(UPDATE_WRITE_KEYS_AT); + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert!(client.stream_send(stream_id, b"explode!").is_ok()); + let dgram = client.process_output(now()).dgram(); + // Not being able to update is fatal. + assert!(dgram.is_none()); + assert!(matches!( + client.state(), + State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + )); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/mod.rs b/third_party/rust/neqo-transport/src/connection/tests/mod.rs new file mode 100644 index 0000000000..20add0faad --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/mod.rs @@ -0,0 +1,305 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![deny(clippy::pedantic)] + +use super::{ + Connection, ConnectionError, FixedConnectionIdManager, Output, State, LOCAL_IDLE_TIMEOUT, +}; +use crate::addr_valid::{AddressValidation, ValidateAddress}; +use crate::cc::CWND_INITIAL_PKTS; +use crate::events::ConnectionEvent; +use crate::frame::StreamType; +use crate::path::PATH_MTU_V6; +use crate::recovery::ACK_ONLY_SIZE_LIMIT; +use crate::ConnectionParameters; + +use std::cell::RefCell; +use std::mem; +use std::rc::Rc; +use std::time::{Duration, Instant}; + +use neqo_common::{event::Provider, qdebug, qtrace, Datagram}; +use neqo_crypto::{AllowZeroRtt, AuthenticationStatus, ResumptionToken}; +use test_fixture::{self, fixture_init, loopback, now}; + +// All the tests. +mod cc; +mod close; +mod handshake; +mod idle; +mod keys; +mod recovery; +mod resumption; +mod stream; +mod vn; +mod zerortt; + +const DEFAULT_RTT: Duration = Duration::from_millis(100); +const AT_LEAST_PTO: Duration = Duration::from_secs(1); +const DEFAULT_STREAM_DATA: &[u8] = b"message"; + +// This is fabulous: because test_fixture uses the public API for Connection, +// it gets a different type to the ones that are referenced via super::super::*. +// Thus, this code can't use default_client() and default_server() from +// test_fixture because they produce different types. +// +// These are a direct copy of those functions. +pub fn default_client() -> Connection { + fixture_init(); + Connection::new_client( + test_fixture::DEFAULT_SERVER_NAME, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(FixedConnectionIdManager::new(3))), + loopback(), + loopback(), + &ConnectionParameters::default(), + ) + .expect("create a default client") +} +pub fn default_server() -> Connection { + fixture_init(); + + let mut c = Connection::new_server( + test_fixture::DEFAULT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(FixedConnectionIdManager::new(5))), + &ConnectionParameters::default(), + ) + .expect("create a default server"); + c.server_enable_0rtt(&test_fixture::anti_replay(), AllowZeroRtt {}) + .expect("enable 0-RTT"); + c +} + +/// If state is `AuthenticationNeeded` call `authenticated()`. This function will +/// consume all outstanding events on the connection. +pub fn maybe_authenticate(conn: &mut Connection) -> bool { + let authentication_needed = |e| matches!(e, ConnectionEvent::AuthenticationNeeded); + if conn.events().any(authentication_needed) { + conn.authenticated(AuthenticationStatus::Ok, now()); + return true; + } + false +} + +/// Drive the handshake between the client and server. +fn handshake( + client: &mut Connection, + server: &mut Connection, + now: Instant, + rtt: Duration, +) -> Instant { + let mut a = client; + let mut b = server; + let mut now = now; + + let mut input = None; + let is_done = |c: &mut Connection| matches!(c.state(), State::Confirmed | State::Closing { .. } | State::Closed(..)); + + while !is_done(a) { + let _ = maybe_authenticate(a); + let had_input = input.is_some(); + let output = a.process(input, now).dgram(); + assert!(had_input || output.is_some()); + input = output; + qtrace!("t += {:?}", rtt / 2); + now += rtt / 2; + mem::swap(&mut a, &mut b); + } + let _ = a.process(input, now); + now +} + +fn connect_with_rtt( + client: &mut Connection, + server: &mut Connection, + now: Instant, + rtt: Duration, +) -> Instant { + let now = handshake(client, server, now, rtt); + assert_eq!(*client.state(), State::Confirmed); + assert_eq!(*client.state(), State::Confirmed); + + assert_eq!(client.loss_recovery.rtt(), rtt); + assert_eq!(server.loss_recovery.rtt(), rtt); + now +} + +fn connect(client: &mut Connection, server: &mut Connection) { + connect_with_rtt(client, server, now(), Duration::new(0, 0)); +} + +fn assert_error(c: &Connection, err: &ConnectionError) { + match c.state() { + State::Closing { error, .. } | State::Draining { error, .. } | State::Closed(error) => { + assert_eq!(*error, *err); + } + _ => panic!("bad state {:?}", c.state()), + } +} + +fn exchange_ticket( + client: &mut Connection, + server: &mut Connection, + now: Instant, +) -> ResumptionToken { + let validation = AddressValidation::new(now, ValidateAddress::NoToken).unwrap(); + let validation = Rc::new(RefCell::new(validation)); + server.set_validation(Rc::clone(&validation)); + server.send_ticket(now, &[]).expect("can send ticket"); + let ticket = server.process_output(now).dgram(); + assert!(ticket.is_some()); + client.process_input(ticket.unwrap(), now); + assert_eq!(*client.state(), State::Confirmed); + get_tokens(client).pop().expect("should have token") +} + +/// Connect with an RTT and then force both peers to be idle. +/// Getting the client and server to reach an idle state is surprisingly hard. +/// The server sends `HANDSHAKE_DONE` at the end of the handshake, and the client +/// doesn't immediately acknowledge it. Reordering packets does the trick. +fn connect_rtt_idle(client: &mut Connection, server: &mut Connection, rtt: Duration) -> Instant { + let mut now = connect_with_rtt(client, server, now(), rtt); + let p1 = send_something(server, now); + let p2 = send_something(server, now); + now += rtt / 2; + // Delivering p2 first at the client causes it to want to ACK. + client.process_input(p2, now); + // Delivering p1 should not have the client change its mind about the ACK. + let ack = client.process(Some(p1), now).dgram(); + assert!(ack.is_some()); + assert_eq!( + server.process(ack, now), + Output::Callback(LOCAL_IDLE_TIMEOUT) + ); + assert_eq!( + client.process_output(now), + Output::Callback(LOCAL_IDLE_TIMEOUT) + ); + now +} + +fn connect_force_idle(client: &mut Connection, server: &mut Connection) { + connect_rtt_idle(client, server, Duration::new(0, 0)); +} + +/// This fills the congestion window from a single source. +/// As the pacer will interfere with this, this moves time forward +/// as `Output::Callback` is received. Because it is hard to tell +/// from the return value whether a timeout is an ACK delay, PTO, or +/// pacing, this looks at the congestion window to tell when to stop. +/// Returns a list of datagrams and the new time. +fn fill_cwnd(src: &mut Connection, stream: u64, mut now: Instant) -> (Vec<Datagram>, Instant) { + const BLOCK_SIZE: usize = 4_096; + let mut total_dgrams = Vec::new(); + + qtrace!( + "fill_cwnd starting cwnd: {}", + src.loss_recovery.cwnd_avail() + ); + + loop { + let bytes_sent = src.stream_send(stream, &[0x42; BLOCK_SIZE]).unwrap(); + qtrace!("fill_cwnd wrote {} bytes", bytes_sent); + if bytes_sent < BLOCK_SIZE { + break; + } + } + + loop { + let pkt = src.process_output(now); + qtrace!( + "fill_cwnd cwnd remaining={}, output: {:?}", + src.loss_recovery.cwnd_avail(), + pkt + ); + match pkt { + Output::Datagram(dgram) => { + total_dgrams.push(dgram); + } + Output::Callback(t) => { + if src.loss_recovery.cwnd_avail() < ACK_ONLY_SIZE_LIMIT { + break; + } + now += t; + } + Output::None => panic!(), + } + } + + (total_dgrams, now) +} + +/// This magic number is the size of the client's CWND after the handshake completes. +/// This is the same as the initial congestion window, because during the handshake +/// the cc is app limited and cwnd is not increased. +/// +/// As we change how we build packets, or even as NSS changes, +/// this number might be different. The tests that depend on this +/// value could fail as a result of variations, so it's OK to just +/// change this value, but it is good to first understand where the +/// change came from. +const POST_HANDSHAKE_CWND: usize = PATH_MTU_V6 * CWND_INITIAL_PKTS; + +/// Determine the number of packets required to fill the CWND. +const fn cwnd_packets(data: usize) -> usize { + (data + ACK_ONLY_SIZE_LIMIT - 1) / PATH_MTU_V6 +} + +/// Determine the size of the last packet. +/// The minimal size of a packet is `ACK_ONLY_SIZE_LIMIT`. +fn last_packet(cwnd: usize) -> usize { + if (cwnd % PATH_MTU_V6) > ACK_ONLY_SIZE_LIMIT { + cwnd % PATH_MTU_V6 + } else { + PATH_MTU_V6 + } +} + +/// Assert that the set of packets fill the CWND. +fn assert_full_cwnd(packets: &[Datagram], cwnd: usize) { + assert_eq!(packets.len(), cwnd_packets(cwnd)); + let (last, rest) = packets.split_last().unwrap(); + assert!(rest.iter().all(|d| d.len() == PATH_MTU_V6)); + assert_eq!(last.len(), last_packet(cwnd)); +} + +/// Send something on a stream from `sender` to `receiver`. +/// Return the resulting datagram. +#[must_use] +fn send_something(sender: &mut Connection, now: Instant) -> Datagram { + let stream_id = sender.stream_create(StreamType::UniDi).unwrap(); + assert!(sender.stream_send(stream_id, DEFAULT_STREAM_DATA).is_ok()); + assert!(sender.stream_close_send(stream_id).is_ok()); + qdebug!([sender], "send_something on {}", stream_id); + let dgram = sender.process(None, now).dgram(); + dgram.expect("should have something to send") +} + +/// Send something on a stream from `sender` to `receiver`. +/// Return any ACK that might result. +fn send_and_receive( + sender: &mut Connection, + receiver: &mut Connection, + now: Instant, +) -> Option<Datagram> { + let dgram = send_something(sender, now); + receiver.process(Some(dgram), now).dgram() +} + +fn get_tokens(client: &mut Connection) -> Vec<ResumptionToken> { + client + .events() + .filter_map(|e| { + if let ConnectionEvent::ResumptionToken(token) = e { + Some(token) + } else { + None + } + }) + .collect() +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/recovery.rs b/third_party/rust/neqo-transport/src/connection/tests/recovery.rs new file mode 100644 index 0000000000..ba7069ccb2 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/recovery.rs @@ -0,0 +1,636 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::super::{Output, State, LOCAL_IDLE_TIMEOUT}; +use super::{ + assert_full_cwnd, connect, connect_force_idle, connect_with_rtt, default_client, + default_server, fill_cwnd, maybe_authenticate, send_and_receive, send_something, AT_LEAST_PTO, + POST_HANDSHAKE_CWND, +}; +use crate::frame::StreamType; +use crate::path::PATH_MTU_V6; +use crate::recovery::PTO_PACKET_COUNT; +use crate::stats::MAX_PTO_COUNTS; +use crate::tparams::TransportParameter; +use crate::tracking::ACK_DELAY; + +use neqo_common::qdebug; +use neqo_crypto::AuthenticationStatus; +use std::time::Duration; +use test_fixture::{self, now, split_datagram}; + +#[test] +fn pto_works_basic() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let mut now = now(); + + let res = client.process(None, now); + assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT)); + + // Send data on two streams + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); + assert_eq!(client.stream_send(2, b"hello").unwrap(), 5); + assert_eq!(client.stream_send(2, b" world").unwrap(), 6); + + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 6); + assert_eq!(client.stream_send(6, b"there!").unwrap(), 6); + + // Send a packet after some time. + now += Duration::from_secs(10); + let out = client.process(None, now); + assert!(out.dgram().is_some()); + + // Nothing to do, should return callback + let out = client.process(None, now); + assert!(matches!(out, Output::Callback(_))); + + // One second later, it should want to send PTO packet + now += AT_LEAST_PTO; + let out = client.process(None, now); + + let stream_before = server.stats().frame_rx.stream; + server.process_input(out.dgram().unwrap(), now); + assert_eq!(server.stats().frame_rx.stream, stream_before + 2); +} + +#[test] +fn pto_works_full_cwnd() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let res = client.process(None, now()); + assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT)); + + // Send lots of data. + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); + let (dgrams, now) = fill_cwnd(&mut client, 2, now()); + assert_full_cwnd(&dgrams, POST_HANDSHAKE_CWND); + + neqo_common::qwarn!("waiting over"); + // Fill the CWND after waiting for a PTO. + let (dgrams, now) = fill_cwnd(&mut client, 2, now + AT_LEAST_PTO); + // Two packets in the PTO. + // The first should be full sized; the second might be small. + assert_eq!(dgrams.len(), 2); + assert_eq!(dgrams[0].len(), PATH_MTU_V6); + + // Both datagrams contain a STREAM frame. + for d in dgrams { + let stream_before = server.stats().frame_rx.stream; + server.process_input(d, now); + assert_eq!(server.stats().frame_rx.stream, stream_before + 1); + } +} + +#[test] +#[allow(clippy::cognitive_complexity)] +fn pto_works_ping() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let now = now(); + + let res = client.process(None, now); + assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT)); + + // Send "zero" pkt + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); + assert_eq!(client.stream_send(2, b"zero").unwrap(), 4); + let pkt0 = client.process(None, now + Duration::from_secs(10)); + assert!(matches!(pkt0, Output::Datagram(_))); + + // Send "one" pkt + assert_eq!(client.stream_send(2, b"one").unwrap(), 3); + let pkt1 = client.process(None, now + Duration::from_secs(10)); + + // Send "two" pkt + assert_eq!(client.stream_send(2, b"two").unwrap(), 3); + let pkt2 = client.process(None, now + Duration::from_secs(10)); + + // Send "three" pkt + assert_eq!(client.stream_send(2, b"three").unwrap(), 5); + let pkt3 = client.process(None, now + Duration::from_secs(10)); + + // Nothing to do, should return callback + let out = client.process(None, now + Duration::from_secs(10)); + // Check callback delay is what we expect + assert!(matches!(out, Output::Callback(x) if x == Duration::from_millis(45))); + + // Process these by server, skipping pkt0 + let srv0_pkt1 = server.process(pkt1.dgram(), now + Duration::from_secs(10)); + // ooo, ack client pkt 1 + assert!(matches!(srv0_pkt1, Output::Datagram(_))); + + // process pkt2 (no ack yet) + let srv2 = server.process( + pkt2.dgram(), + now + Duration::from_secs(10) + Duration::from_millis(20), + ); + assert!(matches!(srv2, Output::Callback(_))); + + // process pkt3 (acked) + let srv2 = server.process( + pkt3.dgram(), + now + Duration::from_secs(10) + Duration::from_millis(20), + ); + // ack client pkt 2 & 3 + assert!(matches!(srv2, Output::Datagram(_))); + + // client processes ack + let pkt4 = client.process( + srv2.dgram(), + now + Duration::from_secs(10) + Duration::from_millis(40), + ); + // client resends data from pkt0 + assert!(matches!(pkt4, Output::Datagram(_))); + + // server sees ooo pkt0 and generates ack + let srv_pkt2 = server.process( + pkt0.dgram(), + now + Duration::from_secs(10) + Duration::from_millis(40), + ); + assert!(matches!(srv_pkt2, Output::Datagram(_))); + + // Orig data is acked + let pkt5 = client.process( + srv_pkt2.dgram(), + now + Duration::from_secs(10) + Duration::from_millis(40), + ); + assert!(matches!(pkt5, Output::Callback(_))); + + // PTO expires. No unacked data. Only send PING. + let pkt6 = client.process( + None, + now + Duration::from_secs(10) + Duration::from_millis(110), + ); + + let ping_before = server.stats().frame_rx.ping; + server.process_input( + pkt6.dgram().unwrap(), + now + Duration::from_secs(10) + Duration::from_millis(110), + ); + assert_eq!(server.stats().frame_rx.ping, ping_before + 1); +} + +#[test] +fn pto_initial() { + const INITIAL_PTO: Duration = Duration::from_millis(300); + let mut now = now(); + + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let pkt1 = client.process(None, now).dgram(); + assert!(pkt1.is_some()); + assert_eq!(pkt1.clone().unwrap().len(), PATH_MTU_V6); + + let delay = client.process(None, now).callback(); + assert_eq!(delay, INITIAL_PTO); + + // Resend initial after PTO. + now += delay; + let pkt2 = client.process(None, now).dgram(); + assert!(pkt2.is_some()); + assert_eq!(pkt2.unwrap().len(), PATH_MTU_V6); + + let pkt3 = client.process(None, now).dgram(); + assert!(pkt3.is_some()); + assert_eq!(pkt3.unwrap().len(), PATH_MTU_V6); + + let delay = client.process(None, now).callback(); + // PTO has doubled. + assert_eq!(delay, INITIAL_PTO * 2); + + // Server process the first initial pkt. + let mut server = default_server(); + let out = server.process(pkt1, now).dgram(); + assert!(out.is_some()); + + // Client receives ack for the first initial packet as well a Handshake packet. + // After the handshake packet the initial keys and the crypto stream for the initial + // packet number space will be discarded. + // Here only an ack for the Handshake packet will be sent. + let out = client.process(out, now).dgram(); + assert!(out.is_some()); + + // We do not have PTO for the resent initial packet any more, but + // the Handshake PTO timer should be armed. As the RTT is apparently + // the same as the initial PTO value, and there is only one sample, + // the PTO will be 3x the INITIAL PTO. + let delay = client.process(None, now).callback(); + assert_eq!(delay, INITIAL_PTO * 3); +} + +/// A complete handshake that involves a PTO in the Handshake space. +#[test] +fn pto_handshake_complete() { + let mut now = now(); + // start handshake + let mut client = default_client(); + let mut server = default_server(); + + let pkt = client.process(None, now).dgram(); + let cb = client.process(None, now).callback(); + assert_eq!(cb, Duration::from_millis(300)); + + now += Duration::from_millis(10); + let pkt = server.process(pkt, now).dgram(); + + now += Duration::from_millis(10); + let pkt = client.process(pkt, now).dgram(); + + let cb = client.process(None, now).callback(); + // The client now has a single RTT estimate (20ms), so + // the handshake PTO is set based on that. + assert_eq!(cb, Duration::from_millis(60)); + + now += Duration::from_millis(10); + let pkt = server.process(pkt, now).dgram(); + assert!(pkt.is_none()); + + now += Duration::from_millis(10); + client.authenticated(AuthenticationStatus::Ok, now); + + qdebug!("---- client: SH..FIN -> FIN"); + let pkt1 = client.process(None, now).dgram(); + assert!(pkt1.is_some()); + + let cb = client.process(None, now).callback(); + assert_eq!(cb, Duration::from_millis(60)); + + let mut pto_counts = [0; MAX_PTO_COUNTS]; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + // Wait for PTO to expire and resend a handshake packet + now += Duration::from_millis(60); + let pkt2 = client.process(None, now).dgram(); + assert!(pkt2.is_some()); + + pto_counts[0] = 1; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + // Get a second PTO packet. + let pkt3 = client.process(None, now).dgram(); + assert!(pkt3.is_some()); + + // PTO has been doubled. + let cb = client.process(None, now).callback(); + assert_eq!(cb, Duration::from_millis(120)); + + // We still have only a single PTO + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + now += Duration::from_millis(10); + // Server receives the first packet. + // The output will be a Handshake packet with an ack and a app pn space packet with + // HANDSHAKE_DONE. + let pkt = server.process(pkt1, now).dgram(); + assert!(pkt.is_some()); + + // Check that the PTO packets (pkt2, pkt3) are Handshake packets. + // The server discarded the Handshake keys already, therefore they are dropped. + let dropped_before1 = server.stats().dropped_rx; + let frames_before = server.stats().frame_rx.all; + server.process_input(pkt2.unwrap(), now); + assert_eq!(1, server.stats().dropped_rx - dropped_before1); + assert_eq!(server.stats().frame_rx.all, frames_before); + + let dropped_before2 = server.stats().dropped_rx; + server.process_input(pkt3.unwrap(), now); + assert_eq!(1, server.stats().dropped_rx - dropped_before2); + assert_eq!(server.stats().frame_rx.all, frames_before); + + now += Duration::from_millis(10); + // Client receive ack for the first packet + let cb = client.process(pkt, now).callback(); + // Ack delay timer for the packet carrying HANDSHAKE_DONE. + assert_eq!(cb, ACK_DELAY); + + // Let the ack timer expire. + now += cb; + let out = client.process(None, now).dgram(); + assert!(out.is_some()); + let cb = client.process(None, now).callback(); + // The handshake keys are discarded, but now we're back to the idle timeout. + // We don't send another PING because the handshake space is done and there + // is nothing to probe for. + + pto_counts[0] = 1; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + assert_eq!(cb, LOCAL_IDLE_TIMEOUT - ACK_DELAY); +} + +/// Test that PTO in the Handshake space contains the right frames. +#[test] +fn pto_handshake_frames() { + let mut now = now(); + qdebug!("---- client: generate CH"); + let mut client = default_client(); + let pkt = client.process(None, now); + + now += Duration::from_millis(10); + qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); + let mut server = default_server(); + let pkt = server.process(pkt.dgram(), now); + + now += Duration::from_millis(10); + qdebug!("---- client: cert verification"); + let pkt = client.process(pkt.dgram(), now); + + now += Duration::from_millis(10); + let _ = server.process(pkt.dgram(), now); + + now += Duration::from_millis(10); + client.authenticated(AuthenticationStatus::Ok, now); + + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); + assert_eq!(client.stream_send(2, b"zero").unwrap(), 4); + qdebug!("---- client: SH..FIN -> FIN and 1RTT packet"); + let pkt1 = client.process(None, now).dgram(); + assert!(pkt1.is_some()); + + // Get PTO timer. + let out = client.process(None, now); + assert_eq!(out, Output::Callback(Duration::from_millis(60))); + + // Wait for PTO to expire and resend a handshake packet. + now += Duration::from_millis(60); + let pkt2 = client.process(None, now).dgram(); + assert!(pkt2.is_some()); + + now += Duration::from_millis(10); + let crypto_before = server.stats().frame_rx.crypto; + server.process_input(pkt2.unwrap(), now); + assert_eq!(server.stats().frame_rx.crypto, crypto_before + 1) +} + +/// In the case that the Handshake takes too many packets, the server might +/// be stalled on the anti-amplification limit. If a Handshake ACK from the +/// client is lost, the client has to keep the PTO timer armed or the server +/// might be unable to send anything, causing a deadlock. +#[test] +fn handshake_ack_pto() { + const RTT: Duration = Duration::from_millis(10); + let mut now = now(); + let mut client = default_client(); + let mut server = default_server(); + // This is a greasing transport parameter, and large enough that the + // server needs to send two Handshake packets. + let big = TransportParameter::Bytes(vec![0; PATH_MTU_V6]); + server.set_local_tparam(0xce16, big).unwrap(); + + let c1 = client.process(None, now).dgram(); + + now += RTT / 2; + let s1 = server.process(c1, now).dgram(); + assert!(s1.is_some()); + let s2 = server.process(None, now).dgram(); + assert!(s1.is_some()); + + // Now let the client have the Initial, but drop the first coalesced Handshake packet. + now += RTT / 2; + let (initial, _) = split_datagram(&s1.unwrap()); + client.process_input(initial, now); + let c2 = client.process(s2, now).dgram(); + assert!(c2.is_some()); // This is an ACK. Drop it. + let delay = client.process(None, now).callback(); + assert_eq!(delay, RTT * 3); + + let mut pto_counts = [0; MAX_PTO_COUNTS]; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + // Wait for the PTO and ensure that the client generates a packet. + now += delay; + let c3 = client.process(None, now).dgram(); + assert!(c3.is_some()); + + now += RTT / 2; + let ping_before = server.stats().frame_rx.ping; + server.process_input(c3.unwrap(), now); + assert_eq!(server.stats().frame_rx.ping, ping_before + 1); + + pto_counts[0] = 1; + assert_eq!(client.stats.borrow().pto_counts, pto_counts); + + // Now complete the handshake as cheaply as possible. + let dgram = server.process(None, now).dgram(); + client.process_input(dgram.unwrap(), now); + maybe_authenticate(&mut client); + let dgram = client.process(None, now).dgram(); + assert_eq!(*client.state(), State::Connected); + let dgram = server.process(dgram, now).dgram(); + assert_eq!(*server.state(), State::Confirmed); + client.process_input(dgram.unwrap(), now); + assert_eq!(*client.state(), State::Confirmed); + + assert_eq!(client.stats.borrow().pto_counts, pto_counts); +} + +#[test] +fn loss_recovery_crash() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + let now = now(); + + // The server sends something, but we will drop this. + let _ = send_something(&mut server, now); + + // Then send something again, but let it through. + let ack = send_and_receive(&mut server, &mut client, now); + assert!(ack.is_some()); + + // Have the server process the ACK. + let cb = server.process(ack, now).callback(); + assert!(cb > Duration::from_secs(0)); + + // Now we leap into the future. The server should regard the first + // packet as lost based on time alone. + let dgram = server.process(None, now + AT_LEAST_PTO).dgram(); + assert!(dgram.is_some()); + + // This crashes. + let _ = send_something(&mut server, now + AT_LEAST_PTO); +} + +// If we receive packets after the PTO timer has fired, we won't clear +// the PTO state, but we might need to acknowledge those packets. +// This shouldn't happen, but we found that some implementations do this. +#[test] +fn ack_after_pto() { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + let mut now = now(); + + // The client sends and is forced into a PTO. + let _ = send_something(&mut client, now); + + // Jump forward to the PTO and drain the PTO packets. + now += AT_LEAST_PTO; + for _ in 0..PTO_PACKET_COUNT { + let dgram = client.process(None, now).dgram(); + assert!(dgram.is_some()); + } + assert!(client.process(None, now).dgram().is_none()); + + // The server now needs to send something that will cause the + // client to want to acknowledge it. A little out of order + // delivery is just the thing. + // Note: The server can't ACK anything here, but none of what + // the client has sent so far has been transferred. + let _ = send_something(&mut server, now); + let dgram = send_something(&mut server, now); + + // The client is now after a PTO, but if it receives something + // that demands acknowledgment, it will send just the ACK. + let ack = client.process(Some(dgram), now).dgram(); + assert!(ack.is_some()); + + // Make sure that the packet only contained an ACK frame. + let all_frames_before = server.stats().frame_rx.all; + let ack_before = server.stats().frame_rx.ack; + server.process_input(ack.unwrap(), now); + assert_eq!(server.stats().frame_rx.all, all_frames_before + 1); + assert_eq!(server.stats().frame_rx.ack, ack_before + 1); +} + +/// When we declare a packet as lost, we keep it around for a while for another loss period. +/// Those packets should not affect how we report the loss recovery timer. +/// As the loss recovery timer based on RTT we use that to drive the state. +#[test] +fn lost_but_kept_and_lr_timer() { + const RTT: Duration = Duration::from_secs(1); + let mut client = default_client(); + let mut server = default_server(); + let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT); + + // Two packets (p1, p2) are sent at around t=0. The first is lost. + let _p1 = send_something(&mut client, now); + let p2 = send_something(&mut client, now); + + // At t=RTT/2 the server receives the packet and ACKs it. + now += RTT / 2; + let ack = server.process(Some(p2), now).dgram(); + assert!(ack.is_some()); + // The client also sends another two packets (p3, p4), again losing the first. + let _p3 = send_something(&mut client, now); + let p4 = send_something(&mut client, now); + + // At t=RTT the client receives the ACK and goes into timed loss recovery. + // The client doesn't call p1 lost at this stage, but it will soon. + now += RTT / 2; + let res = client.process(ack, now); + // The client should be on a loss recovery timer as p1 is missing. + let lr_timer = res.callback(); + // Loss recovery timer should be RTT/8, but only check for 0 or >=RTT/2. + assert_ne!(lr_timer, Duration::from_secs(0)); + assert!(lr_timer < (RTT / 2)); + // The server also receives and acknowledges p4, again sending an ACK. + let ack = server.process(Some(p4), now).dgram(); + assert!(ack.is_some()); + + // At t=RTT*3/2 the client should declare p1 to be lost. + now += RTT / 2; + // So the client will send the data from p1 again. + let res = client.process(None, now); + assert!(res.dgram().is_some()); + // When the client processes the ACK, it should engage the + // loss recovery timer for p3, not p1 (even though it still tracks p1). + let res = client.process(ack, now); + let lr_timer2 = res.callback(); + assert_eq!(lr_timer, lr_timer2); +} + +/// We should not be setting the loss recovery timer based on packets +/// that are sent prior to the largest acknowledged. +/// Testing this requires that we construct a case where one packet +/// number space causes the loss recovery timer to be engaged. At the same time, +/// there is a packet in another space that hasn't been acknowledged AND +/// that packet number space has not received acknowledgments for later packets. +#[test] +fn loss_time_past_largest_acked() { + const RTT: Duration = Duration::from_secs(10); + const INCR: Duration = Duration::from_millis(1); + let mut client = default_client(); + let mut server = default_server(); + + let mut now = now(); + + // Start the handshake. + let c_in = client.process(None, now).dgram(); + now += RTT / 2; + let s_hs1 = server.process(c_in, now).dgram(); + + // Get some spare server handshake packets for the client to ACK. + // This involves a time machine, so be a little cautious. + // This test uses an RTT of 10s, but our server starts + // with a much lower RTT estimate, so the PTO at this point should + // be much smaller than an RTT and so the server shouldn't see + // time go backwards. + let s_pto = server.process(None, now).callback(); + assert_ne!(s_pto, Duration::from_secs(0)); + assert!(s_pto < RTT); + let s_hs2 = server.process(None, now + s_pto).dgram(); + assert!(s_hs2.is_some()); + let s_hs3 = server.process(None, now + s_pto).dgram(); + assert!(s_hs3.is_some()); + + // Get some Handshake packets from the client. + // We need one to be left unacknowledged before one that is acknowledged. + // So that the client engages the loss recovery timer. + // This is complicated by the fact that it is hard to cause the client + // to generate an ack-eliciting packet. For that, we use the Finished message. + // Reordering delivery ensures that the later packet is also acknowledged. + now += RTT / 2; + let c_hs1 = client.process(s_hs1, now).dgram(); + assert!(c_hs1.is_some()); // This comes first, so it's useless. + maybe_authenticate(&mut client); + let c_hs2 = client.process(None, now).dgram(); + assert!(c_hs2.is_some()); // This one will elicit an ACK. + + // The we need the outstanding packet to be sent after the + // application data packet, so space these out a tiny bit. + let _p1 = send_something(&mut client, now + INCR); + let c_hs3 = client.process(s_hs2, now + (INCR * 2)).dgram(); + assert!(c_hs3.is_some()); // This will be left outstanding. + let c_hs4 = client.process(s_hs3, now + (INCR * 3)).dgram(); + assert!(c_hs4.is_some()); // This will be acknowledged. + + // Process c_hs2 and c_hs4, but skip c_hs3. + // Then get an ACK for the client. + now += RTT / 2; + // Deliver c_hs4 first, but don't generate a packet. + server.process_input(c_hs4.unwrap(), now); + let s_ack = server.process(c_hs2, now).dgram(); + assert!(s_ack.is_some()); + // This includes an ACK, but it also includes HANDSHAKE_DONE, + // which we need to remove because that will cause the Handshake loss + // recovery state to be dropped. + let (s_hs_ack, _s_ap_ack) = split_datagram(&s_ack.unwrap()); + + // Now the client should start its loss recovery timer based on the ACK. + now += RTT / 2; + let c_ack = client.process(Some(s_hs_ack), now).dgram(); + assert!(c_ack.is_none()); + // The client should now have the loss recovery timer active. + let lr_time = client.process(None, now).callback(); + assert_ne!(lr_time, Duration::from_secs(0)); + assert!(lr_time < (RTT / 2)); + + // Skipping forward by the loss recovery timer should cause the client to + // mark packets as lost and retransmit, after which we should be on the PTO + // timer. + now += lr_time; + let delay = client.process(None, now).callback(); + assert_ne!(delay, Duration::from_secs(0)); + assert!(delay > lr_time); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/resumption.rs b/third_party/rust/neqo-transport/src/connection/tests/resumption.rs new file mode 100644 index 0000000000..19421dded8 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/resumption.rs @@ -0,0 +1,182 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::{ + connect, connect_with_rtt, default_client, default_server, exchange_ticket, get_tokens, + send_something, AT_LEAST_PTO, +}; +use crate::addr_valid::{AddressValidation, ValidateAddress}; + +use std::cell::RefCell; +use std::rc::Rc; +use std::time::Duration; +use test_fixture::{self, assertions, now}; + +#[test] +fn resume() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let token = exchange_ticket(&mut client, &mut server, now()); + let mut client = default_client(); + client + .enable_resumption(now(), token) + .expect("should set token"); + let mut server = default_server(); + connect(&mut client, &mut server); + assert!(client.tls_info().unwrap().resumed()); + assert!(server.tls_info().unwrap().resumed()); +} + +#[test] +fn remember_smoothed_rtt() { + const RTT1: Duration = Duration::from_millis(130); + const RTT2: Duration = Duration::from_millis(70); + + let mut client = default_client(); + let mut server = default_server(); + + let now = connect_with_rtt(&mut client, &mut server, now(), RTT1); + assert_eq!(client.loss_recovery.rtt(), RTT1); + + let token = exchange_ticket(&mut client, &mut server, now); + let mut client = default_client(); + let mut server = default_server(); + client.enable_resumption(now, token).unwrap(); + assert_eq!( + client.loss_recovery.rtt(), + RTT1, + "client should remember previous RTT" + ); + + connect_with_rtt(&mut client, &mut server, now, RTT2); + assert_eq!( + client.loss_recovery.rtt(), + RTT2, + "previous RTT should be completely erased" + ); +} + +/// Check that a resumed connection uses a token on Initial packets. +#[test] +fn address_validation_token_resume() { + const RTT: Duration = Duration::from_millis(10); + + let mut client = default_client(); + let mut server = default_server(); + let validation = AddressValidation::new(now(), ValidateAddress::Always).unwrap(); + let validation = Rc::new(RefCell::new(validation)); + server.set_validation(Rc::clone(&validation)); + let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT); + + let token = exchange_ticket(&mut client, &mut server, now); + let mut client = default_client(); + client.enable_resumption(now, token).unwrap(); + let mut server = default_server(); + + // Grab an Initial packet from the client. + let dgram = client.process(None, now).dgram(); + assertions::assert_initial(dgram.as_ref().unwrap(), true); + + // Now try to complete the handshake after giving time for a client PTO. + now += AT_LEAST_PTO; + connect_with_rtt(&mut client, &mut server, now, RTT); + assert!(client.crypto.tls.info().unwrap().resumed()); + assert!(server.crypto.tls.info().unwrap().resumed()); +} + +fn can_resume(token: impl AsRef<[u8]>, initial_has_token: bool) { + let mut client = default_client(); + client.enable_resumption(now(), token).unwrap(); + let initial = client.process_output(now()).dgram(); + assertions::assert_initial(initial.as_ref().unwrap(), initial_has_token); +} + +#[test] +fn two_tickets_on_timer() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // Send two tickets and then bundle those into a packet. + server.send_ticket(now(), &[]).expect("send ticket1"); + server.send_ticket(now(), &[]).expect("send ticket2"); + let pkt = send_something(&mut server, now()); + + // process() will return an ack first + assert!(client.process(Some(pkt), now()).dgram().is_some()); + // We do not have a ResumptionToken event yet, because NEW_TOKEN was not sent. + assert_eq!(get_tokens(&mut client).len(), 0); + + // We need to wait for release_resumption_token_timer to expire. The timer will be + // set to 3 * PTO + let mut now = now() + 3 * client.get_pto(); + let _ = client.process(None, now); + let mut recv_tokens = get_tokens(&mut client); + assert_eq!(recv_tokens.len(), 1); + let token1 = recv_tokens.pop().unwrap(); + // Wai for anottheer 3 * PTO to get the nex okeen. + now += 3 * client.get_pto(); + let _ = client.process(None, now); + let mut recv_tokens = get_tokens(&mut client); + assert_eq!(recv_tokens.len(), 1); + let token2 = recv_tokens.pop().unwrap(); + // Wait for 3 * PTO, but now there are no more tokens. + now += 3 * client.get_pto(); + let _ = client.process(None, now); + assert_eq!(get_tokens(&mut client).len(), 0); + assert_ne!(token1.as_ref(), token2.as_ref()); + + can_resume(&token1, false); + can_resume(&token2, false); +} + +#[test] +fn two_tickets_with_new_token() { + let mut client = default_client(); + let mut server = default_server(); + let validation = AddressValidation::new(now(), ValidateAddress::Always).unwrap(); + let validation = Rc::new(RefCell::new(validation)); + server.set_validation(Rc::clone(&validation)); + connect(&mut client, &mut server); + + // Send two tickets with tokens and then bundle those into a packet. + server.send_ticket(now(), &[]).expect("send ticket1"); + server.send_ticket(now(), &[]).expect("send ticket2"); + let pkt = send_something(&mut server, now()); + + client.process_input(pkt, now()); + let mut all_tokens = get_tokens(&mut client); + assert_eq!(all_tokens.len(), 2); + let token1 = all_tokens.pop().unwrap(); + let token2 = all_tokens.pop().unwrap(); + assert_ne!(token1.as_ref(), token2.as_ref()); + + can_resume(&token1, true); + can_resume(&token2, true); +} + +/// By disabling address validation, the server won't send `NEW_TOKEN`, but +/// we can take the session ticket still. +#[test] +fn take_token() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + server.send_ticket(now(), &[]).unwrap(); + let dgram = server.process(None, now()).dgram(); + client.process_input(dgram.unwrap(), now()); + + // There should be no ResumptionToken event here. + let tokens = get_tokens(&mut client); + assert_eq!(tokens.len(), 0); + + // But we should be able to get the token directly, and use it. + let token = client.take_resumption_token(now()).unwrap(); + can_resume(&token, false); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/stream.rs b/third_party/rust/neqo-transport/src/connection/tests/stream.rs new file mode 100644 index 0000000000..1c9eebaa17 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/stream.rs @@ -0,0 +1,580 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::super::State; +use super::{ + connect, default_client, default_server, maybe_authenticate, send_something, + DEFAULT_STREAM_DATA, +}; +use crate::events::ConnectionEvent; +use crate::frame::StreamType; +use crate::recv_stream::RECV_BUFFER_SIZE; +use crate::send_stream::SEND_BUFFER_SIZE; +use crate::tparams::{self, TransportParameter}; +use crate::tracking::MAX_UNACKED_PKTS; +use crate::{Error, StreamId}; + +use neqo_common::{event::Provider, qdebug}; +use std::convert::TryFrom; +use test_fixture::now; + +#[test] +fn stream_create() { + let mut client = default_client(); + + let out = client.process(None, now()); + let mut server = default_server(); + let out = server.process(out.dgram(), now()); + + let out = client.process(out.dgram(), now()); + let _ = server.process(out.dgram(), now()); + assert!(maybe_authenticate(&mut client)); + let out = client.process(None, now()); + + // client now in State::Connected + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); + assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 6); + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0); + assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 4); + + let _ = server.process(out.dgram(), now()); + // server now in State::Connected + assert_eq!(server.stream_create(StreamType::UniDi).unwrap(), 3); + assert_eq!(server.stream_create(StreamType::UniDi).unwrap(), 7); + assert_eq!(server.stream_create(StreamType::BiDi).unwrap(), 1); + assert_eq!(server.stream_create(StreamType::BiDi).unwrap(), 5); +} + +#[test] +#[allow(clippy::cognitive_complexity)] +// tests stream send/recv after connection is established. +fn transfer() { + let mut client = default_client(); + let mut server = default_server(); + + qdebug!("---- client"); + let out = client.process(None, now()); + assert!(out.as_dgram_ref().is_some()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + // -->> Initial[0]: CRYPTO[CH] + + qdebug!("---- server"); + let out = server.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_some()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + // <<-- Initial[0]: CRYPTO[SH] ACK[0] + // <<-- Handshake[0]: CRYPTO[EE, CERT, CV, FIN] + + qdebug!("---- client"); + let out = client.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_some()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + // -->> Initial[1]: ACK[0] + + let out = server.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_none()); + + assert!(maybe_authenticate(&mut client)); + + qdebug!("---- client"); + let out = client.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(*client.state(), State::Connected); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + // -->> Handshake[0]: CRYPTO[FIN], ACK[0] + + qdebug!("---- server"); + let out = server.process(out.dgram(), now()); + assert!(out.as_dgram_ref().is_some()); + assert_eq!(*server.state(), State::Confirmed); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + // ACK and HANDSHAKE_DONE + // -->> nothing + + qdebug!("---- client"); + // Send + let client_stream_id = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(client_stream_id, &[6; 100]).unwrap(); + client.stream_send(client_stream_id, &[7; 40]).unwrap(); + client.stream_send(client_stream_id, &[8; 4000]).unwrap(); + + // Send to another stream but some data after fin has been set + let client_stream_id2 = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(client_stream_id2, &[6; 60]).unwrap(); + client.stream_close_send(client_stream_id2).unwrap(); + client.stream_send(client_stream_id2, &[7; 50]).unwrap_err(); + // Sending this much takes a few datagrams. + let mut datagrams = vec![]; + let mut out = client.process(out.dgram(), now()); + while let Some(d) = out.dgram() { + datagrams.push(d); + out = client.process(None, now()); + } + assert_eq!(datagrams.len(), 4); + assert_eq!(*client.state(), State::Confirmed); + + qdebug!("---- server"); + for (d_num, d) in datagrams.into_iter().enumerate() { + let out = server.process(Some(d), now()); + assert_eq!( + out.as_dgram_ref().is_some(), + (d_num + 1) % (MAX_UNACKED_PKTS + 1) == 0 + ); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + } + assert_eq!(*server.state(), State::Confirmed); + + let mut buf = vec![0; 4000]; + + let mut stream_ids = server.events().filter_map(|evt| match evt { + ConnectionEvent::NewStream { stream_id, .. } => Some(stream_id), + _ => None, + }); + let first_stream = stream_ids.next().expect("should have a new stream event"); + let second_stream = stream_ids + .next() + .expect("should have a second new stream event"); + assert!(stream_ids.next().is_none()); + let (received1, fin1) = server.stream_recv(first_stream.as_u64(), &mut buf).unwrap(); + assert_eq!(received1, 4000); + assert_eq!(fin1, false); + let (received2, fin2) = server.stream_recv(first_stream.as_u64(), &mut buf).unwrap(); + assert_eq!(received2, 140); + assert_eq!(fin2, false); + + let (received3, fin3) = server + .stream_recv(second_stream.as_u64(), &mut buf) + .unwrap(); + assert_eq!(received3, 60); + assert_eq!(fin3, true); +} + +#[test] +// Send fin even if a peer closes a reomte bidi send stream before sending any data. +fn report_fin_when_stream_closed_wo_data() { + // Note that the two servers in this test will get different anti-replay filters. + // That's OK because we aren't testing anti-replay. + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // create a stream + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_send(stream_id, &[0x00]).unwrap(); + let out = client.process(None, now()); + let _ = server.process(out.dgram(), now()); + + assert_eq!(Ok(()), server.stream_close_send(stream_id)); + let out = server.process(None, now()); + let _ = client.process(out.dgram(), now()); + let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable {..}); + assert!(client.events().any(stream_readable)); +} + +#[test] +fn max_data() { + const SMALL_MAX_DATA: usize = 16383; + + let mut client = default_client(); + let mut server = default_server(); + + server + .set_local_tparam( + tparams::INITIAL_MAX_DATA, + TransportParameter::Integer(u64::try_from(SMALL_MAX_DATA).unwrap()), + ) + .unwrap(); + + connect(&mut client, &mut server); + + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(client.events().count(), 2); // SendStreamWritable, StateChange(connected) + assert_eq!(stream_id, 2); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SMALL_MAX_DATA + ); + assert_eq!( + client + .stream_send(stream_id, &vec![b'a'; RECV_BUFFER_SIZE].into_boxed_slice()) + .unwrap(), + SMALL_MAX_DATA + ); + assert_eq!(client.events().count(), 0); + + assert_eq!(client.stream_send(stream_id, b"hello").unwrap(), 0); + client + .send_streams + .get_mut(stream_id.into()) + .unwrap() + .mark_as_sent(0, 4096, false); + assert_eq!(client.events().count(), 0); + client + .send_streams + .get_mut(stream_id.into()) + .unwrap() + .mark_as_acked(0, 4096, false); + assert_eq!(client.events().count(), 0); + + assert_eq!(client.stream_send(stream_id, b"hello").unwrap(), 0); + // no event because still limited by conn max data + assert_eq!(client.events().count(), 0); + + // Increase max data. Avail space now limited by stream credit + client.handle_max_data(100_000_000); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SEND_BUFFER_SIZE - SMALL_MAX_DATA + ); + + // Increase max stream data. Avail space now limited by tx buffer + client + .send_streams + .get_mut(stream_id.into()) + .unwrap() + .set_max_stream_data(100_000_000); + assert_eq!( + client.stream_avail_send_space(stream_id).unwrap(), + SEND_BUFFER_SIZE - SMALL_MAX_DATA + 4096 + ); + + let evts = client.events().collect::<Vec<_>>(); + assert_eq!(evts.len(), 1); + assert!(matches!(evts[0], ConnectionEvent::SendStreamWritable{..})); +} + +#[test] +// If we send a stop_sending to the peer, we should not accept more data from the peer. +fn do_not_accept_data_after_stop_sending() { + // Note that the two servers in this test will get different anti-replay filters. + // That's OK because we aren't testing anti-replay. + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // create a stream + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_send(stream_id, &[0x00]).unwrap(); + let out = client.process(None, now()); + let _ = server.process(out.dgram(), now()); + + let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable {..}); + assert!(server.events().any(stream_readable)); + + // Send one more packet from client. The packet should arrive after the server + // has already requested stop_sending. + client.stream_send(stream_id, &[0x00]).unwrap(); + let out_second_data_frame = client.process(None, now()); + // Call stop sending. + assert_eq!( + Ok(()), + server.stream_stop_sending(stream_id, Error::NoError.code()) + ); + + // Receive the second data frame. The frame should be ignored and + // DataReadable events shouldn't be posted. + let out = server.process(out_second_data_frame.dgram(), now()); + assert!(!server.events().any(stream_readable)); + + let _ = client.process(out.dgram(), now()); + assert_eq!( + Err(Error::FinalSizeError), + client.stream_send(stream_id, &[0x00]) + ); +} + +#[test] +// Server sends stop_sending, the client simultaneous sends reset. +fn simultaneous_stop_sending_and_reset() { + // Note that the two servers in this test will get different anti-replay filters. + // That's OK because we aren't testing anti-replay. + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // create a stream + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_send(stream_id, &[0x00]).unwrap(); + let out = client.process(None, now()); + let _ = server.process(out.dgram(), now()); + + let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable {..}); + assert!(server.events().any(stream_readable)); + + // The client resets the stream. The packet with reset should arrive after the server + // has already requested stop_sending. + client + .stream_reset_send(stream_id, Error::NoError.code()) + .unwrap(); + let out_reset_frame = client.process(None, now()); + // Call stop sending. + assert_eq!( + Ok(()), + server.stream_stop_sending(stream_id, Error::NoError.code()) + ); + + // Receive the second data frame. The frame should be ignored and + // DataReadable events shouldn't be posted. + let out = server.process(out_reset_frame.dgram(), now()); + assert!(!server.events().any(stream_readable)); + + // The client gets the STOP_SENDING frame. + let _ = client.process(out.dgram(), now()); + assert_eq!( + Err(Error::InvalidStreamId), + client.stream_send(stream_id, &[0x00]) + ); +} + +#[test] +fn client_fin_reorder() { + let mut client = default_client(); + let mut server = default_server(); + + // Send ClientHello. + let client_hs = client.process(None, now()); + assert!(client_hs.as_dgram_ref().is_some()); + + let server_hs = server.process(client_hs.dgram(), now()); + assert!(server_hs.as_dgram_ref().is_some()); // ServerHello, etc... + + let client_ack = client.process(server_hs.dgram(), now()); + assert!(client_ack.as_dgram_ref().is_some()); + + let server_out = server.process(client_ack.dgram(), now()); + assert!(server_out.as_dgram_ref().is_none()); + + assert!(maybe_authenticate(&mut client)); + assert_eq!(*client.state(), State::Connected); + + let client_fin = client.process(None, now()); + assert!(client_fin.as_dgram_ref().is_some()); + + let client_stream_id = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(client_stream_id, &[1, 2, 3]).unwrap(); + let client_stream_data = client.process(None, now()); + assert!(client_stream_data.as_dgram_ref().is_some()); + + // Now stream data gets before client_fin + let server_out = server.process(client_stream_data.dgram(), now()); + assert!(server_out.as_dgram_ref().is_none()); // the packet will be discarded + + assert_eq!(*server.state(), State::Handshaking); + let server_out = server.process(client_fin.dgram(), now()); + assert!(server_out.as_dgram_ref().is_some()); +} + +#[test] +fn after_fin_is_read_conn_events_for_stream_should_be_removed() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let id = server.stream_create(StreamType::BiDi).unwrap(); + server.stream_send(id, &[6; 10]).unwrap(); + server.stream_close_send(id).unwrap(); + let out = server.process(None, now()).dgram(); + assert!(out.is_some()); + + let _ = client.process(out, now()); + + // read from the stream before checking connection events. + let mut buf = vec![0; 4000]; + let (_, fin) = client.stream_recv(id, &mut buf).unwrap(); + assert_eq!(fin, true); + + // Make sure we do not have RecvStreamReadable events for the stream when fin has been read. + let readable_stream_evt = + |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id } if stream_id == id); + assert!(!client.events().any(readable_stream_evt)); +} + +#[test] +fn after_stream_stop_sending_is_called_conn_events_for_stream_should_be_removed() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let id = server.stream_create(StreamType::BiDi).unwrap(); + server.stream_send(id, &[6; 10]).unwrap(); + server.stream_close_send(id).unwrap(); + let out = server.process(None, now()).dgram(); + assert!(out.is_some()); + + let _ = client.process(out, now()); + + // send stop seending. + client + .stream_stop_sending(id, Error::NoError.code()) + .unwrap(); + + // Make sure we do not have RecvStreamReadable events for the stream after stream_stop_sending + // has been called. + let readable_stream_evt = + |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id } if stream_id == id); + assert!(!client.events().any(readable_stream_evt)); +} + +#[test] +fn stream_data_blocked_generates_max_stream_data() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let now = now(); + + // Send some data and include STREAM_DATA_BLOCKED with any value. + let stream_id = server.stream_create(StreamType::UniDi).unwrap(); + let _ = server.stream_send(stream_id, DEFAULT_STREAM_DATA).unwrap(); + server.flow_mgr.borrow_mut().stream_data_blocked( + StreamId::from(stream_id), + u64::try_from(DEFAULT_STREAM_DATA.len()).unwrap(), + ); + + let dgram = server.process(None, now).dgram(); + assert!(dgram.is_some()); + + let sdb_before = client.stats().frame_rx.stream_data_blocked; + client.process_input(dgram.unwrap(), now); + assert_eq!(client.stats().frame_rx.stream_data_blocked, sdb_before + 1); + + // Consume the data. + let mut buf = [0; 10]; + let (count, end) = client.stream_recv(stream_id, &mut buf[..]).unwrap(); + assert_eq!(count, DEFAULT_STREAM_DATA.len()); + assert!(!end); + + let dgram = client.process_output(now).dgram(); + + // Client should have sent a MAX_STREAM_DATA frame with just a small increase + // on the default window size. + let msd_before = server.stats().frame_rx.max_stream_data; + server.process_input(dgram.unwrap(), now); + assert_eq!(server.stats().frame_rx.max_stream_data, msd_before + 1); + + // Test that more space is available, but that it is small. + let mut written = 0; + loop { + const LARGE_BUFFER: &[u8] = &[0; 1024]; + let amount = server.stream_send(stream_id, LARGE_BUFFER).unwrap(); + if amount == 0 { + break; + } + written += amount; + } + assert_eq!(written, RECV_BUFFER_SIZE - DEFAULT_STREAM_DATA.len()); +} + +/// See <https://github.com/mozilla/neqo/issues/871> +#[test] +fn max_streams_after_bidi_closed() { + const REQUEST: &[u8] = b"ping"; + const RESPONSE: &[u8] = b"pong"; + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + while client.stream_create(StreamType::BiDi).is_ok() { + // Exhaust the stream limit. + } + // Write on the one stream and send that out. + let _ = client.stream_send(stream_id, REQUEST).unwrap(); + client.stream_close_send(stream_id).unwrap(); + let dgram = client.process(None, now()).dgram(); + + // Now handle the stream and send an incomplete response. + server.process_input(dgram.unwrap(), now()); + server.stream_send(stream_id, RESPONSE).unwrap(); + let dgram = server.process_output(now()).dgram(); + + // The server shouldn't have released more stream credit. + client.process_input(dgram.unwrap(), now()); + let e = client.stream_create(StreamType::BiDi).unwrap_err(); + assert!(matches!(e, Error::StreamLimitError)); + + // Closing the stream isn't enough. + server.stream_close_send(stream_id).unwrap(); + let dgram = server.process_output(now()).dgram(); + client.process_input(dgram.unwrap(), now()); + assert!(client.stream_create(StreamType::BiDi).is_err()); + + // The server needs to see an acknowledgment from the client for its + // response AND the server has to read all of the request. + // and the server needs to read all the data. Read first. + let mut buf = [0; REQUEST.len()]; + let (count, fin) = server.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(&buf[..count], REQUEST); + assert!(fin); + + // We need an ACK from the client now, but that isn't guaranteed, + // so give the client one more packet just in case. + let dgram = send_something(&mut server, now()); + client.process_input(dgram, now()); + + // Now get the client to send the ACK and have the server handle that. + let dgram = send_something(&mut client, now()); + let dgram = server.process(Some(dgram), now()).dgram(); + client.process_input(dgram.unwrap(), now()); + assert!(client.stream_create(StreamType::BiDi).is_ok()); + assert!(client.stream_create(StreamType::BiDi).is_err()); +} + +#[test] +fn no_dupdata_readable_events() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // create a stream + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_send(stream_id, &[0x00]).unwrap(); + let out = client.process(None, now()); + let _ = server.process(out.dgram(), now()); + + // We have a data_readable event. + let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable {..}); + assert!(server.events().any(stream_readable)); + + // Send one more data frame from client. The previous stream data has not been read yet, + // therefore there should not be a new DataReadable event. + client.stream_send(stream_id, &[0x00]).unwrap(); + let out_second_data_frame = client.process(None, now()); + let _ = server.process(out_second_data_frame.dgram(), now()); + assert!(!server.events().any(stream_readable)); + + // One more frame with a fin will not produce a new DataReadable event, because the + // previous stream data has not been read yet. + client.stream_send(stream_id, &[0x00]).unwrap(); + client.stream_close_send(stream_id).unwrap(); + let out_third_data_frame = client.process(None, now()); + let _ = server.process(out_third_data_frame.dgram(), now()); + assert!(!server.events().any(stream_readable)); +} + +#[test] +fn no_dupdata_readable_events_empty_last_frame() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + // create a stream + let stream_id = client.stream_create(StreamType::BiDi).unwrap(); + client.stream_send(stream_id, &[0x00]).unwrap(); + let out = client.process(None, now()); + let _ = server.process(out.dgram(), now()); + + // We have a data_readable event. + let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable {..}); + assert!(server.events().any(stream_readable)); + + // An empty frame with a fin will not produce a new DataReadable event, because + // the previous stream data has not been read yet. + client.stream_close_send(stream_id).unwrap(); + let out_second_data_frame = client.process(None, now()); + let _ = server.process(out_second_data_frame.dgram(), now()); + assert!(!server.events().any(stream_readable)); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/vn.rs b/third_party/rust/neqo-transport/src/connection/tests/vn.rs new file mode 100644 index 0000000000..f7d37c7864 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/vn.rs @@ -0,0 +1,201 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::super::{ConnectionError, Output, State}; +use super::{default_client, default_server}; +use crate::packet::PACKET_BIT_LONG; +use crate::{Error, QuicVersion}; + +use neqo_common::{Datagram, Decoder, Encoder}; +use std::time::Duration; +use test_fixture::{self, loopback, now}; + +// The expected PTO duration after the first Initial is sent. +const INITIAL_PTO: Duration = Duration::from_millis(300); + +#[test] +fn unknown_version() { + let mut client = default_client(); + // Start the handshake. + let _ = client.process(None, now()).dgram(); + + let mut unknown_version_packet = vec![0x80, 0x1a, 0x1a, 0x1a, 0x1a]; + unknown_version_packet.resize(1200, 0x0); + let _ = client.process( + Some(Datagram::new( + loopback(), + loopback(), + unknown_version_packet, + )), + now(), + ); + assert_eq!(1, client.stats().dropped_rx); +} + +#[test] +fn server_receive_unknown_first_packet() { + let mut server = default_server(); + + let mut unknown_version_packet = vec![0x80, 0x1a, 0x1a, 0x1a, 0x1a]; + unknown_version_packet.resize(1200, 0x0); + + assert_eq!( + server.process( + Some(Datagram::new( + loopback(), + loopback(), + unknown_version_packet, + )), + now(), + ), + Output::None + ); + + assert_eq!(1, server.stats().dropped_rx); +} + +fn create_vn(initial_pkt: &[u8], versions: &[u32]) -> Vec<u8> { + let mut dec = Decoder::from(&initial_pkt[5..]); // Skip past version. + let dst_cid = dec.decode_vec(1).expect("client DCID"); + let src_cid = dec.decode_vec(1).expect("client SCID"); + + let mut encoder = Encoder::default(); + encoder.encode_byte(PACKET_BIT_LONG); + encoder.encode(&[0; 4]); // Zero version == VN. + encoder.encode_vec(1, dst_cid); + encoder.encode_vec(1, src_cid); + + for v in versions { + encoder.encode_uint(4, *v); + } + encoder.into() +} + +#[test] +fn version_negotiation_current_version() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let vn = create_vn( + &initial_pkt, + &[0x1a1a_1a1a, QuicVersion::default().as_u32()], + ); + + let dgram = Datagram::new(loopback(), loopback(), vn); + let delay = client.process(Some(dgram), now()).callback(); + assert_eq!(delay, INITIAL_PTO); + assert_eq!(*client.state(), State::WaitInitial); + assert_eq!(1, client.stats().dropped_rx); +} + +#[test] +fn version_negotiation_only_reserved() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a]); + + let dgram = Datagram::new(loopback(), loopback(), vn); + assert_eq!(client.process(Some(dgram), now()), Output::None); + match client.state() { + State::Closed(err) => { + assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation)) + } + _ => panic!("Invalid client state"), + } +} + +#[test] +fn version_negotiation_corrupted() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a]); + + let dgram = Datagram::new(loopback(), loopback(), &vn[..vn.len() - 1]); + let delay = client.process(Some(dgram), now()).callback(); + assert_eq!(delay, INITIAL_PTO); + assert_eq!(*client.state(), State::WaitInitial); + assert_eq!(1, client.stats().dropped_rx); +} + +#[test] +fn version_negotiation_empty() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let vn = create_vn(&initial_pkt, &[]); + + let dgram = Datagram::new(loopback(), loopback(), vn); + let delay = client.process(Some(dgram), now()).callback(); + assert_eq!(delay, INITIAL_PTO); + assert_eq!(*client.state(), State::WaitInitial); + assert_eq!(1, client.stats().dropped_rx); +} + +#[test] +fn version_negotiation_not_supported() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a, 0xff00_0001]); + + assert_eq!( + client.process(Some(Datagram::new(loopback(), loopback(), vn)), now(),), + Output::None + ); + match client.state() { + State::Closed(err) => { + assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation)) + } + _ => panic!("Invalid client state"), + } +} + +#[test] +fn version_negotiation_bad_cid() { + let mut client = default_client(); + // Start the handshake. + let initial_pkt = client + .process(None, now()) + .dgram() + .expect("a datagram") + .to_vec(); + + let mut vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a, 0xff00_0001]); + vn[6] ^= 0xc4; + + let dgram = Datagram::new(loopback(), loopback(), vn); + let delay = client.process(Some(dgram), now()).callback(); + assert_eq!(delay, INITIAL_PTO); + assert_eq!(*client.state(), State::WaitInitial); + assert_eq!(1, client.stats().dropped_rx); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs b/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs new file mode 100644 index 0000000000..4ccffe4203 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs @@ -0,0 +1,193 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::super::{Connection, FixedConnectionIdManager}; +use super::{connect, default_client, default_server, exchange_ticket}; +use crate::events::ConnectionEvent; +use crate::frame::StreamType; +use crate::{ConnectionParameters, Error}; + +use neqo_common::event::Provider; +use neqo_crypto::{AllowZeroRtt, AntiReplay}; +use std::cell::RefCell; +use std::rc::Rc; +use test_fixture::{self, assertions, now}; + +#[test] +fn zero_rtt_negotiate() { + // Note that the two servers in this test will get different anti-replay filters. + // That's OK because we aren't testing anti-replay. + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let token = exchange_ticket(&mut client, &mut server, now()); + let mut client = default_client(); + client + .enable_resumption(now(), token) + .expect("should set token"); + let mut server = default_server(); + connect(&mut client, &mut server); + assert!(client.tls_info().unwrap().early_data_accepted()); + assert!(server.tls_info().unwrap().early_data_accepted()); +} + +#[test] +fn zero_rtt_send_recv() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let token = exchange_ticket(&mut client, &mut server, now()); + let mut client = default_client(); + client + .enable_resumption(now(), token) + .expect("should set token"); + let mut server = default_server(); + + // Send ClientHello. + let client_hs = client.process(None, now()); + assert!(client_hs.as_dgram_ref().is_some()); + + // Now send a 0-RTT packet. + let client_stream_id = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(client_stream_id, &[1, 2, 3]).unwrap(); + let client_0rtt = client.process(None, now()); + assert!(client_0rtt.as_dgram_ref().is_some()); + // 0-RTT packets on their own shouldn't be padded to 1200. + assert!(client_0rtt.as_dgram_ref().unwrap().len() < 1200); + + let server_hs = server.process(client_hs.dgram(), now()); + assert!(server_hs.as_dgram_ref().is_some()); // ServerHello, etc... + let server_process_0rtt = server.process(client_0rtt.dgram(), now()); + assert!(server_process_0rtt.as_dgram_ref().is_none()); + + let server_stream_id = server + .events() + .find_map(|evt| match evt { + ConnectionEvent::NewStream { stream_id, .. } => Some(stream_id), + _ => None, + }) + .expect("should have received a new stream event"); + assert_eq!(client_stream_id, server_stream_id.as_u64()); +} + +#[test] +fn zero_rtt_send_coalesce() { + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let token = exchange_ticket(&mut client, &mut server, now()); + let mut client = default_client(); + client + .enable_resumption(now(), token) + .expect("should set token"); + let mut server = default_server(); + + // Write 0-RTT before generating any packets. + // This should result in a datagram that coalesces Initial and 0-RTT. + let client_stream_id = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(client_stream_id, &[1, 2, 3]).unwrap(); + let client_0rtt = client.process(None, now()); + assert!(client_0rtt.as_dgram_ref().is_some()); + + assertions::assert_coalesced_0rtt(&client_0rtt.as_dgram_ref().unwrap()[..]); + + let server_hs = server.process(client_0rtt.dgram(), now()); + assert!(server_hs.as_dgram_ref().is_some()); // Should produce ServerHello etc... + + let server_stream_id = server + .events() + .find_map(|evt| match evt { + ConnectionEvent::NewStream { stream_id } => Some(stream_id), + _ => None, + }) + .expect("should have received a new stream event"); + assert_eq!(client_stream_id, server_stream_id.as_u64()); +} + +#[test] +fn zero_rtt_before_resumption_token() { + let mut client = default_client(); + assert!(client.stream_create(StreamType::BiDi).is_err()); +} + +#[test] +fn zero_rtt_send_reject() { + const MESSAGE: &[u8] = &[1, 2, 3]; + + let mut client = default_client(); + let mut server = default_server(); + connect(&mut client, &mut server); + + let token = exchange_ticket(&mut client, &mut server, now()); + let mut client = default_client(); + client + .enable_resumption(now(), token) + .expect("should set token"); + let mut server = Connection::new_server( + test_fixture::DEFAULT_KEYS, + test_fixture::DEFAULT_ALPN, + Rc::new(RefCell::new(FixedConnectionIdManager::new(10))), + &ConnectionParameters::default(), + ) + .unwrap(); + // Using a freshly initialized anti-replay context + // should result in the server rejecting 0-RTT. + let ar = + AntiReplay::new(now(), test_fixture::ANTI_REPLAY_WINDOW, 1, 3).expect("setup anti-replay"); + server + .server_enable_0rtt(&ar, AllowZeroRtt {}) + .expect("enable 0-RTT"); + + // Send ClientHello. + let client_hs = client.process(None, now()); + assert!(client_hs.as_dgram_ref().is_some()); + + // Write some data on the client. + let stream_id = client.stream_create(StreamType::UniDi).unwrap(); + client.stream_send(stream_id, MESSAGE).unwrap(); + let client_0rtt = client.process(None, now()); + assert!(client_0rtt.as_dgram_ref().is_some()); + + let server_hs = server.process(client_hs.dgram(), now()); + assert!(server_hs.as_dgram_ref().is_some()); // Should produce ServerHello etc... + let server_ignored = server.process(client_0rtt.dgram(), now()); + assert!(server_ignored.as_dgram_ref().is_none()); + + // The server shouldn't receive that 0-RTT data. + let recvd_stream_evt = |e| matches!(e, ConnectionEvent::NewStream { .. }); + assert!(!server.events().any(recvd_stream_evt)); + + // Client should get a rejection. + let client_fin = client.process(server_hs.dgram(), now()); + let recvd_0rtt_reject = |e| e == ConnectionEvent::ZeroRttRejected; + assert!(client.events().any(recvd_0rtt_reject)); + + // Server consume client_fin + let server_ack = server.process(client_fin.dgram(), now()); + assert!(server_ack.as_dgram_ref().is_some()); + let client_out = client.process(server_ack.dgram(), now()); + assert!(client_out.as_dgram_ref().is_none()); + + // ...and the client stream should be gone. + let res = client.stream_send(stream_id, MESSAGE); + assert!(res.is_err()); + assert_eq!(res.unwrap_err(), Error::InvalidStreamId); + + // Open a new stream and send data. StreamId should start with 0. + let stream_id_after_reject = client.stream_create(StreamType::UniDi).unwrap(); + assert_eq!(stream_id, stream_id_after_reject); + client.stream_send(stream_id_after_reject, MESSAGE).unwrap(); + let client_after_reject = client.process(None, now()); + assert!(client_after_reject.as_dgram_ref().is_some()); + + // The server should receive new stream + let server_out = server.process(client_after_reject.dgram(), now()); + assert!(server_out.as_dgram_ref().is_none()); // suppress the ack + assert!(server.events().any(recvd_stream_evt)); +} diff --git a/third_party/rust/neqo-transport/src/crypto.rs b/third_party/rust/neqo-transport/src/crypto.rs new file mode 100644 index 0000000000..bed669f003 --- /dev/null +++ b/third_party/rust/neqo-transport/src/crypto.rs @@ -0,0 +1,1293 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::cell::RefCell; +use std::cmp::{max, min}; +use std::convert::TryFrom; +use std::mem; +use std::ops::{Index, IndexMut, Range}; +use std::rc::Rc; +use std::time::Instant; + +use neqo_common::{hex, hex_snip_middle, qdebug, qinfo, qtrace, Encoder, Role}; +use neqo_crypto::{ + aead::Aead, hkdf, hp::HpKey, Agent, AntiReplay, Cipher, Epoch, HandshakeState, Record, + RecordList, ResumptionToken, SymKey, ZeroRttChecker, TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, TLS_CT_HANDSHAKE, + TLS_EPOCH_APPLICATION_DATA, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL, TLS_EPOCH_ZERO_RTT, + TLS_VERSION_1_3, +}; + +use crate::packet::{PacketBuilder, PacketNumber, QuicVersion}; +use crate::recovery::RecoveryToken; +use crate::recv_stream::RxStreamOrderer; +use crate::send_stream::TxBuffer; +use crate::tparams::{TpZeroRttChecker, TransportParameters, TransportParametersHandler}; +use crate::tracking::PNSpace; +use crate::{Error, Res}; + +const MAX_AUTH_TAG: usize = 32; +/// The number of invocations remaining on a write cipher before we try +/// to update keys. This has to be much smaller than the number returned +/// by `CryptoDxState::limit` or updates will happen too often. As we don't +/// need to ask permission to update, this can be quite small. +pub(crate) const UPDATE_WRITE_KEYS_AT: PacketNumber = 100; + +// This is a testing kludge that allows for overwriting the number of +// invocations of the next cipher to operate. With this, it is possible +// to test what happens when the number of invocations reaches 0, or +// when it hits `UPDATE_WRITE_KEYS_AT` and an automatic update should occur. +// This is a little crude, but it saves a lot of plumbing. +#[cfg(test)] +thread_local!(pub(crate) static OVERWRITE_INVOCATIONS: RefCell<Option<PacketNumber>> = RefCell::default()); + +#[derive(Debug)] +pub struct Crypto { + pub(crate) tls: Agent, + pub(crate) streams: CryptoStreams, + pub(crate) states: CryptoStates, +} + +type TpHandler = Rc<RefCell<TransportParametersHandler>>; + +impl Crypto { + pub fn new(mut agent: Agent, protocols: &[impl AsRef<str>], tphandler: TpHandler) -> Res<Self> { + agent.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?; + agent.set_ciphers(&[ + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, + ])?; + agent.set_alpn(protocols)?; + agent.disable_end_of_early_data()?; + // Always enable 0-RTT on the client, but the server needs + // more configuration passed to server_enable_0rtt. + if let Agent::Client(c) = &mut agent { + c.enable_0rtt()?; + } + agent.extension_handler(0xffa5, tphandler)?; + Ok(Self { + tls: agent, + streams: Default::default(), + states: Default::default(), + }) + } + + pub fn server_enable_0rtt( + &mut self, + tphandler: TpHandler, + anti_replay: &AntiReplay, + zero_rtt_checker: impl ZeroRttChecker + 'static, + ) -> Res<()> { + if let Agent::Server(s) = &mut self.tls { + Ok(s.enable_0rtt( + anti_replay, + 0xffff_ffff, + TpZeroRttChecker::wrap(tphandler, zero_rtt_checker), + )?) + } else { + panic!("not a server"); + } + } + + pub fn handshake( + &mut self, + now: Instant, + space: PNSpace, + data: Option<&[u8]>, + ) -> Res<&HandshakeState> { + let input = data.map(|d| { + qtrace!("Handshake record received {:0x?} ", d); + let epoch = match space { + PNSpace::Initial => TLS_EPOCH_INITIAL, + PNSpace::Handshake => TLS_EPOCH_HANDSHAKE, + // Our epoch progresses forward, but the TLS epoch is fixed to 3. + PNSpace::ApplicationData => TLS_EPOCH_APPLICATION_DATA, + }; + Record { + ct: TLS_CT_HANDSHAKE, + epoch, + data: d.to_vec(), + } + }); + + match self.tls.handshake_raw(now, input) { + Ok(output) => { + self.buffer_records(output)?; + Ok(self.tls.state()) + } + Err(e) => { + qinfo!("Handshake failed"); + Err(match self.tls.alert() { + Some(a) => Error::CryptoAlert(*a), + _ => Error::CryptoError(e), + }) + } + } + } + + /// Enable 0-RTT and return `true` if it is enabled successfully. + pub fn enable_0rtt(&mut self, role: Role) -> Res<bool> { + let info = self.tls.preinfo()?; + // `info.early_data()` returns false for a server, + // so use `early_data_cipher()` to tell if 0-RTT is enabled. + let cipher = info.early_data_cipher(); + if cipher.is_none() { + return Ok(false); + } + let (dir, secret) = match role { + Role::Client => ( + CryptoDxDirection::Write, + self.tls.write_secret(TLS_EPOCH_ZERO_RTT), + ), + Role::Server => ( + CryptoDxDirection::Read, + self.tls.read_secret(TLS_EPOCH_ZERO_RTT), + ), + }; + let secret = secret.ok_or(Error::InternalError)?; + self.states.set_0rtt_keys(dir, &secret, cipher.unwrap()); + Ok(true) + } + + /// Returns true if new handshake keys were installed. + pub fn install_keys(&mut self, role: Role) -> Res<bool> { + if !self.tls.state().is_final() { + let installed_hs = self.install_handshake_keys()?; + if role == Role::Server { + self.maybe_install_application_write_key()?; + } + Ok(installed_hs) + } else { + Ok(false) + } + } + + fn install_handshake_keys(&mut self) -> Res<bool> { + qtrace!([self], "Attempt to install handshake keys"); + let write_secret = if let Some(secret) = self.tls.write_secret(TLS_EPOCH_HANDSHAKE) { + secret + } else { + // No keys is fine. + return Ok(false); + }; + let read_secret = self + .tls + .read_secret(TLS_EPOCH_HANDSHAKE) + .ok_or(Error::InternalError)?; + let cipher = match self.tls.info() { + None => self.tls.preinfo()?.cipher_suite(), + Some(info) => Some(info.cipher_suite()), + } + .ok_or(Error::InternalError)?; + self.states + .set_handshake_keys(&write_secret, &read_secret, cipher); + qdebug!([self], "Handshake keys installed"); + Ok(true) + } + + fn maybe_install_application_write_key(&mut self) -> Res<()> { + qtrace!([self], "Attempt to install application write key"); + if let Some(secret) = self.tls.write_secret(TLS_EPOCH_APPLICATION_DATA) { + self.states.set_application_write_key(secret)?; + qdebug!([self], "Application write key installed"); + } + Ok(()) + } + + pub fn install_application_keys(&mut self, expire_0rtt: Instant) -> Res<()> { + self.maybe_install_application_write_key()?; + // The write key might have been installed earlier, but it should + // always be installed now. + debug_assert!(self.states.app_write.is_some()); + let read_secret = self + .tls + .read_secret(TLS_EPOCH_APPLICATION_DATA) + .ok_or(Error::InternalError)?; + self.states + .set_application_read_key(read_secret, expire_0rtt)?; + qdebug!([self], "application read keys installed"); + Ok(()) + } + + /// Buffer crypto records for sending. + pub fn buffer_records(&mut self, records: RecordList) -> Res<()> { + for r in records { + if r.ct != TLS_CT_HANDSHAKE { + return Err(Error::ProtocolViolation); + } + qtrace!([self], "Adding CRYPTO data {:?}", r); + self.streams.send(PNSpace::from(r.epoch), &r.data); + } + Ok(()) + } + + pub fn acked(&mut self, token: &CryptoRecoveryToken) { + qinfo!( + "Acked crypto frame space={} offset={} length={}", + token.space, + token.offset, + token.length + ); + self.streams.acked(token); + } + + pub fn lost(&mut self, token: &CryptoRecoveryToken) { + qinfo!( + "Lost crypto frame space={} offset={} length={}", + token.space, + token.offset, + token.length + ); + self.streams.lost(token); + } + + /// Mark any outstanding frames in the indicated space as "lost" so + /// that they can be sent again. + pub fn resend_unacked(&mut self, space: PNSpace) { + self.streams.resend_unacked(space); + } + + /// Discard state for a packet number space and return true + /// if something was discarded. + pub fn discard(&mut self, space: PNSpace) -> bool { + self.streams.discard(space); + self.states.discard(space) + } + + pub fn create_resumption_token( + &mut self, + new_token: Option<&[u8]>, + tps: &TransportParameters, + rtt: u64, + ) -> Option<ResumptionToken> { + if let Agent::Client(ref mut c) = self.tls { + if let Some(ref t) = c.resumption_token() { + qtrace!("TLS token {}", hex(t.as_ref())); + let mut enc = Encoder::default(); + enc.encode_varint(rtt); + enc.encode_vvec_with(|enc_inner| { + tps.encode(enc_inner); + }); + enc.encode_vvec(new_token.unwrap_or(&[])); + enc.encode(t.as_ref()); + qinfo!("resumption token {}", hex_snip_middle(&enc[..])); + Some(ResumptionToken::new(enc.into(), t.expiration_time())) + } else { + None + } + } else { + unreachable!("It is a server."); + } + } + + pub fn has_resumption_token(&self) -> bool { + if let Agent::Client(c) = &self.tls { + c.has_resumption_token() + } else { + unreachable!("It is a server."); + } + } +} + +impl ::std::fmt::Display for Crypto { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Crypto") + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CryptoDxDirection { + Read, + Write, +} + +#[derive(Debug)] +pub struct CryptoDxState { + direction: CryptoDxDirection, + /// The epoch of this crypto state. This initially tracks TLS epochs + /// via DTLS: 0 = initial, 1 = 0-RTT, 2 = handshake, 3 = application. + /// But we don't need to keep that, and QUIC isn't limited in how + /// many times keys can be updated, so we don't use `u16` for this. + epoch: usize, + aead: Aead, + hpkey: HpKey, + /// This tracks the range of packet numbers that have been seen. This allows + /// for verifying that packet numbers before a key update are strictly lower + /// than packet numbers after a key update. + used_pn: Range<PacketNumber>, + /// This is the minimum packet number that is allowed. + min_pn: PacketNumber, + /// The total number of operations that are remaining before the keys + /// become exhausted and can't be used any more. + invocations: PacketNumber, +} + +impl CryptoDxState { + #[allow(clippy::unknown_clippy_lints)] // Until we require rust 1.45. + #[allow(clippy::reversed_empty_ranges)] // To initialize an empty range. + pub fn new( + direction: CryptoDxDirection, + epoch: Epoch, + secret: &SymKey, + cipher: Cipher, + ) -> Self { + qinfo!( + "Making {:?} {} CryptoDxState, cipher={}", + direction, + epoch, + cipher + ); + Self { + direction, + epoch: usize::from(epoch), + aead: Aead::new(TLS_VERSION_1_3, cipher, secret, "quic ").unwrap(), + hpkey: HpKey::extract(TLS_VERSION_1_3, cipher, secret, "quic hp").unwrap(), + used_pn: 0..0, + min_pn: 0, + invocations: Self::limit(direction, cipher), + } + } + + pub fn new_initial( + quic_version: QuicVersion, + direction: CryptoDxDirection, + label: &str, + dcid: &[u8], + ) -> Self { + qtrace!("new_initial for {:?}", quic_version); + const INITIAL_SALT_27: &[u8] = &[ + 0xc3, 0xee, 0xf7, 0x12, 0xc7, 0x2e, 0xbb, 0x5a, 0x11, 0xa7, 0xd2, 0x43, 0x2b, 0xb4, + 0x63, 0x65, 0xbe, 0xf9, 0xf5, 0x02, + ]; + const INITIAL_SALT_29_32: &[u8] = &[ + 0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, + 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99, + ]; + let salt = match quic_version { + QuicVersion::Draft27 | QuicVersion::Draft28 => INITIAL_SALT_27, + QuicVersion::Draft29 + | QuicVersion::Draft30 + | QuicVersion::Draft31 + | QuicVersion::Draft32 => INITIAL_SALT_29_32, + }; + let cipher = TLS_AES_128_GCM_SHA256; + let initial_secret = hkdf::extract( + TLS_VERSION_1_3, + cipher, + Some( + hkdf::import_key(TLS_VERSION_1_3, cipher, salt) + .as_ref() + .unwrap(), + ), + hkdf::import_key(TLS_VERSION_1_3, cipher, dcid) + .as_ref() + .unwrap(), + ) + .unwrap(); + + let secret = + hkdf::expand_label(TLS_VERSION_1_3, cipher, &initial_secret, &[], label).unwrap(); + + Self::new(direction, TLS_EPOCH_INITIAL, &secret, cipher) + } + + /// Determine the confidentiality and integrity limits for the cipher. + fn limit(direction: CryptoDxDirection, cipher: Cipher) -> PacketNumber { + match direction { + // This uses the smaller limits for 2^16 byte packets + // as we don't control incoming packet size. + CryptoDxDirection::Read => match cipher { + TLS_AES_128_GCM_SHA256 => 1 << 52, + TLS_AES_256_GCM_SHA384 => PacketNumber::MAX, + TLS_CHACHA20_POLY1305_SHA256 => 1 << 36, + _ => unreachable!(), + }, + // This uses the larger limits for 2^11 byte packets. + CryptoDxDirection::Write => match cipher { + TLS_AES_128_GCM_SHA256 | TLS_AES_256_GCM_SHA384 => 1 << 28, + TLS_CHACHA20_POLY1305_SHA256 => PacketNumber::MAX, + _ => unreachable!(), + }, + } + } + + fn invoked(&mut self) -> Res<()> { + #[cfg(test)] + OVERWRITE_INVOCATIONS.with(|v| { + if let Some(i) = v.borrow_mut().take() { + neqo_common::qwarn!("Setting {:?} invocations to {}", self.direction, i); + self.invocations = i; + } + }); + self.invocations = self + .invocations + .checked_sub(1) + .ok_or(Error::KeysExhausted)?; + Ok(()) + } + + /// Determine whether we should initiate a key update. + pub fn should_update(&self) -> bool { + // There is no point in updating read keys as the limit is global. + debug_assert_eq!(self.direction, CryptoDxDirection::Write); + self.invocations <= UPDATE_WRITE_KEYS_AT + } + + pub fn next(&self, next_secret: &SymKey, cipher: Cipher) -> Self { + let pn = self.next_pn(); + // We count invocations of each write key just for that key, but all + // attempts to invocations to read count toward a single limit. + // This doesn't count use of Handshake keys. + let invocations = if self.direction == CryptoDxDirection::Read { + self.invocations + } else { + Self::limit(CryptoDxDirection::Write, cipher) + }; + Self { + direction: self.direction, + epoch: self.epoch + 1, + aead: Aead::new(TLS_VERSION_1_3, cipher, next_secret, "quic ").unwrap(), + hpkey: self.hpkey.clone(), + used_pn: pn..pn, + min_pn: pn, + invocations, + } + } + + #[must_use] + pub fn key_phase(&self) -> bool { + // Epoch 3 => 0, 4 => 1, 5 => 0, 6 => 1, ... + self.epoch & 1 != 1 + } + + /// This is a continuation of a previous, so adjust the range accordingly. + /// Fail if the two ranges overlap. Do nothing if the directions don't match. + pub fn continuation(&mut self, prev: &Self) -> Res<()> { + debug_assert_eq!(self.direction, prev.direction); + let next = prev.next_pn(); + self.min_pn = next; + // TODO(mt) use Range::is_empty() when available + if self.used_pn.start == self.used_pn.end { + self.used_pn = next..next; + Ok(()) + } else if prev.used_pn.end > self.used_pn.start { + qdebug!( + [self], + "Found packet with too new packet number {} > {}, compared to {}", + self.used_pn.start, + prev.used_pn.end, + prev, + ); + Err(Error::PacketNumberOverlap) + } else { + self.used_pn.start = next; + Ok(()) + } + } + + /// Mark a packet number as used. If this is too low, reject it. + /// Note that this won't catch a value that is too high if packets protected with + /// old keys are received after a key update. That needs to be caught elsewhere. + pub fn used(&mut self, pn: PacketNumber) -> Res<()> { + if pn < self.min_pn { + qdebug!( + [self], + "Found packet with too old packet number: {} < {}", + pn, + self.min_pn + ); + return Err(Error::PacketNumberOverlap); + } + if self.used_pn.start == self.used_pn.end { + self.used_pn.start = pn; + } + self.used_pn.end = max(pn + 1, self.used_pn.end); + Ok(()) + } + + #[must_use] + pub fn needs_update(&self) -> bool { + // Only initiate a key update if we have processed exactly one packet + // and we are in an epoch greater than 3. + self.used_pn.start + 1 == self.used_pn.end + && self.epoch > usize::from(TLS_EPOCH_APPLICATION_DATA) + } + + #[must_use] + pub fn can_update(&self, largest_acknowledged: Option<PacketNumber>) -> bool { + if let Some(la) = largest_acknowledged { + self.used_pn.contains(&la) + } else { + // If we haven't received any acknowledgments, it's OK to update + // the first application data epoch. + self.epoch == usize::from(TLS_EPOCH_APPLICATION_DATA) + } + } + + pub fn compute_mask(&self, sample: &[u8]) -> Res<Vec<u8>> { + let mask = self.hpkey.mask(sample)?; + qtrace!([self], "HP sample={} mask={}", hex(sample), hex(&mask)); + Ok(mask) + } + + #[must_use] + pub fn next_pn(&self) -> PacketNumber { + self.used_pn.end + } + + pub fn encrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> { + debug_assert_eq!(self.direction, CryptoDxDirection::Write); + qtrace!( + [self], + "encrypt pn={} hdr={} body={}", + pn, + hex(hdr), + hex(body) + ); + // The numbers in `Self::limit` assume a maximum packet size of 2^11. + assert!(body.len() <= 2048); + self.invoked()?; + + let size = body.len() + MAX_AUTH_TAG; + let mut out = vec![0; size]; + let res = self.aead.encrypt(pn, hdr, body, &mut out)?; + + qtrace!([self], "encrypt ct={}", hex(res)); + debug_assert_eq!(pn, self.next_pn()); + self.used(pn)?; + Ok(res.to_vec()) + } + + #[must_use] + pub fn expansion(&self) -> usize { + self.aead.expansion() + } + + pub fn decrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> { + debug_assert_eq!(self.direction, CryptoDxDirection::Read); + qtrace!( + [self], + "decrypt pn={} hdr={} body={}", + pn, + hex(hdr), + hex(body) + ); + self.invoked()?; + let mut out = vec![0; body.len()]; + let res = self.aead.decrypt(pn, hdr, body, &mut out)?; + self.used(pn)?; + Ok(res.to_vec()) + } + + #[cfg(test)] + pub(crate) fn test_default() -> Self { + // This matches the value in packet.rs + const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; + Self::new_initial( + QuicVersion::default(), + CryptoDxDirection::Write, + "server in", + CLIENT_CID, + ) + } + + /// Get the amount of extra padding packets protected with this profile need. + /// This is the difference between the size of the header protection sample + /// and the AEAD expansion. + pub fn extra_padding(&self) -> usize { + self.hpkey + .sample_size() + .saturating_sub(self.aead.expansion()) + } +} + +impl std::fmt::Display for CryptoDxState { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "epoch {} {:?}", self.epoch, self.direction) + } +} + +#[derive(Debug)] +pub struct CryptoState { + tx: CryptoDxState, + rx: CryptoDxState, +} + +impl Index<CryptoDxDirection> for CryptoState { + type Output = CryptoDxState; + + fn index(&self, dir: CryptoDxDirection) -> &Self::Output { + match dir { + CryptoDxDirection::Read => &self.rx, + CryptoDxDirection::Write => &self.tx, + } + } +} + +impl IndexMut<CryptoDxDirection> for CryptoState { + fn index_mut(&mut self, dir: CryptoDxDirection) -> &mut Self::Output { + match dir { + CryptoDxDirection::Read => &mut self.rx, + CryptoDxDirection::Write => &mut self.tx, + } + } +} + +/// `CryptoDxAppData` wraps the state necessary for one direction of application data keys. +/// This includes the secret needed to generate the next set of keys. +#[derive(Debug)] +pub(crate) struct CryptoDxAppData { + dx: CryptoDxState, + cipher: Cipher, + // Not the secret used to create `self.dx`, but the one needed for the next iteration. + next_secret: SymKey, +} + +impl CryptoDxAppData { + pub fn new(dir: CryptoDxDirection, secret: SymKey, cipher: Cipher) -> Res<Self> { + Ok(Self { + dx: CryptoDxState::new(dir, TLS_EPOCH_APPLICATION_DATA, &secret, cipher), + cipher, + next_secret: Self::update_secret(cipher, &secret)?, + }) + } + + fn update_secret(cipher: Cipher, secret: &SymKey) -> Res<SymKey> { + let next = hkdf::expand_label(TLS_VERSION_1_3, cipher, secret, &[], "quic ku")?; + Ok(next) + } + + pub fn next(&self) -> Res<Self> { + if self.dx.epoch == usize::max_value() { + // Guard against too many key updates. + return Err(Error::KeysExhausted); + } + let next_secret = Self::update_secret(self.cipher, &self.next_secret)?; + Ok(Self { + dx: self.dx.next(&self.next_secret, self.cipher), + cipher: self.cipher, + next_secret, + }) + } + + pub fn epoch(&self) -> usize { + self.dx.epoch + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum CryptoSpace { + Initial, + ZeroRtt, + Handshake, + ApplicationData, +} + +#[derive(Debug, Default)] +pub struct CryptoStates { + initial: Option<CryptoState>, + handshake: Option<CryptoState>, + zero_rtt: Option<CryptoDxState>, // One direction only! + cipher: Cipher, + app_write: Option<CryptoDxAppData>, + app_read: Option<CryptoDxAppData>, + app_read_next: Option<CryptoDxAppData>, + // If this is set, then we have noticed a genuine update. + // Once this time passes, we should switch in new keys. + read_update_time: Option<Instant>, +} + +impl CryptoStates { + /// Select a `CryptoDxState` and `CryptoSpace` for the given `PNSpace`. + /// This selects 0-RTT keys for `PNSpace::ApplicationData` if 1-RTT keys are + /// not yet available. + pub fn select_tx(&mut self, space: PNSpace) -> Option<(CryptoSpace, &mut CryptoDxState)> { + match space { + PNSpace::Initial => self + .tx(CryptoSpace::Initial) + .map(|dx| (CryptoSpace::Initial, dx)), + PNSpace::Handshake => self + .tx(CryptoSpace::Handshake) + .map(|dx| (CryptoSpace::Handshake, dx)), + PNSpace::ApplicationData => { + if let Some(app) = self.app_write.as_mut() { + Some((CryptoSpace::ApplicationData, &mut app.dx)) + } else { + self.zero_rtt.as_mut().map(|dx| (CryptoSpace::ZeroRtt, dx)) + } + } + } + } + + pub fn tx<'a>(&'a mut self, cspace: CryptoSpace) -> Option<&'a mut CryptoDxState> { + let tx = |k: Option<&'a mut CryptoState>| k.map(|dx| &mut dx.tx); + match cspace { + CryptoSpace::Initial => tx(self.initial.as_mut()), + CryptoSpace::ZeroRtt => self + .zero_rtt + .as_mut() + .filter(|z| z.direction == CryptoDxDirection::Write), + CryptoSpace::Handshake => tx(self.handshake.as_mut()), + CryptoSpace::ApplicationData => self.app_write.as_mut().map(|app| &mut app.dx), + } + } + + pub fn rx_hp(&mut self, cspace: CryptoSpace) -> Option<&mut CryptoDxState> { + if let CryptoSpace::ApplicationData = cspace { + self.app_read.as_mut().map(|ar| &mut ar.dx) + } else { + self.rx(cspace, false) + } + } + + pub fn rx<'a>( + &'a mut self, + cspace: CryptoSpace, + key_phase: bool, + ) -> Option<&'a mut CryptoDxState> { + let rx = |x: Option<&'a mut CryptoState>| x.map(|dx| &mut dx.rx); + match cspace { + CryptoSpace::Initial => rx(self.initial.as_mut()), + CryptoSpace::ZeroRtt => self + .zero_rtt + .as_mut() + .filter(|z| z.direction == CryptoDxDirection::Read), + CryptoSpace::Handshake => rx(self.handshake.as_mut()), + CryptoSpace::ApplicationData => { + let f = |a: Option<&'a mut CryptoDxAppData>| { + a.filter(|ar| ar.dx.key_phase() == key_phase) + }; + // XOR to reduce the leakage about which key is chosen. + f(self.app_read.as_mut()) + .xor(f(self.app_read_next.as_mut())) + .map(|ar| &mut ar.dx) + } + } + } + + /// Whether keys for processing packets in the indicated space are pending. + /// This allows the caller to determine whether to save a packet for later + /// when keys are not available. + /// NOTE: 0-RTT keys are not considered here. The expectation is that a + /// server will have to save 0-RTT packets in a different place. Though it + /// is possible to attribute 0-RTT packets to an existing connection if there + /// is a multi-packet Initial, that is an unusual circumstance, so we + /// don't do caching for that in those places that call this function. + pub fn rx_pending(&self, space: CryptoSpace) -> bool { + match space { + CryptoSpace::Initial | CryptoSpace::ZeroRtt => false, + CryptoSpace::Handshake => self.handshake.is_none() && self.initial.is_some(), + CryptoSpace::ApplicationData => self.app_read.is_none(), + } + } + + /// Create the initial crypto state. + pub fn init(&mut self, quic_version: QuicVersion, role: Role, dcid: &[u8]) { + const CLIENT_INITIAL_LABEL: &str = "client in"; + const SERVER_INITIAL_LABEL: &str = "server in"; + + qinfo!( + [self], + "Creating initial cipher state role={:?} dcid={}", + role, + hex(dcid) + ); + + let (write, read) = match role { + Role::Client => (CLIENT_INITIAL_LABEL, SERVER_INITIAL_LABEL), + Role::Server => (SERVER_INITIAL_LABEL, CLIENT_INITIAL_LABEL), + }; + + let mut initial = CryptoState { + tx: CryptoDxState::new_initial(quic_version, CryptoDxDirection::Write, write, dcid), + rx: CryptoDxState::new_initial(quic_version, CryptoDxDirection::Read, read, dcid), + }; + if let Some(prev) = &self.initial { + qinfo!( + [self], + "Continue packet numbers for initial after retry (write is {:?})", + prev.rx.used_pn, + ); + initial.tx.continuation(&prev.tx).unwrap(); + } + self.initial = Some(initial); + } + + pub fn set_0rtt_keys(&mut self, dir: CryptoDxDirection, secret: &SymKey, cipher: Cipher) { + qtrace!([self], "install 0-RTT keys"); + self.zero_rtt = Some(CryptoDxState::new(dir, TLS_EPOCH_ZERO_RTT, secret, cipher)); + } + + /// Discard keys and return true if that happened. + pub fn discard(&mut self, space: PNSpace) -> bool { + match space { + PNSpace::Initial => self.initial.take().is_some(), + PNSpace::Handshake => self.handshake.take().is_some(), + PNSpace::ApplicationData => panic!("Can't drop application data keys"), + } + } + + pub fn discard_0rtt_keys(&mut self) { + qtrace!([self], "discard 0-RTT keys"); + assert!( + self.app_read.is_none(), + "Can't discard 0-RTT after setting application keys" + ); + self.zero_rtt = None; + } + + pub fn set_handshake_keys( + &mut self, + write_secret: &SymKey, + read_secret: &SymKey, + cipher: Cipher, + ) { + self.cipher = cipher; + self.handshake = Some(CryptoState { + tx: CryptoDxState::new( + CryptoDxDirection::Write, + TLS_EPOCH_HANDSHAKE, + write_secret, + cipher, + ), + rx: CryptoDxState::new( + CryptoDxDirection::Read, + TLS_EPOCH_HANDSHAKE, + read_secret, + cipher, + ), + }); + } + + pub fn set_application_write_key(&mut self, secret: SymKey) -> Res<()> { + debug_assert!(self.app_write.is_none()); + debug_assert_ne!(self.cipher, 0); + let mut app = CryptoDxAppData::new(CryptoDxDirection::Write, secret, self.cipher)?; + if let Some(z) = &self.zero_rtt { + if z.direction == CryptoDxDirection::Write { + app.dx.continuation(z)?; + } + } + self.zero_rtt = None; + self.app_write = Some(app); + Ok(()) + } + + pub fn set_application_read_key(&mut self, secret: SymKey, expire_0rtt: Instant) -> Res<()> { + debug_assert!(self.app_write.is_some(), "should have write keys installed"); + debug_assert!(self.app_read.is_none()); + let mut app = CryptoDxAppData::new(CryptoDxDirection::Read, secret, self.cipher)?; + if let Some(z) = &self.zero_rtt { + if z.direction == CryptoDxDirection::Read { + app.dx.continuation(z)?; + } + self.read_update_time = Some(expire_0rtt); + } + self.app_read_next = Some(app.next()?); + self.app_read = Some(app); + Ok(()) + } + + /// Update the write keys. + pub fn initiate_key_update(&mut self, largest_acknowledged: Option<PacketNumber>) -> Res<()> { + // Only update if we are able to. We can only do this if we have + // received an acknowledgement for a packet in the current phase. + // Also, skip this if we are waiting for read keys on the existing + // key update to be rolled over. + let write = &self.app_write.as_ref().unwrap().dx; + if write.can_update(largest_acknowledged) && self.read_update_time.is_none() { + // This call additionally checks that we don't advance to the next + // epoch while a key update is in progress. + if self.maybe_update_write()? { + Ok(()) + } else { + qdebug!([self], "Write keys already updated"); + Err(Error::KeyUpdateBlocked) + } + } else { + qdebug!([self], "Waiting for ACK or blocked on read key timer"); + Err(Error::KeyUpdateBlocked) + } + } + + /// Try to update, and return true if it happened. + fn maybe_update_write(&mut self) -> Res<bool> { + // Update write keys. But only do so if the write keys are not already + // ahead of the read keys. If we initiated the key update, the write keys + // will already be ahead. + debug_assert!(self.read_update_time.is_none()); + let write = &self.app_write.as_ref().unwrap(); + let read = &self.app_read.as_ref().unwrap(); + if write.epoch() == read.epoch() { + qdebug!([self], "Update write keys to epoch={}", write.epoch() + 1); + self.app_write = Some(write.next()?); + Ok(true) + } else { + Ok(false) + } + } + + /// Check whether write keys are close to running out of invocations. + /// If that is close, update them if possible. Failing to update at + /// this stage is cause for a fatal error. + pub fn auto_update(&mut self) -> Res<()> { + if let Some(app_write) = self.app_write.as_ref() { + if app_write.dx.should_update() { + qinfo!([self], "Initiating automatic key update"); + if !self.maybe_update_write()? { + return Err(Error::KeysExhausted); + } + } + } + Ok(()) + } + + fn has_0rtt_read(&self) -> bool { + self.zero_rtt + .as_ref() + .filter(|z| z.direction == CryptoDxDirection::Read) + .is_some() + } + + /// Prepare to update read keys. This doesn't happen immediately as + /// we want to ensure that we can continue to receive any delayed + /// packets that use the old keys. So we just set a timer. + pub fn key_update_received(&mut self, expiration: Instant) -> Res<()> { + qtrace!([self], "Key update received"); + // If we received a key update, then we assume that the peer has + // acknowledged a packet we sent in this epoch. It's OK to do that + // because they aren't allowed to update without first having received + // something from us. If the ACK isn't in the packet that triggered this + // key update, it must be in some other packet they have sent. + let _ = self.maybe_update_write()?; + + // We shouldn't have 0-RTT keys at this point, but if we do, dump them. + debug_assert_eq!(self.read_update_time.is_some(), self.has_0rtt_read()); + if self.has_0rtt_read() { + self.zero_rtt = None; + } + self.read_update_time = Some(expiration); + Ok(()) + } + + #[must_use] + pub fn update_time(&self) -> Option<Instant> { + self.read_update_time + } + + /// Check if time has passed for updating key update parameters. + /// If it has, then swap keys over and allow more key updates to be initiated. + /// This is also used to discard 0-RTT read keys at the server in the same way. + pub fn check_key_update(&mut self, now: Instant) -> Res<()> { + if let Some(expiry) = self.read_update_time { + // If enough time has passed, then install new keys and clear the timer. + if now >= expiry { + if self.has_0rtt_read() { + qtrace!([self], "Discarding 0-RTT keys"); + self.zero_rtt = None; + } else { + qtrace!([self], "Rotating read keys"); + mem::swap(&mut self.app_read, &mut self.app_read_next); + self.app_read_next = Some(self.app_read.as_ref().unwrap().next()?); + } + self.read_update_time = None; + } + } + Ok(()) + } + + /// Get the current/highest epoch. This returns (write, read) epochs. + #[cfg(test)] + pub fn get_epochs(&self) -> (Option<usize>, Option<usize>) { + let to_epoch = |app: &Option<CryptoDxAppData>| app.as_ref().map(|a| a.dx.epoch); + (to_epoch(&self.app_write), to_epoch(&self.app_read)) + } + + /// While we are awaiting the completion of a key update, we might receive + /// valid packets that are protected with old keys. We need to ensure that + /// these don't carry packet numbers higher than those in packets protected + /// with the newer keys. To ensure that, this is called after every decryption. + pub fn check_pn_overlap(&mut self) -> Res<()> { + // We only need to do the check while we are waiting for read keys to be updated. + if self.read_update_time.is_some() { + qtrace!([self], "Checking for PN overlap"); + let next_dx = &mut self.app_read_next.as_mut().unwrap().dx; + next_dx.continuation(&self.app_read.as_ref().unwrap().dx)?; + } + Ok(()) + } + + /// Make some state for removing protection in tests. + #[cfg(test)] + pub(crate) fn test_default() -> Self { + let read = |epoch| { + let mut dx = CryptoDxState::test_default(); + dx.direction = CryptoDxDirection::Read; + dx.epoch = epoch; + dx + }; + let app_read = |epoch| CryptoDxAppData { + dx: read(epoch), + cipher: TLS_AES_128_GCM_SHA256, + next_secret: hkdf::import_key(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, &[0xaa; 32]) + .unwrap(), + }; + Self { + initial: Some(CryptoState { + tx: CryptoDxState::test_default(), + rx: read(0), + }), + handshake: None, + zero_rtt: None, + cipher: TLS_AES_128_GCM_SHA256, + // This isn't used, but the epoch is read to check for a key update. + app_write: Some(app_read(3)), + app_read: Some(app_read(3)), + app_read_next: Some(app_read(4)), + read_update_time: None, + } + } + + #[cfg(test)] + pub(crate) fn test_chacha() -> Self { + const SECRET: &[u8] = &[ + 0x9a, 0xc3, 0x12, 0xa7, 0xf8, 0x77, 0x46, 0x8e, 0xbe, 0x69, 0x42, 0x27, 0x48, 0xad, + 0x00, 0xa1, 0x54, 0x43, 0xf1, 0x82, 0x03, 0xa0, 0x7d, 0x60, 0x60, 0xf6, 0x88, 0xf3, + 0x0f, 0x21, 0x63, 0x2b, + ]; + let secret = + hkdf::import_key(TLS_VERSION_1_3, TLS_CHACHA20_POLY1305_SHA256, SECRET).unwrap(); + let app_read = |epoch| CryptoDxAppData { + dx: CryptoDxState { + direction: CryptoDxDirection::Read, + epoch, + aead: Aead::new( + TLS_VERSION_1_3, + TLS_CHACHA20_POLY1305_SHA256, + &secret, + "quic ", + ) + .unwrap(), + hpkey: HpKey::extract( + TLS_VERSION_1_3, + TLS_CHACHA20_POLY1305_SHA256, + &secret, + "quic hp", + ) + .unwrap(), + used_pn: 0..645_971_972, + min_pn: 0, + invocations: 10, + }, + cipher: TLS_CHACHA20_POLY1305_SHA256, + next_secret: secret.clone(), + }; + Self { + initial: None, + handshake: None, + zero_rtt: None, + cipher: TLS_CHACHA20_POLY1305_SHA256, + app_write: Some(app_read(3)), + app_read: Some(app_read(3)), + app_read_next: Some(app_read(4)), + read_update_time: None, + } + } +} + +impl std::fmt::Display for CryptoStates { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "CryptoStates") + } +} + +#[derive(Debug, Default)] +pub struct CryptoStream { + tx: TxBuffer, + rx: RxStreamOrderer, +} + +#[derive(Debug)] +#[allow(dead_code)] // Suppress false positive: https://github.com/rust-lang/rust/issues/68408 +pub enum CryptoStreams { + Initial { + initial: CryptoStream, + handshake: CryptoStream, + application: CryptoStream, + }, + Handshake { + handshake: CryptoStream, + application: CryptoStream, + }, + ApplicationData { + application: CryptoStream, + }, +} + +impl CryptoStreams { + pub fn discard(&mut self, space: PNSpace) { + match space { + PNSpace::Initial => { + if let Self::Initial { + handshake, + application, + .. + } = self + { + *self = Self::Handshake { + handshake: mem::take(handshake), + application: mem::take(application), + }; + } + } + PNSpace::Handshake => { + if let Self::Handshake { application, .. } = self { + *self = Self::ApplicationData { + application: mem::take(application), + }; + } else if matches!(self, Self::Initial {..}) { + panic!("Discarding handshake before initial discarded"); + } + } + PNSpace::ApplicationData => panic!("Discarding application data crypto streams"), + } + } + + pub fn send(&mut self, space: PNSpace, data: &[u8]) { + self.get_mut(space).unwrap().tx.send(data); + } + + pub fn inbound_frame(&mut self, space: PNSpace, offset: u64, data: &[u8]) { + self.get_mut(space).unwrap().rx.inbound_frame(offset, data); + } + + pub fn data_ready(&self, space: PNSpace) -> bool { + self.get(space).map_or(false, |cs| cs.rx.data_ready()) + } + + pub fn read_to_end(&mut self, space: PNSpace, buf: &mut Vec<u8>) -> usize { + self.get_mut(space).unwrap().rx.read_to_end(buf) + } + + pub fn acked(&mut self, token: &CryptoRecoveryToken) { + self.get_mut(token.space) + .unwrap() + .tx + .mark_as_acked(token.offset, token.length); + } + + pub fn lost(&mut self, token: &CryptoRecoveryToken) { + // See BZ 1624800, ignore lost packets in spaces we've dropped keys + if let Some(cs) = self.get_mut(token.space) { + cs.tx.mark_as_lost(token.offset, token.length); + } + } + + /// Resend any Initial or Handshake CRYPTO frames that might be outstanding. + /// This can help speed up handshake times. + pub fn resend_unacked(&mut self, space: PNSpace) { + if space != PNSpace::ApplicationData { + if let Some(cs) = self.get_mut(space) { + cs.tx.unmark_sent(); + } + } + } + + fn get(&self, space: PNSpace) -> Option<&CryptoStream> { + let (initial, hs, app) = match self { + Self::Initial { + initial, + handshake, + application, + } => (Some(initial), Some(handshake), Some(application)), + Self::Handshake { + handshake, + application, + } => (None, Some(handshake), Some(application)), + Self::ApplicationData { application } => (None, None, Some(application)), + }; + match space { + PNSpace::Initial => initial, + PNSpace::Handshake => hs, + PNSpace::ApplicationData => app, + } + } + + fn get_mut(&mut self, space: PNSpace) -> Option<&mut CryptoStream> { + let (initial, hs, app) = match self { + Self::Initial { + initial, + handshake, + application, + } => (Some(initial), Some(handshake), Some(application)), + Self::Handshake { + handshake, + application, + } => (None, Some(handshake), Some(application)), + Self::ApplicationData { application } => (None, None, Some(application)), + }; + match space { + PNSpace::Initial => initial, + PNSpace::Handshake => hs, + PNSpace::ApplicationData => app, + } + } + + pub fn write_frame( + &mut self, + space: PNSpace, + builder: &mut PacketBuilder, + ) -> Option<RecoveryToken> { + let cs = self.get_mut(space).unwrap(); + if let Some((offset, data)) = cs.tx.next_bytes() { + let mut header_len = 1 + Encoder::varint_len(offset) + 1; + + // Don't bother if there isn't room for the header and some data. + if builder.remaining() < header_len + 1 { + return None; + } + // Calculate length of data based on the minimum of: + // - available data + // - remaining space, less the header, which counts only one byte + // for the length at first to avoid underestimating length + let length = min(data.len(), builder.remaining() - header_len); + header_len += Encoder::varint_len(u64::try_from(length).unwrap()) - 1; + let length = min(data.len(), builder.remaining() - header_len); + + builder.encode_varint(crate::frame::FRAME_TYPE_CRYPTO); + builder.encode_varint(offset); + builder.encode_vvec(&data[..length]); + cs.tx.mark_as_sent(offset, length); + + qdebug!("CRYPTO for {} offset={}, len={}", space, offset, length); + Some(RecoveryToken::Crypto(CryptoRecoveryToken { + space, + offset, + length, + })) + } else { + None + } + } +} + +impl Default for CryptoStreams { + fn default() -> Self { + Self::Initial { + initial: CryptoStream::default(), + handshake: CryptoStream::default(), + application: CryptoStream::default(), + } + } +} + +#[derive(Debug, Clone)] +pub struct CryptoRecoveryToken { + space: PNSpace, + offset: u64, + length: usize, +} diff --git a/third_party/rust/neqo-transport/src/dump.rs b/third_party/rust/neqo-transport/src/dump.rs new file mode 100644 index 0000000000..e8f5b32ae9 --- /dev/null +++ b/third_party/rust/neqo-transport/src/dump.rs @@ -0,0 +1,32 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Enable just this file for logging to just see packets. +// e.g. "RUST_LOG=neqo_transport::dump neqo-client ..." + +use crate::connection::Connection; +use crate::frame::Frame; +use crate::packet::{PacketNumber, PacketType}; +use neqo_common::{qdebug, Decoder}; + +#[allow(clippy::module_name_repetitions)] +pub fn dump_packet(conn: &Connection, dir: &str, pt: PacketType, pn: PacketNumber, payload: &[u8]) { + let mut s = String::from(""); + let mut d = Decoder::from(payload); + while d.remaining() > 0 { + let f = match Frame::decode(&mut d) { + Ok(f) => f, + Err(_) => { + s.push_str(" [broken]..."); + break; + } + }; + if let Some(x) = f.dump() { + s.push_str(&format!("\n {} {}", dir, &x)); + } + } + qdebug!([conn], "pn={} type={:?}{}", pn, pt, s); +} diff --git a/third_party/rust/neqo-transport/src/events.rs b/third_party/rust/neqo-transport/src/events.rs new file mode 100644 index 0000000000..fc998cf32c --- /dev/null +++ b/third_party/rust/neqo-transport/src/events.rs @@ -0,0 +1,254 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Collecting a list of events relevant to whoever is using the Connection. + +use std::cell::RefCell; +use std::collections::VecDeque; +use std::rc::Rc; + +use crate::connection::State; +use crate::frame::StreamType; +use crate::stream_id::StreamId; +use crate::AppError; +use neqo_common::event::Provider as EventProvider; +use neqo_crypto::ResumptionToken; + +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq)] +pub enum ConnectionEvent { + /// Cert authentication needed + AuthenticationNeeded, + /// A new uni (read) or bidi stream has been opened by the peer. + NewStream { + stream_id: StreamId, + }, + /// Space available in the buffer for an application write to succeed. + SendStreamWritable { + stream_id: StreamId, + }, + /// New bytes available for reading. + RecvStreamReadable { + stream_id: u64, + }, + /// Peer reset the stream. + RecvStreamReset { + stream_id: u64, + app_error: AppError, + }, + /// Peer has sent STOP_SENDING + SendStreamStopSending { + stream_id: u64, + app_error: AppError, + }, + /// Peer has acked everything sent on the stream. + SendStreamComplete { + stream_id: u64, + }, + /// Peer increased MAX_STREAMS + SendStreamCreatable { + stream_type: StreamType, + }, + /// Connection state change. + StateChange(State), + /// The server rejected 0-RTT. + /// This event invalidates all state in streams that has been created. + /// Any data written to streams needs to be written again. + ZeroRttRejected, + ResumptionToken(ResumptionToken), +} + +#[derive(Debug, Default, Clone)] +#[allow(clippy::module_name_repetitions)] +pub struct ConnectionEvents { + events: Rc<RefCell<VecDeque<ConnectionEvent>>>, +} + +impl ConnectionEvents { + pub fn authentication_needed(&self) { + self.insert(ConnectionEvent::AuthenticationNeeded); + } + + pub fn new_stream(&self, stream_id: StreamId) { + self.insert(ConnectionEvent::NewStream { stream_id }); + } + + pub fn recv_stream_readable(&self, stream_id: StreamId) { + self.insert(ConnectionEvent::RecvStreamReadable { + stream_id: stream_id.as_u64(), + }); + } + + pub fn recv_stream_reset(&self, stream_id: StreamId, app_error: AppError) { + // If reset, no longer readable. + self.remove(|evt| matches!(evt, ConnectionEvent::RecvStreamReadable { stream_id: x } if *x == stream_id.as_u64())); + + self.insert(ConnectionEvent::RecvStreamReset { + stream_id: stream_id.as_u64(), + app_error, + }); + } + + pub fn send_stream_writable(&self, stream_id: StreamId) { + self.insert(ConnectionEvent::SendStreamWritable { stream_id }); + } + + pub fn send_stream_stop_sending(&self, stream_id: StreamId, app_error: AppError) { + // If stopped, no longer writable. + self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamWritable { stream_id: x } if *x == stream_id.as_u64())); + + self.insert(ConnectionEvent::SendStreamStopSending { + stream_id: stream_id.as_u64(), + app_error, + }); + } + + pub fn send_stream_complete(&self, stream_id: StreamId) { + self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamWritable { stream_id: x } if *x == stream_id)); + + self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamStopSending { stream_id: x, .. } if *x == stream_id.as_u64())); + + self.insert(ConnectionEvent::SendStreamComplete { + stream_id: stream_id.as_u64(), + }); + } + + pub fn send_stream_creatable(&self, stream_type: StreamType) { + self.insert(ConnectionEvent::SendStreamCreatable { stream_type }); + } + + pub fn connection_state_change(&self, state: State) { + // If closing, existing events no longer relevant. + match state { + State::Closing { .. } | State::Closed(_) => self.events.borrow_mut().clear(), + _ => (), + } + self.insert(ConnectionEvent::StateChange(state)); + } + + pub fn client_resumption_token(&self, token: ResumptionToken) { + self.insert(ConnectionEvent::ResumptionToken(token)); + } + + pub fn client_0rtt_rejected(&self) { + // If 0rtt rejected, must start over and existing events are no longer + // relevant. + self.events.borrow_mut().clear(); + self.insert(ConnectionEvent::ZeroRttRejected); + } + + pub fn recv_stream_complete(&self, stream_id: StreamId) { + // If stopped, no longer readable. + self.remove(|evt| matches!(evt, ConnectionEvent::RecvStreamReadable { stream_id: x } if *x == stream_id.as_u64())); + } + + fn insert(&self, event: ConnectionEvent) { + let mut q = self.events.borrow_mut(); + + // Special-case two enums that are not strictly PartialEq equal but that + // we wish to avoid inserting duplicates. + let already_present = match &event { + ConnectionEvent::SendStreamStopSending { stream_id, .. } => q.iter().any(|evt| { + matches!(evt, ConnectionEvent::SendStreamStopSending { stream_id: x, .. } + if *x == *stream_id) + }), + ConnectionEvent::RecvStreamReset { stream_id, .. } => q.iter().any(|evt| { + matches!(evt, ConnectionEvent::RecvStreamReset { stream_id: x, .. } + if *x == *stream_id) + }), + _ => q.contains(&event), + }; + if !already_present { + q.push_back(event); + } + } + + fn remove<F>(&self, f: F) + where + F: Fn(&ConnectionEvent) -> bool, + { + self.events.borrow_mut().retain(|evt| !f(evt)) + } +} + +impl EventProvider for ConnectionEvents { + type Event = ConnectionEvent; + + fn has_events(&self) -> bool { + !self.events.borrow().is_empty() + } + + fn next_event(&mut self) -> Option<Self::Event> { + self.events.borrow_mut().pop_front() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ConnectionError, Error}; + + #[test] + fn event_culling() { + let mut evts = ConnectionEvents::default(); + + evts.client_0rtt_rejected(); + evts.client_0rtt_rejected(); + assert_eq!(evts.events().count(), 1); + assert_eq!(evts.events().count(), 0); + + evts.new_stream(4.into()); + evts.new_stream(4.into()); + assert_eq!(evts.events().count(), 1); + + evts.recv_stream_readable(6.into()); + evts.recv_stream_reset(6.into(), 66); + evts.recv_stream_reset(6.into(), 65); + assert_eq!(evts.events().count(), 1); + + evts.send_stream_writable(8.into()); + evts.send_stream_writable(8.into()); + evts.send_stream_stop_sending(8.into(), 55); + evts.send_stream_stop_sending(8.into(), 56); + let events = evts.events().collect::<Vec<_>>(); + assert_eq!(events.len(), 1); + assert_eq!( + events[0], + ConnectionEvent::SendStreamStopSending { + stream_id: 8, + app_error: 55 + } + ); + + evts.send_stream_writable(8.into()); + evts.send_stream_writable(8.into()); + evts.send_stream_stop_sending(8.into(), 55); + evts.send_stream_stop_sending(8.into(), 56); + evts.send_stream_complete(8.into()); + assert_eq!(evts.events().count(), 1); + + evts.send_stream_writable(8.into()); + evts.send_stream_writable(9.into()); + evts.send_stream_stop_sending(10.into(), 55); + evts.send_stream_stop_sending(11.into(), 56); + evts.send_stream_complete(12.into()); + assert_eq!(evts.events().count(), 5); + + evts.send_stream_writable(8.into()); + evts.send_stream_writable(9.into()); + evts.send_stream_stop_sending(10.into(), 55); + evts.send_stream_stop_sending(11.into(), 56); + evts.send_stream_complete(12.into()); + evts.client_0rtt_rejected(); + assert_eq!(evts.events().count(), 1); + + evts.send_stream_writable(9.into()); + evts.send_stream_stop_sending(10.into(), 55); + evts.connection_state_change(State::Closed(ConnectionError::Transport( + Error::StreamStateError, + ))); + assert_eq!(evts.events().count(), 1); + } +} diff --git a/third_party/rust/neqo-transport/src/flow_mgr.rs b/third_party/rust/neqo-transport/src/flow_mgr.rs new file mode 100644 index 0000000000..0ae7fd3c00 --- /dev/null +++ b/third_party/rust/neqo-transport/src/flow_mgr.rs @@ -0,0 +1,400 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Tracks possibly-redundant flow control signals from other code and converts +// into flow control frames needing to be sent to the remote. + +use std::collections::HashMap; +use std::mem; + +use neqo_common::{qinfo, qwarn, Encoder}; +use smallvec::{smallvec, SmallVec}; + +use crate::frame::{Frame, StreamType}; +use crate::packet::PacketBuilder; +use crate::recovery::RecoveryToken; +use crate::recv_stream::RecvStreams; +use crate::send_stream::SendStreams; +use crate::stats::FrameStats; +use crate::stream_id::{StreamId, StreamIndex, StreamIndexes}; +use crate::AppError; + +type FlowFrame = Frame<'static>; +pub type FlowControlRecoveryToken = FlowFrame; + +#[derive(Debug, Default)] +pub struct FlowMgr { + // Discriminant as key ensures only 1 of every frame type will be queued. + from_conn: HashMap<mem::Discriminant<FlowFrame>, FlowFrame>, + + // (id, discriminant) as key ensures only 1 of every frame type per stream + // will be queued. + from_streams: HashMap<(StreamId, mem::Discriminant<FlowFrame>), FlowFrame>, + + // (stream_type, discriminant) as key ensures only 1 of every frame type + // per stream type will be queued. + from_stream_types: HashMap<(StreamType, mem::Discriminant<FlowFrame>), FlowFrame>, + + used_data: u64, + max_data: u64, +} + +impl FlowMgr { + pub fn conn_credit_avail(&self) -> u64 { + self.max_data - self.used_data + } + + pub fn conn_increase_credit_used(&mut self, amount: u64) { + self.used_data += amount; + assert!(self.used_data <= self.max_data) + } + + // Dummy DataBlocked frame for discriminant use below + + /// Returns whether max credit was actually increased. + pub fn conn_increase_max_credit(&mut self, new: u64) -> bool { + const DB_FRAME: FlowFrame = Frame::DataBlocked { data_limit: 0 }; + + if new > self.max_data { + self.max_data = new; + self.from_conn.remove(&mem::discriminant(&DB_FRAME)); + + true + } else { + false + } + } + + // -- frames scoped on connection -- + + pub fn data_blocked(&mut self) { + let frame = Frame::DataBlocked { + data_limit: self.max_data, + }; + self.from_conn.insert(mem::discriminant(&frame), frame); + } + + pub fn path_response(&mut self, data: [u8; 8]) { + let frame = Frame::PathResponse { data }; + self.from_conn.insert(mem::discriminant(&frame), frame); + } + + pub fn max_data(&mut self, maximum_data: u64) { + let frame = Frame::MaxData { maximum_data }; + self.from_conn.insert(mem::discriminant(&frame), frame); + } + + // -- frames scoped on stream -- + + /// Indicate to receiving remote the stream is reset + pub fn stream_reset( + &mut self, + stream_id: StreamId, + application_error_code: AppError, + final_size: u64, + ) { + let frame = Frame::ResetStream { + stream_id, + application_error_code, + final_size, + }; + self.from_streams + .insert((stream_id, mem::discriminant(&frame)), frame); + } + + /// Indicate to sending remote we are no longer interested in the stream + pub fn stop_sending(&mut self, stream_id: StreamId, application_error_code: AppError) { + let frame = Frame::StopSending { + stream_id, + application_error_code, + }; + self.from_streams + .insert((stream_id, mem::discriminant(&frame)), frame); + } + + /// Update sending remote with more credits + pub fn max_stream_data(&mut self, stream_id: StreamId, maximum_stream_data: u64) { + let frame = Frame::MaxStreamData { + stream_id, + maximum_stream_data, + }; + self.from_streams + .insert((stream_id, mem::discriminant(&frame)), frame); + } + + /// Don't send stream data updates if no more data is coming + pub fn clear_max_stream_data(&mut self, stream_id: StreamId) { + let frame = Frame::MaxStreamData { + stream_id, + maximum_stream_data: 0, + }; + self.from_streams + .remove(&(stream_id, mem::discriminant(&frame))); + } + + /// Indicate to receiving remote we need more credits + pub fn stream_data_blocked(&mut self, stream_id: StreamId, stream_data_limit: u64) { + let frame = Frame::StreamDataBlocked { + stream_id, + stream_data_limit, + }; + self.from_streams + .insert((stream_id, mem::discriminant(&frame)), frame); + } + + // -- frames scoped on stream type -- + + pub fn max_streams(&mut self, stream_limit: StreamIndex, stream_type: StreamType) { + let frame = Frame::MaxStreams { + stream_type, + maximum_streams: stream_limit, + }; + self.from_stream_types + .insert((stream_type, mem::discriminant(&frame)), frame); + } + + pub fn streams_blocked(&mut self, stream_limit: StreamIndex, stream_type: StreamType) { + let frame = Frame::StreamsBlocked { + stream_type, + stream_limit, + }; + self.from_stream_types + .insert((stream_type, mem::discriminant(&frame)), frame); + } + + pub fn peek(&self) -> Option<&Frame> { + if let Some(key) = self.from_conn.keys().next() { + self.from_conn.get(key) + } else if let Some(key) = self.from_streams.keys().next() { + self.from_streams.get(key) + } else if let Some(key) = self.from_stream_types.keys().next() { + self.from_stream_types.get(key) + } else { + None + } + } + + pub(crate) fn acked( + &mut self, + token: &FlowControlRecoveryToken, + send_streams: &mut SendStreams, + ) { + const RESET_STREAM: &Frame = &Frame::ResetStream { + stream_id: StreamId::new(0), + application_error_code: 0, + final_size: 0, + }; + + if let Frame::ResetStream { stream_id, .. } = token { + qinfo!("Reset received stream={}", stream_id.as_u64()); + + if self + .from_streams + .remove(&(*stream_id, mem::discriminant(RESET_STREAM))) + .is_some() + { + qinfo!("Removed RESET_STREAM frame for {}", stream_id.as_u64()); + } + + send_streams.reset_acked(*stream_id); + } + } + + pub(crate) fn lost( + &mut self, + token: &FlowControlRecoveryToken, + send_streams: &mut SendStreams, + recv_streams: &mut RecvStreams, + indexes: &mut StreamIndexes, + ) { + match *token { + // Always resend ResetStream if lost + Frame::ResetStream { + stream_id, + application_error_code, + final_size, + } => { + qinfo!( + "Reset lost stream={} err={} final_size={}", + stream_id.as_u64(), + application_error_code, + final_size + ); + if send_streams.get(stream_id).is_ok() { + self.stream_reset(stream_id, application_error_code, final_size); + } + } + // Resend MaxStreams if lost (with updated value) + Frame::MaxStreams { stream_type, .. } => { + let local_max = match stream_type { + StreamType::BiDi => &mut indexes.local_max_stream_bidi, + StreamType::UniDi => &mut indexes.local_max_stream_uni, + }; + + self.max_streams(*local_max, stream_type) + } + // Only resend "*Blocked" frames if still blocked + Frame::DataBlocked { .. } => { + if self.conn_credit_avail() == 0 { + self.data_blocked() + } + } + Frame::StreamDataBlocked { stream_id, .. } => { + if let Ok(ss) = send_streams.get(stream_id) { + if ss.credit_avail() == 0 { + self.stream_data_blocked(stream_id, ss.max_stream_data()) + } + } + } + Frame::StreamsBlocked { stream_type, .. } => match stream_type { + StreamType::UniDi => { + if indexes.remote_next_stream_uni >= indexes.remote_max_stream_uni { + self.streams_blocked(indexes.remote_max_stream_uni, StreamType::UniDi); + } + } + StreamType::BiDi => { + if indexes.remote_next_stream_bidi >= indexes.remote_max_stream_bidi { + self.streams_blocked(indexes.remote_max_stream_bidi, StreamType::BiDi); + } + } + }, + // Resend StopSending + Frame::StopSending { + stream_id, + application_error_code, + } => self.stop_sending(stream_id, application_error_code), + Frame::MaxStreamData { stream_id, .. } => { + if let Some(rs) = recv_streams.get_mut(&stream_id) { + if let Some(msd) = rs.max_stream_data() { + self.max_stream_data(stream_id, msd) + } + } + } + Frame::PathResponse { .. } => qinfo!("Path Response lost, not re-sent"), + _ => qwarn!("Unexpected Flow frame {:?} lost, not re-sent", token), + } + } + + pub(crate) fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + while let Some(frame) = self.peek() { + // All these frames are bags of varints, so we can just extract the + // varints and use common code for writing. + let values: SmallVec<[_; 3]> = match frame { + Frame::ResetStream { + stream_id, + application_error_code, + final_size, + } => { + stats.reset_stream += 1; + smallvec![stream_id.as_u64(), *application_error_code, *final_size] + } + Frame::StopSending { + stream_id, + application_error_code, + } => { + stats.stop_sending += 1; + smallvec![stream_id.as_u64(), *application_error_code] + } + + Frame::MaxStreams { + maximum_streams, .. + } => { + stats.max_streams += 1; + smallvec![maximum_streams.as_u64()] + } + Frame::StreamsBlocked { stream_limit, .. } => { + stats.streams_blocked += 1; + smallvec![stream_limit.as_u64()] + } + + Frame::MaxData { maximum_data } => { + stats.max_data += 1; + smallvec![*maximum_data] + } + Frame::DataBlocked { data_limit } => { + stats.data_blocked += 1; + smallvec![*data_limit] + } + + Frame::MaxStreamData { + stream_id, + maximum_stream_data, + } => { + stats.max_stream_data += 1; + smallvec![stream_id.as_u64(), *maximum_stream_data] + } + Frame::StreamDataBlocked { + stream_id, + stream_data_limit, + } => { + stats.stream_data_blocked += 1; + smallvec![stream_id.as_u64(), *stream_data_limit] + } + + // A special case, just write it out and move on.. + Frame::PathResponse { data } => { + stats.path_response += 1; + if builder.remaining() >= Encoder::varint_len(frame.get_type()) + data.len() { + builder.encode_varint(frame.get_type()); + builder.encode(data); + tokens.push(RecoveryToken::Flow(self.next().unwrap())); + continue; + } else { + return; + } + } + + _ => unreachable!("{:?}", frame), + }; + debug_assert!(!values.spilled()); + + if builder.remaining() + >= Encoder::varint_len(frame.get_type()) + + values + .iter() + .map(|&v| Encoder::varint_len(v)) + .sum::<usize>() + { + builder.encode_varint(frame.get_type()); + for v in values { + builder.encode_varint(v); + } + tokens.push(RecoveryToken::Flow(self.next().unwrap())); + } else { + return; + } + } + } +} + +impl Iterator for FlowMgr { + type Item = FlowFrame; + + /// Used by generator to get a flow control frame. + fn next(&mut self) -> Option<Self::Item> { + let first_key = self.from_conn.keys().next(); + if let Some(&first_key) = first_key { + return self.from_conn.remove(&first_key); + } + + let first_key = self.from_streams.keys().next(); + if let Some(&first_key) = first_key { + return self.from_streams.remove(&first_key); + } + + let first_key = self.from_stream_types.keys().next(); + if let Some(&first_key) = first_key { + return self.from_stream_types.remove(&first_key); + } + + None + } +} diff --git a/third_party/rust/neqo-transport/src/frame.rs b/third_party/rust/neqo-transport/src/frame.rs new file mode 100644 index 0000000000..5f3e7c9b78 --- /dev/null +++ b/third_party/rust/neqo-transport/src/frame.rs @@ -0,0 +1,835 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Directly relating to QUIC frames. + +use neqo_common::{qtrace, Decoder}; + +use crate::cid::MAX_CONNECTION_ID_LEN; +use crate::packet::PacketType; +use crate::stream_id::{StreamId, StreamIndex}; +use crate::{AppError, ConnectionError, Error, Res, TransportError, ERROR_APPLICATION_CLOSE}; + +use std::convert::TryFrom; +use std::ops::RangeInclusive; + +#[allow(clippy::module_name_repetitions)] +pub type FrameType = u64; + +const FRAME_TYPE_PADDING: FrameType = 0x0; +pub const FRAME_TYPE_PING: FrameType = 0x1; +pub const FRAME_TYPE_ACK: FrameType = 0x2; +const FRAME_TYPE_ACK_ECN: FrameType = 0x3; +const FRAME_TYPE_RST_STREAM: FrameType = 0x4; +const FRAME_TYPE_STOP_SENDING: FrameType = 0x5; +pub const FRAME_TYPE_CRYPTO: FrameType = 0x6; +pub const FRAME_TYPE_NEW_TOKEN: FrameType = 0x7; +const FRAME_TYPE_STREAM: FrameType = 0x8; +const FRAME_TYPE_STREAM_MAX: FrameType = 0xf; +const FRAME_TYPE_MAX_DATA: FrameType = 0x10; +const FRAME_TYPE_MAX_STREAM_DATA: FrameType = 0x11; +const FRAME_TYPE_MAX_STREAMS_BIDI: FrameType = 0x12; +const FRAME_TYPE_MAX_STREAMS_UNIDI: FrameType = 0x13; +const FRAME_TYPE_DATA_BLOCKED: FrameType = 0x14; +const FRAME_TYPE_STREAM_DATA_BLOCKED: FrameType = 0x15; +const FRAME_TYPE_STREAMS_BLOCKED_BIDI: FrameType = 0x16; +const FRAME_TYPE_STREAMS_BLOCKED_UNIDI: FrameType = 0x17; +const FRAME_TYPE_NEW_CONNECTION_ID: FrameType = 0x18; +const FRAME_TYPE_RETIRE_CONNECTION_ID: FrameType = 0x19; +const FRAME_TYPE_PATH_CHALLENGE: FrameType = 0x1a; +const FRAME_TYPE_PATH_RESPONSE: FrameType = 0x1b; +pub const FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT: FrameType = 0x1c; +pub const FRAME_TYPE_CONNECTION_CLOSE_APPLICATION: FrameType = 0x1d; +const FRAME_TYPE_HANDSHAKE_DONE: FrameType = 0x1e; + +const STREAM_FRAME_BIT_FIN: u64 = 0x01; +const STREAM_FRAME_BIT_LEN: u64 = 0x02; +const STREAM_FRAME_BIT_OFF: u64 = 0x04; + +/// `FRAME_APPLICATION_CLOSE` is the default CONNECTION_CLOSE frame that +/// is sent when an application error code needs to be sent in an +/// Initial or Handshake packet. +const FRAME_APPLICATION_CLOSE: &Frame = &Frame::ConnectionClose { + error_code: CloseError::Transport(ERROR_APPLICATION_CLOSE), + frame_type: 0, + reason_phrase: Vec::new(), +}; + +#[derive(PartialEq, Debug, Copy, Clone, PartialOrd, Eq, Ord, Hash)] +/// Bi-Directional or Uni-Directional. +pub enum StreamType { + BiDi, + UniDi, +} + +impl StreamType { + fn frame_type_bit(self) -> u64 { + match self { + Self::BiDi => 0, + Self::UniDi => 1, + } + } + fn from_type_bit(bit: u64) -> Self { + if (bit & 0x01) == 0 { + Self::BiDi + } else { + Self::UniDi + } + } +} + +#[derive(PartialEq, Eq, Debug, PartialOrd, Ord, Clone, Copy)] +pub enum CloseError { + Transport(TransportError), + Application(AppError), +} + +impl CloseError { + fn frame_type_bit(self) -> u64 { + match self { + Self::Transport(_) => 0, + Self::Application(_) => 1, + } + } + + fn from_type_bit(bit: u64, code: u64) -> Self { + if (bit & 0x01) == 0 { + Self::Transport(code) + } else { + Self::Application(code) + } + } + + pub fn code(&self) -> u64 { + match self { + Self::Transport(c) | Self::Application(c) => *c, + } + } +} + +impl From<ConnectionError> for CloseError { + fn from(err: ConnectionError) -> Self { + match err { + ConnectionError::Transport(c) => Self::Transport(c.code()), + ConnectionError::Application(c) => Self::Application(c), + } + } +} + +#[derive(PartialEq, Debug, Default, Clone)] +pub struct AckRange { + pub(crate) gap: u64, + pub(crate) range: u64, +} + +#[derive(PartialEq, Debug, Clone)] +pub enum Frame<'a> { + Padding, + Ping, + Ack { + largest_acknowledged: u64, + ack_delay: u64, + first_ack_range: u64, + ack_ranges: Vec<AckRange>, + }, + ResetStream { + stream_id: StreamId, + application_error_code: AppError, + final_size: u64, + }, + StopSending { + stream_id: StreamId, + application_error_code: AppError, + }, + Crypto { + offset: u64, + data: &'a [u8], + }, + NewToken { + token: &'a [u8], + }, + Stream { + stream_id: StreamId, + offset: u64, + data: &'a [u8], + fin: bool, + fill: bool, + }, + MaxData { + maximum_data: u64, + }, + MaxStreamData { + stream_id: StreamId, + maximum_stream_data: u64, + }, + MaxStreams { + stream_type: StreamType, + maximum_streams: StreamIndex, + }, + DataBlocked { + data_limit: u64, + }, + StreamDataBlocked { + stream_id: StreamId, + stream_data_limit: u64, + }, + StreamsBlocked { + stream_type: StreamType, + stream_limit: StreamIndex, + }, + NewConnectionId { + sequence_number: u64, + retire_prior: u64, + connection_id: &'a [u8], + stateless_reset_token: &'a [u8; 16], + }, + RetireConnectionId { + sequence_number: u64, + }, + PathChallenge { + data: [u8; 8], + }, + PathResponse { + data: [u8; 8], + }, + ConnectionClose { + error_code: CloseError, + frame_type: u64, + // Not a reference as we use this to hold the value. + // This is not used in optimized builds anyway. + reason_phrase: Vec<u8>, + }, + HandshakeDone, +} + +impl<'a> Frame<'a> { + pub fn get_type(&self) -> FrameType { + match self { + Self::Padding => FRAME_TYPE_PADDING, + Self::Ping => FRAME_TYPE_PING, + Self::Ack { .. } => FRAME_TYPE_ACK, // We don't do ACK ECN. + Self::ResetStream { .. } => FRAME_TYPE_RST_STREAM, + Self::StopSending { .. } => FRAME_TYPE_STOP_SENDING, + Self::Crypto { .. } => FRAME_TYPE_CRYPTO, + Self::NewToken { .. } => FRAME_TYPE_NEW_TOKEN, + Self::Stream { + fin, offset, fill, .. + } => Self::stream_type(*fin, *offset > 0, *fill), + Self::MaxData { .. } => FRAME_TYPE_MAX_DATA, + Self::MaxStreamData { .. } => FRAME_TYPE_MAX_STREAM_DATA, + Self::MaxStreams { stream_type, .. } => { + FRAME_TYPE_MAX_STREAMS_BIDI + stream_type.frame_type_bit() + } + Self::DataBlocked { .. } => FRAME_TYPE_DATA_BLOCKED, + Self::StreamDataBlocked { .. } => FRAME_TYPE_STREAM_DATA_BLOCKED, + Self::StreamsBlocked { stream_type, .. } => { + FRAME_TYPE_STREAMS_BLOCKED_BIDI + stream_type.frame_type_bit() + } + Self::NewConnectionId { .. } => FRAME_TYPE_NEW_CONNECTION_ID, + Self::RetireConnectionId { .. } => FRAME_TYPE_RETIRE_CONNECTION_ID, + Self::PathChallenge { .. } => FRAME_TYPE_PATH_CHALLENGE, + Self::PathResponse { .. } => FRAME_TYPE_PATH_RESPONSE, + Self::ConnectionClose { error_code, .. } => { + FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT + error_code.frame_type_bit() + } + Self::HandshakeDone => FRAME_TYPE_HANDSHAKE_DONE, + } + } + + pub fn stream_type(fin: bool, nonzero_offset: bool, fill: bool) -> u64 { + let mut t = FRAME_TYPE_STREAM; + if fin { + t |= STREAM_FRAME_BIT_FIN; + } + if nonzero_offset { + t |= STREAM_FRAME_BIT_OFF; + } + if !fill { + t |= STREAM_FRAME_BIT_LEN; + } + t + } + + /// Convert a CONNECTION_CLOSE into a nicer CONNECTION_CLOSE. + pub fn sanitize_close(&self) -> &Self { + if let Self::ConnectionClose { error_code, .. } = &self { + if let CloseError::Application(_) = error_code { + FRAME_APPLICATION_CLOSE + } else { + self + } + } else { + panic!("Attempted to sanitize a non-close frame"); + } + } + + pub fn ack_eliciting(&self) -> bool { + !matches!(self, Self::Ack { .. } | Self::Padding | Self::ConnectionClose { .. }) + } + + /// Converts AckRanges as encoded in a ACK frame (see -transport + /// 19.3.1) into ranges of acked packets (end, start), inclusive of + /// start and end values. + pub fn decode_ack_frame( + largest_acked: u64, + first_ack_range: u64, + ack_ranges: &[AckRange], + ) -> Res<Vec<RangeInclusive<u64>>> { + let mut acked_ranges = Vec::with_capacity(ack_ranges.len() + 1); + + if largest_acked < first_ack_range { + return Err(Error::FrameEncodingError); + } + acked_ranges.push((largest_acked - first_ack_range)..=largest_acked); + if !ack_ranges.is_empty() && largest_acked < first_ack_range + 1 { + return Err(Error::FrameEncodingError); + } + let mut cur = if ack_ranges.is_empty() { + 0 + } else { + largest_acked - first_ack_range - 1 + }; + for r in ack_ranges { + if cur < r.gap + 1 { + return Err(Error::FrameEncodingError); + } + cur = cur - r.gap - 1; + + if cur < r.range { + return Err(Error::FrameEncodingError); + } + acked_ranges.push((cur - r.range)..=cur); + + if cur > r.range + 1 { + cur -= r.range + 1; + } else { + cur -= r.range; + } + } + + Ok(acked_ranges) + } + + pub fn dump(&self) -> Option<String> { + match self { + Self::Crypto { offset, data } => Some(format!( + "Crypto {{ offset: {}, len: {} }}", + offset, + data.len() + )), + Self::Stream { + stream_id, + offset, + fill, + data, + fin, + } => Some(format!( + "Stream {{ stream_id: {}, offset: {}, len: {}{}, fin: {} }}", + stream_id.as_u64(), + offset, + if *fill { ">>" } else { "" }, + data.len(), + fin, + )), + Self::Padding => None, + _ => Some(format!("{:?}", self)), + } + } + + pub fn is_allowed(&self, pt: PacketType) -> bool { + match self { + Self::Padding | Self::Ping => true, + Self::Crypto { .. } + | Self::Ack { .. } + | Self::ConnectionClose { + error_code: CloseError::Transport(_), + .. + } => pt != PacketType::ZeroRtt, + Self::NewToken { .. } | Self::ConnectionClose { .. } => pt == PacketType::Short, + _ => pt == PacketType::ZeroRtt || pt == PacketType::Short, + } + } + + pub fn decode(dec: &mut Decoder<'a>) -> Res<Self> { + fn d<T>(v: Option<T>) -> Res<T> { + v.ok_or(Error::NoMoreData) + } + fn dv(dec: &mut Decoder) -> Res<u64> { + d(dec.decode_varint()) + } + + // TODO(ekr@rtfm.com): check for minimal encoding + let t = d(dec.decode_varint())?; + match t { + FRAME_TYPE_PADDING => Ok(Self::Padding), + FRAME_TYPE_PING => Ok(Self::Ping), + FRAME_TYPE_RST_STREAM => Ok(Self::ResetStream { + stream_id: StreamId::from(dv(dec)?), + application_error_code: d(dec.decode_varint())?, + final_size: match dec.decode_varint() { + Some(v) => v, + _ => return Err(Error::NoMoreData), + }, + }), + FRAME_TYPE_ACK | FRAME_TYPE_ACK_ECN => { + let la = dv(dec)?; + let ad = dv(dec)?; + let nr = dv(dec)?; + let fa = dv(dec)?; + let mut arr: Vec<AckRange> = Vec::with_capacity(nr as usize); + for _ in 0..nr { + let ar = AckRange { + gap: dv(dec)?, + range: dv(dec)?, + }; + arr.push(ar); + } + + // Now check for the values for ACK_ECN. + if t == FRAME_TYPE_ACK_ECN { + dv(dec)?; + dv(dec)?; + dv(dec)?; + } + + Ok(Self::Ack { + largest_acknowledged: la, + ack_delay: ad, + first_ack_range: fa, + ack_ranges: arr, + }) + } + FRAME_TYPE_STOP_SENDING => Ok(Self::StopSending { + stream_id: StreamId::from(dv(dec)?), + application_error_code: d(dec.decode_varint())?, + }), + FRAME_TYPE_CRYPTO => { + let offset = dv(dec)?; + let data = d(dec.decode_vvec())?; + if offset + u64::try_from(data.len()).unwrap() > ((1 << 62) - 1) { + return Err(Error::FrameEncodingError); + } + Ok(Self::Crypto { offset, data }) + } + FRAME_TYPE_NEW_TOKEN => { + let token = d(dec.decode_vvec())?; + if token.is_empty() { + return Err(Error::FrameEncodingError); + } + Ok(Self::NewToken { token }) + } + FRAME_TYPE_STREAM..=FRAME_TYPE_STREAM_MAX => { + let s = dv(dec)?; + let o = if t & STREAM_FRAME_BIT_OFF == 0 { + 0 + } else { + dv(dec)? + }; + let fill = (t & STREAM_FRAME_BIT_LEN) == 0; + let data = if fill { + qtrace!("STREAM frame, extends to the end of the packet"); + dec.decode_remainder() + } else { + qtrace!("STREAM frame, with length"); + d(dec.decode_vvec())? + }; + if o + u64::try_from(data.len()).unwrap() > ((1 << 62) - 1) { + return Err(Error::FrameEncodingError); + } + Ok(Self::Stream { + fin: (t & STREAM_FRAME_BIT_FIN) != 0, + stream_id: StreamId::from(s), + offset: o, + data, + fill, + }) + } + FRAME_TYPE_MAX_DATA => Ok(Self::MaxData { + maximum_data: dv(dec)?, + }), + FRAME_TYPE_MAX_STREAM_DATA => Ok(Self::MaxStreamData { + stream_id: StreamId::from(dv(dec)?), + maximum_stream_data: dv(dec)?, + }), + FRAME_TYPE_MAX_STREAMS_BIDI | FRAME_TYPE_MAX_STREAMS_UNIDI => { + let m = dv(dec)?; + if m > (1 << 60) { + return Err(Error::StreamLimitError); + } + Ok(Self::MaxStreams { + stream_type: StreamType::from_type_bit(t), + maximum_streams: StreamIndex::new(m), + }) + } + FRAME_TYPE_DATA_BLOCKED => Ok(Self::DataBlocked { + data_limit: dv(dec)?, + }), + FRAME_TYPE_STREAM_DATA_BLOCKED => Ok(Self::StreamDataBlocked { + stream_id: dv(dec)?.into(), + stream_data_limit: dv(dec)?, + }), + FRAME_TYPE_STREAMS_BLOCKED_BIDI | FRAME_TYPE_STREAMS_BLOCKED_UNIDI => { + Ok(Self::StreamsBlocked { + stream_type: StreamType::from_type_bit(t), + stream_limit: StreamIndex::new(dv(dec)?), + }) + } + FRAME_TYPE_NEW_CONNECTION_ID => { + let sequence_number = dv(dec)?; + let retire_prior = dv(dec)?; + let connection_id = d(dec.decode_vec(1))?; + if connection_id.len() > MAX_CONNECTION_ID_LEN { + return Err(Error::DecodingFrame); + } + let srt = d(dec.decode(16))?; + let stateless_reset_token = <&[_; 16]>::try_from(srt).unwrap(); + + Ok(Self::NewConnectionId { + sequence_number, + retire_prior, + connection_id, + stateless_reset_token, + }) + } + FRAME_TYPE_RETIRE_CONNECTION_ID => Ok(Self::RetireConnectionId { + sequence_number: dv(dec)?, + }), + FRAME_TYPE_PATH_CHALLENGE => { + let data = d(dec.decode(8))?; + let mut datav: [u8; 8] = [0; 8]; + datav.copy_from_slice(&data); + Ok(Self::PathChallenge { data: datav }) + } + FRAME_TYPE_PATH_RESPONSE => { + let data = d(dec.decode(8))?; + let mut datav: [u8; 8] = [0; 8]; + datav.copy_from_slice(&data); + Ok(Self::PathResponse { data: datav }) + } + FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT | FRAME_TYPE_CONNECTION_CLOSE_APPLICATION => { + let error_code = CloseError::from_type_bit(t, d(dec.decode_varint())?); + let frame_type = if t == FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT { + dv(dec)? + } else { + 0 + }; + // We can tolerate this copy for now. + let reason_phrase = d(dec.decode_vvec())?.to_vec(); + Ok(Self::ConnectionClose { + error_code, + frame_type, + reason_phrase, + }) + } + FRAME_TYPE_HANDSHAKE_DONE => Ok(Self::HandshakeDone), + _ => Err(Error::UnknownFrameType), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use neqo_common::{Decoder, Encoder}; + + fn just_dec(f: &Frame, s: &str) { + let encoded = Encoder::from_hex(s); + let decoded = Frame::decode(&mut encoded.as_decoder()).unwrap(); + assert_eq!(*f, decoded); + } + + #[test] + fn padding() { + let f = Frame::Padding; + just_dec(&f, "00"); + } + + #[test] + fn ping() { + let f = Frame::Ping; + just_dec(&f, "01"); + } + + #[test] + fn ack() { + let ar = vec![AckRange { gap: 1, range: 2 }, AckRange { gap: 3, range: 4 }]; + + let f = Frame::Ack { + largest_acknowledged: 0x1234, + ack_delay: 0x1235, + first_ack_range: 0x1236, + ack_ranges: ar, + }; + + just_dec(&f, "025234523502523601020304"); + + // Try to parse ACK_ECN without ECN values + let enc = Encoder::from_hex("035234523502523601020304"); + let mut dec = enc.as_decoder(); + assert_eq!(Frame::decode(&mut dec).unwrap_err(), Error::NoMoreData); + + // Try to parse ACK_ECN without ECN values + let enc = Encoder::from_hex("035234523502523601020304010203"); + let mut dec = enc.as_decoder(); + assert_eq!(Frame::decode(&mut dec).unwrap(), f); + } + + #[test] + fn reset_stream() { + let f = Frame::ResetStream { + stream_id: StreamId::from(0x1234), + application_error_code: 0x77, + final_size: 0x3456, + }; + + just_dec(&f, "04523440777456"); + } + + #[test] + fn stop_sending() { + let f = Frame::StopSending { + stream_id: StreamId::from(63), + application_error_code: 0x77, + }; + + just_dec(&f, "053F4077") + } + + #[test] + fn crypto() { + let f = Frame::Crypto { + offset: 1, + data: &[1, 2, 3], + }; + + just_dec(&f, "060103010203"); + } + + #[test] + fn new_token() { + let f = Frame::NewToken { + token: &[0x12, 0x34, 0x56], + }; + + just_dec(&f, "0703123456"); + } + + #[test] + fn empty_new_token() { + let mut dec = Decoder::from(&[0x07, 0x00][..]); + assert_eq!( + Frame::decode(&mut dec).unwrap_err(), + Error::FrameEncodingError + ); + } + + #[test] + fn stream() { + // First, just set the length bit. + let f = Frame::Stream { + fin: false, + stream_id: StreamId::from(5), + offset: 0, + data: &[1, 2, 3], + fill: false, + }; + + just_dec(&f, "0a0503010203"); + + // Now with offset != 0 and FIN + let f = Frame::Stream { + fin: true, + stream_id: StreamId::from(5), + offset: 1, + data: &[1, 2, 3], + fill: false, + }; + just_dec(&f, "0f050103010203"); + + // Now to fill the packet. + let f = Frame::Stream { + fin: true, + stream_id: StreamId::from(5), + offset: 0, + data: &[1, 2, 3], + fill: true, + }; + just_dec(&f, "0905010203"); + } + + #[test] + fn max_data() { + let f = Frame::MaxData { + maximum_data: 0x1234, + }; + + just_dec(&f, "105234"); + } + + #[test] + fn max_stream_data() { + let f = Frame::MaxStreamData { + stream_id: StreamId::from(5), + maximum_stream_data: 0x1234, + }; + + just_dec(&f, "11055234"); + } + + #[test] + fn max_streams() { + let mut f = Frame::MaxStreams { + stream_type: StreamType::BiDi, + maximum_streams: StreamIndex::new(0x1234), + }; + + just_dec(&f, "125234"); + + f = Frame::MaxStreams { + stream_type: StreamType::UniDi, + maximum_streams: StreamIndex::new(0x1234), + }; + + just_dec(&f, "135234"); + } + + #[test] + fn data_blocked() { + let f = Frame::DataBlocked { data_limit: 0x1234 }; + + just_dec(&f, "145234"); + } + + #[test] + fn stream_data_blocked() { + let f = Frame::StreamDataBlocked { + stream_id: StreamId::from(5), + stream_data_limit: 0x1234, + }; + + just_dec(&f, "15055234"); + } + + #[test] + fn streams_blocked() { + let mut f = Frame::StreamsBlocked { + stream_type: StreamType::BiDi, + stream_limit: StreamIndex::new(0x1234), + }; + + just_dec(&f, "165234"); + + f = Frame::StreamsBlocked { + stream_type: StreamType::UniDi, + stream_limit: StreamIndex::new(0x1234), + }; + + just_dec(&f, "175234"); + } + + #[test] + fn new_connection_id() { + let f = Frame::NewConnectionId { + sequence_number: 0x1234, + retire_prior: 0, + connection_id: &[0x01, 0x02], + stateless_reset_token: &[9; 16], + }; + + just_dec(&f, "1852340002010209090909090909090909090909090909"); + } + + #[test] + fn too_large_new_connection_id() { + let mut enc = Encoder::from_hex("18523400"); // up to the CID + enc.encode_vvec(&[0x0c; MAX_CONNECTION_ID_LEN + 10]); + enc.encode(&[0x11; 16][..]); + assert_eq!( + Frame::decode(&mut enc.as_decoder()).unwrap_err(), + Error::DecodingFrame + ); + } + + #[test] + fn retire_connection_id() { + let f = Frame::RetireConnectionId { + sequence_number: 0x1234, + }; + + just_dec(&f, "195234"); + } + + #[test] + fn path_challenge() { + let f = Frame::PathChallenge { data: [9; 8] }; + + just_dec(&f, "1a0909090909090909"); + } + + #[test] + fn path_response() { + let f = Frame::PathResponse { data: [9; 8] }; + + just_dec(&f, "1b0909090909090909"); + } + + #[test] + fn connection_close_transport() { + let f = Frame::ConnectionClose { + error_code: CloseError::Transport(0x5678), + frame_type: 0x1234, + reason_phrase: vec![0x01, 0x02, 0x03], + }; + + just_dec(&f, "1c80005678523403010203"); + } + + #[test] + fn connection_close_application() { + let f = Frame::ConnectionClose { + error_code: CloseError::Application(0x5678), + frame_type: 0, + reason_phrase: vec![0x01, 0x02, 0x03], + }; + + just_dec(&f, "1d8000567803010203"); + } + + #[test] + fn test_compare() { + let f1 = Frame::Padding; + let f2 = Frame::Padding; + let f3 = Frame::Crypto { + offset: 0, + data: &[1, 2, 3], + }; + let f4 = Frame::Crypto { + offset: 0, + data: &[1, 2, 3], + }; + let f5 = Frame::Crypto { + offset: 1, + data: &[1, 2, 3], + }; + let f6 = Frame::Crypto { + offset: 0, + data: &[1, 2, 4], + }; + + assert_eq!(f1, f2); + assert_ne!(f1, f3); + assert_eq!(f3, f4); + assert_ne!(f3, f5); + assert_ne!(f3, f6); + } + + #[test] + fn decode_ack_frame() { + let res = Frame::decode_ack_frame(7, 2, &[AckRange { gap: 0, range: 3 }]); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), vec![5..=7, 0..=3]); + } +} diff --git a/third_party/rust/neqo-transport/src/lib.rs b/third_party/rust/neqo-transport/src/lib.rs new file mode 100644 index 0000000000..30c9e783af --- /dev/null +++ b/third_party/rust/neqo-transport/src/lib.rs @@ -0,0 +1,195 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![cfg_attr(feature = "deny-warnings", deny(warnings))] +#![warn(clippy::use_self)] + +use neqo_common::qinfo; + +mod addr_valid; +mod cc; +mod cid; +mod connection; +mod crypto; +mod dump; +mod events; +mod flow_mgr; +mod frame; +mod pace; +mod packet; +mod path; +mod qlog; +mod recovery; +mod recv_stream; +mod send_stream; +mod sender; +pub mod server; +mod stats; +mod stream_id; +pub mod tparams; +mod tracking; + +pub use self::cc::CongestionControlAlgorithm; +pub use self::cid::{ConnectionId, ConnectionIdManager}; +pub use self::connection::{ + params::ConnectionParameters, Connection, FixedConnectionIdManager, Output, State, + ZeroRttState, LOCAL_STREAM_LIMIT_BIDI, LOCAL_STREAM_LIMIT_UNI, +}; +pub use self::events::{ConnectionEvent, ConnectionEvents}; +pub use self::frame::{CloseError, StreamType}; +pub use self::packet::QuicVersion; +pub use self::sender::PacketSender; +pub use self::stats::Stats; +pub use self::stream_id::StreamId; + +pub use self::recv_stream::RECV_BUFFER_SIZE; +pub use self::send_stream::SEND_BUFFER_SIZE; + +pub type TransportError = u64; +const ERROR_APPLICATION_CLOSE: TransportError = 12; +const ERROR_AEAD_LIMIT_REACHED: TransportError = 15; + +#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] +#[allow(clippy::pub_enum_variant_names)] +pub enum Error { + NoError, + InternalError, + ConnectionRefused, + FlowControlError, + StreamLimitError, + StreamStateError, + FinalSizeError, + FrameEncodingError, + TransportParameterError, + ProtocolViolation, + InvalidToken, + ApplicationError, + CryptoError(neqo_crypto::Error), + QlogError, + CryptoAlert(u8), + + // All internal errors from here. + AckedUnsentPacket, + ConnectionState, + DecodingFrame, + DecryptError, + HandshakeFailed, + IdleTimeout, + IntegerOverflow, + InvalidInput, + InvalidMigration, + InvalidPacket, + InvalidResumptionToken, + InvalidRetry, + InvalidStreamId, + KeysDiscarded, + /// Packet protection keys are exhausted. + /// Also used when too many key updates have happened. + KeysExhausted, + /// Packet protection keys aren't available yet for the identified space. + KeysPending(crypto::CryptoSpace), + /// An attempt to update keys can be blocked if + /// a packet sent with the current keys hasn't been acknowledged. + KeyUpdateBlocked, + NoMoreData, + NotConnected, + PacketNumberOverlap, + PeerApplicationError(AppError), + PeerError(TransportError), + StatelessReset, + TooMuchData, + UnexpectedMessage, + UnknownFrameType, + VersionNegotiation, + WrongRole, +} + +impl Error { + pub fn code(&self) -> TransportError { + match self { + Self::NoError + | Self::IdleTimeout + | Self::PeerError(_) + | Self::PeerApplicationError(_) => 0, + Self::ConnectionRefused => 2, + Self::FlowControlError => 3, + Self::StreamLimitError => 4, + Self::StreamStateError => 5, + Self::FinalSizeError => 6, + Self::FrameEncodingError => 7, + Self::TransportParameterError => 8, + Self::ProtocolViolation => 10, + Self::InvalidToken => 11, + Self::KeysExhausted => ERROR_AEAD_LIMIT_REACHED, + Self::ApplicationError => ERROR_APPLICATION_CLOSE, + Self::CryptoAlert(a) => 0x100 + u64::from(*a), + // All the rest are internal errors. + _ => 1, + } + } +} + +impl From<neqo_crypto::Error> for Error { + fn from(err: neqo_crypto::Error) -> Self { + qinfo!("Crypto operation failed {:?}", err); + Self::CryptoError(err) + } +} + +impl From<::qlog::Error> for Error { + fn from(_err: ::qlog::Error) -> Self { + Self::QlogError + } +} + +impl From<std::num::TryFromIntError> for Error { + fn from(_: std::num::TryFromIntError) -> Self { + Self::IntegerOverflow + } +} + +impl ::std::error::Error for Error { + fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> { + match self { + Self::CryptoError(e) => Some(e), + _ => None, + } + } +} + +impl ::std::fmt::Display for Error { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Transport error: {:?}", self) + } +} + +pub type AppError = u64; + +#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] +pub enum ConnectionError { + Transport(Error), + Application(AppError), +} + +impl ConnectionError { + pub fn app_code(&self) -> Option<AppError> { + match self { + Self::Application(e) => Some(*e), + _ => None, + } + } +} + +impl From<CloseError> for ConnectionError { + fn from(err: CloseError) -> Self { + match err { + CloseError::Transport(c) => Self::Transport(Error::PeerError(c)), + CloseError::Application(c) => Self::Application(c), + } + } +} + +pub type Res<T> = std::result::Result<T, Error>; diff --git a/third_party/rust/neqo-transport/src/pace.rs b/third_party/rust/neqo-transport/src/pace.rs new file mode 100644 index 0000000000..624fb8622e --- /dev/null +++ b/third_party/rust/neqo-transport/src/pace.rs @@ -0,0 +1,138 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Pacer +#![deny(clippy::pedantic)] + +use neqo_common::qtrace; + +use std::cmp::min; +use std::convert::TryFrom; +use std::fmt::{Debug, Display}; +use std::time::{Duration, Instant}; + +/// This value determines how much faster the pacer operates than the +/// congestion window. +/// +/// A value of 1 would cause all packets to be spaced over the entire RTT, +/// which is a little slow and might act as an additional restriction in +/// the case the congestion controller increases the congestion window. +/// This value spaces packets over half the congestion window, which matches +/// our current congestion controller, which double the window every RTT. +const PACER_SPEEDUP: usize = 2; + +/// A pacer that uses a leaky bucket. +pub struct Pacer { + /// The last update time. + t: Instant, + /// The maximum capacity, or burst size, in bytes. + m: usize, + /// The current capacity, in bytes. + c: usize, + /// The packet size or minimum capacity for sending, in bytes. + p: usize, +} + +impl Pacer { + /// Create a new `Pacer`. This takes the current time and the + /// initial congestion window. + /// + /// The value of `m` is the maximum capacity. `m` primes the pacer + /// with credit and determines the burst size. `m` must not exceed + /// the initial congestion window, but it should probably be lower. + /// + /// The value of `p` is the packet size, which determines the minimum + /// credit needed before a packet is sent. This should be a substantial + /// fraction of the maximum packet size, if not the packet size. + pub fn new(now: Instant, m: usize, p: usize) -> Self { + assert!(m >= p, "maximum capacity has to be at least one packet"); + Self { t: now, m, c: m, p } + } + + /// Determine when the next packet will be available based on the provided RTT + /// and congestion window. This doesn't update state. + /// This returns a time, which could be in the past (this object doesn't know what + /// the current time is). + pub fn next(&self, rtt: Duration, cwnd: usize) -> Instant { + if self.c >= self.p { + qtrace!([self], "next {}/{:?} no wait = {:?}", cwnd, rtt, self.t); + self.t + } else { + // This is the inverse of the function in `spend`: + // self.t + rtt * (self.p - self.c) / (PACER_SPEEDUP * cwnd) + let r = rtt.as_nanos(); + let d = r.saturating_mul(u128::try_from(self.p - self.c).unwrap()); + let add = d / u128::try_from(cwnd * PACER_SPEEDUP).unwrap(); + let w = u64::try_from(add).map(Duration::from_nanos).unwrap_or(rtt); + let nxt = self.t + w; + qtrace!([self], "next {}/{:?} wait {:?} = {:?}", cwnd, rtt, w, nxt); + nxt + } + } + + /// Spend credit. This cannot fail; users of this API are expected to call + /// next() to determine when to spend. This takes the current time (`now`), + /// an estimate of the round trip time (`rtt`), the estimated congestion + /// window (`cwnd`), and the number of bytes that were sent (`count`). + pub fn spend(&mut self, now: Instant, rtt: Duration, cwnd: usize, count: usize) { + qtrace!([self], "spend {} over {}, {:?}", count, cwnd, rtt); + // Increase the capacity by: + // `(now - self.t) * PACER_SPEEDUP * cwnd / rtt` + // That is, the elapsed fraction of the RTT times rate that data is added. + let incr = now + .saturating_duration_since(self.t) + .as_nanos() + .saturating_mul(u128::try_from(cwnd * PACER_SPEEDUP).unwrap()) + .checked_div(rtt.as_nanos()) + .and_then(|i| usize::try_from(i).ok()) + .unwrap_or(self.m); + + // Add the capacity up to a limit of `self.m`, then subtract `count`. + self.c = min(self.m, (self.c + incr).saturating_sub(count)); + self.t = now; + } +} + +impl Display for Pacer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Pacer {}/{}", self.c, self.p) + } +} + +impl Debug for Pacer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Pacer@{:?} {}/{}..{}", self.t, self.c, self.p, self.m) + } +} + +#[cfg(tests)] +mod tests { + use super::Pacer; + use test_fixture::now; + + const RTT: Duration = Duration::from_millis(1000); + const PACKET: usize = 1000; + const CWND: usize = PACKET * 10; + + #[test] + fn even() { + let mut n = now(); + let p = Pacer::new(n, PACKET, PACKET); + assert_eq!(p.next(RTT, CWND), None); + p.spend(n, RTT, CWND, PACKET); + assert_eq!(p.next(RTT, CWND), Some(n + (RTT / 10))); + } + + #[test] + fn backwards_in_time() { + let mut n = now(); + let p = Pacer::new(n + RTT, PACKET, PACKET); + assert_eq!(p.next(RTT, CWND), None); + // Now spend some credit in the past using a time machine. + p.spend(n, RTT, CWND, PACKET); + assert_eq!(p.next(RTT, CWND), Some(n + (RTT / 10))); + } +} diff --git a/third_party/rust/neqo-transport/src/packet/mod.rs b/third_party/rust/neqo-transport/src/packet/mod.rs new file mode 100644 index 0000000000..9ab4d811a1 --- /dev/null +++ b/third_party/rust/neqo-transport/src/packet/mod.rs @@ -0,0 +1,1339 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Encoding and decoding packets off the wire. +use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdRef, MAX_CONNECTION_ID_LEN}; +use crate::crypto::{CryptoDxState, CryptoSpace, CryptoStates}; +use crate::{Error, Res}; + +use neqo_common::{hex, hex_with_len, qtrace, qwarn, Decoder, Encoder}; +use neqo_crypto::random; + +use std::cmp::min; +use std::convert::TryFrom; +use std::fmt; +use std::iter::ExactSizeIterator; +use std::ops::{Deref, DerefMut, Range}; +use std::time::Instant; + +const PACKET_TYPE_INITIAL: u8 = 0x0; +const PACKET_TYPE_0RTT: u8 = 0x01; +const PACKET_TYPE_HANDSHAKE: u8 = 0x2; +const PACKET_TYPE_RETRY: u8 = 0x03; + +pub const PACKET_BIT_LONG: u8 = 0x80; +const PACKET_BIT_SHORT: u8 = 0x00; +const PACKET_BIT_FIXED_QUIC: u8 = 0x40; +const PACKET_BIT_SPIN: u8 = 0x20; +const PACKET_BIT_KEY_PHASE: u8 = 0x04; + +const PACKET_HP_MASK_LONG: u8 = 0x0f; +const PACKET_HP_MASK_SHORT: u8 = 0x1f; + +const SAMPLE_SIZE: usize = 16; +const SAMPLE_OFFSET: usize = 4; +const MAX_PACKET_NUMBER_LEN: usize = 4; + +mod retry; + +pub type PacketNumber = u64; +type Version = u32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PacketType { + VersionNegotiation, + Initial, + Handshake, + ZeroRtt, + Retry, + Short, + OtherVersion, +} + +impl PacketType { + #[must_use] + fn code(self) -> u8 { + match self { + Self::Initial => PACKET_TYPE_INITIAL, + Self::ZeroRtt => PACKET_TYPE_0RTT, + Self::Handshake => PACKET_TYPE_HANDSHAKE, + Self::Retry => PACKET_TYPE_RETRY, + _ => panic!("shouldn't be here"), + } + } +} + +impl Into<CryptoSpace> for PacketType { + fn into(self) -> CryptoSpace { + match self { + Self::Initial => CryptoSpace::Initial, + Self::ZeroRtt => CryptoSpace::ZeroRtt, + Self::Handshake => CryptoSpace::Handshake, + Self::Short => CryptoSpace::ApplicationData, + _ => panic!("shouldn't be here"), + } + } +} + +impl From<CryptoSpace> for PacketType { + fn from(cs: CryptoSpace) -> Self { + match cs { + CryptoSpace::Initial => Self::Initial, + CryptoSpace::ZeroRtt => Self::ZeroRtt, + CryptoSpace::Handshake => Self::Handshake, + CryptoSpace::ApplicationData => Self::Short, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum QuicVersion { + Draft27, + Draft28, + Draft29, + Draft30, + Draft31, + Draft32, +} + +impl QuicVersion { + pub fn as_u32(self) -> Version { + match self { + Self::Draft27 => 0xff00_0000 + 27, + Self::Draft28 => 0xff00_0000 + 28, + Self::Draft29 => 0xff00_0000 + 29, + Self::Draft30 => 0xff00_0000 + 30, + Self::Draft31 => 0xff00_0000 + 31, + Self::Draft32 => 0xff00_0000 + 32, + } + } +} + +impl Default for QuicVersion { + fn default() -> Self { + Self::Draft29 + } +} + +impl TryFrom<Version> for QuicVersion { + type Error = Error; + + fn try_from(ver: Version) -> Res<Self> { + if ver == 0xff00_0000 + 27 { + Ok(Self::Draft27) + } else if ver == 0xff00_0000 + 28 { + Ok(Self::Draft28) + } else if ver == 0xff00_0000 + 29 { + Ok(Self::Draft29) + } else if ver == 0xff00_0000 + 30 { + Ok(Self::Draft30) + } else if ver == 0xff00_0000 + 31 { + Ok(Self::Draft31) + } else if ver == 0xff00_0000 + 32 { + Ok(Self::Draft32) + } else { + Err(Error::VersionNegotiation) + } + } +} + +struct PacketBuilderOffsets { + /// The bits of the first octet that need masking. + first_byte_mask: u8, + /// The offset of the length field. + len: usize, + /// The location of the packet number field. + pn: Range<usize>, +} + +/// A packet builder that can be used to produce short packets and long packets. +/// This does not produce Retry or Version Negotiation. +pub struct PacketBuilder { + encoder: Encoder, + pn: PacketNumber, + header: Range<usize>, + offsets: PacketBuilderOffsets, + limit: usize, +} + +impl PacketBuilder { + fn infer_limit(encoder: &Encoder) -> usize { + if encoder.capacity() > 64 { + encoder.capacity() + } else { + 2048 + } + } + + /// Start building a short header packet. + #[allow(clippy::unknown_clippy_lints)] // Until we require rust 1.45. + #[allow(clippy::reversed_empty_ranges)] + pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self { + let header_start = encoder.len(); + encoder.encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2)); + encoder.encode(dcid.as_ref()); + let limit = Self::infer_limit(&encoder); + Self { + encoder, + pn: u64::max_value(), + header: header_start..header_start, + offsets: PacketBuilderOffsets { + first_byte_mask: PACKET_HP_MASK_SHORT, + pn: 0..0, + len: 0, + }, + limit, + } + } + + /// Start building a long header packet. + /// For an Initial packet you will need to call initial_token(), + /// even if the token is empty. + #[allow(clippy::unknown_clippy_lints)] // Until we require rust 1.45. + #[allow(clippy::reversed_empty_ranges)] // For initializing an empty range. + pub fn long( + mut encoder: Encoder, + pt: PacketType, + quic_version: QuicVersion, + dcid: impl AsRef<[u8]>, + scid: impl AsRef<[u8]>, + ) -> Self { + let header_start = encoder.len(); + encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.code() << 4); + encoder.encode_uint(4, quic_version.as_u32()); + encoder.encode_vec(1, dcid.as_ref()); + encoder.encode_vec(1, scid.as_ref()); + let limit = Self::infer_limit(&encoder); + Self { + encoder, + pn: u64::max_value(), + header: header_start..header_start, + offsets: PacketBuilderOffsets { + first_byte_mask: PACKET_HP_MASK_LONG, + pn: 0..0, + len: 0, + }, + limit, + } + } + + fn is_long(&self) -> bool { + self[self.header.start] & 0x80 == PACKET_BIT_LONG + } + + /// This stores a value that can be used as a limit. This does not cause + /// this limit to be enforced until encryption occurs. Prior to that, it + /// is only used voluntarily by users of the builder, through `remaining()`. + pub fn set_limit(&mut self, limit: usize) { + self.limit = limit; + } + + /// How many bytes remain against the size limit for the builder. + #[must_use] + pub fn remaining(&self) -> usize { + self.limit - self.encoder.len() + } + + /// Pad with "PADDING" frames. + pub fn pad(&mut self) { + self.encoder.pad_to(self.limit, 0); + } + + /// Add unpredictable values for unprotected parts of the packet. + pub fn scramble(&mut self, quic_bit: bool) { + let mask = if quic_bit { PACKET_BIT_FIXED_QUIC } else { 0 } + | if self.is_long() { 0 } else { PACKET_BIT_SPIN }; + let first = self.header.start; + self[first] ^= random(1)[0] & mask; + } + + /// For an Initial packet, encode the token. + /// If you fail to do this, then you will not get a valid packet. + pub fn initial_token(&mut self, token: &[u8]) { + debug_assert_eq!( + self.encoder[self.header.start] & 0xb0, + PACKET_BIT_LONG | PACKET_TYPE_INITIAL << 4 + ); + self.encoder.encode_vvec(token); + } + + /// Add a packet number of the given size. + /// For a long header packet, this also inserts a dummy length. + /// The length is filled in after calling `build`. + pub fn pn(&mut self, pn: PacketNumber, pn_len: usize) { + // Reserve space for a length in long headers. + if self.is_long() { + self.offsets.len = self.encoder.len(); + self.encoder.encode(&[0; 2]); + } + + // This allows the input to be >4, which is absurd, but we can eat that. + let pn_len = min(MAX_PACKET_NUMBER_LEN, pn_len); + debug_assert_ne!(pn_len, 0); + // Encode the packet number and save its offset. + let pn_offset = self.encoder.len(); + self.encoder.encode_uint(pn_len, pn); + self.offsets.pn = pn_offset..self.encoder.len(); + + // Now encode the packet number length and save the header length. + self.encoder[self.header.start] |= u8::try_from(pn_len - 1).unwrap(); + self.header.end = self.encoder.len(); + self.pn = pn; + } + + fn write_len(&mut self, expansion: usize) { + let len = self.encoder.len() - (self.offsets.len + 2) + expansion; + self.encoder[self.offsets.len] = 0x40 | ((len >> 8) & 0x3f) as u8; + self.encoder[self.offsets.len + 1] = (len & 0xff) as u8; + } + + fn pad_for_crypto(&mut self, crypto: &mut CryptoDxState) { + // Make sure that there is enough data in the packet. + // The length of the packet number plus the payload length needs to + // be at least 4 (MAX_PACKET_NUMBER_LEN) plus any amount by which + // the header protection sample exceeds the AEAD expansion. + let crypto_pad = crypto.extra_padding(); + self.encoder.pad_to( + self.offsets.pn.start + MAX_PACKET_NUMBER_LEN + crypto_pad, + 0, + ); + } + + /// Build the packet and return the encoder. + pub fn build(mut self, crypto: &mut CryptoDxState) -> Res<Encoder> { + if self.len() > self.limit { + qwarn!("Packet contents are more than the limit"); + debug_assert!(false); + return Err(Error::InternalError); + } + + self.pad_for_crypto(crypto); + if self.offsets.len > 0 { + self.write_len(crypto.expansion()); + } + + let hdr = &self.encoder[self.header.clone()]; + let body = &self.encoder[self.header.end..]; + qtrace!( + "Packet build pn={} hdr={} body={}", + self.pn, + hex(hdr), + hex(body) + ); + let ciphertext = crypto.encrypt(self.pn, hdr, body)?; + + // Calculate the mask. + let offset = SAMPLE_OFFSET - self.offsets.pn.len(); + assert!(offset + SAMPLE_SIZE <= ciphertext.len()); + let sample = &ciphertext[offset..offset + SAMPLE_SIZE]; + let mask = crypto.compute_mask(sample)?; + + // Apply the mask. + self.encoder[self.header.start] ^= mask[0] & self.offsets.first_byte_mask; + for (i, j) in (1..=self.offsets.pn.len()).zip(self.offsets.pn) { + self.encoder[j] ^= mask[i]; + } + + // Finally, cut off the plaintext and add back the ciphertext. + self.encoder.truncate(self.header.end); + self.encoder.encode(&ciphertext); + qtrace!("Packet built {}", hex(&self.encoder)); + Ok(self.encoder) + } + + /// Abort writing of this packet and return the encoder. + #[must_use] + pub fn abort(mut self) -> Encoder { + self.encoder.truncate(self.header.start); + self.encoder + } + + /// Work out if nothing was added after the header. + #[must_use] + pub fn packet_empty(&self) -> bool { + self.encoder.len() == self.header.end + } + + /// Make a retry packet. + /// As this is a simple packet, this is just an associated function. + /// As Retry is odd (it has to be constructed with leading bytes), + /// this returns a Vec<u8> rather than building on an encoder. + pub fn retry( + quic_version: QuicVersion, + dcid: &[u8], + scid: &[u8], + token: &[u8], + odcid: &[u8], + ) -> Res<Vec<u8>> { + let mut encoder = Encoder::default(); + encoder.encode_vec(1, odcid); + let start = encoder.len(); + encoder.encode_byte( + PACKET_BIT_LONG + | PACKET_BIT_FIXED_QUIC + | (PACKET_TYPE_RETRY << 4) + | (random(1)[0] & 0xf), + ); + encoder.encode_uint(4, quic_version.as_u32()); + encoder.encode_vec(1, dcid); + encoder.encode_vec(1, scid); + debug_assert_ne!(token.len(), 0); + encoder.encode(token); + let tag = retry::use_aead(quic_version, |aead| { + let mut buf = vec![0; aead.expansion()]; + Ok(aead.encrypt(0, &encoder, &[], &mut buf)?.to_vec()) + })?; + encoder.encode(&tag); + let mut complete: Vec<u8> = encoder.into(); + Ok(complete.split_off(start)) + } + + /// Make a Version Negotiation packet. + pub fn version_negotiation(dcid: &[u8], scid: &[u8]) -> Vec<u8> { + let mut encoder = Encoder::default(); + let mut grease = random(5); + // This will not include the "QUIC bit" sometimes. Intentionally. + encoder.encode_byte(PACKET_BIT_LONG | (grease[4] & 0x7f)); + encoder.encode(&[0; 4]); // Zero version == VN. + encoder.encode_vec(1, dcid); + encoder.encode_vec(1, scid); + encoder.encode_uint(4, QuicVersion::Draft27.as_u32()); + encoder.encode_uint(4, QuicVersion::Draft28.as_u32()); + encoder.encode_uint(4, QuicVersion::Draft29.as_u32()); + encoder.encode_uint(4, QuicVersion::Draft30.as_u32()); + encoder.encode_uint(4, QuicVersion::Draft31.as_u32()); + encoder.encode_uint(4, QuicVersion::Draft32.as_u32()); + // Add a greased version, using the randomness already generated. + for g in &mut grease[..4] { + *g = *g & 0xf0 | 0x0a; + } + encoder.encode(&grease[0..4]); + encoder.into() + } +} + +impl Deref for PacketBuilder { + type Target = Encoder; + + fn deref(&self) -> &Self::Target { + &self.encoder + } +} + +impl DerefMut for PacketBuilder { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.encoder + } +} + +impl Into<Encoder> for PacketBuilder { + fn into(self) -> Encoder { + self.encoder + } +} + +/// PublicPacket holds information from packets that is public only. This allows for +/// processing of packets prior to decryption. +pub struct PublicPacket<'a> { + /// The packet type. + packet_type: PacketType, + /// The recovered destination connection ID. + dcid: ConnectionIdRef<'a>, + /// The source connection ID, if this is a long header packet. + scid: Option<ConnectionIdRef<'a>>, + /// Any token that is included in the packet (Retry always has a token; Initial sometimes does). + /// This is empty when there is no token. + token: &'a [u8], + /// The size of the header, not including the packet number. + header_len: usize, + /// Protocol version, if present in header. + quic_version: Option<QuicVersion>, + /// A reference to the entire packet, including the header. + data: &'a [u8], +} + +impl<'a> PublicPacket<'a> { + fn opt<T>(v: Option<T>) -> Res<T> { + if let Some(v) = v { + Ok(v) + } else { + Err(Error::NoMoreData) + } + } + + /// Decode the type-specific portions of a long header. + /// This includes reading the length and the remainder of the packet. + /// Returns a tuple of any token and the length of the header. + fn decode_long( + decoder: &mut Decoder<'a>, + packet_type: PacketType, + quic_version: QuicVersion, + ) -> Res<(&'a [u8], usize)> { + if packet_type == PacketType::Retry { + let header_len = decoder.offset(); + let expansion = retry::expansion(quic_version); + let token = Self::opt(decoder.decode(decoder.remaining() - expansion))?; + if token.is_empty() { + return Err(Error::InvalidPacket); + } + Self::opt(decoder.decode(expansion))?; + return Ok((token, header_len)); + } + let token = if packet_type == PacketType::Initial { + Self::opt(decoder.decode_vvec())? + } else { + &[] + }; + let len = Self::opt(decoder.decode_varint())?; + let header_len = decoder.offset(); + let _body = Self::opt(decoder.decode(usize::try_from(len)?))?; + Ok((token, header_len)) + } + + /// Decode the common parts of a packet. This provides minimal parsing and validation. + /// Returns a tuple of a `PublicPacket` and a slice with any remainder from the datagram. + pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> { + let mut decoder = Decoder::new(data); + let first = Self::opt(decoder.decode_byte())?; + + if first & 0x80 == PACKET_BIT_SHORT { + let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?; + if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { + return Err(Error::InvalidPacket); + } + let header_len = decoder.offset(); + return Ok(( + Self { + packet_type: PacketType::Short, + dcid, + scid: None, + token: &[], + header_len, + quic_version: None, + data, + }, + &[], + )); + } + + // Generic long header. + let version = Version::try_from(Self::opt(decoder.decode_uint(4))?).unwrap(); + let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); + let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); + + // Version negotiation. + if version == 0 { + return Ok(( + Self { + packet_type: PacketType::VersionNegotiation, + dcid, + scid: Some(scid), + token: &[], + header_len: decoder.offset(), + quic_version: None, + data, + }, + &[], + )); + } + + // Check that this is a long header from a supported version. + let quic_version = if let Ok(v) = QuicVersion::try_from(version) { + v + } else { + return Ok(( + Self { + packet_type: PacketType::OtherVersion, + dcid, + scid: Some(scid), + token: &[], + header_len: decoder.offset(), + quic_version: None, + data, + }, + &[], + )); + }; + + if dcid.len() > MAX_CONNECTION_ID_LEN || scid.len() > MAX_CONNECTION_ID_LEN { + return Err(Error::InvalidPacket); + } + let packet_type = match (first >> 4) & 3 { + PACKET_TYPE_INITIAL => PacketType::Initial, + PACKET_TYPE_0RTT => PacketType::ZeroRtt, + PACKET_TYPE_HANDSHAKE => PacketType::Handshake, + PACKET_TYPE_RETRY => PacketType::Retry, + _ => unreachable!(), + }; + + // The type-specific code includes a token. This consumes the remainder of the packet. + let (token, header_len) = Self::decode_long(&mut decoder, packet_type, quic_version)?; + let end = data.len() - decoder.remaining(); + let (data, remainder) = data.split_at(end); + Ok(( + Self { + packet_type, + dcid, + scid: Some(scid), + token, + header_len, + quic_version: Some(quic_version), + data, + }, + remainder, + )) + } + + /// Validate the given packet as though it were a retry. + pub fn is_valid_retry(&self, odcid: &ConnectionId) -> bool { + if self.packet_type != PacketType::Retry { + return false; + } + let version = self.quic_version.unwrap(); + let expansion = retry::expansion(version); + if self.data.len() <= expansion { + return false; + } + let (header, tag) = self.data.split_at(self.data.len() - expansion); + let mut encoder = Encoder::with_capacity(self.data.len()); + encoder.encode_vec(1, odcid); + encoder.encode(header); + retry::use_aead(version, |aead| { + let mut buf = vec![0; expansion]; + Ok(aead.decrypt(0, &encoder, tag, &mut buf)?.is_empty()) + }) + .unwrap_or(false) + } + + pub fn is_valid_initial(&self) -> bool { + // Packet has to be an initial, with a DCID of 8 bytes, or a token. + // Note: the Server class validates the token and checks the length. + self.packet_type == PacketType::Initial + && (self.dcid().len() >= 8 || !self.token.is_empty()) + } + + pub fn packet_type(&self) -> PacketType { + self.packet_type + } + + pub fn dcid(&self) -> &ConnectionIdRef<'a> { + &self.dcid + } + + pub fn scid(&self) -> &ConnectionIdRef<'a> { + self.scid + .as_ref() + .expect("should only be called for long header packets") + } + + pub fn token(&self) -> &'a [u8] { + self.token + } + + pub fn version(&self) -> Option<QuicVersion> { + self.quic_version + } + + pub fn len(&self) -> usize { + self.data.len() + } + + fn decode_pn(expected: PacketNumber, pn: u64, w: usize) -> PacketNumber { + let window = 1_u64 << (w * 8); + let candidate = (expected & !(window - 1)) | pn; + if candidate + (window / 2) <= expected { + candidate + window + } else if candidate > expected + (window / 2) { + match candidate.checked_sub(window) { + Some(pn_sub) => pn_sub, + None => candidate, + } + } else { + candidate + } + } + + /// Decrypt the header of the packet. + fn decrypt_header( + &self, + crypto: &mut CryptoDxState, + ) -> Res<(bool, PacketNumber, Vec<u8>, &'a [u8])> { + assert_ne!(self.packet_type, PacketType::Retry); + assert_ne!(self.packet_type, PacketType::VersionNegotiation); + + qtrace!( + "unmask hdr={}", + hex(&self.data[..self.header_len + SAMPLE_OFFSET]) + ); + + let sample_offset = self.header_len + SAMPLE_OFFSET; + let mask = if let Some(sample) = self.data.get(sample_offset..(sample_offset + SAMPLE_SIZE)) + { + crypto.compute_mask(sample) + } else { + Err(Error::NoMoreData) + }?; + + // Un-mask the leading byte. + let bits = if self.packet_type == PacketType::Short { + PACKET_HP_MASK_SHORT + } else { + PACKET_HP_MASK_LONG + }; + let first_byte = self.data[0] ^ (mask[0] & bits); + + // Make a copy of the header to work on. + let mut hdrbytes = self.data[..self.header_len + 4].to_vec(); + hdrbytes[0] = first_byte; + + // Unmask the PN. + let mut pn_encoded: u64 = 0; + for i in 0..MAX_PACKET_NUMBER_LEN { + hdrbytes[self.header_len + i] ^= mask[1 + i]; + pn_encoded <<= 8; + pn_encoded += u64::from(hdrbytes[self.header_len + i]); + } + + // Now decode the packet number length and apply it, hopefully in constant time. + let pn_len = usize::from((first_byte & 0x3) + 1); + hdrbytes.truncate(self.header_len + pn_len); + pn_encoded >>= 8 * (MAX_PACKET_NUMBER_LEN - pn_len); + + qtrace!("unmasked hdr={}", hex(&hdrbytes)); + + let key_phase = self.packet_type == PacketType::Short + && (first_byte & PACKET_BIT_KEY_PHASE) == PACKET_BIT_KEY_PHASE; + let pn = Self::decode_pn(crypto.next_pn(), pn_encoded, pn_len); + Ok(( + key_phase, + pn, + hdrbytes, + &self.data[self.header_len + pn_len..], + )) + } + + pub fn decrypt(&self, crypto: &mut CryptoStates, release_at: Instant) -> Res<DecryptedPacket> { + let cspace: CryptoSpace = self.packet_type.into(); + // This has to work in two stages because we need to remove header protection + // before picking the keys to use. + if let Some(rx) = crypto.rx_hp(cspace) { + // Note that this will dump early, which creates a side-channel. + // This is OK in this case because we the only reason this can + // fail is if the cryptographic module is bad or the packet is + // too small (which is public information). + let (key_phase, pn, header, body) = self.decrypt_header(rx)?; + qtrace!([rx], "decoded header: {:?}", header); + let rx = crypto.rx(cspace, key_phase).unwrap(); + let d = rx.decrypt(pn, &header, body)?; + // If this is the first packet ever successfully decrypted + // using `rx`, make sure to initiate a key update. + if rx.needs_update() { + crypto.key_update_received(release_at)?; + } + crypto.check_pn_overlap()?; + Ok(DecryptedPacket { + pt: self.packet_type, + pn, + data: d, + }) + } else if crypto.rx_pending(cspace) { + Err(Error::KeysPending(cspace)) + } else { + qtrace!("keys for {:?} already discarded", cspace); + Err(Error::KeysDiscarded) + } + } + + pub fn supported_versions(&self) -> Res<Vec<Version>> { + assert_eq!(self.packet_type, PacketType::VersionNegotiation); + let mut decoder = Decoder::new(&self.data[self.header_len..]); + let mut res = Vec::new(); + while decoder.remaining() > 0 { + let version = Version::try_from(Self::opt(decoder.decode_uint(4))?)?; + res.push(version); + } + Ok(res) + } +} + +impl fmt::Debug for PublicPacket<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{:?}: {} {}", + self.packet_type(), + hex_with_len(&self.data[..self.header_len]), + hex_with_len(&self.data[self.header_len..]) + ) + } +} + +pub struct DecryptedPacket { + pt: PacketType, + pn: PacketNumber, + data: Vec<u8>, +} + +impl DecryptedPacket { + pub fn packet_type(&self) -> PacketType { + self.pt + } + + pub fn pn(&self) -> PacketNumber { + self.pn + } +} + +impl Deref for DecryptedPacket { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.data[..] + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::crypto::{CryptoDxState, CryptoStates}; + use crate::{FixedConnectionIdManager, QuicVersion}; + use neqo_common::Encoder; + use test_fixture::{fixture_init, now}; + + const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; + const SERVER_CID: &[u8] = &[0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5]; + + /// This is a connection ID manager, which is only used for decoding short header packets. + fn cid_mgr() -> FixedConnectionIdManager { + FixedConnectionIdManager::new(SERVER_CID.len()) + } + + const SAMPLE_INITIAL_PAYLOAD: &[u8] = &[ + 0x02, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x40, 0x5a, 0x02, 0x00, 0x00, 0x56, 0x03, 0x03, + 0xee, 0xfc, 0xe7, 0xf7, 0xb3, 0x7b, 0xa1, 0xd1, 0x63, 0x2e, 0x96, 0x67, 0x78, 0x25, 0xdd, + 0xf7, 0x39, 0x88, 0xcf, 0xc7, 0x98, 0x25, 0xdf, 0x56, 0x6d, 0xc5, 0x43, 0x0b, 0x9a, 0x04, + 0x5a, 0x12, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00, + 0x20, 0x9d, 0x3c, 0x94, 0x0d, 0x89, 0x69, 0x0b, 0x84, 0xd0, 0x8a, 0x60, 0x99, 0x3c, 0x14, + 0x4e, 0xca, 0x68, 0x4d, 0x10, 0x81, 0x28, 0x7c, 0x83, 0x4d, 0x53, 0x11, 0xbc, 0xf3, 0x2b, + 0xb9, 0xda, 0x1a, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04, + ]; + const SAMPLE_INITIAL: &[u8] = &[ + 0xc7, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x00, 0x40, 0x75, 0xfb, 0x12, 0xff, 0x07, 0x82, 0x3a, 0x5d, 0x24, 0x53, 0x4d, 0x90, 0x6c, + 0xe4, 0xc7, 0x67, 0x82, 0xa2, 0x16, 0x7e, 0x34, 0x79, 0xc0, 0xf7, 0xf6, 0x39, 0x5d, 0xc2, + 0xc9, 0x16, 0x76, 0x30, 0x2f, 0xe6, 0xd7, 0x0b, 0xb7, 0xcb, 0xeb, 0x11, 0x7b, 0x4d, 0xdb, + 0x7d, 0x17, 0x34, 0x98, 0x44, 0xfd, 0x61, 0xda, 0xe2, 0x00, 0xb8, 0x33, 0x8e, 0x1b, 0x93, + 0x29, 0x76, 0xb6, 0x1d, 0x91, 0xe6, 0x4a, 0x02, 0xe9, 0xe0, 0xee, 0x72, 0xe3, 0xa6, 0xf6, + 0x3a, 0xba, 0x4c, 0xee, 0xee, 0xc5, 0xbe, 0x2f, 0x24, 0xf2, 0xd8, 0x60, 0x27, 0x57, 0x29, + 0x43, 0x53, 0x38, 0x46, 0xca, 0xa1, 0x3e, 0x6f, 0x16, 0x3f, 0xb2, 0x57, 0x47, 0x3d, 0xcc, + 0xa2, 0x53, 0x96, 0xe8, 0x87, 0x24, 0xf1, 0xe5, 0xd9, 0x64, 0xde, 0xde, 0xe9, 0xb6, 0x33, + ]; + + #[test] + fn sample_server_initial() { + fixture_init(); + let mut prot = CryptoDxState::test_default(); + + // The spec uses PN=1, but our crypto refuses to skip packet numbers. + // So burn an encryption: + let burn = prot.encrypt(0, &[], &[]).expect("burn OK"); + assert_eq!(burn.len(), prot.expansion()); + + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Initial, + QuicVersion::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(SERVER_CID), + ); + builder.initial_token(&[]); + builder.pn(1, 2); + builder.encode(&SAMPLE_INITIAL_PAYLOAD); + let packet = builder.build(&mut prot).expect("build"); + assert_eq!(&packet[..], SAMPLE_INITIAL); + } + + #[test] + fn decrypt_initial() { + const EXTRA: &[u8] = &[0xce; 33]; + + fixture_init(); + let mut padded = SAMPLE_INITIAL.to_vec(); + padded.extend_from_slice(EXTRA); + let (packet, remainder) = PublicPacket::decode(&padded, &cid_mgr()).unwrap(); + assert_eq!(packet.packet_type(), PacketType::Initial); + assert_eq!(&packet.dcid()[..], &[] as &[u8]); + assert_eq!(&packet.scid()[..], SERVER_CID); + assert!(packet.token().is_empty()); + assert_eq!(remainder, EXTRA); + + let decrypted = packet + .decrypt(&mut CryptoStates::test_default(), now()) + .unwrap(); + assert_eq!(decrypted.pn(), 1); + } + + #[test] + fn disallow_long_dcid() { + let mut enc = Encoder::new(); + enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC); + enc.encode_uint(4, QuicVersion::default().as_u32()); + enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 1]); + enc.encode_vec(1, &[]); + enc.encode(&[0xff; 40]); // junk + + assert!(PublicPacket::decode(&enc, &cid_mgr()).is_err()); + } + + #[test] + fn disallow_long_scid() { + let mut enc = Encoder::new(); + enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC); + enc.encode_uint(4, QuicVersion::default().as_u32()); + enc.encode_vec(1, &[]); + enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 2]); + enc.encode(&[0xff; 40]); // junk + + assert!(PublicPacket::decode(&enc, &cid_mgr()).is_err()); + } + + const SAMPLE_SHORT: &[u8] = &[ + 0x55, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x99, 0x9c, 0xbd, 0x77, 0xf5, 0xd7, + 0x0a, 0x28, 0xe8, 0xfb, 0xc3, 0xed, 0xf5, 0x71, 0xb1, 0x04, 0x32, 0x2a, 0xae, 0xae, + ]; + const SAMPLE_SHORT_PAYLOAD: &[u8] = &[0; 3]; + + #[test] + fn build_short() { + fixture_init(); + let mut builder = + PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID)); + builder.pn(0, 1); + builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling. + let packet = builder + .build(&mut CryptoDxState::test_default()) + .expect("build"); + assert_eq!(&packet[..], SAMPLE_SHORT); + } + + #[test] + fn scramble_short() { + fixture_init(); + let mut firsts = Vec::new(); + for _ in 0..64 { + let mut builder = + PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID)); + builder.scramble(true); + builder.pn(0, 1); + firsts.push(builder[0]); + } + let is_set = |bit| move |v| v & bit == bit; + // There should be at least one value with the QUIC bit set: + assert!(firsts.iter().any(is_set(PACKET_BIT_FIXED_QUIC))); + // ... but not all: + assert!(!firsts.iter().all(is_set(PACKET_BIT_FIXED_QUIC))); + // There should be at least one value with the spin bit set: + assert!(firsts.iter().any(is_set(PACKET_BIT_SPIN))); + // ... but not all: + assert!(!firsts.iter().all(is_set(PACKET_BIT_SPIN))); + } + + #[test] + fn decode_short() { + fixture_init(); + let (packet, remainder) = PublicPacket::decode(SAMPLE_SHORT, &cid_mgr()).unwrap(); + assert_eq!(packet.packet_type(), PacketType::Short); + assert!(remainder.is_empty()); + let decrypted = packet + .decrypt(&mut CryptoStates::test_default(), now()) + .unwrap(); + assert_eq!(&decrypted[..], SAMPLE_SHORT_PAYLOAD); + } + + /// By telling the decoder that the connection ID is shorter than it really is, we get a decryption error. + #[test] + fn decode_short_bad_cid() { + fixture_init(); + let (packet, remainder) = PublicPacket::decode( + SAMPLE_SHORT, + &FixedConnectionIdManager::new(SERVER_CID.len() - 1), + ) + .unwrap(); + assert_eq!(packet.packet_type(), PacketType::Short); + assert!(remainder.is_empty()); + assert!(packet + .decrypt(&mut CryptoStates::test_default(), now()) + .is_err()); + } + + /// Saying that the connection ID is longer causes the initial decode to fail. + #[test] + fn decode_short_long_cid() { + assert!(PublicPacket::decode( + SAMPLE_SHORT, + &FixedConnectionIdManager::new(SERVER_CID.len() + 1) + ) + .is_err()); + } + + #[test] + fn build_two() { + fixture_init(); + let mut prot = CryptoDxState::test_default(); + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Handshake, + QuicVersion::default(), + &ConnectionId::from(SERVER_CID), + &ConnectionId::from(CLIENT_CID), + ); + builder.pn(0, 1); + builder.encode(&[0; 3]); + let encoder = builder.build(&mut prot).expect("build"); + assert_eq!(encoder.len(), 45); + let first = encoder.clone(); + + let mut builder = PacketBuilder::short(encoder, false, &ConnectionId::from(SERVER_CID)); + builder.pn(1, 3); + builder.encode(&[0]); // Minimal size (packet number is big enough). + let encoder = builder.build(&mut prot).expect("build"); + assert_eq!( + &first[..], + &encoder[..first.len()], + "the first packet should be a prefix" + ); + assert_eq!(encoder.len(), 45 + 29); + } + + #[test] + fn build_long() { + const EXPECTED: &[u8] = &[ + 0xe5, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x00, 0x40, 0x14, 0xa8, 0x9d, 0xbf, 0x74, 0x70, + 0x32, 0xda, 0xba, 0xfb, 0x87, 0x61, 0xb8, 0x31, 0x90, 0xf3, 0x25, 0x52, 0x0b, 0xbe, + 0xdb, + ]; + + fixture_init(); + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Handshake, + QuicVersion::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(&[][..]), + ); + builder.pn(0, 1); + builder.encode(&[1, 2, 3]); + let packet = builder.build(&mut CryptoDxState::test_default()).unwrap(); + assert_eq!(&packet[..], EXPECTED); + } + + #[test] + fn scramble_long() { + fixture_init(); + let mut found_unset = false; + let mut found_set = false; + for _ in 1..64 { + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Handshake, + QuicVersion::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(&[][..]), + ); + builder.pn(0, 1); + builder.scramble(true); + if (builder[0] & PACKET_BIT_FIXED_QUIC) == 0 { + found_unset = true; + } else { + found_set = true; + } + } + assert!(found_unset); + assert!(found_set); + } + + #[test] + fn build_abort() { + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Initial, + QuicVersion::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(SERVER_CID), + ); + builder.initial_token(&[]); + builder.pn(1, 2); + let encoder = builder.abort(); + assert!(encoder.is_empty()); + } + + const SAMPLE_RETRY_27: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x1b, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xa5, 0x23, 0xcb, 0x5b, 0xa5, 0x24, 0x69, 0x5f, 0x65, 0x69, + 0xf2, 0x93, 0xa1, 0x35, 0x9d, 0x8e, + ]; + + const SAMPLE_RETRY_28: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x1c, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xf7, 0x1a, 0x5f, 0x12, 0xaf, 0xe3, 0xec, 0xf8, 0x00, 0x1a, + 0x92, 0x0e, 0x6f, 0xdf, 0x1d, 0x63, + ]; + + const SAMPLE_RETRY_29: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xd1, 0x69, 0x26, 0xd8, 0x1f, 0x6f, 0x9c, 0xa2, 0x95, 0x3a, + 0x8a, 0xa4, 0x57, 0x5e, 0x1e, 0x49, + ]; + + const SAMPLE_RETRY_30: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x1e, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2d, 0x3e, 0x04, 0x5d, 0x6d, 0x39, 0x20, 0x67, 0x89, 0x94, + 0x37, 0x10, 0x8c, 0xe0, 0x0a, 0x61, + ]; + + const SAMPLE_RETRY_31: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x1f, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xc7, 0x0c, 0xe5, 0xde, 0x43, 0x0b, 0x4b, 0xdb, 0x7d, 0xf1, + 0xa3, 0x83, 0x3a, 0x75, 0xf9, 0x86, + ]; + + const SAMPLE_RETRY_32: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x20, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x59, 0x75, 0x65, 0x19, 0xdd, 0x6c, 0xc8, 0x5b, 0xd9, 0x0e, + 0x33, 0xa9, 0x34, 0xd2, 0xff, 0x85, + ]; + + const RETRY_TOKEN: &[u8] = b"token"; + + fn build_retry_single(quic_version: QuicVersion, sample_retry: &[u8]) { + fixture_init(); + let retry = + PacketBuilder::retry(quic_version, &[], SERVER_CID, RETRY_TOKEN, CLIENT_CID).unwrap(); + + let (packet, remainder) = PublicPacket::decode(&retry, &cid_mgr()).unwrap(); + assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); + assert!(remainder.is_empty()); + + // The builder adds randomness, which makes expectations hard. + // So only do a full check when that randomness matches up. + if retry[0] == sample_retry[0] { + assert_eq!(&retry, &sample_retry); + } else { + // Otherwise, just check that the header is OK. + assert_eq!(retry[0] & 0xf0, 0xf0); + let header_range = 1..retry.len() - 16; + assert_eq!(&retry[header_range.clone()], &sample_retry[header_range]); + } + } + + #[test] + fn build_retry_27() { + build_retry_single(QuicVersion::Draft27, SAMPLE_RETRY_27); + } + + #[test] + fn build_retry_28() { + build_retry_single(QuicVersion::Draft28, SAMPLE_RETRY_28); + } + + #[test] + fn build_retry_29() { + build_retry_single(QuicVersion::Draft29, SAMPLE_RETRY_29); + } + + #[test] + fn build_retry_30() { + build_retry_single(QuicVersion::Draft30, SAMPLE_RETRY_30); + } + + #[test] + fn build_retry_31() { + build_retry_single(QuicVersion::Draft31, SAMPLE_RETRY_31); + } + + #[test] + fn build_retry_32() { + build_retry_single(QuicVersion::Draft32, SAMPLE_RETRY_32); + } + + #[test] + fn build_retry_multiple() { + // Run the build_retry test a few times. + // Odds are approximately 1 in 8 that the full comparison doesn't happen + // for a given version. + for _ in 0..32 { + build_retry_27(); + build_retry_28(); + build_retry_29(); + build_retry_30(); + } + } + + fn decode_retry(quic_version: QuicVersion, sample_retry: &[u8]) { + fixture_init(); + let (packet, remainder) = + PublicPacket::decode(sample_retry, &FixedConnectionIdManager::new(5)).unwrap(); + assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); + assert_eq!(Some(quic_version), packet.quic_version); + assert!(packet.dcid().is_empty()); + assert_eq!(&packet.scid()[..], SERVER_CID); + assert_eq!(packet.token(), RETRY_TOKEN); + assert!(remainder.is_empty()); + } + + #[test] + fn decode_retry_27() { + decode_retry(QuicVersion::Draft27, SAMPLE_RETRY_27); + } + + #[test] + fn decode_retry_28() { + decode_retry(QuicVersion::Draft28, SAMPLE_RETRY_28); + } + + #[test] + fn decode_retry_29() { + decode_retry(QuicVersion::Draft29, SAMPLE_RETRY_29); + } + + #[test] + fn decode_retry_30() { + decode_retry(QuicVersion::Draft30, SAMPLE_RETRY_30); + } + + #[test] + fn decode_retry_31() { + decode_retry(QuicVersion::Draft31, SAMPLE_RETRY_31); + } + + #[test] + fn decode_retry_32() { + decode_retry(QuicVersion::Draft32, SAMPLE_RETRY_32); + } + + /// Check some packets that are clearly not valid Retry packets. + #[test] + fn invalid_retry() { + fixture_init(); + let cid_mgr = FixedConnectionIdManager::new(5); + let odcid = ConnectionId::from(CLIENT_CID); + + assert!(PublicPacket::decode(&[], &cid_mgr).is_err()); + + let (packet, remainder) = PublicPacket::decode(SAMPLE_RETRY_28, &cid_mgr).unwrap(); + assert!(remainder.is_empty()); + assert!(packet.is_valid_retry(&odcid)); + + let mut damaged_retry = SAMPLE_RETRY_28.to_vec(); + let last = damaged_retry.len() - 1; + damaged_retry[last] ^= 66; + let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap(); + assert!(remainder.is_empty()); + assert!(!packet.is_valid_retry(&odcid)); + + damaged_retry.truncate(last); + let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap(); + assert!(remainder.is_empty()); + assert!(!packet.is_valid_retry(&odcid)); + + // An invalid token should be rejected sooner. + damaged_retry.truncate(last - 4); + assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err()); + + damaged_retry.truncate(last - 1); + assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err()); + } + + const SAMPLE_VN: &[u8] = &[ + 0x80, 0x00, 0x00, 0x00, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x08, + 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, 0xff, 0x00, 0x00, 0x1b, 0xff, 0x00, 0x00, + 0x1c, 0xff, 0x00, 0x00, 0x1d, 0xff, 0x00, 0x00, 0x1e, 0xff, 0x00, 0x00, 0x1f, 0xff, 0x00, + 0x00, 0x20, 0x0a, 0x0a, 0x0a, 0x0a, + ]; + + #[test] + fn build_vn() { + fixture_init(); + let mut vn = PacketBuilder::version_negotiation(SERVER_CID, CLIENT_CID); + // Erase randomness from greasing... + assert_eq!(vn.len(), SAMPLE_VN.len()); + vn[0] &= 0x80; + for v in vn.iter_mut().skip(SAMPLE_VN.len() - 4) { + *v &= 0x0f; + } + assert_eq!(&vn, &SAMPLE_VN); + } + + #[test] + fn parse_vn() { + let (packet, remainder) = + PublicPacket::decode(SAMPLE_VN, &FixedConnectionIdManager::new(5)).unwrap(); + assert!(remainder.is_empty()); + assert_eq!(&packet.dcid[..], SERVER_CID); + assert!(packet.scid.is_some()); + assert_eq!(&packet.scid.unwrap()[..], CLIENT_CID); + } + + /// A Version Negotiation packet can have a long connection ID. + #[test] + fn parse_vn_big_cid() { + const BIG_DCID: &[u8] = &[0x44; MAX_CONNECTION_ID_LEN + 1]; + const BIG_SCID: &[u8] = &[0xee; 255]; + + let mut enc = Encoder::from(&[0xff, 0x00, 0x00, 0x00, 0x00][..]); + enc.encode_vec(1, BIG_DCID); + enc.encode_vec(1, BIG_SCID); + enc.encode_uint(4, 0x1a2a_3a4a_u64); + enc.encode_uint(4, QuicVersion::default().as_u32()); + enc.encode_uint(4, 0x5a6a_7a8a_u64); + + let (packet, remainder) = + PublicPacket::decode(&enc, &FixedConnectionIdManager::new(5)).unwrap(); + assert!(remainder.is_empty()); + assert_eq!(&packet.dcid[..], BIG_DCID); + assert!(packet.scid.is_some()); + assert_eq!(&packet.scid.unwrap()[..], BIG_SCID); + } + + #[test] + fn decode_pn() { + // When the expected value is low, the value doesn't go negative. + assert_eq!(PublicPacket::decode_pn(0, 0, 1), 0); + assert_eq!(PublicPacket::decode_pn(0, 0xff, 1), 0xff); + assert_eq!(PublicPacket::decode_pn(10, 0, 1), 0); + assert_eq!(PublicPacket::decode_pn(0x7f, 0, 1), 0); + assert_eq!(PublicPacket::decode_pn(0x80, 0, 1), 0x100); + assert_eq!(PublicPacket::decode_pn(0x80, 2, 1), 2); + assert_eq!(PublicPacket::decode_pn(0x80, 0xff, 1), 0xff); + assert_eq!(PublicPacket::decode_pn(0x7ff, 0xfe, 1), 0x7fe); + + // This is invalid by spec, as we are expected to check for overflow around 2^62-1, + // but we don't need to worry about overflow + // and hitting this is basically impossible in practice. + assert_eq!( + PublicPacket::decode_pn(0x3fff_ffff_ffff_ffff, 2, 4), + 0x4000_0000_0000_0002 + ); + } + + #[test] + fn chacha20_sample() { + const PACKET: &[u8] = &[ + 0x4c, 0xfe, 0x41, 0x89, 0x65, 0x5e, 0x5c, 0xd5, 0x5c, 0x41, 0xf6, 0x90, 0x80, 0x57, + 0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb, + ]; + fixture_init(); + let (packet, slice) = + PublicPacket::decode(PACKET, &FixedConnectionIdManager::new(0)).unwrap(); + assert!(slice.is_empty()); + let decrypted = packet + .decrypt(&mut CryptoStates::test_chacha(), now()) + .unwrap(); + assert_eq!(decrypted.packet_type(), PacketType::Short); + assert_eq!(decrypted.pn(), 654_360_564); + assert_eq!(&decrypted[..], &[0x01]); + } +} diff --git a/third_party/rust/neqo-transport/src/packet/retry.rs b/third_party/rust/neqo-transport/src/packet/retry.rs new file mode 100644 index 0000000000..596714aa6d --- /dev/null +++ b/third_party/rust/neqo-transport/src/packet/retry.rs @@ -0,0 +1,63 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![deny(clippy::pedantic)] + +use crate::packet::QuicVersion; +use crate::{Error, Res}; + +use neqo_common::qerror; +use neqo_crypto::{aead::Aead, hkdf, TLS_AES_128_GCM_SHA256, TLS_VERSION_1_3}; + +use std::cell::RefCell; + +const RETRY_SECRET_27: &[u8] = &[ + 0x65, 0x6e, 0x61, 0xe3, 0x36, 0xae, 0x94, 0x17, 0xf7, 0xf0, 0xed, 0xd8, 0xd7, 0x8d, 0x46, 0x1e, + 0x2a, 0xa7, 0x08, 0x4a, 0xba, 0x7a, 0x14, 0xc1, 0xe9, 0xf7, 0x26, 0xd5, 0x57, 0x09, 0x16, 0x9a, +]; +const RETRY_SECRET_29: &[u8] = &[ + 0x8b, 0x0d, 0x37, 0xeb, 0x85, 0x35, 0x02, 0x2e, 0xbc, 0x8d, 0x76, 0xa2, 0x07, 0xd8, 0x0d, 0xf2, + 0x26, 0x46, 0xec, 0x06, 0xdc, 0x80, 0x96, 0x42, 0xc3, 0x0a, 0x8b, 0xaa, 0x2b, 0xaa, 0xff, 0x4c, +]; + +/// The AEAD used for Retry is fixed, so use thread local storage. +fn make_aead(secret: &[u8]) -> Aead { + #[cfg(debug_assertions)] + ::neqo_crypto::assert_initialized(); + + let secret = hkdf::import_key(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, secret).unwrap(); + Aead::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, &secret, "quic ").unwrap() +} +thread_local!(static RETRY_AEAD_27: RefCell<Aead> = RefCell::new(make_aead(RETRY_SECRET_27))); +thread_local!(static RETRY_AEAD_29: RefCell<Aead> = RefCell::new(make_aead(RETRY_SECRET_29))); + +/// Run a function with the appropriate Retry AEAD. +pub fn use_aead<F, T>(quic_version: QuicVersion, f: F) -> Res<T> +where + F: FnOnce(&Aead) -> Res<T>, +{ + match quic_version { + QuicVersion::Draft27 | QuicVersion::Draft28 => &RETRY_AEAD_27, + QuicVersion::Draft29 + | QuicVersion::Draft30 + | QuicVersion::Draft31 + | QuicVersion::Draft32 => &RETRY_AEAD_29, + } + .try_with(|aead| f(&aead.borrow())) + .map_err(|e| { + qerror!("Unable to access Retry AEAD: {:?}", e); + Error::InternalError + })? +} + +/// Determine how large the expansion is for a given key. +pub fn expansion(quic_version: QuicVersion) -> usize { + if let Ok(ex) = use_aead(quic_version, |aead| Ok(aead.expansion())) { + ex + } else { + panic!("Unable to access Retry AEAD") + } +} diff --git a/third_party/rust/neqo-transport/src/path.rs b/third_party/rust/neqo-transport/src/path.rs new file mode 100644 index 0000000000..a4e6b2f361 --- /dev/null +++ b/third_party/rust/neqo-transport/src/path.rs @@ -0,0 +1,109 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::net::SocketAddr; + +use crate::cid::{ConnectionId, ConnectionIdRef}; + +use neqo_common::Datagram; + +/// This is the MTU that we assume when using IPv6. +/// We use this size for Initial packets, so we don't need to worry about probing for support. +/// If the path doesn't support this MTU, we will assume that it doesn't support QUIC. +/// +/// This is a multiple of 16 greater than the largest possible short header (1 + 20 + 4). +pub const PATH_MTU_V6: usize = 1337; +/// The path MTU for IPv4 can be 20 bytes larger than for v6. +pub const PATH_MTU_V4: usize = PATH_MTU_V6 + 20; + +#[derive(Clone, Debug, PartialEq)] +pub struct Path { + local: SocketAddr, + remote: SocketAddr, + local_cids: Vec<ConnectionId>, + remote_cid: ConnectionId, + reset_token: Option<[u8; 16]>, +} + +impl Path { + /// Create a path from addresses and connection IDs. + pub fn new( + local: SocketAddr, + remote: SocketAddr, + local_cid: ConnectionId, + remote_cid: ConnectionId, + ) -> Self { + Self { + local, + remote, + local_cids: vec![local_cid], + remote_cid, + reset_token: None, + } + } + + pub fn received_on(&self, d: &Datagram) -> bool { + self.local == d.destination() && self.remote == d.source() + } + + pub fn mtu(&self) -> usize { + if self.local.is_ipv4() { + PATH_MTU_V4 + } else { + PATH_MTU_V6 // IPv6 + } + } + + /// Add a connection ID to the local set. + pub fn add_local_cid(&mut self, cid: ConnectionId) { + self.local_cids.push(cid); + } + + /// Determine if the given connection ID is valid. + pub fn valid_local_cid(&self, cid: &ConnectionIdRef) -> bool { + self.local_cids.iter().any(|c| c == cid) + } + + /// Get the first local connection ID. + pub fn local_cid(&self) -> &ConnectionId { + self.local_cids.first().as_ref().unwrap() + } + + /// Set the remote connection ID based on the peer's choice. + pub fn set_remote_cid(&mut self, cid: &ConnectionIdRef) { + self.remote_cid = ConnectionId::from(cid); + } + + /// Access the remote connection ID. + pub fn remote_cid(&self) -> &ConnectionId { + &self.remote_cid + } + + /// Set the stateless reset token for the connection ID that is currently in use. + pub fn set_reset_token(&mut self, token: [u8; 16]) { + self.reset_token = Some(token); + } + + /// Access the reset token. + pub fn reset_token(&self) -> Option<&[u8; 16]> { + self.reset_token.as_ref() + } + + /// Make a datagram. + pub fn datagram<V: Into<Vec<u8>>>(&self, payload: V) -> Datagram { + Datagram::new(self.local, self.remote, payload) + } + + /// Get local address as `SocketAddr` + pub fn local_address(&self) -> SocketAddr { + self.local + } + + /// Get remote address as `SocketAddr` + pub fn remote_address(&self) -> SocketAddr { + self.remote + } +} diff --git a/third_party/rust/neqo-transport/src/qlog.rs b/third_party/rust/neqo-transport/src/qlog.rs new file mode 100644 index 0000000000..b4cc1adf67 --- /dev/null +++ b/third_party/rust/neqo-transport/src/qlog.rs @@ -0,0 +1,442 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Functions that handle capturing QLOG traces. + +use std::convert::TryFrom; +use std::ops::RangeInclusive; +use std::string::String; +use std::time::Duration; + +use qlog::{self, event::Event, PacketHeader, QuicFrame}; + +use neqo_common::{hex, qinfo, qlog::NeqoQlog, Decoder}; + +use crate::connection::State; +use crate::frame::{self, Frame}; +use crate::packet::{DecryptedPacket, PacketNumber, PacketType, PublicPacket}; +use crate::path::Path; +use crate::tparams::{self, TransportParametersHandler}; +use crate::tracking::SentPacket; +use crate::QuicVersion; + +pub fn connection_tparams_set(qlog: &mut NeqoQlog, tph: &TransportParametersHandler) { + qlog.add_event(|| { + let remote = tph.remote(); + Some(Event::transport_parameters_set( + None, + None, + None, + None, + None, + None, + if let Some(ocid) = remote.get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID) { + // Cannot use packet::ConnectionId's Display trait implementation + // because it does not include the 0x prefix. + Some(hex(ocid)) + } else { + None + }, + if let Some(srt) = remote.get_bytes(tparams::STATELESS_RESET_TOKEN) { + Some(hex(srt)) + } else { + None + }, + if remote.get_empty(tparams::DISABLE_MIGRATION) { + Some(true) + } else { + None + }, + Some(remote.get_integer(tparams::IDLE_TIMEOUT)), + Some(remote.get_integer(tparams::MAX_UDP_PAYLOAD_SIZE)), + Some(remote.get_integer(tparams::ACK_DELAY_EXPONENT)), + Some(remote.get_integer(tparams::MAX_ACK_DELAY)), + // TODO(hawkinsw@obs.cr): We do not yet handle ACTIVE_CONNECTION_ID_LIMIT in tparams yet. + None, + Some(format!("{}", remote.get_integer(tparams::INITIAL_MAX_DATA))), + Some(format!( + "{}", + remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL) + )), + Some(format!( + "{}", + remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE) + )), + Some(format!( + "{}", + remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI) + )), + Some(format!( + "{}", + remote.get_integer(tparams::INITIAL_MAX_STREAMS_BIDI) + )), + Some(format!( + "{}", + remote.get_integer(tparams::INITIAL_MAX_STREAMS_UNI) + )), + // TODO(hawkinsw@obs.cr): We do not yet handle PREFERRED_ADDRESS in tparams yet. + None, + )) + }) +} + +pub fn server_connection_started(qlog: &mut NeqoQlog, path: &Path) { + connection_started(qlog, path) +} + +pub fn client_connection_started(qlog: &mut NeqoQlog, path: &Path) { + connection_started(qlog, path) +} + +fn connection_started(qlog: &mut NeqoQlog, path: &Path) { + qlog.add_event(|| { + Some(Event::connection_started( + if path.local_address().ip().is_ipv4() { + "ipv4".into() + } else { + "ipv6".into() + }, + format!("{}", path.local_address().ip()), + format!("{}", path.remote_address().ip()), + Some("QUIC".into()), + path.local_address().port().into(), + path.remote_address().port().into(), + Some(format!("{:x}", QuicVersion::default().as_u32())), + Some(format!("{}", path.local_cid())), + Some(format!("{}", path.remote_cid())), + )) + }) +} + +pub fn connection_state_updated(qlog: &mut NeqoQlog, new: &State) { + qlog.add_event(|| { + Some(Event::connection_state_updated_min(match new { + State::Init => qlog::ConnectionState::Attempted, + State::WaitInitial => qlog::ConnectionState::Attempted, + State::Handshaking => qlog::ConnectionState::Handshake, + State::Connected => qlog::ConnectionState::Active, + State::Confirmed => qlog::ConnectionState::Active, + State::Closing { .. } => qlog::ConnectionState::Draining, + State::Draining { .. } => qlog::ConnectionState::Draining, + State::Closed { .. } => qlog::ConnectionState::Closed, + })) + }) +} + +pub fn packet_sent( + qlog: &mut NeqoQlog, + pt: PacketType, + pn: PacketNumber, + plen: usize, + body: &[u8], +) { + qlog.add_event_with_stream(|stream| { + let mut d = Decoder::from(body); + + stream.add_event(Event::packet_sent_min( + to_qlog_pkt_type(pt), + PacketHeader::new( + pn, + Some(u64::try_from(plen).unwrap()), + None, + None, + None, + None, + ), + Some(Vec::new()), + ))?; + + while d.remaining() > 0 { + match Frame::decode(&mut d) { + Ok(f) => { + stream.add_frame(frame_to_qlogframe(&f), false)?; + } + Err(_) => { + qinfo!("qlog: invalid frame"); + break; + } + } + } + + stream.finish_frames() + }) +} + +pub fn packet_dropped(qlog: &mut NeqoQlog, payload: &PublicPacket) { + qlog.add_event(|| { + Some(Event::packet_dropped( + Some(to_qlog_pkt_type(payload.packet_type())), + Some(u64::try_from(payload.len()).unwrap()), + None, + )) + }) +} + +pub fn packets_lost(qlog: &mut NeqoQlog, pkts: &[SentPacket]) { + qlog.add_event_with_stream(|stream| { + for pkt in pkts { + stream.add_event(Event::packet_lost_min( + to_qlog_pkt_type(pkt.pt), + pkt.pn.to_string(), + Vec::new(), + ))?; + + stream.finish_frames()?; + } + Ok(()) + }) +} + +pub fn packet_received( + qlog: &mut NeqoQlog, + public_packet: &PublicPacket, + payload: &DecryptedPacket, +) { + qlog.add_event_with_stream(|stream| { + let mut d = Decoder::from(&payload[..]); + + stream.add_event(Event::packet_received( + to_qlog_pkt_type(payload.packet_type()), + PacketHeader::new( + payload.pn(), + Some(u64::try_from(public_packet.len()).unwrap()), + None, + None, + None, + None, + ), + Some(Vec::new()), + None, + None, + None, + ))?; + + while d.remaining() > 0 { + match Frame::decode(&mut d) { + Ok(f) => stream.add_frame(frame_to_qlogframe(&f), false)?, + Err(_) => { + qinfo!("qlog: invalid frame"); + break; + } + } + } + + stream.finish_frames() + }) +} + +#[allow(dead_code)] +pub enum QlogMetric { + MinRtt(Duration), + SmoothedRtt(Duration), + LatestRtt(Duration), + RttVariance(u64), + MaxAckDelay(u64), + PtoCount(usize), + CongestionWindow(usize), + BytesInFlight(usize), + SsThresh(usize), + PacketsInFlight(u64), + InRecovery(bool), + PacingRate(u64), +} + +pub fn metrics_updated(qlog: &mut NeqoQlog, updated_metrics: &[QlogMetric]) { + debug_assert!(!updated_metrics.is_empty()); + + qlog.add_event(|| { + let mut min_rtt: Option<u64> = None; + let mut smoothed_rtt: Option<u64> = None; + let mut latest_rtt: Option<u64> = None; + let mut rtt_variance: Option<u64> = None; + let mut max_ack_delay: Option<u64> = None; + let mut pto_count: Option<u64> = None; + let mut congestion_window: Option<u64> = None; + let mut bytes_in_flight: Option<u64> = None; + let mut ssthresh: Option<u64> = None; + let mut packets_in_flight: Option<u64> = None; + let mut in_recovery: Option<bool> = None; + let mut pacing_rate: Option<u64> = None; + + for metric in updated_metrics { + match metric { + QlogMetric::MinRtt(v) => min_rtt = Some(u64::try_from(v.as_millis()).unwrap()), + QlogMetric::SmoothedRtt(v) => { + smoothed_rtt = Some(u64::try_from(v.as_millis()).unwrap()) + } + QlogMetric::LatestRtt(v) => { + latest_rtt = Some(u64::try_from(v.as_millis()).unwrap()) + } + QlogMetric::RttVariance(v) => rtt_variance = Some(*v), + QlogMetric::MaxAckDelay(v) => max_ack_delay = Some(*v), + QlogMetric::PtoCount(v) => pto_count = Some(u64::try_from(*v).unwrap()), + QlogMetric::CongestionWindow(v) => { + congestion_window = Some(u64::try_from(*v).unwrap()) + } + QlogMetric::BytesInFlight(v) => bytes_in_flight = Some(u64::try_from(*v).unwrap()), + QlogMetric::SsThresh(v) => ssthresh = Some(u64::try_from(*v).unwrap()), + QlogMetric::PacketsInFlight(v) => packets_in_flight = Some(*v), + QlogMetric::InRecovery(v) => in_recovery = Some(*v), + QlogMetric::PacingRate(v) => pacing_rate = Some(*v), + } + } + + Some(Event::metrics_updated( + min_rtt, + smoothed_rtt, + latest_rtt, + rtt_variance, + max_ack_delay, + pto_count, + congestion_window, + bytes_in_flight, + ssthresh, + packets_in_flight, + in_recovery, + pacing_rate, + )) + }) +} + +// Helper functions + +fn frame_to_qlogframe(frame: &Frame) -> QuicFrame { + match frame { + Frame::Padding => QuicFrame::padding(), + Frame::Ping => QuicFrame::ping(), + Frame::Ack { + largest_acknowledged, + ack_delay, + first_ack_range, + ack_ranges, + } => { + let ranges = + Frame::decode_ack_frame(*largest_acknowledged, *first_ack_range, ack_ranges).ok(); + + QuicFrame::ack( + Some(ack_delay.to_string()), + ranges.map(|all| { + all.into_iter() + .map(RangeInclusive::into_inner) + .collect::<Vec<_>>() + }), + None, + None, + None, + ) + } + Frame::ResetStream { + stream_id, + application_error_code, + final_size, + } => QuicFrame::reset_stream( + stream_id.as_u64().to_string(), + *application_error_code, + final_size.to_string(), + ), + Frame::StopSending { + stream_id, + application_error_code, + } => QuicFrame::stop_sending(stream_id.as_u64().to_string(), *application_error_code), + Frame::Crypto { offset, data } => { + QuicFrame::crypto(offset.to_string(), data.len().to_string()) + } + Frame::NewToken { token } => QuicFrame::new_token(token.len().to_string(), hex(&token)), + Frame::Stream { + fin, + stream_id, + offset, + data, + .. + } => QuicFrame::stream( + stream_id.as_u64().to_string(), + offset.to_string(), + data.len().to_string(), + *fin, + None, + ), + Frame::MaxData { maximum_data } => QuicFrame::max_data(maximum_data.to_string()), + Frame::MaxStreamData { + stream_id, + maximum_stream_data, + } => QuicFrame::max_stream_data( + stream_id.as_u64().to_string(), + maximum_stream_data.to_string(), + ), + Frame::MaxStreams { + stream_type, + maximum_streams, + } => QuicFrame::max_streams( + match stream_type { + frame::StreamType::BiDi => qlog::StreamType::Bidirectional, + frame::StreamType::UniDi => qlog::StreamType::Unidirectional, + }, + maximum_streams.as_u64().to_string(), + ), + Frame::DataBlocked { data_limit } => QuicFrame::data_blocked(data_limit.to_string()), + Frame::StreamDataBlocked { + stream_id, + stream_data_limit, + } => QuicFrame::stream_data_blocked( + stream_id.as_u64().to_string(), + stream_data_limit.to_string(), + ), + Frame::StreamsBlocked { + stream_type, + stream_limit, + } => QuicFrame::streams_blocked( + match stream_type { + frame::StreamType::BiDi => qlog::StreamType::Bidirectional, + frame::StreamType::UniDi => qlog::StreamType::Unidirectional, + }, + stream_limit.as_u64().to_string(), + ), + Frame::NewConnectionId { + sequence_number, + retire_prior, + connection_id, + stateless_reset_token, + } => QuicFrame::new_connection_id( + sequence_number.to_string(), + retire_prior.to_string(), + connection_id.len() as u64, + hex(&connection_id), + hex(stateless_reset_token), + ), + Frame::RetireConnectionId { sequence_number } => { + QuicFrame::retire_connection_id(sequence_number.to_string()) + } + Frame::PathChallenge { data } => QuicFrame::path_challenge(Some(hex(data))), + Frame::PathResponse { data } => QuicFrame::path_response(Some(hex(data))), + Frame::ConnectionClose { + error_code, + frame_type, + reason_phrase, + } => QuicFrame::connection_close( + match error_code { + frame::CloseError::Transport(_) => qlog::ErrorSpace::TransportError, + frame::CloseError::Application(_) => qlog::ErrorSpace::ApplicationError, + }, + error_code.code(), + 0, + String::from_utf8_lossy(&reason_phrase).to_string(), + Some(frame_type.to_string()), + ), + Frame::HandshakeDone => QuicFrame::handshake_done(), + } +} + +fn to_qlog_pkt_type(ptype: PacketType) -> qlog::PacketType { + match ptype { + PacketType::Initial => qlog::PacketType::Initial, + PacketType::Handshake => qlog::PacketType::Handshake, + PacketType::ZeroRtt => qlog::PacketType::ZeroRtt, + PacketType::Short => qlog::PacketType::OneRtt, + PacketType::Retry => qlog::PacketType::Retry, + PacketType::VersionNegotiation => qlog::PacketType::VersionNegotiation, + PacketType::OtherVersion => qlog::PacketType::Unknown, + } +} diff --git a/third_party/rust/neqo-transport/src/recovery.rs b/third_party/rust/neqo-transport/src/recovery.rs new file mode 100644 index 0000000000..bac07dc988 --- /dev/null +++ b/third_party/rust/neqo-transport/src/recovery.rs @@ -0,0 +1,1470 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Tracking of sent packets and detecting their loss. + +#![deny(clippy::pedantic)] + +use std::cmp::{max, min}; +use std::collections::BTreeMap; +use std::mem; +use std::ops::RangeInclusive; +use std::time::{Duration, Instant}; + +use smallvec::{smallvec, SmallVec}; + +use neqo_common::{qdebug, qinfo, qlog::NeqoQlog, qtrace}; + +use crate::cc::CongestionControlAlgorithm; +use crate::connection::LOCAL_IDLE_TIMEOUT; +use crate::crypto::CryptoRecoveryToken; +use crate::flow_mgr::FlowControlRecoveryToken; +use crate::packet::PacketNumber; +use crate::qlog::{self, QlogMetric}; +use crate::send_stream::StreamRecoveryToken; +use crate::stats::{Stats, StatsCell}; +use crate::tracking::{AckToken, PNSpace, PNSpaceSet, SentPacket}; +use crate::PacketSender; + +pub const GRANULARITY: Duration = Duration::from_millis(20); +/// The default value for the maximum time a peer can delay acknowledgment +/// of an ack-eliciting packet. +pub const MAX_ACK_DELAY: Duration = Duration::from_millis(25); +// Defined in -recovery 6.2 as 333ms but using lower value. +const INITIAL_RTT: Duration = Duration::from_millis(100); +pub(crate) const PACKET_THRESHOLD: u64 = 3; +/// `ACK_ONLY_SIZE_LIMIT` is the minimum size of the congestion window. +/// If the congestion window is this small, we will only send ACK frames. +pub(crate) const ACK_ONLY_SIZE_LIMIT: usize = 256; +/// The number of packets we send on a PTO. +/// And the number to declare lost when the PTO timer is hit. +pub const PTO_PACKET_COUNT: usize = 2; + +#[derive(Debug, Clone)] +#[allow(clippy::module_name_repetitions)] +pub enum RecoveryToken { + Ack(AckToken), + Stream(StreamRecoveryToken), + Crypto(CryptoRecoveryToken), + Flow(FlowControlRecoveryToken), + HandshakeDone, + NewToken(usize), +} + +#[derive(Debug)] +struct RttVals { + first_sample_time: Option<Instant>, + latest_rtt: Duration, + smoothed_rtt: Duration, + rttvar: Duration, + min_rtt: Duration, + max_ack_delay: Duration, +} + +impl RttVals { + pub fn set_initial_rtt(&mut self, rtt: Duration) { + // Only allow this when there are no samples. + debug_assert!(self.first_sample_time.is_none()); + self.latest_rtt = rtt; + self.min_rtt = rtt; + self.smoothed_rtt = rtt; + self.rttvar = rtt / 2; + } + + pub fn set_peer_max_ack_delay(&mut self, mad: Duration) { + self.max_ack_delay = mad; + } + + fn update_rtt( + &mut self, + mut qlog: &mut NeqoQlog, + mut rtt_sample: Duration, + ack_delay: Duration, + now: Instant, + ) { + // min_rtt ignores ack delay. + self.min_rtt = min(self.min_rtt, rtt_sample); + // Note: the caller adjusts `ack_delay` based on `max_ack_delay`. + // Adjust for ack delay unless it goes below `min_rtt`. + if rtt_sample - self.min_rtt >= ack_delay { + rtt_sample -= ack_delay; + } + + if self.first_sample_time.is_none() { + self.set_initial_rtt(rtt_sample); + self.first_sample_time = Some(now); + } else { + // Calculate EWMA RTT (based on {{?RFC6298}}). + let rttvar_sample = if self.smoothed_rtt > rtt_sample { + self.smoothed_rtt - rtt_sample + } else { + rtt_sample - self.smoothed_rtt + }; + + self.latest_rtt = rtt_sample; + self.rttvar = (self.rttvar * 3 + rttvar_sample) / 4; + self.smoothed_rtt = (self.smoothed_rtt * 7 + rtt_sample) / 8; + } + qtrace!( + "RTT latest={:?} -> estimate={:?}~{:?}", + self.latest_rtt, + self.smoothed_rtt, + self.rttvar + ); + qlog::metrics_updated( + &mut qlog, + &[ + QlogMetric::LatestRtt(self.latest_rtt), + QlogMetric::MinRtt(self.min_rtt), + QlogMetric::SmoothedRtt(self.smoothed_rtt), + ], + ); + } + + pub fn rtt(&self) -> Duration { + self.smoothed_rtt + } + + fn pto(&self, pn_space: PNSpace) -> Duration { + self.rtt() + + max(4 * self.rttvar, GRANULARITY) + + if pn_space == PNSpace::ApplicationData { + self.max_ack_delay + } else { + Duration::from_millis(0) + } + } + + fn first_sample_time(&self) -> Option<Instant> { + self.first_sample_time + } +} + +impl Default for RttVals { + fn default() -> Self { + Self { + first_sample_time: None, + latest_rtt: INITIAL_RTT, + smoothed_rtt: INITIAL_RTT, + rttvar: INITIAL_RTT / 2, + min_rtt: INITIAL_RTT, + max_ack_delay: MAX_ACK_DELAY, + } + } +} + +/// `SendProfile` tells a sender how to send packets. +#[derive(Debug)] +pub struct SendProfile { + limit: usize, + pto: Option<PNSpace>, + probe: PNSpaceSet, + paced: bool, +} + +impl SendProfile { + pub fn new_limited(limit: usize) -> Self { + // When the limit is too low, we only send ACK frames. + // Set the limit to `ACK_ONLY_SIZE_LIMIT - 1` to ensure that + // ACK-only packets are still limited in size. + Self { + limit: max(ACK_ONLY_SIZE_LIMIT - 1, limit), + pto: None, + probe: PNSpaceSet::default(), + paced: false, + } + } + + pub fn new_paced() -> Self { + // When pacing, we still allow ACK frames to be sent. + Self { + limit: ACK_ONLY_SIZE_LIMIT - 1, + pto: None, + probe: PNSpaceSet::default(), + paced: true, + } + } + + pub fn new_pto(pn_space: PNSpace, mtu: usize, probe: PNSpaceSet) -> Self { + debug_assert!(mtu > ACK_ONLY_SIZE_LIMIT); + debug_assert!(probe[pn_space]); + Self { + limit: mtu, + pto: Some(pn_space), + probe, + paced: false, + } + } + + /// Whether probing this space is helpful. This isn't necessarily the space + /// that caused the timer to pop, but it is helpful to send a PING in a space + /// that has the PTO timer armed. + pub fn should_probe(&self, space: PNSpace) -> bool { + self.probe[space] + } + + /// Determine whether an ACK-only packet should be sent for the given packet + /// number space. + /// Send only ACKs either: when the space available is too small, or when a PTO + /// exists for a later packet number space (which should get the most space). + pub fn ack_only(&self, space: PNSpace) -> bool { + self.limit < ACK_ONLY_SIZE_LIMIT || self.pto.map_or(false, |sp| space < sp) + } + + pub fn paced(&self) -> bool { + self.paced + } + + pub fn limit(&self) -> usize { + self.limit + } +} + +#[derive(Debug)] +pub(crate) struct LossRecoverySpace { + space: PNSpace, + largest_acked: Option<PacketNumber>, + largest_acked_sent_time: Option<Instant>, + /// The time used to calculate the PTO timer for this space. + /// This is the time that the last ACK-eliciting packet in this space + /// was sent. This might be the time that a probe was sent. + pto_base_time: Option<Instant>, + /// The number of outstanding packets in this space that are in flight. + /// This might be less than the number of ACK-eliciting packets, + /// because PTO packets don't count. + in_flight_outstanding: u64, + sent_packets: BTreeMap<u64, SentPacket>, + /// The time that the first out-of-order packet was sent. + /// This is `None` if there were no out-of-order packets detected. + /// When set to `Some(T)`, time-based loss detection should be enabled. + first_ooo_time: Option<Instant>, +} + +impl LossRecoverySpace { + pub fn new(space: PNSpace) -> Self { + Self { + space, + largest_acked: None, + largest_acked_sent_time: None, + pto_base_time: None, + in_flight_outstanding: 0, + sent_packets: BTreeMap::default(), + first_ooo_time: None, + } + } + + #[must_use] + pub fn space(&self) -> PNSpace { + self.space + } + + /// Find the time we sent the first packet that is lower than the + /// largest acknowledged and that isn't yet declared lost. + /// Use the value we prepared earlier in `detect_lost_packets`. + #[must_use] + pub fn loss_recovery_timer_start(&self) -> Option<Instant> { + self.first_ooo_time + } + + pub fn in_flight_outstanding(&self) -> bool { + self.in_flight_outstanding > 0 + } + + pub fn pto_packets(&mut self, count: usize) -> impl Iterator<Item = &SentPacket> { + self.sent_packets + .iter_mut() + .filter_map(|(pn, sent)| { + if sent.pto() { + qtrace!("PTO: marking packet {} lost ", pn); + Some(&*sent) + } else { + None + } + }) + .take(count) + } + + pub fn pto_base_time(&self) -> Option<Instant> { + if self.in_flight_outstanding() { + debug_assert!(self.pto_base_time.is_some()); + self.pto_base_time + } else if self.space == PNSpace::ApplicationData { + None + } else { + // Nasty special case to prevent handshake deadlocks. + // A client needs to keep the PTO timer armed to prevent a stall + // of the handshake. Technically, this has to stop once we receive + // an ACK of Handshake or 1-RTT, or when we receive HANDSHAKE_DONE, + // but a few extra probes won't hurt. + // A server shouldn't arm its PTO timer this way. The server sends + // ack-eliciting, in-flight packets immediately so this only + // happens when the server has nothing outstanding. If we had + // client authentication, this might cause some extra probes, + // but they would be harmless anyway. + self.pto_base_time + } + } + + pub fn on_packet_sent(&mut self, sent_packet: SentPacket) { + if sent_packet.ack_eliciting() { + self.pto_base_time = Some(sent_packet.time_sent); + self.in_flight_outstanding += 1; + } else if self.space != PNSpace::ApplicationData && self.pto_base_time.is_none() { + // For Initial and Handshake spaces, make sure that we have a PTO baseline + // always. See `LossRecoverySpace::pto_base_time()` for details. + self.pto_base_time = Some(sent_packet.time_sent); + } + self.sent_packets.insert(sent_packet.pn, sent_packet); + } + + fn remove_packet(&mut self, p: &SentPacket) { + if p.ack_eliciting() { + debug_assert!(self.in_flight_outstanding > 0); + self.in_flight_outstanding -= 1; + if self.in_flight_outstanding == 0 { + qtrace!("remove_packet outstanding == 0 for space {}", self.space); + + // See above comments; keep PTO armed for Initial/Handshake even + // if no outstanding packets. + if self.space == PNSpace::ApplicationData { + self.pto_base_time = None; + } + } + } + } + + /// Remove all acknowledged packets. + /// Returns all the acknowledged packets, with the largest packet number first. + /// ...and a boolean indicating if any of those packets were ack-eliciting. + /// This operates more efficiently because it assumes that the input is sorted + /// in the order that an ACK frame is (from the top). + fn remove_acked<R>(&mut self, acked_ranges: R, stats: &mut Stats) -> (Vec<SentPacket>, bool) + where + R: IntoIterator<Item = RangeInclusive<u64>>, + R::IntoIter: ExactSizeIterator, + { + let acked_ranges = acked_ranges.into_iter(); + let mut keep = Vec::with_capacity(acked_ranges.len()); + + let mut acked = Vec::new(); + let mut eliciting = false; + for range in acked_ranges { + let first_keep = *range.end() + 1; + if let Some((&first, _)) = self.sent_packets.range(range).next() { + let mut tail = self.sent_packets.split_off(&first); + if let Some((&next, _)) = tail.range(first_keep..).next() { + keep.push(tail.split_off(&next)); + } + for (_, p) in tail.into_iter().rev() { + self.remove_packet(&p); + eliciting |= p.ack_eliciting(); + if p.lost() { + stats.late_ack += 1; + } + if p.pto_fired() { + stats.pto_ack += 1; + } + acked.push(p); + } + } + } + + for mut k in keep.into_iter().rev() { + self.sent_packets.append(&mut k); + } + + (acked, eliciting) + } + + /// Remove all tracked packets from the space. + /// This is called by a client when 0-RTT packets are dropped, when a Retry is received + /// and when keys are dropped. + fn remove_ignored(&mut self) -> impl Iterator<Item = SentPacket> { + self.in_flight_outstanding = 0; + mem::take(&mut self.sent_packets) + .into_iter() + .map(|(_, v)| v) + } + + /// Remove old packets that we've been tracking in case they get acknowledged. + /// We try to keep these around until a probe is sent for them, so it is + /// important that `cd` is set to at least the current PTO time; otherwise we + /// might remove all in-flight packets and stop sending probes. + #[allow(clippy::option_if_let_else, clippy::unknown_clippy_lints)] // Hard enough to read as-is. + fn remove_old_lost(&mut self, now: Instant, cd: Duration) { + let mut it = self.sent_packets.iter(); + // If the first item is not expired, do nothing. + if it.next().map_or(false, |(_, p)| p.expired(now, cd)) { + // Find the index of the first unexpired packet. + let to_remove = if let Some(first_keep) = + it.find_map(|(i, p)| if p.expired(now, cd) { None } else { Some(*i) }) + { + // Some packets haven't expired, so keep those. + let keep = self.sent_packets.split_off(&first_keep); + mem::replace(&mut self.sent_packets, keep) + } else { + // All packets are expired. + mem::take(&mut self.sent_packets) + }; + for (_, p) in to_remove { + self.remove_packet(&p); + } + } + } + + /// Detect lost packets. + /// `loss_delay` is the time we will wait before declaring something lost. + /// `cleanup_delay` is the time we will wait before cleaning up a lost packet. + pub fn detect_lost_packets( + &mut self, + now: Instant, + loss_delay: Duration, + cleanup_delay: Duration, + lost_packets: &mut Vec<SentPacket>, + ) { + // Housekeeping. + self.remove_old_lost(now, cleanup_delay); + + // Packets sent before this time are deemed lost. + let lost_deadline = now - loss_delay; + qtrace!( + "detect lost {}: now={:?} delay={:?} deadline={:?}", + self.space, + now, + loss_delay, + lost_deadline + ); + self.first_ooo_time = None; + + let largest_acked = self.largest_acked; + + // Lost for retrans/CC purposes + let mut lost_pns = SmallVec::<[_; 8]>::new(); + + for (pn, packet) in self + .sent_packets + .iter_mut() + // BTreeMap iterates in order of ascending PN + .take_while(|(&k, _)| Some(k) < largest_acked) + { + if packet.time_sent <= lost_deadline { + qtrace!( + "lost={}, time sent {:?} is before lost_deadline {:?}", + pn, + packet.time_sent, + lost_deadline + ); + } else if largest_acked >= Some(*pn + PACKET_THRESHOLD) { + qtrace!( + "lost={}, is >= {} from largest acked {:?}", + pn, + PACKET_THRESHOLD, + largest_acked + ); + } else { + self.first_ooo_time = Some(packet.time_sent); + // No more packets can be declared lost after this one. + break; + }; + + if packet.declare_lost(now) { + lost_pns.push(*pn); + } + } + + lost_packets.extend(lost_pns.iter().map(|pn| self.sent_packets[pn].clone())); + } +} + +#[derive(Debug)] +pub(crate) struct LossRecoverySpaces { + /// When we have all of the loss recovery spaces, this will use a separate + /// allocation, but this is reduced once the handshake is done. + spaces: SmallVec<[LossRecoverySpace; 1]>, +} + +impl LossRecoverySpaces { + fn idx(space: PNSpace) -> usize { + match space { + PNSpace::ApplicationData => 0, + PNSpace::Handshake => 1, + PNSpace::Initial => 2, + } + } + + /// Drop a packet number space and return all the packets that were + /// outstanding, so that those can be marked as lost. + /// # Panics + /// If the space has already been removed. + pub fn drop_space(&mut self, space: PNSpace) -> Vec<SentPacket> { + let sp = match space { + PNSpace::Initial => self.spaces.pop(), + PNSpace::Handshake => { + let sp = self.spaces.pop(); + self.spaces.shrink_to_fit(); + sp + } + PNSpace::ApplicationData => panic!("discarding application space"), + }; + let mut sp = sp.unwrap(); + assert_eq!(sp.space(), space, "dropping spaces out of order"); + sp.remove_ignored().collect() + } + + pub fn get(&self, space: PNSpace) -> Option<&LossRecoverySpace> { + self.spaces.get(Self::idx(space)) + } + + pub fn get_mut(&mut self, space: PNSpace) -> Option<&mut LossRecoverySpace> { + self.spaces.get_mut(Self::idx(space)) + } + + fn iter(&self) -> impl Iterator<Item = &LossRecoverySpace> { + self.spaces.iter() + } + + fn iter_mut(&mut self) -> impl Iterator<Item = &mut LossRecoverySpace> { + self.spaces.iter_mut() + } +} + +impl Default for LossRecoverySpaces { + fn default() -> Self { + Self { + spaces: smallvec![ + LossRecoverySpace::new(PNSpace::ApplicationData), + LossRecoverySpace::new(PNSpace::Handshake), + LossRecoverySpace::new(PNSpace::Initial), + ], + } + } +} + +#[derive(Debug)] +struct PtoState { + /// The packet number space that caused the PTO to fire. + space: PNSpace, + /// The number of probes that we have sent. + count: usize, + packets: usize, + /// The complete set of packet number spaces that can have probes sent. + probe: PNSpaceSet, +} + +impl PtoState { + pub fn new(space: PNSpace, probe: PNSpaceSet) -> Self { + debug_assert!(probe[space]); + Self { + space, + count: 1, + packets: PTO_PACKET_COUNT, + probe, + } + } + + pub fn pto(&mut self, space: PNSpace, probe: PNSpaceSet) { + debug_assert!(probe[space]); + self.space = space; + self.count += 1; + self.packets = PTO_PACKET_COUNT; + self.probe = probe; + } + + pub fn count(&self) -> usize { + self.count + } + + pub fn count_pto(&self, stats: &mut Stats) { + stats.add_pto_count(self.count); + } + + /// Generate a sending profile, indicating what space it should be from. + /// This takes a packet from the supply or returns an ack-only profile if it can't. + pub fn send_profile(&mut self, mtu: usize) -> SendProfile { + if self.packets > 0 { + self.packets -= 1; + SendProfile::new_pto(self.space, mtu, self.probe) + } else { + SendProfile::new_limited(0) + } + } +} + +#[derive(Debug)] +pub(crate) struct LossRecovery { + /// When the handshake was confirmed, if it has been. + confirmed_time: Option<Instant>, + pto_state: Option<PtoState>, + rtt_vals: RttVals, + packet_sender: PacketSender, + spaces: LossRecoverySpaces, + qlog: NeqoQlog, + stats: StatsCell, +} + +impl LossRecovery { + pub fn new(alg: CongestionControlAlgorithm, stats: StatsCell) -> Self { + Self { + confirmed_time: None, + pto_state: None, + rtt_vals: RttVals::default(), + packet_sender: PacketSender::new(alg), + spaces: LossRecoverySpaces::default(), + qlog: NeqoQlog::default(), + stats, + } + } + + #[cfg(test)] + pub fn cwnd(&self) -> usize { + self.packet_sender.cwnd() + } + + pub fn rtt(&self) -> Duration { + self.rtt_vals.rtt() + } + + pub fn set_initial_rtt(&mut self, rtt: Duration) { + self.rtt_vals.set_initial_rtt(rtt) + } + + pub fn set_peer_max_ack_delay(&mut self, mad: Duration) { + self.rtt_vals.set_peer_max_ack_delay(mad); + } + + pub fn cwnd_avail(&self) -> usize { + self.packet_sender.cwnd_avail() + } + + pub fn largest_acknowledged_pn(&self, pn_space: PNSpace) -> Option<PacketNumber> { + self.spaces.get(pn_space).and_then(|sp| sp.largest_acked) + } + + pub fn set_qlog(&mut self, qlog: NeqoQlog) { + self.packet_sender.set_qlog(qlog.clone()); + self.qlog = qlog; + } + + pub fn drop_0rtt(&mut self) -> Vec<SentPacket> { + // The largest acknowledged or loss_time should still be unset. + // The client should not have received any ACK frames when it drops 0-RTT. + assert!(self + .spaces + .get(PNSpace::ApplicationData) + .unwrap() + .largest_acked + .is_none()); + self.spaces + .get_mut(PNSpace::ApplicationData) + .unwrap() + .remove_ignored() + .inspect(|p| self.packet_sender.discard(&p)) + .collect() + } + + pub fn on_packet_sent(&mut self, sent_packet: SentPacket) { + let pn_space = PNSpace::from(sent_packet.pt); + qdebug!([self], "packet {}-{} sent", pn_space, sent_packet.pn); + let rtt = self.rtt(); + if let Some(space) = self.spaces.get_mut(pn_space) { + self.packet_sender.on_packet_sent(&sent_packet, rtt); + space.on_packet_sent(sent_packet); + } else { + qinfo!( + [self], + "ignoring {}-{} from dropped space", + pn_space, + sent_packet.pn + ); + } + } + + /// Record an RTT sample. + fn rtt_sample(&mut self, send_time: Instant, now: Instant, ack_delay: Duration) { + // Limit ack delay by max_ack_delay if confirmed. + let delay = self.confirmed_time.map_or(ack_delay, |confirmed| { + if confirmed < send_time { + ack_delay + } else { + min(ack_delay, self.rtt_vals.max_ack_delay) + } + }); + + let sample = now - send_time; + self.rtt_vals.update_rtt(&mut self.qlog, sample, delay, now); + } + + /// Returns (acked packets, lost packets) + pub fn on_ack_received( + &mut self, + pn_space: PNSpace, + largest_acked: u64, + acked_ranges: Vec<RangeInclusive<u64>>, + ack_delay: Duration, + now: Instant, + ) -> (Vec<SentPacket>, Vec<SentPacket>) { + qdebug!( + [self], + "ACK for {} - largest_acked={}.", + pn_space, + largest_acked + ); + + let space = self + .spaces + .get_mut(pn_space) + .expect("ACK on discarded space"); + let (acked_packets, any_ack_eliciting) = + space.remove_acked(acked_ranges, &mut *self.stats.borrow_mut()); + if acked_packets.is_empty() { + // No new information. + return (Vec::new(), Vec::new()); + } + + // Track largest PN acked per space + let prev_largest_acked = space.largest_acked_sent_time; + if Some(largest_acked) > space.largest_acked { + space.largest_acked = Some(largest_acked); + + // If the largest acknowledged is newly acked and any newly acked + // packet was ack-eliciting, update the RTT. (-recovery 5.1) + let largest_acked_pkt = acked_packets.first().expect("must be there"); + space.largest_acked_sent_time = Some(largest_acked_pkt.time_sent); + if any_ack_eliciting { + self.rtt_sample(largest_acked_pkt.time_sent, now, ack_delay); + } + } + + // Perform loss detection. + // PTO is used to remove lost packets from in-flight accounting. + // We need to ensure that we have sent any PTO probes before they are removed + // as we rely on the count of in-flight packets to determine whether to send + // another probe. Removing them too soon would result in not sending on PTO. + let loss_delay = self.loss_delay(); + let cleanup = self.pto_period(pn_space); + let mut lost = Vec::new(); + self.spaces + .get_mut(pn_space) + .unwrap() + .detect_lost_packets(now, loss_delay, cleanup, &mut lost); + self.stats.borrow_mut().lost += lost.len(); + + // Tell the congestion controller about any lost packets. + // The PTO for congestion control is the raw number, without exponential + // backoff, so that we can determine persistent congestion. + let pto_raw = self.pto_raw(pn_space); + let first_rtt_sample = self.rtt_vals.first_sample_time(); + self.packet_sender + .on_packets_lost(first_rtt_sample, prev_largest_acked, pto_raw, &lost); + + // This must happen after on_packets_lost. If in recovery, this could + // take us out, and then lost packets will start a new recovery period + // when it shouldn't. + self.packet_sender.on_packets_acked(&acked_packets); + + self.pto_state = None; + + (acked_packets, lost) + } + + fn loss_delay(&self) -> Duration { + // kTimeThreshold = 9/8 + // loss_delay = kTimeThreshold * max(latest_rtt, smoothed_rtt) + // loss_delay = max(loss_delay, kGranularity) + let rtt = max(self.rtt_vals.latest_rtt, self.rtt_vals.smoothed_rtt); + max(rtt * 9 / 8, GRANULARITY) + } + + /// When receiving a retry, get all the sent packets so that they can be flushed. + /// We also need to pretend that they never happened for the purposes of congestion control. + pub fn retry(&mut self) -> Vec<SentPacket> { + self.pto_state = None; + let packet_sender = &mut self.packet_sender; + self.spaces + .iter_mut() + .flat_map(LossRecoverySpace::remove_ignored) + .inspect(|p| packet_sender.discard(&p)) + .collect() + } + + fn confirmed(&mut self, now: Instant) { + debug_assert!(self.confirmed_time.is_none()); + self.confirmed_time = Some(now); + // Up until now, the ApplicationData space has been ignored for PTO. + // So maybe fire a PTO. + if let Some(pto) = self.pto_time(PNSpace::ApplicationData) { + if pto < now { + let probes = PNSpaceSet::from(&[PNSpace::ApplicationData]); + self.fire_pto(PNSpace::ApplicationData, probes); + } + } + } + + /// Discard state for a given packet number space. + pub fn discard(&mut self, space: PNSpace, now: Instant) { + qdebug!([self], "Reset loss recovery state for {}", space); + for p in self.spaces.drop_space(space) { + self.packet_sender.discard(&p); + } + + // We just made progress, so discard PTO count. + // The spec says that clients should not do this until confirming that + // the server has completed address validation, but ignore that. + self.pto_state = None; + + if space == PNSpace::Handshake { + self.confirmed(now); + } + } + + /// Calculate when the next timeout is likely to be. This is the earlier of the loss timer + /// and the PTO timer; either or both might be disabled, so this can return `None`. + pub fn next_timeout(&mut self) -> Option<Instant> { + let loss_time = self.earliest_loss_time(); + let pto_time = self.earliest_pto(); + qtrace!( + [self], + "next_timeout loss={:?} pto={:?}", + loss_time, + pto_time + ); + match (loss_time, pto_time) { + (Some(loss_time), Some(pto_time)) => Some(min(loss_time, pto_time)), + (Some(loss_time), None) => Some(loss_time), + (None, Some(pto_time)) => Some(pto_time), + _ => None, + } + } + + /// Find when the earliest sent packet should be considered lost. + fn earliest_loss_time(&self) -> Option<Instant> { + self.spaces + .iter() + .filter_map(LossRecoverySpace::loss_recovery_timer_start) + .min() + .map(|val| val + self.loss_delay()) + } + + // The borrow checker is a harsh mistress. + // It's important that calls to `RttVals::pto()` are routed through a central point + // because that ensures consistency, but we often have a mutable borrow on other + // pieces of `self` that prevents that. + // An associated function avoids another borrow on `&self`. + fn pto_raw_inner(rtt_vals: &RttVals, space: PNSpace) -> Duration { + rtt_vals.pto(space) + } + + // Borrow checker hack, see above. + fn pto_period_inner( + rtt_vals: &RttVals, + pto_state: &Option<PtoState>, + pn_space: PNSpace, + ) -> Duration { + Self::pto_raw_inner(rtt_vals, pn_space) + .checked_mul(1 << pto_state.as_ref().map_or(0, |p| p.count)) + .unwrap_or(LOCAL_IDLE_TIMEOUT * 2) + } + + /// Get the Base PTO value, which is derived only from the `RTT` and `RTTvar` values. + /// This is for those cases where you need a value for the time you might sensibly + /// wait for a packet to propagate. Using `3*pto_raw(..)` is common. + pub fn pto_raw(&self, space: PNSpace) -> Duration { + Self::pto_raw_inner(&self.rtt_vals, space) + } + + /// Get the current PTO period for the given packet number space. + /// Unlike `pto_raw`, this includes calculation for the exponential backoff. + fn pto_period(&self, pn_space: PNSpace) -> Duration { + Self::pto_period_inner(&self.rtt_vals, &self.pto_state, pn_space) + } + + // Calculate PTO time for the given space. + fn pto_time(&self, pn_space: PNSpace) -> Option<Instant> { + if self.confirmed_time.is_none() && pn_space == PNSpace::ApplicationData { + None + } else { + self.spaces + .get(pn_space) + .and_then(|space| space.pto_base_time().map(|t| t + self.pto_period(pn_space))) + } + } + + /// Find the earliest PTO time for all active packet number spaces. + /// Ignore Application if either Initial or Handshake have an active PTO. + fn earliest_pto(&self) -> Option<Instant> { + if self.confirmed_time.is_some() { + self.pto_time(PNSpace::ApplicationData) + } else { + self.pto_time(PNSpace::Initial) + .iter() + .chain(self.pto_time(PNSpace::Handshake).iter()) + .min() + .cloned() + } + } + + fn fire_pto(&mut self, pn_space: PNSpace, allow_probes: PNSpaceSet) { + if let Some(st) = &mut self.pto_state { + st.pto(pn_space, allow_probes); + } else { + self.pto_state = Some(PtoState::new(pn_space, allow_probes)); + } + + self.pto_state + .as_mut() + .unwrap() + .count_pto(&mut *self.stats.borrow_mut()); + + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::PtoCount( + self.pto_state.as_ref().unwrap().count(), + )], + ); + } + + /// This checks whether the PTO timer has fired and fires it if needed. + /// When it has, mark a few packets as "lost" for the purposes of having frames + /// regenerated in subsequent packets. The packets aren't truly lost, so + /// we have to clone the `SentPacket` instance. + fn maybe_fire_pto(&mut self, now: Instant, lost: &mut Vec<SentPacket>) { + let mut pto_space = None; + // The spaces in which we will allow probing. + let mut allow_probes = PNSpaceSet::default(); + for pn_space in PNSpace::iter() { + if let Some(t) = self.pto_time(*pn_space) { + allow_probes[*pn_space] = true; + if t <= now { + qdebug!([self], "PTO timer fired for {}", pn_space); + let space = self.spaces.get_mut(*pn_space).unwrap(); + lost.extend(space.pto_packets(PTO_PACKET_COUNT).cloned()); + + pto_space = pto_space.or(Some(*pn_space)); + } + } + } + + // This has to happen outside the loop. Increasing the PTO count here causes the + // pto_time to increase which might cause PTO for later packet number spaces to not fire. + if let Some(pn_space) = pto_space { + qtrace!([self], "PTO {}, probing {:?}", pn_space, allow_probes); + self.fire_pto(pn_space, allow_probes); + } + } + + pub fn timeout(&mut self, now: Instant) -> Vec<SentPacket> { + qtrace!([self], "timeout {:?}", now); + + let loss_delay = self.loss_delay(); + let first_rtt_sample = self.rtt_vals.first_sample_time(); + + let mut lost_packets = Vec::new(); + for space in self.spaces.iter_mut() { + let first = lost_packets.len(); // The first packet lost in this space. + let pto = Self::pto_period_inner(&self.rtt_vals, &self.pto_state, space.space()); + space.detect_lost_packets(now, loss_delay, pto, &mut lost_packets); + self.packet_sender.on_packets_lost( + first_rtt_sample, + space.largest_acked_sent_time, + Self::pto_raw_inner(&self.rtt_vals, space.space()), + &lost_packets[first..], + ); + } + self.stats.borrow_mut().lost += lost_packets.len(); + + self.maybe_fire_pto(now, &mut lost_packets); + lost_packets + } + + /// Start the packet pacer. + pub fn start_pacer(&mut self, now: Instant) { + self.packet_sender.start_pacer(now); + } + + /// Get the next time that a paced packet might be sent. + pub fn next_paced(&self) -> Option<Instant> { + self.packet_sender.next_paced(self.rtt()) + } + + /// Check how packets should be sent, based on whether there is a PTO, + /// what the current congestion window is, and what the pacer says. + #[allow(clippy::option_if_let_else, clippy::unknown_clippy_lints)] + pub fn send_profile(&mut self, now: Instant, mtu: usize) -> SendProfile { + qdebug!([self], "get send profile {:?}", now); + if let Some(pto) = self.pto_state.as_mut() { + pto.send_profile(mtu) + } else { + let cwnd = self.cwnd_avail(); + if cwnd > mtu { + // More than an MTU available; we might need to pace. + if self.next_paced().map_or(false, |t| t > now) { + SendProfile::new_paced() + } else { + SendProfile::new_limited(mtu) + } + } else if self.packet_sender.recovery_packet() { + // After entering recovery, allow a packet to be sent immediately. + // This uses the PTO machinery, probing in all spaces. This will + // result in a PING being sent in every active space. + SendProfile::new_pto(PNSpace::Initial, mtu, PNSpaceSet::all()) + } else { + SendProfile::new_limited(cwnd) + } + } + } +} + +impl ::std::fmt::Display for LossRecovery { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "LossRecovery") + } +} + +#[cfg(test)] +mod tests { + use super::{ + CongestionControlAlgorithm, LossRecovery, LossRecoverySpace, PNSpace, SentPacket, + INITIAL_RTT, MAX_ACK_DELAY, + }; + use crate::packet::PacketType; + use crate::stats::{Stats, StatsCell}; + use std::convert::TryInto; + use std::time::{Duration, Instant}; + use test_fixture::now; + + const ON_SENT_SIZE: usize = 100; + + fn assert_rtts( + lr: &LossRecovery, + latest_rtt: Duration, + smoothed_rtt: Duration, + rttvar: Duration, + min_rtt: Duration, + ) { + println!( + "rtts: {:?} {:?} {:?} {:?}", + lr.rtt_vals.latest_rtt, + lr.rtt_vals.smoothed_rtt, + lr.rtt_vals.rttvar, + lr.rtt_vals.min_rtt, + ); + assert_eq!(lr.rtt_vals.latest_rtt, latest_rtt, "latest RTT"); + assert_eq!(lr.rtt_vals.smoothed_rtt, smoothed_rtt, "smoothed RTT"); + assert_eq!(lr.rtt_vals.rttvar, rttvar, "RTT variance"); + assert_eq!(lr.rtt_vals.min_rtt, min_rtt, "min RTT"); + } + + fn assert_sent_times( + lr: &LossRecovery, + initial: Option<Instant>, + handshake: Option<Instant>, + app_data: Option<Instant>, + ) { + let est = |sp| { + lr.spaces + .get(sp) + .and_then(LossRecoverySpace::loss_recovery_timer_start) + }; + println!( + "loss times: {:?} {:?} {:?}", + est(PNSpace::Initial), + est(PNSpace::Handshake), + est(PNSpace::ApplicationData), + ); + assert_eq!(est(PNSpace::Initial), initial, "Initial earliest sent time"); + assert_eq!( + est(PNSpace::Handshake), + handshake, + "Handshake earliest sent time" + ); + assert_eq!( + est(PNSpace::ApplicationData), + app_data, + "AppData earliest sent time" + ); + } + + fn assert_no_sent_times(lr: &LossRecovery) { + assert_sent_times(lr, None, None, None); + } + + // Time in milliseconds. + macro_rules! ms { + ($t:expr) => { + Duration::from_millis($t) + }; + } + + // In most of the tests below, packets are sent at a fixed cadence, with PACING between each. + const PACING: Duration = ms!(7); + fn pn_time(pn: u64) -> Instant { + now() + (PACING * pn.try_into().unwrap()) + } + + fn pace(lr: &mut LossRecovery, count: u64) { + for pn in 0..count { + lr.on_packet_sent(SentPacket::new( + PacketType::Short, + pn, + pn_time(pn), + true, + Vec::new(), + ON_SENT_SIZE, + )); + } + } + + const ACK_DELAY: Duration = ms!(24); + /// Acknowledge PN with the identified delay. + fn ack(lr: &mut LossRecovery, pn: u64, delay: Duration) { + lr.on_ack_received( + PNSpace::ApplicationData, + pn, + vec![pn..=pn], + ACK_DELAY, + pn_time(pn) + delay, + ); + } + + fn add_sent(lrs: &mut LossRecoverySpace, packet_numbers: &[u64]) { + for &pn in packet_numbers { + lrs.on_packet_sent(SentPacket::new( + PacketType::Short, + pn, + pn_time(pn), + true, + Vec::new(), + ON_SENT_SIZE, + )); + } + } + + fn match_acked(acked: &[SentPacket], expected: &[u64]) { + assert!(acked.iter().map(|p| &p.pn).eq(expected)); + } + + #[test] + fn remove_acked() { + let mut lrs = LossRecoverySpace::new(PNSpace::ApplicationData); + let mut stats = Stats::default(); + add_sent(&mut lrs, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let (acked, _) = lrs.remove_acked(vec![], &mut stats); + assert!(acked.is_empty()); + let (acked, _) = lrs.remove_acked(vec![7..=8, 2..=4], &mut stats); + match_acked(&acked, &[8, 7, 4, 3, 2]); + let (acked, _) = lrs.remove_acked(vec![8..=11], &mut stats); + match_acked(&acked, &[10, 9]); + let (acked, _) = lrs.remove_acked(vec![0..=2], &mut stats); + match_acked(&acked, &[1]); + let (acked, _) = lrs.remove_acked(vec![5..=6], &mut stats); + match_acked(&acked, &[6, 5]); + } + + #[test] + fn initial_rtt() { + let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default()); + lr.start_pacer(now()); + pace(&mut lr, 1); + let rtt = ms!(100); + ack(&mut lr, 0, rtt); + assert_rtts(&lr, rtt, rtt, rtt / 2, rtt); + assert_no_sent_times(&lr); + } + + /// An initial RTT for using with `setup_lr`. + const TEST_RTT: Duration = ms!(80); + const TEST_RTTVAR: Duration = ms!(40); + + /// Send `n` packets (using PACING), then acknowledge the first. + fn setup_lr(n: u64) -> LossRecovery { + let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default()); + lr.start_pacer(now()); + pace(&mut lr, n); + ack(&mut lr, 0, TEST_RTT); + assert_rtts(&lr, TEST_RTT, TEST_RTT, TEST_RTTVAR, TEST_RTT); + assert_no_sent_times(&lr); + lr + } + + // The ack delay is removed from any RTT estimate. + #[test] + fn ack_delay_adjusted() { + let mut lr = setup_lr(2); + ack(&mut lr, 1, TEST_RTT + ACK_DELAY); + // RTT stays the same, but the RTTVAR is adjusted downwards. + assert_rtts(&lr, TEST_RTT, TEST_RTT, TEST_RTTVAR * 3 / 4, TEST_RTT); + assert_no_sent_times(&lr); + } + + // The ack delay is ignored when it would cause a sample to be less than min_rtt. + #[test] + fn ack_delay_ignored() { + let mut lr = setup_lr(2); + let extra = ms!(8); + assert!(extra < ACK_DELAY); + ack(&mut lr, 1, TEST_RTT + extra); + let expected_rtt = TEST_RTT + (extra / 8); + let expected_rttvar = (TEST_RTTVAR * 3 + extra) / 4; + assert_rtts( + &lr, + TEST_RTT + extra, + expected_rtt, + expected_rttvar, + TEST_RTT, + ); + assert_no_sent_times(&lr); + } + + // A lower observed RTT is used as min_rtt (and ack delay is ignored). + #[test] + fn reduce_min_rtt() { + let mut lr = setup_lr(2); + let delta = ms!(4); + let reduced_rtt = TEST_RTT - delta; + ack(&mut lr, 1, reduced_rtt); + let expected_rtt = TEST_RTT - (delta / 8); + let expected_rttvar = (TEST_RTTVAR * 3 + delta) / 4; + assert_rtts(&lr, reduced_rtt, expected_rtt, expected_rttvar, reduced_rtt); + assert_no_sent_times(&lr); + } + + // Acknowledging something again has no effect. + #[test] + fn no_new_acks() { + let mut lr = setup_lr(1); + let check = |lr: &LossRecovery| { + assert_rtts(&lr, TEST_RTT, TEST_RTT, TEST_RTTVAR, TEST_RTT); + assert_no_sent_times(&lr); + }; + check(&lr); + + ack(&mut lr, 0, ms!(1339)); // much delayed ACK + check(&lr); + + ack(&mut lr, 0, ms!(3)); // time travel! + check(&lr); + } + + // Test time loss detection as part of handling a regular ACK. + #[test] + fn time_loss_detection_gap() { + let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default()); + lr.start_pacer(now()); + // Create a single packet gap, and have pn 0 time out. + // This can't use the default pacing, which is too tight. + // So send two packets with 1/4 RTT between them. Acknowledge pn 1 after 1 RTT. + // pn 0 should then be marked lost because it is then outstanding for 5RTT/4 + // the loss time for packets is 9RTT/8. + lr.on_packet_sent(SentPacket::new( + PacketType::Short, + 0, + pn_time(0), + true, + Vec::new(), + ON_SENT_SIZE, + )); + lr.on_packet_sent(SentPacket::new( + PacketType::Short, + 1, + pn_time(0) + TEST_RTT / 4, + true, + Vec::new(), + ON_SENT_SIZE, + )); + let (_, lost) = lr.on_ack_received( + PNSpace::ApplicationData, + 1, + vec![1..=1], + ACK_DELAY, + pn_time(0) + (TEST_RTT * 5 / 4), + ); + assert_eq!(lost.len(), 1); + assert_no_sent_times(&lr); + } + + // Test time loss detection as part of an explicit timeout. + #[test] + fn time_loss_detection_timeout() { + let mut lr = setup_lr(3); + + // We want to declare PN 2 as acknowledged before we declare PN 1 as lost. + // For this to work, we need PACING above to be less than 1/8 of an RTT. + let pn1_sent_time = pn_time(1); + let pn1_loss_time = pn1_sent_time + (TEST_RTT * 9 / 8); + let pn2_ack_time = pn_time(2) + TEST_RTT; + assert!(pn1_loss_time > pn2_ack_time); + + let (_, lost) = lr.on_ack_received( + PNSpace::ApplicationData, + 2, + vec![2..=2], + ACK_DELAY, + pn2_ack_time, + ); + assert!(lost.is_empty()); + // Run the timeout function here to force time-based loss recovery to be enabled. + let lost = lr.timeout(pn2_ack_time); + assert!(lost.is_empty()); + assert_sent_times(&lr, None, None, Some(pn1_sent_time)); + + // After time elapses, pn 1 is marked lost. + let callback_time = lr.next_timeout(); + assert_eq!(callback_time, Some(pn1_loss_time)); + let packets = lr.timeout(pn1_loss_time); + assert_eq!(packets.len(), 1); + // Checking for expiration with zero delay lets us check the loss time. + assert!(packets[0].expired(pn1_loss_time, Duration::new(0, 0))); + assert_no_sent_times(&lr); + } + + #[test] + fn big_gap_loss() { + let mut lr = setup_lr(5); // This sends packets 0-4 and acknowledges pn 0. + // Acknowledge just 2-4, which will cause pn 1 to be marked as lost. + assert_eq!(super::PACKET_THRESHOLD, 3); + let (_, lost) = lr.on_ack_received( + PNSpace::ApplicationData, + 4, + vec![2..=4], + ACK_DELAY, + pn_time(4), + ); + assert_eq!(lost.len(), 1); + } + + #[test] + #[should_panic(expected = "discarding application space")] + fn drop_app() { + let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default()); + lr.discard(PNSpace::ApplicationData, now()); + } + + #[test] + #[should_panic(expected = "dropping spaces out of order")] + fn drop_out_of_order() { + let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default()); + lr.discard(PNSpace::Handshake, now()); + } + + #[test] + #[should_panic(expected = "ACK on discarded space")] + fn ack_after_drop() { + let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default()); + lr.start_pacer(now()); + lr.discard(PNSpace::Initial, now()); + lr.on_ack_received( + PNSpace::Initial, + 0, + vec![], + Duration::from_millis(0), + pn_time(0), + ); + } + + #[test] + fn drop_spaces() { + let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default()); + lr.start_pacer(now()); + lr.on_packet_sent(SentPacket::new( + PacketType::Initial, + 0, + pn_time(0), + true, + Vec::new(), + ON_SENT_SIZE, + )); + lr.on_packet_sent(SentPacket::new( + PacketType::Handshake, + 0, + pn_time(1), + true, + Vec::new(), + ON_SENT_SIZE, + )); + lr.on_packet_sent(SentPacket::new( + PacketType::Short, + 0, + pn_time(2), + true, + Vec::new(), + ON_SENT_SIZE, + )); + + // Now put all spaces on the LR timer so we can see them. + for sp in &[ + PacketType::Initial, + PacketType::Handshake, + PacketType::Short, + ] { + let sent_pkt = SentPacket::new(*sp, 1, pn_time(3), true, Vec::new(), ON_SENT_SIZE); + let pn_space = PNSpace::from(sent_pkt.pt); + lr.on_packet_sent(sent_pkt); + lr.on_ack_received(pn_space, 1, vec![1..=1], Duration::from_secs(0), pn_time(3)); + let mut lost = Vec::new(); + lr.spaces.get_mut(pn_space).unwrap().detect_lost_packets( + pn_time(3), + TEST_RTT, + TEST_RTT * 3, // unused + &mut lost, + ); + assert!(lost.is_empty()); + } + + lr.discard(PNSpace::Initial, pn_time(3)); + assert_sent_times(&lr, None, Some(pn_time(1)), Some(pn_time(2))); + + lr.discard(PNSpace::Handshake, pn_time(3)); + assert_sent_times(&lr, None, None, Some(pn_time(2))); + + // There are cases where we send a packet that is not subsequently tracked. + // So check that this works. + lr.on_packet_sent(SentPacket::new( + PacketType::Initial, + 0, + pn_time(3), + true, + Vec::new(), + ON_SENT_SIZE, + )); + assert_sent_times(&lr, None, None, Some(pn_time(2))); + } + + #[test] + fn rearm_pto_after_confirmed() { + let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default()); + lr.start_pacer(now()); + lr.on_packet_sent(SentPacket::new( + PacketType::Handshake, + 0, + now(), + true, + Vec::new(), + ON_SENT_SIZE, + )); + lr.on_packet_sent(SentPacket::new( + PacketType::Short, + 0, + now(), + true, + Vec::new(), + ON_SENT_SIZE, + )); + + assert_eq!(lr.pto_time(PNSpace::ApplicationData), None); + lr.discard(PNSpace::Initial, pn_time(1)); + assert_eq!(lr.pto_time(PNSpace::ApplicationData), None); + + // Expiring state after the PTO on the ApplicationData space has + // expired should result in setting a PTO state. + let expected_pto = pn_time(2) + (INITIAL_RTT * 3) + MAX_ACK_DELAY; + lr.discard(PNSpace::Handshake, expected_pto); + let profile = lr.send_profile(expected_pto, 10000); + assert!(profile.pto.is_some()); + assert!(!profile.should_probe(PNSpace::Initial)); + assert!(!profile.should_probe(PNSpace::Handshake)); + assert!(profile.should_probe(PNSpace::ApplicationData)); + } +} diff --git a/third_party/rust/neqo-transport/src/recv_stream.rs b/third_party/rust/neqo-transport/src/recv_stream.rs new file mode 100644 index 0000000000..e82e08d2ec --- /dev/null +++ b/third_party/rust/neqo-transport/src/recv_stream.rs @@ -0,0 +1,1110 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Building a stream of ordered bytes to give the application from a series of +// incoming STREAM frames. + +use std::cell::RefCell; +use std::cmp::max; +use std::collections::BTreeMap; +use std::convert::TryFrom; +use std::mem; +use std::ops::Bound::{Included, Unbounded}; +use std::rc::Rc; + +use smallvec::SmallVec; + +use crate::events::ConnectionEvents; +use crate::flow_mgr::FlowMgr; +use crate::stream_id::StreamId; +use crate::{AppError, Error, Res}; +use neqo_common::qtrace; + +const RX_STREAM_DATA_WINDOW: u64 = 0x10_0000; // 1MiB + +// Export as usize for consistency with SEND_BUFFER_SIZE +pub const RECV_BUFFER_SIZE: usize = RX_STREAM_DATA_WINDOW as usize; + +pub(crate) type RecvStreams = BTreeMap<StreamId, RecvStream>; + +/// Holds data not yet read by application. Orders and dedupes data ranges +/// from incoming STREAM frames. +#[derive(Debug, Default)] +pub struct RxStreamOrderer { + data_ranges: BTreeMap<u64, Vec<u8>>, // (start_offset, data) + retired: u64, // Number of bytes the application has read +} + +impl RxStreamOrderer { + pub fn new() -> Self { + Self::default() + } + + /// Process an incoming stream frame off the wire. This may result in data + /// being available to upper layers if frame is not out of order (ooo) or + /// if the frame fills a gap. + pub fn inbound_frame(&mut self, mut new_start: u64, mut new_data: &[u8]) { + qtrace!("Inbound data offset={} len={}", new_start, new_data.len()); + + // Get entry before where new entry would go, so we can see if we already + // have the new bytes. + // Avoid copies and duplicated data. + let new_end = new_start + u64::try_from(new_data.len()).unwrap(); + + if new_end <= self.retired { + // Range already read by application, this frame is very late and unneeded. + return; + } + + if new_start < self.retired { + new_data = &new_data[usize::try_from(self.retired - new_start).unwrap()..]; + new_start = self.retired; + } + + if new_data.is_empty() { + // No data to insert + return; + } + + let extend = if let Some((&prev_start, prev_vec)) = self + .data_ranges + .range_mut((Unbounded, Included(new_start))) + .next_back() + { + let prev_end = prev_start + u64::try_from(prev_vec.len()).unwrap(); + if new_end > prev_end { + // PPPPPP -> PPPPPP + // NNNNNN NN + // NNNNNNNN NN + // Add a range containing only new data + // (In-order frames will take this path, with no overlap) + let overlap = prev_end.saturating_sub(new_start); + qtrace!( + "New frame {}-{} received, overlap: {}", + new_start, + new_end, + overlap + ); + new_start += overlap; + new_data = &new_data[usize::try_from(overlap).unwrap()..]; + // If it is small enough, extend the previous buffer. + // This can't always extend, because otherwise the buffer could end up + // growing indefinitely without being released. + prev_vec.len() < 4096 && prev_end == new_start + } else { + // PPPPPP -> PPPPPP + // NNNN + // NNNN + // Do nothing + qtrace!( + "Dropping frame with already-received range {}-{}", + new_start, + new_end + ); + return; + } + } else { + qtrace!("New frame {}-{} received", new_start, new_end); + false + }; + + // Now handle possible overlap with next entries + let mut to_remove = SmallVec::<[_; 8]>::new(); + let mut to_add = new_data; + + for (&next_start, next_data) in self.data_ranges.range_mut(new_start..) { + let next_end = next_start + u64::try_from(next_data.len()).unwrap(); + let overlap = new_end.saturating_sub(next_start); + if overlap == 0 { + break; + } else if next_end >= new_end { + qtrace!( + "New frame {}-{} overlaps with next frame by {}, truncating", + new_start, + new_end, + overlap + ); + let truncate_to = new_data.len() - usize::try_from(overlap).unwrap(); + to_add = &new_data[..truncate_to]; + break; + } else { + qtrace!( + "New frame {}-{} spans entire next frame {}-{}, replacing", + new_start, + new_end, + next_start, + next_end + ); + to_remove.push(next_start); + } + } + + for start in to_remove { + self.data_ranges.remove(&start); + } + + if !to_add.is_empty() { + if extend { + let (_, buf) = self + .data_ranges + .range_mut((Unbounded, Included(new_start))) + .next_back() + .unwrap(); + buf.extend_from_slice(to_add); + } else { + self.data_ranges.insert(new_start, to_add.to_vec()); + } + } + } + + /// Are any bytes readable? + pub fn data_ready(&self) -> bool { + self.data_ranges + .keys() + .next() + .map_or(false, |&start| start <= self.retired) + } + + /// How many bytes are readable? + fn bytes_ready(&self) -> usize { + let mut prev_end = self.retired; + self.data_ranges + .iter() + .map(|(start_offset, data)| { + // All ranges don't overlap but we could have partially + // retired some of the first entry's data. + let data_len = data.len() as u64 - self.retired.saturating_sub(*start_offset); + (start_offset, data_len) + }) + .take_while(|(start_offset, data_len)| { + if **start_offset <= prev_end { + prev_end += data_len; + true + } else { + false + } + }) + .map(|(_, data_len)| data_len as usize) + .sum() + } + + /// Bytes read by the application. + fn retired(&self) -> u64 { + self.retired + } + + /// Data bytes buffered. Could be more than bytes_readable if there are + /// ranges missing. + fn buffered(&self) -> u64 { + self.data_ranges + .iter() + .map(|(&start, data)| data.len() as u64 - (self.retired.saturating_sub(start))) + .sum() + } + + /// Copy received data (if any) into the buffer. Returns bytes copied. + fn read(&mut self, buf: &mut [u8]) -> usize { + qtrace!("Reading {} bytes, {} available", buf.len(), self.buffered()); + let mut copied = 0; + + for (&range_start, range_data) in &mut self.data_ranges { + let mut keep = false; + if self.retired >= range_start { + // Frame data has new contiguous bytes. + let copy_offset = + usize::try_from(max(range_start, self.retired) - range_start).unwrap(); + assert!(range_data.len() >= copy_offset); + let available = range_data.len() - copy_offset; + let space = buf.len() - copied; + let copy_bytes = if available > space { + keep = true; + space + } else { + available + }; + + if copy_bytes > 0 { + let copy_slc = &range_data[copy_offset..copy_offset + copy_bytes]; + buf[copied..copied + copy_bytes].copy_from_slice(copy_slc); + copied += copy_bytes; + self.retired += u64::try_from(copy_bytes).unwrap(); + } + } else { + // The data in the buffer isn't contiguous. + keep = true; + } + if keep { + let mut keep = self.data_ranges.split_off(&range_start); + mem::swap(&mut self.data_ranges, &mut keep); + return copied; + } + } + + self.data_ranges.clear(); + copied + } + + /// Extend the given Vector with any available data. + pub fn read_to_end(&mut self, buf: &mut Vec<u8>) -> usize { + let orig_len = buf.len(); + buf.resize(orig_len + self.bytes_ready(), 0); + self.read(&mut buf[orig_len..]) + } + + fn highest_seen_offset(&self) -> u64 { + let maybe_ooo_last = self + .data_ranges + .iter() + .next_back() + .map(|(start, data)| *start + data.len() as u64); + maybe_ooo_last.unwrap_or(self.retired) + } +} + +/// QUIC receiving states, based on -transport 3.2. +#[derive(Debug)] +#[allow(dead_code)] +// Because a dead_code warning is easier than clippy::unused_self, see https://github.com/rust-lang/rust/issues/68408 +enum RecvStreamState { + Recv { + recv_buf: RxStreamOrderer, + max_bytes: u64, // Maximum size of recv_buf + max_stream_data: u64, + }, + SizeKnown { + recv_buf: RxStreamOrderer, + final_size: u64, + }, + DataRecvd { + recv_buf: RxStreamOrderer, + }, + DataRead, + ResetRecvd, + // Defined by spec but we don't use it: ResetRead +} + +impl RecvStreamState { + fn new(max_bytes: u64) -> Self { + Self::Recv { + recv_buf: RxStreamOrderer::new(), + max_bytes, + max_stream_data: max_bytes, + } + } + + fn name(&self) -> &str { + match self { + Self::Recv { .. } => "Recv", + Self::SizeKnown { .. } => "SizeKnown", + Self::DataRecvd { .. } => "DataRecvd", + Self::DataRead => "DataRead", + Self::ResetRecvd => "ResetRecvd", + } + } + + fn recv_buf(&self) -> Option<&RxStreamOrderer> { + match self { + Self::Recv { recv_buf, .. } + | Self::SizeKnown { recv_buf, .. } + | Self::DataRecvd { recv_buf } => Some(recv_buf), + Self::DataRead | Self::ResetRecvd => None, + } + } + + fn final_size(&self) -> Option<u64> { + match self { + Self::SizeKnown { final_size, .. } => Some(*final_size), + _ => None, + } + } + + fn max_stream_data(&self) -> Option<u64> { + match self { + Self::Recv { + max_stream_data, .. + } => Some(*max_stream_data), + _ => None, + } + } +} + +/// Implement a QUIC receive stream. +#[derive(Debug)] +pub struct RecvStream { + stream_id: StreamId, + state: RecvStreamState, + flow_mgr: Rc<RefCell<FlowMgr>>, + conn_events: ConnectionEvents, +} + +impl RecvStream { + pub fn new( + stream_id: StreamId, + max_stream_data: u64, + flow_mgr: Rc<RefCell<FlowMgr>>, + conn_events: ConnectionEvents, + ) -> Self { + Self { + stream_id, + state: RecvStreamState::new(max_stream_data), + flow_mgr, + conn_events, + } + } + + fn set_state(&mut self, new_state: RecvStreamState) { + debug_assert_ne!( + mem::discriminant(&self.state), + mem::discriminant(&new_state) + ); + qtrace!( + "RecvStream {} state {} -> {}", + self.stream_id.as_u64(), + self.state.name(), + new_state.name() + ); + + if let RecvStreamState::Recv { .. } = &self.state { + self.flow_mgr + .borrow_mut() + .clear_max_stream_data(self.stream_id) + } + + if let RecvStreamState::DataRead = new_state { + self.conn_events.recv_stream_complete(self.stream_id); + } + + self.state = new_state; + } + + pub fn inbound_stream_frame(&mut self, fin: bool, offset: u64, data: &[u8]) -> Res<()> { + // We should post a DataReadable event only once when we change from no-data-ready to + // data-ready. Therefore remember the state before processing a new frame. + let already_data_ready = self.data_ready(); + let new_end = offset + u64::try_from(data.len()).unwrap(); + + // Send final size errors even if stream is closed + if let Some(final_size) = self.state.final_size() { + if new_end > final_size || (fin && new_end != final_size) { + return Err(Error::FinalSizeError); + } + } + + match &mut self.state { + RecvStreamState::Recv { + recv_buf, + max_stream_data, + .. + } => { + if new_end > *max_stream_data { + qtrace!("Stream RX window {} exceeded: {}", max_stream_data, new_end); + return Err(Error::FlowControlError); + } + + if fin { + let final_size = offset + data.len() as u64; + if final_size < recv_buf.highest_seen_offset() { + return Err(Error::FinalSizeError); + } + recv_buf.inbound_frame(offset, data); + + let buf = mem::replace(recv_buf, RxStreamOrderer::new()); + if final_size == buf.retired() + buf.bytes_ready() as u64 { + self.set_state(RecvStreamState::DataRecvd { recv_buf: buf }); + } else { + self.set_state(RecvStreamState::SizeKnown { + recv_buf: buf, + final_size, + }); + } + } else { + recv_buf.inbound_frame(offset, data); + } + } + RecvStreamState::SizeKnown { + recv_buf, + final_size, + } => { + recv_buf.inbound_frame(offset, data); + if *final_size == recv_buf.retired() + recv_buf.bytes_ready() as u64 { + let buf = mem::replace(recv_buf, RxStreamOrderer::new()); + self.set_state(RecvStreamState::DataRecvd { recv_buf: buf }); + } + } + RecvStreamState::DataRecvd { .. } + | RecvStreamState::DataRead + | RecvStreamState::ResetRecvd => { + qtrace!("data received when we are in state {}", self.state.name()) + } + } + + if !already_data_ready && (self.data_ready() || self.needs_to_inform_app_about_fin()) { + self.conn_events.recv_stream_readable(self.stream_id) + } + + Ok(()) + } + + pub fn reset(&mut self, application_error_code: AppError) { + match self.state { + RecvStreamState::Recv { .. } | RecvStreamState::SizeKnown { .. } => { + self.conn_events + .recv_stream_reset(self.stream_id, application_error_code); + self.set_state(RecvStreamState::ResetRecvd); + } + _ => { + // Ignore reset if in DataRecvd, DataRead, or ResetRecvd + } + } + } + + /// If we should tell the sender they have more credit, return an offset + pub fn maybe_send_flowc_update(&mut self) { + // Only ever needed if actively receiving and not in SizeKnown state + if let RecvStreamState::Recv { + max_bytes, + max_stream_data, + recv_buf, + } = &mut self.state + { + // Algo: send an update if app has consumed more than half + // the data in the current window + // TODO(agrover@mozilla.com): This algo is not great but + // should prevent Silly Window Syndrome. Spec refers to using + // highest seen offset somehow? RTT maybe? + let maybe_new_max = recv_buf.retired() + *max_bytes; + if maybe_new_max > (*max_bytes / 2) + *max_stream_data { + *max_stream_data = maybe_new_max; + self.flow_mgr + .borrow_mut() + .max_stream_data(self.stream_id, maybe_new_max) + } + } + } + + pub fn max_stream_data(&self) -> Option<u64> { + self.state.max_stream_data() + } + + pub fn is_terminal(&self) -> bool { + matches!( + self.state, + RecvStreamState::ResetRecvd | RecvStreamState::DataRead + ) + } + + // App got all data but did not get the fin signal. + fn needs_to_inform_app_about_fin(&self) -> bool { + matches!(self.state, RecvStreamState::DataRecvd { .. }) + } + + fn data_ready(&self) -> bool { + self.state + .recv_buf() + .map_or(false, RxStreamOrderer::data_ready) + } + + /// # Errors + /// `NoMoreData` if data and fin bit were previously read by the application. + pub fn read(&mut self, buf: &mut [u8]) -> Res<(usize, bool)> { + let res = match &mut self.state { + RecvStreamState::Recv { recv_buf, .. } + | RecvStreamState::SizeKnown { recv_buf, .. } => Ok((recv_buf.read(buf), false)), + RecvStreamState::DataRecvd { recv_buf } => { + let bytes_read = recv_buf.read(buf); + let fin_read = recv_buf.buffered() == 0; + if fin_read { + self.set_state(RecvStreamState::DataRead); + } + Ok((bytes_read, fin_read)) + } + RecvStreamState::DataRead | RecvStreamState::ResetRecvd => Err(Error::NoMoreData), + }; + self.maybe_send_flowc_update(); + res + } + + pub fn stop_sending(&mut self, err: AppError) { + qtrace!("stop_sending called when in state {}", self.state.name()); + match &self.state { + RecvStreamState::Recv { .. } | RecvStreamState::SizeKnown { .. } => { + self.set_state(RecvStreamState::ResetRecvd); + self.flow_mgr.borrow_mut().stop_sending(self.stream_id, err) + } + RecvStreamState::DataRecvd { .. } => self.set_state(RecvStreamState::DataRead), + RecvStreamState::DataRead | RecvStreamState::ResetRecvd => { + // Already in terminal state + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::frame::Frame; + use std::ops::Range; + + fn recv_ranges(ranges: &[Range<u64>], available: usize) { + const ZEROES: &[u8] = &[0; 100]; + qtrace!("recv_ranges {:?}", ranges); + + let mut s = RxStreamOrderer::default(); + for r in ranges { + let data = &ZEROES[..usize::try_from(r.end - r.start).unwrap()]; + s.inbound_frame(r.start, data); + } + + let mut buf = [0xff; 100]; + let mut total_recvd = 0; + loop { + let recvd = s.read(&mut buf[..]); + qtrace!("recv_ranges read {}", recvd); + total_recvd += recvd; + if recvd == 0 { + assert_eq!(total_recvd, available); + break; + } + } + } + + #[test] + fn recv_noncontiguous() { + // Non-contiguous with the start, no data available. + recv_ranges(&[10..20], 0); + } + + /// Overlaps with the start of a 10..20 range of bytes. + #[test] + fn recv_overlap_start() { + // Overlap the start, with a larger new value. + // More overlap than not. + recv_ranges(&[10..20, 4..18, 0..4], 20); + // Overlap the start, with a larger new value. + // Less overlap than not. + recv_ranges(&[10..20, 2..15, 0..2], 20); + // Overlap the start, with a smaller new value. + // More overlap than not. + recv_ranges(&[10..20, 8..14, 0..8], 20); + // Overlap the start, with a smaller new value. + // Less overlap than not. + recv_ranges(&[10..20, 6..13, 0..6], 20); + + // Again with some of the first range split in two. + recv_ranges(&[10..11, 11..20, 4..18, 0..4], 20); + recv_ranges(&[10..11, 11..20, 2..15, 0..2], 20); + recv_ranges(&[10..11, 11..20, 8..14, 0..8], 20); + recv_ranges(&[10..11, 11..20, 6..13, 0..6], 20); + + // Again with a gap in the first range. + recv_ranges(&[10..11, 12..20, 4..18, 0..4], 20); + recv_ranges(&[10..11, 12..20, 2..15, 0..2], 20); + recv_ranges(&[10..11, 12..20, 8..14, 0..8], 20); + recv_ranges(&[10..11, 12..20, 6..13, 0..6], 20); + } + + /// Overlaps with the end of a 10..20 range of bytes. + #[test] + fn recv_overlap_end() { + // Overlap the end, with a larger new value. + // More overlap than not. + recv_ranges(&[10..20, 12..25, 0..10], 25); + // Overlap the end, with a larger new value. + // Less overlap than not. + recv_ranges(&[10..20, 17..33, 0..10], 33); + // Overlap the end, with a smaller new value. + // More overlap than not. + recv_ranges(&[10..20, 15..21, 0..10], 21); + // Overlap the end, with a smaller new value. + // Less overlap than not. + recv_ranges(&[10..20, 17..25, 0..10], 25); + + // Again with some of the first range split in two. + recv_ranges(&[10..19, 19..20, 12..25, 0..10], 25); + recv_ranges(&[10..19, 19..20, 17..33, 0..10], 33); + recv_ranges(&[10..19, 19..20, 15..21, 0..10], 21); + recv_ranges(&[10..19, 19..20, 17..25, 0..10], 25); + + // Again with a gap in the first range. + recv_ranges(&[10..18, 19..20, 12..25, 0..10], 25); + recv_ranges(&[10..18, 19..20, 17..33, 0..10], 33); + recv_ranges(&[10..18, 19..20, 15..21, 0..10], 21); + recv_ranges(&[10..18, 19..20, 17..25, 0..10], 25); + } + + /// Complete overlaps with the start of a 10..20 range of bytes. + #[test] + fn recv_overlap_complete() { + // Complete overlap, more at the end. + recv_ranges(&[10..20, 9..23, 0..9], 23); + // Complete overlap, more at the start. + recv_ranges(&[10..20, 3..23, 0..3], 23); + // Complete overlap, to end. + recv_ranges(&[10..20, 5..20, 0..5], 20); + // Complete overlap, from start. + recv_ranges(&[10..20, 10..27, 0..10], 27); + // Complete overlap, from 0 and more. + recv_ranges(&[10..20, 0..23], 23); + + // Again with the first range split in two. + recv_ranges(&[10..14, 14..20, 9..23, 0..9], 23); + recv_ranges(&[10..14, 14..20, 3..23, 0..3], 23); + recv_ranges(&[10..14, 14..20, 5..20, 0..5], 20); + recv_ranges(&[10..14, 14..20, 10..27, 0..10], 27); + recv_ranges(&[10..14, 14..20, 0..23], 23); + + // Again with the a gap in the first range. + recv_ranges(&[10..13, 14..20, 9..23, 0..9], 23); + recv_ranges(&[10..13, 14..20, 3..23, 0..3], 23); + recv_ranges(&[10..13, 14..20, 5..20, 0..5], 20); + recv_ranges(&[10..13, 14..20, 10..27, 0..10], 27); + recv_ranges(&[10..13, 14..20, 0..23], 23); + } + + /// An overlap with no new bytes. + #[test] + fn recv_overlap_duplicate() { + recv_ranges(&[10..20, 11..12, 0..10], 20); + recv_ranges(&[10..20, 10..15, 0..10], 20); + recv_ranges(&[10..20, 14..20, 0..10], 20); + // Now with the first range split. + recv_ranges(&[10..14, 14..20, 10..15, 0..10], 20); + recv_ranges(&[10..15, 16..20, 21..25, 10..25, 0..10], 25); + } + + /// Reading exactly one chunk works, when the next chunk starts immediately. + #[test] + fn stop_reading_at_chunk() { + const CHUNK_SIZE: usize = 10; + const EXTRA_SIZE: usize = 3; + let mut s = RxStreamOrderer::new(); + + // Add three chunks. + s.inbound_frame(0, &[0; CHUNK_SIZE]); + let offset = u64::try_from(CHUNK_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + let offset = u64::try_from(CHUNK_SIZE + EXTRA_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + + // Read, providing only enough space for the first. + let mut buf = vec![0; 100]; + let count = s.read(&mut buf[..CHUNK_SIZE]); + assert_eq!(count, CHUNK_SIZE); + let count = s.read(&mut buf[..]); + assert_eq!(count, EXTRA_SIZE * 2); + } + + #[test] + fn recv_overlap_while_reading() { + let mut s = RxStreamOrderer::new(); + + // Add a chunk + s.inbound_frame(0, &[0; 150]); + assert_eq!(s.data_ranges.get(&0).unwrap().len(), 150); + // Read, providing only enough space for the first 100. + let mut buf = [0; 100]; + let count = s.read(&mut buf[..]); + assert_eq!(count, 100); + assert_eq!(s.retired, 100); + + // Add a second frame that overlaps. + // This shouldn't truncate the first frame, as we're already + // Reading from it. + s.inbound_frame(120, &[0; 60]); + assert_eq!(s.data_ranges.get(&0).unwrap().len(), 180); + // Read second part of first frame and all of the second frame + let count = s.read(&mut buf[..]); + assert_eq!(count, 80); + } + + /// Reading exactly one chunk works, when there is a gap. + #[test] + fn stop_reading_at_gap() { + const CHUNK_SIZE: usize = 10; + const EXTRA_SIZE: usize = 3; + let mut s = RxStreamOrderer::new(); + + // Add three chunks. + s.inbound_frame(0, &[0; CHUNK_SIZE]); + let offset = u64::try_from(CHUNK_SIZE + EXTRA_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + + // Read, providing only enough space for the first chunk. + let mut buf = [0; 100]; + let count = s.read(&mut buf[..CHUNK_SIZE]); + assert_eq!(count, CHUNK_SIZE); + + // Now fill the gap and ensure that everything can be read. + let offset = u64::try_from(CHUNK_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + let count = s.read(&mut buf[..]); + assert_eq!(count, EXTRA_SIZE * 2); + } + + /// Reading exactly one chunk works, when there is a gap. + #[test] + fn stop_reading_in_chunk() { + const CHUNK_SIZE: usize = 10; + const EXTRA_SIZE: usize = 3; + let mut s = RxStreamOrderer::new(); + + // Add two chunks. + s.inbound_frame(0, &[0; CHUNK_SIZE]); + let offset = u64::try_from(CHUNK_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + + // Read, providing only enough space for some of the first chunk. + let mut buf = [0; 100]; + let count = s.read(&mut buf[..CHUNK_SIZE - EXTRA_SIZE]); + assert_eq!(count, CHUNK_SIZE - EXTRA_SIZE); + + let count = s.read(&mut buf[..]); + assert_eq!(count, EXTRA_SIZE * 2); + } + + /// Read one byte at a time. + #[test] + fn read_byte_at_a_time() { + const CHUNK_SIZE: usize = 10; + const EXTRA_SIZE: usize = 3; + let mut s = RxStreamOrderer::new(); + + // Add two chunks. + s.inbound_frame(0, &[0; CHUNK_SIZE]); + let offset = u64::try_from(CHUNK_SIZE).unwrap(); + s.inbound_frame(offset, &[0; EXTRA_SIZE]); + + let mut buf = [0; 1]; + for _ in 0..CHUNK_SIZE + EXTRA_SIZE { + let count = s.read(&mut buf[..]); + assert_eq!(count, 1); + } + assert_eq!(0, s.read(&mut buf[..])); + } + + #[test] + fn stream_rx() { + let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); + let conn_events = ConnectionEvents::default(); + + let mut s = RecvStream::new(StreamId::from(567), 1024, Rc::clone(&flow_mgr), conn_events); + + // test receiving a contig frame and reading it works + s.inbound_stream_frame(false, 0, &[1; 10]).unwrap(); + assert_eq!(s.data_ready(), true); + let mut buf = vec![0u8; 100]; + assert_eq!(s.read(&mut buf).unwrap(), (10, false)); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 0); + + // test receiving a noncontig frame + s.inbound_stream_frame(false, 12, &[2; 12]).unwrap(); + assert_eq!(s.data_ready(), false); + assert_eq!(s.read(&mut buf).unwrap(), (0, false)); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 12); + + // another frame that overlaps the first + s.inbound_stream_frame(false, 14, &[3; 8]).unwrap(); + assert_eq!(s.data_ready(), false); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 12); + + // fill in the gap, but with a FIN + s.inbound_stream_frame(true, 10, &[4; 6]).unwrap_err(); + assert_eq!(s.data_ready(), false); + assert_eq!(s.read(&mut buf).unwrap(), (0, false)); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 12); + + // fill in the gap + s.inbound_stream_frame(false, 10, &[5; 10]).unwrap(); + assert_eq!(s.data_ready(), true); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 14); + + // a legit FIN + s.inbound_stream_frame(true, 24, &[6; 18]).unwrap(); + assert_eq!(s.state.recv_buf().unwrap().retired(), 10); + assert_eq!(s.state.recv_buf().unwrap().buffered(), 32); + assert_eq!(s.data_ready(), true); + assert_eq!(s.read(&mut buf).unwrap(), (32, true)); + + // Stream now no longer readable (is in DataRead state) + s.read(&mut buf).unwrap_err(); + } + + fn check_chunks(s: &mut RxStreamOrderer, expected: &[(u64, usize)]) { + assert_eq!(s.data_ranges.len(), expected.len()); + for ((start, buf), (expected_start, expected_len)) in s.data_ranges.iter().zip(expected) { + assert_eq!((*start, buf.len()), (*expected_start, *expected_len)); + } + } + + // Test deduplication when the new data is at the end. + #[test] + fn stream_rx_dedupe_tail() { + let mut s = RxStreamOrderer::new(); + + s.inbound_frame(0, &[1; 6]); + check_chunks(&mut s, &[(0, 6)]); + + // New data that overlaps entirely (starting from the head), is ignored. + s.inbound_frame(0, &[2; 3]); + check_chunks(&mut s, &[(0, 6)]); + + // New data that overlaps at the tail has any new data appended. + s.inbound_frame(2, &[3; 6]); + check_chunks(&mut s, &[(0, 8)]); + + // New data that overlaps entirely (up to the tail), is ignored. + s.inbound_frame(4, &[4; 4]); + check_chunks(&mut s, &[(0, 8)]); + + // New data that overlaps, starting from the beginning is appended too. + s.inbound_frame(0, &[5; 10]); + check_chunks(&mut s, &[(0, 10)]); + + // New data that is entirely subsumed is ignored. + s.inbound_frame(2, &[6; 2]); + check_chunks(&mut s, &[(0, 10)]); + + let mut buf = [0; 16]; + assert_eq!(s.read(&mut buf[..]), 10); + assert_eq!(buf[..10], [1, 1, 1, 1, 1, 1, 3, 3, 5, 5]); + } + + /// When chunks are added before existing data, they aren't merged. + #[test] + fn stream_rx_dedupe_head() { + let mut s = RxStreamOrderer::new(); + + s.inbound_frame(1, &[6; 6]); + check_chunks(&mut s, &[(1, 6)]); + + // Insertion before an existing chunk causes truncation of the new chunk. + s.inbound_frame(0, &[7; 6]); + check_chunks(&mut s, &[(0, 1), (1, 6)]); + + // Perfect overlap with existing slices has no effect. + s.inbound_frame(0, &[8; 7]); + check_chunks(&mut s, &[(0, 1), (1, 6)]); + + let mut buf = [0; 16]; + assert_eq!(s.read(&mut buf[..]), 7); + assert_eq!(buf[..7], [7, 6, 6, 6, 6, 6, 6]); + } + + #[test] + fn stream_rx_dedupe_new_tail() { + let mut s = RxStreamOrderer::new(); + + s.inbound_frame(1, &[6; 6]); + check_chunks(&mut s, &[(1, 6)]); + + // Insertion before an existing chunk causes truncation of the new chunk. + s.inbound_frame(0, &[7; 6]); + check_chunks(&mut s, &[(0, 1), (1, 6)]); + + // New data at the end causes the tail to be added to the first chunk, + // replacing later chunks entirely. + s.inbound_frame(0, &[9; 8]); + check_chunks(&mut s, &[(0, 8)]); + + let mut buf = [0; 16]; + assert_eq!(s.read(&mut buf[..]), 8); + assert_eq!(buf[..8], [7, 9, 9, 9, 9, 9, 9, 9]); + } + + #[test] + fn stream_rx_dedupe_replace() { + let mut s = RxStreamOrderer::new(); + + s.inbound_frame(2, &[6; 6]); + check_chunks(&mut s, &[(2, 6)]); + + // Insertion before an existing chunk causes truncation of the new chunk. + s.inbound_frame(1, &[7; 6]); + check_chunks(&mut s, &[(1, 1), (2, 6)]); + + // New data at the start and end replaces all the slices. + s.inbound_frame(0, &[9; 10]); + check_chunks(&mut s, &[(0, 10)]); + + let mut buf = [0; 16]; + assert_eq!(s.read(&mut buf[..]), 10); + assert_eq!(buf[..10], [9; 10]); + } + + #[test] + fn trim_retired() { + let mut s = RxStreamOrderer::new(); + + let mut buf = [0; 18]; + s.inbound_frame(0, &[1; 10]); + + // Partially read slices are retained. + assert_eq!(s.read(&mut buf[..6]), 6); + check_chunks(&mut s, &[(0, 10)]); + + // Partially read slices are kept and so are added to. + s.inbound_frame(3, &buf[..10]); + check_chunks(&mut s, &[(0, 13)]); + + // Wholly read pieces are dropped. + assert_eq!(s.read(&mut buf[..]), 7); + assert!(s.data_ranges.is_empty()); + + // New data that overlaps with retired data is trimmed. + s.inbound_frame(0, &buf[..]); + check_chunks(&mut s, &[(13, 5)]); + } + + #[test] + fn stream_flowc_update() { + let flow_mgr = Rc::default(); + let conn_events = ConnectionEvents::default(); + + let frame1 = vec![0; RECV_BUFFER_SIZE]; + + let mut s = RecvStream::new( + StreamId::from(4), + RX_STREAM_DATA_WINDOW, + Rc::clone(&flow_mgr), + conn_events, + ); + + let mut buf = vec![0u8; RECV_BUFFER_SIZE + 100]; // Make it overlarge + + s.maybe_send_flowc_update(); + assert_eq!(s.flow_mgr.borrow().peek(), None); + s.inbound_stream_frame(false, 0, &frame1).unwrap(); + s.maybe_send_flowc_update(); + assert_eq!(s.flow_mgr.borrow().peek(), None); + assert_eq!(s.read(&mut buf).unwrap(), (RECV_BUFFER_SIZE, false)); + assert_eq!(s.data_ready(), false); + s.maybe_send_flowc_update(); + + // flow msg generated! + assert!(s.flow_mgr.borrow().peek().is_some()); + + // consume it + s.flow_mgr.borrow_mut().next().unwrap(); + + // it should be gone + s.maybe_send_flowc_update(); + assert_eq!(s.flow_mgr.borrow().peek(), None); + } + + #[test] + fn stream_max_stream_data() { + let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); + let conn_events = ConnectionEvents::default(); + + let frame1 = vec![0; RECV_BUFFER_SIZE]; + let mut s = RecvStream::new( + StreamId::from(67), + RX_STREAM_DATA_WINDOW, + Rc::clone(&flow_mgr), + conn_events, + ); + + s.maybe_send_flowc_update(); + assert_eq!(s.flow_mgr.borrow().peek(), None); + s.inbound_stream_frame(false, 0, &frame1).unwrap(); + s.inbound_stream_frame(false, RX_STREAM_DATA_WINDOW, &[1; 1]) + .unwrap_err(); + } + + #[test] + fn stream_orderer_bytes_ready() { + let mut rx_ord = RxStreamOrderer::new(); + + rx_ord.inbound_frame(0, &[1; 6]); + assert_eq!(rx_ord.bytes_ready(), 6); + assert_eq!(rx_ord.buffered(), 6); + assert_eq!(rx_ord.retired(), 0); + + // read some so there's an offset into the first frame + let mut buf = [0u8; 10]; + rx_ord.read(&mut buf[..2]); + assert_eq!(rx_ord.bytes_ready(), 4); + assert_eq!(rx_ord.buffered(), 4); + assert_eq!(rx_ord.retired(), 2); + + // an overlapping frame + rx_ord.inbound_frame(5, &[2; 6]); + assert_eq!(rx_ord.bytes_ready(), 9); + assert_eq!(rx_ord.buffered(), 9); + assert_eq!(rx_ord.retired(), 2); + + // a noncontig frame + rx_ord.inbound_frame(20, &[3; 6]); + assert_eq!(rx_ord.bytes_ready(), 9); + assert_eq!(rx_ord.buffered(), 15); + assert_eq!(rx_ord.retired(), 2); + + // an old frame + rx_ord.inbound_frame(0, &[4; 2]); + assert_eq!(rx_ord.bytes_ready(), 9); + assert_eq!(rx_ord.buffered(), 15); + assert_eq!(rx_ord.retired(), 2); + } + + #[test] + fn no_stream_flowc_event_after_exiting_recv() { + let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); + let conn_events = ConnectionEvents::default(); + + let frame1 = vec![0; RECV_BUFFER_SIZE]; + let stream_id = StreamId::from(67); + let mut s = RecvStream::new( + stream_id, + RX_STREAM_DATA_WINDOW, + Rc::clone(&flow_mgr), + conn_events, + ); + + s.inbound_stream_frame(false, 0, &frame1).unwrap(); + flow_mgr.borrow_mut().max_stream_data(stream_id, 100); + assert!(matches!(s.flow_mgr.borrow().peek().unwrap(), Frame::MaxStreamData{..})); + s.inbound_stream_frame(true, RX_STREAM_DATA_WINDOW, &[]) + .unwrap(); + assert!(matches!(s.flow_mgr.borrow().peek(), None)); + } + + #[test] + fn resend_flowc_if_lost() { + let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); + let conn_events = ConnectionEvents::default(); + + let frame1 = &[0; RECV_BUFFER_SIZE]; + let stream_id = StreamId::from(67); + let mut s = RecvStream::new( + stream_id, + RX_STREAM_DATA_WINDOW, + Rc::clone(&flow_mgr), + conn_events, + ); + + // A flow control update is queued + s.inbound_stream_frame(false, 0, frame1).unwrap(); + flow_mgr.borrow_mut().max_stream_data(stream_id, 100); + // Generates frame + assert!(matches!( + s.flow_mgr.borrow_mut().next().unwrap(), + Frame::MaxStreamData { .. } + )); + // Nothing else queued + assert!(matches!(s.flow_mgr.borrow().peek(), None)); + // Asking for another one won't get you one + s.maybe_send_flowc_update(); + assert!(matches!(s.flow_mgr.borrow().peek(), None)); + // But if lost, another frame is generated + flow_mgr.borrow_mut().max_stream_data(stream_id, 100); + assert!(matches!(s.flow_mgr.borrow_mut().next().unwrap(), Frame::MaxStreamData{..})); + } +} diff --git a/third_party/rust/neqo-transport/src/send_stream.rs b/third_party/rust/neqo-transport/src/send_stream.rs new file mode 100644 index 0000000000..b6b9eea5f5 --- /dev/null +++ b/third_party/rust/neqo-transport/src/send_stream.rs @@ -0,0 +1,1746 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Buffering data to send until it is acked. + +use std::cell::RefCell; +use std::cmp::{max, min}; +use std::collections::{BTreeMap, VecDeque}; +use std::convert::{TryFrom, TryInto}; +use std::mem; +use std::rc::Rc; + +use indexmap::IndexMap; +use smallvec::SmallVec; + +use neqo_common::{qdebug, qerror, qinfo, qtrace, Encoder}; + +use crate::events::ConnectionEvents; +use crate::flow_mgr::FlowMgr; +use crate::frame::Frame; +use crate::packet::PacketBuilder; +use crate::recovery::RecoveryToken; +use crate::stats::FrameStats; +use crate::stream_id::StreamId; +use crate::{AppError, Error, Res}; + +pub const SEND_BUFFER_SIZE: usize = 0x10_0000; // 1 MiB + +#[derive(Debug, PartialEq, Clone, Copy)] +enum RangeState { + Sent, + Acked, +} + +/// Track ranges in the stream as sent or acked. Acked implies sent. Not in a +/// range implies needing-to-be-sent, either initially or as a retransmission. +#[derive(Debug, Default, PartialEq)] +struct RangeTracker { + // offset, (len, RangeState). Use u64 for len because ranges can exceed 32bits. + used: BTreeMap<u64, (u64, RangeState)>, +} + +impl RangeTracker { + fn highest_offset(&self) -> u64 { + self.used + .range(..) + .next_back() + .map_or(0, |(k, (v, _))| *k + *v) + } + + fn acked_from_zero(&self) -> u64 { + self.used + .get(&0) + .filter(|(_, state)| *state == RangeState::Acked) + .map_or(0, |(v, _)| *v) + } + + /// Find the first unmarked range. If all are contiguous, this will return + /// (highest_offset(), None). + fn first_unmarked_range(&self) -> (u64, Option<u64>) { + let mut prev_end = 0; + + for (cur_off, (cur_len, _)) in &self.used { + if prev_end == *cur_off { + prev_end = cur_off + cur_len; + } else { + return (prev_end, Some(cur_off - prev_end)); + } + } + (prev_end, None) + } + + /// Turn one range into a list of subranges that align with existing + /// ranges. + /// Check impermissible overlaps in subregions: Sent cannot overwrite Acked. + // + // e.g. given N is new and ABC are existing: + // NNNNNNNNNNNNNNNN + // AAAAA BBBCCCCC ...then we want 5 chunks: + // 1122222333444555 + // + // but also if we have this: + // NNNNNNNNNNNNNNNN + // AAAAAAAAAA BBBB ...then break existing A and B ranges up: + // + // 1111111122222233 + // aaAAAAAAAA BBbb + // + // Doing all this work up front should make handling each chunk much + // easier. + fn chunk_range_on_edges( + &mut self, + new_off: u64, + new_len: u64, + new_state: RangeState, + ) -> Vec<(u64, u64, RangeState)> { + let mut tmp_off = new_off; + let mut tmp_len = new_len; + let mut v = Vec::new(); + + // cut previous overlapping range if needed + let prev = self.used.range_mut(..tmp_off).next_back(); + if let Some((prev_off, (prev_len, prev_state))) = prev { + let prev_state = *prev_state; + let overlap = (*prev_off + *prev_len).saturating_sub(new_off); + *prev_len -= overlap; + if overlap > 0 { + self.used.insert(new_off, (overlap, prev_state)); + } + } + + let mut last_existing_remaining = None; + for (off, (len, state)) in self.used.range(tmp_off..tmp_off + tmp_len) { + // Create chunk for "overhang" before an existing range + if tmp_off < *off { + let sub_len = off - tmp_off; + v.push((tmp_off, sub_len, new_state)); + tmp_off += sub_len; + tmp_len -= sub_len; + } + + // Create chunk to match existing range + let sub_len = min(*len, tmp_len); + let remaining_len = len - sub_len; + if new_state == RangeState::Sent && *state == RangeState::Acked { + qinfo!( + "Attempted to downgrade overlapping range Acked range {}-{} with Sent {}-{}", + off, + len, + new_off, + new_len + ); + } else { + v.push((tmp_off, sub_len, new_state)); + } + tmp_off += sub_len; + tmp_len -= sub_len; + + if remaining_len > 0 { + last_existing_remaining = Some((*off, sub_len, remaining_len, *state)); + } + } + + // Maybe break last existing range in two so that a final chunk will + // have the same length as an existing range entry + if let Some((off, sub_len, remaining_len, state)) = last_existing_remaining { + *self.used.get_mut(&off).expect("must be there") = (sub_len, state); + self.used.insert(off + sub_len, (remaining_len, state)); + } + + // Create final chunk if anything remains of the new range + if tmp_len > 0 { + v.push((tmp_off, tmp_len, new_state)) + } + + v + } + + /// Merge contiguous Acked ranges into the first entry (0). This range may + /// be dropped from the send buffer. + fn coalesce_acked_from_zero(&mut self) { + let acked_range_from_zero = self + .used + .get_mut(&0) + .filter(|(_, state)| *state == RangeState::Acked) + .map(|(len, _)| *len); + + if let Some(len_from_zero) = acked_range_from_zero { + let mut to_remove = SmallVec::<[_; 8]>::new(); + + let mut new_len_from_zero = len_from_zero; + + // See if there's another Acked range entry contiguous to this one + while let Some((next_len, _)) = self + .used + .get(&new_len_from_zero) + .filter(|(_, state)| *state == RangeState::Acked) + { + to_remove.push(new_len_from_zero); + new_len_from_zero += *next_len; + } + + if len_from_zero != new_len_from_zero { + self.used.get_mut(&0).expect("must be there").0 = new_len_from_zero; + } + + for val in to_remove { + self.used.remove(&val); + } + } + } + + fn mark_range(&mut self, off: u64, len: usize, state: RangeState) { + if len == 0 { + qinfo!("mark 0-length range at {}", off); + return; + } + + let subranges = self.chunk_range_on_edges(off, len as u64, state); + + for (sub_off, sub_len, sub_state) in subranges { + self.used.insert(sub_off, (sub_len, sub_state)); + } + + self.coalesce_acked_from_zero() + } + + fn unmark_range(&mut self, off: u64, len: usize) { + if len == 0 { + qdebug!("unmark 0-length range at {}", off); + return; + } + + let len = u64::try_from(len).unwrap(); + let end_off = off + len; + + let mut to_remove = SmallVec::<[_; 8]>::new(); + let mut to_add = None; + + // Walk backwards through possibly affected existing ranges + for (cur_off, (cur_len, cur_state)) in self.used.range_mut(..off + len).rev() { + // Maybe fixup range preceding the removed range + if *cur_off < off { + // Check for overlap + if *cur_off + *cur_len > off { + if *cur_state == RangeState::Acked { + qdebug!( + "Attempted to unmark Acked range {}-{} with unmark_range {}-{}", + cur_off, + cur_len, + off, + off + len + ); + } else { + *cur_len = off - cur_off; + } + } + break; + } + + if *cur_state == RangeState::Acked { + qdebug!( + "Attempted to unmark Acked range {}-{} with unmark_range {}-{}", + cur_off, + cur_len, + off, + off + len + ); + continue; + } + + // Add a new range for old subrange extending beyond + // to-be-unmarked range + let cur_end_off = cur_off + *cur_len; + if cur_end_off > end_off { + let new_cur_off = off + len; + let new_cur_len = cur_end_off - end_off; + assert_eq!(to_add, None); + to_add = Some((new_cur_off, new_cur_len, *cur_state)); + } + + to_remove.push(*cur_off); + } + + for remove_off in to_remove { + self.used.remove(&remove_off); + } + + if let Some((new_cur_off, new_cur_len, cur_state)) = to_add { + self.used.insert(new_cur_off, (new_cur_len, cur_state)); + } + } + + /// Unmark all sent ranges. + pub fn unmark_sent(&mut self) { + self.unmark_range(0, usize::try_from(self.highest_offset()).unwrap()); + } +} + +/// Buffer to contain queued bytes and track their state. +#[derive(Debug, Default, PartialEq)] +pub struct TxBuffer { + retired: u64, // contig acked bytes, no longer in buffer + send_buf: VecDeque<u8>, // buffer of not-acked bytes + ranges: RangeTracker, // ranges in buffer that have been sent or acked +} + +impl TxBuffer { + pub fn new() -> Self { + Self::default() + } + + /// Attempt to add some or all of the passed-in buffer to the TxBuffer. + pub fn send(&mut self, buf: &[u8]) -> usize { + let can_buffer = min(SEND_BUFFER_SIZE - self.buffered(), buf.len()); + if can_buffer > 0 { + self.send_buf.extend(&buf[..can_buffer]); + assert!(self.send_buf.len() <= SEND_BUFFER_SIZE); + } + can_buffer + } + + pub fn next_bytes(&self) -> Option<(u64, &[u8])> { + let (start, maybe_len) = self.ranges.first_unmarked_range(); + + if start == self.retired + u64::try_from(self.buffered()).unwrap() { + return None; + } + + // Convert from ranges-relative-to-zero to + // ranges-relative-to-buffer-start + let buff_off = usize::try_from(start - self.retired).unwrap(); + + // Deque returns two slices. Create a subslice from whichever + // one contains the first unmarked data. + let slc = if buff_off < self.send_buf.as_slices().0.len() { + &self.send_buf.as_slices().0[buff_off..] + } else { + &self.send_buf.as_slices().1[buff_off - self.send_buf.as_slices().0.len()..] + }; + + let len = if let Some(range_len) = maybe_len { + // Truncate if range crosses deque slices + min(usize::try_from(range_len).unwrap(), slc.len()) + } else { + slc.len() + }; + + debug_assert!(len > 0); + debug_assert!(len <= slc.len()); + + Some((start, &slc[..len])) + } + + pub fn mark_as_sent(&mut self, offset: u64, len: usize) { + self.ranges.mark_range(offset, len, RangeState::Sent) + } + + pub fn mark_as_acked(&mut self, offset: u64, len: usize) { + self.ranges.mark_range(offset, len, RangeState::Acked); + + // We can drop contig acked range from the buffer + let new_retirable = self.ranges.acked_from_zero() - self.retired; + debug_assert!(new_retirable <= self.buffered() as u64); + let keep_len = + self.buffered() - usize::try_from(new_retirable).expect("should fit in usize"); + + // Truncate front + self.send_buf.rotate_left(self.buffered() - keep_len); + self.send_buf.truncate(keep_len); + + self.retired += new_retirable; + } + + pub fn mark_as_lost(&mut self, offset: u64, len: usize) { + self.ranges.unmark_range(offset, len) + } + + /// Forget about anything that was marked as sent. + pub fn unmark_sent(&mut self) { + self.ranges.unmark_sent(); + } + + fn data_limit(&self) -> u64 { + self.buffered() as u64 + self.retired + } + + fn buffered(&self) -> usize { + self.send_buf.len() + } + + fn avail(&self) -> usize { + SEND_BUFFER_SIZE - self.buffered() + } + + pub fn highest_sent(&self) -> u64 { + self.ranges.highest_offset() + } +} + +/// QUIC sending stream states, based on -transport 3.1. +#[derive(Debug, PartialEq)] +enum SendStreamState { + Ready, + Send { + send_buf: TxBuffer, + }, + DataSent { + send_buf: TxBuffer, + final_size: u64, + fin_sent: bool, + fin_acked: bool, + }, + DataRecvd { + final_size: u64, + }, + ResetSent, + ResetRecvd, +} + +impl SendStreamState { + fn tx_buf(&self) -> Option<&TxBuffer> { + match self { + Self::Send { send_buf } | Self::DataSent { send_buf, .. } => Some(send_buf), + Self::Ready | Self::DataRecvd { .. } | Self::ResetSent | Self::ResetRecvd => None, + } + } + + fn tx_buf_mut(&mut self) -> Option<&mut TxBuffer> { + match self { + Self::Send { send_buf } | Self::DataSent { send_buf, .. } => Some(send_buf), + Self::Ready | Self::DataRecvd { .. } | Self::ResetSent | Self::ResetRecvd => None, + } + } + + fn tx_avail(&self) -> u64 { + match self { + // In Ready, TxBuffer not yet allocated but size is known + Self::Ready => SEND_BUFFER_SIZE.try_into().unwrap(), + Self::Send { send_buf } | Self::DataSent { send_buf, .. } => { + send_buf.avail().try_into().unwrap() + } + Self::DataRecvd { .. } | Self::ResetSent | Self::ResetRecvd => 0, + } + } + + fn final_size(&self) -> Option<u64> { + match self { + Self::DataSent { final_size, .. } | Self::DataRecvd { final_size } => Some(*final_size), + Self::Ready | Self::Send { .. } | Self::ResetSent | Self::ResetRecvd => None, + } + } + + fn name(&self) -> &str { + match self { + Self::Ready => "Ready", + Self::Send { .. } => "Send", + Self::DataSent { .. } => "DataSent", + Self::DataRecvd { .. } => "DataRecvd", + Self::ResetSent => "ResetSent", + Self::ResetRecvd => "ResetRecvd", + } + } + + fn transition(&mut self, new_state: Self) { + qtrace!("SendStream state {} -> {}", self.name(), new_state.name()); + *self = new_state; + } +} + +/// Implement a QUIC send stream. +#[derive(Debug)] +pub struct SendStream { + stream_id: StreamId, + max_stream_data: u64, + state: SendStreamState, + flow_mgr: Rc<RefCell<FlowMgr>>, + conn_events: ConnectionEvents, +} + +impl SendStream { + pub fn new( + stream_id: StreamId, + max_stream_data: u64, + flow_mgr: Rc<RefCell<FlowMgr>>, + conn_events: ConnectionEvents, + ) -> Self { + let ss = Self { + stream_id, + max_stream_data, + state: SendStreamState::Ready, + flow_mgr, + conn_events, + }; + if ss.avail() > 0 { + ss.conn_events.send_stream_writable(stream_id); + } + ss + } + + /// Return the next range to be sent, if any. + pub fn next_bytes(&mut self) -> Option<(u64, &[u8])> { + match self.state { + SendStreamState::Send { ref send_buf } => send_buf.next_bytes(), + SendStreamState::DataSent { + ref send_buf, + fin_sent, + final_size, + .. + } => { + let bytes = send_buf.next_bytes(); + if bytes.is_some() { + // Must be a resend + bytes + } else if fin_sent { + None + } else { + // Send empty stream frame with fin set + Some((final_size, &[])) + } + } + SendStreamState::Ready + | SendStreamState::DataRecvd { .. } + | SendStreamState::ResetSent + | SendStreamState::ResetRecvd => None, + } + } + + /// Calculate how many bytes (length) can fit into available space and whether + /// the remainder of the space can be filled (or if a length field is needed). + fn length_and_fill(data_len: usize, space: usize) -> (usize, bool) { + if data_len >= space { + // Either more data than space allows, or an exact fit. + qtrace!("SendStream::length_and_fill fill {}", space); + return (space, true); + } + + // Estimate size of the length field based on the available space, + // less 1, which is the worst case. + let length = min(space.saturating_sub(1), data_len); + let length_len = Encoder::varint_len(u64::try_from(length).unwrap()); + if length_len > space { + qtrace!( + "SendStream::length_and_fill no room for length of {} in {}", + length, + space + ); + return (0, false); + } + + let length = min(data_len, space - length_len); + qtrace!("SendStream::length_and_fill {} in {}", length, space); + (length, false) + } + + pub fn write_frame(&mut self, builder: &mut PacketBuilder) -> Option<RecoveryToken> { + let id = self.stream_id; + let final_size = self.final_size(); + if let Some((offset, data)) = self.next_bytes() { + let overhead = 1 // Frame type + + Encoder::varint_len(id.as_u64()) + + if offset > 0 { + Encoder::varint_len(offset) + } else { + 0 + }; + if overhead > builder.remaining() { + qtrace!("SendStream::write_frame no space for header"); + return None; + } + + let (length, fill) = Self::length_and_fill(data.len(), builder.remaining() - overhead); + let fin = final_size.map_or(false, |fs| fs == offset + u64::try_from(length).unwrap()); + if length == 0 && !fin { + qtrace!("SendStream::write_frame no data, no fin"); + return None; + } + + // Write the stream out. + builder.encode_varint(Frame::stream_type(fin, offset > 0, fill)); + builder.encode_varint(id.as_u64()); + if offset > 0 { + builder.encode_varint(offset); + } + if fill { + builder.encode(&data[..length]); + } else { + builder.encode_vvec(&data[..length]); + } + + self.mark_as_sent(offset, length, fin); + Some(RecoveryToken::Stream(StreamRecoveryToken { + id, + offset, + length, + fin, + })) + } else { + None + } + } + + pub fn mark_as_sent(&mut self, offset: u64, len: usize, fin: bool) { + if let Some(buf) = self.state.tx_buf_mut() { + buf.mark_as_sent(offset, len); + self.send_blocked_if_space_needed(0); + }; + + if fin { + if let SendStreamState::DataSent { fin_sent, .. } = &mut self.state { + *fin_sent = true; + } + } + } + + pub fn mark_as_acked(&mut self, offset: u64, len: usize, fin: bool) { + match self.state { + SendStreamState::Send { ref mut send_buf } => { + send_buf.mark_as_acked(offset, len); + if self.avail() > 0 { + self.conn_events.send_stream_writable(self.stream_id) + } + } + SendStreamState::DataSent { + ref mut send_buf, + final_size, + ref mut fin_acked, + .. + } => { + send_buf.mark_as_acked(offset, len); + if fin { + *fin_acked = true; + } + if *fin_acked && send_buf.buffered() == 0 { + self.conn_events.send_stream_complete(self.stream_id); + self.state + .transition(SendStreamState::DataRecvd { final_size }); + } + } + _ => qtrace!("mark_as_acked called from state {}", self.state.name()), + } + } + + pub fn mark_as_lost(&mut self, offset: u64, len: usize, fin: bool) { + if let Some(buf) = self.state.tx_buf_mut() { + buf.mark_as_lost(offset, len); + } + + if fin { + if let SendStreamState::DataSent { + fin_sent, + fin_acked, + .. + } = &mut self.state + { + *fin_sent = *fin_acked; + } + } + } + + pub fn final_size(&self) -> Option<u64> { + self.state.final_size() + } + + /// Stream credit available + pub fn credit_avail(&self) -> u64 { + if self.state == SendStreamState::Ready { + self.max_stream_data + } else { + self.state + .tx_buf() + .map_or(0, |tx| self.max_stream_data - tx.data_limit()) + } + } + + /// Bytes sendable on stream. Constrained by stream credit available, + /// connection credit available, and space in the tx buffer. + pub fn avail(&self) -> usize { + min( + min(self.state.tx_avail(), self.credit_avail()), + self.flow_mgr.borrow().conn_credit_avail(), + ) + .try_into() + .unwrap() + } + + pub fn max_stream_data(&self) -> u64 { + self.max_stream_data + } + + pub fn set_max_stream_data(&mut self, value: u64) { + let stream_was_blocked = self.avail() == 0; + self.max_stream_data = max(self.max_stream_data, value); + if stream_was_blocked && self.avail() > 0 { + self.conn_events.send_stream_writable(self.stream_id) + } + } + + pub fn reset_acked(&mut self) { + match self.state { + SendStreamState::Ready + | SendStreamState::Send { .. } + | SendStreamState::DataSent { .. } + | SendStreamState::DataRecvd { .. } => { + qtrace!("Reset acked while in {} state?", self.state.name()) + } + SendStreamState::ResetSent => self.state.transition(SendStreamState::ResetRecvd), + SendStreamState::ResetRecvd => qtrace!("already in ResetRecvd state"), + }; + } + + pub fn is_terminal(&self) -> bool { + matches!(self.state, SendStreamState::DataRecvd { .. } | SendStreamState::ResetRecvd) + } + + pub fn send(&mut self, buf: &[u8]) -> Res<usize> { + self.send_internal(buf, false) + } + + pub fn send_atomic(&mut self, buf: &[u8]) -> Res<usize> { + self.send_internal(buf, true) + } + + fn send_blocked_if_space_needed(&mut self, needed_space: u64) { + if self.credit_avail() <= needed_space { + self.flow_mgr + .borrow_mut() + .stream_data_blocked(self.stream_id, self.max_stream_data); + } + + if self.flow_mgr.borrow().conn_credit_avail() <= needed_space { + self.flow_mgr.borrow_mut().data_blocked(); + } + } + + fn send_internal(&mut self, buf: &[u8], atomic: bool) -> Res<usize> { + if buf.is_empty() { + qerror!("zero-length send on stream {}", self.stream_id.as_u64()); + return Err(Error::InvalidInput); + } + + if let SendStreamState::Ready = self.state { + self.state.transition(SendStreamState::Send { + send_buf: TxBuffer::new(), + }); + } + + if !matches!(self.state, SendStreamState::Send{..}) { + return Err(Error::FinalSizeError); + } + + let buf = if buf.is_empty() || (self.avail() == 0) { + return Ok(0); + } else if self.avail() < buf.len() { + if atomic { + self.send_blocked_if_space_needed(buf.len() as u64); + return Ok(0); + } else { + &buf[..self.avail()] + } + } else { + buf + }; + + let sent = match &mut self.state { + SendStreamState::Ready => unreachable!(), + SendStreamState::Send { send_buf } => send_buf.send(buf), + _ => return Err(Error::FinalSizeError), + }; + + self.flow_mgr + .borrow_mut() + .conn_increase_credit_used(sent as u64); + + Ok(sent) + } + + pub fn close(&mut self) { + match &mut self.state { + SendStreamState::Ready => { + self.state.transition(SendStreamState::DataSent { + send_buf: TxBuffer::new(), + final_size: 0, + fin_sent: false, + fin_acked: false, + }); + } + SendStreamState::Send { send_buf } => { + let final_size = send_buf.retired + send_buf.buffered() as u64; + let owned_buf = mem::replace(send_buf, TxBuffer::new()); + self.state.transition(SendStreamState::DataSent { + send_buf: owned_buf, + final_size, + fin_sent: false, + fin_acked: false, + }); + } + SendStreamState::DataSent { .. } => qtrace!("already in DataSent state"), + SendStreamState::DataRecvd { .. } => qtrace!("already in DataRecvd state"), + SendStreamState::ResetSent => qtrace!("already in ResetSent state"), + SendStreamState::ResetRecvd => qtrace!("already in ResetRecvd state"), + } + } + + pub fn reset(&mut self, err: AppError) { + match &self.state { + SendStreamState::Ready => { + self.flow_mgr + .borrow_mut() + .stream_reset(self.stream_id, err, 0); + + self.state.transition(SendStreamState::ResetSent); + } + SendStreamState::Send { send_buf } => { + self.flow_mgr.borrow_mut().stream_reset( + self.stream_id, + err, + send_buf.highest_sent(), + ); + + self.state.transition(SendStreamState::ResetSent); + } + SendStreamState::DataSent { final_size, .. } => { + self.flow_mgr + .borrow_mut() + .stream_reset(self.stream_id, err, *final_size); + + self.state.transition(SendStreamState::ResetSent); + } + SendStreamState::DataRecvd { .. } => qtrace!("already in DataRecvd state"), + SendStreamState::ResetSent => qtrace!("already in ResetSent state"), + SendStreamState::ResetRecvd => qtrace!("already in ResetRecvd state"), + }; + } +} + +#[derive(Debug, Default)] +pub(crate) struct SendStreams(IndexMap<StreamId, SendStream>); + +impl SendStreams { + pub fn get(&self, id: StreamId) -> Res<&SendStream> { + self.0.get(&id).ok_or(Error::InvalidStreamId) + } + + pub fn get_mut(&mut self, id: StreamId) -> Res<&mut SendStream> { + self.0.get_mut(&id).ok_or(Error::InvalidStreamId) + } + + pub fn exists(&self, id: StreamId) -> bool { + self.0.contains_key(&id) + } + + pub fn insert(&mut self, id: StreamId, stream: SendStream) { + self.0.insert(id, stream); + } + + pub fn acked(&mut self, token: &StreamRecoveryToken) { + if let Some(ss) = self.0.get_mut(&token.id) { + ss.mark_as_acked(token.offset, token.length, token.fin); + } + } + + pub fn reset_acked(&mut self, id: StreamId) { + if let Some(ss) = self.0.get_mut(&id) { + ss.reset_acked() + } + } + + pub fn lost(&mut self, token: &StreamRecoveryToken) { + if let Some(ss) = self.0.get_mut(&token.id) { + ss.mark_as_lost(token.offset, token.length, token.fin); + } + } + + pub fn clear(&mut self) { + self.0.clear() + } + + pub fn clear_terminal(&mut self) { + self.0.retain(|_, stream| !stream.is_terminal()) + } + + pub(crate) fn write_frames( + &mut self, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) { + for (_, stream) in self { + if let Some(t) = stream.write_frame(builder) { + tokens.push(t); + stats.stream += 1; + } + } + } +} + +impl<'a> IntoIterator for &'a mut SendStreams { + type Item = (&'a StreamId, &'a mut SendStream); + type IntoIter = indexmap::map::IterMut<'a, StreamId, SendStream>; + + fn into_iter(self) -> indexmap::map::IterMut<'a, StreamId, SendStream> { + self.0.iter_mut() + } +} + +#[derive(Debug, Clone)] +pub struct StreamRecoveryToken { + pub(crate) id: StreamId, + offset: u64, + length: usize, + fin: bool, +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::events::ConnectionEvent; + use neqo_common::{event::Provider, hex_with_len, qtrace}; + + #[test] + fn test_mark_range() { + let mut rt = RangeTracker::default(); + + // ranges can go from nothing->Sent if queued for retrans and then + // acks arrive + rt.mark_range(5, 5, RangeState::Acked); + assert_eq!(rt.highest_offset(), 10); + assert_eq!(rt.acked_from_zero(), 0); + rt.mark_range(10, 4, RangeState::Acked); + assert_eq!(rt.highest_offset(), 14); + assert_eq!(rt.acked_from_zero(), 0); + + rt.mark_range(0, 5, RangeState::Sent); + assert_eq!(rt.highest_offset(), 14); + assert_eq!(rt.acked_from_zero(), 0); + rt.mark_range(0, 5, RangeState::Acked); + assert_eq!(rt.highest_offset(), 14); + assert_eq!(rt.acked_from_zero(), 14); + + rt.mark_range(12, 20, RangeState::Acked); + assert_eq!(rt.highest_offset(), 32); + assert_eq!(rt.acked_from_zero(), 32); + + // ack the lot + rt.mark_range(0, 400, RangeState::Acked); + assert_eq!(rt.highest_offset(), 400); + assert_eq!(rt.acked_from_zero(), 400); + + // acked trumps sent + rt.mark_range(0, 200, RangeState::Sent); + assert_eq!(rt.highest_offset(), 400); + assert_eq!(rt.acked_from_zero(), 400); + } + + #[test] + fn unmark_sent_start() { + let mut rt = RangeTracker::default(); + + rt.mark_range(0, 5, RangeState::Sent); + assert_eq!(rt.highest_offset(), 5); + assert_eq!(rt.acked_from_zero(), 0); + + rt.unmark_sent(); + assert_eq!(rt.highest_offset(), 0); + assert_eq!(rt.acked_from_zero(), 0); + assert_eq!(rt.first_unmarked_range(), (0, None)); + } + + #[test] + fn unmark_sent_middle() { + let mut rt = RangeTracker::default(); + + rt.mark_range(0, 5, RangeState::Acked); + assert_eq!(rt.highest_offset(), 5); + assert_eq!(rt.acked_from_zero(), 5); + rt.mark_range(5, 5, RangeState::Sent); + assert_eq!(rt.highest_offset(), 10); + assert_eq!(rt.acked_from_zero(), 5); + rt.mark_range(10, 5, RangeState::Acked); + assert_eq!(rt.highest_offset(), 15); + assert_eq!(rt.acked_from_zero(), 5); + assert_eq!(rt.first_unmarked_range(), (15, None)); + + rt.unmark_sent(); + assert_eq!(rt.highest_offset(), 15); + assert_eq!(rt.acked_from_zero(), 5); + assert_eq!(rt.first_unmarked_range(), (5, Some(5))); + } + + #[test] + fn unmark_sent_end() { + let mut rt = RangeTracker::default(); + + rt.mark_range(0, 5, RangeState::Acked); + assert_eq!(rt.highest_offset(), 5); + assert_eq!(rt.acked_from_zero(), 5); + rt.mark_range(5, 5, RangeState::Sent); + assert_eq!(rt.highest_offset(), 10); + assert_eq!(rt.acked_from_zero(), 5); + assert_eq!(rt.first_unmarked_range(), (10, None)); + + rt.unmark_sent(); + assert_eq!(rt.highest_offset(), 5); + assert_eq!(rt.acked_from_zero(), 5); + assert_eq!(rt.first_unmarked_range(), (5, None)); + } + + #[test] + fn truncate_front() { + let mut v = VecDeque::new(); + v.push_back(5); + v.push_back(6); + v.push_back(7); + v.push_front(4usize); + + v.rotate_left(1); + v.truncate(3); + assert_eq!(*v.front().unwrap(), 5); + assert_eq!(*v.back().unwrap(), 7); + } + + #[test] + fn test_unmark_range() { + let mut rt = RangeTracker::default(); + + rt.mark_range(5, 5, RangeState::Acked); + rt.mark_range(10, 5, RangeState::Sent); + + // Should unmark sent but not acked range + rt.unmark_range(7, 6); + + let res = rt.first_unmarked_range(); + assert_eq!(res, (0, Some(5))); + assert_eq!( + rt.used.iter().next().unwrap(), + (&5, &(5, RangeState::Acked)) + ); + assert_eq!( + rt.used.iter().nth(1).unwrap(), + (&13, &(2, RangeState::Sent)) + ); + assert!(rt.used.iter().nth(2).is_none()); + rt.mark_range(0, 5, RangeState::Sent); + + let res = rt.first_unmarked_range(); + assert_eq!(res, (10, Some(3))); + rt.mark_range(10, 3, RangeState::Sent); + + let res = rt.first_unmarked_range(); + assert_eq!(res, (15, None)); + } + + #[test] + #[allow(clippy::cognitive_complexity)] + fn tx_buffer_next_bytes_1() { + let mut txb = TxBuffer::new(); + + assert_eq!(txb.avail(), SEND_BUFFER_SIZE); + + // Fill the buffer + assert_eq!(txb.send(&[1; SEND_BUFFER_SIZE * 2]), SEND_BUFFER_SIZE); + assert!(matches!(txb.next_bytes(), + Some((0, x)) if x.len()==SEND_BUFFER_SIZE + && x.iter().all(|ch| *ch == 1))); + + // Mark almost all as sent. Get what's left + let one_byte_from_end = SEND_BUFFER_SIZE as u64 - 1; + txb.mark_as_sent(0, one_byte_from_end as usize); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 1 + && start == one_byte_from_end + && x.iter().all(|ch| *ch == 1))); + + // Mark all as sent. Get nothing + txb.mark_as_sent(0, SEND_BUFFER_SIZE); + assert!(matches!(txb.next_bytes(), None)); + + // Mark as lost. Get it again + txb.mark_as_lost(one_byte_from_end, 1); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 1 + && start == one_byte_from_end + && x.iter().all(|ch| *ch == 1))); + + // Mark a larger range lost, including beyond what's in the buffer even. + // Get a little more + let five_bytes_from_end = SEND_BUFFER_SIZE as u64 - 5; + txb.mark_as_lost(five_bytes_from_end, 100); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 5 + && start == five_bytes_from_end + && x.iter().all(|ch| *ch == 1))); + + // Contig acked range at start means it can be removed from buffer + // Impl of vecdeque should now result in a split buffer when more data + // is sent + txb.mark_as_acked(0, five_bytes_from_end as usize); + assert_eq!(txb.send(&[2; 30]), 30); + // Just get 5 even though there is more + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 5 + && start == five_bytes_from_end + && x.iter().all(|ch| *ch == 1))); + assert_eq!(txb.retired, five_bytes_from_end); + assert_eq!(txb.buffered(), 35); + + // Marking that bit as sent should let the last contig bit be returned + // when called again + txb.mark_as_sent(five_bytes_from_end, 5); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 30 + && start == SEND_BUFFER_SIZE as u64 + && x.iter().all(|ch| *ch == 2))); + } + + #[test] + fn tx_buffer_next_bytes_2() { + let mut txb = TxBuffer::new(); + + assert_eq!(txb.avail(), SEND_BUFFER_SIZE); + + // Fill the buffer + assert_eq!(txb.send(&[1; SEND_BUFFER_SIZE * 2]), SEND_BUFFER_SIZE); + assert!(matches!(txb.next_bytes(), + Some((0, x)) if x.len()==SEND_BUFFER_SIZE + && x.iter().all(|ch| *ch == 1))); + + // As above + let forty_bytes_from_end = SEND_BUFFER_SIZE as u64 - 40; + + txb.mark_as_acked(0, forty_bytes_from_end as usize); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 40 + && start == forty_bytes_from_end + )); + + // Valid new data placed in split locations + assert_eq!(txb.send(&[2; 100]), 100); + + // Mark a little more as sent + txb.mark_as_sent(forty_bytes_from_end, 10); + let thirty_bytes_from_end = forty_bytes_from_end + 10; + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 30 + && start == thirty_bytes_from_end + && x.iter().all(|ch| *ch == 1))); + + // Mark a range 'A' in second slice as sent. Should still return the same + let range_a_start = SEND_BUFFER_SIZE as u64 + 30; + let range_a_end = range_a_start + 10; + txb.mark_as_sent(range_a_start, 10); + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 30 + && start == thirty_bytes_from_end + && x.iter().all(|ch| *ch == 1))); + + // Ack entire first slice and into second slice + let ten_bytes_past_end = SEND_BUFFER_SIZE as u64 + 10; + txb.mark_as_acked(0, ten_bytes_past_end as usize); + + // Get up to marked range A + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 20 + && start == ten_bytes_past_end + && x.iter().all(|ch| *ch == 2))); + + txb.mark_as_sent(ten_bytes_past_end, 20); + + // Get bit after earlier marked range A + assert!(matches!(txb.next_bytes(), + Some((start, x)) if x.len() == 60 + && start == range_a_end + && x.iter().all(|ch| *ch == 2))); + + // No more bytes. + txb.mark_as_sent(range_a_end, 60); + assert!(matches!(txb.next_bytes(), None)); + } + + #[test] + fn test_stream_tx() { + let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); + flow_mgr.borrow_mut().conn_increase_max_credit(4096); + let conn_events = ConnectionEvents::default(); + + let mut s = SendStream::new(4.into(), 1024, Rc::clone(&flow_mgr), conn_events); + + let res = s.send(&[4; 100]).unwrap(); + assert_eq!(res, 100); + s.mark_as_sent(0, 50, false); + assert_eq!(s.state.tx_buf().unwrap().data_limit(), 100); + + // Should hit stream flow control limit before filling up send buffer + let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap(); + assert_eq!(res, 1024 - 100); + + // should do nothing, max stream data already 1024 + s.set_max_stream_data(1024); + let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap(); + assert_eq!(res, 0); + + // should now hit the conn flow control (4096) + s.set_max_stream_data(1_048_576); + let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap(); + assert_eq!(res, 3072); + + // should now hit the tx buffer size + flow_mgr + .borrow_mut() + .conn_increase_max_credit(SEND_BUFFER_SIZE as u64); + let res = s.send(&[4; SEND_BUFFER_SIZE + 100]).unwrap(); + assert_eq!(res, SEND_BUFFER_SIZE - 4096); + + // TODO(agrover@mozilla.com): test ooo acks somehow + s.mark_as_acked(0, 40, false); + } + + #[test] + fn test_tx_buffer_acks() { + let mut tx = TxBuffer::new(); + assert_eq!(tx.send(&[4; 100]), 100); + let res = tx.next_bytes().unwrap(); + assert_eq!(res.0, 0); + assert_eq!(res.1.len(), 100); + tx.mark_as_sent(0, 100); + let res = tx.next_bytes(); + assert_eq!(res, None); + + tx.mark_as_acked(0, 100); + let res = tx.next_bytes(); + assert_eq!(res, None); + } + + #[test] + fn send_stream_writable_event_gen() { + let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); + flow_mgr.borrow_mut().conn_increase_max_credit(2); + let mut conn_events = ConnectionEvents::default(); + + let mut s = SendStream::new(4.into(), 0, Rc::clone(&flow_mgr), conn_events.clone()); + + // Stream is initially blocked (conn:2, stream:0) + // and will not accept data. + assert_eq!(s.send(b"hi").unwrap(), 0); + + // increasing to (conn:2, stream:2) will allow 2 bytes, and also + // generate a SendStreamWritable event. + s.set_max_stream_data(2); + let evts = conn_events.events().collect::<Vec<_>>(); + assert_eq!(evts.len(), 1); + assert!(matches!(evts[0], ConnectionEvent::SendStreamWritable{..})); + assert_eq!(s.send(b"hello").unwrap(), 2); + + // increasing to (conn:2, stream:4) will not generate an event or allow + // sending anything. + s.set_max_stream_data(4); + let evts = conn_events.events().collect::<Vec<_>>(); + assert_eq!(evts.len(), 0); + assert_eq!(s.send(b"hello").unwrap(), 0); + + // Increasing conn max (conn:4, stream:4) will unblock but not emit + // event b/c that happens in Connection::emit_frame() (tested in + // connection.rs) + assert_eq!(flow_mgr.borrow_mut().conn_increase_max_credit(4), true); + let evts = conn_events.events().collect::<Vec<_>>(); + assert_eq!(evts.len(), 0); + assert_eq!(s.avail(), 2); + assert_eq!(s.send(b"hello").unwrap(), 2); + + // No event because still blocked by conn + s.set_max_stream_data(1_000_000_000); + let evts = conn_events.events().collect::<Vec<_>>(); + assert_eq!(evts.len(), 0); + + // No event because happens in emit_frame() + flow_mgr + .borrow_mut() + .conn_increase_max_credit(1_000_000_000); + let evts = conn_events.events().collect::<Vec<_>>(); + assert_eq!(evts.len(), 0); + + // Unblocking both by a large amount will cause avail() to be limited by + // tx buffer size. + assert_eq!(s.avail(), SEND_BUFFER_SIZE - 4); + + assert_eq!( + s.send(&[b'a'; SEND_BUFFER_SIZE]).unwrap(), + SEND_BUFFER_SIZE - 4 + ); + + // No event because still blocked by tx buffer full + s.set_max_stream_data(2_000_000_000); + let evts = conn_events.events().collect::<Vec<_>>(); + assert_eq!(evts.len(), 0); + assert_eq!(s.send(b"hello").unwrap(), 0); + } + + #[test] + fn send_stream_writable_event_new_stream() { + let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); + flow_mgr.borrow_mut().conn_increase_max_credit(2); + let mut conn_events = ConnectionEvents::default(); + + let _s = SendStream::new(4.into(), 100, flow_mgr, conn_events.clone()); + + // Creating a new stream with conn and stream credits should result in + // an event. + let evts = conn_events.events().collect::<Vec<_>>(); + assert_eq!(evts.len(), 1); + assert!(matches!(evts[0], ConnectionEvent::SendStreamWritable{..})); + } + + #[test] + // Verify lost frames handle fin properly + fn send_stream_get_frame_data() { + let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); + flow_mgr.borrow_mut().conn_increase_max_credit(100); + let conn_events = ConnectionEvents::default(); + + let mut s = SendStream::new(0.into(), 100, flow_mgr, conn_events); + s.send(&[0; 10]).unwrap(); + s.close(); + + let mut ss = SendStreams::default(); + ss.insert(StreamId::from(0), s); + + let mut tokens = Vec::new(); + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + + // Write a small frame: no fin. + let written = builder.len(); + builder.set_limit(written + 6); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + assert_eq!(builder.len(), written + 6); + assert_eq!(tokens.len(), 1); + let f1_token = tokens.remove(0); + assert!(matches!(&f1_token, RecoveryToken::Stream(x) if !x.fin)); + + // Write the rest: fin. + let written = builder.len(); + builder.set_limit(written + 200); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + assert_eq!(builder.len(), written + 10); + assert_eq!(tokens.len(), 1); + let f2_token = tokens.remove(0); + assert!(matches!(&f2_token, RecoveryToken::Stream(x) if x.fin)); + + // Should be no more data to frame. + let written = builder.len(); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + assert_eq!(builder.len(), written); + assert!(tokens.is_empty()); + + // Mark frame 1 as lost + if let RecoveryToken::Stream(rt) = f1_token { + ss.lost(&rt); + } else { + panic!(); + } + + // Next frame should not set fin even though stream has fin but frame + // does not include end of stream + let written = builder.len(); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + assert_eq!(builder.len(), written + 7); // Needs a length this time. + assert_eq!(tokens.len(), 1); + let f4_token = tokens.remove(0); + assert!(matches!(&f4_token, RecoveryToken::Stream(x) if !x.fin)); + + // Mark frame 2 as lost + if let RecoveryToken::Stream(rt) = f2_token { + ss.lost(&rt); + } else { + panic!(); + } + + // Next frame should set fin because it includes end of stream + let written = builder.len(); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + assert_eq!(builder.len(), written + 10); + assert_eq!(tokens.len(), 1); + let f5_token = tokens.remove(0); + assert!(matches!(&f5_token, RecoveryToken::Stream(x) if x.fin)); + } + + #[test] + #[allow(clippy::cognitive_complexity)] + // Verify lost frames handle fin properly with zero length fin + fn send_stream_get_frame_zerolength_fin() { + let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); + flow_mgr.borrow_mut().conn_increase_max_credit(100); + let conn_events = ConnectionEvents::default(); + + let mut s = SendStream::new(0.into(), 100, flow_mgr, conn_events); + s.send(&[0; 10]).unwrap(); + + let mut ss = SendStreams::default(); + ss.insert(StreamId::from(0), s); + + let mut tokens = Vec::new(); + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + let f1_token = tokens.remove(0); + assert!(matches!(&f1_token, RecoveryToken::Stream(x) if x.offset == 0)); + assert!(matches!(&f1_token, RecoveryToken::Stream(x) if x.length == 10)); + assert!(matches!(&f1_token, RecoveryToken::Stream(x) if !x.fin)); + + // Should be no more data to frame + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + assert!(tokens.is_empty()); + + ss.get_mut(StreamId::from(0)).unwrap().close(); + + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + let f2_token = tokens.remove(0); + assert!(matches!(&f2_token, RecoveryToken::Stream(x) if x.offset == 10)); + assert!(matches!(&f2_token, RecoveryToken::Stream(x) if x.length == 0)); + assert!(matches!(&f2_token, RecoveryToken::Stream(x) if x.fin)); + + // Mark frame 2 as lost + if let RecoveryToken::Stream(rt) = f2_token { + ss.lost(&rt); + } else { + panic!(); + } + + // Next frame should set fin + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + let f3_token = tokens.remove(0); + assert!(matches!(&f3_token, RecoveryToken::Stream(x) if x.offset == 10)); + assert!(matches!(&f3_token, RecoveryToken::Stream(x) if x.length == 0)); + assert!(matches!(&f3_token, RecoveryToken::Stream(x) if x.fin)); + + // Mark frame 1 as lost + if let RecoveryToken::Stream(rt) = f1_token { + ss.lost(&rt); + } else { + panic!(); + } + + // Next frame should set fin and include all data + ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); + let f4_token = tokens.remove(0); + assert!(matches!(&f4_token, RecoveryToken::Stream(x) if x.offset == 0)); + assert!(matches!(&f4_token, RecoveryToken::Stream(x) if x.length == 10)); + assert!(matches!(&f4_token, RecoveryToken::Stream(x) if x.fin)); + } + + #[test] + fn send_atomic() { + let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); + flow_mgr.borrow_mut().conn_increase_max_credit(5); + let conn_events = ConnectionEvents::default(); + + let stream_id = StreamId::from(4); + let mut s = SendStream::new(stream_id, 0, Rc::clone(&flow_mgr), conn_events); + s.set_max_stream_data(2); + + // Stream is initially blocked (conn:5, stream:2) + // and will not accept atomic write of 3 bytes. + assert_eq!(s.send_atomic(b"abc").unwrap(), 0); + + // assert that STREAM_DATA_BLOCKED is sent. + assert_eq!( + flow_mgr.borrow_mut().next().unwrap(), + Frame::StreamDataBlocked { + stream_id, + stream_data_limit: 0x2 + } + ); + + // assert non-atomic write works + assert_eq!(s.send(b"abc").unwrap(), 2); + assert_eq!(s.next_bytes(), Some((0, &b"ab"[..]))); + // STREAM_DATA_BLOCKED is not sent yet. + assert!(flow_mgr.borrow_mut().next().is_none()); + + // STREAM_DATA_BLOCKED is queued once bytes using all credit are sent. + s.mark_as_sent(0, 2, false); + assert_eq!( + flow_mgr.borrow_mut().next().unwrap(), + Frame::StreamDataBlocked { + stream_id, + stream_data_limit: 0x2 + } + ); + + // increasing to (conn:5, stream:10) + s.set_max_stream_data(10); + // will not accept atomic write of 4 bytes. + assert_eq!(s.send_atomic(b"abcd").unwrap(), 0); + + // assert that STREAM_DATA_BLOCKED is sent. + assert_eq!( + flow_mgr.borrow_mut().next().unwrap(), + Frame::DataBlocked { data_limit: 0x5 } + ); + + // assert non-atomic write works + assert_eq!(s.send(b"abcd").unwrap(), 3); + assert_eq!(s.next_bytes(), Some((2, &b"abc"[..]))); + // DATA_BLOCKED is not sent yet. + assert!(flow_mgr.borrow_mut().next().is_none()); + + // DATA_BLOCKED is queued once bytes using all credit are sent. + s.mark_as_sent(2, 3, false); + assert_eq!( + flow_mgr.borrow_mut().next().unwrap(), + Frame::DataBlocked { data_limit: 0x5 } + ); + + // increasing to (conn:15, stream:15) + s.set_max_stream_data(15); + flow_mgr.borrow_mut().conn_increase_max_credit(15); + + // assert that atomic writing 10 byte works + assert_eq!(s.send_atomic(b"abcdefghij").unwrap(), 10); + } + + #[test] + fn ack_fin_first() { + const MESSAGE: &[u8] = b"hello"; + let len_u64 = u64::try_from(MESSAGE.len()).unwrap(); + + let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); + flow_mgr.borrow_mut().conn_increase_max_credit(len_u64); + let conn_events = ConnectionEvents::default(); + + let mut s = SendStream::new(StreamId::new(100), 0, Rc::clone(&flow_mgr), conn_events); + s.set_max_stream_data(len_u64); + + // Send all the data, then the fin. + let _ = s.send(MESSAGE).unwrap(); + s.mark_as_sent(0, MESSAGE.len(), false); + s.close(); + s.mark_as_sent(len_u64, 0, true); + + // Ack the fin, then the data. + s.mark_as_acked(len_u64, 0, true); + s.mark_as_acked(0, MESSAGE.len(), false); + assert!(s.is_terminal()); + } + + #[test] + fn ack_then_lose_fin() { + const MESSAGE: &[u8] = b"hello"; + let len_u64 = u64::try_from(MESSAGE.len()).unwrap(); + + let mut flow_mgr = FlowMgr::default(); + flow_mgr.conn_increase_max_credit(len_u64); + let conn_events = ConnectionEvents::default(); + + let id = StreamId::new(100); + let mut s = SendStream::new(id, 0, Rc::new(RefCell::new(flow_mgr)), conn_events); + s.set_max_stream_data(len_u64); + + // Send all the data, then the fin. + let _ = s.send(MESSAGE).unwrap(); + s.mark_as_sent(0, MESSAGE.len(), false); + s.close(); + s.mark_as_sent(len_u64, 0, true); + + // Ack the fin, then mark it lost. + s.mark_as_acked(len_u64, 0, true); + s.mark_as_lost(len_u64, 0, true); + + // No frame should be sent here. + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + assert!(s.write_frame(&mut builder).is_none()); + } + + /// Create a `SendStream` and force it into a state where it believes that + /// `offset` bytes have already been sent and acknowledged. + fn stream_with_sent(stream: u64, offset: usize) -> SendStream { + const MAX_VARINT: u64 = (1 << 62) - 1; + + let mut flow_mgr = FlowMgr::default(); + flow_mgr.conn_increase_max_credit(MAX_VARINT); + + let mut s = SendStream::new( + StreamId::from(stream), + MAX_VARINT, + Rc::new(RefCell::new(flow_mgr)), + ConnectionEvents::default(), + ); + + let mut send_buf = TxBuffer::new(); + send_buf.retired = u64::try_from(offset).unwrap(); + send_buf.ranges.mark_range(0, offset, RangeState::Acked); + s.state = SendStreamState::Send { send_buf }; + s + } + + fn frame_sent_sid(stream: u64, offset: usize, len: usize, fin: bool, space: usize) -> bool { + const BUF: &[u8] = &[0x42; 128]; + let mut s = stream_with_sent(stream, offset); + + // Now write out the proscribed data and maybe close. + if len > 0 { + s.send(&BUF[..len]).unwrap(); + } + if fin { + s.close(); + } + + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + let header_len = builder.len(); + builder.set_limit(header_len + space); + let token = s.write_frame(&mut builder); + qtrace!("STREAM frame: {}", hex_with_len(&builder[header_len..])); + token.is_some() + } + + fn frame_sent(offset: usize, len: usize, fin: bool, space: usize) -> bool { + frame_sent_sid(0, offset, len, fin, space) + } + + #[test] + fn stream_frame_empty() { + // Stream frames with empty data and no fin never work. + assert!(!frame_sent(10, 0, false, 2)); + assert!(!frame_sent(10, 0, false, 3)); + assert!(!frame_sent(10, 0, false, 4)); + assert!(!frame_sent(10, 0, false, 5)); + assert!(!frame_sent(10, 0, false, 100)); + + // Empty data with fin is only a problem if there is no space. + assert!(!frame_sent(0, 0, true, 1)); + assert!(frame_sent(0, 0, true, 2)); + assert!(!frame_sent(10, 0, true, 2)); + assert!(frame_sent(10, 0, true, 3)); + assert!(frame_sent(10, 0, true, 4)); + assert!(frame_sent(10, 0, true, 5)); + assert!(frame_sent(10, 0, true, 100)); + } + + #[test] + fn stream_frame_minimum() { + // Add minimum data + assert!(!frame_sent(10, 1, false, 3)); + assert!(!frame_sent(10, 1, true, 3)); + assert!(frame_sent(10, 1, false, 4)); + assert!(frame_sent(10, 1, true, 4)); + assert!(frame_sent(10, 1, false, 5)); + assert!(frame_sent(10, 1, true, 5)); + assert!(frame_sent(10, 1, false, 100)); + assert!(frame_sent(10, 1, true, 100)); + } + + #[test] + fn stream_frame_more() { + // Try more data + assert!(!frame_sent(10, 100, false, 3)); + assert!(!frame_sent(10, 100, true, 3)); + assert!(frame_sent(10, 100, false, 4)); + assert!(frame_sent(10, 100, true, 4)); + assert!(frame_sent(10, 100, false, 5)); + assert!(frame_sent(10, 100, true, 5)); + assert!(frame_sent(10, 100, false, 100)); + assert!(frame_sent(10, 100, true, 100)); + + assert!(frame_sent(10, 100, false, 1000)); + assert!(frame_sent(10, 100, true, 1000)); + } + + #[test] + fn stream_frame_big_id() { + // A value that encodes to the largest varint. + const BIG: u64 = 1 << 30; + const BIGSZ: usize = 1 << 30; + + assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 16)); + assert!(!frame_sent_sid(BIG, BIGSZ, 0, true, 16)); + assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 17)); + assert!(frame_sent_sid(BIG, BIGSZ, 0, true, 17)); + assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 18)); + assert!(frame_sent_sid(BIG, BIGSZ, 0, true, 18)); + + assert!(!frame_sent_sid(BIG, BIGSZ, 1, false, 17)); + assert!(!frame_sent_sid(BIG, BIGSZ, 1, true, 17)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 18)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 18)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 19)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 19)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 100)); + assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 100)); + } + + #[test] + fn stream_frame_16384() { + const DATA16384: &[u8] = &[0x43; 16384]; + + // 16383/16384 is an odd boundary in STREAM frame construction. + // That is the boundary where a length goes from 2 bytes to 4 bytes. + // If the data fits in the available space, then it is simple: + let mut s = stream_with_sent(0, 0); + s.send(DATA16384).unwrap(); + s.close(); + + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + let header_len = builder.len(); + builder.set_limit(header_len + DATA16384.len() + 2); + let token = s.write_frame(&mut builder); + assert!(token.is_some()); + // Expect STREAM + FIN only. + assert_eq!(&builder[header_len..header_len + 2], &[0b1001, 0]); + assert_eq!(&builder[header_len + 2..], DATA16384); + + s.mark_as_lost(0, DATA16384.len(), true); + + // However, if there is one extra byte of space, we will try to add a length. + // That length will then make the frame to be too large and the data will be + // truncated. The frame could carry one more byte of data, but it's a corner + // case we don't want to address as it should be rare (if not impossible). + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + let header_len = builder.len(); + builder.set_limit(header_len + DATA16384.len() + 3); + let token = s.write_frame(&mut builder); + assert!(token.is_some()); + // Expect STREAM + LEN + FIN. + assert_eq!( + &builder[header_len..header_len + 4], + &[0b1010, 0, 0x7f, 0xfd] + ); + assert_eq!( + &builder[header_len + 4..], + &DATA16384[..DATA16384.len() - 3] + ); + } + + #[test] + fn stream_frame_64() { + const DATA64: &[u8] = &[0x43; 64]; + + // Unlike 16383/16384, the boundary at 63/64 is easy because the difference + // is just one byte. We lose just the last byte when there is more space. + let mut s = stream_with_sent(0, 0); + s.send(DATA64).unwrap(); + s.close(); + + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + let header_len = builder.len(); + builder.set_limit(header_len + 66); + let token = s.write_frame(&mut builder); + assert!(token.is_some()); + // Expect STREAM + FIN only. + assert_eq!(&builder[header_len..header_len + 2], &[0b1001, 0]); + assert_eq!(&builder[header_len + 2..], DATA64); + + s.mark_as_lost(0, DATA64.len(), true); + + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + let header_len = builder.len(); + builder.set_limit(header_len + 67); + let token = s.write_frame(&mut builder); + assert!(token.is_some()); + // Expect STREAM + LEN, not FIN. + assert_eq!(&builder[header_len..header_len + 3], &[0b1010, 0, 63]); + assert_eq!(&builder[header_len + 3..], &DATA64[..63]); + } +} diff --git a/third_party/rust/neqo-transport/src/sender.rs b/third_party/rust/neqo-transport/src/sender.rs new file mode 100644 index 0000000000..15675f4b2b --- /dev/null +++ b/third_party/rust/neqo-transport/src/sender.rs @@ -0,0 +1,124 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] +#![allow(clippy::module_name_repetitions)] + +use crate::cc::{ + ClassicCongestionControl, CongestionControl, CongestionControlAlgorithm, NewReno, + MAX_DATAGRAM_SIZE, +}; +use crate::pace::Pacer; +use crate::tracking::SentPacket; +use neqo_common::qlog::NeqoQlog; + +use std::fmt::{self, Debug, Display}; +use std::time::{Duration, Instant}; + +/// The number of packets we allow to burst from the pacer. +pub const PACING_BURST_SIZE: usize = 2; + +#[derive(Debug)] +pub struct PacketSender { + cc: Box<dyn CongestionControl>, + pacer: Option<Pacer>, +} + +impl Display for PacketSender { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.cc)?; + if let Some(p) = &self.pacer { + write!(f, " {}", p)?; + } + Ok(()) + } +} + +impl PacketSender { + #[must_use] + pub fn new(alg: CongestionControlAlgorithm) -> Self { + Self { + cc: match alg { + CongestionControlAlgorithm::NewReno => { + Box::new(ClassicCongestionControl::new(NewReno::default())) + } + }, + pacer: None, + } + } + + pub fn set_qlog(&mut self, qlog: NeqoQlog) { + self.cc.set_qlog(qlog); + } + + #[cfg(test)] + #[must_use] + pub fn cwnd(&self) -> usize { + self.cc.cwnd() + } + + #[must_use] + pub fn cwnd_avail(&self) -> usize { + self.cc.cwnd_avail() + } + + // Multi-packet version of OnPacketAckedCC + pub fn on_packets_acked(&mut self, acked_pkts: &[SentPacket]) { + self.cc.on_packets_acked(acked_pkts); + } + + pub fn on_packets_lost( + &mut self, + first_rtt_sample_time: Option<Instant>, + prev_largest_acked_sent: Option<Instant>, + pto: Duration, + lost_packets: &[SentPacket], + ) { + self.cc.on_packets_lost( + first_rtt_sample_time, + prev_largest_acked_sent, + pto, + lost_packets, + ); + } + + pub fn discard(&mut self, pkt: &SentPacket) { + self.cc.discard(pkt); + } + + pub fn on_packet_sent(&mut self, pkt: &SentPacket, rtt: Duration) { + self.pacer + .as_mut() + .unwrap() + .spend(pkt.time_sent, rtt, self.cc.cwnd(), pkt.size); + self.cc.on_packet_sent(pkt); + } + + pub fn start_pacer(&mut self, now: Instant) { + // Start the pacer with a small burst size. + self.pacer = Some(Pacer::new( + now, + MAX_DATAGRAM_SIZE * PACING_BURST_SIZE, + MAX_DATAGRAM_SIZE, + )); + } + + #[must_use] + pub fn next_paced(&self, rtt: Duration) -> Option<Instant> { + // Only pace if there are bytes in flight. + if self.cc.bytes_in_flight() > 0 { + Some(self.pacer.as_ref().unwrap().next(rtt, self.cc.cwnd())) + } else { + None + } + } + + #[must_use] + pub fn recovery_packet(&self) -> bool { + self.cc.recovery_packet() + } +} diff --git a/third_party/rust/neqo-transport/src/server.rs b/third_party/rust/neqo-transport/src/server.rs new file mode 100644 index 0000000000..ef161a9fe4 --- /dev/null +++ b/third_party/rust/neqo-transport/src/server.rs @@ -0,0 +1,636 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// This file implements a server that can handle multiple connections. + +use neqo_common::{ + self as common, event::Provider, hex, qdebug, qerror, qinfo, qlog::NeqoQlog, qtrace, qwarn, + timer::Timer, Datagram, Decoder, Role, +}; +use neqo_crypto::{AntiReplay, Cipher, ZeroRttCheckResult, ZeroRttChecker}; + +pub use crate::addr_valid::ValidateAddress; +use crate::addr_valid::{AddressValidation, AddressValidationResult}; +use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdManager, ConnectionIdRef}; +use crate::connection::{Connection, Output, State}; +use crate::packet::{PacketBuilder, PacketType, PublicPacket}; +use crate::{ConnectionParameters, QuicVersion, Res}; + +use std::cell::RefCell; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::fs::OpenOptions; +use std::mem; +use std::net::SocketAddr; +use std::ops::{Deref, DerefMut}; +use std::path::PathBuf; +use std::rc::{Rc, Weak}; +use std::time::{Duration, Instant}; + +pub enum InitialResult { + Accept, + Drop, + Retry(Vec<u8>), +} + +/// MIN_INITIAL_PACKET_SIZE is the smallest packet that can be used to establish +/// a new connection across all QUIC versions this server supports. +const MIN_INITIAL_PACKET_SIZE: usize = 1200; +const TIMER_GRANULARITY: Duration = Duration::from_millis(10); +const TIMER_CAPACITY: usize = 16384; + +type StateRef = Rc<RefCell<ServerConnectionState>>; +type CidMgr = Rc<RefCell<dyn ConnectionIdManager>>; +type ConnectionTableRef = Rc<RefCell<HashMap<ConnectionId, StateRef>>>; + +#[derive(Debug)] +pub struct ServerConnectionState { + c: Connection, + active_attempt: Option<AttemptKey>, + last_timer: Instant, +} + +impl Deref for ServerConnectionState { + type Target = Connection; + fn deref(&self) -> &Self::Target { + &self.c + } +} + +impl DerefMut for ServerConnectionState { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.c + } +} + +/// A `AttemptKey` is used to disambiguate connection attempts. +/// Multiple connection attempts with the same key won't produce multiple connections. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +struct AttemptKey { + // Using the remote address is sufficient for disambiguation, + // until we support multiple local socket addresses. + remote_address: SocketAddr, + odcid: ConnectionId, +} + +/// A `ServerZeroRttChecker` is a simple wrapper around a single checker. +/// It uses `RefCell` so that the wrapped checker can be shared between +/// multiple connections created by the server. +#[derive(Clone, Debug)] +struct ServerZeroRttChecker { + checker: Rc<RefCell<Box<dyn ZeroRttChecker>>>, +} + +impl ServerZeroRttChecker { + pub fn new(checker: Box<dyn ZeroRttChecker>) -> Self { + Self { + checker: Rc::new(RefCell::new(checker)), + } + } +} + +impl ZeroRttChecker for ServerZeroRttChecker { + fn check(&self, token: &[u8]) -> ZeroRttCheckResult { + self.checker.borrow().check(token) + } +} + +/// `InitialDetails` holds important information for processing `Initial` packets. +struct InitialDetails { + src_cid: ConnectionId, + dst_cid: ConnectionId, + token: Vec<u8>, + quic_version: QuicVersion, +} + +impl InitialDetails { + fn new(packet: &PublicPacket) -> Self { + Self { + src_cid: ConnectionId::from(packet.scid()), + dst_cid: ConnectionId::from(packet.dcid()), + token: packet.token().to_vec(), + quic_version: packet.version().unwrap(), + } + } +} + +pub struct Server { + /// The names of certificates. + certs: Vec<String>, + /// The ALPN values that the server supports. + protocols: Vec<String>, + /// The cipher suites that the server supports. + ciphers: Vec<Cipher>, + /// Anti-replay configuration for 0-RTT. + anti_replay: AntiReplay, + /// A function for determining if 0-RTT can be accepted. + zero_rtt_checker: ServerZeroRttChecker, + /// A connection ID manager. + cid_manager: CidMgr, + /// Connection parameters. + conn_params: ConnectionParameters, + /// Active connection attempts, keyed by `AttemptKey`. Initial packets with + /// the same key are routed to the connection that was first accepted. + /// This is cleared out when the connection is closed or established. + active_attempts: HashMap<AttemptKey, StateRef>, + /// All connections, keyed by ConnectionId. + connections: ConnectionTableRef, + /// The connections that have new events. + active: HashSet<ActiveConnectionRef>, + /// The set of connections that need immediate processing. + waiting: VecDeque<StateRef>, + /// Outstanding timers for connections. + timers: Timer<StateRef>, + /// Address validation logic, which determines whether we send a Retry. + address_validation: Rc<RefCell<AddressValidation>>, + /// Directory to create qlog traces in + qlog_dir: Option<PathBuf>, +} + +impl Server { + /// Construct a new server. + /// * `now` is the time that the server is instantiated. + /// * `certs` is a list of the certificates that should be configured. + /// * `protocols` is the preference list of ALPN values. + /// * `anti_replay` is an anti-replay context. + /// * `zero_rtt_checker` determines whether 0-RTT should be accepted. This + /// will be passed the value of the `extra` argument that was passed to + /// `Connection::send_ticket` to see if it is OK. + /// * `cid_manager` is responsible for generating connection IDs and parsing them; + /// connection IDs produced by the manager cannot be zero-length. + pub fn new( + now: Instant, + certs: &[impl AsRef<str>], + protocols: &[impl AsRef<str>], + anti_replay: AntiReplay, + zero_rtt_checker: Box<dyn ZeroRttChecker>, + cid_manager: CidMgr, + conn_params: ConnectionParameters, + ) -> Res<Self> { + let validation = AddressValidation::new(now, ValidateAddress::Never)?; + Ok(Self { + certs: certs.iter().map(|x| String::from(x.as_ref())).collect(), + protocols: protocols.iter().map(|x| String::from(x.as_ref())).collect(), + ciphers: Vec::new(), + anti_replay, + zero_rtt_checker: ServerZeroRttChecker::new(zero_rtt_checker), + cid_manager, + conn_params, + active_attempts: HashMap::default(), + connections: Rc::default(), + active: HashSet::default(), + waiting: VecDeque::default(), + timers: Timer::new(now, TIMER_GRANULARITY, TIMER_CAPACITY), + address_validation: Rc::new(RefCell::new(validation)), + qlog_dir: None, + }) + } + + /// Set or clear directory to create logs of connection events in QLOG format. + pub fn set_qlog_dir(&mut self, dir: Option<PathBuf>) { + self.qlog_dir = dir; + } + + /// Set the policy for address validation. + pub fn set_validation(&mut self, v: ValidateAddress) { + self.address_validation.borrow_mut().set_validation(v); + } + + /// Set the cipher suites that should be used. Set an empty value to use + /// default values. + pub fn set_ciphers(&mut self, ciphers: impl AsRef<[Cipher]>) { + self.ciphers = Vec::from(ciphers.as_ref()); + } + + fn remove_timer(&mut self, c: &StateRef) { + let last = c.borrow().last_timer; + self.timers.remove(last, |t| Rc::ptr_eq(t, c)); + } + + fn process_connection( + &mut self, + c: StateRef, + dgram: Option<Datagram>, + now: Instant, + ) -> Option<Datagram> { + qtrace!([self], "Process connection {:?}", c); + let out = c.borrow_mut().process(dgram, now); + match out { + Output::Datagram(_) => { + qtrace!([self], "Sending packet, added to waiting connections"); + self.waiting.push_back(Rc::clone(&c)); + } + Output::Callback(delay) => { + let next = now + delay; + if next != c.borrow().last_timer { + qtrace!([self], "Change timer to {:?}", next); + self.remove_timer(&c); + c.borrow_mut().last_timer = next; + self.timers.add(next, Rc::clone(&c)); + } + } + _ => { + self.remove_timer(&c); + } + } + if c.borrow().has_events() { + qtrace!([self], "Connection active: {:?}", c); + self.active.insert(ActiveConnectionRef { c: Rc::clone(&c) }); + } + + if *c.borrow().state() > State::Handshaking { + // Remove any active connection attempt now that this is no longer handshaking. + if let Some(k) = c.borrow_mut().active_attempt.take() { + self.active_attempts.remove(&k); + } + } + + if matches!(c.borrow().state(), State::Closed(_)) { + c.borrow_mut().set_qlog(NeqoQlog::disabled()); + self.connections + .borrow_mut() + .retain(|_, v| !Rc::ptr_eq(v, &c)); + } + out.dgram() + } + + fn connection(&self, cid: &ConnectionIdRef) -> Option<StateRef> { + if let Some(c) = self.connections.borrow().get(&cid[..]) { + Some(Rc::clone(&c)) + } else { + None + } + } + + fn handle_initial( + &mut self, + initial: InitialDetails, + dgram: Datagram, + now: Instant, + ) -> Option<Datagram> { + qdebug!([self], "Handle initial"); + let res = self + .address_validation + .borrow() + .validate(&initial.token, dgram.source(), now); + match res { + AddressValidationResult::Invalid => None, + AddressValidationResult::Pass => self.connection_attempt(initial, dgram, None, now), + AddressValidationResult::ValidRetry(orig_dcid) => { + self.connection_attempt(initial, dgram, Some(orig_dcid), now) + } + AddressValidationResult::Validate => { + qinfo!([self], "Send retry for {:?}", initial.dst_cid); + + let res = self.address_validation.borrow().generate_retry_token( + &initial.dst_cid, + dgram.source(), + now, + ); + let token = if let Ok(t) = res { + t + } else { + qerror!([self], "unable to generate token, dropping packet"); + return None; + }; + let new_dcid = self.cid_manager.borrow_mut().generate_cid(); + let packet = PacketBuilder::retry( + initial.quic_version, + &initial.src_cid, + &new_dcid, + &token, + &initial.dst_cid, + ); + if let Ok(p) = packet { + let retry = Datagram::new(dgram.destination(), dgram.source(), p); + Some(retry) + } else { + qerror!([self], "unable to encode retry, dropping packet"); + None + } + } + } + } + + fn connection_attempt( + &mut self, + initial: InitialDetails, + dgram: Datagram, + orig_dcid: Option<ConnectionId>, + now: Instant, + ) -> Option<Datagram> { + let attempt_key = AttemptKey { + remote_address: dgram.source(), + odcid: orig_dcid.as_ref().unwrap_or(&initial.dst_cid).clone(), + }; + if let Some(c) = self.active_attempts.get(&attempt_key) { + qdebug!( + [self], + "Handle Initial for existing connection attempt {:?}", + attempt_key + ); + let c = Rc::clone(c); + self.process_connection(c, Some(dgram), now) + } else { + self.accept_connection(attempt_key, initial, dgram, orig_dcid, now) + } + } + + fn create_qlog_trace(&self, attempt_key: &AttemptKey) -> NeqoQlog { + if let Some(qlog_dir) = &self.qlog_dir { + let mut qlog_path = qlog_dir.to_path_buf(); + + qlog_path.push(format!("{}.qlog", attempt_key.odcid)); + + // The original DCID is chosen by the client. Using create_new() + // prevents attackers from overwriting existing logs. + match OpenOptions::new() + .write(true) + .create_new(true) + .open(&qlog_path) + { + Ok(f) => { + qinfo!("Qlog output to {}", qlog_path.display()); + + let streamer = ::qlog::QlogStreamer::new( + qlog::QLOG_VERSION.to_string(), + Some("Neqo server qlog".to_string()), + Some("Neqo server qlog".to_string()), + None, + std::time::Instant::now(), + common::qlog::new_trace(Role::Server), + Box::new(f), + ); + let n_qlog = NeqoQlog::enabled(streamer, qlog_path); + match n_qlog { + Ok(nql) => nql, + Err(e) => { + // Keep going but w/o qlogging + qerror!("NeqoQlog error: {}", e); + NeqoQlog::disabled() + } + } + } + Err(e) => { + qerror!( + "Could not open file {} for qlog output: {}", + qlog_path.display(), + e + ); + NeqoQlog::disabled() + } + } + } else { + NeqoQlog::disabled() + } + } + + fn accept_connection( + &mut self, + attempt_key: AttemptKey, + initial: InitialDetails, + dgram: Datagram, + orig_dcid: Option<ConnectionId>, + now: Instant, + ) -> Option<Datagram> { + qinfo!([self], "Accept connection {:?}", attempt_key); + // The internal connection ID manager that we use is not used directly. + // Instead, wrap it so that we can save connection IDs. + + let cid_mgr = Rc::new(RefCell::new(ServerConnectionIdManager { + c: Weak::new(), + cid_manager: Rc::clone(&self.cid_manager), + connections: Rc::clone(&self.connections), + saved_cids: Vec::new(), + })); + + let sconn = Connection::new_server( + &self.certs, + &self.protocols, + Rc::clone(&cid_mgr) as _, + &self.conn_params.clone().quic_version(initial.quic_version), + ); + + if let Ok(mut c) = sconn { + let zcheck = self.zero_rtt_checker.clone(); + if c.server_enable_0rtt(&self.anti_replay, zcheck).is_err() { + qwarn!([self], "Unable to enable 0-RTT"); + } + if let Some(odcid) = orig_dcid { + // There was a retry, so set the connection IDs for. + c.set_retry_cids(odcid, initial.src_cid, initial.dst_cid); + } + c.set_validation(Rc::clone(&self.address_validation)); + c.set_qlog(self.create_qlog_trace(&attempt_key)); + let c = Rc::new(RefCell::new(ServerConnectionState { + c, + last_timer: now, + active_attempt: Some(attempt_key.clone()), + })); + cid_mgr.borrow_mut().set_connection(Rc::clone(&c)); + let previous_attempt = self.active_attempts.insert(attempt_key, Rc::clone(&c)); + debug_assert!(previous_attempt.is_none()); + self.process_connection(c, Some(dgram), now) + } else { + qwarn!([self], "Unable to create connection"); + None + } + } + + fn process_input(&mut self, dgram: Datagram, now: Instant) -> Option<Datagram> { + qtrace!("Process datagram: {}", hex(&dgram[..])); + + // This is only looking at the first packet header in the datagram. + // All packets in the datagram are routed to the same connection. + let res = PublicPacket::decode(&dgram[..], self.cid_manager.borrow().as_decoder()); + let (packet, _remainder) = match res { + Ok(res) => res, + _ => { + qtrace!([self], "Discarding {:?}", dgram); + return None; + } + }; + + // Finding an existing connection. Should be the most common case. + if let Some(c) = self.connection(packet.dcid()) { + return self.process_connection(c, Some(dgram), now); + } + + if packet.packet_type() == PacketType::Short { + // TODO send a stateless reset here. + qtrace!([self], "Short header packet for an unknown connection"); + return None; + } + + if dgram.len() < MIN_INITIAL_PACKET_SIZE { + qtrace!([self], "Bogus packet: too short"); + return None; + } + match packet.packet_type() { + PacketType::Initial => { + // Copy values from `packet` because they are currently still borrowing from `dgram`. + let initial = InitialDetails::new(&packet); + self.handle_initial(initial, dgram, now) + } + PacketType::OtherVersion => { + let vn = PacketBuilder::version_negotiation(packet.scid(), packet.dcid()); + Some(Datagram::new(dgram.destination(), dgram.source(), vn)) + } + _ => { + qtrace!([self], "Not an initial packet"); + None + } + } + } + + /// Iterate through the pending connections looking for any that might want + /// to send a datagram. Stop at the first one that does. + fn process_next_output(&mut self, now: Instant) -> Option<Datagram> { + qtrace!([self], "No packet to send, look at waiting connections"); + while let Some(c) = self.waiting.pop_front() { + if let Some(d) = self.process_connection(c, None, now) { + return Some(d); + } + } + qtrace!([self], "No packet to send still, run timers"); + while let Some(c) = self.timers.take_next(now) { + if let Some(d) = self.process_connection(c, None, now) { + return Some(d); + } + } + None + } + + fn next_time(&mut self, now: Instant) -> Option<Duration> { + if self.waiting.is_empty() { + self.timers.next_time().map(|x| x - now) + } else { + Some(Duration::new(0, 0)) + } + } + + pub fn process(&mut self, dgram: Option<Datagram>, now: Instant) -> Output { + let out = if let Some(d) = dgram { + self.process_input(d, now) + } else { + None + }; + let out = out.or_else(|| self.process_next_output(now)); + match out { + Some(d) => { + qtrace!([self], "Send packet: {:?}", d); + Output::Datagram(d) + } + _ => match self.next_time(now) { + Some(delay) => { + qtrace!([self], "Wait: {:?}", delay); + Output::Callback(delay) + } + _ => { + qtrace!([self], "Go dormant"); + Output::None + } + }, + } + } + + /// This lists the connections that have received new events + /// as a result of calling `process()`. + pub fn active_connections(&mut self) -> Vec<ActiveConnectionRef> { + mem::take(&mut self.active).into_iter().collect() + } + + pub fn add_to_waiting(&mut self, c: ActiveConnectionRef) { + self.waiting.push_back(c.connection()); + } +} + +#[derive(Clone, Debug)] +pub struct ActiveConnectionRef { + c: StateRef, +} + +impl ActiveConnectionRef { + pub fn borrow<'a>(&'a self) -> impl Deref<Target = Connection> + 'a { + std::cell::Ref::map(self.c.borrow(), |c| &c.c) + } + + pub fn borrow_mut<'a>(&'a mut self) -> impl DerefMut<Target = Connection> + 'a { + std::cell::RefMut::map(self.c.borrow_mut(), |c| &mut c.c) + } + + pub fn connection(&self) -> StateRef { + Rc::clone(&self.c) + } +} + +impl std::hash::Hash for ActiveConnectionRef { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + let ptr: *const _ = self.c.as_ref(); + ptr.hash(state) + } +} + +impl PartialEq for ActiveConnectionRef { + fn eq(&self, other: &Self) -> bool { + Rc::ptr_eq(&self.c, &other.c) + } +} + +impl Eq for ActiveConnectionRef {} + +struct ServerConnectionIdManager { + c: Weak<RefCell<ServerConnectionState>>, + connections: ConnectionTableRef, + cid_manager: CidMgr, + saved_cids: Vec<ConnectionId>, +} + +impl ServerConnectionIdManager { + pub fn set_connection(&mut self, c: StateRef) { + let saved = std::mem::replace(&mut self.saved_cids, Vec::with_capacity(0)); + for cid in saved { + qtrace!("ServerConnectionIdManager inserting saved cid {}", cid); + self.insert_cid(cid, Rc::clone(&c)); + } + self.c = Rc::downgrade(&c); + } + + fn insert_cid(&mut self, cid: ConnectionId, rc: StateRef) { + debug_assert!(!cid.is_empty()); + self.connections.borrow_mut().insert(cid, rc); + } +} + +impl ConnectionIdDecoder for ServerConnectionIdManager { + fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> { + self.cid_manager.borrow_mut().decode_cid(dec) + } +} + +impl ConnectionIdManager for ServerConnectionIdManager { + fn generate_cid(&mut self) -> ConnectionId { + let cid = self.cid_manager.borrow_mut().generate_cid(); + if let Some(rc) = self.c.upgrade() { + self.insert_cid(cid.clone(), rc); + } else { + // This function can be called before the connection is set. + // So save any connection IDs until that hookup happens. + qtrace!("ServerConnectionIdManager saving cid {}", cid); + self.saved_cids.push(cid.clone()); + } + cid + } + + fn as_decoder(&self) -> &dyn ConnectionIdDecoder { + self + } +} + +impl ::std::fmt::Display for Server { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Server") + } +} diff --git a/third_party/rust/neqo-transport/src/stats.rs b/third_party/rust/neqo-transport/src/stats.rs new file mode 100644 index 0000000000..9d01dfe211 --- /dev/null +++ b/third_party/rust/neqo-transport/src/stats.rs @@ -0,0 +1,195 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Tracking of some useful statistics. +#![deny(clippy::pedantic)] + +use crate::packet::PacketNumber; +use neqo_common::qinfo; +use std::cell::RefCell; +use std::fmt::{self, Debug}; +use std::ops::Deref; +use std::rc::Rc; + +pub(crate) const MAX_PTO_COUNTS: usize = 16; + +#[derive(Default, Clone)] +#[allow(clippy::module_name_repetitions)] +pub struct FrameStats { + pub all: usize, + pub ack: usize, + pub largest_acknowledged: PacketNumber, + + pub crypto: usize, + pub stream: usize, + pub reset_stream: usize, + pub stop_sending: usize, + + pub ping: usize, + pub padding: usize, + + pub max_streams: usize, + pub streams_blocked: usize, + pub max_data: usize, + pub data_blocked: usize, + pub max_stream_data: usize, + pub stream_data_blocked: usize, + + pub new_connection_id: usize, + pub retire_connection_id: usize, + + pub path_challenge: usize, + pub path_response: usize, + + pub connection_close: usize, + pub handshake_done: usize, + pub new_token: usize, +} + +impl Debug for FrameStats { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + " crypto {} done {} token {} close {}", + self.crypto, self.handshake_done, self.new_token, self.connection_close, + )?; + writeln!( + f, + " ack {} (max {}) ping {} padding {}", + self.ack, self.largest_acknowledged, self.ping, self.padding + )?; + writeln!( + f, + " stream {} reset {} stop {}", + self.stream, self.reset_stream, self.stop_sending, + )?; + writeln!( + f, + " max: stream {} data {} stream_data {}", + self.max_streams, self.max_data, self.max_stream_data, + )?; + writeln!( + f, + " blocked: stream {} data {} stream_data {}", + self.streams_blocked, self.data_blocked, self.stream_data_blocked, + )?; + writeln!( + f, + " ncid {} rcid {} pchallenge {} presponse {}", + self.new_connection_id, + self.retire_connection_id, + self.path_challenge, + self.path_response, + ) + } +} + +/// Connection statistics +#[derive(Default, Clone)] +#[allow(clippy::module_name_repetitions)] +pub struct Stats { + info: String, + + /// Total packets received, including all the bad ones. + pub packets_rx: usize, + /// Duplicate packets received. + pub dups_rx: usize, + /// Dropped packets or dropped garbage. + pub dropped_rx: usize, + /// The number of packet that were saved for later processing. + pub saved_datagrams: usize, + + /// Total packets sent. + pub packets_tx: usize, + /// Total number of packets that are declared lost. + pub lost: usize, + /// Late acknowledgments, for packets that were declared lost already. + pub late_ack: usize, + /// Acknowledgments for packets that contained data that was marked + /// for retransmission when the PTO timer popped. + pub pto_ack: usize, + + /// Whether the connection was resumed successfully. + pub resumed: bool, + + /// Count PTOs. Single PTOs, 2 PTOs in a row, 3 PTOs in row, etc. are counted + /// separately. + pub pto_counts: [usize; MAX_PTO_COUNTS], + + /// Count frames received. + pub frame_rx: FrameStats, + /// Count frames sent. + pub frame_tx: FrameStats, +} + +impl Stats { + pub fn init(&mut self, info: String) { + self.info = info; + } + + pub fn pkt_dropped(&mut self, reason: impl AsRef<str>) { + self.dropped_rx += 1; + qinfo!( + [self.info], + "Dropped received packet: {}; Total: {}", + reason.as_ref(), + self.dropped_rx + ) + } + + pub fn add_pto_count(&mut self, count: usize) { + debug_assert!(count > 0); + if count >= MAX_PTO_COUNTS { + // We can't move this count any further, so stop. + return; + } + self.pto_counts[count - 1] += 1; + if count > 1 { + debug_assert!(self.pto_counts[count - 2] > 0); + self.pto_counts[count - 2] -= 1; + } + } +} + +impl Debug for Stats { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "stats for {}", self.info)?; + writeln!( + f, + " rx: {} drop {} dup {} saved {}", + self.packets_rx, self.dropped_rx, self.dups_rx, self.saved_datagrams + )?; + writeln!( + f, + " tx: {} lost {} lateack {} ptoack {}", + self.packets_tx, self.lost, self.late_ack, self.pto_ack + )?; + writeln!(f, " resumed: {} ", self.resumed)?; + writeln!(f, " frames rx:")?; + self.frame_rx.fmt(f)?; + writeln!(f, " frames tx:")?; + self.frame_tx.fmt(f) + } +} + +#[derive(Default, Clone)] +#[allow(clippy::module_name_repetitions)] +pub struct StatsCell { + stats: Rc<RefCell<Stats>>, +} + +impl Deref for StatsCell { + type Target = RefCell<Stats>; + fn deref(&self) -> &Self::Target { + &*self.stats + } +} + +impl Debug for StatsCell { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.stats.borrow().fmt(f) + } +} diff --git a/third_party/rust/neqo-transport/src/stream_id.rs b/third_party/rust/neqo-transport/src/stream_id.rs new file mode 100644 index 0000000000..486bfb937c --- /dev/null +++ b/third_party/rust/neqo-transport/src/stream_id.rs @@ -0,0 +1,205 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Stream ID and stream index handling. + +use std::ops::AddAssign; + +use neqo_common::Role; + +use crate::connection::{LOCAL_STREAM_LIMIT_BIDI, LOCAL_STREAM_LIMIT_UNI}; +use crate::frame::StreamType; + +pub struct StreamIndexes { + pub local_max_stream_uni: StreamIndex, + pub local_max_stream_bidi: StreamIndex, + pub local_next_stream_uni: StreamIndex, + pub local_next_stream_bidi: StreamIndex, + pub remote_max_stream_uni: StreamIndex, + pub remote_max_stream_bidi: StreamIndex, + pub remote_next_stream_uni: StreamIndex, + pub remote_next_stream_bidi: StreamIndex, +} + +impl StreamIndexes { + pub fn new() -> Self { + Self { + local_max_stream_bidi: StreamIndex::new(LOCAL_STREAM_LIMIT_BIDI), + local_max_stream_uni: StreamIndex::new(LOCAL_STREAM_LIMIT_UNI), + local_next_stream_uni: StreamIndex::new(0), + local_next_stream_bidi: StreamIndex::new(0), + remote_max_stream_bidi: StreamIndex::new(0), + remote_max_stream_uni: StreamIndex::new(0), + remote_next_stream_uni: StreamIndex::new(0), + remote_next_stream_bidi: StreamIndex::new(0), + } + } +} + +#[derive(Debug, Eq, PartialEq, Clone, Copy, Ord, PartialOrd, Hash)] +pub struct StreamId(u64); + +impl StreamId { + pub const fn new(id: u64) -> Self { + Self(id) + } + + pub fn as_u64(self) -> u64 { + self.0 + } + + pub fn is_bidi(self) -> bool { + self.as_u64() & 0x02 == 0 + } + + pub fn is_uni(self) -> bool { + !self.is_bidi() + } + + pub fn stream_type(self) -> StreamType { + if self.is_bidi() { + StreamType::BiDi + } else { + StreamType::UniDi + } + } + + pub fn is_client_initiated(self) -> bool { + self.as_u64() & 0x01 == 0 + } + + pub fn is_server_initiated(self) -> bool { + !self.is_client_initiated() + } + + pub fn role(self) -> Role { + if self.is_client_initiated() { + Role::Client + } else { + Role::Server + } + } + + pub fn is_self_initiated(self, my_role: Role) -> bool { + match my_role { + Role::Client if self.is_client_initiated() => true, + Role::Server if self.is_server_initiated() => true, + _ => false, + } + } + + pub fn is_remote_initiated(self, my_role: Role) -> bool { + !self.is_self_initiated(my_role) + } + + pub fn is_send_only(self, my_role: Role) -> bool { + self.is_uni() && self.is_self_initiated(my_role) + } + + pub fn is_recv_only(self, my_role: Role) -> bool { + self.is_uni() && self.is_remote_initiated(my_role) + } +} + +impl From<u64> for StreamId { + fn from(val: u64) -> Self { + Self::new(val) + } +} + +impl PartialEq<u64> for StreamId { + fn eq(&self, other: &u64) -> bool { + self.as_u64() == *other + } +} + +impl ::std::fmt::Display for StreamId { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}", self.as_u64()) + } +} + +#[derive(Debug, Eq, PartialEq, Clone, Copy, Ord, PartialOrd, Hash)] +pub struct StreamIndex(u64); + +impl StreamIndex { + pub fn new(val: u64) -> Self { + Self(val) + } + + pub fn to_stream_id(self, stream_type: StreamType, role: Role) -> StreamId { + let type_val = match stream_type { + StreamType::BiDi => 0, + StreamType::UniDi => 2, + }; + let role_val = match role { + Role::Server => 1, + Role::Client => 0, + }; + + StreamId::from((self.0 << 2) + type_val + role_val) + } + + pub fn as_u64(self) -> u64 { + self.0 + } +} + +impl From<StreamId> for StreamIndex { + fn from(val: StreamId) -> Self { + Self(val.as_u64() >> 2) + } +} + +impl AddAssign<u64> for StreamIndex { + fn add_assign(&mut self, other: u64) { + *self = Self::new(self.as_u64() + other) + } +} + +#[cfg(test)] +mod test { + use super::{StreamIndex, StreamType}; + use neqo_common::Role; + + #[test] + fn bidi_stream_properties() { + let id1 = StreamIndex::new(4).to_stream_id(StreamType::BiDi, Role::Client); + assert_eq!(id1.is_bidi(), true); + assert_eq!(id1.is_uni(), false); + assert_eq!(id1.is_client_initiated(), true); + assert_eq!(id1.is_server_initiated(), false); + assert_eq!(id1.role(), Role::Client); + assert_eq!(id1.is_self_initiated(Role::Client), true); + assert_eq!(id1.is_self_initiated(Role::Server), false); + assert_eq!(id1.is_remote_initiated(Role::Client), false); + assert_eq!(id1.is_remote_initiated(Role::Server), true); + assert_eq!(id1.is_send_only(Role::Server), false); + assert_eq!(id1.is_send_only(Role::Client), false); + assert_eq!(id1.is_recv_only(Role::Server), false); + assert_eq!(id1.is_recv_only(Role::Client), false); + assert_eq!(id1.as_u64(), 16); + } + + #[test] + fn uni_stream_properties() { + let id2 = StreamIndex::new(8).to_stream_id(StreamType::UniDi, Role::Server); + assert_eq!(id2.is_bidi(), false); + assert_eq!(id2.is_uni(), true); + assert_eq!(id2.is_client_initiated(), false); + assert_eq!(id2.is_server_initiated(), true); + assert_eq!(id2.role(), Role::Server); + assert_eq!(id2.is_self_initiated(Role::Client), false); + assert_eq!(id2.is_self_initiated(Role::Server), true); + assert_eq!(id2.is_remote_initiated(Role::Client), true); + assert_eq!(id2.is_remote_initiated(Role::Server), false); + assert_eq!(id2.is_send_only(Role::Server), true); + assert_eq!(id2.is_send_only(Role::Client), false); + assert_eq!(id2.is_recv_only(Role::Server), false); + assert_eq!(id2.is_recv_only(Role::Client), true); + assert_eq!(id2.as_u64(), 35); + } +} diff --git a/third_party/rust/neqo-transport/src/tparams.rs b/third_party/rust/neqo-transport/src/tparams.rs new file mode 100644 index 0000000000..f0cfbf2203 --- /dev/null +++ b/third_party/rust/neqo-transport/src/tparams.rs @@ -0,0 +1,541 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Transport parameters. See -transport section 7.3. + +#![allow(dead_code)] +use crate::{Error, Res}; +use neqo_common::{hex, qdebug, qinfo, qtrace, Decoder, Encoder}; +use neqo_crypto::constants::{TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS}; +use neqo_crypto::ext::{ExtensionHandler, ExtensionHandlerResult, ExtensionWriterResult}; +use neqo_crypto::{HandshakeMessage, ZeroRttCheckResult, ZeroRttChecker}; +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; + +struct PreferredAddress { + // TODO(ekr@rtfm.com): Implement. +} + +pub type TransportParameterId = u64; +macro_rules! tpids { + { $($n:ident = $v:expr),+ $(,)? } => { + $(pub const $n: TransportParameterId = $v as TransportParameterId;)+ + }; + } +tpids! { + ORIGINAL_DESTINATION_CONNECTION_ID = 0x00, + IDLE_TIMEOUT = 0x01, + STATELESS_RESET_TOKEN = 0x02, + MAX_UDP_PAYLOAD_SIZE = 0x03, + INITIAL_MAX_DATA = 0x04, + INITIAL_MAX_STREAM_DATA_BIDI_LOCAL = 0x05, + INITIAL_MAX_STREAM_DATA_BIDI_REMOTE = 0x06, + INITIAL_MAX_STREAM_DATA_UNI = 0x07, + INITIAL_MAX_STREAMS_BIDI = 0x08, + INITIAL_MAX_STREAMS_UNI = 0x09, + ACK_DELAY_EXPONENT = 0x0a, + MAX_ACK_DELAY = 0x0b, + DISABLE_MIGRATION = 0x0c, + PREFERRED_ADDRESS = 0x0d, + ACTIVE_CONNECTION_ID_LIMIT = 0x0e, + INITIAL_SOURCE_CONNECTION_ID = 0x0f, + RETRY_SOURCE_CONNECTION_ID = 0x10, + GREASE_QUIC_BIT = 0x2ab2, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum TransportParameter { + Bytes(Vec<u8>), + Integer(u64), + Empty, +} + +impl TransportParameter { + fn encode(&self, enc: &mut Encoder, tp: TransportParameterId) { + qdebug!("TP encoded; type 0x{:02x} val {:?}", tp, self); + enc.encode_varint(tp); + match self { + Self::Bytes(a) => { + enc.encode_vvec(a); + } + Self::Integer(a) => { + enc.encode_vvec_with(|enc_inner| { + enc_inner.encode_varint(*a); + }); + } + Self::Empty => { + enc.encode_varint(0_u64); + } + }; + } + + fn decode(dec: &mut Decoder) -> Res<Option<(TransportParameterId, Self)>> { + let tp = match dec.decode_varint() { + Some(v) => v, + _ => return Err(Error::NoMoreData), + }; + let content = match dec.decode_vvec() { + Some(v) => v, + _ => return Err(Error::NoMoreData), + }; + qtrace!("TP {:x} length {:x}", tp, content.len()); + let mut d = Decoder::from(content); + let value = match tp { + ORIGINAL_DESTINATION_CONNECTION_ID + | INITIAL_SOURCE_CONNECTION_ID + | RETRY_SOURCE_CONNECTION_ID => Self::Bytes(d.decode_remainder().to_vec()), + STATELESS_RESET_TOKEN => { + if d.remaining() != 16 { + return Err(Error::TransportParameterError); + } + Self::Bytes(d.decode_remainder().to_vec()) + } + IDLE_TIMEOUT + | INITIAL_MAX_DATA + | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL + | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE + | INITIAL_MAX_STREAM_DATA_UNI + | MAX_ACK_DELAY => match d.decode_varint() { + Some(v) => Self::Integer(v), + None => return Err(Error::TransportParameterError), + }, + + INITIAL_MAX_STREAMS_BIDI | INITIAL_MAX_STREAMS_UNI => match d.decode_varint() { + Some(v) if v <= (1 << 60) => Self::Integer(v), + _ => return Err(Error::StreamLimitError), + }, + + MAX_UDP_PAYLOAD_SIZE => match d.decode_varint() { + Some(v) if v >= 1200 => Self::Integer(v), + _ => return Err(Error::TransportParameterError), + }, + + ACK_DELAY_EXPONENT => match d.decode_varint() { + Some(v) if v <= 20 => Self::Integer(v), + _ => return Err(Error::TransportParameterError), + }, + ACTIVE_CONNECTION_ID_LIMIT => match d.decode_varint() { + Some(v) if v >= 2 => Self::Integer(v), + _ => return Err(Error::TransportParameterError), + }, + + DISABLE_MIGRATION | GREASE_QUIC_BIT => Self::Empty, + // Skip. + _ => return Ok(None), + }; + if d.remaining() > 0 { + return Err(Error::TooMuchData); + } + qdebug!("TP decoded; type 0x{:02x} val {:?}", tp, value); + Ok(Some((tp, value))) + } +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct TransportParameters { + params: HashMap<TransportParameterId, TransportParameter>, +} + +impl TransportParameters { + /// Set a value. + pub fn set(&mut self, k: TransportParameterId, v: TransportParameter) { + self.params.insert(k, v); + } + + /// Clear a key. + pub fn remove(&mut self, k: TransportParameterId) { + self.params.remove(&k); + } + + /// Decode is a static function that parses transport parameters + /// using the provided decoder. + pub(crate) fn decode(d: &mut Decoder) -> Res<Self> { + let mut tps = Self::default(); + qtrace!("Parsed fixed TP header"); + + while d.remaining() > 0 { + match TransportParameter::decode(d) { + Ok(Some((tipe, tp))) => { + tps.set(tipe, tp); + } + Ok(None) => {} + Err(e) => return Err(e), + } + } + Ok(tps) + } + + pub(crate) fn encode(&self, enc: &mut Encoder) { + for (tipe, tp) in &self.params { + tp.encode(enc, *tipe); + } + } + + // Get an integer type or a default. + pub fn get_integer(&self, tp: TransportParameterId) -> u64 { + let default = match tp { + IDLE_TIMEOUT + | INITIAL_MAX_DATA + | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL + | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE + | INITIAL_MAX_STREAM_DATA_UNI + | INITIAL_MAX_STREAMS_BIDI + | INITIAL_MAX_STREAMS_UNI => 0, + MAX_UDP_PAYLOAD_SIZE => 65527, + ACK_DELAY_EXPONENT => 3, + MAX_ACK_DELAY => 25, + ACTIVE_CONNECTION_ID_LIMIT => 2, + _ => panic!("Transport parameter not known or not an Integer"), + }; + match self.params.get(&tp) { + None => default, + Some(TransportParameter::Integer(x)) => *x, + _ => panic!("Internal error"), + } + } + + // Set an integer type or a default. + pub fn set_integer(&mut self, tp: TransportParameterId, value: u64) { + match tp { + IDLE_TIMEOUT + | INITIAL_MAX_DATA + | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL + | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE + | INITIAL_MAX_STREAM_DATA_UNI + | INITIAL_MAX_STREAMS_BIDI + | INITIAL_MAX_STREAMS_UNI + | MAX_UDP_PAYLOAD_SIZE + | ACK_DELAY_EXPONENT + | MAX_ACK_DELAY + | ACTIVE_CONNECTION_ID_LIMIT => { + self.set(tp, TransportParameter::Integer(value)); + } + _ => panic!("Transport parameter not known"), + } + } + + pub fn get_bytes(&self, tp: TransportParameterId) -> Option<&[u8]> { + match tp { + ORIGINAL_DESTINATION_CONNECTION_ID + | INITIAL_SOURCE_CONNECTION_ID + | RETRY_SOURCE_CONNECTION_ID + | STATELESS_RESET_TOKEN => {} + _ => panic!("Transport parameter not known or not type bytes"), + } + + match self.params.get(&tp) { + None => None, + Some(TransportParameter::Bytes(x)) => Some(&x), + _ => panic!("Internal error"), + } + } + + pub fn set_bytes(&mut self, tp: TransportParameterId, value: Vec<u8>) { + match tp { + ORIGINAL_DESTINATION_CONNECTION_ID + | INITIAL_SOURCE_CONNECTION_ID + | RETRY_SOURCE_CONNECTION_ID + | STATELESS_RESET_TOKEN => { + self.set(tp, TransportParameter::Bytes(value)); + } + _ => panic!("Transport parameter not known or not type bytes"), + } + } + + pub fn set_empty(&mut self, tp: TransportParameterId) { + match tp { + DISABLE_MIGRATION | GREASE_QUIC_BIT => { + self.set(tp, TransportParameter::Empty); + } + _ => panic!("Transport parameter not known or not type empty"), + } + } + + pub fn get_empty(&self, tipe: TransportParameterId) -> bool { + match self.params.get(&tipe) { + None => false, + Some(TransportParameter::Empty) => true, + _ => panic!("Internal error"), + } + } + + /// Return true if the remembered transport parameters are OK for 0-RTT. + /// Generally this means that any value that is currently in effect is greater than + /// or equal to the promised value. + pub(crate) fn ok_for_0rtt(&self, remembered: &Self) -> bool { + for (k, v_rem) in &remembered.params { + // Skip checks for these, which don't affect 0-RTT. + if matches!( + *k, + ORIGINAL_DESTINATION_CONNECTION_ID + | INITIAL_SOURCE_CONNECTION_ID + | RETRY_SOURCE_CONNECTION_ID + | STATELESS_RESET_TOKEN + | IDLE_TIMEOUT + | ACK_DELAY_EXPONENT + | MAX_ACK_DELAY + | ACTIVE_CONNECTION_ID_LIMIT + ) { + continue; + } + match self.params.get(k) { + Some(v_self) => match (v_self, v_rem) { + (TransportParameter::Integer(i_self), TransportParameter::Integer(i_rem)) => { + if *i_self < *i_rem { + return false; + } + } + (TransportParameter::Empty, TransportParameter::Empty) => {} + _ => return false, + }, + _ => return false, + } + } + true + } + + fn was_sent(&self, tp: TransportParameterId) -> bool { + self.params.contains_key(&tp) + } +} + +#[derive(Default, Debug)] +pub struct TransportParametersHandler { + pub(crate) local: TransportParameters, + pub(crate) remote: Option<TransportParameters>, + pub(crate) remote_0rtt: Option<TransportParameters>, +} + +impl TransportParametersHandler { + pub fn remote(&self) -> &TransportParameters { + match (self.remote.as_ref(), self.remote_0rtt.as_ref()) { + (Some(tp), _) | (_, Some(tp)) => tp, + _ => panic!("no transport parameters from peer"), + } + } +} + +impl ExtensionHandler for TransportParametersHandler { + fn write(&mut self, msg: HandshakeMessage, d: &mut [u8]) -> ExtensionWriterResult { + if !matches!(msg, TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS) { + return ExtensionWriterResult::Skip; + } + + qdebug!("Writing transport parameters, msg={:?}", msg); + + // TODO(ekr@rtfm.com): Modify to avoid a copy. + let mut enc = Encoder::default(); + self.local.encode(&mut enc); + assert!(enc.len() <= d.len()); + d[..enc.len()].copy_from_slice(&enc); + ExtensionWriterResult::Write(enc.len()) + } + + fn handle(&mut self, msg: HandshakeMessage, d: &[u8]) -> ExtensionHandlerResult { + qtrace!( + "Handling transport parameters, msg={:?} value={}", + msg, + hex(d), + ); + + if !matches!(msg, TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS) { + return ExtensionHandlerResult::Alert(110); // unsupported_extension + } + + let mut dec = Decoder::from(d); + match TransportParameters::decode(&mut dec) { + Ok(tp) => { + self.remote = Some(tp); + ExtensionHandlerResult::Ok + } + _ => ExtensionHandlerResult::Alert(47), // illegal_parameter + } + } +} + +#[derive(Debug)] +pub(crate) struct TpZeroRttChecker<T> { + handler: Rc<RefCell<TransportParametersHandler>>, + app_checker: T, +} + +impl<T> TpZeroRttChecker<T> +where + T: ZeroRttChecker + 'static, +{ + pub fn wrap( + handler: Rc<RefCell<TransportParametersHandler>>, + app_checker: T, + ) -> Box<dyn ZeroRttChecker> { + Box::new(Self { + handler, + app_checker, + }) + } +} + +impl<T> ZeroRttChecker for TpZeroRttChecker<T> +where + T: ZeroRttChecker, +{ + fn check(&self, token: &[u8]) -> ZeroRttCheckResult { + // Reject 0-RTT if there is no token. + if token.is_empty() { + qdebug!("0-RTT: no token, no 0-RTT"); + return ZeroRttCheckResult::Reject; + } + let mut dec = Decoder::from(token); + let tpslice = if let Some(v) = dec.decode_vvec() { + v + } else { + qinfo!("0-RTT: token code error"); + return ZeroRttCheckResult::Fail; + }; + let mut dec_tp = Decoder::from(tpslice); + let remembered = if let Ok(v) = TransportParameters::decode(&mut dec_tp) { + v + } else { + qinfo!("0-RTT: transport parameter decode error"); + return ZeroRttCheckResult::Fail; + }; + if self.handler.borrow().local.ok_for_0rtt(&remembered) { + qinfo!("0-RTT: transport parameters OK, passing to application checker"); + self.app_checker.check(dec.decode_remainder()) + } else { + qinfo!("0-RTT: transport parameters bad, rejecting"); + ZeroRttCheckResult::Reject + } + } +} + +// TODO(ekr@rtfm.com): Need to write more TP unit tests. +#[cfg(test)] +#[allow(unused_variables)] +mod tests { + use super::*; + + #[test] + fn basic_tps() { + const RESET_TOKEN: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8]; + let mut tps = TransportParameters::default(); + tps.set( + STATELESS_RESET_TOKEN, + TransportParameter::Bytes(RESET_TOKEN.to_vec()), + ); + tps.params + .insert(INITIAL_MAX_STREAMS_BIDI, TransportParameter::Integer(10)); + + let mut enc = Encoder::default(); + tps.encode(&mut enc); + + let tps2 = TransportParameters::decode(&mut enc.as_decoder()).expect("Couldn't decode"); + assert_eq!(tps, tps2); + + println!("TPS = {:?}", tps); + assert_eq!(tps2.get_integer(IDLE_TIMEOUT), 0); // Default + assert_eq!(tps2.get_integer(MAX_ACK_DELAY), 25); // Default + assert_eq!(tps2.get_integer(ACTIVE_CONNECTION_ID_LIMIT), 2); // Default + assert_eq!(tps2.get_integer(INITIAL_MAX_STREAMS_BIDI), 10); // Sent + assert_eq!(tps2.get_bytes(STATELESS_RESET_TOKEN), Some(RESET_TOKEN)); + assert_eq!(tps2.get_bytes(ORIGINAL_DESTINATION_CONNECTION_ID), None); + assert_eq!(tps2.get_bytes(INITIAL_SOURCE_CONNECTION_ID), None); + assert_eq!(tps2.get_bytes(RETRY_SOURCE_CONNECTION_ID), None); + assert_eq!(tps2.was_sent(ORIGINAL_DESTINATION_CONNECTION_ID), false); + assert_eq!(tps2.was_sent(INITIAL_SOURCE_CONNECTION_ID), false); + assert_eq!(tps2.was_sent(RETRY_SOURCE_CONNECTION_ID), false); + assert_eq!(tps2.was_sent(STATELESS_RESET_TOKEN), true); + + let mut enc = Encoder::default(); + tps.encode(&mut enc); + + let tps2 = TransportParameters::decode(&mut enc.as_decoder()).expect("Couldn't decode"); + } + + #[test] + fn compatible_0rtt_ignored_values() { + let mut tps_a = TransportParameters::default(); + tps_a.set( + STATELESS_RESET_TOKEN, + TransportParameter::Bytes(vec![1, 2, 3]), + ); + tps_a.set(IDLE_TIMEOUT, TransportParameter::Integer(10)); + tps_a.set(MAX_ACK_DELAY, TransportParameter::Integer(22)); + tps_a.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(33)); + + let mut tps_b = TransportParameters::default(); + assert!(tps_a.ok_for_0rtt(&tps_b)); + assert!(tps_b.ok_for_0rtt(&tps_a)); + + tps_b.set( + STATELESS_RESET_TOKEN, + TransportParameter::Bytes(vec![8, 9, 10]), + ); + tps_b.set(IDLE_TIMEOUT, TransportParameter::Integer(100)); + tps_b.set(MAX_ACK_DELAY, TransportParameter::Integer(2)); + tps_b.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(44)); + assert!(tps_a.ok_for_0rtt(&tps_b)); + assert!(tps_b.ok_for_0rtt(&tps_a)); + } + + #[test] + fn compatible_0rtt_integers() { + let mut tps_a = TransportParameters::default(); + const INTEGER_KEYS: &[TransportParameterId] = &[ + INITIAL_MAX_DATA, + INITIAL_MAX_STREAM_DATA_BIDI_LOCAL, + INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, + INITIAL_MAX_STREAM_DATA_UNI, + INITIAL_MAX_STREAMS_BIDI, + INITIAL_MAX_STREAMS_UNI, + MAX_UDP_PAYLOAD_SIZE, + ]; + for i in INTEGER_KEYS { + tps_a.set(*i, TransportParameter::Integer(12)); + } + + let tps_b = tps_a.clone(); + assert!(tps_a.ok_for_0rtt(&tps_b)); + assert!(tps_b.ok_for_0rtt(&tps_a)); + + // For each integer key, increase the value by one. + for i in INTEGER_KEYS { + let mut tps_b = tps_a.clone(); + tps_b.set(*i, TransportParameter::Integer(13)); + // If an increased value is remembered, then we can't attempt 0-RTT with these parameters. + assert!(!tps_a.ok_for_0rtt(&tps_b)); + // If an increased value is lower, then we can attempt 0-RTT with these parameters. + assert!(tps_b.ok_for_0rtt(&tps_a)); + } + + // Drop integer values and check. + for i in INTEGER_KEYS { + let mut tps_b = tps_a.clone(); + tps_b.remove(*i); + // A value that is missing from what is rememebered is OK. + assert!(tps_a.ok_for_0rtt(&tps_b)); + // A value that is rememebered, but not current is not OK. + assert!(!tps_b.ok_for_0rtt(&tps_a)); + } + } + + #[test] + fn active_connection_id_limit_lt_2_is_error() { + let mut tps = TransportParameters::default(); + + // Intentionally set an invalid value for the ACTIVE_CONNECTION_ID_LIMIT transport parameter. + tps.params + .insert(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(1)); + + let mut enc = Encoder::default(); + tps.encode(&mut enc); + + // When decoding a set of transport parameters with an invalid ACTIVE_CONNECTION_ID_LIMIT + // the result should be an error. + let invalid_decode_result = TransportParameters::decode(&mut enc.as_decoder()); + assert!(invalid_decode_result.is_err()); + } +} diff --git a/third_party/rust/neqo-transport/src/tracking.rs b/third_party/rust/neqo-transport/src/tracking.rs new file mode 100644 index 0000000000..efd3b06069 --- /dev/null +++ b/third_party/rust/neqo-transport/src/tracking.rs @@ -0,0 +1,992 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Tracking of received packets and generating acks thereof. + +#![deny(clippy::pedantic)] + +use std::cmp::min; +use std::collections::VecDeque; +use std::convert::TryFrom; +use std::ops::{Index, IndexMut}; +use std::time::{Duration, Instant}; + +use neqo_common::{qdebug, qinfo, qtrace, qwarn}; +use neqo_crypto::{Epoch, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL}; + +use crate::packet::{PacketBuilder, PacketNumber, PacketType}; +use crate::recovery::RecoveryToken; +use crate::stats::FrameStats; + +use smallvec::{smallvec, SmallVec}; + +// TODO(mt) look at enabling EnumMap for this: https://stackoverflow.com/a/44905797/1375574 +#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq)] +pub enum PNSpace { + Initial, + Handshake, + ApplicationData, +} + +#[allow(clippy::use_self)] // https://github.com/rust-lang/rust-clippy/issues/3410 +impl PNSpace { + pub fn iter() -> impl Iterator<Item = &'static PNSpace> { + const SPACES: &[PNSpace] = &[ + PNSpace::Initial, + PNSpace::Handshake, + PNSpace::ApplicationData, + ]; + SPACES.iter() + } +} + +impl From<Epoch> for PNSpace { + fn from(epoch: Epoch) -> Self { + match epoch { + TLS_EPOCH_INITIAL => Self::Initial, + TLS_EPOCH_HANDSHAKE => Self::Handshake, + _ => Self::ApplicationData, + } + } +} + +impl From<PacketType> for PNSpace { + fn from(pt: PacketType) -> Self { + match pt { + PacketType::Initial => Self::Initial, + PacketType::Handshake => Self::Handshake, + PacketType::ZeroRtt | PacketType::Short => Self::ApplicationData, + _ => panic!("Attempted to get space from wrong packet type"), + } + } +} + +#[derive(Clone, Copy, Default)] +pub struct PNSpaceSet { + initial: bool, + handshake: bool, + application_data: bool, +} + +impl PNSpaceSet { + pub fn all() -> Self { + Self { + initial: true, + handshake: true, + application_data: true, + } + } +} + +impl Index<PNSpace> for PNSpaceSet { + type Output = bool; + + fn index(&self, space: PNSpace) -> &Self::Output { + match space { + PNSpace::Initial => &self.initial, + PNSpace::Handshake => &self.handshake, + PNSpace::ApplicationData => &self.application_data, + } + } +} + +impl IndexMut<PNSpace> for PNSpaceSet { + fn index_mut(&mut self, space: PNSpace) -> &mut Self::Output { + match space { + PNSpace::Initial => &mut self.initial, + PNSpace::Handshake => &mut self.handshake, + PNSpace::ApplicationData => &mut self.application_data, + } + } +} + +impl<T: AsRef<[PNSpace]>> From<T> for PNSpaceSet { + fn from(spaces: T) -> Self { + let mut v = Self::default(); + for sp in spaces.as_ref() { + v[*sp] = true; + } + v + } +} + +impl std::fmt::Debug for PNSpaceSet { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let mut first = true; + f.write_str("(")?; + for sp in PNSpace::iter() { + if self[*sp] { + if !first { + f.write_str("+")?; + first = false; + } + std::fmt::Display::fmt(sp, f)?; + } + } + f.write_str(")") + } +} + +#[derive(Debug, Clone)] +pub struct SentPacket { + pub pt: PacketType, + pub pn: PacketNumber, + ack_eliciting: bool, + pub time_sent: Instant, + pub tokens: Vec<RecoveryToken>, + + time_declared_lost: Option<Instant>, + /// After a PTO, this is true when the packet has been released. + pto: bool, + + pub size: usize, +} + +impl SentPacket { + pub fn new( + pt: PacketType, + pn: PacketNumber, + time_sent: Instant, + ack_eliciting: bool, + tokens: Vec<RecoveryToken>, + size: usize, + ) -> Self { + Self { + pt, + pn, + time_sent, + ack_eliciting, + tokens, + time_declared_lost: None, + pto: false, + size, + } + } + + /// Returns `true` if the packet will elicit an ACK. + pub fn ack_eliciting(&self) -> bool { + self.ack_eliciting + } + + /// Whether the packet has been declared lost. + pub fn lost(&self) -> bool { + self.time_declared_lost.is_some() + } + + /// Whether accounting for the loss or acknowledgement in the + /// congestion controller is pending. + /// Returns `true` if the packet counts as being "in flight", + /// and has not previously been declared lost. + /// Note that this should count packets that contain only ACK and PADDING, + /// but we don't send PADDING, so we don't track that. + pub fn cc_outstanding(&self) -> bool { + self.ack_eliciting() && !self.lost() + } + + /// Declare the packet as lost. Returns `true` if this is the first time. + pub fn declare_lost(&mut self, now: Instant) -> bool { + if self.lost() { + false + } else { + self.time_declared_lost = Some(now); + true + } + } + + /// Ask whether this tracked packet has been declared lost for long enough + /// that it can be expired and no longer tracked. + pub fn expired(&self, now: Instant, expiration_period: Duration) -> bool { + self.time_declared_lost + .map_or(false, |loss_time| (loss_time + expiration_period) <= now) + } + + /// Whether the packet contents were cleared out after a PTO. + pub fn pto_fired(&self) -> bool { + self.pto + } + + /// On PTO, we need to get the recovery tokens so that we can ensure that + /// the frames we sent can be sent again in the PTO packet(s). Do that just once. + pub fn pto(&mut self) -> bool { + if self.pto || self.lost() { + false + } else { + self.pto = true; + true + } + } +} + +impl std::fmt::Display for PNSpace { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str(match self { + Self::Initial => "in", + Self::Handshake => "hs", + Self::ApplicationData => "ap", + }) + } +} + +/// `InsertionResult` tracks whether something was inserted for `PacketRange::add()`. +pub enum InsertionResult { + Largest, + Smallest, + NotInserted, +} + +#[derive(Clone, Debug, Default)] +pub struct PacketRange { + largest: PacketNumber, + smallest: PacketNumber, + ack_needed: bool, +} + +impl PacketRange { + /// Make a single packet range. + pub fn new(pn: PacketNumber) -> Self { + Self { + largest: pn, + smallest: pn, + ack_needed: true, + } + } + + /// Get the number of acknowleged packets in the range. + pub fn len(&self) -> u64 { + self.largest - self.smallest + 1 + } + + /// Returns whether this needs to be sent. + pub fn ack_needed(&self) -> bool { + self.ack_needed + } + + /// Return whether the given number is in the range. + pub fn contains(&self, pn: PacketNumber) -> bool { + (pn >= self.smallest) && (pn <= self.largest) + } + + /// Maybe add a packet number to the range. Returns true if it was added + /// at the small end (which indicates that this might need merging with a + /// preceding range). + pub fn add(&mut self, pn: PacketNumber) -> InsertionResult { + assert!(!self.contains(pn)); + // Only insert if this is adjacent the current range. + if (self.largest + 1) == pn { + qtrace!([self], "Adding largest {}", pn); + self.largest += 1; + self.ack_needed = true; + InsertionResult::Largest + } else if self.smallest == (pn + 1) { + qtrace!([self], "Adding smallest {}", pn); + self.smallest -= 1; + self.ack_needed = true; + InsertionResult::Smallest + } else { + InsertionResult::NotInserted + } + } + + /// Maybe merge a higher-numbered range into this. + fn merge_larger(&mut self, other: &Self) { + qinfo!([self], "Merging {}", other); + // This only works if they are immediately adjacent. + assert_eq!(self.largest + 1, other.smallest); + + self.largest = other.largest; + self.ack_needed = self.ack_needed || other.ack_needed; + } + + /// When a packet containing the range `other` is acknowledged, + /// clear the `ack_needed` attribute on this. + /// Requires that other is equal to this, or a larger range. + pub fn acknowledged(&mut self, other: &Self) { + if (other.smallest <= self.smallest) && (other.largest >= self.largest) { + self.ack_needed = false; + } + } +} + +impl ::std::fmt::Display for PacketRange { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}->{}", self.largest, self.smallest) + } +} + +/// The ACK delay we use. +pub const ACK_DELAY: Duration = Duration::from_millis(20); // 20ms +pub const MAX_UNACKED_PKTS: usize = 1; +const MAX_TRACKED_RANGES: usize = 32; +const MAX_ACKS_PER_FRAME: usize = 32; + +/// A structure that tracks what was included in an ACK. +#[derive(Debug, Clone)] +pub struct AckToken { + space: PNSpace, + ranges: Vec<PacketRange>, +} + +/// A structure that tracks what packets have been received, +/// and what needs acknowledgement for a packet number space. +#[derive(Debug)] +pub struct RecvdPackets { + space: PNSpace, + ranges: VecDeque<PacketRange>, + /// The packet number of the lowest number packet that we are tracking. + min_tracked: PacketNumber, + /// The time we got the largest acknowledged. + largest_pn_time: Option<Instant>, + // The time that we should be sending an ACK. + ack_time: Option<Instant>, + pkts_since_last_ack: usize, +} + +impl RecvdPackets { + /// Make a new `RecvdPackets` for the indicated packet number space. + pub fn new(space: PNSpace) -> Self { + Self { + space, + ranges: VecDeque::new(), + min_tracked: 0, + largest_pn_time: None, + ack_time: None, + pkts_since_last_ack: 0, + } + } + + /// Get the time at which the next ACK should be sent. + pub fn ack_time(&self) -> Option<Instant> { + self.ack_time + } + + /// Returns true if an ACK frame should be sent now. + fn ack_now(&self, now: Instant) -> bool { + match self.ack_time { + Some(t) => t <= now, + None => false, + } + } + + // A simple addition of a packet number to the tracked set. + // This doesn't do a binary search on the assumption that + // new packets will generally be added to the start of the list. + fn add(&mut self, pn: PacketNumber) { + for i in 0..self.ranges.len() { + match self.ranges[i].add(pn) { + InsertionResult::Largest => return, + InsertionResult::Smallest => { + // If this was the smallest, it might have filled a gap. + let nxt = i + 1; + if (nxt < self.ranges.len()) && (pn - 1 == self.ranges[nxt].largest) { + let larger = self.ranges.remove(i).unwrap(); + self.ranges[i].merge_larger(&larger); + } + return; + } + InsertionResult::NotInserted => { + if self.ranges[i].largest < pn { + self.ranges.insert(i, PacketRange::new(pn)); + return; + } + } + } + } + self.ranges.push_back(PacketRange::new(pn)); + } + + fn trim_ranges(&mut self) { + // Limit the number of ranges that are tracked to MAX_TRACKED_RANGES. + if self.ranges.len() > MAX_TRACKED_RANGES { + let oldest = self.ranges.pop_back().unwrap(); + if oldest.ack_needed { + qwarn!([self], "Dropping unacknowledged ACK range: {}", oldest); + // TODO(mt) Record some statistics about this so we can tune MAX_TRACKED_RANGES. + } else { + qdebug!([self], "Drop ACK range: {}", oldest); + } + self.min_tracked = oldest.largest + 1; + } + } + + /// Add the packet to the tracked set. + pub fn set_received(&mut self, now: Instant, pn: PacketNumber, ack_eliciting: bool) { + let next_in_order_pn = self.ranges.front().map_or(0, |pr| pr.largest + 1); + qdebug!( + [self], + "received {}, next in order pn: {}", + pn, + next_in_order_pn + ); + + self.add(pn); + self.trim_ranges(); + + // The new addition was the largest, so update the time we use for calculating ACK delay. + if pn >= next_in_order_pn { + self.largest_pn_time = Some(now); + } + + if ack_eliciting { + self.pkts_since_last_ack += 1; + + // Send ACK right away if out-of-order + // On the first in-order ack-eliciting packet since sending an ACK, + // set a delay. + // Count packets until we exceed MAX_UNACKED_PKTS, then remove the + // delay. + if pn != next_in_order_pn { + self.ack_time = Some(now); + } else if self.space == PNSpace::ApplicationData { + match &mut self.pkts_since_last_ack { + 0 => unreachable!(), + 1 => self.ack_time = Some(now + ACK_DELAY), + x if *x > MAX_UNACKED_PKTS => self.ack_time = Some(now), + _ => debug_assert!(self.ack_time.is_some()), + } + } else { + self.ack_time = Some(now); + } + qdebug!([self], "Set ACK timer to {:?}", self.ack_time); + } + } + + /// Check if the packet is a duplicate. + pub fn is_duplicate(&self, pn: PacketNumber) -> bool { + if pn < self.min_tracked { + return true; + } + // TODO(mt) consider a binary search or early exit. + for range in &self.ranges { + if range.contains(pn) { + return true; + } + } + false + } + + /// Mark the given range as having been acknowledged. + pub fn acknowledged(&mut self, acked: &[PacketRange]) { + let mut range_iter = self.ranges.iter_mut(); + let mut cur = range_iter.next().expect("should have at least one range"); + for ack in acked { + while cur.smallest > ack.largest { + cur = match range_iter.next() { + Some(c) => c, + None => return, + }; + } + cur.acknowledged(&ack); + } + } + + /// Generate an ACK frame for this packet number space. + /// + /// Unlike other frame generators this doesn't modify the underlying instance + /// to track what has been sent. This only clears the delayed ACK timer. + /// + /// When sending ACKs, we want to always send the most recent ranges, + /// even if they have been sent in other packets. + /// + /// We don't send ranges that have been acknowledged, but they still need + /// to be tracked so that duplicates can be detected. + fn write_frame( + &mut self, + now: Instant, + builder: &mut PacketBuilder, + stats: &mut FrameStats, + ) -> Option<RecoveryToken> { + // The worst possible ACK frame, assuming only one range. + // Note that this assumes one byte for the type and count of extra ranges. + const LONGEST_ACK_HEADER: usize = 1 + 8 + 8 + 1 + 8; + + // Check that we aren't delaying ACKs. + if !self.ack_now(now) { + return None; + } + + // Drop extra ACK ranges to fit the available space. Do this based on + // a worst-case estimate of frame size for simplicity. + // + // When congestion limited, ACK-only packets are 255 bytes at most + // (`recovery::ACK_ONLY_SIZE_LIMIT - 1`). This results in limiting the + // ranges to 13 here. + let max_ranges = if let Some(avail) = builder.remaining().checked_sub(LONGEST_ACK_HEADER) { + // Apply a hard maximum to keep plenty of space for other stuff. + min(1 + (avail / 16), MAX_ACKS_PER_FRAME) + } else { + return None; + }; + + let ranges = self + .ranges + .iter() + .filter(|r| r.ack_needed()) + .take(max_ranges) + .cloned() + .collect::<Vec<_>>(); + + builder.encode_varint(crate::frame::FRAME_TYPE_ACK); + let mut iter = ranges.iter(); + let first = match iter.next() { + Some(v) => v, + None => return None, // Nothing to send. + }; + builder.encode_varint(first.largest); + stats.largest_acknowledged = first.largest; + stats.ack += 1; + + let elapsed = now.duration_since(self.largest_pn_time.unwrap()); + // We use the default exponent, so delay is in multiples of 8 microseconds. + let ack_delay = u64::try_from(elapsed.as_micros() / 8).unwrap_or(u64::MAX); + let ack_delay = min((1 << 62) - 1, ack_delay); + builder.encode_varint(ack_delay); + builder.encode_varint(u64::try_from(ranges.len() - 1).unwrap()); // extra ranges + builder.encode_varint(first.len() - 1); // first range + + let mut last = first.smallest; + for r in iter { + // the difference must be at least 2 because 0-length gaps, + // (difference 1) are illegal. + builder.encode_varint(last - r.largest - 2); // Gap + builder.encode_varint(r.len() - 1); // Range + last = r.smallest; + } + + // We've sent an ACK, reset the timer. + self.ack_time = None; + self.pkts_since_last_ack = 0; + + Some(RecoveryToken::Ack(AckToken { + space: self.space, + ranges, + })) + } +} + +impl ::std::fmt::Display for RecvdPackets { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Recvd-{}", self.space) + } +} + +#[derive(Debug)] +pub struct AckTracker { + /// This stores information about received packets in *reverse* order + /// by spaces. Why reverse? Because we ultimately only want to keep + /// `ApplicationData` and this allows us to drop other spaces easily. + spaces: SmallVec<[RecvdPackets; 1]>, +} + +impl AckTracker { + pub fn drop_space(&mut self, space: PNSpace) { + let sp = match space { + PNSpace::Initial => self.spaces.pop(), + PNSpace::Handshake => { + let sp = self.spaces.pop(); + self.spaces.shrink_to_fit(); + sp + } + PNSpace::ApplicationData => panic!("discarding application space"), + }; + assert_eq!(sp.unwrap().space, space, "dropping spaces out of order"); + } + + pub fn get_mut(&mut self, space: PNSpace) -> Option<&mut RecvdPackets> { + self.spaces.get_mut(match space { + PNSpace::ApplicationData => 0, + PNSpace::Handshake => 1, + PNSpace::Initial => 2, + }) + } + + /// Determine the earliest time that an ACK might be needed. + pub fn ack_time(&self, now: Instant) -> Option<Instant> { + if self.spaces.len() == 1 { + self.spaces[0].ack_time() + } else { + // Ignore any time that is in the past relative to `now`. + // That is something of a hack, but there are cases where we can't send ACK + // frames for all spaces, which can mean that one space is stuck in the past. + // That isn't a problem because we guarantee that earlier spaces will always + // be able to send ACK frames. + self.spaces + .iter() + .filter_map(|recvd| recvd.ack_time().filter(|t| *t > now)) + .min() + } + } + + pub fn acked(&mut self, token: &AckToken) { + if let Some(space) = self.get_mut(token.space) { + space.acknowledged(&token.ranges); + } + } + + pub(crate) fn write_frame( + &mut self, + pn_space: PNSpace, + now: Instant, + builder: &mut PacketBuilder, + stats: &mut FrameStats, + ) -> Option<RecoveryToken> { + self.get_mut(pn_space) + .and_then(|space| space.write_frame(now, builder, stats)) + } +} + +impl Default for AckTracker { + fn default() -> Self { + Self { + spaces: smallvec![ + RecvdPackets::new(PNSpace::ApplicationData), + RecvdPackets::new(PNSpace::Handshake), + RecvdPackets::new(PNSpace::Initial), + ], + } + } +} + +#[cfg(test)] +mod tests { + use super::{ + AckTracker, Duration, Instant, PNSpace, PNSpaceSet, RecoveryToken, RecvdPackets, ACK_DELAY, + MAX_TRACKED_RANGES, MAX_UNACKED_PKTS, + }; + use crate::frame::Frame; + use crate::packet::PacketBuilder; + use crate::stats::FrameStats; + use lazy_static::lazy_static; + use neqo_common::Encoder; + use std::collections::HashSet; + use std::convert::TryFrom; + + lazy_static! { + static ref NOW: Instant = Instant::now(); + } + + fn test_ack_range(pns: &[u64], nranges: usize) { + let mut rp = RecvdPackets::new(PNSpace::Initial); // Any space will do. + let mut packets = HashSet::new(); + + for pn in pns { + rp.set_received(*NOW, *pn, true); + packets.insert(*pn); + } + + assert_eq!(rp.ranges.len(), nranges); + + // Check that all these packets will be detected as duplicates. + for pn in pns { + assert!(rp.is_duplicate(*pn)); + } + + // Check that the ranges decrease monotonically and don't overlap. + let mut iter = rp.ranges.iter(); + let mut last = iter.next().expect("should have at least one"); + for n in iter { + assert!(n.largest + 1 < last.smallest); + last = n; + } + + // Check that the ranges include the right values. + let mut in_ranges = HashSet::new(); + for range in &rp.ranges { + for included in range.smallest..=range.largest { + in_ranges.insert(included); + } + } + assert_eq!(packets, in_ranges); + } + + #[test] + fn pn0() { + test_ack_range(&[0], 1); + } + + #[test] + fn pn1() { + test_ack_range(&[1], 1); + } + + #[test] + fn two_ranges() { + test_ack_range(&[0, 1, 2, 5, 6, 7], 2); + } + + #[test] + fn fill_in_range() { + test_ack_range(&[0, 1, 2, 5, 6, 7, 3, 4], 1); + } + + #[test] + fn too_many_ranges() { + let mut rp = RecvdPackets::new(PNSpace::Initial); // Any space will do. + + // This will add one too many disjoint ranges. + for i in 0..=MAX_TRACKED_RANGES { + rp.set_received(*NOW, (i * 2) as u64, true); + } + + assert_eq!(rp.ranges.len(), MAX_TRACKED_RANGES); + assert_eq!(rp.ranges.back().unwrap().largest, 2); + + // Even though the range was dropped, we still consider it a duplicate. + assert!(rp.is_duplicate(0)); + assert!(!rp.is_duplicate(1)); + assert!(rp.is_duplicate(2)); + } + + #[test] + fn ack_delay() { + // Only application data packets are delayed. + let mut rp = RecvdPackets::new(PNSpace::ApplicationData); + assert!(rp.ack_time().is_none()); + assert!(!rp.ack_now(*NOW)); + + // Some packets won't cause an ACK to be needed. + let max_unacked = u64::try_from(MAX_UNACKED_PKTS).unwrap(); + for num in 0..max_unacked { + rp.set_received(*NOW, num, true); + assert_eq!(Some(*NOW + ACK_DELAY), rp.ack_time()); + assert!(!rp.ack_now(*NOW)); + assert!(rp.ack_now(*NOW + ACK_DELAY)); + } + + // Exceeding MAX_UNACKED_PKTS will move the ACK time to now. + rp.set_received(*NOW, max_unacked, true); + assert_eq!(Some(*NOW), rp.ack_time()); + assert!(rp.ack_now(*NOW)); + } + + #[test] + fn no_ack_delay() { + for space in &[PNSpace::Initial, PNSpace::Handshake] { + let mut rp = RecvdPackets::new(*space); + assert!(rp.ack_time().is_none()); + assert!(!rp.ack_now(*NOW)); + + // Any packet will be acknowledged straight away. + rp.set_received(*NOW, 0, true); + assert_eq!(Some(*NOW), rp.ack_time()); + assert!(rp.ack_now(*NOW)); + } + } + + #[test] + fn ooo_no_ack_delay() { + for space in &[ + PNSpace::Initial, + PNSpace::Handshake, + PNSpace::ApplicationData, + ] { + let mut rp = RecvdPackets::new(*space); + assert!(rp.ack_time().is_none()); + assert!(!rp.ack_now(*NOW)); + + // Any OoO packet will be acknowledged straight away. + rp.set_received(*NOW, 3, true); + assert_eq!(Some(*NOW), rp.ack_time()); + assert!(rp.ack_now(*NOW)); + } + } + + #[test] + fn aggregate_ack_time() { + let mut tracker = AckTracker::default(); + // This packet won't trigger an ACK. + tracker + .get_mut(PNSpace::Handshake) + .unwrap() + .set_received(*NOW, 0, false); + assert_eq!(None, tracker.ack_time(*NOW)); + + // This should be delayed. + tracker + .get_mut(PNSpace::ApplicationData) + .unwrap() + .set_received(*NOW, 0, true); + assert_eq!(Some(*NOW + ACK_DELAY), tracker.ack_time(*NOW)); + + // This should move the time forward. + let later = *NOW + ACK_DELAY.checked_div(2).unwrap(); + tracker + .get_mut(PNSpace::Initial) + .unwrap() + .set_received(later, 0, true); + assert_eq!(Some(later), tracker.ack_time(*NOW)); + } + + #[test] + #[should_panic(expected = "discarding application space")] + fn drop_app() { + let mut tracker = AckTracker::default(); + tracker.drop_space(PNSpace::ApplicationData); + } + + #[test] + #[should_panic(expected = "dropping spaces out of order")] + fn drop_out_of_order() { + let mut tracker = AckTracker::default(); + tracker.drop_space(PNSpace::Handshake); + } + + #[test] + fn drop_spaces() { + let mut tracker = AckTracker::default(); + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + tracker + .get_mut(PNSpace::Initial) + .unwrap() + .set_received(*NOW, 0, true); + // The reference time for `ack_time` has to be in the past or we filter out the timer. + assert!(tracker.ack_time(*NOW - Duration::from_millis(1)).is_some()); + let token = tracker.write_frame( + PNSpace::Initial, + *NOW, + &mut builder, + &mut FrameStats::default(), + ); + assert!(token.is_some()); + + // Mark another packet as received so we have cause to send another ACK in that space. + tracker + .get_mut(PNSpace::Initial) + .unwrap() + .set_received(*NOW, 1, true); + assert!(tracker.ack_time(*NOW - Duration::from_millis(1)).is_some()); + + // Now drop that space. + tracker.drop_space(PNSpace::Initial); + + assert!(tracker.get_mut(PNSpace::Initial).is_none()); + assert!(tracker.ack_time(*NOW - Duration::from_millis(1)).is_none()); + assert!(tracker + .write_frame( + PNSpace::Initial, + *NOW, + &mut builder, + &mut FrameStats::default() + ) + .is_none()); + if let RecoveryToken::Ack(tok) = token.unwrap() { + tracker.acked(&tok); // Should be a noop. + } else { + panic!("not an ACK token"); + } + } + + #[test] + fn no_room_for_ack() { + let mut tracker = AckTracker::default(); + tracker + .get_mut(PNSpace::Initial) + .unwrap() + .set_received(*NOW, 0, true); + assert!(tracker.ack_time(*NOW - Duration::from_millis(1)).is_some()); + + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + builder.set_limit(10); + + let token = tracker.write_frame( + PNSpace::Initial, + *NOW, + &mut builder, + &mut FrameStats::default(), + ); + assert!(token.is_none()); + assert_eq!(builder.len(), 1); // Only the short packet header has been added. + } + + #[test] + fn no_room_for_extra_range() { + let mut tracker = AckTracker::default(); + tracker + .get_mut(PNSpace::Initial) + .unwrap() + .set_received(*NOW, 0, true); + tracker + .get_mut(PNSpace::Initial) + .unwrap() + .set_received(*NOW, 2, true); + assert!(tracker.ack_time(*NOW - Duration::from_millis(1)).is_some()); + + let mut builder = PacketBuilder::short(Encoder::new(), false, &[]); + builder.set_limit(32); + + let token = tracker.write_frame( + PNSpace::Initial, + *NOW, + &mut builder, + &mut FrameStats::default(), + ); + assert!(token.is_some()); + + let mut dec = builder.as_decoder(); + let _ = dec.decode_byte().unwrap(); // Skip the short header. + let frame = Frame::decode(&mut dec).unwrap(); + if let Frame::Ack { ack_ranges, .. } = frame { + assert_eq!(ack_ranges.len(), 0); + } else { + panic!("not an ACK!"); + } + } + + #[test] + fn ack_time_elapsed() { + let mut tracker = AckTracker::default(); + + // While we have multiple PN spaces, we ignore ACK timers from the past. + // Send out of order to cause the delayed ack timer to be set to `*NOW`. + tracker + .get_mut(PNSpace::ApplicationData) + .unwrap() + .set_received(*NOW, 3, true); + assert!(tracker.ack_time(*NOW + Duration::from_millis(1)).is_none()); + + // When we are reduced to one space, that filter is off. + tracker.drop_space(PNSpace::Initial); + tracker.drop_space(PNSpace::Handshake); + assert_eq!( + tracker.ack_time(*NOW + Duration::from_millis(1)), + Some(*NOW) + ); + } + + #[test] + fn pnspaceset_default() { + let set = PNSpaceSet::default(); + assert!(!set[PNSpace::Initial]); + assert!(!set[PNSpace::Handshake]); + assert!(!set[PNSpace::ApplicationData]); + } + + #[test] + fn pnspaceset_from() { + let set = PNSpaceSet::from(&[PNSpace::Initial]); + assert!(set[PNSpace::Initial]); + assert!(!set[PNSpace::Handshake]); + assert!(!set[PNSpace::ApplicationData]); + + let set = PNSpaceSet::from(&[PNSpace::Handshake, PNSpace::Initial]); + assert!(set[PNSpace::Initial]); + assert!(set[PNSpace::Handshake]); + assert!(!set[PNSpace::ApplicationData]); + + let set = PNSpaceSet::from(&[PNSpace::ApplicationData, PNSpace::ApplicationData]); + assert!(!set[PNSpace::Initial]); + assert!(!set[PNSpace::Handshake]); + assert!(set[PNSpace::ApplicationData]); + } + + #[test] + fn pnspaceset_copy() { + let set = PNSpaceSet::from(&[PNSpace::Handshake, PNSpace::ApplicationData]); + let copy = set; + assert!(!copy[PNSpace::Initial]); + assert!(copy[PNSpace::Handshake]); + assert!(copy[PNSpace::ApplicationData]); + } +} |