diff options
Diffstat (limited to 'third_party/rust/neqo-transport/src/connection')
12 files changed, 690 insertions, 193 deletions
diff --git a/third_party/rust/neqo-transport/src/connection/mod.rs b/third_party/rust/neqo-transport/src/connection/mod.rs index 8522507a69..f955381414 100644 --- a/third_party/rust/neqo-transport/src/connection/mod.rs +++ b/third_party/rust/neqo-transport/src/connection/mod.rs @@ -19,7 +19,7 @@ use std::{ use neqo_common::{ event::Provider as EventProvider, hex, hex_snip_middle, hrtime, qdebug, qerror, qinfo, - qlog::NeqoQlog, qtrace, qwarn, Datagram, Decoder, Encoder, IpTos, Role, + qlog::NeqoQlog, qtrace, qwarn, Datagram, Decoder, Encoder, Role, }; use neqo_crypto::{ agent::CertificateInfo, Agent, AntiReplay, AuthenticationStatus, Cipher, Client, Group, @@ -35,6 +35,7 @@ use crate::{ ConnectionIdRef, ConnectionIdStore, LOCAL_ACTIVE_CID_LIMIT, }, crypto::{Crypto, CryptoDxState, CryptoSpace}, + ecn::EcnCount, events::{ConnectionEvent, ConnectionEvents, OutgoingDatagramOutcome}, frame::{ CloseError, Frame, FrameType, FRAME_TYPE_CONNECTION_CLOSE_APPLICATION, @@ -46,7 +47,7 @@ use crate::{ quic_datagrams::{DatagramTracking, QuicDatagrams}, recovery::{LossRecovery, RecoveryToken, SendProfile}, recv_stream::RecvStreamStats, - rtt::GRANULARITY, + rtt::{RttEstimate, GRANULARITY}, send_stream::SendStream, stats::{Stats, StatsCell}, stream_id::StreamType, @@ -55,9 +56,9 @@ use crate::{ self, TransportParameter, TransportParameterId, TransportParameters, TransportParametersHandler, }, - tracking::{AckTracker, PacketNumberSpace, SentPacket}, + tracking::{AckTracker, PacketNumberSpace, RecvdPackets, SentPacket}, version::{Version, WireVersion}, - AppError, ConnectionError, Error, Res, StreamId, + AppError, CloseReason, Error, Res, StreamId, }; mod dump; @@ -291,7 +292,7 @@ impl Debug for Connection { "{:?} Connection: {:?} {:?}", self.role, self.state, - self.paths.primary_fallible() + self.paths.primary() ) } } @@ -591,7 +592,11 @@ impl Connection { fn make_resumption_token(&mut self) -> ResumptionToken { debug_assert_eq!(self.role, Role::Client); debug_assert!(self.crypto.has_resumption_token()); - let rtt = self.paths.primary().borrow().rtt().estimate(); + let rtt = self.paths.primary().map_or_else( + || RttEstimate::default().estimate(), + |p| p.borrow().rtt().estimate(), + ); + self.crypto .create_resumption_token( self.new_token.take_token(), @@ -610,11 +615,10 @@ impl Connection { /// a value of this approximate order. Don't use this for loss recovery, /// only use it where a more precise value is not important. fn pto(&self) -> Duration { - self.paths - .primary() - .borrow() - .rtt() - .pto(PacketNumberSpace::ApplicationData) + self.paths.primary().map_or_else( + || RttEstimate::default().pto(PacketNumberSpace::ApplicationData), + |p| p.borrow().rtt().pto(PacketNumberSpace::ApplicationData), + ) } fn create_resumption_token(&mut self, now: Instant) { @@ -746,7 +750,12 @@ impl Connection { if !init_token.is_empty() { self.address_validation = AddressValidationInfo::NewToken(init_token.to_vec()); } - self.paths.primary().borrow_mut().rtt_mut().set_initial(rtt); + self.paths + .primary() + .ok_or(Error::InternalError)? + .borrow_mut() + .rtt_mut() + .set_initial(rtt); self.set_initial_limits(); // Start up TLS, which has the effect of setting up all the necessary // state for 0-RTT. This only stages the CRYPTO frames. @@ -786,7 +795,7 @@ impl Connection { // If we are able, also send a NEW_TOKEN frame. // This should be recording all remote addresses that are valid, // but there are just 0 or 1 in the current implementation. - if let Some(path) = self.paths.primary_fallible() { + if let Some(path) = self.paths.primary() { if let Some(token) = self .address_validation .generate_new_token(path.borrow().remote_address(), now) @@ -858,7 +867,7 @@ impl Connection { #[must_use] pub fn stats(&self) -> Stats { let mut v = self.stats.borrow().clone(); - if let Some(p) = self.paths.primary_fallible() { + if let Some(p) = self.paths.primary() { let p = p.borrow(); v.rtt = p.rtt().estimate(); v.rttvar = p.rtt().rttvar(); @@ -880,7 +889,7 @@ impl Connection { let msg = format!("{v:?}"); #[cfg(not(debug_assertions))] let msg = ""; - let error = ConnectionError::Transport(v.clone()); + let error = CloseReason::Transport(v.clone()); match &self.state { State::Closing { error: err, .. } | State::Draining { error: err, .. } @@ -895,14 +904,14 @@ impl Connection { State::WaitInitial => { // We don't have any state yet, so don't bother with // the closing state, just send one CONNECTION_CLOSE. - if let Some(path) = path.or_else(|| self.paths.primary_fallible()) { + if let Some(path) = path.or_else(|| self.paths.primary()) { self.state_signaling .close(path, error.clone(), frame_type, msg); } self.set_state(State::Closed(error)); } _ => { - if let Some(path) = path.or_else(|| self.paths.primary_fallible()) { + if let Some(path) = path.or_else(|| self.paths.primary()) { self.state_signaling .close(path, error.clone(), frame_type, msg); if matches!(v, Error::KeysExhausted) { @@ -951,9 +960,7 @@ impl Connection { let pto = self.pto(); if self.idle_timeout.expired(now, pto) { qinfo!([self], "idle timeout expired"); - self.set_state(State::Closed(ConnectionError::Transport( - Error::IdleTimeout, - ))); + self.set_state(State::Closed(CloseReason::Transport(Error::IdleTimeout))); return; } @@ -962,9 +969,11 @@ impl Connection { let res = self.crypto.states.check_key_update(now); self.absorb_error(now, res); - let lost = self.loss_recovery.timeout(&self.paths.primary(), now); - self.handle_lost_packets(&lost); - qlog::packets_lost(&mut self.qlog, &lost); + if let Some(path) = self.paths.primary() { + let lost = self.loss_recovery.timeout(&path, now); + self.handle_lost_packets(&lost); + qlog::packets_lost(&mut self.qlog, &lost); + } if self.release_resumption_token_timer.is_some() { self.create_resumption_token(now); @@ -1014,7 +1023,7 @@ impl Connection { delays.push(ack_time); } - if let Some(p) = self.paths.primary_fallible() { + if let Some(p) = self.paths.primary() { let path = p.borrow(); let rtt = path.rtt(); let pto = rtt.pto(PacketNumberSpace::ApplicationData); @@ -1102,7 +1111,15 @@ impl Connection { self.input(d, now, now); self.process_saved(now); } - self.process_output(now) + #[allow(clippy::let_and_return)] + let output = self.process_output(now); + #[cfg(all(feature = "build-fuzzing-corpus", test))] + if self.test_frame_writer.is_none() { + if let Some(d) = output.clone().dgram() { + neqo_common::write_item_to_fuzzing_corpus("packet", &d); + } + } + output } fn handle_retry(&mut self, packet: &PublicPacket, now: Instant) { @@ -1123,7 +1140,13 @@ impl Connection { } // At this point, we should only have the connection ID that we generated. // Update to the one that the server prefers. - let path = self.paths.primary(); + let Some(path) = self.paths.primary() else { + self.stats + .borrow_mut() + .pkt_dropped("Retry without an existing path"); + return; + }; + path.borrow_mut().set_remote_cid(packet.scid()); let retry_scid = ConnectionId::from(packet.scid()); @@ -1151,8 +1174,9 @@ impl Connection { fn discard_keys(&mut self, space: PacketNumberSpace, now: Instant) { if self.crypto.discard(space) { qdebug!([self], "Drop packet number space {}", space); - let primary = self.paths.primary(); - self.loss_recovery.discard(&primary, space, now); + if let Some(path) = self.paths.primary() { + self.loss_recovery.discard(&path, space, now); + } self.acks.drop_space(space); } } @@ -1180,7 +1204,7 @@ impl Connection { qdebug!([self], "Stateless reset: {}", hex(&d[d.len() - 16..])); self.state_signaling.reset(); self.set_state(State::Draining { - error: ConnectionError::Transport(Error::StatelessReset), + error: CloseReason::Transport(Error::StatelessReset), timeout: self.get_closing_period_time(now), }); Err(Error::StatelessReset) @@ -1227,8 +1251,9 @@ impl Connection { assert_ne!(self.version, version); qinfo!([self], "Version negotiation: trying {:?}", version); - let local_addr = self.paths.primary().borrow().local_address(); - let remote_addr = self.paths.primary().borrow().remote_address(); + let path = self.paths.primary().ok_or(Error::NoAvailablePath)?; + let local_addr = path.borrow().local_address(); + let remote_addr = path.borrow().remote_address(); let conn_params = self .conn_params .clone() @@ -1256,7 +1281,7 @@ impl Connection { } else { qinfo!([self], "Version negotiation: failed with {:?}", supported); // This error goes straight to closed. - self.set_state(State::Closed(ConnectionError::Transport( + self.set_state(State::Closed(CloseReason::Transport( Error::VersionNegotiation, ))); Err(Error::VersionNegotiation) @@ -1417,6 +1442,13 @@ impl Connection { migrate: bool, now: Instant, ) { + let space = PacketNumberSpace::from(packet.packet_type()); + if let Some(space) = self.acks.get_mut(space) { + *space.ecn_marks() += d.tos().into(); + } else { + qtrace!("Not tracking ECN for dropped packet number space"); + } + if self.state == State::WaitInitial { self.start_handshake(path, packet, now); } @@ -1491,6 +1523,16 @@ impl Connection { d.tos(), ); + #[cfg(feature = "build-fuzzing-corpus")] + if packet.packet_type() == PacketType::Initial { + let target = if self.role == Role::Client { + "server_initial" + } else { + "client_initial" + }; + neqo_common::write_item_to_fuzzing_corpus(target, &payload[..]); + } + qlog::packet_received(&mut self.qlog, &packet, &payload); let space = PacketNumberSpace::from(payload.packet_type()); if self.acks.get_mut(space).unwrap().is_duplicate(payload.pn()) { @@ -1562,7 +1604,11 @@ impl Connection { let mut probing = true; let mut d = Decoder::from(&packet[..]); while d.remaining() > 0 { + #[cfg(feature = "build-fuzzing-corpus")] + let pos = d.offset(); let f = Frame::decode(&mut d)?; + #[cfg(feature = "build-fuzzing-corpus")] + neqo_common::write_item_to_fuzzing_corpus("frame", &packet[pos..d.offset()]); ack_eliciting |= f.ack_eliciting(); probing &= f.path_probing(); let t = f.get_type(); @@ -1623,10 +1669,15 @@ impl Connection { if let Some(cid) = self.connection_ids.next() { self.paths.make_permanent(path, None, cid); Ok(()) - } else if self.paths.primary().borrow().remote_cid().is_empty() { - self.paths - .make_permanent(path, None, ConnectionIdEntry::empty_remote()); - Ok(()) + } else if let Some(primary) = self.paths.primary() { + if primary.borrow().remote_cid().is_empty() { + self.paths + .make_permanent(path, None, ConnectionIdEntry::empty_remote()); + Ok(()) + } else { + qtrace!([self], "Unable to make path permanent: {}", path.borrow()); + Err(Error::InvalidMigration) + } } else { qtrace!([self], "Unable to make path permanent: {}", path.borrow()); Err(Error::InvalidMigration) @@ -1719,8 +1770,10 @@ impl Connection { // Pointless migration is pointless. return Err(Error::InvalidMigration); } - let local = local.unwrap_or_else(|| self.paths.primary().borrow().local_address()); - let remote = remote.unwrap_or_else(|| self.paths.primary().borrow().remote_address()); + + let path = self.paths.primary().ok_or(Error::InvalidMigration)?; + let local = local.unwrap_or_else(|| path.borrow().local_address()); + let remote = remote.unwrap_or_else(|| path.borrow().remote_address()); if mem::discriminant(&local.ip()) != mem::discriminant(&remote.ip()) { // Can't mix address families. @@ -1773,7 +1826,12 @@ impl Connection { // has to use the existing address. So only pay attention to a preferred // address from the same family as is currently in use. More thought will // be needed to work out how to get addresses from a different family. - let prev = self.paths.primary().borrow().remote_address(); + let prev = self + .paths + .primary() + .ok_or(Error::NoAvailablePath)? + .borrow() + .remote_address(); let remote = match prev.ip() { IpAddr::V4(_) => addr.ipv4().map(SocketAddr::V4), IpAddr::V6(_) => addr.ipv6().map(SocketAddr::V6), @@ -1937,20 +1995,15 @@ impl Connection { } } - self.streams - .write_frames(TransmissionPriority::Critical, builder, tokens, frame_stats); - if builder.is_full() { - return; - } - - self.streams.write_frames( + for prio in [ + TransmissionPriority::Critical, TransmissionPriority::Important, - builder, - tokens, - frame_stats, - ); - if builder.is_full() { - return; + ] { + self.streams + .write_frames(prio, builder, tokens, frame_stats); + if builder.is_full() { + return; + } } // NEW_CONNECTION_ID, RETIRE_CONNECTION_ID, and ACK_FREQUENCY. @@ -1958,21 +2011,18 @@ impl Connection { if builder.is_full() { return; } - self.paths.write_frames(builder, tokens, frame_stats); - if builder.is_full() { - return; - } - self.streams - .write_frames(TransmissionPriority::High, builder, tokens, frame_stats); + self.paths.write_frames(builder, tokens, frame_stats); if builder.is_full() { return; } - self.streams - .write_frames(TransmissionPriority::Normal, builder, tokens, frame_stats); - if builder.is_full() { - return; + for prio in [TransmissionPriority::High, TransmissionPriority::Normal] { + self.streams + .write_frames(prio, builder, tokens, &mut stats.frame_tx); + if builder.is_full() { + return; + } } // Datagrams are best-effort and unreliable. Let streams starve them for now. @@ -1981,9 +2031,9 @@ impl Connection { return; } - let frame_stats = &mut stats.frame_tx; // CRYPTO here only includes NewSessionTicket, plus NEW_TOKEN. // Both of these are only used for resumption and so can be relatively low priority. + let frame_stats = &mut stats.frame_tx; self.crypto.write_frame( PacketNumberSpace::ApplicationData, builder, @@ -1993,6 +2043,7 @@ impl Connection { if builder.is_full() { return; } + self.new_token.write_frames(builder, tokens, frame_stats); if builder.is_full() { return; @@ -2002,10 +2053,8 @@ impl Connection { .write_frames(TransmissionPriority::Low, builder, tokens, frame_stats); #[cfg(test)] - { - if let Some(w) = &mut self.test_frame_writer { - w.write_frames(builder); - } + if let Some(w) = &mut self.test_frame_writer { + w.write_frames(builder); } } @@ -2138,6 +2187,40 @@ impl Connection { (tokens, ack_eliciting, padded) } + fn write_closing_frames( + &mut self, + close: &ClosingFrame, + builder: &mut PacketBuilder, + space: PacketNumberSpace, + now: Instant, + path: &PathRef, + tokens: &mut Vec<RecoveryToken>, + ) { + if builder.remaining() > ClosingFrame::MIN_LENGTH + RecvdPackets::USEFUL_ACK_LEN { + // Include an ACK frame with the CONNECTION_CLOSE. + let limit = builder.limit(); + builder.set_limit(limit - ClosingFrame::MIN_LENGTH); + self.acks.immediate_ack(now); + self.acks.write_frame( + space, + now, + path.borrow().rtt().estimate(), + builder, + tokens, + &mut self.stats.borrow_mut().frame_tx, + ); + builder.set_limit(limit); + } + // CloseReason::Application is only allowed at 1RTT. + let sanitized = if space == PacketNumberSpace::ApplicationData { + None + } else { + close.sanitize() + }; + sanitized.as_ref().unwrap_or(close).write_frame(builder); + self.stats.borrow_mut().frame_tx.connection_close += 1; + } + /// Build a datagram, possibly from multiple packets (for different PN /// spaces) and each containing 1+ frames. #[allow(clippy::too_many_lines)] // Yeah, that's just the way it is. @@ -2201,17 +2284,7 @@ impl Connection { let payload_start = builder.len(); let (mut tokens, mut ack_eliciting, mut padded) = (Vec::new(), false, false); if let Some(ref close) = closing_frame { - // ConnectionError::Application is only allowed at 1RTT. - let sanitized = if *space == PacketNumberSpace::ApplicationData { - None - } else { - close.sanitize() - }; - sanitized - .as_ref() - .unwrap_or(close) - .write_frame(&mut builder); - self.stats.borrow_mut().frame_tx.connection_close += 1; + self.write_closing_frames(close, &mut builder, *space, now, path, &mut tokens); } else { (tokens, ack_eliciting, padded) = self.write_frames(path, *space, &profile, &mut builder, now); @@ -2229,7 +2302,7 @@ impl Connection { pt, pn, &builder.as_ref()[payload_start..], - IpTos::default(), // TODO: set from path + path.borrow().tos(), ); qlog::packet_sent( &mut self.qlog, @@ -2251,6 +2324,7 @@ impl Connection { let sent = SentPacket::new( pt, pn, + path.borrow().tos().into(), now, ack_eliciting, tokens, @@ -2303,7 +2377,7 @@ impl Connection { self.loss_recovery.on_packet_sent(path, initial); } path.borrow_mut().add_sent(packets.len()); - Ok(SendOption::Yes(path.borrow().datagram(packets))) + Ok(SendOption::Yes(path.borrow_mut().datagram(packets))) } } @@ -2330,7 +2404,9 @@ impl Connection { fn client_start(&mut self, now: Instant) -> Res<()> { qdebug!([self], "client_start"); debug_assert_eq!(self.role, Role::Client); - qlog::client_connection_started(&mut self.qlog, &self.paths.primary()); + if let Some(path) = self.paths.primary() { + qlog::client_connection_started(&mut self.qlog, &path); + } qlog::client_version_information_initiated(&mut self.qlog, self.conn_params.get_versions()); self.handshake(now, self.version, PacketNumberSpace::Initial, None)?; @@ -2351,9 +2427,9 @@ impl Connection { /// Close the connection. pub fn close(&mut self, now: Instant, app_error: AppError, msg: impl AsRef<str>) { - let error = ConnectionError::Application(app_error); + let error = CloseReason::Application(app_error); let timeout = self.get_closing_period_time(now); - if let Some(path) = self.paths.primary_fallible() { + if let Some(path) = self.paths.primary() { self.state_signaling.close(path, error.clone(), 0, msg); self.set_state(State::Closing { error, timeout }); } else { @@ -2411,10 +2487,8 @@ impl Connection { // That's OK, they can try guessing this. ConnectionIdEntry::random_srt() }; - self.paths - .primary() - .borrow_mut() - .set_reset_token(reset_token); + let path = self.paths.primary().ok_or(Error::NoAvailablePath)?; + path.borrow_mut().set_reset_token(reset_token); let max_ad = Duration::from_millis(remote.get_integer(tparams::MAX_ACK_DELAY)); let min_ad = if remote.has_value(tparams::MIN_ACK_DELAY) { @@ -2426,11 +2500,8 @@ impl Connection { } else { None }; - self.paths.primary().borrow_mut().set_ack_delay( - max_ad, - min_ad, - self.conn_params.get_ack_ratio(), - ); + path.borrow_mut() + .set_ack_delay(max_ad, min_ad, self.conn_params.get_ack_ratio()); let max_active_cids = remote.get_integer(tparams::ACTIVE_CONNECTION_ID_LIMIT); self.cid_manager.set_limit(max_active_cids); @@ -2673,10 +2744,18 @@ impl Connection { ack_delay, first_ack_range, ack_ranges, + ecn_count, } => { let ranges = Frame::decode_ack_frame(largest_acknowledged, first_ack_range, &ack_ranges)?; - self.handle_ack(space, largest_acknowledged, ranges, ack_delay, now); + self.handle_ack( + space, + largest_acknowledged, + ranges, + ecn_count, + ack_delay, + now, + ); } Frame::Crypto { offset, data } => { qtrace!( @@ -2747,7 +2826,6 @@ impl Connection { reason_phrase, } => { self.stats.borrow_mut().frame_rx.connection_close += 1; - let reason_phrase = String::from_utf8_lossy(&reason_phrase); qinfo!( [self], "ConnectionClose received. Error code: {:?} frame type {:x} reason {}", @@ -2768,7 +2846,7 @@ impl Connection { FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT, ) }; - let error = ConnectionError::Transport(detail); + let error = CloseReason::Transport(detail); self.state_signaling .drain(Rc::clone(path), error.clone(), frame_type, ""); self.set_state(State::Draining { @@ -2853,6 +2931,7 @@ impl Connection { space: PacketNumberSpace, largest_acknowledged: u64, ack_ranges: R, + ack_ecn: Option<EcnCount>, ack_delay: u64, now: Instant, ) where @@ -2861,11 +2940,15 @@ impl Connection { { qdebug!([self], "Rx ACK space={}, ranges={:?}", space, ack_ranges); + let Some(path) = self.paths.primary() else { + return; + }; let (acked_packets, lost_packets) = self.loss_recovery.on_ack_received( - &self.paths.primary(), + &path, space, largest_acknowledged, ack_ranges, + ack_ecn, self.decode_ack_delay(ack_delay), now, ); @@ -2903,8 +2986,10 @@ impl Connection { qdebug!([self], "0-RTT rejected"); // Tell 0-RTT packets that they were "lost". - let dropped = self.loss_recovery.drop_0rtt(&self.paths.primary(), now); - self.handle_lost_packets(&dropped); + if let Some(path) = self.paths.primary() { + let dropped = self.loss_recovery.drop_0rtt(&path, now); + self.handle_lost_packets(&dropped); + } self.streams.zero_rtt_rejected(); @@ -2923,7 +3008,7 @@ impl Connection { // Remove the randomized client CID from the list of acceptable CIDs. self.cid_manager.remove_odcid(); // Mark the path as validated, if it isn't already. - let path = self.paths.primary(); + let path = self.paths.primary().ok_or(Error::NoAvailablePath)?; path.borrow_mut().set_valid(now); // Generate a qlog event that the server connection started. qlog::server_connection_started(&mut self.qlog, &path); @@ -3191,7 +3276,7 @@ impl Connection { else { return Err(Error::NotAvailable); }; - let path = self.paths.primary_fallible().ok_or(Error::NotAvailable)?; + let path = self.paths.primary().ok_or(Error::NotAvailable)?; let mtu = path.borrow().mtu(); let encoder = Encoder::with_capacity(mtu); diff --git a/third_party/rust/neqo-transport/src/connection/state.rs b/third_party/rust/neqo-transport/src/connection/state.rs index cc2f6e30d2..e76f937938 100644 --- a/third_party/rust/neqo-transport/src/connection/state.rs +++ b/third_party/rust/neqo-transport/src/connection/state.rs @@ -21,7 +21,7 @@ use crate::{ packet::PacketBuilder, path::PathRef, recovery::RecoveryToken, - ConnectionError, Error, + CloseReason, Error, }; #[derive(Clone, Debug, PartialEq, Eq)] @@ -42,14 +42,14 @@ pub enum State { Connected, Confirmed, Closing { - error: ConnectionError, + error: CloseReason, timeout: Instant, }, Draining { - error: ConnectionError, + error: CloseReason, timeout: Instant, }, - Closed(ConnectionError), + Closed(CloseReason), } impl State { @@ -67,7 +67,7 @@ impl State { } #[must_use] - pub fn error(&self) -> Option<&ConnectionError> { + pub fn error(&self) -> Option<&CloseReason> { if let Self::Closing { error, .. } | Self::Draining { error, .. } | Self::Closed(error) = self { @@ -116,7 +116,7 @@ impl Ord for State { #[derive(Debug, Clone)] pub struct ClosingFrame { path: PathRef, - error: ConnectionError, + error: CloseReason, frame_type: FrameType, reason_phrase: Vec<u8>, } @@ -124,7 +124,7 @@ pub struct ClosingFrame { impl ClosingFrame { fn new( path: PathRef, - error: ConnectionError, + error: CloseReason, frame_type: FrameType, message: impl AsRef<str>, ) -> Self { @@ -142,12 +142,12 @@ impl ClosingFrame { } pub fn sanitize(&self) -> Option<Self> { - if let ConnectionError::Application(_) = self.error { + if let CloseReason::Application(_) = self.error { // The default CONNECTION_CLOSE frame that is sent when an application // error code needs to be sent in an Initial or Handshake packet. Some(Self { path: Rc::clone(&self.path), - error: ConnectionError::Transport(Error::ApplicationError), + error: CloseReason::Transport(Error::ApplicationError), frame_type: 0, reason_phrase: Vec::new(), }) @@ -156,19 +156,22 @@ impl ClosingFrame { } } + /// Length of a closing frame with a truncated `reason_length`. Allow 8 bytes for the reason + /// phrase to ensure that if it needs to be truncated there is still at least a few bytes of + /// the value. + pub const MIN_LENGTH: usize = 1 + 8 + 8 + 2 + 8; + pub fn write_frame(&self, builder: &mut PacketBuilder) { - // Allow 8 bytes for the reason phrase to ensure that if it needs to be - // truncated there is still at least a few bytes of the value. - if builder.remaining() < 1 + 8 + 8 + 2 + 8 { + if builder.remaining() < ClosingFrame::MIN_LENGTH { return; } match &self.error { - ConnectionError::Transport(e) => { + CloseReason::Transport(e) => { builder.encode_varint(FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT); builder.encode_varint(e.code()); builder.encode_varint(self.frame_type); } - ConnectionError::Application(code) => { + CloseReason::Application(code) => { builder.encode_varint(FRAME_TYPE_CONNECTION_CLOSE_APPLICATION); builder.encode_varint(*code); } @@ -209,10 +212,6 @@ pub enum StateSignaling { impl StateSignaling { pub fn handshake_done(&mut self) { if !matches!(self, Self::Idle) { - debug_assert!( - false, - "StateSignaling must be in Idle state but is in {self:?} state.", - ); return; } *self = Self::HandshakeDone; @@ -231,7 +230,7 @@ impl StateSignaling { pub fn close( &mut self, path: PathRef, - error: ConnectionError, + error: CloseReason, frame_type: FrameType, message: impl AsRef<str>, ) { @@ -243,7 +242,7 @@ impl StateSignaling { pub fn drain( &mut self, path: PathRef, - error: ConnectionError, + error: CloseReason, frame_type: FrameType, message: impl AsRef<str>, ) { diff --git a/third_party/rust/neqo-transport/src/connection/tests/cc.rs b/third_party/rust/neqo-transport/src/connection/tests/cc.rs index b708bc421d..f21f4e184f 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/cc.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/cc.rs @@ -6,7 +6,7 @@ use std::{mem, time::Duration}; -use neqo_common::{qdebug, qinfo, Datagram}; +use neqo_common::{qdebug, qinfo, Datagram, IpTosEcn}; use super::{ super::Output, ack_bytes, assert_full_cwnd, connect_rtt_idle, cwnd, cwnd_avail, cwnd_packets, @@ -36,9 +36,13 @@ fn cc_slow_start() { assert!(cwnd_avail(&client) < ACK_ONLY_SIZE_LIMIT); } -#[test] -/// Verify that CC moves to cong avoidance when a packet is marked lost. -fn cc_slow_start_to_cong_avoidance_recovery_period() { +#[derive(PartialEq, Eq, Clone, Copy)] +enum CongestionSignal { + PacketLoss, + EcnCe, +} + +fn cc_slow_start_to_cong_avoidance_recovery_period(congestion_signal: CongestionSignal) { let mut client = default_client(); let mut server = default_server(); let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT); @@ -78,9 +82,17 @@ fn cc_slow_start_to_cong_avoidance_recovery_period() { assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND * 2); let flight2_largest = flight1_largest + u64::try_from(c_tx_dgrams.len()).unwrap(); - // Server: Receive and generate ack again, but drop first packet + // Server: Receive and generate ack again, but this time add congestion + // signal first. now += DEFAULT_RTT / 2; - c_tx_dgrams.remove(0); + match congestion_signal { + CongestionSignal::PacketLoss => { + c_tx_dgrams.remove(0); + } + CongestionSignal::EcnCe => { + c_tx_dgrams.last_mut().unwrap().set_tos(IpTosEcn::Ce.into()); + } + } let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now); assert_eq!( server.stats().frame_tx.largest_acknowledged, @@ -98,6 +110,18 @@ fn cc_slow_start_to_cong_avoidance_recovery_period() { } #[test] +/// Verify that CC moves to cong avoidance when a packet is marked lost. +fn cc_slow_start_to_cong_avoidance_recovery_period_due_to_packet_loss() { + cc_slow_start_to_cong_avoidance_recovery_period(CongestionSignal::PacketLoss); +} + +/// Verify that CC moves to cong avoidance when ACK is marked with ECN CE. +#[test] +fn cc_slow_start_to_cong_avoidance_recovery_period_due_to_ecn_ce() { + cc_slow_start_to_cong_avoidance_recovery_period(CongestionSignal::EcnCe); +} + +#[test] /// Verify that CC stays in recovery period when packet sent before start of /// recovery period is acked. fn cc_cong_avoidance_recovery_period_unchanged() { diff --git a/third_party/rust/neqo-transport/src/connection/tests/close.rs b/third_party/rust/neqo-transport/src/connection/tests/close.rs index 5351dd0d5c..7c620de17e 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/close.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/close.rs @@ -14,13 +14,13 @@ use super::{ }; use crate::{ tparams::{self, TransportParameter}, - AppError, ConnectionError, Error, ERROR_APPLICATION_CLOSE, + AppError, CloseReason, Error, ERROR_APPLICATION_CLOSE, }; fn assert_draining(c: &Connection, expected: &Error) { assert!(c.state().closed()); if let State::Draining { - error: ConnectionError::Transport(error), + error: CloseReason::Transport(error), .. } = c.state() { @@ -40,7 +40,14 @@ fn connection_close() { client.close(now, 42, ""); + let stats_before = client.stats().frame_tx; let out = client.process(None, now); + let stats_after = client.stats().frame_tx; + assert_eq!( + stats_after.connection_close, + stats_before.connection_close + 1 + ); + assert_eq!(stats_after.ack, stats_before.ack + 1); server.process_input(&out.dgram().unwrap(), now); assert_draining(&server, &Error::PeerApplicationError(42)); @@ -57,7 +64,14 @@ fn connection_close_with_long_reason_string() { let long_reason = String::from_utf8([0x61; 2048].to_vec()).unwrap(); client.close(now, 42, long_reason); + let stats_before = client.stats().frame_tx; let out = client.process(None, now); + let stats_after = client.stats().frame_tx; + assert_eq!( + stats_after.connection_close, + stats_before.connection_close + 1 + ); + assert_eq!(stats_after.ack, stats_before.ack + 1); server.process_input(&out.dgram().unwrap(), now); assert_draining(&server, &Error::PeerApplicationError(42)); @@ -100,7 +114,7 @@ fn bad_tls_version() { let dgram = server.process(dgram.as_ref(), now()).dgram(); assert_eq!( *server.state(), - State::Closed(ConnectionError::Transport(Error::ProtocolViolation)) + State::Closed(CloseReason::Transport(Error::ProtocolViolation)) ); assert!(dgram.is_some()); client.process_input(&dgram.unwrap(), now()); @@ -154,7 +168,6 @@ fn closing_and_draining() { assert!(client_close.is_some()); let client_close_timer = client.process(None, now()).callback(); assert_ne!(client_close_timer, Duration::from_secs(0)); - // The client will spit out the same packet in response to anything it receives. let p3 = send_something(&mut server, now()); let client_close2 = client.process(Some(&p3), now()).dgram(); @@ -168,7 +181,7 @@ fn closing_and_draining() { assert_eq!(end, Output::None); assert_eq!( *client.state(), - State::Closed(ConnectionError::Application(APP_ERROR)) + State::Closed(CloseReason::Application(APP_ERROR)) ); // When the server receives the close, it too should generate CONNECTION_CLOSE. @@ -186,7 +199,7 @@ fn closing_and_draining() { assert_eq!(end, Output::None); assert_eq!( *server.state(), - State::Closed(ConnectionError::Transport(Error::PeerApplicationError( + State::Closed(CloseReason::Transport(Error::PeerApplicationError( APP_ERROR ))) ); diff --git a/third_party/rust/neqo-transport/src/connection/tests/datagram.rs b/third_party/rust/neqo-transport/src/connection/tests/datagram.rs index ade8c753be..f1b64b3c8f 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/datagram.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/datagram.rs @@ -19,7 +19,7 @@ use crate::{ packet::PacketBuilder, quic_datagrams::MAX_QUIC_DATAGRAM, send_stream::{RetransmissionPriority, TransmissionPriority}, - Connection, ConnectionError, ConnectionParameters, Error, StreamType, + CloseReason, Connection, ConnectionParameters, Error, StreamType, }; const DATAGRAM_LEN_MTU: u64 = 1310; @@ -362,10 +362,7 @@ fn dgram_no_allowed() { client.process_input(&out, now()); - assert_error( - &client, - &ConnectionError::Transport(Error::ProtocolViolation), - ); + assert_error(&client, &CloseReason::Transport(Error::ProtocolViolation)); } #[test] @@ -383,10 +380,7 @@ fn dgram_too_big() { client.process_input(&out, now()); - assert_error( - &client, - &ConnectionError::Transport(Error::ProtocolViolation), - ); + assert_error(&client, &CloseReason::Transport(Error::ProtocolViolation)); } #[test] @@ -587,7 +581,7 @@ fn datagram_fill() { // Work out how much space we have for a datagram. let space = { - let p = client.paths.primary(); + let p = client.paths.primary().unwrap(); let path = p.borrow(); // Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number, // 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD. diff --git a/third_party/rust/neqo-transport/src/connection/tests/ecn.rs b/third_party/rust/neqo-transport/src/connection/tests/ecn.rs new file mode 100644 index 0000000000..87957297e5 --- /dev/null +++ b/third_party/rust/neqo-transport/src/connection/tests/ecn.rs @@ -0,0 +1,392 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::time::Duration; + +use neqo_common::{Datagram, IpTos, IpTosEcn}; +use test_fixture::{ + assertions::{assert_v4_path, assert_v6_path}, + fixture_init, now, DEFAULT_ADDR_V4, +}; + +use super::send_something_with_modifier; +use crate::{ + connection::tests::{ + connect_force_idle, connect_force_idle_with_modifier, default_client, default_server, + migration::get_cid, new_client, new_server, send_something, + }, + ecn::ECN_TEST_COUNT, + ConnectionId, ConnectionParameters, StreamType, +}; + +fn assert_ecn_enabled(tos: IpTos) { + assert!(tos.is_ecn_marked()); +} + +fn assert_ecn_disabled(tos: IpTos) { + assert!(!tos.is_ecn_marked()); +} + +fn set_tos(mut d: Datagram, ecn: IpTosEcn) -> Datagram { + d.set_tos(ecn.into()); + d +} + +fn noop() -> fn(Datagram) -> Option<Datagram> { + Some +} + +fn bleach() -> fn(Datagram) -> Option<Datagram> { + |d| Some(set_tos(d, IpTosEcn::NotEct)) +} + +fn remark() -> fn(Datagram) -> Option<Datagram> { + |d| { + if d.tos().is_ecn_marked() { + Some(set_tos(d, IpTosEcn::Ect1)) + } else { + Some(d) + } + } +} + +fn ce() -> fn(Datagram) -> Option<Datagram> { + |d| { + if d.tos().is_ecn_marked() { + Some(set_tos(d, IpTosEcn::Ce)) + } else { + Some(d) + } + } +} + +fn drop() -> fn(Datagram) -> Option<Datagram> { + |_| None +} + +#[test] +fn disables_on_loss() { + let now = now(); + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + // Right after the handshake, the ECN validation should still be in progress. + let client_pkt = send_something(&mut client, now); + assert_ecn_enabled(client_pkt.tos()); + + for _ in 0..ECN_TEST_COUNT { + send_something(&mut client, now); + } + + // ECN should now be disabled. + let client_pkt = send_something(&mut client, now); + assert_ecn_disabled(client_pkt.tos()); +} + +/// This function performs a handshake over a path that modifies packets via `orig_path_modifier`. +/// It then sends `burst` packets on that path, and then migrates to a new path that +/// modifies packets via `new_path_modifier`. It sends `burst` packets on the new path. +/// The function returns the TOS value of the last packet sent on the old path and the TOS value +/// of the last packet sent on the new path to allow for verification of correct behavior. +pub fn migration_with_modifiers( + orig_path_modifier: fn(Datagram) -> Option<Datagram>, + new_path_modifier: fn(Datagram) -> Option<Datagram>, + burst: usize, +) -> (IpTos, IpTos, bool) { + fixture_init(); + let mut client = new_client(ConnectionParameters::default().max_streams(StreamType::UniDi, 64)); + let mut server = new_server(ConnectionParameters::default().max_streams(StreamType::UniDi, 64)); + + connect_force_idle_with_modifier(&mut client, &mut server, orig_path_modifier); + let mut now = now(); + + // Right after the handshake, the ECN validation should still be in progress. + let client_pkt = send_something(&mut client, now); + assert_ecn_enabled(client_pkt.tos()); + server.process_input(&orig_path_modifier(client_pkt).unwrap(), now); + + // Send some data on the current path. + for _ in 0..burst { + let client_pkt = send_something_with_modifier(&mut client, now, orig_path_modifier); + server.process_input(&client_pkt, now); + } + + if let Some(ack) = server.process_output(now).dgram() { + client.process_input(&ack, now); + } + + let client_pkt = send_something(&mut client, now); + let tos_before_migration = client_pkt.tos(); + server.process_input(&orig_path_modifier(client_pkt).unwrap(), now); + + client + .migrate(Some(DEFAULT_ADDR_V4), Some(DEFAULT_ADDR_V4), false, now) + .unwrap(); + + let mut migrated = false; + let probe = new_path_modifier(client.process_output(now).dgram().unwrap()); + if let Some(probe) = probe { + assert_v4_path(&probe, true); // Contains PATH_CHALLENGE. + assert_eq!(client.stats().frame_tx.path_challenge, 1); + let probe_cid = ConnectionId::from(get_cid(&probe)); + + let resp = new_path_modifier(server.process(Some(&probe), now).dgram().unwrap()).unwrap(); + assert_v4_path(&resp, true); + assert_eq!(server.stats().frame_tx.path_response, 1); + assert_eq!(server.stats().frame_tx.path_challenge, 1); + + // Data continues to be exchanged on the old path. + let client_data = send_something_with_modifier(&mut client, now, orig_path_modifier); + assert_ne!(get_cid(&client_data), probe_cid); + assert_v6_path(&client_data, false); + server.process_input(&client_data, now); + let server_data = send_something_with_modifier(&mut server, now, orig_path_modifier); + assert_v6_path(&server_data, false); + client.process_input(&server_data, now); + + // Once the client receives the probe response, it migrates to the new path. + client.process_input(&resp, now); + assert_eq!(client.stats().frame_rx.path_challenge, 1); + migrated = true; + + let migrate_client = send_something_with_modifier(&mut client, now, new_path_modifier); + assert_v4_path(&migrate_client, true); // Responds to server probe. + + // The server now sees the migration and will switch over. + // However, it will probe the old path again, even though it has just + // received a response to its last probe, because it needs to verify + // that the migration is genuine. + server.process_input(&migrate_client, now); + } + + let stream_before = server.stats().frame_tx.stream; + let probe_old_server = send_something_with_modifier(&mut server, now, orig_path_modifier); + // This is just the double-check probe; no STREAM frames. + assert_v6_path(&probe_old_server, migrated); + assert_eq!( + server.stats().frame_tx.path_challenge, + if migrated { 2 } else { 0 } + ); + assert_eq!( + server.stats().frame_tx.stream, + if migrated { stream_before } else { 1 } + ); + + if migrated { + // The server then sends data on the new path. + let migrate_server = + new_path_modifier(server.process_output(now).dgram().unwrap()).unwrap(); + assert_v4_path(&migrate_server, false); + assert_eq!(server.stats().frame_tx.path_challenge, 2); + assert_eq!(server.stats().frame_tx.stream, stream_before + 1); + + // The client receives these checks and responds to the probe, but uses the new path. + client.process_input(&migrate_server, now); + client.process_input(&probe_old_server, now); + let old_probe_resp = send_something_with_modifier(&mut client, now, new_path_modifier); + assert_v6_path(&old_probe_resp, true); + let client_confirmation = client.process_output(now).dgram().unwrap(); + assert_v4_path(&client_confirmation, false); + + // The server has now sent 2 packets, so it is blocked on the pacer. Wait. + let server_pacing = server.process_output(now).callback(); + assert_ne!(server_pacing, Duration::new(0, 0)); + // ... then confirm that the server sends on the new path still. + let server_confirmation = + send_something_with_modifier(&mut server, now + server_pacing, new_path_modifier); + assert_v4_path(&server_confirmation, false); + client.process_input(&server_confirmation, now); + + // Send some data on the new path. + for _ in 0..burst { + now += client.process_output(now).callback(); + let client_pkt = send_something_with_modifier(&mut client, now, new_path_modifier); + server.process_input(&client_pkt, now); + } + + if let Some(ack) = server.process_output(now).dgram() { + client.process_input(&ack, now); + } + } + + now += client.process_output(now).callback(); + let mut client_pkt = send_something(&mut client, now); + while !migrated && client_pkt.source() == DEFAULT_ADDR_V4 { + client_pkt = send_something(&mut client, now); + } + let tos_after_migration = client_pkt.tos(); + (tos_before_migration, tos_after_migration, migrated) +} + +#[test] +fn ecn_migration_zero_burst_all_cases() { + for orig_path_mod in &[noop(), bleach(), remark(), ce()] { + for new_path_mod in &[noop(), bleach(), remark(), ce(), drop()] { + let (before, after, migrated) = + migration_with_modifiers(*orig_path_mod, *new_path_mod, 0); + // Too few packets sent before and after migration to conclude ECN validation. + assert_ecn_enabled(before); + assert_ecn_enabled(after); + // Migration succeeds except if the new path drops ECN. + assert!(*new_path_mod == drop() || migrated); + } + } +} + +#[test] +fn ecn_migration_noop_bleach_data() { + let (before, after, migrated) = migration_with_modifiers(noop(), bleach(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration. + assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching. + assert!(migrated); +} + +#[test] +fn ecn_migration_noop_remark_data() { + let (before, after, migrated) = migration_with_modifiers(noop(), remark(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration. + assert_ecn_disabled(after); // ECN validation fails after migration due to remarking. + assert!(migrated); +} + +#[test] +fn ecn_migration_noop_ce_data() { + let (before, after, migrated) = migration_with_modifiers(noop(), ce(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration. + assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks. + assert!(migrated); +} + +#[test] +fn ecn_migration_noop_drop_data() { + let (before, after, migrated) = migration_with_modifiers(noop(), drop(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration. + assert_ecn_enabled(after); // Migration failed, ECN on original path is still validated. + assert!(!migrated); +} + +#[test] +fn ecn_migration_bleach_noop_data() { + let (before, after, migrated) = migration_with_modifiers(bleach(), noop(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching. + assert_ecn_enabled(after); // ECN validation concludes after migration. + assert!(migrated); +} + +#[test] +fn ecn_migration_bleach_bleach_data() { + let (before, after, migrated) = migration_with_modifiers(bleach(), bleach(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching. + assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching. + assert!(migrated); +} + +#[test] +fn ecn_migration_bleach_remark_data() { + let (before, after, migrated) = migration_with_modifiers(bleach(), remark(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching. + assert_ecn_disabled(after); // ECN validation fails after migration due to remarking. + assert!(migrated); +} + +#[test] +fn ecn_migration_bleach_ce_data() { + let (before, after, migrated) = migration_with_modifiers(bleach(), ce(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching. + assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks. + assert!(migrated); +} + +#[test] +fn ecn_migration_bleach_drop_data() { + let (before, after, migrated) = migration_with_modifiers(bleach(), drop(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to bleaching. + // Migration failed, ECN on original path is still disabled. + assert_ecn_disabled(after); + assert!(!migrated); +} + +#[test] +fn ecn_migration_remark_noop_data() { + let (before, after, migrated) = migration_with_modifiers(remark(), noop(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to remarking. + assert_ecn_enabled(after); // ECN validation succeeds after migration. + assert!(migrated); +} + +#[test] +fn ecn_migration_remark_bleach_data() { + let (before, after, migrated) = migration_with_modifiers(remark(), bleach(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to remarking. + assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching. + assert!(migrated); +} + +#[test] +fn ecn_migration_remark_remark_data() { + let (before, after, migrated) = migration_with_modifiers(remark(), remark(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to remarking. + assert_ecn_disabled(after); // ECN validation fails after migration due to remarking. + assert!(migrated); +} + +#[test] +fn ecn_migration_remark_ce_data() { + let (before, after, migrated) = migration_with_modifiers(remark(), ce(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to remarking. + assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks. + assert!(migrated); +} + +#[test] +fn ecn_migration_remark_drop_data() { + let (before, after, migrated) = migration_with_modifiers(remark(), drop(), ECN_TEST_COUNT); + assert_ecn_disabled(before); // ECN validation fails before migration due to remarking. + assert_ecn_disabled(after); // Migration failed, ECN on original path is still disabled. + assert!(!migrated); +} + +#[test] +fn ecn_migration_ce_noop_data() { + let (before, after, migrated) = migration_with_modifiers(ce(), noop(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks. + assert_ecn_enabled(after); // ECN validation concludes after migration. + assert!(migrated); +} + +#[test] +fn ecn_migration_ce_bleach_data() { + let (before, after, migrated) = migration_with_modifiers(ce(), bleach(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks. + assert_ecn_disabled(after); // ECN validation fails after migration due to bleaching + assert!(migrated); +} + +#[test] +fn ecn_migration_ce_remark_data() { + let (before, after, migrated) = migration_with_modifiers(ce(), remark(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks. + assert_ecn_disabled(after); // ECN validation fails after migration due to remarking. + assert!(migrated); +} + +#[test] +fn ecn_migration_ce_ce_data() { + let (before, after, migrated) = migration_with_modifiers(ce(), ce(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks. + assert_ecn_enabled(after); // ECN validation concludes after migration, despite all CE marks. + assert!(migrated); +} + +#[test] +fn ecn_migration_ce_drop_data() { + let (before, after, migrated) = migration_with_modifiers(ce(), drop(), ECN_TEST_COUNT); + assert_ecn_enabled(before); // ECN validation concludes before migration, despite all CE marks. + // Migration failed, ECN on original path is still enabled. + assert_ecn_enabled(after); + assert!(!migrated); +} diff --git a/third_party/rust/neqo-transport/src/connection/tests/handshake.rs b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs index f2103523ec..c908340616 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/handshake.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs @@ -35,7 +35,7 @@ use crate::{ server::ValidateAddress, tparams::{TransportParameter, MIN_ACK_DELAY}, tracking::DEFAULT_ACK_DELAY, - ConnectionError, ConnectionParameters, EmptyConnectionIdGenerator, Error, StreamType, Version, + CloseReason, ConnectionParameters, EmptyConnectionIdGenerator, Error, StreamType, Version, }; const ECH_CONFIG_ID: u8 = 7; @@ -111,8 +111,8 @@ fn handshake_failed_authentication() { qdebug!("---- server: Alert(certificate_revoked)"); let out = server.process(out.as_dgram_ref(), now()); assert!(out.as_dgram_ref().is_some()); - assert_error(&client, &ConnectionError::Transport(Error::CryptoAlert(44))); - assert_error(&server, &ConnectionError::Transport(Error::PeerError(300))); + assert_error(&client, &CloseReason::Transport(Error::CryptoAlert(44))); + assert_error(&server, &CloseReason::Transport(Error::PeerError(300))); } #[test] @@ -133,11 +133,8 @@ fn no_alpn() { handshake(&mut client, &mut server, now(), Duration::new(0, 0)); // TODO (mt): errors are immediate, which means that we never send CONNECTION_CLOSE // and the client never sees the server's rejection of its handshake. - // assert_error(&client, ConnectionError::Transport(Error::CryptoAlert(120))); - assert_error( - &server, - &ConnectionError::Transport(Error::CryptoAlert(120)), - ); + // assert_error(&client, CloseReason::Transport(Error::CryptoAlert(120))); + assert_error(&server, &CloseReason::Transport(Error::CryptoAlert(120))); } #[test] @@ -934,10 +931,10 @@ fn ech_retry() { server.process_input(&dgram.unwrap(), now()); assert_eq!( server.state().error(), - Some(&ConnectionError::Transport(Error::PeerError(0x100 + 121))) + Some(&CloseReason::Transport(Error::PeerError(0x100 + 121))) ); - let Some(ConnectionError::Transport(Error::EchRetry(updated_config))) = client.state().error() + let Some(CloseReason::Transport(Error::EchRetry(updated_config))) = client.state().error() else { panic!( "Client state should be failed with EchRetry, is {:?}", @@ -984,7 +981,7 @@ fn ech_retry_fallback_rejected() { client.authenticated(AuthenticationStatus::PolicyRejection, now()); assert!(client.state().error().is_some()); - if let Some(ConnectionError::Transport(Error::EchRetry(_))) = client.state().error() { + if let Some(CloseReason::Transport(Error::EchRetry(_))) = client.state().error() { panic!("Client should not get EchRetry error"); } @@ -993,14 +990,13 @@ fn ech_retry_fallback_rejected() { server.process_input(&dgram.unwrap(), now()); assert_eq!( server.state().error(), - Some(&ConnectionError::Transport(Error::PeerError(298))) + Some(&CloseReason::Transport(Error::PeerError(298))) ); // A bad_certificate alert. } #[test] fn bad_min_ack_delay() { - const EXPECTED_ERROR: ConnectionError = - ConnectionError::Transport(Error::TransportParameterError); + const EXPECTED_ERROR: CloseReason = CloseReason::Transport(Error::TransportParameterError); let mut server = default_server(); let max_ad = u64::try_from(DEFAULT_ACK_DELAY.as_micros()).unwrap(); server @@ -1018,7 +1014,7 @@ fn bad_min_ack_delay() { server.process_input(&dgram.unwrap(), now()); assert_eq!( server.state().error(), - Some(&ConnectionError::Transport(Error::PeerError( + Some(&CloseReason::Transport(Error::PeerError( Error::TransportParameterError.code() ))) ); diff --git a/third_party/rust/neqo-transport/src/connection/tests/keys.rs b/third_party/rust/neqo-transport/src/connection/tests/keys.rs index 847b253284..c2ae9529bf 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/keys.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/keys.rs @@ -11,7 +11,7 @@ use test_fixture::now; use super::{ super::{ - super::{ConnectionError, ERROR_AEAD_LIMIT_REACHED}, + super::{CloseReason, ERROR_AEAD_LIMIT_REACHED}, Connection, ConnectionParameters, Error, Output, State, StreamType, }, connect, connect_force_idle, default_client, default_server, maybe_authenticate, @@ -269,7 +269,7 @@ fn exhaust_write_keys() { assert!(dgram.is_none()); assert!(matches!( client.state(), - State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + State::Closed(CloseReason::Transport(Error::KeysExhausted)) )); } @@ -285,14 +285,14 @@ fn exhaust_read_keys() { let dgram = server.process(Some(&dgram), now()).dgram(); assert!(matches!( server.state(), - State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + State::Closed(CloseReason::Transport(Error::KeysExhausted)) )); client.process_input(&dgram.unwrap(), now()); assert!(matches!( client.state(), State::Draining { - error: ConnectionError::Transport(Error::PeerError(ERROR_AEAD_LIMIT_REACHED)), + error: CloseReason::Transport(Error::PeerError(ERROR_AEAD_LIMIT_REACHED)), .. } )); @@ -341,6 +341,6 @@ fn automatic_update_write_keys_blocked() { assert!(dgram.is_none()); assert!(matches!( client.state(), - State::Closed(ConnectionError::Transport(Error::KeysExhausted)) + State::Closed(CloseReason::Transport(Error::KeysExhausted)) )); } diff --git a/third_party/rust/neqo-transport/src/connection/tests/migration.rs b/third_party/rust/neqo-transport/src/connection/tests/migration.rs index 405ae161a4..779cc78c53 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/migration.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/migration.rs @@ -30,7 +30,7 @@ use crate::{ packet::PacketBuilder, path::{PATH_MTU_V4, PATH_MTU_V6}, tparams::{self, PreferredAddress, TransportParameter}, - ConnectionError, ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef, + CloseReason, ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef, ConnectionParameters, EmptyConnectionIdGenerator, Error, }; @@ -357,13 +357,13 @@ fn migrate_same_fail() { assert!(matches!(res, Output::None)); assert!(matches!( client.state(), - State::Closed(ConnectionError::Transport(Error::NoAvailablePath)) + State::Closed(CloseReason::Transport(Error::NoAvailablePath)) )); } /// This gets the connection ID from a datagram using the default /// connection ID generator/decoder. -fn get_cid(d: &Datagram) -> ConnectionIdRef { +pub fn get_cid(d: &Datagram) -> ConnectionIdRef { let gen = CountingConnectionIdGenerator::default(); assert_eq!(d[0] & 0x80, 0); // Only support short packets for now. gen.decode_cid(&mut Decoder::from(&d[1..])).unwrap() @@ -894,7 +894,7 @@ fn retire_prior_to_migration_failure() { assert!(matches!( client.state(), State::Closing { - error: ConnectionError::Transport(Error::InvalidMigration), + error: CloseReason::Transport(Error::InvalidMigration), .. } )); diff --git a/third_party/rust/neqo-transport/src/connection/tests/mod.rs b/third_party/rust/neqo-transport/src/connection/tests/mod.rs index c8c87a0df0..65283b8eb8 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/mod.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/mod.rs @@ -17,7 +17,7 @@ use neqo_common::{event::Provider, qdebug, qtrace, Datagram, Decoder, Role}; use neqo_crypto::{random, AllowZeroRtt, AuthenticationStatus, ResumptionToken}; use test_fixture::{fixture_init, new_neqo_qlog, now, DEFAULT_ADDR}; -use super::{Connection, ConnectionError, ConnectionId, Output, State}; +use super::{CloseReason, Connection, ConnectionId, Output, State}; use crate::{ addr_valid::{AddressValidation, ValidateAddress}, cc::{CWND_INITIAL_PKTS, CWND_MIN}, @@ -37,6 +37,7 @@ mod ackrate; mod cc; mod close; mod datagram; +mod ecn; mod handshake; mod idle; mod keys; @@ -170,17 +171,13 @@ impl crate::connection::test_internal::FrameWriter for PingWriter { } } -trait DatagramModifier: FnMut(Datagram) -> Option<Datagram> {} - -impl<T> DatagramModifier for T where T: FnMut(Datagram) -> Option<Datagram> {} - /// Drive the handshake between the client and server. fn handshake_with_modifier( client: &mut Connection, server: &mut Connection, now: Instant, rtt: Duration, - mut modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) -> Instant { let mut a = client; let mut b = server; @@ -248,8 +245,8 @@ fn connect_fail( server_error: Error, ) { handshake(client, server, now(), Duration::new(0, 0)); - assert_error(client, &ConnectionError::Transport(client_error)); - assert_error(server, &ConnectionError::Transport(server_error)); + assert_error(client, &CloseReason::Transport(client_error)); + assert_error(server, &CloseReason::Transport(server_error)); } fn connect_with_rtt_and_modifier( @@ -257,7 +254,7 @@ fn connect_with_rtt_and_modifier( server: &mut Connection, now: Instant, rtt: Duration, - modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) -> Instant { fn check_rtt(stats: &Stats, rtt: Duration) { assert_eq!(stats.rtt, rtt); @@ -287,7 +284,7 @@ fn connect(client: &mut Connection, server: &mut Connection) { connect_with_rtt(client, server, now(), Duration::new(0, 0)); } -fn assert_error(c: &Connection, expected: &ConnectionError) { +fn assert_error(c: &Connection, expected: &CloseReason) { match c.state() { State::Closing { error, .. } | State::Draining { error, .. } | State::Closed(error) => { assert_eq!(*error, *expected, "{c} error mismatch"); @@ -333,7 +330,7 @@ fn connect_rtt_idle_with_modifier( client: &mut Connection, server: &mut Connection, rtt: Duration, - modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) -> Instant { let now = connect_with_rtt_and_modifier(client, server, now(), rtt, modifier); assert_idle(client, server, rtt, now); @@ -351,7 +348,7 @@ fn connect_rtt_idle(client: &mut Connection, server: &mut Connection, rtt: Durat fn connect_force_idle_with_modifier( client: &mut Connection, server: &mut Connection, - modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) { connect_rtt_idle_with_modifier(client, server, Duration::new(0, 0), modifier); } @@ -380,7 +377,7 @@ fn fill_stream(c: &mut Connection, stream: StreamId) { fn fill_cwnd(c: &mut Connection, stream: StreamId, mut now: Instant) -> (Vec<Datagram>, Instant) { // Train wreck function to get the remaining congestion window on the primary path. fn cwnd(c: &Connection) -> usize { - c.paths.primary().borrow().sender().cwnd_avail() + c.paths.primary().unwrap().borrow().sender().cwnd_avail() } qtrace!("fill_cwnd starting cwnd: {}", cwnd(c)); @@ -478,10 +475,10 @@ where // Get the current congestion window for the connection. fn cwnd(c: &Connection) -> usize { - c.paths.primary().borrow().sender().cwnd() + c.paths.primary().unwrap().borrow().sender().cwnd() } fn cwnd_avail(c: &Connection) -> usize { - c.paths.primary().borrow().sender().cwnd_avail() + c.paths.primary().unwrap().borrow().sender().cwnd_avail() } fn induce_persistent_congestion( @@ -576,7 +573,7 @@ fn send_something_paced_with_modifier( sender: &mut Connection, mut now: Instant, allow_pacing: bool, - mut modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) -> (Datagram, Instant) { let stream_id = sender.stream_create(StreamType::UniDi).unwrap(); assert!(sender.stream_send(stream_id, DEFAULT_STREAM_DATA).is_ok()); @@ -608,7 +605,7 @@ fn send_something_paced( fn send_something_with_modifier( sender: &mut Connection, now: Instant, - modifier: impl DatagramModifier, + modifier: fn(Datagram) -> Option<Datagram>, ) -> Datagram { send_something_paced_with_modifier(sender, now, false, modifier).0 } diff --git a/third_party/rust/neqo-transport/src/connection/tests/stream.rs b/third_party/rust/neqo-transport/src/connection/tests/stream.rs index 66d3bf32f3..f7472d917f 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/stream.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/stream.rs @@ -19,9 +19,9 @@ use crate::{ send_stream::{OrderGroup, SendStreamState, SEND_BUFFER_SIZE}, streams::{SendOrder, StreamOrder}, tparams::{self, TransportParameter}, + CloseReason, // tracking::DEFAULT_ACK_PACKET_TOLERANCE, Connection, - ConnectionError, ConnectionParameters, Error, StreamId, @@ -494,12 +494,9 @@ fn exceed_max_data() { assert_error( &client, - &ConnectionError::Transport(Error::PeerError(Error::FlowControlError.code())), - ); - assert_error( - &server, - &ConnectionError::Transport(Error::FlowControlError), + &CloseReason::Transport(Error::PeerError(Error::FlowControlError.code())), ); + assert_error(&server, &CloseReason::Transport(Error::FlowControlError)); } #[test] diff --git a/third_party/rust/neqo-transport/src/connection/tests/vn.rs b/third_party/rust/neqo-transport/src/connection/tests/vn.rs index 93872a94f4..815868d78d 100644 --- a/third_party/rust/neqo-transport/src/connection/tests/vn.rs +++ b/third_party/rust/neqo-transport/src/connection/tests/vn.rs @@ -10,7 +10,7 @@ use neqo_common::{event::Provider, Decoder, Encoder}; use test_fixture::{assertions, datagram, now}; use super::{ - super::{ConnectionError, ConnectionEvent, Output, State, ZeroRttState}, + super::{CloseReason, ConnectionEvent, Output, State, ZeroRttState}, connect, connect_fail, default_client, default_server, exchange_ticket, new_client, new_server, send_something, }; @@ -124,7 +124,7 @@ fn version_negotiation_only_reserved() { assert_eq!(client.process(Some(&dgram), now()), Output::None); match client.state() { State::Closed(err) => { - assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation)); + assert_eq!(*err, CloseReason::Transport(Error::VersionNegotiation)); } _ => panic!("Invalid client state"), } @@ -183,7 +183,7 @@ fn version_negotiation_not_supported() { assert_eq!(client.process(Some(&dgram), now()), Output::None); match client.state() { State::Closed(err) => { - assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation)); + assert_eq!(*err, CloseReason::Transport(Error::VersionNegotiation)); } _ => panic!("Invalid client state"), } @@ -338,7 +338,7 @@ fn invalid_server_version() { // The server effectively hasn't reacted here. match server.state() { State::Closed(err) => { - assert_eq!(*err, ConnectionError::Transport(Error::CryptoAlert(47))); + assert_eq!(*err, CloseReason::Transport(Error::CryptoAlert(47))); } _ => panic!("invalid server state"), } |