summaryrefslogtreecommitdiffstats
path: root/third_party/rust/neqo-transport/src
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/neqo-transport/src')
-rw-r--r--third_party/rust/neqo-transport/src/cc/classic_cc.rs105
-rw-r--r--third_party/rust/neqo-transport/src/cc/mod.rs3
-rw-r--r--third_party/rust/neqo-transport/src/cc/tests/cubic.rs34
-rw-r--r--third_party/rust/neqo-transport/src/cc/tests/new_reno.rs89
-rw-r--r--third_party/rust/neqo-transport/src/connection/mod.rs285
-rw-r--r--third_party/rust/neqo-transport/src/connection/state.rs39
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/cc.rs36
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/close.rs25
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/datagram.rs14
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/ecn.rs392
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/handshake.rs26
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/keys.rs10
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/migration.rs8
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/mod.rs31
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/stream.rs9
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/vn.rs8
-rw-r--r--third_party/rust/neqo-transport/src/ecn.rs225
-rw-r--r--third_party/rust/neqo-transport/src/events.rs4
-rw-r--r--third_party/rust/neqo-transport/src/frame.rs143
-rw-r--r--third_party/rust/neqo-transport/src/lib.rs27
-rw-r--r--third_party/rust/neqo-transport/src/packet/mod.rs1
-rw-r--r--third_party/rust/neqo-transport/src/path.rs77
-rw-r--r--third_party/rust/neqo-transport/src/qlog.rs73
-rw-r--r--third_party/rust/neqo-transport/src/recovery.rs52
-rw-r--r--third_party/rust/neqo-transport/src/send_stream.rs18
-rw-r--r--third_party/rust/neqo-transport/src/sender.rs5
-rw-r--r--third_party/rust/neqo-transport/src/server.rs7
-rw-r--r--third_party/rust/neqo-transport/src/tracking.rs47
28 files changed, 1371 insertions, 422 deletions
diff --git a/third_party/rust/neqo-transport/src/cc/classic_cc.rs b/third_party/rust/neqo-transport/src/cc/classic_cc.rs
index f8bcee6722..6914e91f67 100644
--- a/third_party/rust/neqo-transport/src/cc/classic_cc.rs
+++ b/third_party/rust/neqo-transport/src/cc/classic_cc.rs
@@ -298,6 +298,14 @@ impl<T: WindowAdjustment> CongestionControl for ClassicCongestionControl<T> {
congestion || persistent_congestion
}
+ /// Report received ECN CE mark(s) to the congestion controller as a
+ /// congestion event.
+ ///
+ /// See <https://datatracker.ietf.org/doc/html/rfc9002#section-b.7>.
+ fn on_ecn_ce_received(&mut self, largest_acked_pkt: &SentPacket) -> bool {
+ self.on_congestion_event(largest_acked_pkt)
+ }
+
fn discard(&mut self, pkt: &SentPacket) {
if pkt.cc_outstanding() {
assert!(self.bytes_in_flight >= pkt.size);
@@ -488,8 +496,8 @@ impl<T: WindowAdjustment> ClassicCongestionControl<T> {
/// Handle a congestion event.
/// Returns true if this was a true congestion event.
fn on_congestion_event(&mut self, last_packet: &SentPacket) -> bool {
- // Start a new congestion event if lost packet was sent after the start
- // of the previous congestion recovery period.
+ // Start a new congestion event if lost or ECN CE marked packet was sent
+ // after the start of the previous congestion recovery period.
if !self.after_recovery_start(last_packet) {
return false;
}
@@ -538,7 +546,7 @@ impl<T: WindowAdjustment> ClassicCongestionControl<T> {
mod tests {
use std::time::{Duration, Instant};
- use neqo_common::qinfo;
+ use neqo_common::{qinfo, IpTosEcn};
use test_fixture::now;
use super::{
@@ -582,6 +590,7 @@ mod tests {
SentPacket::new(
PacketType::Short,
pn,
+ IpTosEcn::default(),
now() + t,
ack_eliciting,
Vec::new(),
@@ -795,6 +804,7 @@ mod tests {
SentPacket::new(
PacketType::Short,
u64::try_from(i).unwrap(),
+ IpTosEcn::default(),
by_pto(t),
true,
Vec::new(),
@@ -915,6 +925,7 @@ mod tests {
lost[0] = SentPacket::new(
lost[0].pt,
lost[0].pn,
+ lost[0].ecn_mark,
lost[0].time_sent,
false,
Vec::new(),
@@ -1015,11 +1026,12 @@ mod tests {
for _ in 0..packet_burst_size {
let p = SentPacket::new(
PacketType::Short,
- next_pn, // pn
- now, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ next_pn,
+ IpTosEcn::default(),
+ now,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
);
next_pn += 1;
cc.on_packet_sent(&p);
@@ -1039,11 +1051,12 @@ mod tests {
for _ in 0..ABOVE_APP_LIMIT_PKTS {
let p = SentPacket::new(
PacketType::Short,
- next_pn, // pn
- now, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ next_pn,
+ IpTosEcn::default(),
+ now,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
);
next_pn += 1;
cc.on_packet_sent(&p);
@@ -1082,11 +1095,12 @@ mod tests {
let p_lost = SentPacket::new(
PacketType::Short,
- 1, // pn
- now, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ 1,
+ IpTosEcn::default(),
+ now,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
);
cc.on_packet_sent(&p_lost);
cwnd_is_default(&cc);
@@ -1095,11 +1109,12 @@ mod tests {
cwnd_is_halved(&cc);
let p_not_lost = SentPacket::new(
PacketType::Short,
- 2, // pn
- now, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ 2,
+ IpTosEcn::default(),
+ now,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
);
cc.on_packet_sent(&p_not_lost);
now += RTT;
@@ -1118,11 +1133,12 @@ mod tests {
for _ in 0..packet_burst_size {
let p = SentPacket::new(
PacketType::Short,
- next_pn, // pn
- now, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ next_pn,
+ IpTosEcn::default(),
+ now,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
);
next_pn += 1;
cc.on_packet_sent(&p);
@@ -1148,11 +1164,12 @@ mod tests {
for _ in 0..ABOVE_APP_LIMIT_PKTS {
let p = SentPacket::new(
PacketType::Short,
- next_pn, // pn
- now, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ next_pn,
+ IpTosEcn::default(),
+ now,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
);
next_pn += 1;
cc.on_packet_sent(&p);
@@ -1180,4 +1197,26 @@ mod tests {
last_acked_bytes = cc.acked_bytes;
}
}
+
+ #[test]
+ fn ecn_ce() {
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+ let p_ce = SentPacket::new(
+ PacketType::Short,
+ 1,
+ IpTosEcn::default(),
+ now(),
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
+ );
+ cc.on_packet_sent(&p_ce);
+ cwnd_is_default(&cc);
+ assert_eq!(cc.state, State::SlowStart);
+
+ // Signal congestion (ECN CE) and thus change state to recovery start.
+ cc.on_ecn_ce_received(&p_ce);
+ cwnd_is_halved(&cc);
+ assert_eq!(cc.state, State::RecoveryStart);
+ }
}
diff --git a/third_party/rust/neqo-transport/src/cc/mod.rs b/third_party/rust/neqo-transport/src/cc/mod.rs
index 486d15e67e..2adffbc0c4 100644
--- a/third_party/rust/neqo-transport/src/cc/mod.rs
+++ b/third_party/rust/neqo-transport/src/cc/mod.rs
@@ -53,6 +53,9 @@ pub trait CongestionControl: Display + Debug {
lost_packets: &[SentPacket],
) -> bool;
+ /// Returns true if the congestion window was reduced.
+ fn on_ecn_ce_received(&mut self, largest_acked_pkt: &SentPacket) -> bool;
+
#[must_use]
fn recovery_packet(&self) -> bool;
diff --git a/third_party/rust/neqo-transport/src/cc/tests/cubic.rs b/third_party/rust/neqo-transport/src/cc/tests/cubic.rs
index 2e0200fd6d..8ff591cb47 100644
--- a/third_party/rust/neqo-transport/src/cc/tests/cubic.rs
+++ b/third_party/rust/neqo-transport/src/cc/tests/cubic.rs
@@ -12,6 +12,7 @@ use std::{
time::{Duration, Instant},
};
+use neqo_common::IpTosEcn;
use test_fixture::now;
use crate::{
@@ -41,11 +42,12 @@ fn fill_cwnd(cc: &mut ClassicCongestionControl<Cubic>, mut next_pn: u64, now: In
while cc.bytes_in_flight() < cc.cwnd() {
let sent = SentPacket::new(
PacketType::Short,
- next_pn, // pn
- now, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ next_pn,
+ IpTosEcn::default(),
+ now,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
);
cc.on_packet_sent(&sent);
next_pn += 1;
@@ -56,11 +58,12 @@ fn fill_cwnd(cc: &mut ClassicCongestionControl<Cubic>, mut next_pn: u64, now: In
fn ack_packet(cc: &mut ClassicCongestionControl<Cubic>, pn: u64, now: Instant) {
let acked = SentPacket::new(
PacketType::Short,
- pn, // pn
- now, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ pn,
+ IpTosEcn::default(),
+ now,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
);
cc.on_packets_acked(&[acked], &RTT_ESTIMATE, now);
}
@@ -69,11 +72,12 @@ fn packet_lost(cc: &mut ClassicCongestionControl<Cubic>, pn: u64) {
const PTO: Duration = Duration::from_millis(120);
let p_lost = SentPacket::new(
PacketType::Short,
- pn, // pn
- now(), // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ pn,
+ IpTosEcn::default(),
+ now(),
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
);
cc.on_packets_lost(None, None, PTO, &[p_lost]);
}
diff --git a/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs b/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs
index 4cc20de5a7..0cc560bf2b 100644
--- a/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs
+++ b/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs
@@ -8,6 +8,7 @@
use std::time::Duration;
+use neqo_common::IpTosEcn;
use test_fixture::now;
use crate::{
@@ -44,59 +45,66 @@ fn issue_876() {
let sent_packets = &[
SentPacket::new(
PacketType::Short,
- 1, // pn
- time_before, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE - 1, // size
+ 1,
+ IpTosEcn::default(),
+ time_before,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE - 1,
),
SentPacket::new(
PacketType::Short,
- 2, // pn
- time_before, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE - 2, // size
+ 2,
+ IpTosEcn::default(),
+ time_before,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE - 2,
),
SentPacket::new(
PacketType::Short,
- 3, // pn
- time_before, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ 3,
+ IpTosEcn::default(),
+ time_before,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
),
SentPacket::new(
PacketType::Short,
- 4, // pn
- time_before, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ 4,
+ IpTosEcn::default(),
+ time_before,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
),
SentPacket::new(
PacketType::Short,
- 5, // pn
- time_before, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ 5,
+ IpTosEcn::default(),
+ time_before,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
),
SentPacket::new(
PacketType::Short,
- 6, // pn
- time_before, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ 6,
+ IpTosEcn::default(),
+ time_before,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
),
SentPacket::new(
PacketType::Short,
- 7, // pn
- time_after, // time sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE - 3, // size
+ 7,
+ IpTosEcn::default(),
+ time_after,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE - 3,
),
];
@@ -146,11 +154,12 @@ fn issue_1465() {
let mut next_packet = |now| {
let p = SentPacket::new(
PacketType::Short,
- pn, // pn
- now, // time_sent
- true, // ack eliciting
- Vec::new(), // tokens
- MAX_DATAGRAM_SIZE, // size
+ pn,
+ IpTosEcn::default(),
+ now,
+ true,
+ Vec::new(),
+ MAX_DATAGRAM_SIZE,
);
pn += 1;
p
diff --git a/third_party/rust/neqo-transport/src/connection/mod.rs b/third_party/rust/neqo-transport/src/connection/mod.rs
index 8522507a69..f955381414 100644
--- a/third_party/rust/neqo-transport/src/connection/mod.rs
+++ b/third_party/rust/neqo-transport/src/connection/mod.rs
@@ -19,7 +19,7 @@ use std::{
use neqo_common::{
event::Provider as EventProvider, hex, hex_snip_middle, hrtime, qdebug, qerror, qinfo,
- qlog::NeqoQlog, qtrace, qwarn, Datagram, Decoder, Encoder, IpTos, Role,
+ qlog::NeqoQlog, qtrace, qwarn, Datagram, Decoder, Encoder, Role,
};
use neqo_crypto::{
agent::CertificateInfo, Agent, AntiReplay, AuthenticationStatus, Cipher, Client, Group,
@@ -35,6 +35,7 @@ use crate::{
ConnectionIdRef, ConnectionIdStore, LOCAL_ACTIVE_CID_LIMIT,
},
crypto::{Crypto, CryptoDxState, CryptoSpace},
+ ecn::EcnCount,
events::{ConnectionEvent, ConnectionEvents, OutgoingDatagramOutcome},
frame::{
CloseError, Frame, FrameType, FRAME_TYPE_CONNECTION_CLOSE_APPLICATION,
@@ -46,7 +47,7 @@ use crate::{
quic_datagrams::{DatagramTracking, QuicDatagrams},
recovery::{LossRecovery, RecoveryToken, SendProfile},
recv_stream::RecvStreamStats,
- rtt::GRANULARITY,
+ rtt::{RttEstimate, GRANULARITY},
send_stream::SendStream,
stats::{Stats, StatsCell},
stream_id::StreamType,
@@ -55,9 +56,9 @@ use crate::{
self, TransportParameter, TransportParameterId, TransportParameters,
TransportParametersHandler,
},
- tracking::{AckTracker, PacketNumberSpace, SentPacket},
+ tracking::{AckTracker, PacketNumberSpace, RecvdPackets, SentPacket},
version::{Version, WireVersion},
- AppError, ConnectionError, Error, Res, StreamId,
+ AppError, CloseReason, Error, Res, StreamId,
};
mod dump;
@@ -291,7 +292,7 @@ impl Debug for Connection {
"{:?} Connection: {:?} {:?}",
self.role,
self.state,
- self.paths.primary_fallible()
+ self.paths.primary()
)
}
}
@@ -591,7 +592,11 @@ impl Connection {
fn make_resumption_token(&mut self) -> ResumptionToken {
debug_assert_eq!(self.role, Role::Client);
debug_assert!(self.crypto.has_resumption_token());
- let rtt = self.paths.primary().borrow().rtt().estimate();
+ let rtt = self.paths.primary().map_or_else(
+ || RttEstimate::default().estimate(),
+ |p| p.borrow().rtt().estimate(),
+ );
+
self.crypto
.create_resumption_token(
self.new_token.take_token(),
@@ -610,11 +615,10 @@ impl Connection {
/// a value of this approximate order. Don't use this for loss recovery,
/// only use it where a more precise value is not important.
fn pto(&self) -> Duration {
- self.paths
- .primary()
- .borrow()
- .rtt()
- .pto(PacketNumberSpace::ApplicationData)
+ self.paths.primary().map_or_else(
+ || RttEstimate::default().pto(PacketNumberSpace::ApplicationData),
+ |p| p.borrow().rtt().pto(PacketNumberSpace::ApplicationData),
+ )
}
fn create_resumption_token(&mut self, now: Instant) {
@@ -746,7 +750,12 @@ impl Connection {
if !init_token.is_empty() {
self.address_validation = AddressValidationInfo::NewToken(init_token.to_vec());
}
- self.paths.primary().borrow_mut().rtt_mut().set_initial(rtt);
+ self.paths
+ .primary()
+ .ok_or(Error::InternalError)?
+ .borrow_mut()
+ .rtt_mut()
+ .set_initial(rtt);
self.set_initial_limits();
// Start up TLS, which has the effect of setting up all the necessary
// state for 0-RTT. This only stages the CRYPTO frames.
@@ -786,7 +795,7 @@ impl Connection {
// If we are able, also send a NEW_TOKEN frame.
// This should be recording all remote addresses that are valid,
// but there are just 0 or 1 in the current implementation.
- if let Some(path) = self.paths.primary_fallible() {
+ if let Some(path) = self.paths.primary() {
if let Some(token) = self
.address_validation
.generate_new_token(path.borrow().remote_address(), now)
@@ -858,7 +867,7 @@ impl Connection {
#[must_use]
pub fn stats(&self) -> Stats {
let mut v = self.stats.borrow().clone();
- if let Some(p) = self.paths.primary_fallible() {
+ if let Some(p) = self.paths.primary() {
let p = p.borrow();
v.rtt = p.rtt().estimate();
v.rttvar = p.rtt().rttvar();
@@ -880,7 +889,7 @@ impl Connection {
let msg = format!("{v:?}");
#[cfg(not(debug_assertions))]
let msg = "";
- let error = ConnectionError::Transport(v.clone());
+ let error = CloseReason::Transport(v.clone());
match &self.state {
State::Closing { error: err, .. }
| State::Draining { error: err, .. }
@@ -895,14 +904,14 @@ impl Connection {
State::WaitInitial => {
// We don't have any state yet, so don't bother with
// the closing state, just send one CONNECTION_CLOSE.
- if let Some(path) = path.or_else(|| self.paths.primary_fallible()) {
+ if let Some(path) = path.or_else(|| self.paths.primary()) {
self.state_signaling
.close(path, error.clone(), frame_type, msg);
}
self.set_state(State::Closed(error));
}
_ => {
- if let Some(path) = path.or_else(|| self.paths.primary_fallible()) {
+ if let Some(path) = path.or_else(|| self.paths.primary()) {
self.state_signaling
.close(path, error.clone(), frame_type, msg);
if matches!(v, Error::KeysExhausted) {
@@ -951,9 +960,7 @@ impl Connection {
let pto = self.pto();
if self.idle_timeout.expired(now, pto) {
qinfo!([self], "idle timeout expired");
- self.set_state(State::Closed(ConnectionError::Transport(
- Error::IdleTimeout,
- )));
+ self.set_state(State::Closed(CloseReason::Transport(Error::IdleTimeout)));
return;
}
@@ -962,9 +969,11 @@ impl Connection {
let res = self.crypto.states.check_key_update(now);
self.absorb_error(now, res);
- let lost = self.loss_recovery.timeout(&self.paths.primary(), now);
- self.handle_lost_packets(&lost);
- qlog::packets_lost(&mut self.qlog, &lost);
+ if let Some(path) = self.paths.primary() {
+ let lost = self.loss_recovery.timeout(&path, now);
+ self.handle_lost_packets(&lost);
+ qlog::packets_lost(&mut self.qlog, &lost);
+ }
if self.release_resumption_token_timer.is_some() {
self.create_resumption_token(now);
@@ -1014,7 +1023,7 @@ impl Connection {
delays.push(ack_time);
}
- if let Some(p) = self.paths.primary_fallible() {
+ if let Some(p) = self.paths.primary() {
let path = p.borrow();
let rtt = path.rtt();
let pto = rtt.pto(PacketNumberSpace::ApplicationData);
@@ -1102,7 +1111,15 @@ impl Connection {
self.input(d, now, now);
self.process_saved(now);
}
- self.process_output(now)
+ #[allow(clippy::let_and_return)]
+ let output = self.process_output(now);
+ #[cfg(all(feature = "build-fuzzing-corpus", test))]
+ if self.test_frame_writer.is_none() {
+ if let Some(d) = output.clone().dgram() {
+ neqo_common::write_item_to_fuzzing_corpus("packet", &d);
+ }
+ }
+ output
}
fn handle_retry(&mut self, packet: &PublicPacket, now: Instant) {
@@ -1123,7 +1140,13 @@ impl Connection {
}
// At this point, we should only have the connection ID that we generated.
// Update to the one that the server prefers.
- let path = self.paths.primary();
+ let Some(path) = self.paths.primary() else {
+ self.stats
+ .borrow_mut()
+ .pkt_dropped("Retry without an existing path");
+ return;
+ };
+
path.borrow_mut().set_remote_cid(packet.scid());
let retry_scid = ConnectionId::from(packet.scid());
@@ -1151,8 +1174,9 @@ impl Connection {
fn discard_keys(&mut self, space: PacketNumberSpace, now: Instant) {
if self.crypto.discard(space) {
qdebug!([self], "Drop packet number space {}", space);
- let primary = self.paths.primary();
- self.loss_recovery.discard(&primary, space, now);
+ if let Some(path) = self.paths.primary() {
+ self.loss_recovery.discard(&path, space, now);
+ }
self.acks.drop_space(space);
}
}
@@ -1180,7 +1204,7 @@ impl Connection {
qdebug!([self], "Stateless reset: {}", hex(&d[d.len() - 16..]));
self.state_signaling.reset();
self.set_state(State::Draining {
- error: ConnectionError::Transport(Error::StatelessReset),
+ error: CloseReason::Transport(Error::StatelessReset),
timeout: self.get_closing_period_time(now),
});
Err(Error::StatelessReset)
@@ -1227,8 +1251,9 @@ impl Connection {
assert_ne!(self.version, version);
qinfo!([self], "Version negotiation: trying {:?}", version);
- let local_addr = self.paths.primary().borrow().local_address();
- let remote_addr = self.paths.primary().borrow().remote_address();
+ let path = self.paths.primary().ok_or(Error::NoAvailablePath)?;
+ let local_addr = path.borrow().local_address();
+ let remote_addr = path.borrow().remote_address();
let conn_params = self
.conn_params
.clone()
@@ -1256,7 +1281,7 @@ impl Connection {
} else {
qinfo!([self], "Version negotiation: failed with {:?}", supported);
// This error goes straight to closed.
- self.set_state(State::Closed(ConnectionError::Transport(
+ self.set_state(State::Closed(CloseReason::Transport(
Error::VersionNegotiation,
)));
Err(Error::VersionNegotiation)
@@ -1417,6 +1442,13 @@ impl Connection {
migrate: bool,
now: Instant,
) {
+ let space = PacketNumberSpace::from(packet.packet_type());
+ if let Some(space) = self.acks.get_mut(space) {
+ *space.ecn_marks() += d.tos().into();
+ } else {
+ qtrace!("Not tracking ECN for dropped packet number space");
+ }
+
if self.state == State::WaitInitial {
self.start_handshake(path, packet, now);
}
@@ -1491,6 +1523,16 @@ impl Connection {
d.tos(),
);
+ #[cfg(feature = "build-fuzzing-corpus")]
+ if packet.packet_type() == PacketType::Initial {
+ let target = if self.role == Role::Client {
+ "server_initial"
+ } else {
+ "client_initial"
+ };
+ neqo_common::write_item_to_fuzzing_corpus(target, &payload[..]);
+ }
+
qlog::packet_received(&mut self.qlog, &packet, &payload);
let space = PacketNumberSpace::from(payload.packet_type());
if self.acks.get_mut(space).unwrap().is_duplicate(payload.pn()) {
@@ -1562,7 +1604,11 @@ impl Connection {
let mut probing = true;
let mut d = Decoder::from(&packet[..]);
while d.remaining() > 0 {
+ #[cfg(feature = "build-fuzzing-corpus")]
+ let pos = d.offset();
let f = Frame::decode(&mut d)?;
+ #[cfg(feature = "build-fuzzing-corpus")]
+ neqo_common::write_item_to_fuzzing_corpus("frame", &packet[pos..d.offset()]);
ack_eliciting |= f.ack_eliciting();
probing &= f.path_probing();
let t = f.get_type();
@@ -1623,10 +1669,15 @@ impl Connection {
if let Some(cid) = self.connection_ids.next() {
self.paths.make_permanent(path, None, cid);
Ok(())
- } else if self.paths.primary().borrow().remote_cid().is_empty() {
- self.paths
- .make_permanent(path, None, ConnectionIdEntry::empty_remote());
- Ok(())
+ } else if let Some(primary) = self.paths.primary() {
+ if primary.borrow().remote_cid().is_empty() {
+ self.paths
+ .make_permanent(path, None, ConnectionIdEntry::empty_remote());
+ Ok(())
+ } else {
+ qtrace!([self], "Unable to make path permanent: {}", path.borrow());
+ Err(Error::InvalidMigration)
+ }
} else {
qtrace!([self], "Unable to make path permanent: {}", path.borrow());
Err(Error::InvalidMigration)
@@ -1719,8 +1770,10 @@ impl Connection {
// Pointless migration is pointless.
return Err(Error::InvalidMigration);
}
- let local = local.unwrap_or_else(|| self.paths.primary().borrow().local_address());
- let remote = remote.unwrap_or_else(|| self.paths.primary().borrow().remote_address());
+
+ let path = self.paths.primary().ok_or(Error::InvalidMigration)?;
+ let local = local.unwrap_or_else(|| path.borrow().local_address());
+ let remote = remote.unwrap_or_else(|| path.borrow().remote_address());
if mem::discriminant(&local.ip()) != mem::discriminant(&remote.ip()) {
// Can't mix address families.
@@ -1773,7 +1826,12 @@ impl Connection {
// has to use the existing address. So only pay attention to a preferred
// address from the same family as is currently in use. More thought will
// be needed to work out how to get addresses from a different family.
- let prev = self.paths.primary().borrow().remote_address();
+ let prev = self
+ .paths
+ .primary()
+ .ok_or(Error::NoAvailablePath)?
+ .borrow()
+ .remote_address();
let remote = match prev.ip() {
IpAddr::V4(_) => addr.ipv4().map(SocketAddr::V4),
IpAddr::V6(_) => addr.ipv6().map(SocketAddr::V6),
@@ -1937,20 +1995,15 @@ impl Connection {
}
}
- self.streams
- .write_frames(TransmissionPriority::Critical, builder, tokens, frame_stats);
- if builder.is_full() {
- return;
- }
-
- self.streams.write_frames(
+ for prio in [
+ TransmissionPriority::Critical,
TransmissionPriority::Important,
- builder,
- tokens,
- frame_stats,
- );
- if builder.is_full() {
- return;
+ ] {
+ self.streams
+ .write_frames(prio, builder, tokens, frame_stats);
+ if builder.is_full() {
+ return;
+ }
}
// NEW_CONNECTION_ID, RETIRE_CONNECTION_ID, and ACK_FREQUENCY.
@@ -1958,21 +2011,18 @@ impl Connection {
if builder.is_full() {
return;
}
- self.paths.write_frames(builder, tokens, frame_stats);
- if builder.is_full() {
- return;
- }
- self.streams
- .write_frames(TransmissionPriority::High, builder, tokens, frame_stats);
+ self.paths.write_frames(builder, tokens, frame_stats);
if builder.is_full() {
return;
}
- self.streams
- .write_frames(TransmissionPriority::Normal, builder, tokens, frame_stats);
- if builder.is_full() {
- return;
+ for prio in [TransmissionPriority::High, TransmissionPriority::Normal] {
+ self.streams
+ .write_frames(prio, builder, tokens, &mut stats.frame_tx);
+ if builder.is_full() {
+ return;
+ }
}
// Datagrams are best-effort and unreliable. Let streams starve them for now.
@@ -1981,9 +2031,9 @@ impl Connection {
return;
}
- let frame_stats = &mut stats.frame_tx;
// CRYPTO here only includes NewSessionTicket, plus NEW_TOKEN.
// Both of these are only used for resumption and so can be relatively low priority.
+ let frame_stats = &mut stats.frame_tx;
self.crypto.write_frame(
PacketNumberSpace::ApplicationData,
builder,
@@ -1993,6 +2043,7 @@ impl Connection {
if builder.is_full() {
return;
}
+
self.new_token.write_frames(builder, tokens, frame_stats);
if builder.is_full() {
return;
@@ -2002,10 +2053,8 @@ impl Connection {
.write_frames(TransmissionPriority::Low, builder, tokens, frame_stats);
#[cfg(test)]
- {
- if let Some(w) = &mut self.test_frame_writer {
- w.write_frames(builder);
- }
+ if let Some(w) = &mut self.test_frame_writer {
+ w.write_frames(builder);
}
}
@@ -2138,6 +2187,40 @@ impl Connection {
(tokens, ack_eliciting, padded)
}
+ fn write_closing_frames(
+ &mut self,
+ close: &ClosingFrame,
+ builder: &mut PacketBuilder,
+ space: PacketNumberSpace,
+ now: Instant,
+ path: &PathRef,
+ tokens: &mut Vec<RecoveryToken>,
+ ) {
+ if builder.remaining() > ClosingFrame::MIN_LENGTH + RecvdPackets::USEFUL_ACK_LEN {
+ // Include an ACK frame with the CONNECTION_CLOSE.
+ let limit = builder.limit();
+ builder.set_limit(limit - ClosingFrame::MIN_LENGTH);
+ self.acks.immediate_ack(now);
+ self.acks.write_frame(
+ space,
+ now,
+ path.borrow().rtt().estimate(),
+ builder,
+ tokens,
+ &mut self.stats.borrow_mut().frame_tx,
+ );
+ builder.set_limit(limit);
+ }
+ // CloseReason::Application is only allowed at 1RTT.
+ let sanitized = if space == PacketNumberSpace::ApplicationData {
+ None
+ } else {
+ close.sanitize()
+ };
+ sanitized.as_ref().unwrap_or(close).write_frame(builder);
+ self.stats.borrow_mut().frame_tx.connection_close += 1;
+ }
+
/// Build a datagram, possibly from multiple packets (for different PN
/// spaces) and each containing 1+ frames.
#[allow(clippy::too_many_lines)] // Yeah, that's just the way it is.
@@ -2201,17 +2284,7 @@ impl Connection {
let payload_start = builder.len();
let (mut tokens, mut ack_eliciting, mut padded) = (Vec::new(), false, false);
if let Some(ref close) = closing_frame {
- // ConnectionError::Application is only allowed at 1RTT.
- let sanitized = if *space == PacketNumberSpace::ApplicationData {
- None
- } else {
- close.sanitize()
- };
- sanitized
- .as_ref()
- .unwrap_or(close)
- .write_frame(&mut builder);
- self.stats.borrow_mut().frame_tx.connection_close += 1;
+ self.write_closing_frames(close, &mut builder, *space, now, path, &mut tokens);
} else {
(tokens, ack_eliciting, padded) =
self.write_frames(path, *space, &profile, &mut builder, now);
@@ -2229,7 +2302,7 @@ impl Connection {
pt,
pn,
&builder.as_ref()[payload_start..],
- IpTos::default(), // TODO: set from path
+ path.borrow().tos(),
);
qlog::packet_sent(
&mut self.qlog,
@@ -2251,6 +2324,7 @@ impl Connection {
let sent = SentPacket::new(
pt,
pn,
+ path.borrow().tos().into(),
now,
ack_eliciting,
tokens,
@@ -2303,7 +2377,7 @@ impl Connection {
self.loss_recovery.on_packet_sent(path, initial);
}
path.borrow_mut().add_sent(packets.len());
- Ok(SendOption::Yes(path.borrow().datagram(packets)))
+ Ok(SendOption::Yes(path.borrow_mut().datagram(packets)))
}
}
@@ -2330,7 +2404,9 @@ impl Connection {
fn client_start(&mut self, now: Instant) -> Res<()> {
qdebug!([self], "client_start");
debug_assert_eq!(self.role, Role::Client);
- qlog::client_connection_started(&mut self.qlog, &self.paths.primary());
+ if let Some(path) = self.paths.primary() {
+ qlog::client_connection_started(&mut self.qlog, &path);
+ }
qlog::client_version_information_initiated(&mut self.qlog, self.conn_params.get_versions());
self.handshake(now, self.version, PacketNumberSpace::Initial, None)?;
@@ -2351,9 +2427,9 @@ impl Connection {
/// Close the connection.
pub fn close(&mut self, now: Instant, app_error: AppError, msg: impl AsRef<str>) {
- let error = ConnectionError::Application(app_error);
+ let error = CloseReason::Application(app_error);
let timeout = self.get_closing_period_time(now);
- if let Some(path) = self.paths.primary_fallible() {
+ if let Some(path) = self.paths.primary() {
self.state_signaling.close(path, error.clone(), 0, msg);
self.set_state(State::Closing { error, timeout });
} else {
@@ -2411,10 +2487,8 @@ impl Connection {
// That's OK, they can try guessing this.
ConnectionIdEntry::random_srt()
};
- self.paths
- .primary()
- .borrow_mut()
- .set_reset_token(reset_token);
+ let path = self.paths.primary().ok_or(Error::NoAvailablePath)?;
+ path.borrow_mut().set_reset_token(reset_token);
let max_ad = Duration::from_millis(remote.get_integer(tparams::MAX_ACK_DELAY));
let min_ad = if remote.has_value(tparams::MIN_ACK_DELAY) {
@@ -2426,11 +2500,8 @@ impl Connection {
} else {
None
};
- self.paths.primary().borrow_mut().set_ack_delay(
- max_ad,
- min_ad,
- self.conn_params.get_ack_ratio(),
- );
+ path.borrow_mut()
+ .set_ack_delay(max_ad, min_ad, self.conn_params.get_ack_ratio());
let max_active_cids = remote.get_integer(tparams::ACTIVE_CONNECTION_ID_LIMIT);
self.cid_manager.set_limit(max_active_cids);
@@ -2673,10 +2744,18 @@ impl Connection {
ack_delay,
first_ack_range,
ack_ranges,
+ ecn_count,
} => {
let ranges =
Frame::decode_ack_frame(largest_acknowledged, first_ack_range, &ack_ranges)?;
- self.handle_ack(space, largest_acknowledged, ranges, ack_delay, now);
+ self.handle_ack(
+ space,
+ largest_acknowledged,
+ ranges,
+ ecn_count,
+ ack_delay,
+ now,
+ );
}
Frame::Crypto { offset, data } => {
qtrace!(
@@ -2747,7 +2826,6 @@ impl Connection {
reason_phrase,
} => {
self.stats.borrow_mut().frame_rx.connection_close += 1;
- let reason_phrase = String::from_utf8_lossy(&reason_phrase);
qinfo!(
[self],
"ConnectionClose received. Error code: {:?} frame type {:x} reason {}",
@@ -2768,7 +2846,7 @@ impl Connection {
FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT,
)
};
- let error = ConnectionError::Transport(detail);
+ let error = CloseReason::Transport(detail);
self.state_signaling
.drain(Rc::clone(path), error.clone(), frame_type, "");
self.set_state(State::Draining {
@@ -2853,6 +2931,7 @@ impl Connection {
space: PacketNumberSpace,
largest_acknowledged: u64,
ack_ranges: R,
+ ack_ecn: Option<EcnCount>,
ack_delay: u64,
now: Instant,
) where
@@ -2861,11 +2940,15 @@ impl Connection {
{
qdebug!([self], "Rx ACK space={}, ranges={:?}", space, ack_ranges);
+ let Some(path) = self.paths.primary() else {
+ return;
+ };
let (acked_packets, lost_packets) = self.loss_recovery.on_ack_received(
- &self.paths.primary(),
+ &path,
space,
largest_acknowledged,
ack_ranges,
+ ack_ecn,
self.decode_ack_delay(ack_delay),
now,
);
@@ -2903,8 +2986,10 @@ impl Connection {
qdebug!([self], "0-RTT rejected");
// Tell 0-RTT packets that they were "lost".
- let dropped = self.loss_recovery.drop_0rtt(&self.paths.primary(), now);
- self.handle_lost_packets(&dropped);
+ if let Some(path) = self.paths.primary() {
+ let dropped = self.loss_recovery.drop_0rtt(&path, now);
+ self.handle_lost_packets(&dropped);
+ }
self.streams.zero_rtt_rejected();
@@ -2923,7 +3008,7 @@ impl Connection {
// Remove the randomized client CID from the list of acceptable CIDs.
self.cid_manager.remove_odcid();
// Mark the path as validated, if it isn't already.
- let path = self.paths.primary();
+ let path = self.paths.primary().ok_or(Error::NoAvailablePath)?;
path.borrow_mut().set_valid(now);
// Generate a qlog event that the server connection started.
qlog::server_connection_started(&mut self.qlog, &path);
@@ -3191,7 +3276,7 @@ impl Connection {
else {
return Err(Error::NotAvailable);
};
- let path = self.paths.primary_fallible().ok_or(Error::NotAvailable)?;
+ let path = self.paths.primary().ok_or(Error::NotAvailable)?;
let mtu = path.borrow().mtu();
let encoder = Encoder::with_capacity(mtu);
diff --git a/third_party/rust/neqo-transport/src/connection/state.rs b/third_party/rust/neqo-transport/src/connection/state.rs
index cc2f6e30d2..e76f937938 100644
--- a/third_party/rust/neqo-transport/src/connection/state.rs
+++ b/third_party/rust/neqo-transport/src/connection/state.rs
@@ -21,7 +21,7 @@ use crate::{
packet::PacketBuilder,
path::PathRef,
recovery::RecoveryToken,
- ConnectionError, Error,
+ CloseReason, Error,
};
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -42,14 +42,14 @@ pub enum State {
Connected,
Confirmed,
Closing {
- error: ConnectionError,
+ error: CloseReason,
timeout: Instant,
},
Draining {
- error: ConnectionError,
+ error: CloseReason,
timeout: Instant,
},
- Closed(ConnectionError),
+ Closed(CloseReason),
}
impl State {
@@ -67,7 +67,7 @@ impl State {
}
#[must_use]
- pub fn error(&self) -> Option<&ConnectionError> {
+ pub fn error(&self) -> Option<&CloseReason> {
if let Self::Closing { error, .. } | Self::Draining { error, .. } | Self::Closed(error) =
self
{
@@ -116,7 +116,7 @@ impl Ord for State {
#[derive(Debug, Clone)]
pub struct ClosingFrame {
path: PathRef,
- error: ConnectionError,
+ error: CloseReason,
frame_type: FrameType,
reason_phrase: Vec<u8>,
}
@@ -124,7 +124,7 @@ pub struct ClosingFrame {
impl ClosingFrame {
fn new(
path: PathRef,
- error: ConnectionError,
+ error: CloseReason,
frame_type: FrameType,
message: impl AsRef<str>,
) -> Self {
@@ -142,12 +142,12 @@ impl ClosingFrame {
}
pub fn sanitize(&self) -> Option<Self> {
- if let ConnectionError::Application(_) = self.error {
+ if let CloseReason::Application(_) = self.error {
// The default CONNECTION_CLOSE frame that is sent when an application
// error code needs to be sent in an Initial or Handshake packet.
Some(Self {
path: Rc::clone(&self.path),
- error: ConnectionError::Transport(Error::ApplicationError),
+ error: CloseReason::Transport(Error::ApplicationError),
frame_type: 0,
reason_phrase: Vec::new(),
})
@@ -156,19 +156,22 @@ impl ClosingFrame {
}
}
+ /// Length of a closing frame with a truncated `reason_length`. Allow 8 bytes for the reason
+ /// phrase to ensure that if it needs to be truncated there is still at least a few bytes of
+ /// the value.
+ pub const MIN_LENGTH: usize = 1 + 8 + 8 + 2 + 8;
+
pub fn write_frame(&self, builder: &mut PacketBuilder) {
- // Allow 8 bytes for the reason phrase to ensure that if it needs to be
- // truncated there is still at least a few bytes of the value.
- if builder.remaining() < 1 + 8 + 8 + 2 + 8 {
+ if builder.remaining() < ClosingFrame::MIN_LENGTH {
return;
}
match &self.error {
- ConnectionError::Transport(e) => {
+ CloseReason::Transport(e) => {
builder.encode_varint(FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT);
builder.encode_varint(e.code());
builder.encode_varint(self.frame_type);
}
- ConnectionError::Application(code) => {
+ CloseReason::Application(code) => {
builder.encode_varint(FRAME_TYPE_CONNECTION_CLOSE_APPLICATION);
builder.encode_varint(*code);
}
@@ -209,10 +212,6 @@ pub enum StateSignaling {
impl StateSignaling {
pub fn handshake_done(&mut self) {
if !matches!(self, Self::Idle) {
- debug_assert!(
- false,
- "StateSignaling must be in Idle state but is in {self:?} state.",
- );
return;
}
*self = Self::HandshakeDone;
@@ -231,7 +230,7 @@ impl StateSignaling {
pub fn close(
&mut self,
path: PathRef,
- error: ConnectionError,
+ error: CloseReason,
frame_type: FrameType,
message: impl AsRef<str>,
) {
@@ -243,7 +242,7 @@ impl StateSignaling {
pub fn drain(
&mut self,
path: PathRef,
- error: ConnectionError,
+ error: CloseReason,
frame_type: FrameType,
message: impl AsRef<str>,
) {
diff --git a/third_party/rust/neqo-transport/src/connection/tests/cc.rs b/third_party/rust/neqo-transport/src/connection/tests/cc.rs
index b708bc421d..f21f4e184f 100644
--- a/third_party/rust/neqo-transport/src/connection/tests/cc.rs
+++ b/third_party/rust/neqo-transport/src/connection/tests/cc.rs
@@ -6,7 +6,7 @@
use std::{mem, time::Duration};
-use neqo_common::{qdebug, qinfo, Datagram};
+use neqo_common::{qdebug, qinfo, Datagram, IpTosEcn};
use super::{
super::Output, ack_bytes, assert_full_cwnd, connect_rtt_idle, cwnd, cwnd_avail, cwnd_packets,
@@ -36,9 +36,13 @@ fn cc_slow_start() {
assert!(cwnd_avail(&client) < ACK_ONLY_SIZE_LIMIT);
}
-#[test]
-/// Verify that CC moves to cong avoidance when a packet is marked lost.
-fn cc_slow_start_to_cong_avoidance_recovery_period() {
+#[derive(PartialEq, Eq, Clone, Copy)]
+enum CongestionSignal {
+ PacketLoss,
+ EcnCe,
+}
+
+fn cc_slow_start_to_cong_avoidance_recovery_period(congestion_signal: CongestionSignal) {
let mut client = default_client();
let mut server = default_server();
let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
@@ -78,9 +82,17 @@ fn cc_slow_start_to_cong_avoidance_recovery_period() {
assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND * 2);
let flight2_largest = flight1_largest + u64::try_from(c_tx_dgrams.len()).unwrap();
- // Server: Receive and generate ack again, but drop first packet
+ // Server: Receive and generate ack again, but this time add congestion
+ // signal first.
now += DEFAULT_RTT / 2;
- c_tx_dgrams.remove(0);
+ match congestion_signal {
+ CongestionSignal::PacketLoss => {
+ c_tx_dgrams.remove(0);
+ }
+ CongestionSignal::EcnCe => {
+ c_tx_dgrams.last_mut().unwrap().set_tos(IpTosEcn::Ce.into());
+ }
+ }
let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now);
assert_eq!(
server.stats().frame_tx.largest_acknowledged,
@@ -98,6 +110,18 @@ fn cc_slow_start_to_cong_avoidance_recovery_period() {
}
#[test]
+/// Verify that CC moves to cong avoidance when a packet is marked lost.
+fn cc_slow_start_to_cong_avoidance_recovery_period_due_to_packet_loss() {
+ cc_slow_start_to_cong_avoidance_recovery_period(CongestionSignal::PacketLoss);
+}
+
+/// Verify that CC moves to cong avoidance when ACK is marked with ECN CE.
+#[test]
+fn cc_slow_start_to_cong_avoidance_recovery_period_due_to_ecn_ce() {
+ cc_slow_start_to_cong_avoidance_recovery_period(CongestionSignal::EcnCe);
+}
+
+#[test]
/// Verify that CC stays in recovery period when packet sent before start of
/// recovery period is acked.
fn cc_cong_avoidance_recovery_period_unchanged() {
diff --git a/third_party/rust/neqo-transport/src/connection/tests/close.rs b/third_party/rust/neqo-transport/src/connection/tests/close.rs
index 5351dd0d5c..7c620de17e 100644
--- a/third_party/rust/neqo-transport/src/connection/tests/close.rs
+++ b/third_party/rust/neqo-transport/src/connection/tests/close.rs
@@ -14,13 +14,13 @@ use super::{
};
use crate::{
tparams::{self, TransportParameter},
- AppError, ConnectionError, Error, ERROR_APPLICATION_CLOSE,
+ AppError, CloseReason, Error, ERROR_APPLICATION_CLOSE,
};
fn assert_draining(c: &Connection, expected: &Error) {
assert!(c.state().closed());
if let State::Draining {
- error: ConnectionError::Transport(error),
+ error: CloseReason::Transport(error),
..
} = c.state()
{
@@ -40,7 +40,14 @@ fn connection_close() {
client.close(now, 42, "");
+ let stats_before = client.stats().frame_tx;
let out = client.process(None, now);
+ let stats_after = client.stats().frame_tx;
+ assert_eq!(
+ stats_after.connection_close,
+ stats_before.connection_close + 1
+ );
+ assert_eq!(stats_after.ack, stats_before.ack + 1);
server.process_input(&out.dgram().unwrap(), now);
assert_draining(&server, &Error::PeerApplicationError(42));
@@ -57,7 +64,14 @@ fn connection_close_with_long_reason_string() {
let long_reason = String::from_utf8([0x61; 2048].to_vec()).unwrap();
client.close(now, 42, long_reason);
+ let stats_before = client.stats().frame_tx;
let out = client.process(None, now);
+ let stats_after = client.stats().frame_tx;
+ assert_eq!(
+ stats_after.connection_close,
+ stats_before.connection_close + 1
+ );
+ assert_eq!(stats_after.ack, stats_before.ack + 1);
server.process_input(&out.dgram().unwrap(), now);
assert_draining(&server, &Error::PeerApplicationError(42));
@@ -100,7 +114,7 @@ fn bad_tls_version() {
let dgram = server.process(dgram.as_ref(), now()).dgram();
assert_eq!(
*server.state(),
- State::Closed(ConnectionError::Transport(Error::ProtocolViolation))
+ State::Closed(CloseReason::Transport(Error::ProtocolViolation))
);
assert!(dgram.is_some());
client.process_input(&dgram.unwrap(), now());
@@ -154,7 +168,6 @@ fn closing_and_draining() {
assert!(client_close.is_some());
let client_close_timer = client.process(None, now()).callback();
assert_ne!(client_close_timer, Duration::from_secs(0));
-
// The client will spit out the same packet in response to anything it receives.
let p3 = send_something(&mut server, now());
let client_close2 = client.process(Some(&p3), now()).dgram();
@@ -168,7 +181,7 @@ fn closing_and_draining() {
assert_eq!(end, Output::None);
assert_eq!(
*client.state(),
- State::Closed(ConnectionError::Application(APP_ERROR))
+ State::Closed(CloseReason::Application(APP_ERROR))
);
// When the server receives the close, it too should generate CONNECTION_CLOSE.
@@ -186,7 +199,7 @@ fn closing_and_draining() {
assert_eq!(end, Output::None);
assert_eq!(
*server.state(),
- State::Closed(ConnectionError::Transport(Error::PeerApplicationError(
+ State::Closed(CloseReason::Transport(Error::PeerApplicationError(
APP_ERROR
)))
);
diff --git a/third_party/rust/neqo-transport/src/connection/tests/datagram.rs b/third_party/rust/neqo-transport/src/connection/tests/datagram.rs
index ade8c753be..f1b64b3c8f 100644
--- a/third_party/rust/neqo-transport/src/connection/tests/datagram.rs
+++ b/third_party/rust/neqo-transport/src/connection/tests/datagram.rs
@@ -19,7 +19,7 @@ use crate::{
packet::PacketBuilder,
quic_datagrams::MAX_QUIC_DATAGRAM,
send_stream::{RetransmissionPriority, TransmissionPriority},
- Connection, ConnectionError, ConnectionParameters, Error, StreamType,
+ CloseReason, Connection, ConnectionParameters, Error, StreamType,
};
const DATAGRAM_LEN_MTU: u64 = 1310;
@@ -362,10 +362,7 @@ fn dgram_no_allowed() {
client.process_input(&out, now());
- assert_error(
- &client,
- &ConnectionError::Transport(Error::ProtocolViolation),
- );
+ assert_error(&client, &CloseReason::Transport(Error::ProtocolViolation));
}
#[test]
@@ -383,10 +380,7 @@ fn dgram_too_big() {
client.process_input(&out, now());
- assert_error(
- &client,
- &ConnectionError::Transport(Error::ProtocolViolation),
- );
+ assert_error(&client, &CloseReason::Transport(Error::ProtocolViolation));
}
#[test]
@@ -587,7 +581,7 @@ fn datagram_fill() {
// Work out how much space we have for a datagram.
let space = {
- let p = client.paths.primary();
+ let p = client.paths.primary().unwrap();
let path = p.borrow();
// Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number,
// 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD.
diff --git a/third_party/rust/neqo-transport/src/connection/tests/ecn.rs b/third_party/rust/neqo-transport/src/connection/tests/ecn.rs
new file mode 100644
index 0000000000..87957297e5
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/ecn.rs
@@ -0,0 +1,392 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::time::Duration;
+
+use neqo_common::{Datagram, IpTos, IpTosEcn};
+use test_fixture::{
+ assertions::{assert_v4_path, assert_v6_path},
+ fixture_init, now, DEFAULT_ADDR_V4,
+};
+
+use super::send_something_with_modifier;
+use crate::{
+ connection::tests::{
+ connect_force_idle, connect_force_idle_with_modifier, default_client, default_server,
+ migration::get_cid, new_client, new_server, send_something,
+ },
+ ecn::ECN_TEST_COUNT,
+ ConnectionId, ConnectionParameters, StreamType,
+};
+
+fn assert_ecn_enabled(tos: IpTos) {
+ assert!(tos.is_ecn_marked());
+}
+
+fn assert_ecn_disabled(tos: IpTos) {
+ assert!(!tos.is_ecn_marked());
+}
+
+fn set_tos(mut d: Datagram, ecn: IpTosEcn) -> Datagram {
+ d.set_tos(ecn.into());
+ d
+}
+
+fn noop() -> fn(Datagram) -> Option<Datagram> {
+ Some
+}
+
+fn bleach() -> fn(Datagram) -> Option<Datagram> {
+ |d| Some(set_tos(d, IpTosEcn::NotEct))
+}
+
+fn remark() -> fn(Datagram) -> Option<Datagram> {
+ |d| {
+ if d.tos().is_ecn_marked() {
+ Some(set_tos(d, IpTosEcn::Ect1))
+ } else {
+ Some(d)
+ }
+ }
+}
+
+fn ce() -> fn(Datagram) -> Option<Datagram> {
+ |d| {
+ if d.tos().is_ecn_marked() {
+ Some(set_tos(d, IpTosEcn::Ce))
+ } else {
+ Some(d)
+ }
+ }
+}
+
+fn drop() -> fn(Datagram) -> Option<Datagram> {
+ |_| None
+}
+
+#[test]
+fn disables_on_loss() {
+ let now = now();
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ // Right after the handshake, the ECN validation should still be in progress.
+ let client_pkt = send_something(&mut client, now);
+ assert_ecn_enabled(client_pkt.tos());
+
+ for _ in 0..ECN_TEST_COUNT {
+ send_something(&mut client, now);
+ }
+
+ // ECN should now be disabled.
+ let client_pkt = send_something(&mut client, now);
+ assert_ecn_disabled(client_pkt.tos());
+}
+
+/// This function performs a handshake over a path that modifies packets via `orig_path_modifier`.
+/// It then sends `burst` packets on that path, and then migrates to a new path that
+/// modifies packets via `new_path_modifier`. It sends `burst` packets on the new path.
+/// The function returns the TOS value of the last packet sent on the old path and the TOS value
+/// of the last packet sent on the new path to allow for verification of correct behavior.
+pub fn migration_with_modifiers(
+ orig_path_modifier: fn(Datagram) -> Option<Datagram>,
+ new_path_modifier: fn(Datagram) -> Option<Datagram>,
+ burst: usize,
+) -> (IpTos, IpTos, bool) {
+ fixture_init();
+ let mut client = new_client(ConnectionParameters::default().max_streams(StreamType::UniDi, 64));
+ let mut server = new_server(ConnectionParameters::default().max_streams(StreamType::UniDi, 64));
+
+ connect_force_idle_with_modifier(&mut client, &mut server, orig_path_modifier);
+ let mut now = now();
+
+ // Right after the handshake, the ECN validation should still be in progress.
+ let client_pkt = send_something(&mut client, now);
+ assert_ecn_enabled(client_pkt.tos());
+ server.process_input(&orig_path_modifier(client_pkt).unwrap(), now);
+
+ // Send some data on the current path.
+ for _ in 0..burst {
+ let client_pkt = send_something_with_modifier(&mut client, now, orig_path_modifier);
+ server.process_input(&client_pkt, now);
+ }
+
+ if let Some(ack) = server.process_output(now).dgram() {
+ client.process_input(&ack, now);
+ }
+
+ let client_pkt = send_something(&mut client, now);
+ let tos_before_migration = client_pkt.tos();
+ server.process_input(&orig_path_modifier(client_pkt).unwrap(), now);
+
+ client
+ .migrate(Some(DEFAULT_ADDR_V4), Some(DEFAULT_ADDR_V4), false, now)
+ .unwrap();
+
+ let mut migrated = false;
+ let probe = new_path_modifier(client.process_output(now).dgram().unwrap());
+ if let Some(probe) = probe {
+ assert_v4_path(&probe, true); // Contains PATH_CHALLENGE.
+ assert_eq!(client.stats().frame_tx.path_challenge, 1);
+ let probe_cid = ConnectionId::from(get_cid(&probe));
+
+ let resp = new_path_modifier(server.process(Some(&probe), now).dgram().unwrap()).unwrap();
+ assert_v4_path(&resp, true);
+ assert_eq!(server.stats().frame_tx.path_response, 1);
+ assert_eq!(server.stats().frame_tx.path_challenge, 1);
+
+ // Data continues to be exchanged on the old path.
+ let client_data = send_something_with_modifier(&mut client, now, orig_path_modifier);
+ assert_ne!(get_cid(&client_data), probe_cid);
+ assert_v6_path(&client_data, false);
+ server.process_input(&client_data, now);
+ let server_data = send_something_with_modifier(&mut server, now, orig_path_modifier);
+ assert_v6_path(&server_data, false);
+ client.process_input(&server_data, now);
+
+ // Once the client receives the probe response, it migrates to the new path.
+ client.process_input(&resp, now);
+ assert_eq!(client.stats().frame_rx.path_challenge, 1);
+ migrated = true;
+
+ let migrate_client = send_something_with_modifier(&mut client, now, new_path_modifier);
+ assert_v4_path(&migrate_client, true); // Responds to server probe.
+
+ // The server now sees the migration and will switch over.
+ // However, it will probe the old path again, even though it has just
+ // received a response to its last probe, because it needs to verify
+ // that the migration is genuine.
+ server.process_input(&migrate_client, now);
+ }
+
+ let stream_before = server.stats().frame_tx.stream;
+ let probe_old_server = send_something_with_modifier(&mut server, now, orig_path_modifier);
+ // This is just the double-check probe; no STREAM frames.
+ assert_v6_path(&probe_old_server, migrated);
+ assert_eq!(
+ server.stats().frame_tx.path_challenge,
+ if migrated { 2 } else { 0 }
+ );
+ assert_eq!(
+ server.stats().frame_tx.stream,
+ if migrated { stream_before } else { 1 }
+ );
+
+ if migrated {
+ // The server then sends data on the new path.
+ let migrate_server =
+ new_path_modifier(server.process_output(now).dgram().unwrap()).unwrap();
+ assert_v4_path(&migrate_server, false);
+ assert_eq!(server.stats().frame_tx.path_challenge, 2);
+ assert_eq!(server.stats().frame_tx.stream, stream_before + 1);
+
+ // The client receives these checks and responds to the probe, but uses the new path.
+ client.process_input(&migrate_server, now);
+ client.process_input(&probe_old_server, now);
+ let old_probe_resp = send_something_with_modifier(&mut client, now, new_path_modifier);
+ assert_v6_path(&old_probe_resp, true);
+ let client_confirmation = client.process_output(now).dgram().unwrap();
+ assert_v4_path(&client_confirmation, false);
+
+ // The server has now sent 2 packets, so it is blocked on the pacer. Wait.
+ let server_pacing = server.process_output(now).callback();
+ assert_ne!(server_pacing, Duration::new(0, 0));
+ // ... then confirm that the server sends on the new path still.
+ let server_confirmation =
+ send_something_with_modifier(&mut server, now + server_pacing, new_path_modifier);
+ assert_v4_path(&server_confirmation, false);
+ client.process_input(&server_confirmation, now);
+
+ // Send some data on the new path.
+ for _ in 0..burst {
+ now += client.process_output(now).callback();
+ let client_pkt = send_something_with_modifier(&mut client, now, new_path_modifier);
+ server.process_input(&client_pkt, now);
+ }
+
+ if let Some(ack) = server.process_output(now).dgram() {
+ client.process_input(&ack, now);
+ }
+ }
+
+ now += client.process_output(now).callback();
+ let mut client_pkt = send_something(&mut client, now);
+ while !migrated && client_pkt.source() == DEFAULT_ADDR_V4 {
+ client_pkt = send_something(&mut client, now);
+ }
+ let tos_after_migration = client_pkt.tos();
+ (tos_before_migration, tos_after_migration, migrated)
+}
+
+#[test]
+fn ecn_migration_zero_burst_all_cases() {
+ for orig_path_mod in &[noop(), bleach(), remark(), ce()] {
+ for new_path_mod in &[noop(), bleach(), remark(), ce(), drop()] {
+ let (before, after, migrated) =
+ migration_with_modifiers(*orig_path_mod, *new_path_mod, 0);
+ // Too few packets sent before and after migration to conclude ECN validation.
+ assert_ecn_enabled(before);
+ assert_ecn_enabled(after);
+ // Migration succeeds except if the new path drops ECN.
+ assert!(*new_path_mod == drop() || migrated);
+ }
+ }
+}
+
+#[test]
+fn ecn_migration_noop_bleach_data() {
+ let (before, after, migrated) = migration_with_modifiers(noop(), bleach(), ECN_TEST_COUNT);
+ assert_ecn_enabled(before); // ECN validation concludes before migration.
+ assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_noop_remark_data() {
+ let (before, after, migrated) = migration_with_modifiers(noop(), remark(), ECN_TEST_COUNT);
+ assert_ecn_enabled(before); // ECN validation concludes before migration.
+ assert_ecn_disabled(after); // ECN validation fails after migration due to remarking.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_noop_ce_data() {
+ let (before, after, migrated) = migration_with_modifiers(noop(), ce(), ECN_TEST_COUNT);
+ assert_ecn_enabled(before); // ECN validation concludes before migration.
+ assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_noop_drop_data() {
+ let (before, after, migrated) = migration_with_modifiers(noop(), drop(), ECN_TEST_COUNT);
+ assert_ecn_enabled(before); // ECN validation concludes before migration.
+ assert_ecn_enabled(after); // Migration failed, ECN on original path is still validated.
+ assert!(!migrated);
+}
+
+#[test]
+fn ecn_migration_bleach_noop_data() {
+ let (before, after, migrated) = migration_with_modifiers(bleach(), noop(), ECN_TEST_COUNT);
+ assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching.
+ assert_ecn_enabled(after); // ECN validation concludes after migration.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_bleach_bleach_data() {
+ let (before, after, migrated) = migration_with_modifiers(bleach(), bleach(), ECN_TEST_COUNT);
+ assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching.
+ assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_bleach_remark_data() {
+ let (before, after, migrated) = migration_with_modifiers(bleach(), remark(), ECN_TEST_COUNT);
+ assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching.
+ assert_ecn_disabled(after); // ECN validation fails after migration due to remarking.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_bleach_ce_data() {
+ let (before, after, migrated) = migration_with_modifiers(bleach(), ce(), ECN_TEST_COUNT);
+ assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching.
+ assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_bleach_drop_data() {
+ let (before, after, migrated) = migration_with_modifiers(bleach(), drop(), ECN_TEST_COUNT);
+ assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching.
+ // Migration failed, ECN on original path is still disabled.
+ assert_ecn_disabled(after);
+ assert!(!migrated);
+}
+
+#[test]
+fn ecn_migration_remark_noop_data() {
+ let (before, after, migrated) = migration_with_modifiers(remark(), noop(), ECN_TEST_COUNT);
+ assert_ecn_disabled(before); // ECN validation fails before migration due to remarking.
+ assert_ecn_enabled(after); // ECN validation succeeds after migration.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_remark_bleach_data() {
+ let (before, after, migrated) = migration_with_modifiers(remark(), bleach(), ECN_TEST_COUNT);
+ assert_ecn_disabled(before); // ECN validation fails before migration due to remarking.
+ assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_remark_remark_data() {
+ let (before, after, migrated) = migration_with_modifiers(remark(), remark(), ECN_TEST_COUNT);
+ assert_ecn_disabled(before); // ECN validation fails before migration due to remarking.
+ assert_ecn_disabled(after); // ECN validation fails after migration due to remarking.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_remark_ce_data() {
+ let (before, after, migrated) = migration_with_modifiers(remark(), ce(), ECN_TEST_COUNT);
+ assert_ecn_disabled(before); // ECN validation fails before migration due to remarking.
+ assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_remark_drop_data() {
+ let (before, after, migrated) = migration_with_modifiers(remark(), drop(), ECN_TEST_COUNT);
+ assert_ecn_disabled(before); // ECN validation fails before migration due to remarking.
+ assert_ecn_disabled(after); // Migration failed, ECN on original path is still disabled.
+ assert!(!migrated);
+}
+
+#[test]
+fn ecn_migration_ce_noop_data() {
+ let (before, after, migrated) = migration_with_modifiers(ce(), noop(), ECN_TEST_COUNT);
+ assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks.
+ assert_ecn_enabled(after); // ECN validation concludes after migration.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_ce_bleach_data() {
+ let (before, after, migrated) = migration_with_modifiers(ce(), bleach(), ECN_TEST_COUNT);
+ assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks.
+ assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_ce_remark_data() {
+ let (before, after, migrated) = migration_with_modifiers(ce(), remark(), ECN_TEST_COUNT);
+ assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks.
+ assert_ecn_disabled(after); // ECN validation fails after migration due to remarking.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_ce_ce_data() {
+ let (before, after, migrated) = migration_with_modifiers(ce(), ce(), ECN_TEST_COUNT);
+ assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks.
+ assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks.
+ assert!(migrated);
+}
+
+#[test]
+fn ecn_migration_ce_drop_data() {
+ let (before, after, migrated) = migration_with_modifiers(ce(), drop(), ECN_TEST_COUNT);
+ assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks.
+ // Migration failed, ECN on original path is still enabled.
+ assert_ecn_enabled(after);
+ assert!(!migrated);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/handshake.rs b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs
index f2103523ec..c908340616 100644
--- a/third_party/rust/neqo-transport/src/connection/tests/handshake.rs
+++ b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs
@@ -35,7 +35,7 @@ use crate::{
server::ValidateAddress,
tparams::{TransportParameter, MIN_ACK_DELAY},
tracking::DEFAULT_ACK_DELAY,
- ConnectionError, ConnectionParameters, EmptyConnectionIdGenerator, Error, StreamType, Version,
+ CloseReason, ConnectionParameters, EmptyConnectionIdGenerator, Error, StreamType, Version,
};
const ECH_CONFIG_ID: u8 = 7;
@@ -111,8 +111,8 @@ fn handshake_failed_authentication() {
qdebug!("---- server: Alert(certificate_revoked)");
let out = server.process(out.as_dgram_ref(), now());
assert!(out.as_dgram_ref().is_some());
- assert_error(&client, &ConnectionError::Transport(Error::CryptoAlert(44)));
- assert_error(&server, &ConnectionError::Transport(Error::PeerError(300)));
+ assert_error(&client, &CloseReason::Transport(Error::CryptoAlert(44)));
+ assert_error(&server, &CloseReason::Transport(Error::PeerError(300)));
}
#[test]
@@ -133,11 +133,8 @@ fn no_alpn() {
handshake(&mut client, &mut server, now(), Duration::new(0, 0));
// TODO (mt): errors are immediate, which means that we never send CONNECTION_CLOSE
// and the client never sees the server's rejection of its handshake.
- // assert_error(&client, ConnectionError::Transport(Error::CryptoAlert(120)));
- assert_error(
- &server,
- &ConnectionError::Transport(Error::CryptoAlert(120)),
- );
+ // assert_error(&client, CloseReason::Transport(Error::CryptoAlert(120)));
+ assert_error(&server, &CloseReason::Transport(Error::CryptoAlert(120)));
}
#[test]
@@ -934,10 +931,10 @@ fn ech_retry() {
server.process_input(&dgram.unwrap(), now());
assert_eq!(
server.state().error(),
- Some(&ConnectionError::Transport(Error::PeerError(0x100 + 121)))
+ Some(&CloseReason::Transport(Error::PeerError(0x100 + 121)))
);
- let Some(ConnectionError::Transport(Error::EchRetry(updated_config))) = client.state().error()
+ let Some(CloseReason::Transport(Error::EchRetry(updated_config))) = client.state().error()
else {
panic!(
"Client state should be failed with EchRetry, is {:?}",
@@ -984,7 +981,7 @@ fn ech_retry_fallback_rejected() {
client.authenticated(AuthenticationStatus::PolicyRejection, now());
assert!(client.state().error().is_some());
- if let Some(ConnectionError::Transport(Error::EchRetry(_))) = client.state().error() {
+ if let Some(CloseReason::Transport(Error::EchRetry(_))) = client.state().error() {
panic!("Client should not get EchRetry error");
}
@@ -993,14 +990,13 @@ fn ech_retry_fallback_rejected() {
server.process_input(&dgram.unwrap(), now());
assert_eq!(
server.state().error(),
- Some(&ConnectionError::Transport(Error::PeerError(298)))
+ Some(&CloseReason::Transport(Error::PeerError(298)))
); // A bad_certificate alert.
}
#[test]
fn bad_min_ack_delay() {
- const EXPECTED_ERROR: ConnectionError =
- ConnectionError::Transport(Error::TransportParameterError);
+ const EXPECTED_ERROR: CloseReason = CloseReason::Transport(Error::TransportParameterError);
let mut server = default_server();
let max_ad = u64::try_from(DEFAULT_ACK_DELAY.as_micros()).unwrap();
server
@@ -1018,7 +1014,7 @@ fn bad_min_ack_delay() {
server.process_input(&dgram.unwrap(), now());
assert_eq!(
server.state().error(),
- Some(&ConnectionError::Transport(Error::PeerError(
+ Some(&CloseReason::Transport(Error::PeerError(
Error::TransportParameterError.code()
)))
);
diff --git a/third_party/rust/neqo-transport/src/connection/tests/keys.rs b/third_party/rust/neqo-transport/src/connection/tests/keys.rs
index 847b253284..c2ae9529bf 100644
--- a/third_party/rust/neqo-transport/src/connection/tests/keys.rs
+++ b/third_party/rust/neqo-transport/src/connection/tests/keys.rs
@@ -11,7 +11,7 @@ use test_fixture::now;
use super::{
super::{
- super::{ConnectionError, ERROR_AEAD_LIMIT_REACHED},
+ super::{CloseReason, ERROR_AEAD_LIMIT_REACHED},
Connection, ConnectionParameters, Error, Output, State, StreamType,
},
connect, connect_force_idle, default_client, default_server, maybe_authenticate,
@@ -269,7 +269,7 @@ fn exhaust_write_keys() {
assert!(dgram.is_none());
assert!(matches!(
client.state(),
- State::Closed(ConnectionError::Transport(Error::KeysExhausted))
+ State::Closed(CloseReason::Transport(Error::KeysExhausted))
));
}
@@ -285,14 +285,14 @@ fn exhaust_read_keys() {
let dgram = server.process(Some(&dgram), now()).dgram();
assert!(matches!(
server.state(),
- State::Closed(ConnectionError::Transport(Error::KeysExhausted))
+ State::Closed(CloseReason::Transport(Error::KeysExhausted))
));
client.process_input(&dgram.unwrap(), now());
assert!(matches!(
client.state(),
State::Draining {
- error: ConnectionError::Transport(Error::PeerError(ERROR_AEAD_LIMIT_REACHED)),
+ error: CloseReason::Transport(Error::PeerError(ERROR_AEAD_LIMIT_REACHED)),
..
}
));
@@ -341,6 +341,6 @@ fn automatic_update_write_keys_blocked() {
assert!(dgram.is_none());
assert!(matches!(
client.state(),
- State::Closed(ConnectionError::Transport(Error::KeysExhausted))
+ State::Closed(CloseReason::Transport(Error::KeysExhausted))
));
}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/migration.rs b/third_party/rust/neqo-transport/src/connection/tests/migration.rs
index 405ae161a4..779cc78c53 100644
--- a/third_party/rust/neqo-transport/src/connection/tests/migration.rs
+++ b/third_party/rust/neqo-transport/src/connection/tests/migration.rs
@@ -30,7 +30,7 @@ use crate::{
packet::PacketBuilder,
path::{PATH_MTU_V4, PATH_MTU_V6},
tparams::{self, PreferredAddress, TransportParameter},
- ConnectionError, ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef,
+ CloseReason, ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef,
ConnectionParameters, EmptyConnectionIdGenerator, Error,
};
@@ -357,13 +357,13 @@ fn migrate_same_fail() {
assert!(matches!(res, Output::None));
assert!(matches!(
client.state(),
- State::Closed(ConnectionError::Transport(Error::NoAvailablePath))
+ State::Closed(CloseReason::Transport(Error::NoAvailablePath))
));
}
/// This gets the connection ID from a datagram using the default
/// connection ID generator/decoder.
-fn get_cid(d: &Datagram) -> ConnectionIdRef {
+pub fn get_cid(d: &Datagram) -> ConnectionIdRef {
let gen = CountingConnectionIdGenerator::default();
assert_eq!(d[0] & 0x80, 0); // Only support short packets for now.
gen.decode_cid(&mut Decoder::from(&d[1..])).unwrap()
@@ -894,7 +894,7 @@ fn retire_prior_to_migration_failure() {
assert!(matches!(
client.state(),
State::Closing {
- error: ConnectionError::Transport(Error::InvalidMigration),
+ error: CloseReason::Transport(Error::InvalidMigration),
..
}
));
diff --git a/third_party/rust/neqo-transport/src/connection/tests/mod.rs b/third_party/rust/neqo-transport/src/connection/tests/mod.rs
index c8c87a0df0..65283b8eb8 100644
--- a/third_party/rust/neqo-transport/src/connection/tests/mod.rs
+++ b/third_party/rust/neqo-transport/src/connection/tests/mod.rs
@@ -17,7 +17,7 @@ use neqo_common::{event::Provider, qdebug, qtrace, Datagram, Decoder, Role};
use neqo_crypto::{random, AllowZeroRtt, AuthenticationStatus, ResumptionToken};
use test_fixture::{fixture_init, new_neqo_qlog, now, DEFAULT_ADDR};
-use super::{Connection, ConnectionError, ConnectionId, Output, State};
+use super::{CloseReason, Connection, ConnectionId, Output, State};
use crate::{
addr_valid::{AddressValidation, ValidateAddress},
cc::{CWND_INITIAL_PKTS, CWND_MIN},
@@ -37,6 +37,7 @@ mod ackrate;
mod cc;
mod close;
mod datagram;
+mod ecn;
mod handshake;
mod idle;
mod keys;
@@ -170,17 +171,13 @@ impl crate::connection::test_internal::FrameWriter for PingWriter {
}
}
-trait DatagramModifier: FnMut(Datagram) -> Option<Datagram> {}
-
-impl<T> DatagramModifier for T where T: FnMut(Datagram) -> Option<Datagram> {}
-
/// Drive the handshake between the client and server.
fn handshake_with_modifier(
client: &mut Connection,
server: &mut Connection,
now: Instant,
rtt: Duration,
- mut modifier: impl DatagramModifier,
+ modifier: fn(Datagram) -> Option<Datagram>,
) -> Instant {
let mut a = client;
let mut b = server;
@@ -248,8 +245,8 @@ fn connect_fail(
server_error: Error,
) {
handshake(client, server, now(), Duration::new(0, 0));
- assert_error(client, &ConnectionError::Transport(client_error));
- assert_error(server, &ConnectionError::Transport(server_error));
+ assert_error(client, &CloseReason::Transport(client_error));
+ assert_error(server, &CloseReason::Transport(server_error));
}
fn connect_with_rtt_and_modifier(
@@ -257,7 +254,7 @@ fn connect_with_rtt_and_modifier(
server: &mut Connection,
now: Instant,
rtt: Duration,
- modifier: impl DatagramModifier,
+ modifier: fn(Datagram) -> Option<Datagram>,
) -> Instant {
fn check_rtt(stats: &Stats, rtt: Duration) {
assert_eq!(stats.rtt, rtt);
@@ -287,7 +284,7 @@ fn connect(client: &mut Connection, server: &mut Connection) {
connect_with_rtt(client, server, now(), Duration::new(0, 0));
}
-fn assert_error(c: &Connection, expected: &ConnectionError) {
+fn assert_error(c: &Connection, expected: &CloseReason) {
match c.state() {
State::Closing { error, .. } | State::Draining { error, .. } | State::Closed(error) => {
assert_eq!(*error, *expected, "{c} error mismatch");
@@ -333,7 +330,7 @@ fn connect_rtt_idle_with_modifier(
client: &mut Connection,
server: &mut Connection,
rtt: Duration,
- modifier: impl DatagramModifier,
+ modifier: fn(Datagram) -> Option<Datagram>,
) -> Instant {
let now = connect_with_rtt_and_modifier(client, server, now(), rtt, modifier);
assert_idle(client, server, rtt, now);
@@ -351,7 +348,7 @@ fn connect_rtt_idle(client: &mut Connection, server: &mut Connection, rtt: Durat
fn connect_force_idle_with_modifier(
client: &mut Connection,
server: &mut Connection,
- modifier: impl DatagramModifier,
+ modifier: fn(Datagram) -> Option<Datagram>,
) {
connect_rtt_idle_with_modifier(client, server, Duration::new(0, 0), modifier);
}
@@ -380,7 +377,7 @@ fn fill_stream(c: &mut Connection, stream: StreamId) {
fn fill_cwnd(c: &mut Connection, stream: StreamId, mut now: Instant) -> (Vec<Datagram>, Instant) {
// Train wreck function to get the remaining congestion window on the primary path.
fn cwnd(c: &Connection) -> usize {
- c.paths.primary().borrow().sender().cwnd_avail()
+ c.paths.primary().unwrap().borrow().sender().cwnd_avail()
}
qtrace!("fill_cwnd starting cwnd: {}", cwnd(c));
@@ -478,10 +475,10 @@ where
// Get the current congestion window for the connection.
fn cwnd(c: &Connection) -> usize {
- c.paths.primary().borrow().sender().cwnd()
+ c.paths.primary().unwrap().borrow().sender().cwnd()
}
fn cwnd_avail(c: &Connection) -> usize {
- c.paths.primary().borrow().sender().cwnd_avail()
+ c.paths.primary().unwrap().borrow().sender().cwnd_avail()
}
fn induce_persistent_congestion(
@@ -576,7 +573,7 @@ fn send_something_paced_with_modifier(
sender: &mut Connection,
mut now: Instant,
allow_pacing: bool,
- mut modifier: impl DatagramModifier,
+ modifier: fn(Datagram) -> Option<Datagram>,
) -> (Datagram, Instant) {
let stream_id = sender.stream_create(StreamType::UniDi).unwrap();
assert!(sender.stream_send(stream_id, DEFAULT_STREAM_DATA).is_ok());
@@ -608,7 +605,7 @@ fn send_something_paced(
fn send_something_with_modifier(
sender: &mut Connection,
now: Instant,
- modifier: impl DatagramModifier,
+ modifier: fn(Datagram) -> Option<Datagram>,
) -> Datagram {
send_something_paced_with_modifier(sender, now, false, modifier).0
}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/stream.rs b/third_party/rust/neqo-transport/src/connection/tests/stream.rs
index 66d3bf32f3..f7472d917f 100644
--- a/third_party/rust/neqo-transport/src/connection/tests/stream.rs
+++ b/third_party/rust/neqo-transport/src/connection/tests/stream.rs
@@ -19,9 +19,9 @@ use crate::{
send_stream::{OrderGroup, SendStreamState, SEND_BUFFER_SIZE},
streams::{SendOrder, StreamOrder},
tparams::{self, TransportParameter},
+ CloseReason,
// tracking::DEFAULT_ACK_PACKET_TOLERANCE,
Connection,
- ConnectionError,
ConnectionParameters,
Error,
StreamId,
@@ -494,12 +494,9 @@ fn exceed_max_data() {
assert_error(
&client,
- &ConnectionError::Transport(Error::PeerError(Error::FlowControlError.code())),
- );
- assert_error(
- &server,
- &ConnectionError::Transport(Error::FlowControlError),
+ &CloseReason::Transport(Error::PeerError(Error::FlowControlError.code())),
);
+ assert_error(&server, &CloseReason::Transport(Error::FlowControlError));
}
#[test]
diff --git a/third_party/rust/neqo-transport/src/connection/tests/vn.rs b/third_party/rust/neqo-transport/src/connection/tests/vn.rs
index 93872a94f4..815868d78d 100644
--- a/third_party/rust/neqo-transport/src/connection/tests/vn.rs
+++ b/third_party/rust/neqo-transport/src/connection/tests/vn.rs
@@ -10,7 +10,7 @@ use neqo_common::{event::Provider, Decoder, Encoder};
use test_fixture::{assertions, datagram, now};
use super::{
- super::{ConnectionError, ConnectionEvent, Output, State, ZeroRttState},
+ super::{CloseReason, ConnectionEvent, Output, State, ZeroRttState},
connect, connect_fail, default_client, default_server, exchange_ticket, new_client, new_server,
send_something,
};
@@ -124,7 +124,7 @@ fn version_negotiation_only_reserved() {
assert_eq!(client.process(Some(&dgram), now()), Output::None);
match client.state() {
State::Closed(err) => {
- assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation));
+ assert_eq!(*err, CloseReason::Transport(Error::VersionNegotiation));
}
_ => panic!("Invalid client state"),
}
@@ -183,7 +183,7 @@ fn version_negotiation_not_supported() {
assert_eq!(client.process(Some(&dgram), now()), Output::None);
match client.state() {
State::Closed(err) => {
- assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation));
+ assert_eq!(*err, CloseReason::Transport(Error::VersionNegotiation));
}
_ => panic!("Invalid client state"),
}
@@ -338,7 +338,7 @@ fn invalid_server_version() {
// The server effectively hasn't reacted here.
match server.state() {
State::Closed(err) => {
- assert_eq!(*err, ConnectionError::Transport(Error::CryptoAlert(47)));
+ assert_eq!(*err, CloseReason::Transport(Error::CryptoAlert(47)));
}
_ => panic!("invalid server state"),
}
diff --git a/third_party/rust/neqo-transport/src/ecn.rs b/third_party/rust/neqo-transport/src/ecn.rs
new file mode 100644
index 0000000000..20eb4da003
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/ecn.rs
@@ -0,0 +1,225 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::ops::{AddAssign, Deref, DerefMut, Sub};
+
+use enum_map::EnumMap;
+use neqo_common::{qdebug, qinfo, qwarn, IpTosEcn};
+
+use crate::{packet::PacketNumber, tracking::SentPacket};
+
+/// The number of packets to use for testing a path for ECN capability.
+pub const ECN_TEST_COUNT: usize = 10;
+
+/// The state information related to testing a path for ECN capability.
+/// See RFC9000, Appendix A.4.
+#[derive(Debug, PartialEq, Clone)]
+enum EcnValidationState {
+ /// The path is currently being tested for ECN capability, with the number of probes sent so
+ /// far on the path during the ECN validation.
+ Testing(usize),
+ /// The validation test has concluded but the path's ECN capability is not yet known.
+ Unknown,
+ /// The path is known to **not** be ECN capable.
+ Failed,
+ /// The path is known to be ECN capable.
+ Capable,
+}
+
+impl Default for EcnValidationState {
+ fn default() -> Self {
+ EcnValidationState::Testing(0)
+ }
+}
+
+/// The counts for different ECN marks.
+#[derive(PartialEq, Eq, Debug, Clone, Copy, Default)]
+pub struct EcnCount(EnumMap<IpTosEcn, u64>);
+
+impl Deref for EcnCount {
+ type Target = EnumMap<IpTosEcn, u64>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl DerefMut for EcnCount {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+impl EcnCount {
+ pub fn new(not_ect: u64, ect0: u64, ect1: u64, ce: u64) -> Self {
+ // Yes, the enum array order is different from the argument order.
+ Self(EnumMap::from_array([not_ect, ect1, ect0, ce]))
+ }
+
+ /// Whether any of the ECN counts are non-zero.
+ pub fn is_some(&self) -> bool {
+ self[IpTosEcn::Ect0] > 0 || self[IpTosEcn::Ect1] > 0 || self[IpTosEcn::Ce] > 0
+ }
+}
+
+impl Sub<EcnCount> for EcnCount {
+ type Output = EcnCount;
+
+ /// Subtract the ECN counts in `other` from `self`.
+ fn sub(self, other: EcnCount) -> EcnCount {
+ let mut diff = EcnCount::default();
+ for (ecn, count) in &mut *diff {
+ *count = self[ecn].saturating_sub(other[ecn]);
+ }
+ diff
+ }
+}
+
+impl AddAssign<IpTosEcn> for EcnCount {
+ fn add_assign(&mut self, ecn: IpTosEcn) {
+ self[ecn] += 1;
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct EcnInfo {
+ /// The current state of ECN validation on this path.
+ state: EcnValidationState,
+
+ /// The largest ACK seen so far.
+ largest_acked: PacketNumber,
+
+ /// The ECN counts from the last ACK frame that increased `largest_acked`.
+ baseline: EcnCount,
+}
+
+impl EcnInfo {
+ /// Set the baseline (= the ECN counts from the last ACK Frame).
+ pub fn set_baseline(&mut self, baseline: EcnCount) {
+ self.baseline = baseline;
+ }
+
+ /// Expose the current baseline.
+ pub fn baseline(&self) -> EcnCount {
+ self.baseline
+ }
+
+ /// Count the number of packets sent out on this path during ECN validation.
+ /// Exit ECN validation if the number of packets sent exceeds `ECN_TEST_COUNT`.
+ /// We do not implement the part of the RFC that says to exit ECN validation if the time since
+ /// the start of ECN validation exceeds 3 * PTO, since this seems to happen much too quickly.
+ pub fn on_packet_sent(&mut self) {
+ if let EcnValidationState::Testing(ref mut probes_sent) = &mut self.state {
+ *probes_sent += 1;
+ qdebug!("ECN probing: sent {} probes", probes_sent);
+ if *probes_sent == ECN_TEST_COUNT {
+ qdebug!("ECN probing concluded with {} probes sent", probes_sent);
+ self.state = EcnValidationState::Unknown;
+ }
+ }
+ }
+
+ /// Process ECN counts from an ACK frame.
+ ///
+ /// Returns whether ECN counts contain new valid ECN CE marks.
+ pub fn on_packets_acked(
+ &mut self,
+ acked_packets: &[SentPacket],
+ ack_ecn: Option<EcnCount>,
+ ) -> bool {
+ let prev_baseline = self.baseline;
+
+ self.validate_ack_ecn_and_update(acked_packets, ack_ecn);
+
+ matches!(self.state, EcnValidationState::Capable)
+ && (self.baseline - prev_baseline)[IpTosEcn::Ce] > 0
+ }
+
+ /// After the ECN validation test has ended, check if the path is ECN capable.
+ pub fn validate_ack_ecn_and_update(
+ &mut self,
+ acked_packets: &[SentPacket],
+ ack_ecn: Option<EcnCount>,
+ ) {
+ // RFC 9000, Appendix A.4:
+ //
+ // > From the "unknown" state, successful validation of the ECN counts in an ACK frame
+ // > (see Section 13.4.2.1) causes the ECN state for the path to become "capable", unless
+ // > no marked packet has been acknowledged.
+ match self.state {
+ EcnValidationState::Testing { .. } | EcnValidationState::Failed => return,
+ EcnValidationState::Unknown | EcnValidationState::Capable => {}
+ }
+
+ // RFC 9000, Section 13.4.2.1:
+ //
+ // > Validating ECN counts from reordered ACK frames can result in failure. An endpoint MUST
+ // > NOT fail ECN validation as a result of processing an ACK frame that does not increase
+ // > the largest acknowledged packet number.
+ let largest_acked = acked_packets.first().expect("must be there").pn;
+ if largest_acked <= self.largest_acked {
+ return;
+ }
+
+ // RFC 9000, Section 13.4.2.1:
+ //
+ // > An endpoint that receives an ACK frame with ECN counts therefore validates
+ // > the counts before using them. It performs this validation by comparing newly
+ // > received counts against those from the last successfully processed ACK frame.
+ //
+ // > If an ACK frame newly acknowledges a packet that the endpoint sent with
+ // > either the ECT(0) or ECT(1) codepoint set, ECN validation fails if the
+ // > corresponding ECN counts are not present in the ACK frame.
+ let Some(ack_ecn) = ack_ecn else {
+ qwarn!("ECN validation failed, no ECN counts in ACK frame");
+ self.state = EcnValidationState::Failed;
+ return;
+ };
+
+ // We always mark with ECT(0) - if at all - so we only need to check for that.
+ //
+ // > ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is
+ // > less than the number of newly acknowledged packets that were originally sent with an
+ // > ECT(0) marking.
+ let newly_acked_sent_with_ect0: u64 = acked_packets
+ .iter()
+ .filter(|p| p.ecn_mark == IpTosEcn::Ect0)
+ .count()
+ .try_into()
+ .unwrap();
+ if newly_acked_sent_with_ect0 == 0 {
+ qwarn!("ECN validation failed, no ECT(0) packets were newly acked");
+ self.state = EcnValidationState::Failed;
+ return;
+ }
+ let ecn_diff = ack_ecn - self.baseline;
+ let sum_inc = ecn_diff[IpTosEcn::Ect0] + ecn_diff[IpTosEcn::Ce];
+ if sum_inc < newly_acked_sent_with_ect0 {
+ qwarn!(
+ "ECN validation failed, ACK counted {} new marks, but {} of newly acked packets were sent with ECT(0)",
+ sum_inc,
+ newly_acked_sent_with_ect0
+ );
+ self.state = EcnValidationState::Failed;
+ } else if ecn_diff[IpTosEcn::Ect1] > 0 {
+ qwarn!("ECN validation failed, ACK counted ECT(1) marks that were never sent");
+ self.state = EcnValidationState::Failed;
+ } else {
+ qinfo!("ECN validation succeeded, path is capable",);
+ self.state = EcnValidationState::Capable;
+ }
+ self.baseline = ack_ecn;
+ self.largest_acked = largest_acked;
+ }
+
+ /// The ECN mark to use for packets sent on this path.
+ pub fn ecn_mark(&self) -> IpTosEcn {
+ match self.state {
+ EcnValidationState::Testing { .. } | EcnValidationState::Capable => IpTosEcn::Ect0,
+ EcnValidationState::Failed | EcnValidationState::Unknown => IpTosEcn::NotEct,
+ }
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/events.rs b/third_party/rust/neqo-transport/src/events.rs
index a892e384b9..68ef0d6798 100644
--- a/third_party/rust/neqo-transport/src/events.rs
+++ b/third_party/rust/neqo-transport/src/events.rs
@@ -256,7 +256,7 @@ impl EventProvider for ConnectionEvents {
mod tests {
use neqo_common::event::Provider;
- use crate::{ConnectionError, ConnectionEvent, ConnectionEvents, Error, State, StreamId};
+ use crate::{CloseReason, ConnectionEvent, ConnectionEvents, Error, State, StreamId};
#[test]
fn event_culling() {
@@ -314,7 +314,7 @@ mod tests {
evts.send_stream_writable(9.into());
evts.send_stream_stop_sending(10.into(), 55);
- evts.connection_state_change(State::Closed(ConnectionError::Transport(
+ evts.connection_state_change(State::Closed(CloseReason::Transport(
Error::StreamStateError,
)));
assert_eq!(evts.events().count(), 1);
diff --git a/third_party/rust/neqo-transport/src/frame.rs b/third_party/rust/neqo-transport/src/frame.rs
index d84eb61ce8..7d009f3b46 100644
--- a/third_party/rust/neqo-transport/src/frame.rs
+++ b/third_party/rust/neqo-transport/src/frame.rs
@@ -8,13 +8,14 @@
use std::ops::RangeInclusive;
-use neqo_common::{qtrace, Decoder};
+use neqo_common::{qtrace, Decoder, Encoder};
use crate::{
cid::MAX_CONNECTION_ID_LEN,
+ ecn::EcnCount,
packet::PacketType,
stream_id::{StreamId, StreamType},
- AppError, ConnectionError, Error, Res, TransportError,
+ AppError, CloseReason, Error, Res, TransportError,
};
#[allow(clippy::module_name_repetitions)]
@@ -23,7 +24,7 @@ pub type FrameType = u64;
pub const FRAME_TYPE_PADDING: FrameType = 0x0;
pub const FRAME_TYPE_PING: FrameType = 0x1;
pub const FRAME_TYPE_ACK: FrameType = 0x2;
-const FRAME_TYPE_ACK_ECN: FrameType = 0x3;
+pub const FRAME_TYPE_ACK_ECN: FrameType = 0x3;
pub const FRAME_TYPE_RESET_STREAM: FrameType = 0x4;
pub const FRAME_TYPE_STOP_SENDING: FrameType = 0x5;
pub const FRAME_TYPE_CRYPTO: FrameType = 0x6;
@@ -86,11 +87,11 @@ impl CloseError {
}
}
-impl From<ConnectionError> for CloseError {
- fn from(err: ConnectionError) -> Self {
+impl From<CloseReason> for CloseError {
+ fn from(err: CloseReason) -> Self {
match err {
- ConnectionError::Transport(c) => Self::Transport(c.code()),
- ConnectionError::Application(c) => Self::Application(c),
+ CloseReason::Transport(c) => Self::Transport(c.code()),
+ CloseReason::Application(c) => Self::Application(c),
}
}
}
@@ -116,6 +117,7 @@ pub enum Frame<'a> {
ack_delay: u64,
first_ack_range: u64,
ack_ranges: Vec<AckRange>,
+ ecn_count: Option<EcnCount>,
},
ResetStream {
stream_id: StreamId,
@@ -182,7 +184,7 @@ pub enum Frame<'a> {
frame_type: u64,
// Not a reference as we use this to hold the value.
// This is not used in optimized builds anyway.
- reason_phrase: Vec<u8>,
+ reason_phrase: String,
},
HandshakeDone,
AckFrequency {
@@ -224,7 +226,7 @@ impl<'a> Frame<'a> {
match self {
Self::Padding { .. } => FRAME_TYPE_PADDING,
Self::Ping => FRAME_TYPE_PING,
- Self::Ack { .. } => FRAME_TYPE_ACK, // We don't do ACK ECN.
+ Self::Ack { .. } => FRAME_TYPE_ACK,
Self::ResetStream { .. } => FRAME_TYPE_RESET_STREAM,
Self::StopSending { .. } => FRAME_TYPE_STOP_SENDING,
Self::Crypto { .. } => FRAME_TYPE_CRYPTO,
@@ -426,8 +428,54 @@ impl<'a> Frame<'a> {
d(dec.decode_varint())
}
- // TODO(ekr@rtfm.com): check for minimal encoding
+ fn decode_ack<'a>(dec: &mut Decoder<'a>, ecn: bool) -> Res<Frame<'a>> {
+ let la = dv(dec)?;
+ let ad = dv(dec)?;
+ let nr = dv(dec).and_then(|nr| {
+ if nr < MAX_ACK_RANGE_COUNT {
+ Ok(nr)
+ } else {
+ Err(Error::TooMuchData)
+ }
+ })?;
+ let fa = dv(dec)?;
+ let mut arr: Vec<AckRange> = Vec::with_capacity(usize::try_from(nr)?);
+ for _ in 0..nr {
+ let ar = AckRange {
+ gap: dv(dec)?,
+ range: dv(dec)?,
+ };
+ arr.push(ar);
+ }
+
+ // Now check for the values for ACK_ECN.
+ let ecn_count = if ecn {
+ Some(EcnCount::new(0, dv(dec)?, dv(dec)?, dv(dec)?))
+ } else {
+ None
+ };
+
+ Ok(Frame::Ack {
+ largest_acknowledged: la,
+ ack_delay: ad,
+ first_ack_range: fa,
+ ack_ranges: arr,
+ ecn_count,
+ })
+ }
+
+ // Check for minimal encoding of frame type.
+ let pos = dec.offset();
let t = dv(dec)?;
+ // RFC 9000, Section 12.4:
+ //
+ // The Frame Type field uses a variable-length integer encoding [...],
+ // with one exception. To ensure simple and efficient implementations of
+ // frame parsing, a frame type MUST use the shortest possible encoding.
+ if Encoder::varint_len(t) != dec.offset() - pos {
+ return Err(Error::ProtocolViolation);
+ }
+
match t {
FRAME_TYPE_PADDING => {
let mut length: u16 = 1;
@@ -449,40 +497,8 @@ impl<'a> Frame<'a> {
_ => return Err(Error::NoMoreData),
},
}),
- FRAME_TYPE_ACK | FRAME_TYPE_ACK_ECN => {
- let la = dv(dec)?;
- let ad = dv(dec)?;
- let nr = dv(dec).and_then(|nr| {
- if nr < MAX_ACK_RANGE_COUNT {
- Ok(nr)
- } else {
- Err(Error::TooMuchData)
- }
- })?;
- let fa = dv(dec)?;
- let mut arr: Vec<AckRange> = Vec::with_capacity(usize::try_from(nr)?);
- for _ in 0..nr {
- let ar = AckRange {
- gap: dv(dec)?,
- range: dv(dec)?,
- };
- arr.push(ar);
- }
-
- // Now check for the values for ACK_ECN.
- if t == FRAME_TYPE_ACK_ECN {
- dv(dec)?;
- dv(dec)?;
- dv(dec)?;
- }
-
- Ok(Self::Ack {
- largest_acknowledged: la,
- ack_delay: ad,
- first_ack_range: fa,
- ack_ranges: arr,
- })
- }
+ FRAME_TYPE_ACK => decode_ack(dec, false),
+ FRAME_TYPE_ACK_ECN => decode_ack(dec, true),
FRAME_TYPE_STOP_SENDING => Ok(Self::StopSending {
stream_id: StreamId::from(dv(dec)?),
application_error_code: dv(dec)?,
@@ -598,7 +614,7 @@ impl<'a> Frame<'a> {
0
};
// We can tolerate this copy for now.
- let reason_phrase = d(dec.decode_vvec())?.to_vec();
+ let reason_phrase = String::from_utf8_lossy(d(dec.decode_vvec())?).to_string();
Ok(Self::ConnectionClose {
error_code,
frame_type,
@@ -647,13 +663,14 @@ mod tests {
use crate::{
cid::MAX_CONNECTION_ID_LEN,
+ ecn::EcnCount,
frame::{AckRange, Frame, FRAME_TYPE_ACK},
CloseError, Error, StreamId, StreamType,
};
fn just_dec(f: &Frame, s: &str) {
let encoded = Encoder::from_hex(s);
- let decoded = Frame::decode(&mut encoded.as_decoder()).unwrap();
+ let decoded = Frame::decode(&mut encoded.as_decoder()).expect("Failed to decode frame");
assert_eq!(*f, decoded);
}
@@ -679,7 +696,8 @@ mod tests {
largest_acknowledged: 0x1234,
ack_delay: 0x1235,
first_ack_range: 0x1236,
- ack_ranges: ar,
+ ack_ranges: ar.clone(),
+ ecn_count: None,
};
just_dec(&f, "025234523502523601020304");
@@ -689,10 +707,18 @@ mod tests {
let mut dec = enc.as_decoder();
assert_eq!(Frame::decode(&mut dec).unwrap_err(), Error::NoMoreData);
- // Try to parse ACK_ECN without ECN values
+ // Try to parse ACK_ECN with ECN values
+ let ecn_count = Some(EcnCount::new(0, 1, 2, 3));
+ let fe = Frame::Ack {
+ largest_acknowledged: 0x1234,
+ ack_delay: 0x1235,
+ first_ack_range: 0x1236,
+ ack_ranges: ar,
+ ecn_count,
+ };
let enc = Encoder::from_hex("035234523502523601020304010203");
let mut dec = enc.as_decoder();
- assert_eq!(Frame::decode(&mut dec).unwrap(), f);
+ assert_eq!(Frame::decode(&mut dec).unwrap(), fe);
}
#[test]
@@ -899,7 +925,7 @@ mod tests {
let f = Frame::ConnectionClose {
error_code: CloseError::Transport(0x5678),
frame_type: 0x1234,
- reason_phrase: vec![0x01, 0x02, 0x03],
+ reason_phrase: String::from("\x01\x02\x03"),
};
just_dec(&f, "1c80005678523403010203");
@@ -910,7 +936,7 @@ mod tests {
let f = Frame::ConnectionClose {
error_code: CloseError::Application(0x5678),
frame_type: 0,
- reason_phrase: vec![0x01, 0x02, 0x03],
+ reason_phrase: String::from("\x01\x02\x03"),
};
just_dec(&f, "1d8000567803010203");
@@ -989,14 +1015,14 @@ mod tests {
fill: true,
};
- just_dec(&f, "4030010203");
+ just_dec(&f, "30010203");
// With the length bit.
let f = Frame::Datagram {
data: &[1, 2, 3],
fill: false,
};
- just_dec(&f, "403103010203");
+ just_dec(&f, "3103010203");
}
#[test]
@@ -1010,4 +1036,15 @@ mod tests {
assert_eq!(Err(Error::TooMuchData), Frame::decode(&mut e.as_decoder()));
}
+
+ #[test]
+ #[should_panic(expected = "Failed to decode frame")]
+ fn invalid_frame_type_len() {
+ let f = Frame::Datagram {
+ data: &[1, 2, 3],
+ fill: true,
+ };
+
+ just_dec(&f, "4030010203");
+ }
}
diff --git a/third_party/rust/neqo-transport/src/lib.rs b/third_party/rust/neqo-transport/src/lib.rs
index 5488472b58..723a86980e 100644
--- a/third_party/rust/neqo-transport/src/lib.rs
+++ b/third_party/rust/neqo-transport/src/lib.rs
@@ -15,10 +15,17 @@ mod cc;
mod cid;
mod connection;
mod crypto;
+mod ecn;
mod events;
mod fc;
+#[cfg(fuzzing)]
+pub mod frame;
+#[cfg(not(fuzzing))]
mod frame;
mod pace;
+#[cfg(fuzzing)]
+pub mod packet;
+#[cfg(not(fuzzing))]
mod packet;
mod path;
mod qlog;
@@ -202,13 +209,17 @@ impl ::std::fmt::Display for Error {
pub type AppError = u64;
+#[deprecated(note = "use `CloseReason` instead")]
+pub type ConnectionError = CloseReason;
+
+/// Reason why a connection closed.
#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)]
-pub enum ConnectionError {
+pub enum CloseReason {
Transport(Error),
Application(AppError),
}
-impl ConnectionError {
+impl CloseReason {
#[must_use]
pub fn app_code(&self) -> Option<AppError> {
match self {
@@ -216,9 +227,19 @@ impl ConnectionError {
Self::Transport(_) => None,
}
}
+
+ /// Checks enclosed error for [`Error::NoError`] and
+ /// [`CloseReason::Application(0)`].
+ #[must_use]
+ pub fn is_error(&self) -> bool {
+ !matches!(
+ self,
+ CloseReason::Transport(Error::NoError) | CloseReason::Application(0),
+ )
+ }
}
-impl From<CloseError> for ConnectionError {
+impl From<CloseError> for CloseReason {
fn from(err: CloseError) -> Self {
match err {
CloseError::Transport(c) => Self::Transport(Error::PeerError(c)),
diff --git a/third_party/rust/neqo-transport/src/packet/mod.rs b/third_party/rust/neqo-transport/src/packet/mod.rs
index ce611a9664..10d9b13208 100644
--- a/third_party/rust/neqo-transport/src/packet/mod.rs
+++ b/third_party/rust/neqo-transport/src/packet/mod.rs
@@ -740,6 +740,7 @@ impl<'a> PublicPacket<'a> {
}
#[must_use]
+ #[allow(clippy::len_without_is_empty)] // is_empty() would always return false in this case
pub fn len(&self) -> usize {
self.data.len()
}
diff --git a/third_party/rust/neqo-transport/src/path.rs b/third_party/rust/neqo-transport/src/path.rs
index 50e458ff36..0e4c82b1ca 100644
--- a/third_party/rust/neqo-transport/src/path.rs
+++ b/third_party/rust/neqo-transport/src/path.rs
@@ -22,6 +22,7 @@ use crate::{
ackrate::{AckRate, PeerAckDelay},
cc::CongestionControlAlgorithm,
cid::{ConnectionId, ConnectionIdRef, ConnectionIdStore, RemoteConnectionIdEntry},
+ ecn::{EcnCount, EcnInfo},
frame::{FRAME_TYPE_PATH_CHALLENGE, FRAME_TYPE_PATH_RESPONSE, FRAME_TYPE_RETIRE_CONNECTION_ID},
packet::PacketBuilder,
recovery::RecoveryToken,
@@ -145,15 +146,8 @@ impl Paths {
})
}
- /// Get a reference to the primary path. This will assert if there is no primary
- /// path, which happens at a server prior to receiving a valid Initial packet
- /// from a client. So be careful using this method.
- pub fn primary(&self) -> PathRef {
- self.primary_fallible().unwrap()
- }
-
- /// Get a reference to the primary path. Use this prior to handshake completion.
- pub fn primary_fallible(&self) -> Option<PathRef> {
+ /// Get a reference to the primary path, if one exists.
+ pub fn primary(&self) -> Option<PathRef> {
self.primary.clone()
}
@@ -242,6 +236,11 @@ impl Paths {
/// Returns `true` if the path was migrated.
pub fn migrate(&mut self, path: &PathRef, force: bool, now: Instant) -> bool {
debug_assert!(!self.is_temporary(path));
+ let baseline = self.primary().map_or_else(
+ || EcnInfo::default().baseline(),
+ |p| p.borrow().ecn_info.baseline(),
+ );
+ path.borrow_mut().set_ecn_baseline(baseline);
if force || path.borrow().is_valid() {
path.borrow_mut().set_valid(now);
mem::drop(self.select_primary(path));
@@ -307,7 +306,6 @@ impl Paths {
/// Set the identified path to be primary.
/// This panics if `make_permanent` hasn't been called.
pub fn handle_migration(&mut self, path: &PathRef, remote: SocketAddr, now: Instant) {
- qtrace!([self.primary().borrow()], "handle_migration");
// The update here needs to match the checks in `Path::received_on`.
// Here, we update the remote port number to match the source port on the
// datagram that was received. This ensures that we send subsequent
@@ -425,10 +423,10 @@ impl Paths {
stats.retire_connection_id += 1;
}
- // Write out any ACK_FREQUENCY frames.
- self.primary()
- .borrow_mut()
- .write_cc_frames(builder, tokens, stats);
+ if let Some(path) = self.primary() {
+ // Write out any ACK_FREQUENCY frames.
+ path.borrow_mut().write_cc_frames(builder, tokens, stats);
+ }
}
pub fn lost_retire_cid(&mut self, lost: u64) {
@@ -440,11 +438,15 @@ impl Paths {
}
pub fn lost_ack_frequency(&mut self, lost: &AckRate) {
- self.primary().borrow_mut().lost_ack_frequency(lost);
+ if let Some(path) = self.primary() {
+ path.borrow_mut().lost_ack_frequency(lost);
+ }
}
pub fn acked_ack_frequency(&mut self, acked: &AckRate) {
- self.primary().borrow_mut().acked_ack_frequency(acked);
+ if let Some(path) = self.primary() {
+ path.borrow_mut().acked_ack_frequency(acked);
+ }
}
/// Get an estimate of the RTT on the primary path.
@@ -454,7 +456,7 @@ impl Paths {
// make a new RTT esimate and interrogate that.
// That is more expensive, but it should be rare and breaking encapsulation
// is worse, especially as this is only used in tests.
- self.primary_fallible()
+ self.primary()
.map_or(RttEstimate::default().estimate(), |p| {
p.borrow().rtt().estimate()
})
@@ -532,8 +534,6 @@ pub struct Path {
rtt: RttEstimate,
/// A packet sender for the path, which includes congestion control and a pacer.
sender: PacketSender,
- /// The DSCP/ECN marking to use for outgoing packets on this path.
- tos: IpTos,
/// The IP TTL to use for outgoing packets on this path.
ttl: u8,
@@ -543,7 +543,8 @@ pub struct Path {
received_bytes: usize,
/// The number of bytes sent on this path.
sent_bytes: usize,
-
+ /// The ECN-related state for this path (see RFC9000, Section 13.4 and Appendix A.4)
+ ecn_info: EcnInfo,
/// For logging of events.
qlog: NeqoQlog,
}
@@ -572,14 +573,23 @@ impl Path {
challenge: None,
rtt: RttEstimate::default(),
sender,
- tos: IpTos::default(), // TODO: Default to Ect0 when ECN is supported.
- ttl: 64, // This is the default TTL on many OSes.
+ ttl: 64, // This is the default TTL on many OSes.
received_bytes: 0,
sent_bytes: 0,
+ ecn_info: EcnInfo::default(),
qlog,
}
}
+ pub fn set_ecn_baseline(&mut self, baseline: EcnCount) {
+ self.ecn_info.set_baseline(baseline);
+ }
+
+ /// Return the DSCP/ECN marking to use for outgoing packets on this path.
+ pub fn tos(&self) -> IpTos {
+ self.ecn_info.ecn_mark().into()
+ }
+
/// Whether this path is the primary or current path for the connection.
pub fn is_primary(&self) -> bool {
self.primary
@@ -695,8 +705,9 @@ impl Path {
}
/// Make a datagram.
- pub fn datagram<V: Into<Vec<u8>>>(&self, payload: V) -> Datagram {
- Datagram::new(self.local, self.remote, self.tos, Some(self.ttl), payload)
+ pub fn datagram<V: Into<Vec<u8>>>(&mut self, payload: V) -> Datagram {
+ self.ecn_info.on_packet_sent();
+ Datagram::new(self.local, self.remote, self.tos(), Some(self.ttl), payload)
}
/// Get local address as `SocketAddr`
@@ -959,8 +970,24 @@ impl Path {
}
/// Record packets as acknowledged with the sender.
- pub fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], now: Instant) {
+ pub fn on_packets_acked(
+ &mut self,
+ acked_pkts: &[SentPacket],
+ ack_ecn: Option<EcnCount>,
+ now: Instant,
+ ) {
debug_assert!(self.is_primary());
+
+ let ecn_ce_received = self.ecn_info.on_packets_acked(acked_pkts, ack_ecn);
+ if ecn_ce_received {
+ let cwnd_reduced = self
+ .sender
+ .on_ecn_ce_received(acked_pkts.first().expect("must be there"));
+ if cwnd_reduced {
+ self.rtt.update_ack_delay(self.sender.cwnd(), self.mtu());
+ }
+ }
+
self.sender.on_packets_acked(acked_pkts, &self.rtt, now);
}
diff --git a/third_party/rust/neqo-transport/src/qlog.rs b/third_party/rust/neqo-transport/src/qlog.rs
index a8ad986d2a..715ba85e81 100644
--- a/third_party/rust/neqo-transport/src/qlog.rs
+++ b/third_party/rust/neqo-transport/src/qlog.rs
@@ -11,7 +11,7 @@ use std::{
time::Duration,
};
-use neqo_common::{hex, qinfo, qlog::NeqoQlog, Decoder};
+use neqo_common::{hex, qinfo, qlog::NeqoQlog, Decoder, IpTosEcn};
use qlog::events::{
connectivity::{ConnectionStarted, ConnectionState, ConnectionStateUpdated},
quic::{
@@ -205,7 +205,7 @@ pub fn packet_sent(
let mut frames = SmallVec::new();
while d.remaining() > 0 {
if let Ok(f) = Frame::decode(&mut d) {
- frames.push(QuicFrame::from(&f));
+ frames.push(QuicFrame::from(f));
} else {
qinfo!("qlog: invalid frame");
break;
@@ -293,7 +293,7 @@ pub fn packet_received(
while d.remaining() > 0 {
if let Ok(f) = Frame::decode(&mut d) {
- frames.push(QuicFrame::from(&f));
+ frames.push(QuicFrame::from(f));
} else {
qinfo!("qlog: invalid frame");
break;
@@ -387,21 +387,26 @@ pub fn metrics_updated(qlog: &mut NeqoQlog, updated_metrics: &[QlogMetric]) {
#[allow(clippy::too_many_lines)] // Yeah, but it's a nice match.
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] // No choice here.
-impl From<&Frame<'_>> for QuicFrame {
- fn from(frame: &Frame) -> Self {
+impl From<Frame<'_>> for QuicFrame {
+ fn from(frame: Frame) -> Self {
match frame {
- // TODO: Add payload length to `QuicFrame::Padding` once
- // https://github.com/cloudflare/quiche/pull/1745 is available via the qlog crate.
- Frame::Padding { .. } => QuicFrame::Padding,
- Frame::Ping => QuicFrame::Ping,
+ Frame::Padding(len) => QuicFrame::Padding {
+ length: None,
+ payload_length: u32::from(len),
+ },
+ Frame::Ping => QuicFrame::Ping {
+ length: None,
+ payload_length: None,
+ },
Frame::Ack {
largest_acknowledged,
ack_delay,
first_ack_range,
ack_ranges,
+ ecn_count,
} => {
let ranges =
- Frame::decode_ack_frame(*largest_acknowledged, *first_ack_range, ack_ranges)
+ Frame::decode_ack_frame(largest_acknowledged, first_ack_range, &ack_ranges)
.ok();
let acked_ranges = ranges.map(|all| {
@@ -413,11 +418,13 @@ impl From<&Frame<'_>> for QuicFrame {
});
QuicFrame::Ack {
- ack_delay: Some(*ack_delay as f32 / 1000.0),
+ ack_delay: Some(ack_delay as f32 / 1000.0),
acked_ranges,
- ect1: None,
- ect0: None,
- ce: None,
+ ect1: ecn_count.map(|c| c[IpTosEcn::Ect1]),
+ ect0: ecn_count.map(|c| c[IpTosEcn::Ect0]),
+ ce: ecn_count.map(|c| c[IpTosEcn::Ce]),
+ length: None,
+ payload_length: None,
}
}
Frame::ResetStream {
@@ -426,18 +433,22 @@ impl From<&Frame<'_>> for QuicFrame {
final_size,
} => QuicFrame::ResetStream {
stream_id: stream_id.as_u64(),
- error_code: *application_error_code,
- final_size: *final_size,
+ error_code: application_error_code,
+ final_size,
+ length: None,
+ payload_length: None,
},
Frame::StopSending {
stream_id,
application_error_code,
} => QuicFrame::StopSending {
stream_id: stream_id.as_u64(),
- error_code: *application_error_code,
+ error_code: application_error_code,
+ length: None,
+ payload_length: None,
},
Frame::Crypto { offset, data } => QuicFrame::Crypto {
- offset: *offset,
+ offset,
length: data.len() as u64,
},
Frame::NewToken { token } => QuicFrame::NewToken {
@@ -459,20 +470,20 @@ impl From<&Frame<'_>> for QuicFrame {
..
} => QuicFrame::Stream {
stream_id: stream_id.as_u64(),
- offset: *offset,
+ offset,
length: data.len() as u64,
- fin: Some(*fin),
+ fin: Some(fin),
raw: None,
},
Frame::MaxData { maximum_data } => QuicFrame::MaxData {
- maximum: *maximum_data,
+ maximum: maximum_data,
},
Frame::MaxStreamData {
stream_id,
maximum_stream_data,
} => QuicFrame::MaxStreamData {
stream_id: stream_id.as_u64(),
- maximum: *maximum_stream_data,
+ maximum: maximum_stream_data,
},
Frame::MaxStreams {
stream_type,
@@ -482,15 +493,15 @@ impl From<&Frame<'_>> for QuicFrame {
NeqoStreamType::BiDi => StreamType::Bidirectional,
NeqoStreamType::UniDi => StreamType::Unidirectional,
},
- maximum: *maximum_streams,
+ maximum: maximum_streams,
},
- Frame::DataBlocked { data_limit } => QuicFrame::DataBlocked { limit: *data_limit },
+ Frame::DataBlocked { data_limit } => QuicFrame::DataBlocked { limit: data_limit },
Frame::StreamDataBlocked {
stream_id,
stream_data_limit,
} => QuicFrame::StreamDataBlocked {
stream_id: stream_id.as_u64(),
- limit: *stream_data_limit,
+ limit: stream_data_limit,
},
Frame::StreamsBlocked {
stream_type,
@@ -500,7 +511,7 @@ impl From<&Frame<'_>> for QuicFrame {
NeqoStreamType::BiDi => StreamType::Bidirectional,
NeqoStreamType::UniDi => StreamType::Unidirectional,
},
- limit: *stream_limit,
+ limit: stream_limit,
},
Frame::NewConnectionId {
sequence_number,
@@ -508,14 +519,14 @@ impl From<&Frame<'_>> for QuicFrame {
connection_id,
stateless_reset_token,
} => QuicFrame::NewConnectionId {
- sequence_number: *sequence_number as u32,
- retire_prior_to: *retire_prior as u32,
+ sequence_number: sequence_number as u32,
+ retire_prior_to: retire_prior as u32,
connection_id_length: Some(connection_id.len() as u8),
connection_id: hex(connection_id),
stateless_reset_token: Some(hex(stateless_reset_token)),
},
Frame::RetireConnectionId { sequence_number } => QuicFrame::RetireConnectionId {
- sequence_number: *sequence_number as u32,
+ sequence_number: sequence_number as u32,
},
Frame::PathChallenge { data } => QuicFrame::PathChallenge {
data: Some(hex(data)),
@@ -534,8 +545,8 @@ impl From<&Frame<'_>> for QuicFrame {
},
error_code: Some(error_code.code()),
error_code_value: Some(0),
- reason: Some(String::from_utf8_lossy(reason_phrase).to_string()),
- trigger_frame_type: Some(*frame_type),
+ reason: Some(reason_phrase),
+ trigger_frame_type: Some(frame_type),
},
Frame::HandshakeDone => QuicFrame::HandshakeDone,
Frame::AckFrequency { .. } => QuicFrame::Unknown {
diff --git a/third_party/rust/neqo-transport/src/recovery.rs b/third_party/rust/neqo-transport/src/recovery.rs
index dbea3aaf57..22a635d9f3 100644
--- a/third_party/rust/neqo-transport/src/recovery.rs
+++ b/third_party/rust/neqo-transport/src/recovery.rs
@@ -21,6 +21,7 @@ use crate::{
ackrate::AckRate,
cid::ConnectionIdEntry,
crypto::CryptoRecoveryToken,
+ ecn::EcnCount,
packet::PacketNumber,
path::{Path, PathRef},
qlog::{self, QlogMetric},
@@ -665,12 +666,14 @@ impl LossRecovery {
}
/// Returns (acked packets, lost packets)
+ #[allow(clippy::too_many_arguments)]
pub fn on_ack_received<R>(
&mut self,
primary_path: &PathRef,
pn_space: PacketNumberSpace,
largest_acked: u64,
acked_ranges: R,
+ ack_ecn: Option<EcnCount>,
ack_delay: Duration,
now: Instant,
) -> (Vec<SentPacket>, Vec<SentPacket>)
@@ -692,10 +695,10 @@ impl LossRecovery {
let (acked_packets, any_ack_eliciting) =
space.remove_acked(acked_ranges, &mut self.stats.borrow_mut());
- if acked_packets.is_empty() {
+ let Some(largest_acked_pkt) = acked_packets.first() else {
// No new information.
return (Vec::new(), Vec::new());
- }
+ };
// Track largest PN acked per space
let prev_largest_acked = space.largest_acked_sent_time;
@@ -704,7 +707,6 @@ impl LossRecovery {
// If the largest acknowledged is newly acked and any newly acked
// packet was ack-eliciting, update the RTT. (-recovery 5.1)
- let largest_acked_pkt = acked_packets.first().expect("must be there");
space.largest_acked_sent_time = Some(largest_acked_pkt.time_sent);
if any_ack_eliciting && largest_acked_pkt.on_primary_path() {
self.rtt_sample(
@@ -744,7 +746,7 @@ impl LossRecovery {
// when it shouldn't.
primary_path
.borrow_mut()
- .on_packets_acked(&acked_packets, now);
+ .on_packets_acked(&acked_packets, ack_ecn, now);
self.pto_state = None;
@@ -1022,7 +1024,7 @@ mod tests {
time::{Duration, Instant},
};
- use neqo_common::qlog::NeqoQlog;
+ use neqo_common::{qlog::NeqoQlog, IpTosEcn};
use test_fixture::{now, DEFAULT_ADDR};
use super::{
@@ -1031,6 +1033,7 @@ mod tests {
use crate::{
cc::CongestionControlAlgorithm,
cid::{ConnectionId, ConnectionIdEntry},
+ ecn::EcnCount,
packet::PacketType,
path::{Path, PathRef},
rtt::RttEstimate,
@@ -1060,6 +1063,7 @@ mod tests {
pn_space: PacketNumberSpace,
largest_acked: u64,
acked_ranges: Vec<RangeInclusive<u64>>,
+ ack_ecn: Option<EcnCount>,
ack_delay: Duration,
now: Instant,
) -> (Vec<SentPacket>, Vec<SentPacket>) {
@@ -1068,6 +1072,7 @@ mod tests {
pn_space,
largest_acked,
acked_ranges,
+ ack_ecn,
ack_delay,
now,
)
@@ -1208,6 +1213,7 @@ mod tests {
lr.on_packet_sent(SentPacket::new(
PacketType::Short,
pn,
+ IpTosEcn::default(),
pn_time(pn),
true,
Vec::new(),
@@ -1223,6 +1229,7 @@ mod tests {
PacketNumberSpace::ApplicationData,
pn,
vec![pn..=pn],
+ None,
ACK_DELAY,
pn_time(pn) + delay,
);
@@ -1233,6 +1240,7 @@ mod tests {
lrs.on_packet_sent(SentPacket::new(
PacketType::Short,
pn,
+ IpTosEcn::default(),
pn_time(pn),
true,
Vec::new(),
@@ -1353,6 +1361,7 @@ mod tests {
lr.on_packet_sent(SentPacket::new(
PacketType::Short,
0,
+ IpTosEcn::default(),
pn_time(0),
true,
Vec::new(),
@@ -1361,6 +1370,7 @@ mod tests {
lr.on_packet_sent(SentPacket::new(
PacketType::Short,
1,
+ IpTosEcn::default(),
pn_time(0) + TEST_RTT / 4,
true,
Vec::new(),
@@ -1370,6 +1380,7 @@ mod tests {
PacketNumberSpace::ApplicationData,
1,
vec![1..=1],
+ None,
ACK_DELAY,
pn_time(0) + (TEST_RTT * 5 / 4),
);
@@ -1393,6 +1404,7 @@ mod tests {
PacketNumberSpace::ApplicationData,
2,
vec![2..=2],
+ None,
ACK_DELAY,
pn2_ack_time,
);
@@ -1422,6 +1434,7 @@ mod tests {
PacketNumberSpace::ApplicationData,
4,
vec![2..=4],
+ None,
ACK_DELAY,
pn_time(4),
);
@@ -1450,6 +1463,7 @@ mod tests {
PacketNumberSpace::Initial,
0,
vec![],
+ None,
Duration::from_millis(0),
pn_time(0),
);
@@ -1463,6 +1477,7 @@ mod tests {
lr.on_packet_sent(SentPacket::new(
PacketType::Initial,
0,
+ IpTosEcn::default(),
pn_time(0),
true,
Vec::new(),
@@ -1471,6 +1486,7 @@ mod tests {
lr.on_packet_sent(SentPacket::new(
PacketType::Handshake,
0,
+ IpTosEcn::default(),
pn_time(1),
true,
Vec::new(),
@@ -1479,6 +1495,7 @@ mod tests {
lr.on_packet_sent(SentPacket::new(
PacketType::Short,
0,
+ IpTosEcn::default(),
pn_time(2),
true,
Vec::new(),
@@ -1491,10 +1508,25 @@ mod tests {
PacketType::Handshake,
PacketType::Short,
] {
- let sent_pkt = SentPacket::new(*sp, 1, pn_time(3), true, Vec::new(), ON_SENT_SIZE);
+ let sent_pkt = SentPacket::new(
+ *sp,
+ 1,
+ IpTosEcn::default(),
+ pn_time(3),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ );
let pn_space = PacketNumberSpace::from(sent_pkt.pt);
lr.on_packet_sent(sent_pkt);
- lr.on_ack_received(pn_space, 1, vec![1..=1], Duration::from_secs(0), pn_time(3));
+ lr.on_ack_received(
+ pn_space,
+ 1,
+ vec![1..=1],
+ None,
+ Duration::from_secs(0),
+ pn_time(3),
+ );
let mut lost = Vec::new();
lr.spaces.get_mut(pn_space).unwrap().detect_lost_packets(
pn_time(3),
@@ -1516,6 +1548,7 @@ mod tests {
lr.on_packet_sent(SentPacket::new(
PacketType::Initial,
0,
+ IpTosEcn::default(),
pn_time(3),
true,
Vec::new(),
@@ -1530,6 +1563,7 @@ mod tests {
lr.on_packet_sent(SentPacket::new(
PacketType::Initial,
0,
+ IpTosEcn::default(),
now(),
true,
Vec::new(),
@@ -1542,6 +1576,7 @@ mod tests {
PacketNumberSpace::Initial,
0,
vec![0..=0],
+ None,
Duration::new(0, 0),
now() + rtt,
);
@@ -1549,6 +1584,7 @@ mod tests {
lr.on_packet_sent(SentPacket::new(
PacketType::Handshake,
0,
+ IpTosEcn::default(),
now(),
true,
Vec::new(),
@@ -1557,6 +1593,7 @@ mod tests {
lr.on_packet_sent(SentPacket::new(
PacketType::Short,
0,
+ IpTosEcn::default(),
now(),
true,
Vec::new(),
@@ -1594,6 +1631,7 @@ mod tests {
lr.on_packet_sent(SentPacket::new(
PacketType::Initial,
1,
+ IpTosEcn::default(),
now(),
true,
Vec::new(),
diff --git a/third_party/rust/neqo-transport/src/send_stream.rs b/third_party/rust/neqo-transport/src/send_stream.rs
index 8771ec7765..98476e9d18 100644
--- a/third_party/rust/neqo-transport/src/send_stream.rs
+++ b/third_party/rust/neqo-transport/src/send_stream.rs
@@ -1269,7 +1269,7 @@ impl SendStream {
return Err(Error::FinalSizeError);
}
- let buf = if buf.is_empty() || (self.avail() == 0) {
+ let buf = if self.avail() == 0 {
return Ok(0);
} else if self.avail() < buf.len() {
if atomic {
@@ -1634,20 +1634,16 @@ impl SendStreams {
}
pub fn remove_terminal(&mut self) {
- let map: &mut IndexMap<StreamId, SendStream> = &mut self.map;
- let regular: &mut OrderGroup = &mut self.regular;
- let sendordered: &mut BTreeMap<SendOrder, OrderGroup> = &mut self.sendordered;
-
- // Take refs to all the items we need to modify instead of &mut
- // self to keep the compiler happy (if we use self.map.retain it
- // gets upset due to borrows)
- map.retain(|stream_id, stream| {
+ self.map.retain(|stream_id, stream| {
if stream.is_terminal() {
if stream.is_fair() {
match stream.sendorder() {
- None => regular.remove(*stream_id),
+ None => self.regular.remove(*stream_id),
Some(sendorder) => {
- sendordered.get_mut(&sendorder).unwrap().remove(*stream_id);
+ self.sendordered
+ .get_mut(&sendorder)
+ .unwrap()
+ .remove(*stream_id);
}
};
}
diff --git a/third_party/rust/neqo-transport/src/sender.rs b/third_party/rust/neqo-transport/src/sender.rs
index 3a54851533..abb14d0a25 100644
--- a/third_party/rust/neqo-transport/src/sender.rs
+++ b/third_party/rust/neqo-transport/src/sender.rs
@@ -97,6 +97,11 @@ impl PacketSender {
)
}
+ /// Called when ECN CE mark received. Returns true if the congestion window was reduced.
+ pub fn on_ecn_ce_received(&mut self, largest_acked_pkt: &SentPacket) -> bool {
+ self.cc.on_ecn_ce_received(largest_acked_pkt)
+ }
+
pub fn discard(&mut self, pkt: &SentPacket) {
self.cc.discard(pkt);
}
diff --git a/third_party/rust/neqo-transport/src/server.rs b/third_party/rust/neqo-transport/src/server.rs
index 96a6244ef1..60909d71e1 100644
--- a/third_party/rust/neqo-transport/src/server.rs
+++ b/third_party/rust/neqo-transport/src/server.rs
@@ -689,6 +689,13 @@ impl Server {
mem::take(&mut self.active).into_iter().collect()
}
+ /// Whether any connections have received new events as a result of calling
+ /// `process()`.
+ #[must_use]
+ pub fn has_active_connections(&self) -> bool {
+ !self.active.is_empty()
+ }
+
pub fn add_to_waiting(&mut self, c: &ActiveConnectionRef) {
self.waiting.push_back(c.connection());
}
diff --git a/third_party/rust/neqo-transport/src/tracking.rs b/third_party/rust/neqo-transport/src/tracking.rs
index bdd0f250c7..6643d516e3 100644
--- a/third_party/rust/neqo-transport/src/tracking.rs
+++ b/third_party/rust/neqo-transport/src/tracking.rs
@@ -13,18 +13,21 @@ use std::{
time::{Duration, Instant},
};
-use neqo_common::{qdebug, qinfo, qtrace, qwarn};
+use enum_map::Enum;
+use neqo_common::{qdebug, qinfo, qtrace, qwarn, IpTosEcn};
use neqo_crypto::{Epoch, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL};
use smallvec::{smallvec, SmallVec};
use crate::{
+ ecn::EcnCount,
+ frame::{FRAME_TYPE_ACK, FRAME_TYPE_ACK_ECN},
packet::{PacketBuilder, PacketNumber, PacketType},
recovery::RecoveryToken,
stats::FrameStats,
};
// TODO(mt) look at enabling EnumMap for this: https://stackoverflow.com/a/44905797/1375574
-#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq)]
+#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq, Enum)]
pub enum PacketNumberSpace {
Initial,
Handshake,
@@ -134,6 +137,7 @@ impl std::fmt::Debug for PacketNumberSpaceSet {
pub struct SentPacket {
pub pt: PacketType,
pub pn: PacketNumber,
+ pub ecn_mark: IpTosEcn,
ack_eliciting: bool,
pub time_sent: Instant,
primary_path: bool,
@@ -150,6 +154,7 @@ impl SentPacket {
pub fn new(
pt: PacketType,
pn: PacketNumber,
+ ecn_mark: IpTosEcn,
time_sent: Instant,
ack_eliciting: bool,
tokens: Vec<RecoveryToken>,
@@ -158,6 +163,7 @@ impl SentPacket {
Self {
pt,
pn,
+ ecn_mark,
time_sent,
ack_eliciting,
primary_path: true,
@@ -377,6 +383,8 @@ pub struct RecvdPackets {
/// Whether we are ignoring packets that arrive out of order
/// for the purposes of generating immediate acknowledgment.
ignore_order: bool,
+ // The counts of different ECN marks that have been received.
+ ecn_count: EcnCount,
}
impl RecvdPackets {
@@ -394,9 +402,15 @@ impl RecvdPackets {
unacknowledged_count: 0,
unacknowledged_tolerance: DEFAULT_ACK_PACKET_TOLERANCE,
ignore_order: false,
+ ecn_count: EcnCount::default(),
}
}
+ /// Get the ECN counts.
+ pub fn ecn_marks(&mut self) -> &mut EcnCount {
+ &mut self.ecn_count
+ }
+
/// Get the time at which the next ACK should be sent.
pub fn ack_time(&self) -> Option<Instant> {
self.ack_time
@@ -545,6 +559,10 @@ impl RecvdPackets {
}
}
+ /// Length of the worst possible ACK frame, assuming only one range and ECN counts.
+ /// Note that this assumes one byte for the type and count of extra ranges.
+ pub const USEFUL_ACK_LEN: usize = 1 + 8 + 8 + 1 + 8 + 3 * 8;
+
/// Generate an ACK frame for this packet number space.
///
/// Unlike other frame generators this doesn't modify the underlying instance
@@ -563,10 +581,6 @@ impl RecvdPackets {
tokens: &mut Vec<RecoveryToken>,
stats: &mut FrameStats,
) {
- // The worst possible ACK frame, assuming only one range.
- // Note that this assumes one byte for the type and count of extra ranges.
- const LONGEST_ACK_HEADER: usize = 1 + 8 + 8 + 1 + 8;
-
// Check that we aren't delaying ACKs.
if !self.ack_now(now, rtt) {
return;
@@ -578,7 +592,10 @@ impl RecvdPackets {
// When congestion limited, ACK-only packets are 255 bytes at most
// (`recovery::ACK_ONLY_SIZE_LIMIT - 1`). This results in limiting the
// ranges to 13 here.
- let max_ranges = if let Some(avail) = builder.remaining().checked_sub(LONGEST_ACK_HEADER) {
+ let max_ranges = if let Some(avail) = builder
+ .remaining()
+ .checked_sub(RecvdPackets::USEFUL_ACK_LEN)
+ {
// Apply a hard maximum to keep plenty of space for other stuff.
min(1 + (avail / 16), MAX_ACKS_PER_FRAME)
} else {
@@ -593,7 +610,11 @@ impl RecvdPackets {
.cloned()
.collect::<Vec<_>>();
- builder.encode_varint(crate::frame::FRAME_TYPE_ACK);
+ builder.encode_varint(if self.ecn_count.is_some() {
+ FRAME_TYPE_ACK_ECN
+ } else {
+ FRAME_TYPE_ACK
+ });
let mut iter = ranges.iter();
let Some(first) = iter.next() else { return };
builder.encode_varint(first.largest);
@@ -617,6 +638,12 @@ impl RecvdPackets {
last = r.smallest;
}
+ if self.ecn_count.is_some() {
+ builder.encode_varint(self.ecn_count[IpTosEcn::Ect0]);
+ builder.encode_varint(self.ecn_count[IpTosEcn::Ect1]);
+ builder.encode_varint(self.ecn_count[IpTosEcn::Ce]);
+ }
+
// We've sent an ACK, reset the timer.
self.ack_time = None;
self.last_ack_time = Some(now);
@@ -1134,7 +1161,9 @@ mod tests {
.is_some());
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
- builder.set_limit(32);
+ // The code pessimistically assumes that each range needs 16 bytes to express.
+ // So this won't be enough for a second range.
+ builder.set_limit(RecvdPackets::USEFUL_ACK_LEN + 8);
let mut stats = FrameStats::default();
tracker.write_frame(