diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/neqo-transport/src/cc | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/neqo-transport/src/cc')
-rw-r--r-- | third_party/rust/neqo-transport/src/cc/classic_cc.rs | 1186 | ||||
-rw-r--r-- | third_party/rust/neqo-transport/src/cc/cubic.rs | 215 | ||||
-rw-r--r-- | third_party/rust/neqo-transport/src/cc/mod.rs | 87 | ||||
-rw-r--r-- | third_party/rust/neqo-transport/src/cc/new_reno.rs | 51 | ||||
-rw-r--r-- | third_party/rust/neqo-transport/src/cc/tests/cubic.rs | 333 | ||||
-rw-r--r-- | third_party/rust/neqo-transport/src/cc/tests/mod.rs | 7 | ||||
-rw-r--r-- | third_party/rust/neqo-transport/src/cc/tests/new_reno.rs | 219 |
7 files changed, 2098 insertions, 0 deletions
diff --git a/third_party/rust/neqo-transport/src/cc/classic_cc.rs b/third_party/rust/neqo-transport/src/cc/classic_cc.rs new file mode 100644 index 0000000000..6f4a01d795 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/classic_cc.rs @@ -0,0 +1,1186 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] + +use std::{ + cmp::{max, min}, + fmt::{self, Debug, Display}, + time::{Duration, Instant}, +}; + +use super::CongestionControl; +use crate::{ + cc::MAX_DATAGRAM_SIZE, + packet::PacketNumber, + qlog::{self, QlogMetric}, + rtt::RttEstimate, + sender::PACING_BURST_SIZE, + tracking::SentPacket, +}; +#[rustfmt::skip] // to keep `::` and thus prevent conflict with `crate::qlog` +use ::qlog::events::{quic::CongestionStateUpdated, EventData}; +use neqo_common::{const_max, const_min, qdebug, qinfo, qlog::NeqoQlog, qtrace}; + +pub const CWND_INITIAL_PKTS: usize = 10; +pub const CWND_INITIAL: usize = const_min( + CWND_INITIAL_PKTS * MAX_DATAGRAM_SIZE, + const_max(2 * MAX_DATAGRAM_SIZE, 14720), +); +pub const CWND_MIN: usize = MAX_DATAGRAM_SIZE * 2; +const PERSISTENT_CONG_THRESH: u32 = 3; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum State { + /// In either slow start or congestion avoidance, not recovery. + SlowStart, + /// In congestion avoidance. + CongestionAvoidance, + /// In a recovery period, but no packets have been sent yet. This is a + /// transient state because we want to exempt the first packet sent after + /// entering recovery from the congestion window. + RecoveryStart, + /// In a recovery period, with the first packet sent at this time. + Recovery, + /// Start of persistent congestion, which is transient, like `RecoveryStart`. + PersistentCongestion, +} + +impl State { + pub fn in_recovery(self) -> bool { + matches!(self, Self::RecoveryStart | Self::Recovery) + } + + pub fn in_slow_start(self) -> bool { + self == Self::SlowStart + } + + /// These states are transient, we tell qlog on entry, but not on exit. + pub fn transient(self) -> bool { + matches!(self, Self::RecoveryStart | Self::PersistentCongestion) + } + + /// Update a transient state to the true state. + pub fn update(&mut self) { + *self = match self { + Self::PersistentCongestion => Self::SlowStart, + Self::RecoveryStart => Self::Recovery, + _ => unreachable!(), + }; + } + + pub fn to_qlog(self) -> &'static str { + match self { + Self::SlowStart | Self::PersistentCongestion => "slow_start", + Self::CongestionAvoidance => "congestion_avoidance", + Self::Recovery | Self::RecoveryStart => "recovery", + } + } +} + +pub trait WindowAdjustment: Display + Debug { + /// This is called when an ack is received. + /// The function calculates the amount of acked bytes congestion controller needs + /// to collect before increasing its cwnd by `MAX_DATAGRAM_SIZE`. + fn bytes_for_cwnd_increase( + &mut self, + curr_cwnd: usize, + new_acked_bytes: usize, + min_rtt: Duration, + now: Instant, + ) -> usize; + /// This function is called when a congestion event has beed detected and it + /// returns new (decreased) values of `curr_cwnd` and `acked_bytes`. + /// This value can be very small; the calling code is responsible for ensuring that the + /// congestion window doesn't drop below the minimum of `CWND_MIN`. + fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize); + /// Cubic needs this signal to reset its epoch. + fn on_app_limited(&mut self); + #[cfg(test)] + fn last_max_cwnd(&self) -> f64; + #[cfg(test)] + fn set_last_max_cwnd(&mut self, last_max_cwnd: f64); +} + +#[derive(Debug)] +pub struct ClassicCongestionControl<T> { + cc_algorithm: T, + state: State, + congestion_window: usize, // = kInitialWindow + bytes_in_flight: usize, + acked_bytes: usize, + ssthresh: usize, + recovery_start: Option<PacketNumber>, + /// `first_app_limited` indicates the packet number after which the application might be + /// underutilizing the congestion window. When underutilizing the congestion window due to not + /// sending out enough data, we SHOULD NOT increase the congestion window.[1] Packets sent + /// before this point are deemed to fully utilize the congestion window and count towards + /// increasing the congestion window. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc9002#section-7.8 + first_app_limited: PacketNumber, + + qlog: NeqoQlog, +} + +impl<T: WindowAdjustment> Display for ClassicCongestionControl<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} CongCtrl {}/{} ssthresh {}", + self.cc_algorithm, self.bytes_in_flight, self.congestion_window, self.ssthresh, + )?; + Ok(()) + } +} + +impl<T: WindowAdjustment> CongestionControl for ClassicCongestionControl<T> { + fn set_qlog(&mut self, qlog: NeqoQlog) { + self.qlog = qlog; + } + + #[must_use] + fn cwnd(&self) -> usize { + self.congestion_window + } + + #[must_use] + fn bytes_in_flight(&self) -> usize { + self.bytes_in_flight + } + + #[must_use] + fn cwnd_avail(&self) -> usize { + // BIF can be higher than cwnd due to PTO packets, which are sent even + // if avail is 0, but still count towards BIF. + self.congestion_window.saturating_sub(self.bytes_in_flight) + } + + // Multi-packet version of OnPacketAckedCC + fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], rtt_est: &RttEstimate, now: Instant) { + let mut is_app_limited = true; + let mut new_acked = 0; + for pkt in acked_pkts { + qinfo!( + "packet_acked this={:p}, pn={}, ps={}, ignored={}, lost={}, rtt_est={:?}", + self, + pkt.pn, + pkt.size, + i32::from(!pkt.cc_outstanding()), + i32::from(pkt.lost()), + rtt_est, + ); + if !pkt.cc_outstanding() { + continue; + } + if pkt.pn < self.first_app_limited { + is_app_limited = false; + } + assert!(self.bytes_in_flight >= pkt.size); + self.bytes_in_flight -= pkt.size; + + if !self.after_recovery_start(pkt) { + // Do not increase congestion window for packets sent before + // recovery last started. + continue; + } + + if self.state.in_recovery() { + self.set_state(State::CongestionAvoidance); + qlog::metrics_updated(&mut self.qlog, &[QlogMetric::InRecovery(false)]); + } + + new_acked += pkt.size; + } + + if is_app_limited { + self.cc_algorithm.on_app_limited(); + qinfo!("on_packets_acked this={:p}, limited=1, bytes_in_flight={}, cwnd={}, state={:?}, new_acked={}", self, self.bytes_in_flight, self.congestion_window, self.state, new_acked); + return; + } + + // Slow start, up to the slow start threshold. + if self.congestion_window < self.ssthresh { + self.acked_bytes += new_acked; + let increase = min(self.ssthresh - self.congestion_window, self.acked_bytes); + self.congestion_window += increase; + self.acked_bytes -= increase; + qinfo!([self], "slow start += {}", increase); + if self.congestion_window == self.ssthresh { + // This doesn't look like it is necessary, but it can happen + // after persistent congestion. + self.set_state(State::CongestionAvoidance); + } + } + // Congestion avoidance, above the slow start threshold. + if self.congestion_window >= self.ssthresh { + // The following function return the amount acked bytes a controller needs + // to collect to be allowed to increase its cwnd by MAX_DATAGRAM_SIZE. + let bytes_for_increase = self.cc_algorithm.bytes_for_cwnd_increase( + self.congestion_window, + new_acked, + rtt_est.minimum(), + now, + ); + debug_assert!(bytes_for_increase > 0); + // If enough credit has been accumulated already, apply them gradually. + // If we have sudden increase in allowed rate we actually increase cwnd gently. + if self.acked_bytes >= bytes_for_increase { + self.acked_bytes = 0; + self.congestion_window += MAX_DATAGRAM_SIZE; + } + self.acked_bytes += new_acked; + if self.acked_bytes >= bytes_for_increase { + self.acked_bytes -= bytes_for_increase; + self.congestion_window += MAX_DATAGRAM_SIZE; // or is this the current MTU? + } + // The number of bytes we require can go down over time with Cubic. + // That might result in an excessive rate of increase, so limit the number of unused + // acknowledged bytes after increasing the congestion window twice. + self.acked_bytes = min(bytes_for_increase, self.acked_bytes); + } + qlog::metrics_updated( + &mut self.qlog, + &[ + QlogMetric::CongestionWindow(self.congestion_window), + QlogMetric::BytesInFlight(self.bytes_in_flight), + ], + ); + qinfo!([self], "on_packets_acked this={:p}, limited=0, bytes_in_flight={}, cwnd={}, state={:?}, new_acked={}", self, self.bytes_in_flight, self.congestion_window, self.state, new_acked); + } + + /// Update congestion controller state based on lost packets. + fn on_packets_lost( + &mut self, + first_rtt_sample_time: Option<Instant>, + prev_largest_acked_sent: Option<Instant>, + pto: Duration, + lost_packets: &[SentPacket], + ) -> bool { + if lost_packets.is_empty() { + return false; + } + + for pkt in lost_packets.iter().filter(|pkt| pkt.cc_in_flight()) { + qinfo!( + "packet_lost this={:p}, pn={}, ps={}", + self, + pkt.pn, + pkt.size + ); + assert!(self.bytes_in_flight >= pkt.size); + self.bytes_in_flight -= pkt.size; + } + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::BytesInFlight(self.bytes_in_flight)], + ); + + let congestion = self.on_congestion_event(lost_packets.last().unwrap()); + let persistent_congestion = self.detect_persistent_congestion( + first_rtt_sample_time, + prev_largest_acked_sent, + pto, + lost_packets, + ); + qinfo!( + "on_packets_lost this={:p}, bytes_in_flight={}, cwnd={}, state={:?}", + self, + self.bytes_in_flight, + self.congestion_window, + self.state + ); + congestion || persistent_congestion + } + + fn discard(&mut self, pkt: &SentPacket) { + if pkt.cc_outstanding() { + assert!(self.bytes_in_flight >= pkt.size); + self.bytes_in_flight -= pkt.size; + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::BytesInFlight(self.bytes_in_flight)], + ); + qtrace!([self], "Ignore pkt with size {}", pkt.size); + } + } + + fn discard_in_flight(&mut self) { + self.bytes_in_flight = 0; + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::BytesInFlight(self.bytes_in_flight)], + ); + } + + fn on_packet_sent(&mut self, pkt: &SentPacket) { + // Record the recovery time and exit any transient state. + if self.state.transient() { + self.recovery_start = Some(pkt.pn); + self.state.update(); + } + + if !pkt.cc_in_flight() { + return; + } + if !self.app_limited() { + // Given the current non-app-limited condition, we're fully utilizing the congestion + // window. Assume that all in-flight packets up to this one are NOT app-limited. + // However, subsequent packets might be app-limited. Set `first_app_limited` to the + // next packet number. + self.first_app_limited = pkt.pn + 1; + } + + self.bytes_in_flight += pkt.size; + qinfo!( + "packet_sent this={:p}, pn={}, ps={}", + self, + pkt.pn, + pkt.size + ); + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::BytesInFlight(self.bytes_in_flight)], + ); + } + + /// Whether a packet can be sent immediately as a result of entering recovery. + fn recovery_packet(&self) -> bool { + self.state == State::RecoveryStart + } +} + +impl<T: WindowAdjustment> ClassicCongestionControl<T> { + pub fn new(cc_algorithm: T) -> Self { + Self { + cc_algorithm, + state: State::SlowStart, + congestion_window: CWND_INITIAL, + bytes_in_flight: 0, + acked_bytes: 0, + ssthresh: usize::MAX, + recovery_start: None, + qlog: NeqoQlog::disabled(), + first_app_limited: 0, + } + } + + #[cfg(test)] + #[must_use] + pub fn ssthresh(&self) -> usize { + self.ssthresh + } + + #[cfg(test)] + pub fn set_ssthresh(&mut self, v: usize) { + self.ssthresh = v; + } + + #[cfg(test)] + pub fn last_max_cwnd(&self) -> f64 { + self.cc_algorithm.last_max_cwnd() + } + + #[cfg(test)] + pub fn set_last_max_cwnd(&mut self, last_max_cwnd: f64) { + self.cc_algorithm.set_last_max_cwnd(last_max_cwnd); + } + + #[cfg(test)] + pub fn acked_bytes(&self) -> usize { + self.acked_bytes + } + + fn set_state(&mut self, state: State) { + if self.state != state { + qdebug!([self], "state -> {:?}", state); + let old_state = self.state; + self.qlog.add_event_data(|| { + // No need to tell qlog about exit from transient states. + if old_state.transient() { + None + } else { + let ev_data = EventData::CongestionStateUpdated(CongestionStateUpdated { + old: Some(old_state.to_qlog().to_owned()), + new: state.to_qlog().to_owned(), + trigger: None, + }); + Some(ev_data) + } + }); + self.state = state; + } + } + + fn detect_persistent_congestion( + &mut self, + first_rtt_sample_time: Option<Instant>, + prev_largest_acked_sent: Option<Instant>, + pto: Duration, + lost_packets: &[SentPacket], + ) -> bool { + if first_rtt_sample_time.is_none() { + return false; + } + + let pc_period = pto * PERSISTENT_CONG_THRESH; + + let mut last_pn = 1 << 62; // Impossibly large, but not enough to overflow. + let mut start = None; + + // Look for the first lost packet after the previous largest acknowledged. + // Ignore packets that weren't ack-eliciting for the start of this range. + // Also, make sure to ignore any packets sent before we got an RTT estimate + // as we might not have sent PTO packets soon enough after those. + let cutoff = max(first_rtt_sample_time, prev_largest_acked_sent); + for p in lost_packets + .iter() + .skip_while(|p| Some(p.time_sent) < cutoff) + { + if p.pn != last_pn + 1 { + // Not a contiguous range of lost packets, start over. + start = None; + } + last_pn = p.pn; + if !p.cc_in_flight() { + // Not interesting, keep looking. + continue; + } + if let Some(t) = start { + let elapsed = p + .time_sent + .checked_duration_since(t) + .expect("time is monotonic"); + if elapsed > pc_period { + qinfo!([self], "persistent congestion"); + self.congestion_window = CWND_MIN; + self.acked_bytes = 0; + self.set_state(State::PersistentCongestion); + qlog::metrics_updated( + &mut self.qlog, + &[QlogMetric::CongestionWindow(self.congestion_window)], + ); + return true; + } + } else { + start = Some(p.time_sent); + } + } + false + } + + #[must_use] + fn after_recovery_start(&mut self, packet: &SentPacket) -> bool { + // At the start of the recovery period, the state is transient and + // all packets will have been sent before recovery. When sending out + // the first packet we transition to the non-transient `Recovery` + // state and update the variable `self.recovery_start`. Before the + // first recovery, all packets were sent after the recovery event, + // allowing to reduce the cwnd on congestion events. + !self.state.transient() && self.recovery_start.map_or(true, |pn| packet.pn >= pn) + } + + /// Handle a congestion event. + /// Returns true if this was a true congestion event. + fn on_congestion_event(&mut self, last_packet: &SentPacket) -> bool { + // Start a new congestion event if lost packet was sent after the start + // of the previous congestion recovery period. + if !self.after_recovery_start(last_packet) { + return false; + } + + let (cwnd, acked_bytes) = self + .cc_algorithm + .reduce_cwnd(self.congestion_window, self.acked_bytes); + self.congestion_window = max(cwnd, CWND_MIN); + self.acked_bytes = acked_bytes; + self.ssthresh = self.congestion_window; + qinfo!( + [self], + "Cong event -> recovery; cwnd {}, ssthresh {}", + self.congestion_window, + self.ssthresh + ); + qlog::metrics_updated( + &mut self.qlog, + &[ + QlogMetric::CongestionWindow(self.congestion_window), + QlogMetric::SsThresh(self.ssthresh), + QlogMetric::InRecovery(true), + ], + ); + self.set_state(State::RecoveryStart); + true + } + + #[allow(clippy::unused_self)] + fn app_limited(&self) -> bool { + if self.bytes_in_flight >= self.congestion_window { + false + } else if self.state.in_slow_start() { + // Allow for potential doubling of the congestion window during slow start. + // That is, the application might not have been able to send enough to respond + // to increases to the congestion window. + self.bytes_in_flight < self.congestion_window / 2 + } else { + // We're not limited if the in-flight data is within a single burst of the + // congestion window. + (self.bytes_in_flight + MAX_DATAGRAM_SIZE * PACING_BURST_SIZE) < self.congestion_window + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + convert::TryFrom, + time::{Duration, Instant}, + }; + + use neqo_common::qinfo; + use test_fixture::now; + + use super::{ + ClassicCongestionControl, WindowAdjustment, CWND_INITIAL, CWND_MIN, PERSISTENT_CONG_THRESH, + }; + use crate::{ + cc::{ + classic_cc::State, + cubic::{Cubic, CUBIC_BETA_USIZE_DIVIDEND, CUBIC_BETA_USIZE_DIVISOR}, + new_reno::NewReno, + CongestionControl, CongestionControlAlgorithm, CWND_INITIAL_PKTS, MAX_DATAGRAM_SIZE, + }, + packet::{PacketNumber, PacketType}, + rtt::RttEstimate, + tracking::SentPacket, + }; + + const PTO: Duration = Duration::from_millis(100); + const RTT: Duration = Duration::from_millis(98); + const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(98)); + const ZERO: Duration = Duration::from_secs(0); + const EPSILON: Duration = Duration::from_nanos(1); + const GAP: Duration = Duration::from_secs(1); + /// The largest time between packets without causing persistent congestion. + const SUB_PC: Duration = Duration::from_millis(100 * PERSISTENT_CONG_THRESH as u64); + /// The minimum time between packets to cause persistent congestion. + /// Uses an odd expression because `Duration` arithmetic isn't `const`. + const PC: Duration = Duration::from_nanos(100_000_000 * (PERSISTENT_CONG_THRESH as u64) + 1); + + fn cwnd_is_default(cc: &ClassicCongestionControl<NewReno>) { + assert_eq!(cc.cwnd(), CWND_INITIAL); + assert_eq!(cc.ssthresh(), usize::MAX); + } + + fn cwnd_is_halved(cc: &ClassicCongestionControl<NewReno>) { + assert_eq!(cc.cwnd(), CWND_INITIAL / 2); + assert_eq!(cc.ssthresh(), CWND_INITIAL / 2); + } + + fn lost(pn: PacketNumber, ack_eliciting: bool, t: Duration) -> SentPacket { + SentPacket::new( + PacketType::Short, + pn, + now() + t, + ack_eliciting, + Vec::new(), + 100, + ) + } + + fn congestion_control(cc: CongestionControlAlgorithm) -> Box<dyn CongestionControl> { + match cc { + CongestionControlAlgorithm::NewReno => { + Box::new(ClassicCongestionControl::new(NewReno::default())) + } + CongestionControlAlgorithm::Cubic => { + Box::new(ClassicCongestionControl::new(Cubic::default())) + } + } + } + + fn persistent_congestion_by_algorithm( + cc_alg: CongestionControlAlgorithm, + reduced_cwnd: usize, + lost_packets: &[SentPacket], + persistent_expected: bool, + ) { + let mut cc = congestion_control(cc_alg); + for p in lost_packets { + cc.on_packet_sent(p); + } + + cc.on_packets_lost(Some(now()), None, PTO, lost_packets); + + let persistent = if cc.cwnd() == reduced_cwnd { + false + } else if cc.cwnd() == CWND_MIN { + true + } else { + panic!("unexpected cwnd"); + }; + assert_eq!(persistent, persistent_expected); + } + + fn persistent_congestion(lost_packets: &[SentPacket], persistent_expected: bool) { + persistent_congestion_by_algorithm( + CongestionControlAlgorithm::NewReno, + CWND_INITIAL / 2, + lost_packets, + persistent_expected, + ); + persistent_congestion_by_algorithm( + CongestionControlAlgorithm::Cubic, + CWND_INITIAL * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR, + lost_packets, + persistent_expected, + ); + } + + /// A span of exactly the PC threshold only reduces the window on loss. + #[test] + fn persistent_congestion_none() { + persistent_congestion(&[lost(1, true, ZERO), lost(2, true, SUB_PC)], false); + } + + /// A span of just more than the PC threshold causes persistent congestion. + #[test] + fn persistent_congestion_simple() { + persistent_congestion(&[lost(1, true, ZERO), lost(2, true, PC)], true); + } + + /// Both packets need to be ack-eliciting. + #[test] + fn persistent_congestion_non_ack_eliciting() { + persistent_congestion(&[lost(1, false, ZERO), lost(2, true, PC)], false); + persistent_congestion(&[lost(1, true, ZERO), lost(2, false, PC)], false); + } + + /// Packets in the middle, of any type, are OK. + #[test] + fn persistent_congestion_middle() { + persistent_congestion( + &[lost(1, true, ZERO), lost(2, false, RTT), lost(3, true, PC)], + true, + ); + persistent_congestion( + &[lost(1, true, ZERO), lost(2, true, RTT), lost(3, true, PC)], + true, + ); + } + + /// Leading non-ack-eliciting packets are skipped. + #[test] + fn persistent_congestion_leading_non_ack_eliciting() { + persistent_congestion( + &[lost(1, false, ZERO), lost(2, true, RTT), lost(3, true, PC)], + false, + ); + persistent_congestion( + &[ + lost(1, false, ZERO), + lost(2, true, RTT), + lost(3, true, RTT + PC), + ], + true, + ); + } + + /// Trailing non-ack-eliciting packets aren't relevant. + #[test] + fn persistent_congestion_trailing_non_ack_eliciting() { + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PC), + lost(3, false, PC + EPSILON), + ], + true, + ); + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, SUB_PC), + lost(3, false, PC), + ], + false, + ); + } + + /// Gaps in the middle, of any type, restart the count. + #[test] + fn persistent_congestion_gap_reset() { + persistent_congestion(&[lost(1, true, ZERO), lost(3, true, PC)], false); + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, RTT), + lost(4, true, GAP), + lost(5, true, GAP + PTO * PERSISTENT_CONG_THRESH), + ], + false, + ); + } + + /// A span either side of a gap will cause persistent congestion. + #[test] + fn persistent_congestion_gap_or() { + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PC), + lost(4, true, GAP), + lost(5, true, GAP + PTO), + ], + true, + ); + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, true, GAP), + lost(5, true, GAP + PC), + ], + true, + ); + } + + /// A gap only restarts after an ack-eliciting packet. + #[test] + fn persistent_congestion_gap_non_ack_eliciting() { + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, false, GAP), + lost(5, true, GAP + PC), + ], + false, + ); + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, false, GAP), + lost(5, true, GAP + RTT), + lost(6, true, GAP + RTT + SUB_PC), + ], + false, + ); + persistent_congestion( + &[ + lost(1, true, ZERO), + lost(2, true, PTO), + lost(4, false, GAP), + lost(5, true, GAP + RTT), + lost(6, true, GAP + RTT + PC), + ], + true, + ); + } + + /// Get a time, in multiples of `PTO`, relative to `now()`. + fn by_pto(t: u32) -> Instant { + now() + (PTO * t) + } + + /// Make packets that will be made lost. + /// `times` is the time of sending, in multiples of `PTO`, relative to `now()`. + fn make_lost(times: &[u32]) -> Vec<SentPacket> { + times + .iter() + .enumerate() + .map(|(i, &t)| { + SentPacket::new( + PacketType::Short, + u64::try_from(i).unwrap(), + by_pto(t), + true, + Vec::new(), + 1000, + ) + }) + .collect::<Vec<_>>() + } + + /// Call `detect_persistent_congestion` using times relative to now and the fixed PTO time. + /// `last_ack` and `rtt_time` are times in multiples of `PTO`, relative to `now()`, + /// for the time of the largest acknowledged and the first RTT sample, respectively. + fn persistent_congestion_by_pto<T: WindowAdjustment>( + mut cc: ClassicCongestionControl<T>, + last_ack: u32, + rtt_time: u32, + lost: &[SentPacket], + ) -> bool { + assert_eq!(cc.cwnd(), CWND_INITIAL); + + let last_ack = Some(by_pto(last_ack)); + let rtt_time = Some(by_pto(rtt_time)); + + // Persistent congestion is never declared if the RTT time is `None`. + cc.detect_persistent_congestion(None, None, PTO, lost); + assert_eq!(cc.cwnd(), CWND_INITIAL); + cc.detect_persistent_congestion(None, last_ack, PTO, lost); + assert_eq!(cc.cwnd(), CWND_INITIAL); + + cc.detect_persistent_congestion(rtt_time, last_ack, PTO, lost); + cc.cwnd() == CWND_MIN + } + + /// No persistent congestion can be had if there are no lost packets. + #[test] + fn persistent_congestion_no_lost() { + let lost = make_lost(&[]); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 0, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 0, + &lost + )); + } + + /// No persistent congestion can be had if there is only one lost packet. + #[test] + fn persistent_congestion_one_lost() { + let lost = make_lost(&[1]); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 0, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 0, + &lost + )); + } + + /// Persistent congestion can't happen based on old packets. + #[test] + fn persistent_congestion_past() { + // Packets sent prior to either the last acknowledged or the first RTT + // sample are not considered. So 0 is ignored. + let lost = make_lost(&[0, PERSISTENT_CONG_THRESH + 1, PERSISTENT_CONG_THRESH + 2]); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 1, + 1, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 1, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 1, + 0, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 1, + 1, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 1, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 1, + 0, + &lost + )); + } + + /// Persistent congestion doesn't start unless the packet is ack-eliciting. + #[test] + fn persistent_congestion_ack_eliciting() { + let mut lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); + lost[0] = SentPacket::new( + lost[0].pt, + lost[0].pn, + lost[0].time_sent, + false, + Vec::new(), + lost[0].size, + ); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 0, + &lost + )); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 0, + &lost + )); + } + + /// Detect persistent congestion. Note that the first lost packet needs to have a time + /// greater than the previously acknowledged packet AND the first RTT sample. And the + /// difference in times needs to be greater than the persistent congestion threshold. + #[test] + fn persistent_congestion_min() { + let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); + assert!(persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 0, + &lost + )); + assert!(persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 0, + &lost + )); + } + + /// Make sure that not having a previous largest acknowledged also results + /// in detecting persistent congestion. (This is not expected to happen, but + /// the code permits it). + #[test] + fn persistent_congestion_no_prev_ack_newreno() { + let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); + let mut cc = ClassicCongestionControl::new(NewReno::default()); + cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost); + assert_eq!(cc.cwnd(), CWND_MIN); + } + + #[test] + fn persistent_congestion_no_prev_ack_cubic() { + let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); + let mut cc = ClassicCongestionControl::new(Cubic::default()); + cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost); + assert_eq!(cc.cwnd(), CWND_MIN); + } + + /// The code asserts on ordering errors. + #[test] + #[should_panic(expected = "time is monotonic")] + fn persistent_congestion_unsorted_newreno() { + let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(NewReno::default()), + 0, + 0, + &lost + )); + } + + /// The code asserts on ordering errors. + #[test] + #[should_panic(expected = "time is monotonic")] + fn persistent_congestion_unsorted_cubic() { + let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]); + assert!(!persistent_congestion_by_pto( + ClassicCongestionControl::new(Cubic::default()), + 0, + 0, + &lost + )); + } + + #[test] + fn app_limited_slow_start() { + const BELOW_APP_LIMIT_PKTS: usize = 5; + const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1; + let mut cc = ClassicCongestionControl::new(NewReno::default()); + let cwnd = cc.congestion_window; + let mut now = now(); + let mut next_pn = 0; + + // simulate packet bursts below app_limit + for packet_burst_size in 1..=BELOW_APP_LIMIT_PKTS { + // always stay below app_limit during sent. + let mut pkts = Vec::new(); + for _ in 0..packet_burst_size { + let p = SentPacket::new( + PacketType::Short, + next_pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + next_pn += 1; + cc.on_packet_sent(&p); + pkts.push(p); + } + assert_eq!(cc.bytes_in_flight(), packet_burst_size * MAX_DATAGRAM_SIZE); + now += RTT; + cc.on_packets_acked(&pkts, &RTT_ESTIMATE, now); + assert_eq!(cc.bytes_in_flight(), 0); + assert_eq!(cc.acked_bytes, 0); + assert_eq!(cwnd, cc.congestion_window); // CWND doesn't grow because we're app limited + } + + // Fully utilize the congestion window by sending enough packets to + // have `bytes_in_flight` above the `app_limited` threshold. + let mut pkts = Vec::new(); + for _ in 0..ABOVE_APP_LIMIT_PKTS { + let p = SentPacket::new( + PacketType::Short, + next_pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + next_pn += 1; + cc.on_packet_sent(&p); + pkts.push(p); + } + assert_eq!( + cc.bytes_in_flight(), + ABOVE_APP_LIMIT_PKTS * MAX_DATAGRAM_SIZE + ); + now += RTT; + // Check if congestion window gets increased for all packets currently in flight + for (i, pkt) in pkts.into_iter().enumerate() { + cc.on_packets_acked(&[pkt], &RTT_ESTIMATE, now); + + assert_eq!( + cc.bytes_in_flight(), + (ABOVE_APP_LIMIT_PKTS - i - 1) * MAX_DATAGRAM_SIZE + ); + // increase acked_bytes with each packet + qinfo!("{} {}", cc.congestion_window, cwnd + i * MAX_DATAGRAM_SIZE); + assert_eq!(cc.congestion_window, cwnd + (i + 1) * MAX_DATAGRAM_SIZE); + assert_eq!(cc.acked_bytes, 0); + } + } + + #[test] + fn app_limited_congestion_avoidance() { + const CWND_PKTS_CA: usize = CWND_INITIAL_PKTS / 2; + const BELOW_APP_LIMIT_PKTS: usize = CWND_PKTS_CA - 2; + const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1; + + let mut cc = ClassicCongestionControl::new(NewReno::default()); + let mut now = now(); + + // Change state to congestion avoidance by introducing loss. + + let p_lost = SentPacket::new( + PacketType::Short, + 1, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packet_sent(&p_lost); + cwnd_is_default(&cc); + now += PTO; + cc.on_packets_lost(Some(now), None, PTO, &[p_lost]); + cwnd_is_halved(&cc); + let p_not_lost = SentPacket::new( + PacketType::Short, + 2, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packet_sent(&p_not_lost); + now += RTT; + cc.on_packets_acked(&[p_not_lost], &RTT_ESTIMATE, now); + cwnd_is_halved(&cc); + // cc is app limited therefore cwnd in not increased. + assert_eq!(cc.acked_bytes, 0); + + // Now we are in the congestion avoidance state. + assert_eq!(cc.state, State::CongestionAvoidance); + // simulate packet bursts below app_limit + let mut next_pn = 3; + for packet_burst_size in 1..=BELOW_APP_LIMIT_PKTS { + // always stay below app_limit during sent. + let mut pkts = Vec::new(); + for _ in 0..packet_burst_size { + let p = SentPacket::new( + PacketType::Short, + next_pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + next_pn += 1; + cc.on_packet_sent(&p); + pkts.push(p); + } + assert_eq!(cc.bytes_in_flight(), packet_burst_size * MAX_DATAGRAM_SIZE); + now += RTT; + for (i, pkt) in pkts.into_iter().enumerate() { + cc.on_packets_acked(&[pkt], &RTT_ESTIMATE, now); + + assert_eq!( + cc.bytes_in_flight(), + (packet_burst_size - i - 1) * MAX_DATAGRAM_SIZE + ); + cwnd_is_halved(&cc); // CWND doesn't grow because we're app limited + assert_eq!(cc.acked_bytes, 0); + } + } + + // Fully utilize the congestion window by sending enough packets to + // have `bytes_in_flight` above the `app_limited` threshold. + let mut pkts = Vec::new(); + for _ in 0..ABOVE_APP_LIMIT_PKTS { + let p = SentPacket::new( + PacketType::Short, + next_pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + next_pn += 1; + cc.on_packet_sent(&p); + pkts.push(p); + } + assert_eq!( + cc.bytes_in_flight(), + ABOVE_APP_LIMIT_PKTS * MAX_DATAGRAM_SIZE + ); + now += RTT; + let mut last_acked_bytes = 0; + // Check if congestion window gets increased for all packets currently in flight + for (i, pkt) in pkts.into_iter().enumerate() { + cc.on_packets_acked(&[pkt], &RTT_ESTIMATE, now); + + assert_eq!( + cc.bytes_in_flight(), + (ABOVE_APP_LIMIT_PKTS - i - 1) * MAX_DATAGRAM_SIZE + ); + // The cwnd doesn't increase, but the acked_bytes do, which will eventually lead to an + // increase, once the number of bytes reaches the necessary level + cwnd_is_halved(&cc); + // increase acked_bytes with each packet + assert_ne!(cc.acked_bytes, last_acked_bytes); + last_acked_bytes = cc.acked_bytes; + } + } +} diff --git a/third_party/rust/neqo-transport/src/cc/cubic.rs b/third_party/rust/neqo-transport/src/cc/cubic.rs new file mode 100644 index 0000000000..c04a29b443 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/cubic.rs @@ -0,0 +1,215 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![deny(clippy::pedantic)] + +use std::{ + convert::TryFrom, + fmt::{self, Display}, + time::{Duration, Instant}, +}; + +use neqo_common::qtrace; + +use crate::cc::{classic_cc::WindowAdjustment, MAX_DATAGRAM_SIZE_F64}; + +// CUBIC congestion control + +// C is a constant fixed to determine the aggressiveness of window +// increase in high BDP networks. +pub const CUBIC_C: f64 = 0.4; +pub const CUBIC_ALPHA: f64 = 3.0 * (1.0 - 0.7) / (1.0 + 0.7); + +// CUBIC_BETA = 0.7; +pub const CUBIC_BETA_USIZE_DIVIDEND: usize = 7; +pub const CUBIC_BETA_USIZE_DIVISOR: usize = 10; + +/// The fast convergence ratio further reduces the congestion window when a congestion event +/// occurs before reaching the previous `W_max`. +pub const CUBIC_FAST_CONVERGENCE: f64 = 0.85; // (1.0 + CUBIC_BETA) / 2.0; + +/// The minimum number of multiples of the datagram size that need +/// to be received to cause an increase in the congestion window. +/// When there is no loss, Cubic can return to exponential increase, but +/// this value reduces the magnitude of the resulting growth by a constant factor. +/// A value of 1.0 would mean a return to the rate used in slow start. +const EXPONENTIAL_GROWTH_REDUCTION: f64 = 2.0; + +/// Convert an integer congestion window value into a floating point value. +/// This has the effect of reducing larger values to `1<<53`. +/// If you have a congestion window that large, something is probably wrong. +fn convert_to_f64(v: usize) -> f64 { + let mut f_64 = f64::from(u32::try_from(v >> 21).unwrap_or(u32::MAX)); + f_64 *= 2_097_152.0; // f_64 <<= 21 + f_64 += f64::from(u32::try_from(v & 0x1f_ffff).unwrap()); + f_64 +} + +#[derive(Debug)] +pub struct Cubic { + last_max_cwnd: f64, + estimated_tcp_cwnd: f64, + k: f64, + w_max: f64, + ca_epoch_start: Option<Instant>, + tcp_acked_bytes: f64, +} + +impl Default for Cubic { + fn default() -> Self { + Self { + last_max_cwnd: 0.0, + estimated_tcp_cwnd: 0.0, + k: 0.0, + w_max: 0.0, + ca_epoch_start: None, + tcp_acked_bytes: 0.0, + } + } +} + +impl Display for Cubic { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Cubic [last_max_cwnd: {}, k: {}, w_max: {}, ca_epoch_start: {:?}]", + self.last_max_cwnd, self.k, self.w_max, self.ca_epoch_start + )?; + Ok(()) + } +} + +#[allow(clippy::doc_markdown)] +impl Cubic { + /// Original equations is: + /// K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2 RFC8312) + /// W_max is number of segments of the maximum segment size (MSS). + /// + /// K is actually the time that W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) would + /// take to increase to W_max. We use bytes not MSS units, therefore this + /// equation will be: W_cubic(t) = C*MSS*(t-K)^3 + W_max. + /// + /// From that equation we can calculate K as: + /// K = cubic_root((W_max - W_cubic) / C / MSS); + fn calc_k(&self, curr_cwnd: f64) -> f64 { + ((self.w_max - curr_cwnd) / CUBIC_C / MAX_DATAGRAM_SIZE_F64).cbrt() + } + + /// W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) + /// t is relative to the start of the congestion avoidance phase and it is in seconds. + fn w_cubic(&self, t: f64) -> f64 { + CUBIC_C * (t - self.k).powi(3) * MAX_DATAGRAM_SIZE_F64 + self.w_max + } + + fn start_epoch(&mut self, curr_cwnd_f64: f64, new_acked_f64: f64, now: Instant) { + self.ca_epoch_start = Some(now); + // reset tcp_acked_bytes and estimated_tcp_cwnd; + self.tcp_acked_bytes = new_acked_f64; + self.estimated_tcp_cwnd = curr_cwnd_f64; + if self.last_max_cwnd <= curr_cwnd_f64 { + self.w_max = curr_cwnd_f64; + self.k = 0.0; + } else { + self.w_max = self.last_max_cwnd; + self.k = self.calc_k(curr_cwnd_f64); + } + qtrace!([self], "New epoch"); + } +} + +impl WindowAdjustment for Cubic { + // This is because of the cast in the last line from f64 to usize. + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + fn bytes_for_cwnd_increase( + &mut self, + curr_cwnd: usize, + new_acked_bytes: usize, + min_rtt: Duration, + now: Instant, + ) -> usize { + let curr_cwnd_f64 = convert_to_f64(curr_cwnd); + let new_acked_f64 = convert_to_f64(new_acked_bytes); + if self.ca_epoch_start.is_none() { + // This is a start of a new congestion avoidance phase. + self.start_epoch(curr_cwnd_f64, new_acked_f64, now); + } else { + self.tcp_acked_bytes += new_acked_f64; + } + + let time_ca = self + .ca_epoch_start + .map_or(min_rtt, |t| { + if now + min_rtt < t { + // This only happens when processing old packets + // that were saved and replayed with old timestamps. + min_rtt + } else { + now + min_rtt - t + } + }) + .as_secs_f64(); + let target_cubic = self.w_cubic(time_ca); + + let tcp_cnt = self.estimated_tcp_cwnd / CUBIC_ALPHA; + while self.tcp_acked_bytes > tcp_cnt { + self.tcp_acked_bytes -= tcp_cnt; + self.estimated_tcp_cwnd += MAX_DATAGRAM_SIZE_F64; + } + + let target_cwnd = target_cubic.max(self.estimated_tcp_cwnd); + + // Calculate the number of bytes that would need to be acknowledged for an increase + // of `MAX_DATAGRAM_SIZE` to match the increase of `target - cwnd / cwnd` as defined + // in the specification (Sections 4.4 and 4.5). + // The amount of data required therefore reduces asymptotically as the target increases. + // If the target is not significantly higher than the congestion window, require a very + // large amount of acknowledged data (effectively block increases). + let mut acked_to_increase = + MAX_DATAGRAM_SIZE_F64 * curr_cwnd_f64 / (target_cwnd - curr_cwnd_f64).max(1.0); + + // Limit increase to max 1 MSS per EXPONENTIAL_GROWTH_REDUCTION ack packets. + // This effectively limits target_cwnd to (1 + 1 / EXPONENTIAL_GROWTH_REDUCTION) cwnd. + acked_to_increase = + acked_to_increase.max(EXPONENTIAL_GROWTH_REDUCTION * MAX_DATAGRAM_SIZE_F64); + acked_to_increase as usize + } + + fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize) { + let curr_cwnd_f64 = convert_to_f64(curr_cwnd); + // Fast Convergence + // If congestion event occurs before the maximum congestion window before the last + // congestion event, we reduce the the maximum congestion window and thereby W_max. + // check cwnd + MAX_DATAGRAM_SIZE instead of cwnd because with cwnd in bytes, cwnd may be + // slightly off. + self.last_max_cwnd = if curr_cwnd_f64 + MAX_DATAGRAM_SIZE_F64 < self.last_max_cwnd { + curr_cwnd_f64 * CUBIC_FAST_CONVERGENCE + } else { + curr_cwnd_f64 + }; + self.ca_epoch_start = None; + ( + curr_cwnd * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR, + acked_bytes * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR, + ) + } + + fn on_app_limited(&mut self) { + // Reset ca_epoch_start. Let it start again when the congestion controller + // exits the app-limited period. + self.ca_epoch_start = None; + } + + #[cfg(test)] + fn last_max_cwnd(&self) -> f64 { + self.last_max_cwnd + } + + #[cfg(test)] + fn set_last_max_cwnd(&mut self, last_max_cwnd: f64) { + self.last_max_cwnd = last_max_cwnd; + } +} diff --git a/third_party/rust/neqo-transport/src/cc/mod.rs b/third_party/rust/neqo-transport/src/cc/mod.rs new file mode 100644 index 0000000000..a1a43bd157 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/mod.rs @@ -0,0 +1,87 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] + +use std::{ + fmt::{Debug, Display}, + str::FromStr, + time::{Duration, Instant}, +}; + +use neqo_common::qlog::NeqoQlog; + +use crate::{path::PATH_MTU_V6, rtt::RttEstimate, tracking::SentPacket, Error}; + +mod classic_cc; +mod cubic; +mod new_reno; + +pub use classic_cc::ClassicCongestionControl; +#[cfg(test)] +pub use classic_cc::{CWND_INITIAL, CWND_INITIAL_PKTS, CWND_MIN}; +pub use cubic::Cubic; +pub use new_reno::NewReno; + +pub const MAX_DATAGRAM_SIZE: usize = PATH_MTU_V6; +#[allow(clippy::cast_precision_loss)] +pub const MAX_DATAGRAM_SIZE_F64: f64 = MAX_DATAGRAM_SIZE as f64; + +pub trait CongestionControl: Display + Debug { + fn set_qlog(&mut self, qlog: NeqoQlog); + + #[must_use] + fn cwnd(&self) -> usize; + + #[must_use] + fn bytes_in_flight(&self) -> usize; + + #[must_use] + fn cwnd_avail(&self) -> usize; + + fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], rtt_est: &RttEstimate, now: Instant); + + /// Returns true if the congestion window was reduced. + fn on_packets_lost( + &mut self, + first_rtt_sample_time: Option<Instant>, + prev_largest_acked_sent: Option<Instant>, + pto: Duration, + lost_packets: &[SentPacket], + ) -> bool; + + #[must_use] + fn recovery_packet(&self) -> bool; + + fn discard(&mut self, pkt: &SentPacket); + + fn on_packet_sent(&mut self, pkt: &SentPacket); + + fn discard_in_flight(&mut self); +} + +#[derive(Debug, Copy, Clone)] +pub enum CongestionControlAlgorithm { + NewReno, + Cubic, +} + +// A `FromStr` implementation so that this can be used in command-line interfaces. +impl FromStr for CongestionControlAlgorithm { + type Err = Error; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s.trim().to_ascii_lowercase().as_str() { + "newreno" | "reno" => Ok(Self::NewReno), + "cubic" => Ok(Self::Cubic), + _ => Err(Error::InvalidInput), + } + } +} + +#[cfg(test)] +mod tests; diff --git a/third_party/rust/neqo-transport/src/cc/new_reno.rs b/third_party/rust/neqo-transport/src/cc/new_reno.rs new file mode 100644 index 0000000000..e51b3d6cc0 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/new_reno.rs @@ -0,0 +1,51 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] + +use std::{ + fmt::{self, Display}, + time::{Duration, Instant}, +}; + +use crate::cc::classic_cc::WindowAdjustment; + +#[derive(Debug, Default)] +pub struct NewReno {} + +impl Display for NewReno { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "NewReno")?; + Ok(()) + } +} + +impl WindowAdjustment for NewReno { + fn bytes_for_cwnd_increase( + &mut self, + curr_cwnd: usize, + _new_acked_bytes: usize, + _min_rtt: Duration, + _now: Instant, + ) -> usize { + curr_cwnd + } + + fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize) { + (curr_cwnd / 2, acked_bytes / 2) + } + + fn on_app_limited(&mut self) {} + + #[cfg(test)] + fn last_max_cwnd(&self) -> f64 { + 0.0 + } + + #[cfg(test)] + fn set_last_max_cwnd(&mut self, _last_max_cwnd: f64) {} +} diff --git a/third_party/rust/neqo-transport/src/cc/tests/cubic.rs b/third_party/rust/neqo-transport/src/cc/tests/cubic.rs new file mode 100644 index 0000000000..0c82e47817 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/tests/cubic.rs @@ -0,0 +1,333 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(clippy::cast_possible_truncation)] +#![allow(clippy::cast_sign_loss)] + +use std::{ + convert::TryFrom, + ops::Sub, + time::{Duration, Instant}, +}; + +use test_fixture::now; + +use crate::{ + cc::{ + classic_cc::{ClassicCongestionControl, CWND_INITIAL}, + cubic::{ + Cubic, CUBIC_ALPHA, CUBIC_BETA_USIZE_DIVIDEND, CUBIC_BETA_USIZE_DIVISOR, CUBIC_C, + CUBIC_FAST_CONVERGENCE, + }, + CongestionControl, MAX_DATAGRAM_SIZE, MAX_DATAGRAM_SIZE_F64, + }, + packet::PacketType, + rtt::RttEstimate, + tracking::SentPacket, +}; + +const RTT: Duration = Duration::from_millis(100); +const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(100)); +const CWND_INITIAL_F64: f64 = 10.0 * MAX_DATAGRAM_SIZE_F64; +const CWND_INITIAL_10_F64: f64 = 10.0 * CWND_INITIAL_F64; +const CWND_INITIAL_10: usize = 10 * CWND_INITIAL; +const CWND_AFTER_LOSS: usize = CWND_INITIAL * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR; +const CWND_AFTER_LOSS_SLOW_START: usize = + (CWND_INITIAL + MAX_DATAGRAM_SIZE) * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR; + +fn fill_cwnd(cc: &mut ClassicCongestionControl<Cubic>, mut next_pn: u64, now: Instant) -> u64 { + while cc.bytes_in_flight() < cc.cwnd() { + let sent = SentPacket::new( + PacketType::Short, + next_pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packet_sent(&sent); + next_pn += 1; + } + next_pn +} + +fn ack_packet(cc: &mut ClassicCongestionControl<Cubic>, pn: u64, now: Instant) { + let acked = SentPacket::new( + PacketType::Short, + pn, // pn + now, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packets_acked(&[acked], &RTT_ESTIMATE, now); +} + +fn packet_lost(cc: &mut ClassicCongestionControl<Cubic>, pn: u64) { + const PTO: Duration = Duration::from_millis(120); + let p_lost = SentPacket::new( + PacketType::Short, + pn, // pn + now(), // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + cc.on_packets_lost(None, None, PTO, &[p_lost]); +} + +fn expected_tcp_acks(cwnd_rtt_start: usize) -> u64 { + (f64::from(i32::try_from(cwnd_rtt_start).unwrap()) / MAX_DATAGRAM_SIZE_F64 / CUBIC_ALPHA) + .round() as u64 +} + +#[test] +fn tcp_phase() { + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + + // change to congestion avoidance state. + cubic.set_ssthresh(1); + + let mut now = now(); + let start_time = now; + // helper variables to remember the next packet number to be sent/acked. + let mut next_pn_send = 0; + let mut next_pn_ack = 0; + + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + + // This will start with TCP phase. + // in this phase cwnd is increase by CUBIC_ALPHA every RTT. We can look at it as + // increase of MAX_DATAGRAM_SIZE every 1 / CUBIC_ALPHA RTTs. + // The phase will end when cwnd calculated with cubic equation is equal to TCP estimate: + // CUBIC_C * (n * RTT / CUBIC_ALPHA)^3 * MAX_DATAGRAM_SIZE = n * MAX_DATAGRAM_SIZE + // from this n = sqrt(CUBIC_ALPHA^3/ (CUBIC_C * RTT^3)). + let num_tcp_increases = (CUBIC_ALPHA.powi(3) / (CUBIC_C * RTT.as_secs_f64().powi(3))) + .sqrt() + .floor() as u64; + + for _ in 0..num_tcp_increases { + let cwnd_rtt_start = cubic.cwnd(); + // Expected acks during a period of RTT / CUBIC_ALPHA. + let acks = expected_tcp_acks(cwnd_rtt_start); + // The time between acks if they are ideally paced over a RTT. + let time_increase = RTT / u32::try_from(cwnd_rtt_start / MAX_DATAGRAM_SIZE).unwrap(); + + for _ in 0..acks { + now += time_increase; + ack_packet(&mut cubic, next_pn_ack, now); + next_pn_ack += 1; + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + } + + assert_eq!(cubic.cwnd() - cwnd_rtt_start, MAX_DATAGRAM_SIZE); + } + + // The next increase will be according to the cubic equation. + + let cwnd_rtt_start = cubic.cwnd(); + // cwnd_rtt_start has change, therefore calculate new time_increase (the time + // between acks if they are ideally paced over a RTT). + let time_increase = RTT / u32::try_from(cwnd_rtt_start / MAX_DATAGRAM_SIZE).unwrap(); + let mut num_acks = 0; // count the number of acks. until cwnd is increased by MAX_DATAGRAM_SIZE. + + while cwnd_rtt_start == cubic.cwnd() { + num_acks += 1; + now += time_increase; + ack_packet(&mut cubic, next_pn_ack, now); + next_pn_ack += 1; + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + } + + // Make sure that the increase is not according to TCP equation, i.e., that it took + // less than RTT / CUBIC_ALPHA. + let expected_ack_tcp_increase = expected_tcp_acks(cwnd_rtt_start); + assert!(num_acks < expected_ack_tcp_increase); + + // This first increase after a TCP phase may be shorter than what it would take by a regular + // cubic phase, because of the proper byte counting and the credit it already had before + // entering this phase. Therefore We will perform another round and compare it to expected + // increase using the cubic equation. + + let cwnd_rtt_start_after_tcp = cubic.cwnd(); + let elapsed_time = now - start_time; + + // calculate new time_increase. + let time_increase = RTT / u32::try_from(cwnd_rtt_start_after_tcp / MAX_DATAGRAM_SIZE).unwrap(); + let mut num_acks2 = 0; // count the number of acks. until cwnd is increased by MAX_DATAGRAM_SIZE. + + while cwnd_rtt_start_after_tcp == cubic.cwnd() { + num_acks2 += 1; + now += time_increase; + ack_packet(&mut cubic, next_pn_ack, now); + next_pn_ack += 1; + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + } + + let expected_ack_tcp_increase2 = expected_tcp_acks(cwnd_rtt_start_after_tcp); + assert!(num_acks2 < expected_ack_tcp_increase2); + + // The time needed to increase cwnd by MAX_DATAGRAM_SIZE using the cubic equation will be + // calculates from: W_cubic(elapsed_time + t_to_increase) - W_cubis(elapsed_time) = + // MAX_DATAGRAM_SIZE => CUBIC_C * (elapsed_time + t_to_increase)^3 * MAX_DATAGRAM_SIZE + + // CWND_INITIAL - CUBIC_C * elapsed_time^3 * MAX_DATAGRAM_SIZE + CWND_INITIAL = + // MAX_DATAGRAM_SIZE => t_to_increase = cbrt((1 + CUBIC_C * elapsed_time^3) / CUBIC_C) - + // elapsed_time (t_to_increase is in seconds) + // number of ack needed is t_to_increase / time_increase. + let expected_ack_cubic_increase = + ((((1.0 + CUBIC_C * (elapsed_time).as_secs_f64().powi(3)) / CUBIC_C).cbrt() + - elapsed_time.as_secs_f64()) + / time_increase.as_secs_f64()) + .ceil() as u64; + // num_acks is very close to the calculated value. The exact value is hard to calculate + // because the proportional increase(i.e. curr_cwnd_f64 / (target - curr_cwnd_f64) * + // MAX_DATAGRAM_SIZE_F64) and the byte counting. + assert_eq!(num_acks2, expected_ack_cubic_increase + 2); +} + +#[test] +fn cubic_phase() { + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + // Set last_max_cwnd to a higher number make sure that cc is the cubic phase (cwnd is calculated + // by the cubic equation). + cubic.set_last_max_cwnd(CWND_INITIAL_10_F64); + // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. + cubic.set_ssthresh(1); + let mut now = now(); + let mut next_pn_send = 0; + let mut next_pn_ack = 0; + + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + + let k = ((CWND_INITIAL_10_F64 - CWND_INITIAL_F64) / CUBIC_C / MAX_DATAGRAM_SIZE_F64).cbrt(); + let epoch_start = now; + + // The number of RTT until W_max is reached. + let num_rtts_w_max = (k / RTT.as_secs_f64()).round() as u64; + for _ in 0..num_rtts_w_max { + let cwnd_rtt_start = cubic.cwnd(); + // Expected acks + let acks = cwnd_rtt_start / MAX_DATAGRAM_SIZE; + let time_increase = RTT / u32::try_from(acks).unwrap(); + for _ in 0..acks { + now += time_increase; + ack_packet(&mut cubic, next_pn_ack, now); + next_pn_ack += 1; + next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now); + } + + let expected = + (CUBIC_C * ((now - epoch_start).as_secs_f64() - k).powi(3) * MAX_DATAGRAM_SIZE_F64 + + CWND_INITIAL_10_F64) + .round() as usize; + + assert_within(cubic.cwnd(), expected, MAX_DATAGRAM_SIZE); + } + assert_eq!(cubic.cwnd(), CWND_INITIAL_10); +} + +fn assert_within<T: Sub<Output = T> + PartialOrd + Copy>(value: T, expected: T, margin: T) { + if value >= expected { + assert!(value - expected < margin); + } else { + assert!(expected - value < margin); + } +} + +#[test] +fn congestion_event_slow_start() { + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + + _ = fill_cwnd(&mut cubic, 0, now()); + ack_packet(&mut cubic, 0, now()); + + assert_within(cubic.last_max_cwnd(), 0.0, f64::EPSILON); + + // cwnd is increased by 1 in slow start phase, after an ack. + assert_eq!(cubic.cwnd(), CWND_INITIAL + MAX_DATAGRAM_SIZE); + + // Trigger a congestion_event in slow start phase + packet_lost(&mut cubic, 1); + + // last_max_cwnd is equal to cwnd before decrease. + assert_within( + cubic.last_max_cwnd(), + CWND_INITIAL_F64 + MAX_DATAGRAM_SIZE_F64, + f64::EPSILON, + ); + assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS_SLOW_START); +} + +#[test] +fn congestion_event_congestion_avoidance() { + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + + // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. + cubic.set_ssthresh(1); + + // Set last_max_cwnd to something smaller than cwnd so that the fast convergence is not + // triggered. + cubic.set_last_max_cwnd(3.0 * MAX_DATAGRAM_SIZE_F64); + + _ = fill_cwnd(&mut cubic, 0, now()); + ack_packet(&mut cubic, 0, now()); + + assert_eq!(cubic.cwnd(), CWND_INITIAL); + + // Trigger a congestion_event in slow start phase + packet_lost(&mut cubic, 1); + + assert_within(cubic.last_max_cwnd(), CWND_INITIAL_F64, f64::EPSILON); + assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS); +} + +#[test] +fn congestion_event_congestion_avoidance_2() { + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + + // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. + cubic.set_ssthresh(1); + + // Set last_max_cwnd to something higher than cwnd so that the fast convergence is triggered. + cubic.set_last_max_cwnd(CWND_INITIAL_10_F64); + + _ = fill_cwnd(&mut cubic, 0, now()); + ack_packet(&mut cubic, 0, now()); + + assert_within(cubic.last_max_cwnd(), CWND_INITIAL_10_F64, f64::EPSILON); + assert_eq!(cubic.cwnd(), CWND_INITIAL); + + // Trigger a congestion_event. + packet_lost(&mut cubic, 1); + + assert_within( + cubic.last_max_cwnd(), + CWND_INITIAL_F64 * CUBIC_FAST_CONVERGENCE, + f64::EPSILON, + ); + assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS); +} + +#[test] +fn congestion_event_congestion_avoidance_test_no_overflow() { + const PTO: Duration = Duration::from_millis(120); + let mut cubic = ClassicCongestionControl::new(Cubic::default()); + + // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. + cubic.set_ssthresh(1); + + // Set last_max_cwnd to something higher than cwnd so that the fast convergence is triggered. + cubic.set_last_max_cwnd(CWND_INITIAL_10_F64); + + _ = fill_cwnd(&mut cubic, 0, now()); + ack_packet(&mut cubic, 1, now()); + + assert_within(cubic.last_max_cwnd(), CWND_INITIAL_10_F64, f64::EPSILON); + assert_eq!(cubic.cwnd(), CWND_INITIAL); + + // Now ack packet that was send earlier. + ack_packet(&mut cubic, 0, now().checked_sub(PTO).unwrap()); +} diff --git a/third_party/rust/neqo-transport/src/cc/tests/mod.rs b/third_party/rust/neqo-transport/src/cc/tests/mod.rs new file mode 100644 index 0000000000..238a7ad012 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/tests/mod.rs @@ -0,0 +1,7 @@ +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +mod cubic; +mod new_reno; diff --git a/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs b/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs new file mode 100644 index 0000000000..a73844a755 --- /dev/null +++ b/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs @@ -0,0 +1,219 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Congestion control +#![deny(clippy::pedantic)] + +use std::time::Duration; + +use test_fixture::now; + +use crate::{ + cc::{ + new_reno::NewReno, ClassicCongestionControl, CongestionControl, CWND_INITIAL, + MAX_DATAGRAM_SIZE, + }, + packet::PacketType, + rtt::RttEstimate, + tracking::SentPacket, +}; + +const PTO: Duration = Duration::from_millis(100); +const RTT: Duration = Duration::from_millis(98); +const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(98)); + +fn cwnd_is_default(cc: &ClassicCongestionControl<NewReno>) { + assert_eq!(cc.cwnd(), CWND_INITIAL); + assert_eq!(cc.ssthresh(), usize::MAX); +} + +fn cwnd_is_halved(cc: &ClassicCongestionControl<NewReno>) { + assert_eq!(cc.cwnd(), CWND_INITIAL / 2); + assert_eq!(cc.ssthresh(), CWND_INITIAL / 2); +} + +#[test] +fn issue_876() { + let mut cc = ClassicCongestionControl::new(NewReno::default()); + let time_now = now(); + let time_before = time_now.checked_sub(Duration::from_millis(100)).unwrap(); + let time_after = time_now + Duration::from_millis(150); + + let sent_packets = &[ + SentPacket::new( + PacketType::Short, + 1, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE - 1, // size + ), + SentPacket::new( + PacketType::Short, + 2, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE - 2, // size + ), + SentPacket::new( + PacketType::Short, + 3, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 4, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 5, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 6, // pn + time_before, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ), + SentPacket::new( + PacketType::Short, + 7, // pn + time_after, // time sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE - 3, // size + ), + ]; + + // Send some more packets so that the cc is not app-limited. + for p in &sent_packets[..6] { + cc.on_packet_sent(p); + } + assert_eq!(cc.acked_bytes(), 0); + cwnd_is_default(&cc); + assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 3); + + cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[0..1]); + + // We are now in recovery + assert!(cc.recovery_packet()); + assert_eq!(cc.acked_bytes(), 0); + cwnd_is_halved(&cc); + assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2); + + // Send a packet after recovery starts + cc.on_packet_sent(&sent_packets[6]); + assert!(!cc.recovery_packet()); + cwnd_is_halved(&cc); + assert_eq!(cc.acked_bytes(), 0); + assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 5); + + // and ack it. cwnd increases slightly + cc.on_packets_acked(&sent_packets[6..], &RTT_ESTIMATE, time_now); + assert_eq!(cc.acked_bytes(), sent_packets[6].size); + cwnd_is_halved(&cc); + assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2); + + // Packet from before is lost. Should not hurt cwnd. + cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[1..2]); + assert!(!cc.recovery_packet()); + assert_eq!(cc.acked_bytes(), sent_packets[6].size); + cwnd_is_halved(&cc); + assert_eq!(cc.bytes_in_flight(), 4 * MAX_DATAGRAM_SIZE); +} + +#[test] +// https://github.com/mozilla/neqo/pull/1465 +fn issue_1465() { + let mut cc = ClassicCongestionControl::new(NewReno::default()); + let mut pn = 0; + let mut now = now(); + let mut next_packet = |now| { + let p = SentPacket::new( + PacketType::Short, + pn, // pn + now, // time_sent + true, // ack eliciting + Vec::new(), // tokens + MAX_DATAGRAM_SIZE, // size + ); + pn += 1; + p + }; + let mut send_next = |cc: &mut ClassicCongestionControl<NewReno>, now| { + let p = next_packet(now); + cc.on_packet_sent(&p); + p + }; + + let p1 = send_next(&mut cc, now); + let p2 = send_next(&mut cc, now); + let p3 = send_next(&mut cc, now); + + assert_eq!(cc.acked_bytes(), 0); + cwnd_is_default(&cc); + assert_eq!(cc.bytes_in_flight(), 3 * MAX_DATAGRAM_SIZE); + + // advance one rtt to detect lost packet there this simplifies the timers, because + // on_packet_loss would only be called after RTO, but that is not relevant to the problem + now += RTT; + cc.on_packets_lost(Some(now), None, PTO, &[p1]); + + // We are now in recovery + assert!(cc.recovery_packet()); + assert_eq!(cc.acked_bytes(), 0); + cwnd_is_halved(&cc); + assert_eq!(cc.bytes_in_flight(), 2 * MAX_DATAGRAM_SIZE); + + // Don't reduce the cwnd again on second packet loss + cc.on_packets_lost(Some(now), None, PTO, &[p3]); + assert_eq!(cc.acked_bytes(), 0); + cwnd_is_halved(&cc); // still the same as after first packet loss + assert_eq!(cc.bytes_in_flight(), MAX_DATAGRAM_SIZE); + + // the acked packets before on_packet_sent were the cause of + // https://github.com/mozilla/neqo/pull/1465 + cc.on_packets_acked(&[p2], &RTT_ESTIMATE, now); + + assert_eq!(cc.bytes_in_flight(), 0); + + // send out recovery packet and get it acked to get out of recovery state + let p4 = send_next(&mut cc, now); + cc.on_packet_sent(&p4); + now += RTT; + cc.on_packets_acked(&[p4], &RTT_ESTIMATE, now); + + // do the same as in the first rtt but now the bug appears + let p5 = send_next(&mut cc, now); + let p6 = send_next(&mut cc, now); + now += RTT; + + let cur_cwnd = cc.cwnd(); + cc.on_packets_lost(Some(now), None, PTO, &[p5]); + + // go back into recovery + assert!(cc.recovery_packet()); + assert_eq!(cc.cwnd(), cur_cwnd / 2); + assert_eq!(cc.acked_bytes(), 0); + assert_eq!(cc.bytes_in_flight(), 2 * MAX_DATAGRAM_SIZE); + + // this shouldn't introduce further cwnd reduction, but it did before https://github.com/mozilla/neqo/pull/1465 + cc.on_packets_lost(Some(now), None, PTO, &[p6]); + assert_eq!(cc.cwnd(), cur_cwnd / 2); +} |