diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-06-12 05:43:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-06-12 05:43:14 +0000 |
commit | 8dd16259287f58f9273002717ec4d27e97127719 (patch) | |
tree | 3863e62a53829a84037444beab3abd4ed9dfc7d0 /third_party/rust/neqo-transport/src | |
parent | Releasing progress-linux version 126.0.1-1~progress7.99u1. (diff) | |
download | firefox-8dd16259287f58f9273002717ec4d27e97127719.tar.xz firefox-8dd16259287f58f9273002717ec4d27e97127719.zip |
Merging upstream version 127.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/neqo-transport/src')
28 files changed, 1371 insertions, 422 deletions
diff --git a/third_party/rust/neqo-transport/src/cc/classic_cc.rs b/third_party/rust/neqo-transport/src/cc/classic_cc.rs index f8bcee6722..6914e91f67 100644 --- a/third_party/rust/neqo-transport/src/cc/classic_cc.rs +++ b/third_party/rust/neqo-transport/src/cc/classic_cc.rs @@ -298,6 +298,14 @@ impl<T: WindowAdjustment> CongestionControl for ClassicCongestionControl<T> { congestion || persistent_congestion } + /// Report received ECN CE mark(s) to the congestion controller as a + /// congestion event. + /// + /// See <https://datatracker.ietf.org/doc/html/rfc9002#section-b.7>. + fn on_ecn_ce_received(&mut self, largest_acked_pkt: &SentPacket) -> bool { + self.on_congestion_event(largest_acked_pkt) + } + fn discard(&mut self, pkt: &SentPacket) { if pkt.cc_outstanding() { assert!(self.bytes_in_flight >= pkt.size); @@ -488,8 +496,8 @@ impl<T: WindowAdjustment> ClassicCongestionControl<T> { /// Handle a congestion event. /// Returns true if this was a true congestion event. fn on_congestion_event(&mut self, last_packet: &SentPacket) -> bool { - // Start a new congestion event if lost packet was sent after the start - // of the previous congestion recovery period. + // Start a new congestion event if lost or ECN CE marked packet was sent + // after the start of the previous congestion recovery period. if !self.after_recovery_start(last_packet) { return false; } @@ -538,7 +546,7 @@ impl<T: WindowAdjustment> ClassicCongestionControl<T> { mod tests { use std::time::{Duration, Instant}; - use neqo_common::qinfo; + use neqo_common::{qinfo, IpTosEcn}; use test_fixture::now; use super::{ @@ -582,6 +590,7 @@ mod tests { SentPacket::new( PacketType::Short, pn, + IpTosEcn::default(), now() + t, ack_eliciting, Vec::new(), @@ -795,6 +804,7 @@ mod tests { SentPacket::new( PacketType::Short, u64::try_from(i).unwrap(), + IpTosEcn::default(), by_pto(t), true, Vec::new(), @@ -915,6 +925,7 @@ mod tests { lost[0] = SentPacket::new( lost[0].pt, lost[0].pn, + lost[0].ecn_mark, lost[0].time_sent, false, Vec::new(), @@ -1015,11 +1026,12 @@ mod tests { for _ in 0..packet_burst_size { let p = SentPacket::new( PacketType::Short, - next_pn, // pn - now, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + next_pn, + IpTosEcn::default(), + now, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ); next_pn += 1; cc.on_packet_sent(&p); @@ -1039,11 +1051,12 @@ mod tests { for _ in 0..ABOVE_APP_LIMIT_PKTS { let p = SentPacket::new( PacketType::Short, - next_pn, // pn - now, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + next_pn, + IpTosEcn::default(), + now, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ); next_pn += 1; cc.on_packet_sent(&p); @@ -1082,11 +1095,12 @@ mod tests { let p_lost = SentPacket::new( PacketType::Short, - 1, // pn - now, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + 1, + IpTosEcn::default(), + now, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ); cc.on_packet_sent(&p_lost); cwnd_is_default(&cc); @@ -1095,11 +1109,12 @@ mod tests { cwnd_is_halved(&cc); let p_not_lost = SentPacket::new( PacketType::Short, - 2, // pn - now, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + 2, + IpTosEcn::default(), + now, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ); cc.on_packet_sent(&p_not_lost); now += RTT; @@ -1118,11 +1133,12 @@ mod tests { for _ in 0..packet_burst_size { let p = SentPacket::new( PacketType::Short, - next_pn, // pn - now, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + next_pn, + IpTosEcn::default(), + now, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ); next_pn += 1; cc.on_packet_sent(&p); @@ -1148,11 +1164,12 @@ mod tests { for _ in 0..ABOVE_APP_LIMIT_PKTS { let p = SentPacket::new( PacketType::Short, - next_pn, // pn - now, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + next_pn, + IpTosEcn::default(), + now, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ); next_pn += 1; cc.on_packet_sent(&p); @@ -1180,4 +1197,26 @@ mod tests { last_acked_bytes = cc.acked_bytes; } } + + #[test] + fn ecn_ce() { + let mut cc = ClassicCongestionControl::new(NewReno::default()); + let p_ce = SentPacket::new( + PacketType::Short, + 1, + IpTosEcn::default(), + now(), + true, + Vec::new(), + MAX_DATAGRAM_SIZE, + ); + cc.on_packet_sent(&p_ce); + cwnd_is_default(&cc); + assert_eq!(cc.state, State::SlowStart); + + // Signal congestion (ECN CE) and thus change state to recovery start. + cc.on_ecn_ce_received(&p_ce); + cwnd_is_halved(&cc); + assert_eq!(cc.state, State::RecoveryStart); + } } diff --git a/third_party/rust/neqo-transport/src/cc/mod.rs b/third_party/rust/neqo-transport/src/cc/mod.rs index 486d15e67e..2adffbc0c4 100644 --- a/third_party/rust/neqo-transport/src/cc/mod.rs +++ b/third_party/rust/neqo-transport/src/cc/mod.rs @@ -53,6 +53,9 @@ pub trait CongestionControl: Display + Debug { lost_packets: &[SentPacket], ) -> bool; + /// Returns true if the congestion window was reduced. + fn on_ecn_ce_received(&mut self, largest_acked_pkt: &SentPacket) -> bool; + #[must_use] fn recovery_packet(&self) -> bool; diff --git a/third_party/rust/neqo-transport/src/cc/tests/cubic.rs b/third_party/rust/neqo-transport/src/cc/tests/cubic.rs index 2e0200fd6d..8ff591cb47 100644 --- a/third_party/rust/neqo-transport/src/cc/tests/cubic.rs +++ b/third_party/rust/neqo-transport/src/cc/tests/cubic.rs @@ -12,6 +12,7 @@ use std::{ time::{Duration, Instant}, }; +use neqo_common::IpTosEcn; use test_fixture::now; use crate::{ @@ -41,11 +42,12 @@ fn fill_cwnd(cc: &mut ClassicCongestionControl<Cubic>, mut next_pn: u64, now: In while cc.bytes_in_flight() < cc.cwnd() { let sent = SentPacket::new( PacketType::Short, - next_pn, // pn - now, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + next_pn, + IpTosEcn::default(), + now, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ); cc.on_packet_sent(&sent); next_pn += 1; @@ -56,11 +58,12 @@ fn fill_cwnd(cc: &mut ClassicCongestionControl<Cubic>, mut next_pn: u64, now: In fn ack_packet(cc: &mut ClassicCongestionControl<Cubic>, pn: u64, now: Instant) { let acked = SentPacket::new( PacketType::Short, - pn, // pn - now, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + pn, + IpTosEcn::default(), + now, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ); cc.on_packets_acked(&[acked], &RTT_ESTIMATE, now); } @@ -69,11 +72,12 @@ fn packet_lost(cc: &mut ClassicCongestionControl<Cubic>, pn: u64) { const PTO: Duration = Duration::from_millis(120); let p_lost = SentPacket::new( PacketType::Short, - pn, // pn - now(), // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + pn, + IpTosEcn::default(), + now(), + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ); cc.on_packets_lost(None, None, PTO, &[p_lost]); } diff --git a/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs b/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs index 4cc20de5a7..0cc560bf2b 100644 --- a/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs +++ b/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs @@ -8,6 +8,7 @@ use std::time::Duration; +use neqo_common::IpTosEcn; use test_fixture::now; use crate::{ @@ -44,59 +45,66 @@ fn issue_876() { let sent_packets = &[ SentPacket::new( PacketType::Short, - 1, // pn - time_before, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE - 1, // size + 1, + IpTosEcn::default(), + time_before, + true, + Vec::new(), + MAX_DATAGRAM_SIZE - 1, ), SentPacket::new( PacketType::Short, - 2, // pn - time_before, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE - 2, // size + 2, + IpTosEcn::default(), + time_before, + true, + Vec::new(), + MAX_DATAGRAM_SIZE - 2, ), SentPacket::new( PacketType::Short, - 3, // pn - time_before, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + 3, + IpTosEcn::default(), + time_before, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ), SentPacket::new( PacketType::Short, - 4, // pn - time_before, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + 4, + IpTosEcn::default(), + time_before, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ), SentPacket::new( PacketType::Short, - 5, // pn - time_before, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + 5, + IpTosEcn::default(), + time_before, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ), SentPacket::new( PacketType::Short, - 6, // pn - time_before, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + 6, + IpTosEcn::default(), + time_before, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ), SentPacket::new( PacketType::Short, - 7, // pn - time_after, // time sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE - 3, // size + 7, + IpTosEcn::default(), + time_after, + true, + Vec::new(), + MAX_DATAGRAM_SIZE - 3, ), ]; @@ -146,11 +154,12 @@ fn issue_1465() { let mut next_packet = |now| { let p = SentPacket::new( PacketType::Short, - pn, // pn - now, // time_sent - true, // ack eliciting - Vec::new(), // tokens - MAX_DATAGRAM_SIZE, // size + pn, + IpTosEcn::default(), + now, + true, + Vec::new(), + MAX_DATAGRAM_SIZE, ); pn += 1; p diff --git a/third_party/rust/neqo-transport/src/connection/mod.rs b/third_party/rust/neqo-transport/src/connection/mod.rs index 8522507a69..f955381414 100644 --- a/third_party/rust/neqo-transport/src/connection/mod.rs +++ b/third_party/rust/neqo-transport/src/connection/mod.rs @@ -19,7 +19,7 @@ use std::{ use neqo_common::{ event::Provider as EventProvider, hex, hex_snip_middle, hrtime, qdebug, qerror, qinfo, - qlog::NeqoQlog, qtrace, qwarn, Datagram, Decoder, Encoder, IpTos, Role, + qlog::NeqoQlog, qtrace, qwarn, Datagram, Decoder, Encoder, Role, }; use neqo_crypto::{ agent::CertificateInfo, Agent, AntiReplay, AuthenticationStatus, Cipher, Client, Group, @@ -35,6 +35,7 @@ use crate::{ ConnectionIdRef, ConnectionIdStore, LOCAL_ACTIVE_CID_LIMIT, }, crypto::{Crypto, CryptoDxState, CryptoSpace}, + ecn::EcnCount, events::{ConnectionEvent, ConnectionEvents, OutgoingDatagramOutcome}, frame::{ CloseError, Frame, FrameType, FRAME_TYPE_CONNECTION_CLOSE_APPLICATION, @@ -46,7 +47,7 @@ use crate::{ quic_datagrams::{DatagramTracking, QuicDatagrams}, recovery::{LossRecovery, RecoveryToken, SendProfile}, recv_stream::RecvStreamStats, - rtt::GRANULARITY, + rtt::{RttEstimate, GRANULARITY}, send_stream::SendStream, stats::{Stats, StatsCell}, stream_id::StreamType, @@ -55,9 +56,9 @@ use crate::{ self, TransportParameter, TransportParameterId, TransportParameters, TransportParametersHandler, }, - tracking::{AckTracker, PacketNumberSpace, SentPacket}, + tracking::{AckTracker, PacketNumberSpace, RecvdPackets, SentPacket}, version::{Version, WireVersion}, - AppError, ConnectionError, Error, Res, StreamId, + AppError, CloseReason, Error, Res, StreamId, }; mod dump; @@ -291,7 +292,7 @@ impl Debug for Connection { "{:?} Connection: {:?} {:?}", self.role, self.state, - self.paths.primary_fallible() + self.paths.primary() ) } } @@ -591,7 +592,11 @@ impl Connection { fn make_resumption_token(&mut self) -> ResumptionToken { debug_assert_eq!(self.role, Role::Client); debug_assert!(self.crypto.has_resumption_token()); - let rtt = self.paths.primary().borrow().rtt().estimate(); + let rtt = self.paths.primary().map_or_else( + || RttEstimate::default().estimate(), + |p| p.borrow().rtt().estimate(), + ); + self.crypto .create_resumption_token( self.new_token.take_token(), @@ -610,11 +615,10 @@ impl Connection { /// a value of this approximate order. Don't use this for loss recovery, /// only use it where a more precise value is not important. fn pto(&self) -> Duration { - self.paths - .primary() - .borrow() - .rtt() - .pto(PacketNumberSpace::ApplicationData) + self.paths.primary().map_or_else( + || RttEstimate::default().pto(PacketNumberSpace::ApplicationData), + |p| p.borrow().rtt().pto(PacketNumberSpace::ApplicationData), + ) } fn create_resumption_token(&mut self, now: Instant) { @@ -746,7 +750,12 @@ impl Connection { if !init_token.is_empty() { self.address_validation = AddressValidationInfo::NewToken(init_token.to_vec()); } - self.paths.primary().borrow_mut().rtt_mut().set_initial(rtt); + self.paths + .primary() + .ok_or(Error::InternalError)? + .borrow_mut() + .rtt_mut() + .set_initial(rtt); self.set_initial_limits(); // Start up TLS, which has the effect of setting up all the necessary // state for 0-RTT. This only stages the CRYPTO frames. @@ -786,7 +795,7 @@ impl Connection { // If we are able, also send a NEW_TOKEN frame. // This should be recording all remote addresses that are valid, // but there are just 0 or 1 in the current implementation. - if let Some(path) = self.paths.primary_fallible() { + if let Some(path) = self.paths.primary() { if let Some(token) = self .address_validation .generate_new_token(path.borrow().remote_address(), now) @@ -858,7 +867,7 @@ impl Connection { #[must_use] pub fn stats(&self) -> Stats { let mut v = self.stats.borrow().clone(); - if let Some(p) = self.paths.primary_fallible() { + if let Some(p) = self.paths.primary() { let p = p.borrow(); v.rtt = p.rtt().estimate(); v.rttvar = p.rtt().rttvar(); @@ -880,7 +889,7 @@ impl Connection { let msg = format!("{v:?}"); #[cfg(not(debug_assertions))] let msg = ""; - let error = ConnectionError::Transport(v.clone()); + let error = CloseReason::Transport(v.clone()); match &self.state { State::Closing { error: err, .. } | State::Draining { error: err, .. } @@ -895,14 +904,14 @@ impl Connection { State::WaitInitial => { // We don't have any state yet, so don't bother with // the closing state, just send one CONNECTION_CLOSE. - if let Some(path) = path.or_else(|| self.paths.primary_fallible()) { + if let Some(path) = path.or_else(|| self.paths.primary()) { self.state_signaling .close(path, error.clone(), frame_type, msg); } self.set_state(State::Closed(error)); } _ => { - if let Some(path) = path.or_else(|| self.paths.primary_fallible()) { + if let Some(path) = path.or_else(|| self.paths.primary()) { self.state_signaling .close(path, error.clone(), frame_type, msg); if matches!(v, Error::KeysExhausted) { @@ -951,9 +960,7 @@ impl Connection { let pto = self.pto(); if self.idle_timeout.expired(now, pto) { qinfo!([self], "idle timeout expired"); - self.set_state(State::Closed(ConnectionError::Transport( - Error::IdleTimeout, - ))); + self.set_state(State::Closed(CloseReason::Transport(Error::IdleTimeout))); return; } @@ -962,9 +969,11 @@ impl Connection { let res = self.crypto.states.check_key_update(now); self.absorb_error(now, res); - let lost = self.loss_recovery.timeout(&self.paths.primary(), now); - self.handle_lost_packets(&lost); - qlog::packets_lost(&mut self.qlog, &lost); + if let Some(path) = self.paths.primary() { + let lost = self.loss_recovery.timeout(&path, now); + self.handle_lost_packets(&lost); + qlog::packets_lost(&mut self.qlog, &lost); + } if self.release_resumption_token_timer.is_some() { self.create_resumption_token(now); @@ -1014,7 +1023,7 @@ impl Connection { delays.push(ack_time); } - if let Some(p) = self.paths.primary_fallible() { + if let Some(p) = self.paths.primary() { let path = p.borrow(); let rtt = path.rtt(); let pto = rtt.pto(PacketNumberSpace::ApplicationData); @@ -1102,7 +1111,15 @@ impl Connection { self.input(d, now, now); self.process_saved(now); } - self.process_output(now) + #[allow(clippy::let_and_return)] + let output = self.process_output(now); + #[cfg(all(feature = "build-fuzzing-corpus", test))] + if self.test_frame_writer.is_none() { + if let Some(d) = output.clone().dgram() { + neqo_common::write_item_to_fuzzing_corpus("packet", &d); + } + } + output } fn handle_retry(&mut self, packet: &PublicPacket, now: Instant) { @@ -1123,7 +1140,13 @@ impl Connection { } // At this point, we should only have the connection ID that we generated. // Update to the one that the server prefers. - let path = self.paths.primary(); + let Some(path) = self.paths.primary() else { + self.stats + .borrow_mut() + .pkt_dropped("Retry without an existing path"); + return; + }; + path.borrow_mut().set_remote_cid(packet.scid()); let retry_scid = ConnectionId::from(packet.scid()); @@ -1151,8 +1174,9 @@ impl Connection { fn discard_keys(&mut self, space: PacketNumberSpace, now: Instant) { if self.crypto.discard(space) { qdebug!([self], "Drop packet number space {}", space); - let primary = self.paths.primary(); - self.loss_recovery.discard(&primary, space, now); + if let Some(path) = self.paths.primary() { + self.loss_recovery.discard(&path, space, now); + } self.acks.drop_space(space); } } @@ -1180,7 +1204,7 @@ impl Connection { qdebug!([self], "Stateless reset: {}", hex(&d[d.len() - 16..])); self.state_signaling.reset(); self.set_state(State::Draining { - error: ConnectionError::Transport(Error::StatelessReset), + error: CloseReason::Transport(Error::StatelessReset), timeout: self.get_closing_period_time(now), }); Err(Error::StatelessReset) @@ -1227,8 +1251,9 @@ impl Connection { assert_ne!(self.version, version); qinfo!([self], "Version negotiation: trying {:?}", version); - let local_addr = self.paths.primary().borrow().local_address(); - let remote_addr = self.paths.primary().borrow().remote_address(); + let path = self.paths.primary().ok_or(Error::NoAvailablePath)?; + let local_addr = path.borrow().local_address(); + let remote_addr = path.borrow().remote_address(); let conn_params = self .conn_params .clone() @@ -1256,7 +1281,7 @@ impl Connection { } else { qinfo!([self], "Version negotiation: failed with {:?}", supported); // This error goes straight to closed. - self.set_state(State::Closed(ConnectionError::Transport( + self.set_state(State::Closed(CloseReason::Transport( Error::VersionNegotiation, ))); Err(Error::VersionNegotiation) @@ -1417,6 +1442,13 @@ impl Connection { migrate: bool, now: Instant, ) { + let space = PacketNumberSpace::from(packet.packet_type()); + if let Some(space) = self.acks.get_mut(space) { + *space.ecn_marks() += d.tos().into(); + } else { + qtrace!("Not tracking ECN for dropped packet number space"); + } + if self.state == State::WaitInitial { self.start_handshake(path, packet, now); } @@ -1491,6 +1523,16 @@ impl Connection { d.tos(), ); + #[cfg(feature = "build-fuzzing-corpus")] + if packet.packet_type() == PacketType::Initial { + let target = if self.role == Role::Client { + "server_initial" + } else { + "client_initial" + }; + neqo_common::write_item_to_fuzzing_corpus(target, &payload[..]); + } + qlog::packet_received(&mut self.qlog, &packet, &payload); let space = PacketNumberSpace::from(payload.packet_type()); if self.acks.get_mut(space).unwrap().is_duplicate(payload.pn()) { @@ -1562,7 +1604,11 @@ impl Connection { let mut probing = true; let mut d = Decoder::from(&packet[..]); while d.remaining() > 0 { + #[cfg(feature = "build-fuzzing-corpus")] + let pos = d.offset(); let f = Frame::decode(&mut d)?; + #[cfg(feature = "build-fuzzing-corpus")] + neqo_common::write_item_to_fuzzing_corpus("frame", &packet[pos..d.offset()]); ack_eliciting |= f.ack_eliciting(); probing &= f.path_probing(); let t = f.get_type(); @@ -1623,10 +1669,15 @@ impl Connection { if let Some(cid) = self.connection_ids.next() { self.paths.make_permanent(path, None, cid); Ok(()) - } else if self.paths.primary().borrow().remote_cid().is_empty() { - self.paths - .make_permanent(path, None, ConnectionIdEntry::empty_remote()); - Ok(()) + } else if let Some(primary) = self.paths.primary() { + if primary.borrow().remote_cid().is_empty() { + self.paths + .make_permanent(path, None, ConnectionIdEntry::empty_remote()); + Ok(()) + } else { + qtrace!([self], "Unable to make path permanent: {}", path.borrow()); + Err(Error::InvalidMigration) + } } else { qtrace!([self], "Unable to make path permanent: {}", path.borrow()); Err(Error::InvalidMigration) @@ -1719,8 +1770,10 @@ impl Connection { // Pointless migration is pointless. return Err(Error::InvalidMigration); } - let local = local.unwrap_or_else(|| self.paths.primary().borrow().local_address()); - let remote = remote.unwrap_or_else(|| self.paths.primary().borrow().remote_address()); + + let path = self.paths.primary().ok_or(Error::InvalidMigration)?; + let local = local.unwrap_or_else(|| path.borrow().local_address()); + let remote = remote.unwrap_or_else(|| path.borrow().remote_address()); if mem::discriminant(&local.ip()) != mem::discriminant(&remote.ip()) { // Can't mix address families. @@ -1773,7 +1826,12 @@ impl Connection { // has to use the existing address. So only pay attention to a preferred // address from the same family as is currently in use. More thought will // be needed to work out how to get addresses from a different family. - let prev = self.paths.primary().borrow().remote_address(); + let prev = self + .paths + .primary() + .ok_or(Error::NoAvailablePath)? + .borrow() + .remote_address(); let remote = match prev.ip() { IpAddr::V4(_) => addr.ipv4().map(SocketAddr::V4), IpAddr::V6(_) => addr.ipv6().map(SocketAddr::V6), @@ -1937,20 +1995,15 @@ impl Connection { } } - self.streams - .write_frames(TransmissionPriority::Critical, builder, tokens, frame_stats); - if builder.is_full() { - return; - } - - self.streams.write_frames( + for prio in [ + TransmissionPriority::Critical, TransmissionPriority::Important, - builder, - tokens, - frame_stats, - ); - if builder.is_full() { - return; + ] { + self.streams + .write_frames(prio, builder, tokens, frame_stats); + if builder.is_full() { + return; + } } // NEW_CONNECTION_ID, RETIRE_CONNECTION_ID, and ACK_FREQUENCY. @@ -1958,21 +2011,18 @@ impl Connection { if builder.is_full() { return; } - self.paths.write_frames(builder, tokens, frame_stats); - if builder.is_full() { - return; - } - self.streams - .write_frames(TransmissionPriority::High, builder, tokens, frame_stats); + self.paths.write_frames(builder, tokens, frame_stats); if builder.is_full() { return; } - self.streams - .write_frames(TransmissionPriority::Normal, builder, tokens, frame_stats); - if builder.is_full() { - return; + for prio in [TransmissionPriority::High, TransmissionPriority::Normal] { + self.streams + .write_frames(prio, builder, tokens, &mut stats.frame_tx); + if builder.is_full() { + return; + } } // Datagrams are best-effort and unreliable. Let streams starve them for now. @@ -1981,9 +2031,9 @@ impl Connection { return; } - let frame_stats = &mut stats.frame_tx; // CRYPTO here only includes NewSessionTicket, plus NEW_TOKEN. // Both of these are only used for resumption and so can be relatively low priority. + let frame_stats = &mut stats.frame_tx; self.crypto.write_frame( PacketNumberSpace::ApplicationData, builder, @@ -1993,6 +2043,7 @@ impl Connection { if builder.is_full() { return; } + self.new_token.write_frames(builder, tokens, frame_stats); if builder.is_full() { return; @@ -2002,10 +2053,8 @@ impl Connection { .write_frames(TransmissionPriority::Low, builder, tokens, frame_stats); #[cfg(test)] - { - if let Some(w) = &mut self.test_frame_writer { - w.write_frames(builder); - } + if let Some(w) = &mut self.test_frame_writer { + w.write_frames(builder); } } @@ -2138,6 +2187,40 @@ impl Connection { (tokens, ack_eliciting, padded) } + fn write_closing_frames( + &mut self, + close: &ClosingFrame, + builder: &mut PacketBuilder, + space: PacketNumberSpace, + now: Instant, + path: &PathRef, + tokens: &mut Vec<RecoveryToken>, + ) { + if builder.remaining() > ClosingFrame::MIN_LENGTH + RecvdPackets::USEFUL_ACK_LEN { + // Include an ACK frame with the CONNECTION_CLOSE. + let limit = builder.limit(); + builder.set_limit(limit - ClosingFrame::MIN_LENGTH); + self.acks.immediate_ack(now); + self.acks.write_frame( + space, + now, + path.borrow().rtt().estimate(), + builder, + tokens, + &mut self.stats.borrow_mut().frame_tx, + ); + builder.set_limit(limit); + } + // CloseReason::Application is only allowed at 1RTT. + let sanitized = if space == PacketNumberSpace::ApplicationData { + None + } else { + close.sanitize() + }; + sanitized.as_ref().unwrap_or(close).write_frame(builder); + self.stats.borrow_mut().frame_tx.connection_close += 1; + } + /// Build a datagram, possibly from multiple packets (for different PN /// spaces) and each containing 1+ frames. #[allow(clippy::too_many_lines)] // Yeah, that's just the way it is. @@ -2201,17 +2284,7 @@ impl Connection { let payload_start = builder.len(); let (mut tokens, mut ack_eliciting, mut padded) = (Vec::new(), false, false); if let Some(ref close) = closing_frame { - // ConnectionError::Application is only allowed at 1RTT. - let sanitized = if *space == PacketNumberSpace::ApplicationData { - None - } else { - close.sanitize() - }; - sanitized - .as_ref() - .unwrap_or(close) - .write_frame(&mut builder); - self.stats.borrow_mut().frame_tx.connection_close += 1; + self.write_closing_frames(close, &mut builder, *space, now, path, &mut tokens); } else { (tokens, ack_eliciting, padded) = self.write_frames(path, *space, &profile, &mut builder, now); @@ -2229,7 +2302,7 @@ impl Connection { pt, pn, &builder.as_ref()[payload_start..], - IpTos::default(), // TODO: set from path + path.borrow().tos(), ); qlog::packet_sent( &mut self.qlog, @@ -2251,6 +2324,7 @@ impl Connection { let sent = SentPacket::new( pt, pn, + path.borrow().tos().into(), now, ack_eliciting, tokens, @@ -2303,7 +2377,7 @@ impl Connection { self.loss_recovery.on_packet_sent(path, initial); } path.borrow_mut().add_sent(packets.len()); - Ok(SendOption::Yes(path.borrow().datagram(packets))) + Ok(SendOption::Yes(path.borrow_mut().datagram(packets))) } } @@ -2330,7 +2404,9 @@ impl Connection { fn client_start(&mut self, now: Instant) -> Res<()> { qdebug!([self], "client_start"); debug_assert_eq!(self.role, Role::Client); - qlog::client_connection_started(&mut self.qlog, &self.paths.primary()); + if let Some(path) = self.paths.primary() { + qlog::client_connection_started(&mut self.qlog, &path); + } qlog::client_version_information_initiated(&mut self.qlog, self.conn_params.get_versions()); self.handshake(now, self.version, PacketNumberSpace::Initial, None)?; @@ -2351,9 +2427,9 @@ impl Connection { /// Close the connection. pub fn close(&mut self, now: Instant, app_error: AppError, msg: impl AsRef<str>) { - let error = ConnectionError::Application(app_error); + let error = CloseReason::Application(app_error); let timeout = self.get_closing_period_time(now); - if let Some(path) = self.paths.primary_fallible() { + if let Some(path) = self.paths.primary() { self.state_signaling.close(path, error.clone(), 0, msg); self.set_state(State::Closing { error, timeout }); } else { @@ -2411,10 +2487,8 @@ impl Connection { // That's OK, they can try guessing this. ConnectionIdEntry::random_srt() }; - self.paths - .primary() - .borrow_mut() - .set_reset_token(reset_token); + let path = self.paths.primary().ok_or(Error::NoAvailablePath)?; + path.borrow_mut().set_reset_token(reset_token); let max_ad = Duration::from_millis(remote.get_integer(tparams::MAX_ACK_DELAY)); let min_ad = if remote.has_value(tparams::MIN_ACK_DELAY) { @@ -2426,11 +2500,8 @@ impl Connection { } else { None }; - self.paths.primary().borrow_mut().set_ack_delay( - max_ad, - min_ad, - self.conn_params.get_ack_ratio(), - ); + path.borrow_mut() + .set_ack_delay(max_ad, min_ad, self.conn_params.get_ack_ratio()); let max_active_cids = remote.get_integer(tparams::ACTIVE_CONNECTION_ID_LIMIT); self.cid_manager.set_limit(max_active_cids); @@ -2673,10 +2744,18 @@ impl Connection { ack_delay, first_ack_range, ack_ranges, + ecn_count, } => { let ranges = Frame::decode_ack_frame(largest_acknowledged, first_ack_range, &ack_ranges)?; - self.handle_ack(space, largest_acknowledged, ranges, ack_delay, now); + self.handle_ack( + space, + largest_acknowledged, + ranges, + ecn_count, + ack_delay, + now, + ); } Frame::Crypto { offset, data } => { qtrace!( @@ -2747,7 +2826,6 @@ impl Connection { reason_phrase, } => { self.stats.borrow_mut().frame_rx.connection_close += 1; - let reason_phrase = String::from_utf8_lossy(&reason_phrase); qinfo!( [self], "ConnectionClose received. Error code: {:?} frame type {:x} reason {}", @@ -2768,7 +2846,7 @@ impl Connection { FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT, ) }; - let error = ConnectionError::Transport(detail); + let error = CloseReason::Transport(detail); self.state_signaling .drain(Rc::clone(path), error.clone(), frame_type, ""); self.set_state(State::Draining { @@ -2853,6 +2931,7 @@ impl Connection { space: PacketNumberSpace, largest_acknowledged: u64, ack_ranges: R, + ack_ecn: Option<EcnCount>, ack_delay: u64, now: Instant, ) where @@ -2861,11 +2940,15 @@ impl Connection { { qdebug!([self], "Rx ACK space={}, ranges={:?}", space, ack_ranges); + let Some(path) = self.paths.primary() else { + return; + }; let (acked_packets, lost_packets) = self.loss_recovery.on_ack_received( - &self.paths.primary(), + &path, space, largest_acknowledged, ack_ranges, + ack_ecn, self.decode_ack_delay(ack_delay), now, ); @@ -2903,8 +2986,10 @@ impl Connection { qdebug!([self], "0-RTT rejected"); // Tell 0-RTT packets that they were "lost". - let dropped = self.loss_recovery.drop_0rtt(&self.paths.primary(), now); - self.handle_lost_packets(&dropped); + if let Some(path) = self.paths.primary() { + let dropped = self.loss_recovery.drop_0rtt(&path, now); + self.handle_lost_packets(&dropped); + } self.streams.zero_rtt_rejected(); @@ -2923,7 +3008,7 @@ impl Connection { // Remove the randomized client CID from the list of acceptable CIDs. self.cid_manager.remove_odcid(); // Mark the path as validated, if it isn't already. - let path = self.paths.primary(); + let path = self.paths.primary().ok_or(Error::NoAvailablePath)?; path.borrow_mut().set_valid(now); // Generate a qlog event that the server connection started. qlog::server_connection_started(&mut self.qlog, &path); @@ -3191,7 +3276,7 @@ impl Connection { else { return Err(Error::NotAvailable); }; - let path = self.paths.primary_fallible().ok_or(Error::NotAvailable)?; + let path = self.paths.primary().ok_or(Error::NotAvailable)?; let mtu = path.borrow().mtu(); let encoder = Encoder::with_capacity(mtu); diff --git a/third_party/rust/neqo-transport/src/connection/state.rs b/third_party/rust/neqo-transport/src/connection/state.rs index cc2f6e30d2..e76f937938 100644 --- a/third_party/rust/neqo-transport/src/connection/state.rs +++ b/third_party/rust/neqo-transport/src/connection/state.rs @@ -21,7 +21,7 @@ use crate::{ packet::PacketBuilder, path::PathRef, recovery::RecoveryToken, - ConnectionError, Error, + CloseReason, Error, }; #[derive(Clone, Debug, PartialEq, Eq)] @@ -42,14 +42,14 @@ pub enum State { Connected, Confirmed, Closing { - error: ConnectionError, + error: CloseReason, timeout: Instant, }, Draining { - error: ConnectionError, + error: CloseReason, timeout: Instant, }, - Closed(ConnectionError), + Closed(CloseReason), } impl State { @@ -67,7 +67,7 @@ impl State { } #[must_use] - pub fn error(&self) -> Option<&ConnectionError> { + pub fn error(&self) -> Option<&CloseReason> { if let Self::Closing { error, .. } | Self::Draining { error, .. } | Self::Closed(error) = self { @@ -116,7 +116,7 @@ impl Ord for State { #[derive(Debug, Clone)] pub struct ClosingFrame { path: PathRef, - error: ConnectionError, + error: CloseReason, frame_type: FrameType, reason_phrase: Vec<u8>, } @@ -124,7 +124,7 @@ pub struct ClosingFrame { impl ClosingFrame { fn new( path: PathRef, - error: ConnectionError, + error: CloseReason, frame_type: FrameType, message: impl AsRef<str>, ) -> Self { @@ -142,12 +142,12 @@ impl ClosingFrame { } pub fn sanitize(&self) -> Option<Self> { - if let ConnectionError::Application(_) = self.error { + if let CloseReason::Application(_) = self.error { // The default CONNECTION_CLOSE frame that is sent when an application // error code needs to be sent in an Initial or Handshake packet. Some(Self { path: Rc::clone(&self.path), - error: ConnectionError::Transport(Error::ApplicationError), + error: CloseReason::Transport(Error::ApplicationError), frame_type: 0, reason_phrase: Vec::new(), }) @@ -156,19 +156,22 @@ impl ClosingFrame { } } + /// Length of a closing frame with a truncated `reason_length`. Allow 8 bytes for the reason + /// phrase to ensure that if it needs to be truncated there is still at least a few bytes of + /// the value. + pub const MIN_LENGTH: usize = 1 + 8 + 8 + 2 + 8; + pub fn write_frame(&self, builder: &mut PacketBuilder) { - // Allow 8 bytes for the reason phrase to ensure that if it needs to be - // truncated there is still at least a few bytes of the value. - if builder.remaining() < 1 + 8 + 8 + 2 + 8 { + if builder.remaining() < ClosingFrame::MIN_LENGTH { return; } match &self.error { - ConnectionError::Transport(e) => { + CloseReason::Transport(e) => { builder.encode_varint(FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT); builder.encode_varint(e.code()); builder.encode_varint(self.frame_type); } - ConnectionError::Application(code) => { + CloseReason::Application(code) => { builder.encode_varint(FRAME_TYPE_CONNECTION_CLOSE_APPLICATION); builder.encode_varint(*code); } @@ -209,10 +212,6 @@ pub enum StateSignaling { impl StateSignaling { pub fn handshake_done(&mut self) { if !matches!(self, Self::Idle) { - debug_assert!( - false, - "StateSignaling must be in Idle state but is in {self:?} state.", - ); return; } *self = Self::HandshakeDone; @@ -231,7 +230,7 @@ impl StateSignaling { pub fn close( &mut self, path: PathRef, - error: ConnectionError, + error: CloseReason, frame_type: FrameType, message: impl AsRef<str>, ) { @@ -243,7 +242,7 @@ impl StateSignaling { pub fn drain( &mut self, path: PathRef, - error: ConnectionError, + error: CloseReason, frame_type: FrameType, message: impl AsRef<str>, ) { diff --git a/third_party/rust/neqo-transport/src/connection/tests/cc.rs b/third_party/rust/neqo-transport/src/connection/tests/cc.rs index b708bc421d..f21f4e184f 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/cc.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/cc.rs @@ -6,7 +6,7 @@ use std::{mem, time::Duration}; -use neqo_common::{qdebug, qinfo, Datagram}; +use neqo_common::{qdebug, qinfo, Datagram, IpTosEcn}; use super::{ super::Output, ack_bytes, assert_full_cwnd, connect_rtt_idle, cwnd, cwnd_avail, cwnd_packets, @@ -36,9 +36,13 @@ fn cc_slow_start() { assert!(cwnd_avail(&client) < ACK_ONLY_SIZE_LIMIT); } -#[test] -/// Verify that CC moves to cong avoidance when a packet is marked lost. -fn cc_slow_start_to_cong_avoidance_recovery_period() { +#[derive(PartialEq, Eq, Clone, Copy)] +enum CongestionSignal { + PacketLoss, + EcnCe, +} + +fn cc_slow_start_to_cong_avoidance_recovery_period(congestion_signal: CongestionSignal) { let mut client = default_client(); let mut server = default_server(); let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); @@ -78,9 +82,17 @@ fn cc_slow_start_to_cong_avoidance_recovery_period() { assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND * 2); let flight2_largest = flight1_largest + u64::try_from(c_tx_dgrams.len()).unwrap(); - // Server: Receive and generate ack again, but drop first packet + // Server: Receive and generate ack again, but this time add congestion + // signal first. now += DEFAULT_RTT / 2; - c_tx_dgrams.remove(0); + match congestion_signal { + CongestionSignal::PacketLoss => { + c_tx_dgrams.remove(0); + } + CongestionSignal::EcnCe => { + c_tx_dgrams.last_mut().unwrap().set_tos(IpTosEcn::Ce.into()); + } + } let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now); assert_eq!( server.stats().frame_tx.largest_acknowledged, @@ -98,6 +110,18 @@ fn cc_slow_start_to_cong_avoidance_recovery_period() { } #[test] +/// Verify that CC moves to cong avoidance when a packet is marked lost. +fn cc_slow_start_to_cong_avoidance_recovery_period_due_to_packet_loss() { + cc_slow_start_to_cong_avoidance_recovery_period(CongestionSignal::PacketLoss); +} + +/// Verify that CC moves to cong avoidance when ACK is marked with ECN CE. +#[test] +fn cc_slow_start_to_cong_avoidance_recovery_period_due_to_ecn_ce() { + cc_slow_start_to_cong_avoidance_recovery_period(CongestionSignal::EcnCe); +} + +#[test] /// Verify that CC stays in recovery period when packet sent before start of /// recovery period is acked. fn cc_cong_avoidance_recovery_period_unchanged() { diff --git a/third_party/rust/neqo-transport/src/connection/tests/close.rs b/third_party/rust/neqo-transport/src/connection/tests/close.rs index 5351dd0d5c..7c620de17e 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/close.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/close.rs @@ -14,13 +14,13 @@ use super::{ }; use crate::{ tparams::{self, TransportParameter}, - AppError, ConnectionError, Error, ERROR_APPLICATION_CLOSE, + AppError, CloseReason, Error, ERROR_APPLICATION_CLOSE, }; fn assert_draining(c: &Connection, expected: &Error) { assert!(c.state().closed()); if let State::Draining { - error: ConnectionError::Transport(error), + error: CloseReason::Transport(error), .. } = c.state() { @@ -40,7 +40,14 @@ fn connection_close() { client.close(now, 42, ""); + let stats_before = client.stats().frame_tx; let out = client.process(None, now); + let stats_after = client.stats().frame_tx; + assert_eq!( + stats_after.connection_close, + stats_before.connection_close + 1 + ); + assert_eq!(stats_after.ack, stats_before.ack + 1); server.process_input(&out.dgram().unwrap(), now); assert_draining(&server, &Error::PeerApplicationError(42)); @@ -57,7 +64,14 @@ fn connection_close_with_long_reason_string() { let long_reason = String::from_utf8([0x61; 2048].to_vec()).unwrap(); client.close(now, 42, long_reason); + let stats_before = client.stats().frame_tx; let out = client.process(None, now); + let stats_after = client.stats().frame_tx; + assert_eq!( + stats_after.connection_close, + stats_before.connection_close + 1 + ); + assert_eq!(stats_after.ack, stats_before.ack + 1); server.process_input(&out.dgram().unwrap(), now); assert_draining(&server, &Error::PeerApplicationError(42)); @@ -100,7 +114,7 @@ fn bad_tls_version() { let dgram = server.process(dgram.as_ref(), now()).dgram(); assert_eq!( *server.state(), - State::Closed(ConnectionError::Transport(Error::ProtocolViolation)) + State::Closed(CloseReason::Transport(Error::ProtocolViolation)) ); assert!(dgram.is_some()); client.process_input(&dgram.unwrap(), now()); @@ -154,7 +168,6 @@ fn closing_and_draining() { assert!(client_close.is_some()); let client_close_timer = client.process(None, now()).callback(); assert_ne!(client_close_timer, Duration::from_secs(0)); - // The client will spit out the same packet in response to anything it receives. let p3 = send_something(&mut server, now()); let client_close2 = client.process(Some(&p3), now()).dgram(); @@ -168,7 +181,7 @@ fn closing_and_draining() { assert_eq!(end, Output::None); assert_eq!( *client.state(), - State::Closed(ConnectionError::Application(APP_ERROR)) + State::Closed(CloseReason::Application(APP_ERROR)) ); // When the server receives the close, it too should generate CONNECTION_CLOSE. @@ -186,7 +199,7 @@ fn closing_and_draining() { assert_eq!(end, Output::None); assert_eq!( *server.state(), - State::Closed(ConnectionError::Transport(Error::PeerApplicationError( + State::Closed(CloseReason::Transport(Error::PeerApplicationError( APP_ERROR ))) ); diff --git a/third_party/rust/neqo-transport/src/connection/tests/datagram.rs b/third_party/rust/neqo-transport/src/connection/tests/datagram.rs index ade8c753be..f1b64b3c8f 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/datagram.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/datagram.rs @@ -19,7 +19,7 @@ use crate::{ packet::PacketBuilder, quic_datagrams::MAX_QUIC_DATAGRAM, send_stream::{RetransmissionPriority, TransmissionPriority}, - Connection, ConnectionError, ConnectionParameters, Error, StreamType, + CloseReason, Connection, ConnectionParameters, Error, StreamType, }; const DATAGRAM_LEN_MTU: u64 = 1310; @@ -362,10 +362,7 @@ fn dgram_no_allowed() { client.process_input(&out, now()); - assert_error( - &client, - &ConnectionError::Transport(Error::ProtocolViolation), - ); + assert_error(&client, &CloseReason::Transport(Error::ProtocolViolation)); } #[test] @@ -383,10 +380,7 @@ fn dgram_too_big() { client.process_input(&out, now()); - assert_error( - &client, - &ConnectionError::Transport(Error::ProtocolViolation), - ); + assert_error(&client, &CloseReason::Transport(Error::ProtocolViolation)); } #[test] @@ -587,7 +581,7 @@ fn datagram_fill() { // Work out how much space we have for a datagram. let space = { - let p = client.paths.primary(); + let p = client.paths.primary().unwrap(); let path = p.borrow(); // Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number, // 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD. diff --git a/third_party/rust/neqo-transport/src/connection/tests/ecn.rs b/third_party/rust/neqo-transport/src/connection/tests/ecn.rs new file mode 100644 index 0000000000..87957297e5 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/ecn.rs @@ -0,0 +1,392 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::time::Duration; + +use neqo_common::{Datagram, IpTos, IpTosEcn}; +use test_fixture::{ + assertions::{assert_v4_path, assert_v6_path}, + fixture_init, now, DEFAULT_ADDR_V4, +}; + +use super::send_something_with_modifier; +use crate::{ + connection::tests::{ + connect_force_idle, connect_force_idle_with_modifier, default_client, default_server, + migration::get_cid, new_client, new_server, send_something, + }, + ecn::ECN_TEST_COUNT, + ConnectionId, ConnectionParameters, StreamType, +}; + +fn assert_ecn_enabled(tos: IpTos) { + assert!(tos.is_ecn_marked()); +} + +fn assert_ecn_disabled(tos: IpTos) { + assert!(!tos.is_ecn_marked()); +} + +fn set_tos(mut d: Datagram, ecn: IpTosEcn) -> Datagram { + d.set_tos(ecn.into()); + d +} + +fn noop() -> fn(Datagram) -> Option<Datagram> { + Some +} + +fn bleach() -> fn(Datagram) -> Option<Datagram> { + |d| Some(set_tos(d, IpTosEcn::NotEct)) +} + +fn remark() -> fn(Datagram) -> Option<Datagram> { + |d| { + if d.tos().is_ecn_marked() { + Some(set_tos(d, IpTosEcn::Ect1)) + } else { + Some(d) + } + } +} + +fn ce() -> fn(Datagram) -> Option<Datagram> { + |d| { + if d.tos().is_ecn_marked() { + Some(set_tos(d, IpTosEcn::Ce)) + } else { + Some(d) + } + } +} + +fn drop() -> fn(Datagram) -> Option<Datagram> { + |_| None +} + +#[test] +fn disables_on_loss() { + let now = now(); + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + // Right after the handshake, the ECN validation should still be in progress. + let client_pkt = send_something(&mut client, now); + assert_ecn_enabled(client_pkt.tos()); + + for _ in 0..ECN_TEST_COUNT { + send_something(&mut client, now); + } + + // ECN should now be disabled. + let client_pkt = send_something(&mut client, now); + assert_ecn_disabled(client_pkt.tos()); +} + +/// This function performs a handshake over a path that modifies packets via `orig_path_modifier`. +/// It then sends `burst` packets on that path, and then migrates to a new path that +/// modifies packets via `new_path_modifier`. It sends `burst` packets on the new path. +/// The function returns the TOS value of the last packet sent on the old path and the TOS value +/// of the last packet sent on the new path to allow for verification of correct behavior. +pub fn migration_with_modifiers( + orig_path_modifier: fn(Datagram) -> Option<Datagram>, + new_path_modifier: fn(Datagram) -> Option<Datagram>, + burst: usize, +) -> (IpTos, IpTos, bool) { + fixture_init(); + let mut client = new_client(ConnectionParameters::default().max_streams(StreamType::UniDi, 64)); + let mut server = new_server(ConnectionParameters::default().max_streams(StreamType::UniDi, 64)); + + connect_force_idle_with_modifier(&mut client, &mut server, orig_path_modifier); + let mut now = now(); + + // Right after the handshake, the ECN validation should still be in progress. + let client_pkt = send_something(&mut client, now); + assert_ecn_enabled(client_pkt.tos()); + server.process_input(&orig_path_modifier(client_pkt).unwrap(), now); + + // Send some data on the current path. + for _ in 0..burst { + let client_pkt = send_something_with_modifier(&mut client, now, orig_path_modifier); + server.process_input(&client_pkt, now); + } + + if let Some(ack) = server.process_output(now).dgram() { + client.process_input(&ack, now); + } + + let client_pkt = send_something(&mut client, now); + let tos_before_migration = client_pkt.tos(); + server.process_input(&orig_path_modifier(client_pkt).unwrap(), now); + + client + .migrate(Some(DEFAULT_ADDR_V4), Some(DEFAULT_ADDR_V4), false, now) + .unwrap(); + + let mut migrated = false; + let probe = new_path_modifier(client.process_output(now).dgram().unwrap()); + if let Some(probe) = probe { + assert_v4_path(&probe, true); // Contains PATH_CHALLENGE. + assert_eq!(client.stats().frame_tx.path_challenge, 1); + let probe_cid = ConnectionId::from(get_cid(&probe)); + + let resp = new_path_modifier(server.process(Some(&probe), now).dgram().unwrap()).unwrap(); + assert_v4_path(&resp, true); + assert_eq!(server.stats().frame_tx.path_response, 1); + assert_eq!(server.stats().frame_tx.path_challenge, 1); + + // Data continues to be exchanged on the old path. + let client_data = send_something_with_modifier(&mut client, now, orig_path_modifier); + assert_ne!(get_cid(&client_data), probe_cid); + assert_v6_path(&client_data, false); + server.process_input(&client_data, now); + let server_data = send_something_with_modifier(&mut server, now, orig_path_modifier); + assert_v6_path(&server_data, false); + client.process_input(&server_data, now); + + // Once the client receives the probe response, it migrates to the new path. + client.process_input(&resp, now); + assert_eq!(client.stats().frame_rx.path_challenge, 1); + migrated = true; + + let migrate_client = send_something_with_modifier(&mut client, now, new_path_modifier); + assert_v4_path(&migrate_client, true); // Responds to server probe. + + // The server now sees the migration and will switch over. + // However, it will probe the old path again, even though it has just + // received a response to its last probe, because it needs to verify + // that the migration is genuine. + server.process_input(&migrate_client, now); + } + + let stream_before = server.stats().frame_tx.stream; + let probe_old_server = send_something_with_modifier(&mut server, now, orig_path_modifier); + // This is just the double-check probe; no STREAM frames. + assert_v6_path(&probe_old_server, migrated); + assert_eq!( + server.stats().frame_tx.path_challenge, + if migrated { 2 } else { 0 } + ); + assert_eq!( + server.stats().frame_tx.stream, + if migrated { stream_before } else { 1 } + ); + + if migrated { + // The server then sends data on the new path. + let migrate_server = + new_path_modifier(server.process_output(now).dgram().unwrap()).unwrap(); + assert_v4_path(&migrate_server, false); + assert_eq!(server.stats().frame_tx.path_challenge, 2); + assert_eq!(server.stats().frame_tx.stream, stream_before + 1); + + // The client receives these checks and responds to the probe, but uses the new path. + client.process_input(&migrate_server, now); + client.process_input(&probe_old_server, now); + let old_probe_resp = send_something_with_modifier(&mut client, now, new_path_modifier); + assert_v6_path(&old_probe_resp, true); + let client_confirmation = client.process_output(now).dgram().unwrap(); + assert_v4_path(&client_confirmation, false); + + // The server has now sent 2 packets, so it is blocked on the pacer. Wait. + let server_pacing = server.process_output(now).callback(); + assert_ne!(server_pacing, Duration::new(0, 0)); + // ... then confirm that the server sends on the new path still. + let server_confirmation = + send_something_with_modifier(&mut server, now + server_pacing, new_path_modifier); + assert_v4_path(&server_confirmation, false); + client.process_input(&server_confirmation, now); + + // Send some data on the new path. + for _ in 0..burst { + now += client.process_output(now).callback(); + let client_pkt = send_something_with_modifier(&mut client, now, new_path_modifier); + server.process_input(&client_pkt, now); + } + + if let Some(ack) = server.process_output(now).dgram() { + client.process_input(&ack, now); + } + } + + now += client.process_output(now).callback(); + let mut client_pkt = send_something(&mut client, now); + while !migrated && client_pkt.source() == DEFAULT_ADDR_V4 { + client_pkt = send_something(&mut client, now); + } + let tos_after_migration = client_pkt.tos(); + (tos_before_migration, tos_after_migration, migrated) +} + +#[test] +fn ecn_migration_zero_burst_all_cases() { + for orig_path_mod in &[noop(), bleach(), remark(), ce()] { + for new_path_mod in &[noop(), bleach(), remark(), ce(), drop()] { + let (before, after, migrated) = + migration_with_modifiers(*orig_path_mod, *new_path_mod, 0); + // Too few packets sent before and after migration to conclude ECN validation. + assert_ecn_enabled(before); + assert_ecn_enabled(after); + // Migration succeeds except if the new path drops ECN. + assert!(*new_path_mod == drop() || migrated); + } + } +} + +#[test] +fn ecn_migration_noop_bleach_data() { + let (before, after, migrated) = migration_with_modifiers(noop(), bleach(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration. + assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching. + assert!(migrated); +} + +#[test] +fn ecn_migration_noop_remark_data() { + let (before, after, migrated) = migration_with_modifiers(noop(), remark(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration. + assert_ecn_disabled(after); // ECN validation fails after migration due to remarking. + assert!(migrated); +} + +#[test] +fn ecn_migration_noop_ce_data() { + let (before, after, migrated) = migration_with_modifiers(noop(), ce(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration. + assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks. + assert!(migrated); +} + +#[test] +fn ecn_migration_noop_drop_data() { + let (before, after, migrated) = migration_with_modifiers(noop(), drop(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration. + assert_ecn_enabled(after); // Migration failed, ECN on original path is still validated. + assert!(!migrated); +} + +#[test] +fn ecn_migration_bleach_noop_data() { + let (before, after, migrated) = migration_with_modifiers(bleach(), noop(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching. + assert_ecn_enabled(after); // ECN validation concludes after migration. + assert!(migrated); +} + +#[test] +fn ecn_migration_bleach_bleach_data() { + let (before, after, migrated) = migration_with_modifiers(bleach(), bleach(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching. + assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching. + assert!(migrated); +} + +#[test] +fn ecn_migration_bleach_remark_data() { + let (before, after, migrated) = migration_with_modifiers(bleach(), remark(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching. + assert_ecn_disabled(after); // ECN validation fails after migration due to remarking. + assert!(migrated); +} + +#[test] +fn ecn_migration_bleach_ce_data() { + let (before, after, migrated) = migration_with_modifiers(bleach(), ce(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching. + assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks. + assert!(migrated); +} + +#[test] +fn ecn_migration_bleach_drop_data() { + let (before, after, migrated) = migration_with_modifiers(bleach(), drop(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching. + // Migration failed, ECN on original path is still disabled. + assert_ecn_disabled(after); + assert!(!migrated); +} + +#[test] +fn ecn_migration_remark_noop_data() { + let (before, after, migrated) = migration_with_modifiers(remark(), noop(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to remarking. + assert_ecn_enabled(after); // ECN validation succeeds after migration. + assert!(migrated); +} + +#[test] +fn ecn_migration_remark_bleach_data() { + let (before, after, migrated) = migration_with_modifiers(remark(), bleach(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to remarking. + assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching. + assert!(migrated); +} + +#[test] +fn ecn_migration_remark_remark_data() { + let (before, after, migrated) = migration_with_modifiers(remark(), remark(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to remarking. + assert_ecn_disabled(after); // ECN validation fails after migration due to remarking. + assert!(migrated); +} + +#[test] +fn ecn_migration_remark_ce_data() { + let (before, after, migrated) = migration_with_modifiers(remark(), ce(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to remarking. + assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks. + assert!(migrated); +} + +#[test] +fn ecn_migration_remark_drop_data() { + let (before, after, migrated) = migration_with_modifiers(remark(), drop(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to remarking. + assert_ecn_disabled(after); // Migration failed, ECN on original path is still disabled. + assert!(!migrated); +} + +#[test] +fn ecn_migration_ce_noop_data() { + let (before, after, migrated) = migration_with_modifiers(ce(), noop(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks. + assert_ecn_enabled(after); // ECN validation concludes after migration. + assert!(migrated); +} + +#[test] +fn ecn_migration_ce_bleach_data() { + let (before, after, migrated) = migration_with_modifiers(ce(), bleach(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks. + assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching + assert!(migrated); +} + +#[test] +fn ecn_migration_ce_remark_data() { + let (before, after, migrated) = migration_with_modifiers(ce(), remark(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks. + assert_ecn_disabled(after); // ECN validation fails after migration due to remarking. + assert!(migrated); +} + +#[test] +fn ecn_migration_ce_ce_data() { + let (before, after, migrated) = migration_with_modifiers(ce(), ce(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks. + assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks. + assert!(migrated); +} + +#[test] +fn ecn_migration_ce_drop_data() { + let (before, after, migrated) = migration_with_modifiers(ce(), drop(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks. + // Migration failed, ECN on original path is still enabled. + assert_ecn_enabled(after); + assert!(!migrated); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/handshake.rs b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs index f2103523ec..c908340616 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/handshake.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs @@ -35,7 +35,7 @@ use crate::{ server::ValidateAddress, tparams::{TransportParameter, MIN_ACK_DELAY}, tracking::DEFAULT_ACK_DELAY, - ConnectionError, ConnectionParameters, EmptyConnectionIdGenerator, Error, StreamType, Version, + CloseReason, ConnectionParameters, EmptyConnectionIdGenerator, Error, StreamType, Version, }; const ECH_CONFIG_ID: u8 = 7; @@ -111,8 +111,8 @@ fn handshake_failed_authentication() { qdebug!("---- server: Alert(certificate_revoked)"); let out = server.process(out.as_dgram_ref(), now()); assert!(out.as_dgram_ref().is_some()); - assert_error(&client, &ConnectionError::Transport(Error::CryptoAlert(44))); - assert_error(&server, &ConnectionError::Transport(Error::PeerError(300))); + assert_error(&client, &CloseReason::Transport(Error::CryptoAlert(44))); + assert_error(&server, &CloseReason::Transport(Error::PeerError(300))); } #[test] @@ -133,11 +133,8 @@ fn no_alpn() { handshake(&mut client, &mut server, now(), Duration::new(0, 0)); // TODO (mt): errors are immediate, which means that we never send CONNECTION_CLOSE // and the client never sees the server's rejection of its handshake. - // assert_error(&client, ConnectionError::Transport(Error::CryptoAlert(120))); - assert_error( - &server, - &ConnectionError::Transport(Error::CryptoAlert(120)), - ); + // assert_error(&client, CloseReason::Transport(Error::CryptoAlert(120))); + assert_error(&server, &CloseReason::Transport(Error::CryptoAlert(120))); } #[test] @@ -934,10 +931,10 @@ fn ech_retry() { server.process_input(&dgram.unwrap(), now()); assert_eq!( server.state().error(), - Some(&ConnectionError::Transport(Error::PeerError(0x100 + 121))) + Some(&CloseReason::Transport(Error::PeerError(0x100 + 121))) ); - let Some(ConnectionError::Transport(Error::EchRetry(updated_config))) = client.state().error() + let Some(CloseReason::Transport(Error::EchRetry(updated_config))) = client.state().error() else { panic!( "Client state should be failed with EchRetry, is {:?}", @@ -984,7 +981,7 @@ fn ech_retry_fallback_rejected() { client.authenticated(AuthenticationStatus::PolicyRejection, now()); assert!(client.state().error().is_some()); - if let Some(ConnectionError::Transport(Error::EchRetry(_))) = client.state().error() { + if let Some(CloseReason::Transport(Error::EchRetry(_))) = client.state().error() { panic!("Client should not get EchRetry error"); } @@ -993,14 +990,13 @@ fn ech_retry_fallback_rejected() { server.process_input(&dgram.unwrap(), now()); assert_eq!( server.state().error(), - Some(&ConnectionError::Transport(Error::PeerError(298))) + Some(&CloseReason::Transport(Error::PeerError(298))) ); // A bad_certificate alert. } #[test] fn bad_min_ack_delay() { - const EXPECTED_ERROR: ConnectionError = - ConnectionError::Transport(Error::TransportParameterError); + const EXPECTED_ERROR: CloseReason = CloseReason::Transport(Error::TransportParameterError); let mut server = default_server(); let max_ad = u64::try_from(DEFAULT_ACK_DELAY.as_micros()).unwrap(); server @@ -1018,7 +1014,7 @@ fn bad_min_ack_delay() { server.process_input(&dgram.unwrap(), now()); assert_eq!( server.state().error(), - Some(&ConnectionError::Transport(Error::PeerError( + Some(&CloseReason::Transport(Error::PeerError( Error::TransportParameterError.code() ))) ); diff --git a/third_party/rust/neqo-transport/src/connection/tests/keys.rs b/third_party/rust/neqo-transport/src/connection/tests/keys.rs index 847b253284..c2ae9529bf 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/keys.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/keys.rs @@ -11,7 +11,7 @@ use test_fixture::now; use super::{ super::{ - super::{ConnectionError, ERROR_AEAD_LIMIT_REACHED}, + super::{CloseReason, ERROR_AEAD_LIMIT_REACHED}, Connection, ConnectionParameters, Error, Output, State, StreamType, }, connect, connect_force_idle, default_client, default_server, maybe_authenticate, @@ -269,7 +269,7 @@ fn exhaust_write_keys() { assert!(dgram.is_none()); assert!(matches!( client.state(), - State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + State::Closed(CloseReason::Transport(Error::KeysExhausted)) )); } @@ -285,14 +285,14 @@ fn exhaust_read_keys() { let dgram = server.process(Some(&dgram), now()).dgram(); assert!(matches!( server.state(), - State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + State::Closed(CloseReason::Transport(Error::KeysExhausted)) )); client.process_input(&dgram.unwrap(), now()); assert!(matches!( client.state(), State::Draining { - error: ConnectionError::Transport(Error::PeerError(ERROR_AEAD_LIMIT_REACHED)), + error: CloseReason::Transport(Error::PeerError(ERROR_AEAD_LIMIT_REACHED)), .. } )); @@ -341,6 +341,6 @@ fn automatic_update_write_keys_blocked() { assert!(dgram.is_none()); assert!(matches!( client.state(), - State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + State::Closed(CloseReason::Transport(Error::KeysExhausted)) )); } diff --git a/third_party/rust/neqo-transport/src/connection/tests/migration.rs b/third_party/rust/neqo-transport/src/connection/tests/migration.rs index 405ae161a4..779cc78c53 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/migration.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/migration.rs @@ -30,7 +30,7 @@ use crate::{ packet::PacketBuilder, path::{PATH_MTU_V4, PATH_MTU_V6}, tparams::{self, PreferredAddress, TransportParameter}, - ConnectionError, ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef, + CloseReason, ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef, ConnectionParameters, EmptyConnectionIdGenerator, Error, }; @@ -357,13 +357,13 @@ fn migrate_same_fail() { assert!(matches!(res, Output::None)); assert!(matches!( client.state(), - State::Closed(ConnectionError::Transport(Error::NoAvailablePath)) + State::Closed(CloseReason::Transport(Error::NoAvailablePath)) )); } /// This gets the connection ID from a datagram using the default /// connection ID generator/decoder. -fn get_cid(d: &Datagram) -> ConnectionIdRef { +pub fn get_cid(d: &Datagram) -> ConnectionIdRef { let gen = CountingConnectionIdGenerator::default(); assert_eq!(d[0] & 0x80, 0); // Only support short packets for now. gen.decode_cid(&mut Decoder::from(&d[1..])).unwrap() @@ -894,7 +894,7 @@ fn retire_prior_to_migration_failure() { assert!(matches!( client.state(), State::Closing { - error: ConnectionError::Transport(Error::InvalidMigration), + error: CloseReason::Transport(Error::InvalidMigration), .. } )); diff --git a/third_party/rust/neqo-transport/src/connection/tests/mod.rs b/third_party/rust/neqo-transport/src/connection/tests/mod.rs index c8c87a0df0..65283b8eb8 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/mod.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/mod.rs @@ -17,7 +17,7 @@ use neqo_common::{event::Provider, qdebug, qtrace, Datagram, Decoder, Role}; use neqo_crypto::{random, AllowZeroRtt, AuthenticationStatus, ResumptionToken}; use test_fixture::{fixture_init, new_neqo_qlog, now, DEFAULT_ADDR}; -use super::{Connection, ConnectionError, ConnectionId, Output, State}; +use super::{CloseReason, Connection, ConnectionId, Output, State}; use crate::{ addr_valid::{AddressValidation, ValidateAddress}, cc::{CWND_INITIAL_PKTS, CWND_MIN}, @@ -37,6 +37,7 @@ mod ackrate; mod cc; mod close; mod datagram; +mod ecn; mod handshake; mod idle; mod keys; @@ -170,17 +171,13 @@ impl crate::connection::test_internal::FrameWriter for PingWriter { } } -trait DatagramModifier: FnMut(Datagram) -> Option<Datagram> {} - -impl<T> DatagramModifier for T where T: FnMut(Datagram) -> Option<Datagram> {} - /// Drive the handshake between the client and server. fn handshake_with_modifier( client: &mut Connection, server: &mut Connection, now: Instant, rtt: Duration, - mut modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) -> Instant { let mut a = client; let mut b = server; @@ -248,8 +245,8 @@ fn connect_fail( server_error: Error, ) { handshake(client, server, now(), Duration::new(0, 0)); - assert_error(client, &ConnectionError::Transport(client_error)); - assert_error(server, &ConnectionError::Transport(server_error)); + assert_error(client, &CloseReason::Transport(client_error)); + assert_error(server, &CloseReason::Transport(server_error)); } fn connect_with_rtt_and_modifier( @@ -257,7 +254,7 @@ fn connect_with_rtt_and_modifier( server: &mut Connection, now: Instant, rtt: Duration, - modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) -> Instant { fn check_rtt(stats: &Stats, rtt: Duration) { assert_eq!(stats.rtt, rtt); @@ -287,7 +284,7 @@ fn connect(client: &mut Connection, server: &mut Connection) { connect_with_rtt(client, server, now(), Duration::new(0, 0)); } -fn assert_error(c: &Connection, expected: &ConnectionError) { +fn assert_error(c: &Connection, expected: &CloseReason) { match c.state() { State::Closing { error, .. } | State::Draining { error, .. } | State::Closed(error) => { assert_eq!(*error, *expected, "{c} error mismatch"); @@ -333,7 +330,7 @@ fn connect_rtt_idle_with_modifier( client: &mut Connection, server: &mut Connection, rtt: Duration, - modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) -> Instant { let now = connect_with_rtt_and_modifier(client, server, now(), rtt, modifier); assert_idle(client, server, rtt, now); @@ -351,7 +348,7 @@ fn connect_rtt_idle(client: &mut Connection, server: &mut Connection, rtt: Durat fn connect_force_idle_with_modifier( client: &mut Connection, server: &mut Connection, - modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) { connect_rtt_idle_with_modifier(client, server, Duration::new(0, 0), modifier); } @@ -380,7 +377,7 @@ fn fill_stream(c: &mut Connection, stream: StreamId) { fn fill_cwnd(c: &mut Connection, stream: StreamId, mut now: Instant) -> (Vec<Datagram>, Instant) { // Train wreck function to get the remaining congestion window on the primary path. fn cwnd(c: &Connection) -> usize { - c.paths.primary().borrow().sender().cwnd_avail() + c.paths.primary().unwrap().borrow().sender().cwnd_avail() } qtrace!("fill_cwnd starting cwnd: {}", cwnd(c)); @@ -478,10 +475,10 @@ where // Get the current congestion window for the connection. fn cwnd(c: &Connection) -> usize { - c.paths.primary().borrow().sender().cwnd() + c.paths.primary().unwrap().borrow().sender().cwnd() } fn cwnd_avail(c: &Connection) -> usize { - c.paths.primary().borrow().sender().cwnd_avail() + c.paths.primary().unwrap().borrow().sender().cwnd_avail() } fn induce_persistent_congestion( @@ -576,7 +573,7 @@ fn send_something_paced_with_modifier( sender: &mut Connection, mut now: Instant, allow_pacing: bool, - mut modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) -> (Datagram, Instant) { let stream_id = sender.stream_create(StreamType::UniDi).unwrap(); assert!(sender.stream_send(stream_id, DEFAULT_STREAM_DATA).is_ok()); @@ -608,7 +605,7 @@ fn send_something_paced( fn send_something_with_modifier( sender: &mut Connection, now: Instant, - modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) -> Datagram { send_something_paced_with_modifier(sender, now, false, modifier).0 } diff --git a/third_party/rust/neqo-transport/src/connection/tests/stream.rs b/third_party/rust/neqo-transport/src/connection/tests/stream.rs index 66d3bf32f3..f7472d917f 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/stream.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/stream.rs @@ -19,9 +19,9 @@ use crate::{ send_stream::{OrderGroup, SendStreamState, SEND_BUFFER_SIZE}, streams::{SendOrder, StreamOrder}, tparams::{self, TransportParameter}, + CloseReason, // tracking::DEFAULT_ACK_PACKET_TOLERANCE, Connection, - ConnectionError, ConnectionParameters, Error, StreamId, @@ -494,12 +494,9 @@ fn exceed_max_data() { assert_error( &client, - &ConnectionError::Transport(Error::PeerError(Error::FlowControlError.code())), - ); - assert_error( - &server, - &ConnectionError::Transport(Error::FlowControlError), + &CloseReason::Transport(Error::PeerError(Error::FlowControlError.code())), ); + assert_error(&server, &CloseReason::Transport(Error::FlowControlError)); } #[test] diff --git a/third_party/rust/neqo-transport/src/connection/tests/vn.rs b/third_party/rust/neqo-transport/src/connection/tests/vn.rs index 93872a94f4..815868d78d 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/vn.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/vn.rs @@ -10,7 +10,7 @@ use neqo_common::{event::Provider, Decoder, Encoder}; use test_fixture::{assertions, datagram, now}; use super::{ - super::{ConnectionError, ConnectionEvent, Output, State, ZeroRttState}, + super::{CloseReason, ConnectionEvent, Output, State, ZeroRttState}, connect, connect_fail, default_client, default_server, exchange_ticket, new_client, new_server, send_something, }; @@ -124,7 +124,7 @@ fn version_negotiation_only_reserved() { assert_eq!(client.process(Some(&dgram), now()), Output::None); match client.state() { State::Closed(err) => { - assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation)); + assert_eq!(*err, CloseReason::Transport(Error::VersionNegotiation)); } _ => panic!("Invalid client state"), } @@ -183,7 +183,7 @@ fn version_negotiation_not_supported() { assert_eq!(client.process(Some(&dgram), now()), Output::None); match client.state() { State::Closed(err) => { - assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation)); + assert_eq!(*err, CloseReason::Transport(Error::VersionNegotiation)); } _ => panic!("Invalid client state"), } @@ -338,7 +338,7 @@ fn invalid_server_version() { // The server effectively hasn't reacted here. match server.state() { State::Closed(err) => { - assert_eq!(*err, ConnectionError::Transport(Error::CryptoAlert(47))); + assert_eq!(*err, CloseReason::Transport(Error::CryptoAlert(47))); } _ => panic!("invalid server state"), } diff --git a/third_party/rust/neqo-transport/src/ecn.rs b/third_party/rust/neqo-transport/src/ecn.rs new file mode 100644 index 0000000000..20eb4da003 --- /dev/null +++ b/third_party/rust/neqo-transport/src/ecn.rs @@ -0,0 +1,225 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::ops::{AddAssign, Deref, DerefMut, Sub}; + +use enum_map::EnumMap; +use neqo_common::{qdebug, qinfo, qwarn, IpTosEcn}; + +use crate::{packet::PacketNumber, tracking::SentPacket}; + +/// The number of packets to use for testing a path for ECN capability. +pub const ECN_TEST_COUNT: usize = 10; + +/// The state information related to testing a path for ECN capability. +/// See RFC9000, Appendix A.4. +#[derive(Debug, PartialEq, Clone)] +enum EcnValidationState { + /// The path is currently being tested for ECN capability, with the number of probes sent so + /// far on the path during the ECN validation. + Testing(usize), + /// The validation test has concluded but the path's ECN capability is not yet known. + Unknown, + /// The path is known to **not** be ECN capable. + Failed, + /// The path is known to be ECN capable. + Capable, +} + +impl Default for EcnValidationState { + fn default() -> Self { + EcnValidationState::Testing(0) + } +} + +/// The counts for different ECN marks. +#[derive(PartialEq, Eq, Debug, Clone, Copy, Default)] +pub struct EcnCount(EnumMap<IpTosEcn, u64>); + +impl Deref for EcnCount { + type Target = EnumMap<IpTosEcn, u64>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for EcnCount { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl EcnCount { + pub fn new(not_ect: u64, ect0: u64, ect1: u64, ce: u64) -> Self { + // Yes, the enum array order is different from the argument order. + Self(EnumMap::from_array([not_ect, ect1, ect0, ce])) + } + + /// Whether any of the ECN counts are non-zero. + pub fn is_some(&self) -> bool { + self[IpTosEcn::Ect0] > 0 || self[IpTosEcn::Ect1] > 0 || self[IpTosEcn::Ce] > 0 + } +} + +impl Sub<EcnCount> for EcnCount { + type Output = EcnCount; + + /// Subtract the ECN counts in `other` from `self`. + fn sub(self, other: EcnCount) -> EcnCount { + let mut diff = EcnCount::default(); + for (ecn, count) in &mut *diff { + *count = self[ecn].saturating_sub(other[ecn]); + } + diff + } +} + +impl AddAssign<IpTosEcn> for EcnCount { + fn add_assign(&mut self, ecn: IpTosEcn) { + self[ecn] += 1; + } +} + +#[derive(Debug, Default)] +pub struct EcnInfo { + /// The current state of ECN validation on this path. + state: EcnValidationState, + + /// The largest ACK seen so far. + largest_acked: PacketNumber, + + /// The ECN counts from the last ACK frame that increased `largest_acked`. + baseline: EcnCount, +} + +impl EcnInfo { + /// Set the baseline (= the ECN counts from the last ACK Frame). + pub fn set_baseline(&mut self, baseline: EcnCount) { + self.baseline = baseline; + } + + /// Expose the current baseline. + pub fn baseline(&self) -> EcnCount { + self.baseline + } + + /// Count the number of packets sent out on this path during ECN validation. + /// Exit ECN validation if the number of packets sent exceeds `ECN_TEST_COUNT`. + /// We do not implement the part of the RFC that says to exit ECN validation if the time since + /// the start of ECN validation exceeds 3 * PTO, since this seems to happen much too quickly. + pub fn on_packet_sent(&mut self) { + if let EcnValidationState::Testing(ref mut probes_sent) = &mut self.state { + *probes_sent += 1; + qdebug!("ECN probing: sent {} probes", probes_sent); + if *probes_sent == ECN_TEST_COUNT { + qdebug!("ECN probing concluded with {} probes sent", probes_sent); + self.state = EcnValidationState::Unknown; + } + } + } + + /// Process ECN counts from an ACK frame. + /// + /// Returns whether ECN counts contain new valid ECN CE marks. + pub fn on_packets_acked( + &mut self, + acked_packets: &[SentPacket], + ack_ecn: Option<EcnCount>, + ) -> bool { + let prev_baseline = self.baseline; + + self.validate_ack_ecn_and_update(acked_packets, ack_ecn); + + matches!(self.state, EcnValidationState::Capable) + && (self.baseline - prev_baseline)[IpTosEcn::Ce] > 0 + } + + /// After the ECN validation test has ended, check if the path is ECN capable. + pub fn validate_ack_ecn_and_update( + &mut self, + acked_packets: &[SentPacket], + ack_ecn: Option<EcnCount>, + ) { + // RFC 9000, Appendix A.4: + // + // > From the "unknown" state, successful validation of the ECN counts in an ACK frame + // > (see Section 13.4.2.1) causes the ECN state for the path to become "capable", unless + // > no marked packet has been acknowledged. + match self.state { + EcnValidationState::Testing { .. } | EcnValidationState::Failed => return, + EcnValidationState::Unknown | EcnValidationState::Capable => {} + } + + // RFC 9000, Section 13.4.2.1: + // + // > Validating ECN counts from reordered ACK frames can result in failure. An endpoint MUST + // > NOT fail ECN validation as a result of processing an ACK frame that does not increase + // > the largest acknowledged packet number. + let largest_acked = acked_packets.first().expect("must be there").pn; + if largest_acked <= self.largest_acked { + return; + } + + // RFC 9000, Section 13.4.2.1: + // + // > An endpoint that receives an ACK frame with ECN counts therefore validates + // > the counts before using them. It performs this validation by comparing newly + // > received counts against those from the last successfully processed ACK frame. + // + // > If an ACK frame newly acknowledges a packet that the endpoint sent with + // > either the ECT(0) or ECT(1) codepoint set, ECN validation fails if the + // > corresponding ECN counts are not present in the ACK frame. + let Some(ack_ecn) = ack_ecn else { + qwarn!("ECN validation failed, no ECN counts in ACK frame"); + self.state = EcnValidationState::Failed; + return; + }; + + // We always mark with ECT(0) - if at all - so we only need to check for that. + // + // > ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is + // > less than the number of newly acknowledged packets that were originally sent with an + // > ECT(0) marking. + let newly_acked_sent_with_ect0: u64 = acked_packets + .iter() + .filter(|p| p.ecn_mark == IpTosEcn::Ect0) + .count() + .try_into() + .unwrap(); + if newly_acked_sent_with_ect0 == 0 { + qwarn!("ECN validation failed, no ECT(0) packets were newly acked"); + self.state = EcnValidationState::Failed; + return; + } + let ecn_diff = ack_ecn - self.baseline; + let sum_inc = ecn_diff[IpTosEcn::Ect0] + ecn_diff[IpTosEcn::Ce]; + if sum_inc < newly_acked_sent_with_ect0 { + qwarn!( + "ECN validation failed, ACK counted {} new marks, but {} of newly acked packets were sent with ECT(0)", + sum_inc, + newly_acked_sent_with_ect0 + ); + self.state = EcnValidationState::Failed; + } else if ecn_diff[IpTosEcn::Ect1] > 0 { + qwarn!("ECN validation failed, ACK counted ECT(1) marks that were never sent"); + self.state = EcnValidationState::Failed; + } else { + qinfo!("ECN validation succeeded, path is capable",); + self.state = EcnValidationState::Capable; + } + self.baseline = ack_ecn; + self.largest_acked = largest_acked; + } + + /// The ECN mark to use for packets sent on this path. + pub fn ecn_mark(&self) -> IpTosEcn { + match self.state { + EcnValidationState::Testing { .. } | EcnValidationState::Capable => IpTosEcn::Ect0, + EcnValidationState::Failed | EcnValidationState::Unknown => IpTosEcn::NotEct, + } + } +} diff --git a/third_party/rust/neqo-transport/src/events.rs b/third_party/rust/neqo-transport/src/events.rs index a892e384b9..68ef0d6798 100644 --- a/third_party/rust/neqo-transport/src/events.rs +++ b/third_party/rust/neqo-transport/src/events.rs @@ -256,7 +256,7 @@ impl EventProvider for ConnectionEvents { mod tests { use neqo_common::event::Provider; - use crate::{ConnectionError, ConnectionEvent, ConnectionEvents, Error, State, StreamId}; + use crate::{CloseReason, ConnectionEvent, ConnectionEvents, Error, State, StreamId}; #[test] fn event_culling() { @@ -314,7 +314,7 @@ mod tests { evts.send_stream_writable(9.into()); evts.send_stream_stop_sending(10.into(), 55); - evts.connection_state_change(State::Closed(ConnectionError::Transport( + evts.connection_state_change(State::Closed(CloseReason::Transport( Error::StreamStateError, ))); assert_eq!(evts.events().count(), 1); diff --git a/third_party/rust/neqo-transport/src/frame.rs b/third_party/rust/neqo-transport/src/frame.rs index d84eb61ce8..7d009f3b46 100644 --- a/third_party/rust/neqo-transport/src/frame.rs +++ b/third_party/rust/neqo-transport/src/frame.rs @@ -8,13 +8,14 @@ use std::ops::RangeInclusive; -use neqo_common::{qtrace, Decoder}; +use neqo_common::{qtrace, Decoder, Encoder}; use crate::{ cid::MAX_CONNECTION_ID_LEN, + ecn::EcnCount, packet::PacketType, stream_id::{StreamId, StreamType}, - AppError, ConnectionError, Error, Res, TransportError, + AppError, CloseReason, Error, Res, TransportError, }; #[allow(clippy::module_name_repetitions)] @@ -23,7 +24,7 @@ pub type FrameType = u64; pub const FRAME_TYPE_PADDING: FrameType = 0x0; pub const FRAME_TYPE_PING: FrameType = 0x1; pub const FRAME_TYPE_ACK: FrameType = 0x2; -const FRAME_TYPE_ACK_ECN: FrameType = 0x3; +pub const FRAME_TYPE_ACK_ECN: FrameType = 0x3; pub const FRAME_TYPE_RESET_STREAM: FrameType = 0x4; pub const FRAME_TYPE_STOP_SENDING: FrameType = 0x5; pub const FRAME_TYPE_CRYPTO: FrameType = 0x6; @@ -86,11 +87,11 @@ impl CloseError { } } -impl From<ConnectionError> for CloseError { - fn from(err: ConnectionError) -> Self { +impl From<CloseReason> for CloseError { + fn from(err: CloseReason) -> Self { match err { - ConnectionError::Transport(c) => Self::Transport(c.code()), - ConnectionError::Application(c) => Self::Application(c), + CloseReason::Transport(c) => Self::Transport(c.code()), + CloseReason::Application(c) => Self::Application(c), } } } @@ -116,6 +117,7 @@ pub enum Frame<'a> { ack_delay: u64, first_ack_range: u64, ack_ranges: Vec<AckRange>, + ecn_count: Option<EcnCount>, }, ResetStream { stream_id: StreamId, @@ -182,7 +184,7 @@ pub enum Frame<'a> { frame_type: u64, // Not a reference as we use this to hold the value. // This is not used in optimized builds anyway. - reason_phrase: Vec<u8>, + reason_phrase: String, }, HandshakeDone, AckFrequency { @@ -224,7 +226,7 @@ impl<'a> Frame<'a> { match self { Self::Padding { .. } => FRAME_TYPE_PADDING, Self::Ping => FRAME_TYPE_PING, - Self::Ack { .. } => FRAME_TYPE_ACK, // We don't do ACK ECN. + Self::Ack { .. } => FRAME_TYPE_ACK, Self::ResetStream { .. } => FRAME_TYPE_RESET_STREAM, Self::StopSending { .. } => FRAME_TYPE_STOP_SENDING, Self::Crypto { .. } => FRAME_TYPE_CRYPTO, @@ -426,8 +428,54 @@ impl<'a> Frame<'a> { d(dec.decode_varint()) } - // TODO(ekr@rtfm.com): check for minimal encoding + fn decode_ack<'a>(dec: &mut Decoder<'a>, ecn: bool) -> Res<Frame<'a>> { + let la = dv(dec)?; + let ad = dv(dec)?; + let nr = dv(dec).and_then(|nr| { + if nr < MAX_ACK_RANGE_COUNT { + Ok(nr) + } else { + Err(Error::TooMuchData) + } + })?; + let fa = dv(dec)?; + let mut arr: Vec<AckRange> = Vec::with_capacity(usize::try_from(nr)?); + for _ in 0..nr { + let ar = AckRange { + gap: dv(dec)?, + range: dv(dec)?, + }; + arr.push(ar); + } + + // Now check for the values for ACK_ECN. + let ecn_count = if ecn { + Some(EcnCount::new(0, dv(dec)?, dv(dec)?, dv(dec)?)) + } else { + None + }; + + Ok(Frame::Ack { + largest_acknowledged: la, + ack_delay: ad, + first_ack_range: fa, + ack_ranges: arr, + ecn_count, + }) + } + + // Check for minimal encoding of frame type. + let pos = dec.offset(); let t = dv(dec)?; + // RFC 9000, Section 12.4: + // + // The Frame Type field uses a variable-length integer encoding [...], + // with one exception. To ensure simple and efficient implementations of + // frame parsing, a frame type MUST use the shortest possible encoding. + if Encoder::varint_len(t) != dec.offset() - pos { + return Err(Error::ProtocolViolation); + } + match t { FRAME_TYPE_PADDING => { let mut length: u16 = 1; @@ -449,40 +497,8 @@ impl<'a> Frame<'a> { _ => return Err(Error::NoMoreData), }, }), - FRAME_TYPE_ACK | FRAME_TYPE_ACK_ECN => { - let la = dv(dec)?; - let ad = dv(dec)?; - let nr = dv(dec).and_then(|nr| { - if nr < MAX_ACK_RANGE_COUNT { - Ok(nr) - } else { - Err(Error::TooMuchData) - } - })?; - let fa = dv(dec)?; - let mut arr: Vec<AckRange> = Vec::with_capacity(usize::try_from(nr)?); - for _ in 0..nr { - let ar = AckRange { - gap: dv(dec)?, - range: dv(dec)?, - }; - arr.push(ar); - } - - // Now check for the values for ACK_ECN. - if t == FRAME_TYPE_ACK_ECN { - dv(dec)?; - dv(dec)?; - dv(dec)?; - } - - Ok(Self::Ack { - largest_acknowledged: la, - ack_delay: ad, - first_ack_range: fa, - ack_ranges: arr, - }) - } + FRAME_TYPE_ACK => decode_ack(dec, false), + FRAME_TYPE_ACK_ECN => decode_ack(dec, true), FRAME_TYPE_STOP_SENDING => Ok(Self::StopSending { stream_id: StreamId::from(dv(dec)?), application_error_code: dv(dec)?, @@ -598,7 +614,7 @@ impl<'a> Frame<'a> { 0 }; // We can tolerate this copy for now. - let reason_phrase = d(dec.decode_vvec())?.to_vec(); + let reason_phrase = String::from_utf8_lossy(d(dec.decode_vvec())?).to_string(); Ok(Self::ConnectionClose { error_code, frame_type, @@ -647,13 +663,14 @@ mod tests { use crate::{ cid::MAX_CONNECTION_ID_LEN, + ecn::EcnCount, frame::{AckRange, Frame, FRAME_TYPE_ACK}, CloseError, Error, StreamId, StreamType, }; fn just_dec(f: &Frame, s: &str) { let encoded = Encoder::from_hex(s); - let decoded = Frame::decode(&mut encoded.as_decoder()).unwrap(); + let decoded = Frame::decode(&mut encoded.as_decoder()).expect("Failed to decode frame"); assert_eq!(*f, decoded); } @@ -679,7 +696,8 @@ mod tests { largest_acknowledged: 0x1234, ack_delay: 0x1235, first_ack_range: 0x1236, - ack_ranges: ar, + ack_ranges: ar.clone(), + ecn_count: None, }; just_dec(&f, "025234523502523601020304"); @@ -689,10 +707,18 @@ mod tests { let mut dec = enc.as_decoder(); assert_eq!(Frame::decode(&mut dec).unwrap_err(), Error::NoMoreData); - // Try to parse ACK_ECN without ECN values + // Try to parse ACK_ECN with ECN values + let ecn_count = Some(EcnCount::new(0, 1, 2, 3)); + let fe = Frame::Ack { + largest_acknowledged: 0x1234, + ack_delay: 0x1235, + first_ack_range: 0x1236, + ack_ranges: ar, + ecn_count, + }; let enc = Encoder::from_hex("035234523502523601020304010203"); let mut dec = enc.as_decoder(); - assert_eq!(Frame::decode(&mut dec).unwrap(), f); + assert_eq!(Frame::decode(&mut dec).unwrap(), fe); } #[test] @@ -899,7 +925,7 @@ mod tests { let f = Frame::ConnectionClose { error_code: CloseError::Transport(0x5678), frame_type: 0x1234, - reason_phrase: vec![0x01, 0x02, 0x03], + reason_phrase: String::from("\x01\x02\x03"), }; just_dec(&f, "1c80005678523403010203"); @@ -910,7 +936,7 @@ mod tests { let f = Frame::ConnectionClose { error_code: CloseError::Application(0x5678), frame_type: 0, - reason_phrase: vec![0x01, 0x02, 0x03], + reason_phrase: String::from("\x01\x02\x03"), }; just_dec(&f, "1d8000567803010203"); @@ -989,14 +1015,14 @@ mod tests { fill: true, }; - just_dec(&f, "4030010203"); + just_dec(&f, "30010203"); // With the length bit. let f = Frame::Datagram { data: &[1, 2, 3], fill: false, }; - just_dec(&f, "403103010203"); + just_dec(&f, "3103010203"); } #[test] @@ -1010,4 +1036,15 @@ mod tests { assert_eq!(Err(Error::TooMuchData), Frame::decode(&mut e.as_decoder())); } + + #[test] + #[should_panic(expected = "Failed to decode frame")] + fn invalid_frame_type_len() { + let f = Frame::Datagram { + data: &[1, 2, 3], + fill: true, + }; + + just_dec(&f, "4030010203"); + } } diff --git a/third_party/rust/neqo-transport/src/lib.rs b/third_party/rust/neqo-transport/src/lib.rs index 5488472b58..723a86980e 100644 --- a/third_party/rust/neqo-transport/src/lib.rs +++ b/third_party/rust/neqo-transport/src/lib.rs @@ -15,10 +15,17 @@ mod cc; mod cid; mod connection; mod crypto; +mod ecn; mod events; mod fc; +#[cfg(fuzzing)] +pub mod frame; +#[cfg(not(fuzzing))] mod frame; mod pace; +#[cfg(fuzzing)] +pub mod packet; +#[cfg(not(fuzzing))] mod packet; mod path; mod qlog; @@ -202,13 +209,17 @@ impl ::std::fmt::Display for Error { pub type AppError = u64; +#[deprecated(note = "use `CloseReason` instead")] +pub type ConnectionError = CloseReason; + +/// Reason why a connection closed. #[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] -pub enum ConnectionError { +pub enum CloseReason { Transport(Error), Application(AppError), } -impl ConnectionError { +impl CloseReason { #[must_use] pub fn app_code(&self) -> Option<AppError> { match self { @@ -216,9 +227,19 @@ impl ConnectionError { Self::Transport(_) => None, } } + + /// Checks enclosed error for [`Error::NoError`] and + /// [`CloseReason::Application(0)`]. + #[must_use] + pub fn is_error(&self) -> bool { + !matches!( + self, + CloseReason::Transport(Error::NoError) | CloseReason::Application(0), + ) + } } -impl From<CloseError> for ConnectionError { +impl From<CloseError> for CloseReason { fn from(err: CloseError) -> Self { match err { CloseError::Transport(c) => Self::Transport(Error::PeerError(c)), diff --git a/third_party/rust/neqo-transport/src/packet/mod.rs b/third_party/rust/neqo-transport/src/packet/mod.rs index ce611a9664..10d9b13208 100644 --- a/third_party/rust/neqo-transport/src/packet/mod.rs +++ b/third_party/rust/neqo-transport/src/packet/mod.rs @@ -740,6 +740,7 @@ impl<'a> PublicPacket<'a> { } #[must_use] + #[allow(clippy::len_without_is_empty)] // is_empty() would always return false in this case pub fn len(&self) -> usize { self.data.len() } diff --git a/third_party/rust/neqo-transport/src/path.rs b/third_party/rust/neqo-transport/src/path.rs index 50e458ff36..0e4c82b1ca 100644 --- a/third_party/rust/neqo-transport/src/path.rs +++ b/third_party/rust/neqo-transport/src/path.rs @@ -22,6 +22,7 @@ use crate::{ ackrate::{AckRate, PeerAckDelay}, cc::CongestionControlAlgorithm, cid::{ConnectionId, ConnectionIdRef, ConnectionIdStore, RemoteConnectionIdEntry}, + ecn::{EcnCount, EcnInfo}, frame::{FRAME_TYPE_PATH_CHALLENGE, FRAME_TYPE_PATH_RESPONSE, FRAME_TYPE_RETIRE_CONNECTION_ID}, packet::PacketBuilder, recovery::RecoveryToken, @@ -145,15 +146,8 @@ impl Paths { }) } - /// Get a reference to the primary path. This will assert if there is no primary - /// path, which happens at a server prior to receiving a valid Initial packet - /// from a client. So be careful using this method. - pub fn primary(&self) -> PathRef { - self.primary_fallible().unwrap() - } - - /// Get a reference to the primary path. Use this prior to handshake completion. - pub fn primary_fallible(&self) -> Option<PathRef> { + /// Get a reference to the primary path, if one exists. + pub fn primary(&self) -> Option<PathRef> { self.primary.clone() } @@ -242,6 +236,11 @@ impl Paths { /// Returns `true` if the path was migrated. pub fn migrate(&mut self, path: &PathRef, force: bool, now: Instant) -> bool { debug_assert!(!self.is_temporary(path)); + let baseline = self.primary().map_or_else( + || EcnInfo::default().baseline(), + |p| p.borrow().ecn_info.baseline(), + ); + path.borrow_mut().set_ecn_baseline(baseline); if force || path.borrow().is_valid() { path.borrow_mut().set_valid(now); mem::drop(self.select_primary(path)); @@ -307,7 +306,6 @@ impl Paths { /// Set the identified path to be primary. /// This panics if `make_permanent` hasn't been called. pub fn handle_migration(&mut self, path: &PathRef, remote: SocketAddr, now: Instant) { - qtrace!([self.primary().borrow()], "handle_migration"); // The update here needs to match the checks in `Path::received_on`. // Here, we update the remote port number to match the source port on the // datagram that was received. This ensures that we send subsequent @@ -425,10 +423,10 @@ impl Paths { stats.retire_connection_id += 1; } - // Write out any ACK_FREQUENCY frames. - self.primary() - .borrow_mut() - .write_cc_frames(builder, tokens, stats); + if let Some(path) = self.primary() { + // Write out any ACK_FREQUENCY frames. + path.borrow_mut().write_cc_frames(builder, tokens, stats); + } } pub fn lost_retire_cid(&mut self, lost: u64) { @@ -440,11 +438,15 @@ impl Paths { } pub fn lost_ack_frequency(&mut self, lost: &AckRate) { - self.primary().borrow_mut().lost_ack_frequency(lost); + if let Some(path) = self.primary() { + path.borrow_mut().lost_ack_frequency(lost); + } } pub fn acked_ack_frequency(&mut self, acked: &AckRate) { - self.primary().borrow_mut().acked_ack_frequency(acked); + if let Some(path) = self.primary() { + path.borrow_mut().acked_ack_frequency(acked); + } } /// Get an estimate of the RTT on the primary path. @@ -454,7 +456,7 @@ impl Paths { // make a new RTT esimate and interrogate that. // That is more expensive, but it should be rare and breaking encapsulation // is worse, especially as this is only used in tests. - self.primary_fallible() + self.primary() .map_or(RttEstimate::default().estimate(), |p| { p.borrow().rtt().estimate() }) @@ -532,8 +534,6 @@ pub struct Path { rtt: RttEstimate, /// A packet sender for the path, which includes congestion control and a pacer. sender: PacketSender, - /// The DSCP/ECN marking to use for outgoing packets on this path. - tos: IpTos, /// The IP TTL to use for outgoing packets on this path. ttl: u8, @@ -543,7 +543,8 @@ pub struct Path { received_bytes: usize, /// The number of bytes sent on this path. sent_bytes: usize, - + /// The ECN-related state for this path (see RFC9000, Section 13.4 and Appendix A.4) + ecn_info: EcnInfo, /// For logging of events. qlog: NeqoQlog, } @@ -572,14 +573,23 @@ impl Path { challenge: None, rtt: RttEstimate::default(), sender, - tos: IpTos::default(), // TODO: Default to Ect0 when ECN is supported. - ttl: 64, // This is the default TTL on many OSes. + ttl: 64, // This is the default TTL on many OSes. received_bytes: 0, sent_bytes: 0, + ecn_info: EcnInfo::default(), qlog, } } + pub fn set_ecn_baseline(&mut self, baseline: EcnCount) { + self.ecn_info.set_baseline(baseline); + } + + /// Return the DSCP/ECN marking to use for outgoing packets on this path. + pub fn tos(&self) -> IpTos { + self.ecn_info.ecn_mark().into() + } + /// Whether this path is the primary or current path for the connection. pub fn is_primary(&self) -> bool { self.primary @@ -695,8 +705,9 @@ impl Path { } /// Make a datagram. - pub fn datagram<V: Into<Vec<u8>>>(&self, payload: V) -> Datagram { - Datagram::new(self.local, self.remote, self.tos, Some(self.ttl), payload) + pub fn datagram<V: Into<Vec<u8>>>(&mut self, payload: V) -> Datagram { + self.ecn_info.on_packet_sent(); + Datagram::new(self.local, self.remote, self.tos(), Some(self.ttl), payload) } /// Get local address as `SocketAddr` @@ -959,8 +970,24 @@ impl Path { } /// Record packets as acknowledged with the sender. - pub fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], now: Instant) { + pub fn on_packets_acked( + &mut self, + acked_pkts: &[SentPacket], + ack_ecn: Option<EcnCount>, + now: Instant, + ) { debug_assert!(self.is_primary()); + + let ecn_ce_received = self.ecn_info.on_packets_acked(acked_pkts, ack_ecn); + if ecn_ce_received { + let cwnd_reduced = self + .sender + .on_ecn_ce_received(acked_pkts.first().expect("must be there")); + if cwnd_reduced { + self.rtt.update_ack_delay(self.sender.cwnd(), self.mtu()); + } + } + self.sender.on_packets_acked(acked_pkts, &self.rtt, now); } diff --git a/third_party/rust/neqo-transport/src/qlog.rs b/third_party/rust/neqo-transport/src/qlog.rs index a8ad986d2a..715ba85e81 100644 --- a/third_party/rust/neqo-transport/src/qlog.rs +++ b/third_party/rust/neqo-transport/src/qlog.rs @@ -11,7 +11,7 @@ use std::{ time::Duration, }; -use neqo_common::{hex, qinfo, qlog::NeqoQlog, Decoder}; +use neqo_common::{hex, qinfo, qlog::NeqoQlog, Decoder, IpTosEcn}; use qlog::events::{ connectivity::{ConnectionStarted, ConnectionState, ConnectionStateUpdated}, quic::{ @@ -205,7 +205,7 @@ pub fn packet_sent( let mut frames = SmallVec::new(); while d.remaining() > 0 { if let Ok(f) = Frame::decode(&mut d) { - frames.push(QuicFrame::from(&f)); + frames.push(QuicFrame::from(f)); } else { qinfo!("qlog: invalid frame"); break; @@ -293,7 +293,7 @@ pub fn packet_received( while d.remaining() > 0 { if let Ok(f) = Frame::decode(&mut d) { - frames.push(QuicFrame::from(&f)); + frames.push(QuicFrame::from(f)); } else { qinfo!("qlog: invalid frame"); break; @@ -387,21 +387,26 @@ pub fn metrics_updated(qlog: &mut NeqoQlog, updated_metrics: &[QlogMetric]) { #[allow(clippy::too_many_lines)] // Yeah, but it's a nice match. #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] // No choice here. -impl From<&Frame<'_>> for QuicFrame { - fn from(frame: &Frame) -> Self { +impl From<Frame<'_>> for QuicFrame { + fn from(frame: Frame) -> Self { match frame { - // TODO: Add payload length to `QuicFrame::Padding` once - // https://github.com/cloudflare/quiche/pull/1745 is available via the qlog crate. - Frame::Padding { .. } => QuicFrame::Padding, - Frame::Ping => QuicFrame::Ping, + Frame::Padding(len) => QuicFrame::Padding { + length: None, + payload_length: u32::from(len), + }, + Frame::Ping => QuicFrame::Ping { + length: None, + payload_length: None, + }, Frame::Ack { largest_acknowledged, ack_delay, first_ack_range, ack_ranges, + ecn_count, } => { let ranges = - Frame::decode_ack_frame(*largest_acknowledged, *first_ack_range, ack_ranges) + Frame::decode_ack_frame(largest_acknowledged, first_ack_range, &ack_ranges) .ok(); let acked_ranges = ranges.map(|all| { @@ -413,11 +418,13 @@ impl From<&Frame<'_>> for QuicFrame { }); QuicFrame::Ack { - ack_delay: Some(*ack_delay as f32 / 1000.0), + ack_delay: Some(ack_delay as f32 / 1000.0), acked_ranges, - ect1: None, - ect0: None, - ce: None, + ect1: ecn_count.map(|c| c[IpTosEcn::Ect1]), + ect0: ecn_count.map(|c| c[IpTosEcn::Ect0]), + ce: ecn_count.map(|c| c[IpTosEcn::Ce]), + length: None, + payload_length: None, } } Frame::ResetStream { @@ -426,18 +433,22 @@ impl From<&Frame<'_>> for QuicFrame { final_size, } => QuicFrame::ResetStream { stream_id: stream_id.as_u64(), - error_code: *application_error_code, - final_size: *final_size, + error_code: application_error_code, + final_size, + length: None, + payload_length: None, }, Frame::StopSending { stream_id, application_error_code, } => QuicFrame::StopSending { stream_id: stream_id.as_u64(), - error_code: *application_error_code, + error_code: application_error_code, + length: None, + payload_length: None, }, Frame::Crypto { offset, data } => QuicFrame::Crypto { - offset: *offset, + offset, length: data.len() as u64, }, Frame::NewToken { token } => QuicFrame::NewToken { @@ -459,20 +470,20 @@ impl From<&Frame<'_>> for QuicFrame { .. } => QuicFrame::Stream { stream_id: stream_id.as_u64(), - offset: *offset, + offset, length: data.len() as u64, - fin: Some(*fin), + fin: Some(fin), raw: None, }, Frame::MaxData { maximum_data } => QuicFrame::MaxData { - maximum: *maximum_data, + maximum: maximum_data, }, Frame::MaxStreamData { stream_id, maximum_stream_data, } => QuicFrame::MaxStreamData { stream_id: stream_id.as_u64(), - maximum: *maximum_stream_data, + maximum: maximum_stream_data, }, Frame::MaxStreams { stream_type, @@ -482,15 +493,15 @@ impl From<&Frame<'_>> for QuicFrame { NeqoStreamType::BiDi => StreamType::Bidirectional, NeqoStreamType::UniDi => StreamType::Unidirectional, }, - maximum: *maximum_streams, + maximum: maximum_streams, }, - Frame::DataBlocked { data_limit } => QuicFrame::DataBlocked { limit: *data_limit }, + Frame::DataBlocked { data_limit } => QuicFrame::DataBlocked { limit: data_limit }, Frame::StreamDataBlocked { stream_id, stream_data_limit, } => QuicFrame::StreamDataBlocked { stream_id: stream_id.as_u64(), - limit: *stream_data_limit, + limit: stream_data_limit, }, Frame::StreamsBlocked { stream_type, @@ -500,7 +511,7 @@ impl From<&Frame<'_>> for QuicFrame { NeqoStreamType::BiDi => StreamType::Bidirectional, NeqoStreamType::UniDi => StreamType::Unidirectional, }, - limit: *stream_limit, + limit: stream_limit, }, Frame::NewConnectionId { sequence_number, @@ -508,14 +519,14 @@ impl From<&Frame<'_>> for QuicFrame { connection_id, stateless_reset_token, } => QuicFrame::NewConnectionId { - sequence_number: *sequence_number as u32, - retire_prior_to: *retire_prior as u32, + sequence_number: sequence_number as u32, + retire_prior_to: retire_prior as u32, connection_id_length: Some(connection_id.len() as u8), connection_id: hex(connection_id), stateless_reset_token: Some(hex(stateless_reset_token)), }, Frame::RetireConnectionId { sequence_number } => QuicFrame::RetireConnectionId { - sequence_number: *sequence_number as u32, + sequence_number: sequence_number as u32, }, Frame::PathChallenge { data } => QuicFrame::PathChallenge { data: Some(hex(data)), @@ -534,8 +545,8 @@ impl From<&Frame<'_>> for QuicFrame { }, error_code: Some(error_code.code()), error_code_value: Some(0), - reason: Some(String::from_utf8_lossy(reason_phrase).to_string()), - trigger_frame_type: Some(*frame_type), + reason: Some(reason_phrase), + trigger_frame_type: Some(frame_type), }, Frame::HandshakeDone => QuicFrame::HandshakeDone, Frame::AckFrequency { .. } => QuicFrame::Unknown { diff --git a/third_party/rust/neqo-transport/src/recovery.rs b/third_party/rust/neqo-transport/src/recovery.rs index dbea3aaf57..22a635d9f3 100644 --- a/third_party/rust/neqo-transport/src/recovery.rs +++ b/third_party/rust/neqo-transport/src/recovery.rs @@ -21,6 +21,7 @@ use crate::{ ackrate::AckRate, cid::ConnectionIdEntry, crypto::CryptoRecoveryToken, + ecn::EcnCount, packet::PacketNumber, path::{Path, PathRef}, qlog::{self, QlogMetric}, @@ -665,12 +666,14 @@ impl LossRecovery { } /// Returns (acked packets, lost packets) + #[allow(clippy::too_many_arguments)] pub fn on_ack_received<R>( &mut self, primary_path: &PathRef, pn_space: PacketNumberSpace, largest_acked: u64, acked_ranges: R, + ack_ecn: Option<EcnCount>, ack_delay: Duration, now: Instant, ) -> (Vec<SentPacket>, Vec<SentPacket>) @@ -692,10 +695,10 @@ impl LossRecovery { let (acked_packets, any_ack_eliciting) = space.remove_acked(acked_ranges, &mut self.stats.borrow_mut()); - if acked_packets.is_empty() { + let Some(largest_acked_pkt) = acked_packets.first() else { // No new information. return (Vec::new(), Vec::new()); - } + }; // Track largest PN acked per space let prev_largest_acked = space.largest_acked_sent_time; @@ -704,7 +707,6 @@ impl LossRecovery { // If the largest acknowledged is newly acked and any newly acked // packet was ack-eliciting, update the RTT. (-recovery 5.1) - let largest_acked_pkt = acked_packets.first().expect("must be there"); space.largest_acked_sent_time = Some(largest_acked_pkt.time_sent); if any_ack_eliciting && largest_acked_pkt.on_primary_path() { self.rtt_sample( @@ -744,7 +746,7 @@ impl LossRecovery { // when it shouldn't. primary_path .borrow_mut() - .on_packets_acked(&acked_packets, now); + .on_packets_acked(&acked_packets, ack_ecn, now); self.pto_state = None; @@ -1022,7 +1024,7 @@ mod tests { time::{Duration, Instant}, }; - use neqo_common::qlog::NeqoQlog; + use neqo_common::{qlog::NeqoQlog, IpTosEcn}; use test_fixture::{now, DEFAULT_ADDR}; use super::{ @@ -1031,6 +1033,7 @@ mod tests { use crate::{ cc::CongestionControlAlgorithm, cid::{ConnectionId, ConnectionIdEntry}, + ecn::EcnCount, packet::PacketType, path::{Path, PathRef}, rtt::RttEstimate, @@ -1060,6 +1063,7 @@ mod tests { pn_space: PacketNumberSpace, largest_acked: u64, acked_ranges: Vec<RangeInclusive<u64>>, + ack_ecn: Option<EcnCount>, ack_delay: Duration, now: Instant, ) -> (Vec<SentPacket>, Vec<SentPacket>) { @@ -1068,6 +1072,7 @@ mod tests { pn_space, largest_acked, acked_ranges, + ack_ecn, ack_delay, now, ) @@ -1208,6 +1213,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Short, pn, + IpTosEcn::default(), pn_time(pn), true, Vec::new(), @@ -1223,6 +1229,7 @@ mod tests { PacketNumberSpace::ApplicationData, pn, vec![pn..=pn], + None, ACK_DELAY, pn_time(pn) + delay, ); @@ -1233,6 +1240,7 @@ mod tests { lrs.on_packet_sent(SentPacket::new( PacketType::Short, pn, + IpTosEcn::default(), pn_time(pn), true, Vec::new(), @@ -1353,6 +1361,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Short, 0, + IpTosEcn::default(), pn_time(0), true, Vec::new(), @@ -1361,6 +1370,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Short, 1, + IpTosEcn::default(), pn_time(0) + TEST_RTT / 4, true, Vec::new(), @@ -1370,6 +1380,7 @@ mod tests { PacketNumberSpace::ApplicationData, 1, vec![1..=1], + None, ACK_DELAY, pn_time(0) + (TEST_RTT * 5 / 4), ); @@ -1393,6 +1404,7 @@ mod tests { PacketNumberSpace::ApplicationData, 2, vec![2..=2], + None, ACK_DELAY, pn2_ack_time, ); @@ -1422,6 +1434,7 @@ mod tests { PacketNumberSpace::ApplicationData, 4, vec![2..=4], + None, ACK_DELAY, pn_time(4), ); @@ -1450,6 +1463,7 @@ mod tests { PacketNumberSpace::Initial, 0, vec![], + None, Duration::from_millis(0), pn_time(0), ); @@ -1463,6 +1477,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Initial, 0, + IpTosEcn::default(), pn_time(0), true, Vec::new(), @@ -1471,6 +1486,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Handshake, 0, + IpTosEcn::default(), pn_time(1), true, Vec::new(), @@ -1479,6 +1495,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Short, 0, + IpTosEcn::default(), pn_time(2), true, Vec::new(), @@ -1491,10 +1508,25 @@ mod tests { PacketType::Handshake, PacketType::Short, ] { - let sent_pkt = SentPacket::new(*sp, 1, pn_time(3), true, Vec::new(), ON_SENT_SIZE); + let sent_pkt = SentPacket::new( + *sp, + 1, + IpTosEcn::default(), + pn_time(3), + true, + Vec::new(), + ON_SENT_SIZE, + ); let pn_space = PacketNumberSpace::from(sent_pkt.pt); lr.on_packet_sent(sent_pkt); - lr.on_ack_received(pn_space, 1, vec![1..=1], Duration::from_secs(0), pn_time(3)); + lr.on_ack_received( + pn_space, + 1, + vec![1..=1], + None, + Duration::from_secs(0), + pn_time(3), + ); let mut lost = Vec::new(); lr.spaces.get_mut(pn_space).unwrap().detect_lost_packets( pn_time(3), @@ -1516,6 +1548,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Initial, 0, + IpTosEcn::default(), pn_time(3), true, Vec::new(), @@ -1530,6 +1563,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Initial, 0, + IpTosEcn::default(), now(), true, Vec::new(), @@ -1542,6 +1576,7 @@ mod tests { PacketNumberSpace::Initial, 0, vec![0..=0], + None, Duration::new(0, 0), now() + rtt, ); @@ -1549,6 +1584,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Handshake, 0, + IpTosEcn::default(), now(), true, Vec::new(), @@ -1557,6 +1593,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Short, 0, + IpTosEcn::default(), now(), true, Vec::new(), @@ -1594,6 +1631,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Initial, 1, + IpTosEcn::default(), now(), true, Vec::new(), diff --git a/third_party/rust/neqo-transport/src/send_stream.rs b/third_party/rust/neqo-transport/src/send_stream.rs index 8771ec7765..98476e9d18 100644 --- a/third_party/rust/neqo-transport/src/send_stream.rs +++ b/third_party/rust/neqo-transport/src/send_stream.rs @@ -1269,7 +1269,7 @@ impl SendStream { return Err(Error::FinalSizeError); } - let buf = if buf.is_empty() || (self.avail() == 0) { + let buf = if self.avail() == 0 { return Ok(0); } else if self.avail() < buf.len() { if atomic { @@ -1634,20 +1634,16 @@ impl SendStreams { } pub fn remove_terminal(&mut self) { - let map: &mut IndexMap<StreamId, SendStream> = &mut self.map; - let regular: &mut OrderGroup = &mut self.regular; - let sendordered: &mut BTreeMap<SendOrder, OrderGroup> = &mut self.sendordered; - - // Take refs to all the items we need to modify instead of &mut - // self to keep the compiler happy (if we use self.map.retain it - // gets upset due to borrows) - map.retain(|stream_id, stream| { + self.map.retain(|stream_id, stream| { if stream.is_terminal() { if stream.is_fair() { match stream.sendorder() { - None => regular.remove(*stream_id), + None => self.regular.remove(*stream_id), Some(sendorder) => { - sendordered.get_mut(&sendorder).unwrap().remove(*stream_id); + self.sendordered + .get_mut(&sendorder) + .unwrap() + .remove(*stream_id); } }; } diff --git a/third_party/rust/neqo-transport/src/sender.rs b/third_party/rust/neqo-transport/src/sender.rs index 3a54851533..abb14d0a25 100644 --- a/third_party/rust/neqo-transport/src/sender.rs +++ b/third_party/rust/neqo-transport/src/sender.rs @@ -97,6 +97,11 @@ impl PacketSender { ) } + /// Called when ECN CE mark received. Returns true if the congestion window was reduced. + pub fn on_ecn_ce_received(&mut self, largest_acked_pkt: &SentPacket) -> bool { + self.cc.on_ecn_ce_received(largest_acked_pkt) + } + pub fn discard(&mut self, pkt: &SentPacket) { self.cc.discard(pkt); } diff --git a/third_party/rust/neqo-transport/src/server.rs b/third_party/rust/neqo-transport/src/server.rs index 96a6244ef1..60909d71e1 100644 --- a/third_party/rust/neqo-transport/src/server.rs +++ b/third_party/rust/neqo-transport/src/server.rs @@ -689,6 +689,13 @@ impl Server { mem::take(&mut self.active).into_iter().collect() } + /// Whether any connections have received new events as a result of calling + /// `process()`. + #[must_use] + pub fn has_active_connections(&self) -> bool { + !self.active.is_empty() + } + pub fn add_to_waiting(&mut self, c: &ActiveConnectionRef) { self.waiting.push_back(c.connection()); } diff --git a/third_party/rust/neqo-transport/src/tracking.rs b/third_party/rust/neqo-transport/src/tracking.rs index bdd0f250c7..6643d516e3 100644 --- a/third_party/rust/neqo-transport/src/tracking.rs +++ b/third_party/rust/neqo-transport/src/tracking.rs @@ -13,18 +13,21 @@ use std::{ time::{Duration, Instant}, }; -use neqo_common::{qdebug, qinfo, qtrace, qwarn}; +use enum_map::Enum; +use neqo_common::{qdebug, qinfo, qtrace, qwarn, IpTosEcn}; use neqo_crypto::{Epoch, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL}; use smallvec::{smallvec, SmallVec}; use crate::{ + ecn::EcnCount, + frame::{FRAME_TYPE_ACK, FRAME_TYPE_ACK_ECN}, packet::{PacketBuilder, PacketNumber, PacketType}, recovery::RecoveryToken, stats::FrameStats, }; // TODO(mt) look at enabling EnumMap for this: https://stackoverflow.com/a/44905797/1375574 -#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq, Enum)] pub enum PacketNumberSpace { Initial, Handshake, @@ -134,6 +137,7 @@ impl std::fmt::Debug for PacketNumberSpaceSet { pub struct SentPacket { pub pt: PacketType, pub pn: PacketNumber, + pub ecn_mark: IpTosEcn, ack_eliciting: bool, pub time_sent: Instant, primary_path: bool, @@ -150,6 +154,7 @@ impl SentPacket { pub fn new( pt: PacketType, pn: PacketNumber, + ecn_mark: IpTosEcn, time_sent: Instant, ack_eliciting: bool, tokens: Vec<RecoveryToken>, @@ -158,6 +163,7 @@ impl SentPacket { Self { pt, pn, + ecn_mark, time_sent, ack_eliciting, primary_path: true, @@ -377,6 +383,8 @@ pub struct RecvdPackets { /// Whether we are ignoring packets that arrive out of order /// for the purposes of generating immediate acknowledgment. ignore_order: bool, + // The counts of different ECN marks that have been received. + ecn_count: EcnCount, } impl RecvdPackets { @@ -394,9 +402,15 @@ impl RecvdPackets { unacknowledged_count: 0, unacknowledged_tolerance: DEFAULT_ACK_PACKET_TOLERANCE, ignore_order: false, + ecn_count: EcnCount::default(), } } + /// Get the ECN counts. + pub fn ecn_marks(&mut self) -> &mut EcnCount { + &mut self.ecn_count + } + /// Get the time at which the next ACK should be sent. pub fn ack_time(&self) -> Option<Instant> { self.ack_time @@ -545,6 +559,10 @@ impl RecvdPackets { } } + /// Length of the worst possible ACK frame, assuming only one range and ECN counts. + /// Note that this assumes one byte for the type and count of extra ranges. + pub const USEFUL_ACK_LEN: usize = 1 + 8 + 8 + 1 + 8 + 3 * 8; + /// Generate an ACK frame for this packet number space. /// /// Unlike other frame generators this doesn't modify the underlying instance @@ -563,10 +581,6 @@ impl RecvdPackets { tokens: &mut Vec<RecoveryToken>, stats: &mut FrameStats, ) { - // The worst possible ACK frame, assuming only one range. - // Note that this assumes one byte for the type and count of extra ranges. - const LONGEST_ACK_HEADER: usize = 1 + 8 + 8 + 1 + 8; - // Check that we aren't delaying ACKs. if !self.ack_now(now, rtt) { return; @@ -578,7 +592,10 @@ impl RecvdPackets { // When congestion limited, ACK-only packets are 255 bytes at most // (`recovery::ACK_ONLY_SIZE_LIMIT - 1`). This results in limiting the // ranges to 13 here. - let max_ranges = if let Some(avail) = builder.remaining().checked_sub(LONGEST_ACK_HEADER) { + let max_ranges = if let Some(avail) = builder + .remaining() + .checked_sub(RecvdPackets::USEFUL_ACK_LEN) + { // Apply a hard maximum to keep plenty of space for other stuff. min(1 + (avail / 16), MAX_ACKS_PER_FRAME) } else { @@ -593,7 +610,11 @@ impl RecvdPackets { .cloned() .collect::<Vec<_>>(); - builder.encode_varint(crate::frame::FRAME_TYPE_ACK); + builder.encode_varint(if self.ecn_count.is_some() { + FRAME_TYPE_ACK_ECN + } else { + FRAME_TYPE_ACK + }); let mut iter = ranges.iter(); let Some(first) = iter.next() else { return }; builder.encode_varint(first.largest); @@ -617,6 +638,12 @@ impl RecvdPackets { last = r.smallest; } + if self.ecn_count.is_some() { + builder.encode_varint(self.ecn_count[IpTosEcn::Ect0]); + builder.encode_varint(self.ecn_count[IpTosEcn::Ect1]); + builder.encode_varint(self.ecn_count[IpTosEcn::Ce]); + } + // We've sent an ACK, reset the timer. self.ack_time = None; self.last_ack_time = Some(now); @@ -1134,7 +1161,9 @@ mod tests { .is_some()); let mut builder = PacketBuilder::short(Encoder::new(), false, []); - builder.set_limit(32); + // The code pessimistically assumes that each range needs 16 bytes to express. + // So this won't be enough for a second range. + builder.set_limit(RecvdPackets::USEFUL_ACK_LEN + 8); let mut stats = FrameStats::default(); tracker.write_frame( |