summaryrefslogtreecommitdiffstats
path: root/third_party/rust/neqo-transport/src
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
commit26a029d407be480d791972afb5975cf62c9360a6 (patch)
treef435a8308119effd964b339f76abb83a57c29483 /third_party/rust/neqo-transport/src
parentInitial commit. (diff)
downloadfirefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz
firefox-26a029d407be480d791972afb5975cf62c9360a6.zip
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/neqo-transport/src')
-rw-r--r--third_party/rust/neqo-transport/src/ackrate.rs213
-rw-r--r--third_party/rust/neqo-transport/src/addr_valid.rs508
-rw-r--r--third_party/rust/neqo-transport/src/cc/classic_cc.rs1186
-rw-r--r--third_party/rust/neqo-transport/src/cc/cubic.rs215
-rw-r--r--third_party/rust/neqo-transport/src/cc/mod.rs87
-rw-r--r--third_party/rust/neqo-transport/src/cc/new_reno.rs51
-rw-r--r--third_party/rust/neqo-transport/src/cc/tests/cubic.rs333
-rw-r--r--third_party/rust/neqo-transport/src/cc/tests/mod.rs7
-rw-r--r--third_party/rust/neqo-transport/src/cc/tests/new_reno.rs219
-rw-r--r--third_party/rust/neqo-transport/src/cid.rs609
-rw-r--r--third_party/rust/neqo-transport/src/connection/dump.rs46
-rw-r--r--third_party/rust/neqo-transport/src/connection/idle.rs120
-rw-r--r--third_party/rust/neqo-transport/src/connection/mod.rs3241
-rw-r--r--third_party/rust/neqo-transport/src/connection/params.rs392
-rw-r--r--third_party/rust/neqo-transport/src/connection/saved.rs68
-rw-r--r--third_party/rust/neqo-transport/src/connection/state.rs281
-rw-r--r--third_party/rust/neqo-transport/src/connection/test_internal.rs13
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/ackrate.rs194
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/cc.rs429
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/close.rs210
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/datagram.rs620
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/fuzzing.rs44
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/handshake.rs1137
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/idle.rs752
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/keys.rs346
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/migration.rs953
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/mod.rs614
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/priority.rs404
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/recovery.rs804
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/resumption.rs246
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/stream.rs1162
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/vn.rs482
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/zerortt.rs257
-rw-r--r--third_party/rust/neqo-transport/src/crypto.rs1583
-rw-r--r--third_party/rust/neqo-transport/src/events.rs321
-rw-r--r--third_party/rust/neqo-transport/src/fc.rs918
-rw-r--r--third_party/rust/neqo-transport/src/frame.rs977
-rw-r--r--third_party/rust/neqo-transport/src/lib.rs226
-rw-r--r--third_party/rust/neqo-transport/src/pace.rs165
-rw-r--r--third_party/rust/neqo-transport/src/packet/mod.rs1457
-rw-r--r--third_party/rust/neqo-transport/src/packet/retry.rs59
-rw-r--r--third_party/rust/neqo-transport/src/path.rs1032
-rw-r--r--third_party/rust/neqo-transport/src/qlog.rs563
-rw-r--r--third_party/rust/neqo-transport/src/quic_datagrams.rs185
-rw-r--r--third_party/rust/neqo-transport/src/recovery.rs1610
-rw-r--r--third_party/rust/neqo-transport/src/recv_stream.rs2149
-rw-r--r--third_party/rust/neqo-transport/src/rtt.rs211
-rw-r--r--third_party/rust/neqo-transport/src/send_stream.rs2636
-rw-r--r--third_party/rust/neqo-transport/src/sender.rs130
-rw-r--r--third_party/rust/neqo-transport/src/server.rs782
-rw-r--r--third_party/rust/neqo-transport/src/stats.rs235
-rw-r--r--third_party/rust/neqo-transport/src/stream_id.rs177
-rw-r--r--third_party/rust/neqo-transport/src/streams.rs547
-rw-r--r--third_party/rust/neqo-transport/src/tparams.rs1130
-rw-r--r--third_party/rust/neqo-transport/src/tracking.rs1228
-rw-r--r--third_party/rust/neqo-transport/src/version.rs235
56 files changed, 34799 insertions, 0 deletions
diff --git a/third_party/rust/neqo-transport/src/ackrate.rs b/third_party/rust/neqo-transport/src/ackrate.rs
new file mode 100644
index 0000000000..cf68f9021f
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/ackrate.rs
@@ -0,0 +1,213 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Management of the peer's ack rate.
+#![deny(clippy::pedantic)]
+
+use std::{cmp::max, convert::TryFrom, time::Duration};
+
+use neqo_common::qtrace;
+
+use crate::{
+ connection::params::ACK_RATIO_SCALE, frame::FRAME_TYPE_ACK_FREQUENCY, packet::PacketBuilder,
+ recovery::RecoveryToken, stats::FrameStats,
+};
+
+#[derive(Debug, Clone)]
+pub struct AckRate {
+ /// The maximum number of packets that can be received without sending an ACK.
+ packets: usize,
+ /// The maximum delay before sending an ACK.
+ delay: Duration,
+}
+
+impl AckRate {
+ pub fn new(minimum: Duration, ratio: u8, cwnd: usize, mtu: usize, rtt: Duration) -> Self {
+ const PACKET_RATIO: usize = ACK_RATIO_SCALE as usize;
+ // At worst, ask for an ACK for every other packet.
+ const MIN_PACKETS: usize = 2;
+ // At worst, require an ACK every 256 packets.
+ const MAX_PACKETS: usize = 256;
+ const RTT_RATIO: u32 = ACK_RATIO_SCALE as u32;
+ const MAX_DELAY: Duration = Duration::from_millis(50);
+
+ let packets = cwnd * PACKET_RATIO / mtu / usize::from(ratio);
+ let packets = packets.clamp(MIN_PACKETS, MAX_PACKETS) - 1;
+ let delay = rtt * RTT_RATIO / u32::from(ratio);
+ let delay = delay.clamp(minimum, MAX_DELAY);
+ qtrace!("AckRate inputs: {}/{}/{}, {:?}", cwnd, mtu, ratio, rtt);
+ Self { packets, delay }
+ }
+
+ pub fn write_frame(&self, builder: &mut PacketBuilder, seqno: u64) -> bool {
+ builder.write_varint_frame(&[
+ FRAME_TYPE_ACK_FREQUENCY,
+ seqno,
+ u64::try_from(self.packets + 1).unwrap(),
+ u64::try_from(self.delay.as_micros()).unwrap(),
+ 0,
+ ])
+ }
+
+ /// Determine whether to send an update frame.
+ pub fn needs_update(&self, target: &Self) -> bool {
+ if self.packets != target.packets {
+ return true;
+ }
+ // Allow more flexibility for delays, as those can change
+ // by small amounts fairly easily.
+ let delta = target.delay / 4;
+ target.delay + delta < self.delay || target.delay > self.delay + delta
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct FlexibleAckRate {
+ current: AckRate,
+ target: AckRate,
+ next_frame_seqno: u64,
+ frame_outstanding: bool,
+ min_ack_delay: Duration,
+ ratio: u8,
+}
+
+impl FlexibleAckRate {
+ fn new(
+ max_ack_delay: Duration,
+ min_ack_delay: Duration,
+ ratio: u8,
+ cwnd: usize,
+ mtu: usize,
+ rtt: Duration,
+ ) -> Self {
+ qtrace!(
+ "FlexibleAckRate: {:?} {:?} {}",
+ max_ack_delay,
+ min_ack_delay,
+ ratio
+ );
+ let ratio = max(ACK_RATIO_SCALE, ratio); // clamp it
+ Self {
+ current: AckRate {
+ packets: 1,
+ delay: max_ack_delay,
+ },
+ target: AckRate::new(min_ack_delay, ratio, cwnd, mtu, rtt),
+ next_frame_seqno: 0,
+ frame_outstanding: false,
+ min_ack_delay,
+ ratio,
+ }
+ }
+
+ fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ if !self.frame_outstanding
+ && self.current.needs_update(&self.target)
+ && self.target.write_frame(builder, self.next_frame_seqno)
+ {
+ qtrace!("FlexibleAckRate: write frame {:?}", self.target);
+ self.frame_outstanding = true;
+ self.next_frame_seqno += 1;
+ tokens.push(RecoveryToken::AckFrequency(self.target.clone()));
+ stats.ack_frequency += 1;
+ }
+ }
+
+ fn frame_acked(&mut self, acked: &AckRate) {
+ self.frame_outstanding = false;
+ self.current = acked.clone();
+ }
+
+ fn frame_lost(&mut self, _lost: &AckRate) {
+ self.frame_outstanding = false;
+ }
+
+ fn update(&mut self, cwnd: usize, mtu: usize, rtt: Duration) {
+ self.target = AckRate::new(self.min_ack_delay, self.ratio, cwnd, mtu, rtt);
+ qtrace!("FlexibleAckRate: {:?} -> {:?}", self.current, self.target);
+ }
+
+ fn peer_ack_delay(&self) -> Duration {
+ max(self.current.delay, self.target.delay)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub enum PeerAckDelay {
+ Fixed(Duration),
+ Flexible(FlexibleAckRate),
+}
+
+impl PeerAckDelay {
+ pub fn fixed(max_ack_delay: Duration) -> Self {
+ Self::Fixed(max_ack_delay)
+ }
+
+ pub fn flexible(
+ max_ack_delay: Duration,
+ min_ack_delay: Duration,
+ ratio: u8,
+ cwnd: usize,
+ mtu: usize,
+ rtt: Duration,
+ ) -> Self {
+ Self::Flexible(FlexibleAckRate::new(
+ max_ack_delay,
+ min_ack_delay,
+ ratio,
+ cwnd,
+ mtu,
+ rtt,
+ ))
+ }
+
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ if let Self::Flexible(rate) = self {
+ rate.write_frames(builder, tokens, stats);
+ }
+ }
+
+ pub fn frame_acked(&mut self, r: &AckRate) {
+ if let Self::Flexible(rate) = self {
+ rate.frame_acked(r);
+ }
+ }
+
+ pub fn frame_lost(&mut self, r: &AckRate) {
+ if let Self::Flexible(rate) = self {
+ rate.frame_lost(r);
+ }
+ }
+
+ pub fn max(&self) -> Duration {
+ match self {
+ Self::Flexible(rate) => rate.peer_ack_delay(),
+ Self::Fixed(delay) => *delay,
+ }
+ }
+
+ pub fn update(&mut self, cwnd: usize, mtu: usize, rtt: Duration) {
+ if let Self::Flexible(rate) = self {
+ rate.update(cwnd, mtu, rtt);
+ }
+ }
+}
+
+impl Default for PeerAckDelay {
+ fn default() -> Self {
+ Self::fixed(Duration::from_millis(25))
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/addr_valid.rs b/third_party/rust/neqo-transport/src/addr_valid.rs
new file mode 100644
index 0000000000..b5ed2d07d1
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/addr_valid.rs
@@ -0,0 +1,508 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// This file implements functions necessary for address validation.
+
+use std::{
+ convert::TryFrom,
+ net::{IpAddr, SocketAddr},
+ time::{Duration, Instant},
+};
+
+use neqo_common::{qinfo, qtrace, Decoder, Encoder, Role};
+use neqo_crypto::{
+ constants::{TLS_AES_128_GCM_SHA256, TLS_VERSION_1_3},
+ selfencrypt::SelfEncrypt,
+};
+use smallvec::SmallVec;
+
+use crate::{
+ cid::ConnectionId, packet::PacketBuilder, recovery::RecoveryToken, stats::FrameStats, Res,
+};
+
+/// A prefix we add to Retry tokens to distinguish them from NEW_TOKEN tokens.
+const TOKEN_IDENTIFIER_RETRY: &[u8] = &[0x52, 0x65, 0x74, 0x72, 0x79];
+/// A prefix on NEW_TOKEN tokens, that is maximally Hamming distant from NEW_TOKEN.
+/// Together, these need to have a low probability of collision, even if there is
+/// corruption of individual bits in transit.
+const TOKEN_IDENTIFIER_NEW_TOKEN: &[u8] = &[0xad, 0x9a, 0x8b, 0x8d, 0x86];
+
+/// The maximum number of tokens we'll save from NEW_TOKEN frames.
+/// This should be the same as the value of MAX_TICKETS in neqo-crypto.
+const MAX_NEW_TOKEN: usize = 4;
+/// The number of tokens we'll track for the purposes of looking for duplicates.
+/// This is based on how many might be received over a period where could be
+/// retransmissions. It should be at least `MAX_NEW_TOKEN`.
+const MAX_SAVED_TOKENS: usize = 8;
+
+/// `ValidateAddress` determines what sort of address validation is performed.
+/// In short, this determines when a Retry packet is sent.
+#[derive(Debug, PartialEq, Eq)]
+pub enum ValidateAddress {
+ /// Require address validation never.
+ Never,
+ /// Require address validation unless a NEW_TOKEN token is provided.
+ NoToken,
+ /// Require address validation even if a NEW_TOKEN token is provided.
+ Always,
+}
+
+pub enum AddressValidationResult {
+ Pass,
+ ValidRetry(ConnectionId),
+ Validate,
+ Invalid,
+}
+
+pub struct AddressValidation {
+ /// What sort of validation is performed.
+ validation: ValidateAddress,
+ /// A self-encryption object used for protecting Retry tokens.
+ self_encrypt: SelfEncrypt,
+ /// When this object was created.
+ start_time: Instant,
+}
+
+impl AddressValidation {
+ pub fn new(now: Instant, validation: ValidateAddress) -> Res<Self> {
+ Ok(Self {
+ validation,
+ self_encrypt: SelfEncrypt::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256)?,
+ start_time: now,
+ })
+ }
+
+ fn encode_aad(peer_address: SocketAddr, retry: bool) -> Encoder {
+ // Let's be "clever" by putting the peer's address in the AAD.
+ // We don't need to encode these into the token as they should be
+ // available when we need to check the token.
+ let mut aad = Encoder::default();
+ if retry {
+ aad.encode(TOKEN_IDENTIFIER_RETRY);
+ } else {
+ aad.encode(TOKEN_IDENTIFIER_NEW_TOKEN);
+ }
+ match peer_address.ip() {
+ IpAddr::V4(a) => {
+ aad.encode_byte(4);
+ aad.encode(&a.octets());
+ }
+ IpAddr::V6(a) => {
+ aad.encode_byte(6);
+ aad.encode(&a.octets());
+ }
+ }
+ if retry {
+ aad.encode_uint(2, peer_address.port());
+ }
+ aad
+ }
+
+ pub fn generate_token(
+ &self,
+ dcid: Option<&ConnectionId>,
+ peer_address: SocketAddr,
+ now: Instant,
+ ) -> Res<Vec<u8>> {
+ const EXPIRATION_RETRY: Duration = Duration::from_secs(5);
+ const EXPIRATION_NEW_TOKEN: Duration = Duration::from_secs(60 * 60 * 24);
+
+ // TODO(mt) rotate keys on a fixed schedule.
+ let retry = dcid.is_some();
+ let mut data = Encoder::default();
+ let end = now
+ + if retry {
+ EXPIRATION_RETRY
+ } else {
+ EXPIRATION_NEW_TOKEN
+ };
+ let end_millis = u32::try_from(end.duration_since(self.start_time).as_millis())?;
+ data.encode_uint(4, end_millis);
+ if let Some(dcid) = dcid {
+ data.encode(dcid);
+ }
+
+ // Include the token identifier ("Retry"/~) in the AAD, then keep it for plaintext.
+ let mut buf = Self::encode_aad(peer_address, retry);
+ let encrypted = self.self_encrypt.seal(buf.as_ref(), data.as_ref())?;
+ buf.truncate(TOKEN_IDENTIFIER_RETRY.len());
+ buf.encode(&encrypted);
+ Ok(buf.into())
+ }
+
+ /// This generates a token for use with Retry.
+ pub fn generate_retry_token(
+ &self,
+ dcid: &ConnectionId,
+ peer_address: SocketAddr,
+ now: Instant,
+ ) -> Res<Vec<u8>> {
+ self.generate_token(Some(dcid), peer_address, now)
+ }
+
+ /// This generates a token for use with NEW_TOKEN.
+ pub fn generate_new_token(&self, peer_address: SocketAddr, now: Instant) -> Res<Vec<u8>> {
+ self.generate_token(None, peer_address, now)
+ }
+
+ pub fn set_validation(&mut self, validation: ValidateAddress) {
+ qtrace!("AddressValidation {:p}: set to {:?}", self, validation);
+ self.validation = validation;
+ }
+
+ /// Decrypts `token` and returns the connection ID it contains.
+ /// Returns a tuple with a boolean indicating whether this thinks
+ /// that the token was a Retry token, and a connection ID, that is
+ /// None if the token wasn't successfully decrypted.
+ fn decrypt_token(
+ &self,
+ token: &[u8],
+ peer_address: SocketAddr,
+ retry: bool,
+ now: Instant,
+ ) -> Option<ConnectionId> {
+ let peer_addr = Self::encode_aad(peer_address, retry);
+ let data = self.self_encrypt.open(peer_addr.as_ref(), token).ok()?;
+ let mut dec = Decoder::new(&data);
+ match dec.decode_uint(4) {
+ Some(d) => {
+ let end = self.start_time + Duration::from_millis(d);
+ if end < now {
+ qtrace!("Expired token: {:?} vs. {:?}", end, now);
+ return None;
+ }
+ }
+ _ => return None,
+ }
+ Some(ConnectionId::from(dec.decode_remainder()))
+ }
+
+ /// Calculate the Hamming difference between our identifier and the target.
+ /// Less than one difference per byte indicates that it is likely not a Retry.
+ /// This generous interpretation allows for a lot of damage in transit.
+ /// Note that if this check fails, then the token will be treated like it came
+ /// from NEW_TOKEN instead. If there truly is corruption of packets that causes
+ /// validation failure, it will be a failure that we try to recover from.
+ fn is_likely_retry(token: &[u8]) -> bool {
+ let mut difference = 0;
+ for i in 0..TOKEN_IDENTIFIER_RETRY.len() {
+ difference += (token[i] ^ TOKEN_IDENTIFIER_RETRY[i]).count_ones();
+ }
+ usize::try_from(difference).unwrap() < TOKEN_IDENTIFIER_RETRY.len()
+ }
+
+ pub fn validate(
+ &self,
+ token: &[u8],
+ peer_address: SocketAddr,
+ now: Instant,
+ ) -> AddressValidationResult {
+ qtrace!(
+ "AddressValidation {:p}: validate {:?}",
+ self,
+ self.validation
+ );
+
+ if token.is_empty() {
+ if self.validation == ValidateAddress::Never {
+ qinfo!("AddressValidation: no token; accepting");
+ return AddressValidationResult::Pass;
+ } else {
+ qinfo!("AddressValidation: no token; validating");
+ return AddressValidationResult::Validate;
+ }
+ }
+ if token.len() <= TOKEN_IDENTIFIER_RETRY.len() {
+ // Treat bad tokens strictly.
+ qinfo!("AddressValidation: too short token");
+ return AddressValidationResult::Invalid;
+ }
+ let retry = Self::is_likely_retry(token);
+ let enc = &token[TOKEN_IDENTIFIER_RETRY.len()..];
+ // Note that this allows the token identifier part to be corrupted.
+ // That's OK here as we don't depend on that being authenticated.
+ if let Some(cid) = self.decrypt_token(enc, peer_address, retry, now) {
+ if retry {
+ // This is from Retry, so we should have an ODCID >= 8.
+ if cid.len() >= 8 {
+ qinfo!("AddressValidation: valid Retry token for {}", cid);
+ AddressValidationResult::ValidRetry(cid)
+ } else {
+ panic!("AddressValidation: Retry token with small CID {}", cid);
+ }
+ } else if cid.is_empty() {
+ // An empty connection ID means NEW_TOKEN.
+ if self.validation == ValidateAddress::Always {
+ qinfo!("AddressValidation: valid NEW_TOKEN token; validating again");
+ AddressValidationResult::Validate
+ } else {
+ qinfo!("AddressValidation: valid NEW_TOKEN token; accepting");
+ AddressValidationResult::Pass
+ }
+ } else {
+ panic!("AddressValidation: NEW_TOKEN token with CID {}", cid);
+ }
+ } else {
+ // From here on, we have a token that we couldn't decrypt.
+ // We've either lost the keys or we've received junk.
+ if retry {
+ // If this looked like a Retry, treat it as being bad.
+ qinfo!("AddressValidation: invalid Retry token; rejecting");
+ AddressValidationResult::Invalid
+ } else if self.validation == ValidateAddress::Never {
+ // We don't require validation, so OK.
+ qinfo!("AddressValidation: invalid NEW_TOKEN token; accepting");
+ AddressValidationResult::Pass
+ } else {
+ // This might be an invalid NEW_TOKEN token, or a valid one
+ // for which we have since lost the keys. Check again.
+ qinfo!("AddressValidation: invalid NEW_TOKEN token; validating again");
+ AddressValidationResult::Validate
+ }
+ }
+ }
+}
+
+// Note: these lint override can be removed in later versions where the lints
+// either don't trip a false positive or don't apply. rustc 1.46 is fine.
+#[allow(dead_code, clippy::large_enum_variant)]
+pub enum NewTokenState {
+ Client {
+ /// Tokens that haven't been taken yet.
+ pending: SmallVec<[Vec<u8>; MAX_NEW_TOKEN]>,
+ /// Tokens that have been taken, saved so that we can discard duplicates.
+ old: SmallVec<[Vec<u8>; MAX_SAVED_TOKENS]>,
+ },
+ Server(NewTokenSender),
+}
+
+impl NewTokenState {
+ pub fn new(role: Role) -> Self {
+ match role {
+ Role::Client => Self::Client {
+ pending: SmallVec::<[_; MAX_NEW_TOKEN]>::new(),
+ old: SmallVec::<[_; MAX_SAVED_TOKENS]>::new(),
+ },
+ Role::Server => Self::Server(NewTokenSender::default()),
+ }
+ }
+
+ /// Is there a token available?
+ pub fn has_token(&self) -> bool {
+ match self {
+ Self::Client { ref pending, .. } => !pending.is_empty(),
+ Self::Server(..) => false,
+ }
+ }
+
+ /// If this is a client, take a token if there is one.
+ /// If this is a server, panic.
+ pub fn take_token(&mut self) -> Option<&[u8]> {
+ if let Self::Client {
+ ref mut pending,
+ ref mut old,
+ } = self
+ {
+ if let Some(t) = pending.pop() {
+ if old.len() >= MAX_SAVED_TOKENS {
+ old.remove(0);
+ }
+ old.push(t);
+ Some(&old[old.len() - 1])
+ } else {
+ None
+ }
+ } else {
+ unreachable!();
+ }
+ }
+
+ /// If this is a client, save a token.
+ /// If this is a server, panic.
+ pub fn save_token(&mut self, token: Vec<u8>) {
+ if let Self::Client {
+ ref mut pending,
+ ref old,
+ } = self
+ {
+ for t in old.iter().rev().chain(pending.iter().rev()) {
+ if t == &token {
+ qinfo!("NewTokenState discarding duplicate NEW_TOKEN");
+ return;
+ }
+ }
+
+ if pending.len() >= MAX_NEW_TOKEN {
+ pending.remove(0);
+ }
+ pending.push(token);
+ } else {
+ unreachable!();
+ }
+ }
+
+ /// If this is a server, maybe send a frame.
+ /// If this is a client, do nothing.
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) -> Res<()> {
+ if let Self::Server(ref mut sender) = self {
+ sender.write_frames(builder, tokens, stats)?;
+ }
+ Ok(())
+ }
+
+ /// If this a server, buffer a NEW_TOKEN for sending.
+ /// If this is a client, panic.
+ pub fn send_new_token(&mut self, token: Vec<u8>) {
+ if let Self::Server(ref mut sender) = self {
+ sender.send_new_token(token);
+ } else {
+ unreachable!();
+ }
+ }
+
+ /// If this a server, process a lost signal for a NEW_TOKEN frame.
+ /// If this is a client, panic.
+ pub fn lost(&mut self, seqno: usize) {
+ if let Self::Server(ref mut sender) = self {
+ sender.lost(seqno);
+ } else {
+ unreachable!();
+ }
+ }
+
+ /// If this a server, process remove the acknowledged NEW_TOKEN frame.
+ /// If this is a client, panic.
+ pub fn acked(&mut self, seqno: usize) {
+ if let Self::Server(ref mut sender) = self {
+ sender.acked(seqno);
+ } else {
+ unreachable!();
+ }
+ }
+}
+
+struct NewTokenFrameStatus {
+ seqno: usize,
+ token: Vec<u8>,
+ needs_sending: bool,
+}
+
+impl NewTokenFrameStatus {
+ fn len(&self) -> usize {
+ 1 + Encoder::vvec_len(self.token.len())
+ }
+}
+
+#[derive(Default)]
+pub struct NewTokenSender {
+ /// The unacknowledged NEW_TOKEN frames we are yet to send.
+ tokens: Vec<NewTokenFrameStatus>,
+ /// A sequence number that is used to track individual tokens
+ /// by reference (so that recovery tokens can be simple).
+ next_seqno: usize,
+}
+
+impl NewTokenSender {
+ /// Add a token to be sent.
+ pub fn send_new_token(&mut self, token: Vec<u8>) {
+ self.tokens.push(NewTokenFrameStatus {
+ seqno: self.next_seqno,
+ token,
+ needs_sending: true,
+ });
+ self.next_seqno += 1;
+ }
+
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) -> Res<()> {
+ for t in self.tokens.iter_mut() {
+ if t.needs_sending && t.len() <= builder.remaining() {
+ t.needs_sending = false;
+
+ builder.encode_varint(crate::frame::FRAME_TYPE_NEW_TOKEN);
+ builder.encode_vvec(&t.token);
+
+ tokens.push(RecoveryToken::NewToken(t.seqno));
+ stats.new_token += 1;
+ }
+ }
+ Ok(())
+ }
+
+ pub fn lost(&mut self, seqno: usize) {
+ for t in self.tokens.iter_mut() {
+ if t.seqno == seqno {
+ t.needs_sending = true;
+ break;
+ }
+ }
+ }
+
+ pub fn acked(&mut self, seqno: usize) {
+ self.tokens.retain(|i| i.seqno != seqno);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use neqo_common::Role;
+
+ use super::NewTokenState;
+
+ const ONE: &[u8] = &[1, 2, 3];
+ const TWO: &[u8] = &[4, 5];
+
+ #[test]
+ fn duplicate_saved() {
+ let mut tokens = NewTokenState::new(Role::Client);
+ tokens.save_token(ONE.to_vec());
+ tokens.save_token(TWO.to_vec());
+ tokens.save_token(ONE.to_vec());
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably TWO
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably ONE
+ assert!(!tokens.has_token());
+ assert!(tokens.take_token().is_none());
+ }
+
+ #[test]
+ fn duplicate_after_take() {
+ let mut tokens = NewTokenState::new(Role::Client);
+ tokens.save_token(ONE.to_vec());
+ tokens.save_token(TWO.to_vec());
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably TWO
+ tokens.save_token(ONE.to_vec());
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably ONE
+ assert!(!tokens.has_token());
+ assert!(tokens.take_token().is_none());
+ }
+
+ #[test]
+ fn duplicate_after_empty() {
+ let mut tokens = NewTokenState::new(Role::Client);
+ tokens.save_token(ONE.to_vec());
+ tokens.save_token(TWO.to_vec());
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably TWO
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably ONE
+ tokens.save_token(ONE.to_vec());
+ assert!(!tokens.has_token());
+ assert!(tokens.take_token().is_none());
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/cc/classic_cc.rs b/third_party/rust/neqo-transport/src/cc/classic_cc.rs
new file mode 100644
index 0000000000..6f4a01d795
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cc/classic_cc.rs
@@ -0,0 +1,1186 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Congestion control
+#![deny(clippy::pedantic)]
+
+use std::{
+ cmp::{max, min},
+ fmt::{self, Debug, Display},
+ time::{Duration, Instant},
+};
+
+use super::CongestionControl;
+use crate::{
+ cc::MAX_DATAGRAM_SIZE,
+ packet::PacketNumber,
+ qlog::{self, QlogMetric},
+ rtt::RttEstimate,
+ sender::PACING_BURST_SIZE,
+ tracking::SentPacket,
+};
+#[rustfmt::skip] // to keep `::` and thus prevent conflict with `crate::qlog`
+use ::qlog::events::{quic::CongestionStateUpdated, EventData};
+use neqo_common::{const_max, const_min, qdebug, qinfo, qlog::NeqoQlog, qtrace};
+
+pub const CWND_INITIAL_PKTS: usize = 10;
+pub const CWND_INITIAL: usize = const_min(
+ CWND_INITIAL_PKTS * MAX_DATAGRAM_SIZE,
+ const_max(2 * MAX_DATAGRAM_SIZE, 14720),
+);
+pub const CWND_MIN: usize = MAX_DATAGRAM_SIZE * 2;
+const PERSISTENT_CONG_THRESH: u32 = 3;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum State {
+ /// In either slow start or congestion avoidance, not recovery.
+ SlowStart,
+ /// In congestion avoidance.
+ CongestionAvoidance,
+ /// In a recovery period, but no packets have been sent yet. This is a
+ /// transient state because we want to exempt the first packet sent after
+ /// entering recovery from the congestion window.
+ RecoveryStart,
+ /// In a recovery period, with the first packet sent at this time.
+ Recovery,
+ /// Start of persistent congestion, which is transient, like `RecoveryStart`.
+ PersistentCongestion,
+}
+
+impl State {
+ pub fn in_recovery(self) -> bool {
+ matches!(self, Self::RecoveryStart | Self::Recovery)
+ }
+
+ pub fn in_slow_start(self) -> bool {
+ self == Self::SlowStart
+ }
+
+ /// These states are transient, we tell qlog on entry, but not on exit.
+ pub fn transient(self) -> bool {
+ matches!(self, Self::RecoveryStart | Self::PersistentCongestion)
+ }
+
+ /// Update a transient state to the true state.
+ pub fn update(&mut self) {
+ *self = match self {
+ Self::PersistentCongestion => Self::SlowStart,
+ Self::RecoveryStart => Self::Recovery,
+ _ => unreachable!(),
+ };
+ }
+
+ pub fn to_qlog(self) -> &'static str {
+ match self {
+ Self::SlowStart | Self::PersistentCongestion => "slow_start",
+ Self::CongestionAvoidance => "congestion_avoidance",
+ Self::Recovery | Self::RecoveryStart => "recovery",
+ }
+ }
+}
+
+pub trait WindowAdjustment: Display + Debug {
+ /// This is called when an ack is received.
+ /// The function calculates the amount of acked bytes congestion controller needs
+ /// to collect before increasing its cwnd by `MAX_DATAGRAM_SIZE`.
+ fn bytes_for_cwnd_increase(
+ &mut self,
+ curr_cwnd: usize,
+ new_acked_bytes: usize,
+ min_rtt: Duration,
+ now: Instant,
+ ) -> usize;
+ /// This function is called when a congestion event has beed detected and it
+ /// returns new (decreased) values of `curr_cwnd` and `acked_bytes`.
+ /// This value can be very small; the calling code is responsible for ensuring that the
+ /// congestion window doesn't drop below the minimum of `CWND_MIN`.
+ fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize);
+ /// Cubic needs this signal to reset its epoch.
+ fn on_app_limited(&mut self);
+ #[cfg(test)]
+ fn last_max_cwnd(&self) -> f64;
+ #[cfg(test)]
+ fn set_last_max_cwnd(&mut self, last_max_cwnd: f64);
+}
+
+#[derive(Debug)]
+pub struct ClassicCongestionControl<T> {
+ cc_algorithm: T,
+ state: State,
+ congestion_window: usize, // = kInitialWindow
+ bytes_in_flight: usize,
+ acked_bytes: usize,
+ ssthresh: usize,
+ recovery_start: Option<PacketNumber>,
+ /// `first_app_limited` indicates the packet number after which the application might be
+ /// underutilizing the congestion window. When underutilizing the congestion window due to not
+ /// sending out enough data, we SHOULD NOT increase the congestion window.[1] Packets sent
+ /// before this point are deemed to fully utilize the congestion window and count towards
+ /// increasing the congestion window.
+ ///
+ /// [1]: https://datatracker.ietf.org/doc/html/rfc9002#section-7.8
+ first_app_limited: PacketNumber,
+
+ qlog: NeqoQlog,
+}
+
+impl<T: WindowAdjustment> Display for ClassicCongestionControl<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{} CongCtrl {}/{} ssthresh {}",
+ self.cc_algorithm, self.bytes_in_flight, self.congestion_window, self.ssthresh,
+ )?;
+ Ok(())
+ }
+}
+
+impl<T: WindowAdjustment> CongestionControl for ClassicCongestionControl<T> {
+ fn set_qlog(&mut self, qlog: NeqoQlog) {
+ self.qlog = qlog;
+ }
+
+ #[must_use]
+ fn cwnd(&self) -> usize {
+ self.congestion_window
+ }
+
+ #[must_use]
+ fn bytes_in_flight(&self) -> usize {
+ self.bytes_in_flight
+ }
+
+ #[must_use]
+ fn cwnd_avail(&self) -> usize {
+ // BIF can be higher than cwnd due to PTO packets, which are sent even
+ // if avail is 0, but still count towards BIF.
+ self.congestion_window.saturating_sub(self.bytes_in_flight)
+ }
+
+ // Multi-packet version of OnPacketAckedCC
+ fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], rtt_est: &RttEstimate, now: Instant) {
+ let mut is_app_limited = true;
+ let mut new_acked = 0;
+ for pkt in acked_pkts {
+ qinfo!(
+ "packet_acked this={:p}, pn={}, ps={}, ignored={}, lost={}, rtt_est={:?}",
+ self,
+ pkt.pn,
+ pkt.size,
+ i32::from(!pkt.cc_outstanding()),
+ i32::from(pkt.lost()),
+ rtt_est,
+ );
+ if !pkt.cc_outstanding() {
+ continue;
+ }
+ if pkt.pn < self.first_app_limited {
+ is_app_limited = false;
+ }
+ assert!(self.bytes_in_flight >= pkt.size);
+ self.bytes_in_flight -= pkt.size;
+
+ if !self.after_recovery_start(pkt) {
+ // Do not increase congestion window for packets sent before
+ // recovery last started.
+ continue;
+ }
+
+ if self.state.in_recovery() {
+ self.set_state(State::CongestionAvoidance);
+ qlog::metrics_updated(&mut self.qlog, &[QlogMetric::InRecovery(false)]);
+ }
+
+ new_acked += pkt.size;
+ }
+
+ if is_app_limited {
+ self.cc_algorithm.on_app_limited();
+ qinfo!("on_packets_acked this={:p}, limited=1, bytes_in_flight={}, cwnd={}, state={:?}, new_acked={}", self, self.bytes_in_flight, self.congestion_window, self.state, new_acked);
+ return;
+ }
+
+ // Slow start, up to the slow start threshold.
+ if self.congestion_window < self.ssthresh {
+ self.acked_bytes += new_acked;
+ let increase = min(self.ssthresh - self.congestion_window, self.acked_bytes);
+ self.congestion_window += increase;
+ self.acked_bytes -= increase;
+ qinfo!([self], "slow start += {}", increase);
+ if self.congestion_window == self.ssthresh {
+ // This doesn't look like it is necessary, but it can happen
+ // after persistent congestion.
+ self.set_state(State::CongestionAvoidance);
+ }
+ }
+ // Congestion avoidance, above the slow start threshold.
+ if self.congestion_window >= self.ssthresh {
+ // The following function return the amount acked bytes a controller needs
+ // to collect to be allowed to increase its cwnd by MAX_DATAGRAM_SIZE.
+ let bytes_for_increase = self.cc_algorithm.bytes_for_cwnd_increase(
+ self.congestion_window,
+ new_acked,
+ rtt_est.minimum(),
+ now,
+ );
+ debug_assert!(bytes_for_increase > 0);
+ // If enough credit has been accumulated already, apply them gradually.
+ // If we have sudden increase in allowed rate we actually increase cwnd gently.
+ if self.acked_bytes >= bytes_for_increase {
+ self.acked_bytes = 0;
+ self.congestion_window += MAX_DATAGRAM_SIZE;
+ }
+ self.acked_bytes += new_acked;
+ if self.acked_bytes >= bytes_for_increase {
+ self.acked_bytes -= bytes_for_increase;
+ self.congestion_window += MAX_DATAGRAM_SIZE; // or is this the current MTU?
+ }
+ // The number of bytes we require can go down over time with Cubic.
+ // That might result in an excessive rate of increase, so limit the number of unused
+ // acknowledged bytes after increasing the congestion window twice.
+ self.acked_bytes = min(bytes_for_increase, self.acked_bytes);
+ }
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[
+ QlogMetric::CongestionWindow(self.congestion_window),
+ QlogMetric::BytesInFlight(self.bytes_in_flight),
+ ],
+ );
+ qinfo!([self], "on_packets_acked this={:p}, limited=0, bytes_in_flight={}, cwnd={}, state={:?}, new_acked={}", self, self.bytes_in_flight, self.congestion_window, self.state, new_acked);
+ }
+
+ /// Update congestion controller state based on lost packets.
+ fn on_packets_lost(
+ &mut self,
+ first_rtt_sample_time: Option<Instant>,
+ prev_largest_acked_sent: Option<Instant>,
+ pto: Duration,
+ lost_packets: &[SentPacket],
+ ) -> bool {
+ if lost_packets.is_empty() {
+ return false;
+ }
+
+ for pkt in lost_packets.iter().filter(|pkt| pkt.cc_in_flight()) {
+ qinfo!(
+ "packet_lost this={:p}, pn={}, ps={}",
+ self,
+ pkt.pn,
+ pkt.size
+ );
+ assert!(self.bytes_in_flight >= pkt.size);
+ self.bytes_in_flight -= pkt.size;
+ }
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[QlogMetric::BytesInFlight(self.bytes_in_flight)],
+ );
+
+ let congestion = self.on_congestion_event(lost_packets.last().unwrap());
+ let persistent_congestion = self.detect_persistent_congestion(
+ first_rtt_sample_time,
+ prev_largest_acked_sent,
+ pto,
+ lost_packets,
+ );
+ qinfo!(
+ "on_packets_lost this={:p}, bytes_in_flight={}, cwnd={}, state={:?}",
+ self,
+ self.bytes_in_flight,
+ self.congestion_window,
+ self.state
+ );
+ congestion || persistent_congestion
+ }
+
+ fn discard(&mut self, pkt: &SentPacket) {
+ if pkt.cc_outstanding() {
+ assert!(self.bytes_in_flight >= pkt.size);
+ self.bytes_in_flight -= pkt.size;
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[QlogMetric::BytesInFlight(self.bytes_in_flight)],
+ );
+ qtrace!([self], "Ignore pkt with size {}", pkt.size);
+ }
+ }
+
+ fn discard_in_flight(&mut self) {
+ self.bytes_in_flight = 0;
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[QlogMetric::BytesInFlight(self.bytes_in_flight)],
+ );
+ }
+
+ fn on_packet_sent(&mut self, pkt: &SentPacket) {
+ // Record the recovery time and exit any transient state.
+ if self.state.transient() {
+ self.recovery_start = Some(pkt.pn);
+ self.state.update();
+ }
+
+ if !pkt.cc_in_flight() {
+ return;
+ }
+ if !self.app_limited() {
+ // Given the current non-app-limited condition, we're fully utilizing the congestion
+ // window. Assume that all in-flight packets up to this one are NOT app-limited.
+ // However, subsequent packets might be app-limited. Set `first_app_limited` to the
+ // next packet number.
+ self.first_app_limited = pkt.pn + 1;
+ }
+
+ self.bytes_in_flight += pkt.size;
+ qinfo!(
+ "packet_sent this={:p}, pn={}, ps={}",
+ self,
+ pkt.pn,
+ pkt.size
+ );
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[QlogMetric::BytesInFlight(self.bytes_in_flight)],
+ );
+ }
+
+ /// Whether a packet can be sent immediately as a result of entering recovery.
+ fn recovery_packet(&self) -> bool {
+ self.state == State::RecoveryStart
+ }
+}
+
+impl<T: WindowAdjustment> ClassicCongestionControl<T> {
+ pub fn new(cc_algorithm: T) -> Self {
+ Self {
+ cc_algorithm,
+ state: State::SlowStart,
+ congestion_window: CWND_INITIAL,
+ bytes_in_flight: 0,
+ acked_bytes: 0,
+ ssthresh: usize::MAX,
+ recovery_start: None,
+ qlog: NeqoQlog::disabled(),
+ first_app_limited: 0,
+ }
+ }
+
+ #[cfg(test)]
+ #[must_use]
+ pub fn ssthresh(&self) -> usize {
+ self.ssthresh
+ }
+
+ #[cfg(test)]
+ pub fn set_ssthresh(&mut self, v: usize) {
+ self.ssthresh = v;
+ }
+
+ #[cfg(test)]
+ pub fn last_max_cwnd(&self) -> f64 {
+ self.cc_algorithm.last_max_cwnd()
+ }
+
+ #[cfg(test)]
+ pub fn set_last_max_cwnd(&mut self, last_max_cwnd: f64) {
+ self.cc_algorithm.set_last_max_cwnd(last_max_cwnd);
+ }
+
+ #[cfg(test)]
+ pub fn acked_bytes(&self) -> usize {
+ self.acked_bytes
+ }
+
+ fn set_state(&mut self, state: State) {
+ if self.state != state {
+ qdebug!([self], "state -> {:?}", state);
+ let old_state = self.state;
+ self.qlog.add_event_data(|| {
+ // No need to tell qlog about exit from transient states.
+ if old_state.transient() {
+ None
+ } else {
+ let ev_data = EventData::CongestionStateUpdated(CongestionStateUpdated {
+ old: Some(old_state.to_qlog().to_owned()),
+ new: state.to_qlog().to_owned(),
+ trigger: None,
+ });
+ Some(ev_data)
+ }
+ });
+ self.state = state;
+ }
+ }
+
+ fn detect_persistent_congestion(
+ &mut self,
+ first_rtt_sample_time: Option<Instant>,
+ prev_largest_acked_sent: Option<Instant>,
+ pto: Duration,
+ lost_packets: &[SentPacket],
+ ) -> bool {
+ if first_rtt_sample_time.is_none() {
+ return false;
+ }
+
+ let pc_period = pto * PERSISTENT_CONG_THRESH;
+
+ let mut last_pn = 1 << 62; // Impossibly large, but not enough to overflow.
+ let mut start = None;
+
+ // Look for the first lost packet after the previous largest acknowledged.
+ // Ignore packets that weren't ack-eliciting for the start of this range.
+ // Also, make sure to ignore any packets sent before we got an RTT estimate
+ // as we might not have sent PTO packets soon enough after those.
+ let cutoff = max(first_rtt_sample_time, prev_largest_acked_sent);
+ for p in lost_packets
+ .iter()
+ .skip_while(|p| Some(p.time_sent) < cutoff)
+ {
+ if p.pn != last_pn + 1 {
+ // Not a contiguous range of lost packets, start over.
+ start = None;
+ }
+ last_pn = p.pn;
+ if !p.cc_in_flight() {
+ // Not interesting, keep looking.
+ continue;
+ }
+ if let Some(t) = start {
+ let elapsed = p
+ .time_sent
+ .checked_duration_since(t)
+ .expect("time is monotonic");
+ if elapsed > pc_period {
+ qinfo!([self], "persistent congestion");
+ self.congestion_window = CWND_MIN;
+ self.acked_bytes = 0;
+ self.set_state(State::PersistentCongestion);
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[QlogMetric::CongestionWindow(self.congestion_window)],
+ );
+ return true;
+ }
+ } else {
+ start = Some(p.time_sent);
+ }
+ }
+ false
+ }
+
+ #[must_use]
+ fn after_recovery_start(&mut self, packet: &SentPacket) -> bool {
+ // At the start of the recovery period, the state is transient and
+ // all packets will have been sent before recovery. When sending out
+ // the first packet we transition to the non-transient `Recovery`
+ // state and update the variable `self.recovery_start`. Before the
+ // first recovery, all packets were sent after the recovery event,
+ // allowing to reduce the cwnd on congestion events.
+ !self.state.transient() && self.recovery_start.map_or(true, |pn| packet.pn >= pn)
+ }
+
+ /// Handle a congestion event.
+ /// Returns true if this was a true congestion event.
+ fn on_congestion_event(&mut self, last_packet: &SentPacket) -> bool {
+ // Start a new congestion event if lost packet was sent after the start
+ // of the previous congestion recovery period.
+ if !self.after_recovery_start(last_packet) {
+ return false;
+ }
+
+ let (cwnd, acked_bytes) = self
+ .cc_algorithm
+ .reduce_cwnd(self.congestion_window, self.acked_bytes);
+ self.congestion_window = max(cwnd, CWND_MIN);
+ self.acked_bytes = acked_bytes;
+ self.ssthresh = self.congestion_window;
+ qinfo!(
+ [self],
+ "Cong event -> recovery; cwnd {}, ssthresh {}",
+ self.congestion_window,
+ self.ssthresh
+ );
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[
+ QlogMetric::CongestionWindow(self.congestion_window),
+ QlogMetric::SsThresh(self.ssthresh),
+ QlogMetric::InRecovery(true),
+ ],
+ );
+ self.set_state(State::RecoveryStart);
+ true
+ }
+
+ #[allow(clippy::unused_self)]
+ fn app_limited(&self) -> bool {
+ if self.bytes_in_flight >= self.congestion_window {
+ false
+ } else if self.state.in_slow_start() {
+ // Allow for potential doubling of the congestion window during slow start.
+ // That is, the application might not have been able to send enough to respond
+ // to increases to the congestion window.
+ self.bytes_in_flight < self.congestion_window / 2
+ } else {
+ // We're not limited if the in-flight data is within a single burst of the
+ // congestion window.
+ (self.bytes_in_flight + MAX_DATAGRAM_SIZE * PACING_BURST_SIZE) < self.congestion_window
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::{
+ convert::TryFrom,
+ time::{Duration, Instant},
+ };
+
+ use neqo_common::qinfo;
+ use test_fixture::now;
+
+ use super::{
+ ClassicCongestionControl, WindowAdjustment, CWND_INITIAL, CWND_MIN, PERSISTENT_CONG_THRESH,
+ };
+ use crate::{
+ cc::{
+ classic_cc::State,
+ cubic::{Cubic, CUBIC_BETA_USIZE_DIVIDEND, CUBIC_BETA_USIZE_DIVISOR},
+ new_reno::NewReno,
+ CongestionControl, CongestionControlAlgorithm, CWND_INITIAL_PKTS, MAX_DATAGRAM_SIZE,
+ },
+ packet::{PacketNumber, PacketType},
+ rtt::RttEstimate,
+ tracking::SentPacket,
+ };
+
+ const PTO: Duration = Duration::from_millis(100);
+ const RTT: Duration = Duration::from_millis(98);
+ const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(98));
+ const ZERO: Duration = Duration::from_secs(0);
+ const EPSILON: Duration = Duration::from_nanos(1);
+ const GAP: Duration = Duration::from_secs(1);
+ /// The largest time between packets without causing persistent congestion.
+ const SUB_PC: Duration = Duration::from_millis(100 * PERSISTENT_CONG_THRESH as u64);
+ /// The minimum time between packets to cause persistent congestion.
+ /// Uses an odd expression because `Duration` arithmetic isn't `const`.
+ const PC: Duration = Duration::from_nanos(100_000_000 * (PERSISTENT_CONG_THRESH as u64) + 1);
+
+ fn cwnd_is_default(cc: &ClassicCongestionControl<NewReno>) {
+ assert_eq!(cc.cwnd(), CWND_INITIAL);
+ assert_eq!(cc.ssthresh(), usize::MAX);
+ }
+
+ fn cwnd_is_halved(cc: &ClassicCongestionControl<NewReno>) {
+ assert_eq!(cc.cwnd(), CWND_INITIAL / 2);
+ assert_eq!(cc.ssthresh(), CWND_INITIAL / 2);
+ }
+
+ fn lost(pn: PacketNumber, ack_eliciting: bool, t: Duration) -> SentPacket {
+ SentPacket::new(
+ PacketType::Short,
+ pn,
+ now() + t,
+ ack_eliciting,
+ Vec::new(),
+ 100,
+ )
+ }
+
+ fn congestion_control(cc: CongestionControlAlgorithm) -> Box<dyn CongestionControl> {
+ match cc {
+ CongestionControlAlgorithm::NewReno => {
+ Box::new(ClassicCongestionControl::new(NewReno::default()))
+ }
+ CongestionControlAlgorithm::Cubic => {
+ Box::new(ClassicCongestionControl::new(Cubic::default()))
+ }
+ }
+ }
+
+ fn persistent_congestion_by_algorithm(
+ cc_alg: CongestionControlAlgorithm,
+ reduced_cwnd: usize,
+ lost_packets: &[SentPacket],
+ persistent_expected: bool,
+ ) {
+ let mut cc = congestion_control(cc_alg);
+ for p in lost_packets {
+ cc.on_packet_sent(p);
+ }
+
+ cc.on_packets_lost(Some(now()), None, PTO, lost_packets);
+
+ let persistent = if cc.cwnd() == reduced_cwnd {
+ false
+ } else if cc.cwnd() == CWND_MIN {
+ true
+ } else {
+ panic!("unexpected cwnd");
+ };
+ assert_eq!(persistent, persistent_expected);
+ }
+
+ fn persistent_congestion(lost_packets: &[SentPacket], persistent_expected: bool) {
+ persistent_congestion_by_algorithm(
+ CongestionControlAlgorithm::NewReno,
+ CWND_INITIAL / 2,
+ lost_packets,
+ persistent_expected,
+ );
+ persistent_congestion_by_algorithm(
+ CongestionControlAlgorithm::Cubic,
+ CWND_INITIAL * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR,
+ lost_packets,
+ persistent_expected,
+ );
+ }
+
+ /// A span of exactly the PC threshold only reduces the window on loss.
+ #[test]
+ fn persistent_congestion_none() {
+ persistent_congestion(&[lost(1, true, ZERO), lost(2, true, SUB_PC)], false);
+ }
+
+ /// A span of just more than the PC threshold causes persistent congestion.
+ #[test]
+ fn persistent_congestion_simple() {
+ persistent_congestion(&[lost(1, true, ZERO), lost(2, true, PC)], true);
+ }
+
+ /// Both packets need to be ack-eliciting.
+ #[test]
+ fn persistent_congestion_non_ack_eliciting() {
+ persistent_congestion(&[lost(1, false, ZERO), lost(2, true, PC)], false);
+ persistent_congestion(&[lost(1, true, ZERO), lost(2, false, PC)], false);
+ }
+
+ /// Packets in the middle, of any type, are OK.
+ #[test]
+ fn persistent_congestion_middle() {
+ persistent_congestion(
+ &[lost(1, true, ZERO), lost(2, false, RTT), lost(3, true, PC)],
+ true,
+ );
+ persistent_congestion(
+ &[lost(1, true, ZERO), lost(2, true, RTT), lost(3, true, PC)],
+ true,
+ );
+ }
+
+ /// Leading non-ack-eliciting packets are skipped.
+ #[test]
+ fn persistent_congestion_leading_non_ack_eliciting() {
+ persistent_congestion(
+ &[lost(1, false, ZERO), lost(2, true, RTT), lost(3, true, PC)],
+ false,
+ );
+ persistent_congestion(
+ &[
+ lost(1, false, ZERO),
+ lost(2, true, RTT),
+ lost(3, true, RTT + PC),
+ ],
+ true,
+ );
+ }
+
+ /// Trailing non-ack-eliciting packets aren't relevant.
+ #[test]
+ fn persistent_congestion_trailing_non_ack_eliciting() {
+ persistent_congestion(
+ &[
+ lost(1, true, ZERO),
+ lost(2, true, PC),
+ lost(3, false, PC + EPSILON),
+ ],
+ true,
+ );
+ persistent_congestion(
+ &[
+ lost(1, true, ZERO),
+ lost(2, true, SUB_PC),
+ lost(3, false, PC),
+ ],
+ false,
+ );
+ }
+
+ /// Gaps in the middle, of any type, restart the count.
+ #[test]
+ fn persistent_congestion_gap_reset() {
+ persistent_congestion(&[lost(1, true, ZERO), lost(3, true, PC)], false);
+ persistent_congestion(
+ &[
+ lost(1, true, ZERO),
+ lost(2, true, RTT),
+ lost(4, true, GAP),
+ lost(5, true, GAP + PTO * PERSISTENT_CONG_THRESH),
+ ],
+ false,
+ );
+ }
+
+ /// A span either side of a gap will cause persistent congestion.
+ #[test]
+ fn persistent_congestion_gap_or() {
+ persistent_congestion(
+ &[
+ lost(1, true, ZERO),
+ lost(2, true, PC),
+ lost(4, true, GAP),
+ lost(5, true, GAP + PTO),
+ ],
+ true,
+ );
+ persistent_congestion(
+ &[
+ lost(1, true, ZERO),
+ lost(2, true, PTO),
+ lost(4, true, GAP),
+ lost(5, true, GAP + PC),
+ ],
+ true,
+ );
+ }
+
+ /// A gap only restarts after an ack-eliciting packet.
+ #[test]
+ fn persistent_congestion_gap_non_ack_eliciting() {
+ persistent_congestion(
+ &[
+ lost(1, true, ZERO),
+ lost(2, true, PTO),
+ lost(4, false, GAP),
+ lost(5, true, GAP + PC),
+ ],
+ false,
+ );
+ persistent_congestion(
+ &[
+ lost(1, true, ZERO),
+ lost(2, true, PTO),
+ lost(4, false, GAP),
+ lost(5, true, GAP + RTT),
+ lost(6, true, GAP + RTT + SUB_PC),
+ ],
+ false,
+ );
+ persistent_congestion(
+ &[
+ lost(1, true, ZERO),
+ lost(2, true, PTO),
+ lost(4, false, GAP),
+ lost(5, true, GAP + RTT),
+ lost(6, true, GAP + RTT + PC),
+ ],
+ true,
+ );
+ }
+
+ /// Get a time, in multiples of `PTO`, relative to `now()`.
+ fn by_pto(t: u32) -> Instant {
+ now() + (PTO * t)
+ }
+
+ /// Make packets that will be made lost.
+ /// `times` is the time of sending, in multiples of `PTO`, relative to `now()`.
+ fn make_lost(times: &[u32]) -> Vec<SentPacket> {
+ times
+ .iter()
+ .enumerate()
+ .map(|(i, &t)| {
+ SentPacket::new(
+ PacketType::Short,
+ u64::try_from(i).unwrap(),
+ by_pto(t),
+ true,
+ Vec::new(),
+ 1000,
+ )
+ })
+ .collect::<Vec<_>>()
+ }
+
+ /// Call `detect_persistent_congestion` using times relative to now and the fixed PTO time.
+ /// `last_ack` and `rtt_time` are times in multiples of `PTO`, relative to `now()`,
+ /// for the time of the largest acknowledged and the first RTT sample, respectively.
+ fn persistent_congestion_by_pto<T: WindowAdjustment>(
+ mut cc: ClassicCongestionControl<T>,
+ last_ack: u32,
+ rtt_time: u32,
+ lost: &[SentPacket],
+ ) -> bool {
+ assert_eq!(cc.cwnd(), CWND_INITIAL);
+
+ let last_ack = Some(by_pto(last_ack));
+ let rtt_time = Some(by_pto(rtt_time));
+
+ // Persistent congestion is never declared if the RTT time is `None`.
+ cc.detect_persistent_congestion(None, None, PTO, lost);
+ assert_eq!(cc.cwnd(), CWND_INITIAL);
+ cc.detect_persistent_congestion(None, last_ack, PTO, lost);
+ assert_eq!(cc.cwnd(), CWND_INITIAL);
+
+ cc.detect_persistent_congestion(rtt_time, last_ack, PTO, lost);
+ cc.cwnd() == CWND_MIN
+ }
+
+ /// No persistent congestion can be had if there are no lost packets.
+ #[test]
+ fn persistent_congestion_no_lost() {
+ let lost = make_lost(&[]);
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(NewReno::default()),
+ 0,
+ 0,
+ &lost
+ ));
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(Cubic::default()),
+ 0,
+ 0,
+ &lost
+ ));
+ }
+
+ /// No persistent congestion can be had if there is only one lost packet.
+ #[test]
+ fn persistent_congestion_one_lost() {
+ let lost = make_lost(&[1]);
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(NewReno::default()),
+ 0,
+ 0,
+ &lost
+ ));
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(Cubic::default()),
+ 0,
+ 0,
+ &lost
+ ));
+ }
+
+ /// Persistent congestion can't happen based on old packets.
+ #[test]
+ fn persistent_congestion_past() {
+ // Packets sent prior to either the last acknowledged or the first RTT
+ // sample are not considered. So 0 is ignored.
+ let lost = make_lost(&[0, PERSISTENT_CONG_THRESH + 1, PERSISTENT_CONG_THRESH + 2]);
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(NewReno::default()),
+ 1,
+ 1,
+ &lost
+ ));
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(NewReno::default()),
+ 0,
+ 1,
+ &lost
+ ));
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(NewReno::default()),
+ 1,
+ 0,
+ &lost
+ ));
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(Cubic::default()),
+ 1,
+ 1,
+ &lost
+ ));
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(Cubic::default()),
+ 0,
+ 1,
+ &lost
+ ));
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(Cubic::default()),
+ 1,
+ 0,
+ &lost
+ ));
+ }
+
+ /// Persistent congestion doesn't start unless the packet is ack-eliciting.
+ #[test]
+ fn persistent_congestion_ack_eliciting() {
+ let mut lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
+ lost[0] = SentPacket::new(
+ lost[0].pt,
+ lost[0].pn,
+ lost[0].time_sent,
+ false,
+ Vec::new(),
+ lost[0].size,
+ );
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(NewReno::default()),
+ 0,
+ 0,
+ &lost
+ ));
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(Cubic::default()),
+ 0,
+ 0,
+ &lost
+ ));
+ }
+
+ /// Detect persistent congestion. Note that the first lost packet needs to have a time
+ /// greater than the previously acknowledged packet AND the first RTT sample. And the
+ /// difference in times needs to be greater than the persistent congestion threshold.
+ #[test]
+ fn persistent_congestion_min() {
+ let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
+ assert!(persistent_congestion_by_pto(
+ ClassicCongestionControl::new(NewReno::default()),
+ 0,
+ 0,
+ &lost
+ ));
+ assert!(persistent_congestion_by_pto(
+ ClassicCongestionControl::new(Cubic::default()),
+ 0,
+ 0,
+ &lost
+ ));
+ }
+
+ /// Make sure that not having a previous largest acknowledged also results
+ /// in detecting persistent congestion. (This is not expected to happen, but
+ /// the code permits it).
+ #[test]
+ fn persistent_congestion_no_prev_ack_newreno() {
+ let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+ cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost);
+ assert_eq!(cc.cwnd(), CWND_MIN);
+ }
+
+ #[test]
+ fn persistent_congestion_no_prev_ack_cubic() {
+ let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
+ let mut cc = ClassicCongestionControl::new(Cubic::default());
+ cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost);
+ assert_eq!(cc.cwnd(), CWND_MIN);
+ }
+
+ /// The code asserts on ordering errors.
+ #[test]
+ #[should_panic(expected = "time is monotonic")]
+ fn persistent_congestion_unsorted_newreno() {
+ let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]);
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(NewReno::default()),
+ 0,
+ 0,
+ &lost
+ ));
+ }
+
+ /// The code asserts on ordering errors.
+ #[test]
+ #[should_panic(expected = "time is monotonic")]
+ fn persistent_congestion_unsorted_cubic() {
+ let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]);
+ assert!(!persistent_congestion_by_pto(
+ ClassicCongestionControl::new(Cubic::default()),
+ 0,
+ 0,
+ &lost
+ ));
+ }
+
+ #[test]
+ fn app_limited_slow_start() {
+ const BELOW_APP_LIMIT_PKTS: usize = 5;
+ const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1;
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+ let cwnd = cc.congestion_window;
+ let mut now = now();
+ let mut next_pn = 0;
+
+ // simulate packet bursts below app_limit
+ for packet_burst_size in 1..=BELOW_APP_LIMIT_PKTS {
+ // always stay below app_limit during sent.
+ let mut pkts = Vec::new();
+ for _ in 0..packet_burst_size {
+ let p = SentPacket::new(
+ PacketType::Short,
+ next_pn, // pn
+ now, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ next_pn += 1;
+ cc.on_packet_sent(&p);
+ pkts.push(p);
+ }
+ assert_eq!(cc.bytes_in_flight(), packet_burst_size * MAX_DATAGRAM_SIZE);
+ now += RTT;
+ cc.on_packets_acked(&pkts, &RTT_ESTIMATE, now);
+ assert_eq!(cc.bytes_in_flight(), 0);
+ assert_eq!(cc.acked_bytes, 0);
+ assert_eq!(cwnd, cc.congestion_window); // CWND doesn't grow because we're app limited
+ }
+
+ // Fully utilize the congestion window by sending enough packets to
+ // have `bytes_in_flight` above the `app_limited` threshold.
+ let mut pkts = Vec::new();
+ for _ in 0..ABOVE_APP_LIMIT_PKTS {
+ let p = SentPacket::new(
+ PacketType::Short,
+ next_pn, // pn
+ now, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ next_pn += 1;
+ cc.on_packet_sent(&p);
+ pkts.push(p);
+ }
+ assert_eq!(
+ cc.bytes_in_flight(),
+ ABOVE_APP_LIMIT_PKTS * MAX_DATAGRAM_SIZE
+ );
+ now += RTT;
+ // Check if congestion window gets increased for all packets currently in flight
+ for (i, pkt) in pkts.into_iter().enumerate() {
+ cc.on_packets_acked(&[pkt], &RTT_ESTIMATE, now);
+
+ assert_eq!(
+ cc.bytes_in_flight(),
+ (ABOVE_APP_LIMIT_PKTS - i - 1) * MAX_DATAGRAM_SIZE
+ );
+ // increase acked_bytes with each packet
+ qinfo!("{} {}", cc.congestion_window, cwnd + i * MAX_DATAGRAM_SIZE);
+ assert_eq!(cc.congestion_window, cwnd + (i + 1) * MAX_DATAGRAM_SIZE);
+ assert_eq!(cc.acked_bytes, 0);
+ }
+ }
+
+ #[test]
+ fn app_limited_congestion_avoidance() {
+ const CWND_PKTS_CA: usize = CWND_INITIAL_PKTS / 2;
+ const BELOW_APP_LIMIT_PKTS: usize = CWND_PKTS_CA - 2;
+ const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1;
+
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+ let mut now = now();
+
+ // Change state to congestion avoidance by introducing loss.
+
+ let p_lost = SentPacket::new(
+ PacketType::Short,
+ 1, // pn
+ now, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ cc.on_packet_sent(&p_lost);
+ cwnd_is_default(&cc);
+ now += PTO;
+ cc.on_packets_lost(Some(now), None, PTO, &[p_lost]);
+ cwnd_is_halved(&cc);
+ let p_not_lost = SentPacket::new(
+ PacketType::Short,
+ 2, // pn
+ now, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ cc.on_packet_sent(&p_not_lost);
+ now += RTT;
+ cc.on_packets_acked(&[p_not_lost], &RTT_ESTIMATE, now);
+ cwnd_is_halved(&cc);
+ // cc is app limited therefore cwnd in not increased.
+ assert_eq!(cc.acked_bytes, 0);
+
+ // Now we are in the congestion avoidance state.
+ assert_eq!(cc.state, State::CongestionAvoidance);
+ // simulate packet bursts below app_limit
+ let mut next_pn = 3;
+ for packet_burst_size in 1..=BELOW_APP_LIMIT_PKTS {
+ // always stay below app_limit during sent.
+ let mut pkts = Vec::new();
+ for _ in 0..packet_burst_size {
+ let p = SentPacket::new(
+ PacketType::Short,
+ next_pn, // pn
+ now, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ next_pn += 1;
+ cc.on_packet_sent(&p);
+ pkts.push(p);
+ }
+ assert_eq!(cc.bytes_in_flight(), packet_burst_size * MAX_DATAGRAM_SIZE);
+ now += RTT;
+ for (i, pkt) in pkts.into_iter().enumerate() {
+ cc.on_packets_acked(&[pkt], &RTT_ESTIMATE, now);
+
+ assert_eq!(
+ cc.bytes_in_flight(),
+ (packet_burst_size - i - 1) * MAX_DATAGRAM_SIZE
+ );
+ cwnd_is_halved(&cc); // CWND doesn't grow because we're app limited
+ assert_eq!(cc.acked_bytes, 0);
+ }
+ }
+
+ // Fully utilize the congestion window by sending enough packets to
+ // have `bytes_in_flight` above the `app_limited` threshold.
+ let mut pkts = Vec::new();
+ for _ in 0..ABOVE_APP_LIMIT_PKTS {
+ let p = SentPacket::new(
+ PacketType::Short,
+ next_pn, // pn
+ now, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ next_pn += 1;
+ cc.on_packet_sent(&p);
+ pkts.push(p);
+ }
+ assert_eq!(
+ cc.bytes_in_flight(),
+ ABOVE_APP_LIMIT_PKTS * MAX_DATAGRAM_SIZE
+ );
+ now += RTT;
+ let mut last_acked_bytes = 0;
+ // Check if congestion window gets increased for all packets currently in flight
+ for (i, pkt) in pkts.into_iter().enumerate() {
+ cc.on_packets_acked(&[pkt], &RTT_ESTIMATE, now);
+
+ assert_eq!(
+ cc.bytes_in_flight(),
+ (ABOVE_APP_LIMIT_PKTS - i - 1) * MAX_DATAGRAM_SIZE
+ );
+ // The cwnd doesn't increase, but the acked_bytes do, which will eventually lead to an
+ // increase, once the number of bytes reaches the necessary level
+ cwnd_is_halved(&cc);
+ // increase acked_bytes with each packet
+ assert_ne!(cc.acked_bytes, last_acked_bytes);
+ last_acked_bytes = cc.acked_bytes;
+ }
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/cc/cubic.rs b/third_party/rust/neqo-transport/src/cc/cubic.rs
new file mode 100644
index 0000000000..c04a29b443
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cc/cubic.rs
@@ -0,0 +1,215 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![deny(clippy::pedantic)]
+
+use std::{
+ convert::TryFrom,
+ fmt::{self, Display},
+ time::{Duration, Instant},
+};
+
+use neqo_common::qtrace;
+
+use crate::cc::{classic_cc::WindowAdjustment, MAX_DATAGRAM_SIZE_F64};
+
+// CUBIC congestion control
+
+// C is a constant fixed to determine the aggressiveness of window
+// increase in high BDP networks.
+pub const CUBIC_C: f64 = 0.4;
+pub const CUBIC_ALPHA: f64 = 3.0 * (1.0 - 0.7) / (1.0 + 0.7);
+
+// CUBIC_BETA = 0.7;
+pub const CUBIC_BETA_USIZE_DIVIDEND: usize = 7;
+pub const CUBIC_BETA_USIZE_DIVISOR: usize = 10;
+
+/// The fast convergence ratio further reduces the congestion window when a congestion event
+/// occurs before reaching the previous `W_max`.
+pub const CUBIC_FAST_CONVERGENCE: f64 = 0.85; // (1.0 + CUBIC_BETA) / 2.0;
+
+/// The minimum number of multiples of the datagram size that need
+/// to be received to cause an increase in the congestion window.
+/// When there is no loss, Cubic can return to exponential increase, but
+/// this value reduces the magnitude of the resulting growth by a constant factor.
+/// A value of 1.0 would mean a return to the rate used in slow start.
+const EXPONENTIAL_GROWTH_REDUCTION: f64 = 2.0;
+
+/// Convert an integer congestion window value into a floating point value.
+/// This has the effect of reducing larger values to `1<<53`.
+/// If you have a congestion window that large, something is probably wrong.
+fn convert_to_f64(v: usize) -> f64 {
+ let mut f_64 = f64::from(u32::try_from(v >> 21).unwrap_or(u32::MAX));
+ f_64 *= 2_097_152.0; // f_64 <<= 21
+ f_64 += f64::from(u32::try_from(v & 0x1f_ffff).unwrap());
+ f_64
+}
+
+#[derive(Debug)]
+pub struct Cubic {
+ last_max_cwnd: f64,
+ estimated_tcp_cwnd: f64,
+ k: f64,
+ w_max: f64,
+ ca_epoch_start: Option<Instant>,
+ tcp_acked_bytes: f64,
+}
+
+impl Default for Cubic {
+ fn default() -> Self {
+ Self {
+ last_max_cwnd: 0.0,
+ estimated_tcp_cwnd: 0.0,
+ k: 0.0,
+ w_max: 0.0,
+ ca_epoch_start: None,
+ tcp_acked_bytes: 0.0,
+ }
+ }
+}
+
+impl Display for Cubic {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "Cubic [last_max_cwnd: {}, k: {}, w_max: {}, ca_epoch_start: {:?}]",
+ self.last_max_cwnd, self.k, self.w_max, self.ca_epoch_start
+ )?;
+ Ok(())
+ }
+}
+
+#[allow(clippy::doc_markdown)]
+impl Cubic {
+ /// Original equations is:
+ /// K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2 RFC8312)
+ /// W_max is number of segments of the maximum segment size (MSS).
+ ///
+ /// K is actually the time that W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) would
+ /// take to increase to W_max. We use bytes not MSS units, therefore this
+ /// equation will be: W_cubic(t) = C*MSS*(t-K)^3 + W_max.
+ ///
+ /// From that equation we can calculate K as:
+ /// K = cubic_root((W_max - W_cubic) / C / MSS);
+ fn calc_k(&self, curr_cwnd: f64) -> f64 {
+ ((self.w_max - curr_cwnd) / CUBIC_C / MAX_DATAGRAM_SIZE_F64).cbrt()
+ }
+
+ /// W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
+ /// t is relative to the start of the congestion avoidance phase and it is in seconds.
+ fn w_cubic(&self, t: f64) -> f64 {
+ CUBIC_C * (t - self.k).powi(3) * MAX_DATAGRAM_SIZE_F64 + self.w_max
+ }
+
+ fn start_epoch(&mut self, curr_cwnd_f64: f64, new_acked_f64: f64, now: Instant) {
+ self.ca_epoch_start = Some(now);
+ // reset tcp_acked_bytes and estimated_tcp_cwnd;
+ self.tcp_acked_bytes = new_acked_f64;
+ self.estimated_tcp_cwnd = curr_cwnd_f64;
+ if self.last_max_cwnd <= curr_cwnd_f64 {
+ self.w_max = curr_cwnd_f64;
+ self.k = 0.0;
+ } else {
+ self.w_max = self.last_max_cwnd;
+ self.k = self.calc_k(curr_cwnd_f64);
+ }
+ qtrace!([self], "New epoch");
+ }
+}
+
+impl WindowAdjustment for Cubic {
+ // This is because of the cast in the last line from f64 to usize.
+ #[allow(clippy::cast_possible_truncation)]
+ #[allow(clippy::cast_sign_loss)]
+ fn bytes_for_cwnd_increase(
+ &mut self,
+ curr_cwnd: usize,
+ new_acked_bytes: usize,
+ min_rtt: Duration,
+ now: Instant,
+ ) -> usize {
+ let curr_cwnd_f64 = convert_to_f64(curr_cwnd);
+ let new_acked_f64 = convert_to_f64(new_acked_bytes);
+ if self.ca_epoch_start.is_none() {
+ // This is a start of a new congestion avoidance phase.
+ self.start_epoch(curr_cwnd_f64, new_acked_f64, now);
+ } else {
+ self.tcp_acked_bytes += new_acked_f64;
+ }
+
+ let time_ca = self
+ .ca_epoch_start
+ .map_or(min_rtt, |t| {
+ if now + min_rtt < t {
+ // This only happens when processing old packets
+ // that were saved and replayed with old timestamps.
+ min_rtt
+ } else {
+ now + min_rtt - t
+ }
+ })
+ .as_secs_f64();
+ let target_cubic = self.w_cubic(time_ca);
+
+ let tcp_cnt = self.estimated_tcp_cwnd / CUBIC_ALPHA;
+ while self.tcp_acked_bytes > tcp_cnt {
+ self.tcp_acked_bytes -= tcp_cnt;
+ self.estimated_tcp_cwnd += MAX_DATAGRAM_SIZE_F64;
+ }
+
+ let target_cwnd = target_cubic.max(self.estimated_tcp_cwnd);
+
+ // Calculate the number of bytes that would need to be acknowledged for an increase
+ // of `MAX_DATAGRAM_SIZE` to match the increase of `target - cwnd / cwnd` as defined
+ // in the specification (Sections 4.4 and 4.5).
+ // The amount of data required therefore reduces asymptotically as the target increases.
+ // If the target is not significantly higher than the congestion window, require a very
+ // large amount of acknowledged data (effectively block increases).
+ let mut acked_to_increase =
+ MAX_DATAGRAM_SIZE_F64 * curr_cwnd_f64 / (target_cwnd - curr_cwnd_f64).max(1.0);
+
+ // Limit increase to max 1 MSS per EXPONENTIAL_GROWTH_REDUCTION ack packets.
+ // This effectively limits target_cwnd to (1 + 1 / EXPONENTIAL_GROWTH_REDUCTION) cwnd.
+ acked_to_increase =
+ acked_to_increase.max(EXPONENTIAL_GROWTH_REDUCTION * MAX_DATAGRAM_SIZE_F64);
+ acked_to_increase as usize
+ }
+
+ fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize) {
+ let curr_cwnd_f64 = convert_to_f64(curr_cwnd);
+ // Fast Convergence
+ // If congestion event occurs before the maximum congestion window before the last
+ // congestion event, we reduce the the maximum congestion window and thereby W_max.
+ // check cwnd + MAX_DATAGRAM_SIZE instead of cwnd because with cwnd in bytes, cwnd may be
+ // slightly off.
+ self.last_max_cwnd = if curr_cwnd_f64 + MAX_DATAGRAM_SIZE_F64 < self.last_max_cwnd {
+ curr_cwnd_f64 * CUBIC_FAST_CONVERGENCE
+ } else {
+ curr_cwnd_f64
+ };
+ self.ca_epoch_start = None;
+ (
+ curr_cwnd * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR,
+ acked_bytes * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR,
+ )
+ }
+
+ fn on_app_limited(&mut self) {
+ // Reset ca_epoch_start. Let it start again when the congestion controller
+ // exits the app-limited period.
+ self.ca_epoch_start = None;
+ }
+
+ #[cfg(test)]
+ fn last_max_cwnd(&self) -> f64 {
+ self.last_max_cwnd
+ }
+
+ #[cfg(test)]
+ fn set_last_max_cwnd(&mut self, last_max_cwnd: f64) {
+ self.last_max_cwnd = last_max_cwnd;
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/cc/mod.rs b/third_party/rust/neqo-transport/src/cc/mod.rs
new file mode 100644
index 0000000000..a1a43bd157
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cc/mod.rs
@@ -0,0 +1,87 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Congestion control
+#![deny(clippy::pedantic)]
+
+use std::{
+ fmt::{Debug, Display},
+ str::FromStr,
+ time::{Duration, Instant},
+};
+
+use neqo_common::qlog::NeqoQlog;
+
+use crate::{path::PATH_MTU_V6, rtt::RttEstimate, tracking::SentPacket, Error};
+
+mod classic_cc;
+mod cubic;
+mod new_reno;
+
+pub use classic_cc::ClassicCongestionControl;
+#[cfg(test)]
+pub use classic_cc::{CWND_INITIAL, CWND_INITIAL_PKTS, CWND_MIN};
+pub use cubic::Cubic;
+pub use new_reno::NewReno;
+
+pub const MAX_DATAGRAM_SIZE: usize = PATH_MTU_V6;
+#[allow(clippy::cast_precision_loss)]
+pub const MAX_DATAGRAM_SIZE_F64: f64 = MAX_DATAGRAM_SIZE as f64;
+
+pub trait CongestionControl: Display + Debug {
+ fn set_qlog(&mut self, qlog: NeqoQlog);
+
+ #[must_use]
+ fn cwnd(&self) -> usize;
+
+ #[must_use]
+ fn bytes_in_flight(&self) -> usize;
+
+ #[must_use]
+ fn cwnd_avail(&self) -> usize;
+
+ fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], rtt_est: &RttEstimate, now: Instant);
+
+ /// Returns true if the congestion window was reduced.
+ fn on_packets_lost(
+ &mut self,
+ first_rtt_sample_time: Option<Instant>,
+ prev_largest_acked_sent: Option<Instant>,
+ pto: Duration,
+ lost_packets: &[SentPacket],
+ ) -> bool;
+
+ #[must_use]
+ fn recovery_packet(&self) -> bool;
+
+ fn discard(&mut self, pkt: &SentPacket);
+
+ fn on_packet_sent(&mut self, pkt: &SentPacket);
+
+ fn discard_in_flight(&mut self);
+}
+
+#[derive(Debug, Copy, Clone)]
+pub enum CongestionControlAlgorithm {
+ NewReno,
+ Cubic,
+}
+
+// A `FromStr` implementation so that this can be used in command-line interfaces.
+impl FromStr for CongestionControlAlgorithm {
+ type Err = Error;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ match s.trim().to_ascii_lowercase().as_str() {
+ "newreno" | "reno" => Ok(Self::NewReno),
+ "cubic" => Ok(Self::Cubic),
+ _ => Err(Error::InvalidInput),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests;
diff --git a/third_party/rust/neqo-transport/src/cc/new_reno.rs b/third_party/rust/neqo-transport/src/cc/new_reno.rs
new file mode 100644
index 0000000000..e51b3d6cc0
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cc/new_reno.rs
@@ -0,0 +1,51 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Congestion control
+#![deny(clippy::pedantic)]
+
+use std::{
+ fmt::{self, Display},
+ time::{Duration, Instant},
+};
+
+use crate::cc::classic_cc::WindowAdjustment;
+
+#[derive(Debug, Default)]
+pub struct NewReno {}
+
+impl Display for NewReno {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "NewReno")?;
+ Ok(())
+ }
+}
+
+impl WindowAdjustment for NewReno {
+ fn bytes_for_cwnd_increase(
+ &mut self,
+ curr_cwnd: usize,
+ _new_acked_bytes: usize,
+ _min_rtt: Duration,
+ _now: Instant,
+ ) -> usize {
+ curr_cwnd
+ }
+
+ fn reduce_cwnd(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize) {
+ (curr_cwnd / 2, acked_bytes / 2)
+ }
+
+ fn on_app_limited(&mut self) {}
+
+ #[cfg(test)]
+ fn last_max_cwnd(&self) -> f64 {
+ 0.0
+ }
+
+ #[cfg(test)]
+ fn set_last_max_cwnd(&mut self, _last_max_cwnd: f64) {}
+}
diff --git a/third_party/rust/neqo-transport/src/cc/tests/cubic.rs b/third_party/rust/neqo-transport/src/cc/tests/cubic.rs
new file mode 100644
index 0000000000..0c82e47817
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cc/tests/cubic.rs
@@ -0,0 +1,333 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![allow(clippy::cast_possible_truncation)]
+#![allow(clippy::cast_sign_loss)]
+
+use std::{
+ convert::TryFrom,
+ ops::Sub,
+ time::{Duration, Instant},
+};
+
+use test_fixture::now;
+
+use crate::{
+ cc::{
+ classic_cc::{ClassicCongestionControl, CWND_INITIAL},
+ cubic::{
+ Cubic, CUBIC_ALPHA, CUBIC_BETA_USIZE_DIVIDEND, CUBIC_BETA_USIZE_DIVISOR, CUBIC_C,
+ CUBIC_FAST_CONVERGENCE,
+ },
+ CongestionControl, MAX_DATAGRAM_SIZE, MAX_DATAGRAM_SIZE_F64,
+ },
+ packet::PacketType,
+ rtt::RttEstimate,
+ tracking::SentPacket,
+};
+
+const RTT: Duration = Duration::from_millis(100);
+const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(100));
+const CWND_INITIAL_F64: f64 = 10.0 * MAX_DATAGRAM_SIZE_F64;
+const CWND_INITIAL_10_F64: f64 = 10.0 * CWND_INITIAL_F64;
+const CWND_INITIAL_10: usize = 10 * CWND_INITIAL;
+const CWND_AFTER_LOSS: usize = CWND_INITIAL * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR;
+const CWND_AFTER_LOSS_SLOW_START: usize =
+ (CWND_INITIAL + MAX_DATAGRAM_SIZE) * CUBIC_BETA_USIZE_DIVIDEND / CUBIC_BETA_USIZE_DIVISOR;
+
+fn fill_cwnd(cc: &mut ClassicCongestionControl<Cubic>, mut next_pn: u64, now: Instant) -> u64 {
+ while cc.bytes_in_flight() < cc.cwnd() {
+ let sent = SentPacket::new(
+ PacketType::Short,
+ next_pn, // pn
+ now, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ cc.on_packet_sent(&sent);
+ next_pn += 1;
+ }
+ next_pn
+}
+
+fn ack_packet(cc: &mut ClassicCongestionControl<Cubic>, pn: u64, now: Instant) {
+ let acked = SentPacket::new(
+ PacketType::Short,
+ pn, // pn
+ now, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ cc.on_packets_acked(&[acked], &RTT_ESTIMATE, now);
+}
+
+fn packet_lost(cc: &mut ClassicCongestionControl<Cubic>, pn: u64) {
+ const PTO: Duration = Duration::from_millis(120);
+ let p_lost = SentPacket::new(
+ PacketType::Short,
+ pn, // pn
+ now(), // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ cc.on_packets_lost(None, None, PTO, &[p_lost]);
+}
+
+fn expected_tcp_acks(cwnd_rtt_start: usize) -> u64 {
+ (f64::from(i32::try_from(cwnd_rtt_start).unwrap()) / MAX_DATAGRAM_SIZE_F64 / CUBIC_ALPHA)
+ .round() as u64
+}
+
+#[test]
+fn tcp_phase() {
+ let mut cubic = ClassicCongestionControl::new(Cubic::default());
+
+ // change to congestion avoidance state.
+ cubic.set_ssthresh(1);
+
+ let mut now = now();
+ let start_time = now;
+ // helper variables to remember the next packet number to be sent/acked.
+ let mut next_pn_send = 0;
+ let mut next_pn_ack = 0;
+
+ next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now);
+
+ // This will start with TCP phase.
+ // in this phase cwnd is increase by CUBIC_ALPHA every RTT. We can look at it as
+ // increase of MAX_DATAGRAM_SIZE every 1 / CUBIC_ALPHA RTTs.
+ // The phase will end when cwnd calculated with cubic equation is equal to TCP estimate:
+ // CUBIC_C * (n * RTT / CUBIC_ALPHA)^3 * MAX_DATAGRAM_SIZE = n * MAX_DATAGRAM_SIZE
+ // from this n = sqrt(CUBIC_ALPHA^3/ (CUBIC_C * RTT^3)).
+ let num_tcp_increases = (CUBIC_ALPHA.powi(3) / (CUBIC_C * RTT.as_secs_f64().powi(3)))
+ .sqrt()
+ .floor() as u64;
+
+ for _ in 0..num_tcp_increases {
+ let cwnd_rtt_start = cubic.cwnd();
+ // Expected acks during a period of RTT / CUBIC_ALPHA.
+ let acks = expected_tcp_acks(cwnd_rtt_start);
+ // The time between acks if they are ideally paced over a RTT.
+ let time_increase = RTT / u32::try_from(cwnd_rtt_start / MAX_DATAGRAM_SIZE).unwrap();
+
+ for _ in 0..acks {
+ now += time_increase;
+ ack_packet(&mut cubic, next_pn_ack, now);
+ next_pn_ack += 1;
+ next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now);
+ }
+
+ assert_eq!(cubic.cwnd() - cwnd_rtt_start, MAX_DATAGRAM_SIZE);
+ }
+
+ // The next increase will be according to the cubic equation.
+
+ let cwnd_rtt_start = cubic.cwnd();
+ // cwnd_rtt_start has change, therefore calculate new time_increase (the time
+ // between acks if they are ideally paced over a RTT).
+ let time_increase = RTT / u32::try_from(cwnd_rtt_start / MAX_DATAGRAM_SIZE).unwrap();
+ let mut num_acks = 0; // count the number of acks. until cwnd is increased by MAX_DATAGRAM_SIZE.
+
+ while cwnd_rtt_start == cubic.cwnd() {
+ num_acks += 1;
+ now += time_increase;
+ ack_packet(&mut cubic, next_pn_ack, now);
+ next_pn_ack += 1;
+ next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now);
+ }
+
+ // Make sure that the increase is not according to TCP equation, i.e., that it took
+ // less than RTT / CUBIC_ALPHA.
+ let expected_ack_tcp_increase = expected_tcp_acks(cwnd_rtt_start);
+ assert!(num_acks < expected_ack_tcp_increase);
+
+ // This first increase after a TCP phase may be shorter than what it would take by a regular
+ // cubic phase, because of the proper byte counting and the credit it already had before
+ // entering this phase. Therefore We will perform another round and compare it to expected
+ // increase using the cubic equation.
+
+ let cwnd_rtt_start_after_tcp = cubic.cwnd();
+ let elapsed_time = now - start_time;
+
+ // calculate new time_increase.
+ let time_increase = RTT / u32::try_from(cwnd_rtt_start_after_tcp / MAX_DATAGRAM_SIZE).unwrap();
+ let mut num_acks2 = 0; // count the number of acks. until cwnd is increased by MAX_DATAGRAM_SIZE.
+
+ while cwnd_rtt_start_after_tcp == cubic.cwnd() {
+ num_acks2 += 1;
+ now += time_increase;
+ ack_packet(&mut cubic, next_pn_ack, now);
+ next_pn_ack += 1;
+ next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now);
+ }
+
+ let expected_ack_tcp_increase2 = expected_tcp_acks(cwnd_rtt_start_after_tcp);
+ assert!(num_acks2 < expected_ack_tcp_increase2);
+
+ // The time needed to increase cwnd by MAX_DATAGRAM_SIZE using the cubic equation will be
+ // calculates from: W_cubic(elapsed_time + t_to_increase) - W_cubis(elapsed_time) =
+ // MAX_DATAGRAM_SIZE => CUBIC_C * (elapsed_time + t_to_increase)^3 * MAX_DATAGRAM_SIZE +
+ // CWND_INITIAL - CUBIC_C * elapsed_time^3 * MAX_DATAGRAM_SIZE + CWND_INITIAL =
+ // MAX_DATAGRAM_SIZE => t_to_increase = cbrt((1 + CUBIC_C * elapsed_time^3) / CUBIC_C) -
+ // elapsed_time (t_to_increase is in seconds)
+ // number of ack needed is t_to_increase / time_increase.
+ let expected_ack_cubic_increase =
+ ((((1.0 + CUBIC_C * (elapsed_time).as_secs_f64().powi(3)) / CUBIC_C).cbrt()
+ - elapsed_time.as_secs_f64())
+ / time_increase.as_secs_f64())
+ .ceil() as u64;
+ // num_acks is very close to the calculated value. The exact value is hard to calculate
+ // because the proportional increase(i.e. curr_cwnd_f64 / (target - curr_cwnd_f64) *
+ // MAX_DATAGRAM_SIZE_F64) and the byte counting.
+ assert_eq!(num_acks2, expected_ack_cubic_increase + 2);
+}
+
+#[test]
+fn cubic_phase() {
+ let mut cubic = ClassicCongestionControl::new(Cubic::default());
+ // Set last_max_cwnd to a higher number make sure that cc is the cubic phase (cwnd is calculated
+ // by the cubic equation).
+ cubic.set_last_max_cwnd(CWND_INITIAL_10_F64);
+ // Set ssthresh to something small to make sure that cc is in the congection avoidance phase.
+ cubic.set_ssthresh(1);
+ let mut now = now();
+ let mut next_pn_send = 0;
+ let mut next_pn_ack = 0;
+
+ next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now);
+
+ let k = ((CWND_INITIAL_10_F64 - CWND_INITIAL_F64) / CUBIC_C / MAX_DATAGRAM_SIZE_F64).cbrt();
+ let epoch_start = now;
+
+ // The number of RTT until W_max is reached.
+ let num_rtts_w_max = (k / RTT.as_secs_f64()).round() as u64;
+ for _ in 0..num_rtts_w_max {
+ let cwnd_rtt_start = cubic.cwnd();
+ // Expected acks
+ let acks = cwnd_rtt_start / MAX_DATAGRAM_SIZE;
+ let time_increase = RTT / u32::try_from(acks).unwrap();
+ for _ in 0..acks {
+ now += time_increase;
+ ack_packet(&mut cubic, next_pn_ack, now);
+ next_pn_ack += 1;
+ next_pn_send = fill_cwnd(&mut cubic, next_pn_send, now);
+ }
+
+ let expected =
+ (CUBIC_C * ((now - epoch_start).as_secs_f64() - k).powi(3) * MAX_DATAGRAM_SIZE_F64
+ + CWND_INITIAL_10_F64)
+ .round() as usize;
+
+ assert_within(cubic.cwnd(), expected, MAX_DATAGRAM_SIZE);
+ }
+ assert_eq!(cubic.cwnd(), CWND_INITIAL_10);
+}
+
+fn assert_within<T: Sub<Output = T> + PartialOrd + Copy>(value: T, expected: T, margin: T) {
+ if value >= expected {
+ assert!(value - expected < margin);
+ } else {
+ assert!(expected - value < margin);
+ }
+}
+
+#[test]
+fn congestion_event_slow_start() {
+ let mut cubic = ClassicCongestionControl::new(Cubic::default());
+
+ _ = fill_cwnd(&mut cubic, 0, now());
+ ack_packet(&mut cubic, 0, now());
+
+ assert_within(cubic.last_max_cwnd(), 0.0, f64::EPSILON);
+
+ // cwnd is increased by 1 in slow start phase, after an ack.
+ assert_eq!(cubic.cwnd(), CWND_INITIAL + MAX_DATAGRAM_SIZE);
+
+ // Trigger a congestion_event in slow start phase
+ packet_lost(&mut cubic, 1);
+
+ // last_max_cwnd is equal to cwnd before decrease.
+ assert_within(
+ cubic.last_max_cwnd(),
+ CWND_INITIAL_F64 + MAX_DATAGRAM_SIZE_F64,
+ f64::EPSILON,
+ );
+ assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS_SLOW_START);
+}
+
+#[test]
+fn congestion_event_congestion_avoidance() {
+ let mut cubic = ClassicCongestionControl::new(Cubic::default());
+
+ // Set ssthresh to something small to make sure that cc is in the congection avoidance phase.
+ cubic.set_ssthresh(1);
+
+ // Set last_max_cwnd to something smaller than cwnd so that the fast convergence is not
+ // triggered.
+ cubic.set_last_max_cwnd(3.0 * MAX_DATAGRAM_SIZE_F64);
+
+ _ = fill_cwnd(&mut cubic, 0, now());
+ ack_packet(&mut cubic, 0, now());
+
+ assert_eq!(cubic.cwnd(), CWND_INITIAL);
+
+ // Trigger a congestion_event in slow start phase
+ packet_lost(&mut cubic, 1);
+
+ assert_within(cubic.last_max_cwnd(), CWND_INITIAL_F64, f64::EPSILON);
+ assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS);
+}
+
+#[test]
+fn congestion_event_congestion_avoidance_2() {
+ let mut cubic = ClassicCongestionControl::new(Cubic::default());
+
+ // Set ssthresh to something small to make sure that cc is in the congection avoidance phase.
+ cubic.set_ssthresh(1);
+
+ // Set last_max_cwnd to something higher than cwnd so that the fast convergence is triggered.
+ cubic.set_last_max_cwnd(CWND_INITIAL_10_F64);
+
+ _ = fill_cwnd(&mut cubic, 0, now());
+ ack_packet(&mut cubic, 0, now());
+
+ assert_within(cubic.last_max_cwnd(), CWND_INITIAL_10_F64, f64::EPSILON);
+ assert_eq!(cubic.cwnd(), CWND_INITIAL);
+
+ // Trigger a congestion_event.
+ packet_lost(&mut cubic, 1);
+
+ assert_within(
+ cubic.last_max_cwnd(),
+ CWND_INITIAL_F64 * CUBIC_FAST_CONVERGENCE,
+ f64::EPSILON,
+ );
+ assert_eq!(cubic.cwnd(), CWND_AFTER_LOSS);
+}
+
+#[test]
+fn congestion_event_congestion_avoidance_test_no_overflow() {
+ const PTO: Duration = Duration::from_millis(120);
+ let mut cubic = ClassicCongestionControl::new(Cubic::default());
+
+ // Set ssthresh to something small to make sure that cc is in the congection avoidance phase.
+ cubic.set_ssthresh(1);
+
+ // Set last_max_cwnd to something higher than cwnd so that the fast convergence is triggered.
+ cubic.set_last_max_cwnd(CWND_INITIAL_10_F64);
+
+ _ = fill_cwnd(&mut cubic, 0, now());
+ ack_packet(&mut cubic, 1, now());
+
+ assert_within(cubic.last_max_cwnd(), CWND_INITIAL_10_F64, f64::EPSILON);
+ assert_eq!(cubic.cwnd(), CWND_INITIAL);
+
+ // Now ack packet that was send earlier.
+ ack_packet(&mut cubic, 0, now().checked_sub(PTO).unwrap());
+}
diff --git a/third_party/rust/neqo-transport/src/cc/tests/mod.rs b/third_party/rust/neqo-transport/src/cc/tests/mod.rs
new file mode 100644
index 0000000000..238a7ad012
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cc/tests/mod.rs
@@ -0,0 +1,7 @@
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+mod cubic;
+mod new_reno;
diff --git a/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs b/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs
new file mode 100644
index 0000000000..a73844a755
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cc/tests/new_reno.rs
@@ -0,0 +1,219 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Congestion control
+#![deny(clippy::pedantic)]
+
+use std::time::Duration;
+
+use test_fixture::now;
+
+use crate::{
+ cc::{
+ new_reno::NewReno, ClassicCongestionControl, CongestionControl, CWND_INITIAL,
+ MAX_DATAGRAM_SIZE,
+ },
+ packet::PacketType,
+ rtt::RttEstimate,
+ tracking::SentPacket,
+};
+
+const PTO: Duration = Duration::from_millis(100);
+const RTT: Duration = Duration::from_millis(98);
+const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(98));
+
+fn cwnd_is_default(cc: &ClassicCongestionControl<NewReno>) {
+ assert_eq!(cc.cwnd(), CWND_INITIAL);
+ assert_eq!(cc.ssthresh(), usize::MAX);
+}
+
+fn cwnd_is_halved(cc: &ClassicCongestionControl<NewReno>) {
+ assert_eq!(cc.cwnd(), CWND_INITIAL / 2);
+ assert_eq!(cc.ssthresh(), CWND_INITIAL / 2);
+}
+
+#[test]
+fn issue_876() {
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+ let time_now = now();
+ let time_before = time_now.checked_sub(Duration::from_millis(100)).unwrap();
+ let time_after = time_now + Duration::from_millis(150);
+
+ let sent_packets = &[
+ SentPacket::new(
+ PacketType::Short,
+ 1, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE - 1, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 2, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE - 2, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 3, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 4, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 5, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 6, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 7, // pn
+ time_after, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE - 3, // size
+ ),
+ ];
+
+ // Send some more packets so that the cc is not app-limited.
+ for p in &sent_packets[..6] {
+ cc.on_packet_sent(p);
+ }
+ assert_eq!(cc.acked_bytes(), 0);
+ cwnd_is_default(&cc);
+ assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 3);
+
+ cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[0..1]);
+
+ // We are now in recovery
+ assert!(cc.recovery_packet());
+ assert_eq!(cc.acked_bytes(), 0);
+ cwnd_is_halved(&cc);
+ assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2);
+
+ // Send a packet after recovery starts
+ cc.on_packet_sent(&sent_packets[6]);
+ assert!(!cc.recovery_packet());
+ cwnd_is_halved(&cc);
+ assert_eq!(cc.acked_bytes(), 0);
+ assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 5);
+
+ // and ack it. cwnd increases slightly
+ cc.on_packets_acked(&sent_packets[6..], &RTT_ESTIMATE, time_now);
+ assert_eq!(cc.acked_bytes(), sent_packets[6].size);
+ cwnd_is_halved(&cc);
+ assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2);
+
+ // Packet from before is lost. Should not hurt cwnd.
+ cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[1..2]);
+ assert!(!cc.recovery_packet());
+ assert_eq!(cc.acked_bytes(), sent_packets[6].size);
+ cwnd_is_halved(&cc);
+ assert_eq!(cc.bytes_in_flight(), 4 * MAX_DATAGRAM_SIZE);
+}
+
+#[test]
+// https://github.com/mozilla/neqo/pull/1465
+fn issue_1465() {
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+ let mut pn = 0;
+ let mut now = now();
+ let mut next_packet = |now| {
+ let p = SentPacket::new(
+ PacketType::Short,
+ pn, // pn
+ now, // time_sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ pn += 1;
+ p
+ };
+ let mut send_next = |cc: &mut ClassicCongestionControl<NewReno>, now| {
+ let p = next_packet(now);
+ cc.on_packet_sent(&p);
+ p
+ };
+
+ let p1 = send_next(&mut cc, now);
+ let p2 = send_next(&mut cc, now);
+ let p3 = send_next(&mut cc, now);
+
+ assert_eq!(cc.acked_bytes(), 0);
+ cwnd_is_default(&cc);
+ assert_eq!(cc.bytes_in_flight(), 3 * MAX_DATAGRAM_SIZE);
+
+ // advance one rtt to detect lost packet there this simplifies the timers, because
+ // on_packet_loss would only be called after RTO, but that is not relevant to the problem
+ now += RTT;
+ cc.on_packets_lost(Some(now), None, PTO, &[p1]);
+
+ // We are now in recovery
+ assert!(cc.recovery_packet());
+ assert_eq!(cc.acked_bytes(), 0);
+ cwnd_is_halved(&cc);
+ assert_eq!(cc.bytes_in_flight(), 2 * MAX_DATAGRAM_SIZE);
+
+ // Don't reduce the cwnd again on second packet loss
+ cc.on_packets_lost(Some(now), None, PTO, &[p3]);
+ assert_eq!(cc.acked_bytes(), 0);
+ cwnd_is_halved(&cc); // still the same as after first packet loss
+ assert_eq!(cc.bytes_in_flight(), MAX_DATAGRAM_SIZE);
+
+ // the acked packets before on_packet_sent were the cause of
+ // https://github.com/mozilla/neqo/pull/1465
+ cc.on_packets_acked(&[p2], &RTT_ESTIMATE, now);
+
+ assert_eq!(cc.bytes_in_flight(), 0);
+
+ // send out recovery packet and get it acked to get out of recovery state
+ let p4 = send_next(&mut cc, now);
+ cc.on_packet_sent(&p4);
+ now += RTT;
+ cc.on_packets_acked(&[p4], &RTT_ESTIMATE, now);
+
+ // do the same as in the first rtt but now the bug appears
+ let p5 = send_next(&mut cc, now);
+ let p6 = send_next(&mut cc, now);
+ now += RTT;
+
+ let cur_cwnd = cc.cwnd();
+ cc.on_packets_lost(Some(now), None, PTO, &[p5]);
+
+ // go back into recovery
+ assert!(cc.recovery_packet());
+ assert_eq!(cc.cwnd(), cur_cwnd / 2);
+ assert_eq!(cc.acked_bytes(), 0);
+ assert_eq!(cc.bytes_in_flight(), 2 * MAX_DATAGRAM_SIZE);
+
+ // this shouldn't introduce further cwnd reduction, but it did before https://github.com/mozilla/neqo/pull/1465
+ cc.on_packets_lost(Some(now), None, PTO, &[p6]);
+ assert_eq!(cc.cwnd(), cur_cwnd / 2);
+}
diff --git a/third_party/rust/neqo-transport/src/cid.rs b/third_party/rust/neqo-transport/src/cid.rs
new file mode 100644
index 0000000000..be202daf25
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cid.rs
@@ -0,0 +1,609 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Representation and management of connection IDs.
+
+use std::{
+ borrow::Borrow,
+ cell::{Ref, RefCell},
+ cmp::{max, min},
+ convert::{AsRef, TryFrom},
+ ops::Deref,
+ rc::Rc,
+};
+
+use neqo_common::{hex, hex_with_len, qinfo, Decoder, Encoder};
+use neqo_crypto::random;
+use smallvec::SmallVec;
+
+use crate::{
+ frame::FRAME_TYPE_NEW_CONNECTION_ID, packet::PacketBuilder, recovery::RecoveryToken,
+ stats::FrameStats, Error, Res,
+};
+
+pub const MAX_CONNECTION_ID_LEN: usize = 20;
+pub const LOCAL_ACTIVE_CID_LIMIT: usize = 8;
+pub const CONNECTION_ID_SEQNO_INITIAL: u64 = 0;
+pub const CONNECTION_ID_SEQNO_PREFERRED: u64 = 1;
+/// A special value. See `ConnectionIdManager::add_odcid`.
+const CONNECTION_ID_SEQNO_ODCID: u64 = u64::MAX;
+/// A special value. See `ConnectionIdEntry::empty_remote`.
+const CONNECTION_ID_SEQNO_EMPTY: u64 = u64::MAX - 1;
+
+#[derive(Clone, Default, Eq, Hash, PartialEq)]
+pub struct ConnectionId {
+ pub(crate) cid: SmallVec<[u8; MAX_CONNECTION_ID_LEN]>,
+}
+
+impl ConnectionId {
+ pub fn generate(len: usize) -> Self {
+ assert!(matches!(len, 0..=MAX_CONNECTION_ID_LEN));
+ Self::from(random(len))
+ }
+
+ // Apply a wee bit of greasing here in picking a length between 8 and 20 bytes long.
+ pub fn generate_initial() -> Self {
+ let v = random(1);
+ // Bias selection toward picking 8 (>50% of the time).
+ let len: usize = max(8, 5 + (v[0] & (v[0] >> 4))).into();
+ Self::generate(len)
+ }
+
+ pub fn as_cid_ref(&self) -> ConnectionIdRef {
+ ConnectionIdRef::from(&self.cid[..])
+ }
+}
+
+impl AsRef<[u8]> for ConnectionId {
+ fn as_ref(&self) -> &[u8] {
+ self.borrow()
+ }
+}
+
+impl Borrow<[u8]> for ConnectionId {
+ fn borrow(&self) -> &[u8] {
+ &self.cid
+ }
+}
+
+impl From<SmallVec<[u8; MAX_CONNECTION_ID_LEN]>> for ConnectionId {
+ fn from(cid: SmallVec<[u8; MAX_CONNECTION_ID_LEN]>) -> Self {
+ Self { cid }
+ }
+}
+
+impl From<Vec<u8>> for ConnectionId {
+ fn from(cid: Vec<u8>) -> Self {
+ Self::from(SmallVec::from(cid))
+ }
+}
+
+impl<T: AsRef<[u8]> + ?Sized> From<&T> for ConnectionId {
+ fn from(buf: &T) -> Self {
+ Self::from(SmallVec::from(buf.as_ref()))
+ }
+}
+
+impl<'a> From<ConnectionIdRef<'a>> for ConnectionId {
+ fn from(cidref: ConnectionIdRef<'a>) -> Self {
+ Self::from(SmallVec::from(cidref.cid))
+ }
+}
+
+impl std::ops::Deref for ConnectionId {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.cid
+ }
+}
+
+impl ::std::fmt::Debug for ConnectionId {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "CID {}", hex_with_len(&self.cid))
+ }
+}
+
+impl ::std::fmt::Display for ConnectionId {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "{}", hex(&self.cid))
+ }
+}
+
+impl<'a> PartialEq<ConnectionIdRef<'a>> for ConnectionId {
+ fn eq(&self, other: &ConnectionIdRef<'a>) -> bool {
+ &self.cid[..] == other.cid
+ }
+}
+
+#[derive(Hash, Eq, PartialEq, Clone, Copy)]
+pub struct ConnectionIdRef<'a> {
+ cid: &'a [u8],
+}
+
+impl<'a> ::std::fmt::Debug for ConnectionIdRef<'a> {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "CID {}", hex_with_len(self.cid))
+ }
+}
+
+impl<'a> ::std::fmt::Display for ConnectionIdRef<'a> {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "{}", hex(self.cid))
+ }
+}
+
+impl<'a, T: AsRef<[u8]> + ?Sized> From<&'a T> for ConnectionIdRef<'a> {
+ fn from(cid: &'a T) -> Self {
+ Self { cid: cid.as_ref() }
+ }
+}
+
+impl<'a> std::ops::Deref for ConnectionIdRef<'a> {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ self.cid
+ }
+}
+
+impl<'a> PartialEq<ConnectionId> for ConnectionIdRef<'a> {
+ fn eq(&self, other: &ConnectionId) -> bool {
+ self.cid == &other.cid[..]
+ }
+}
+
+pub trait ConnectionIdDecoder {
+ /// Decodes a connection ID from the provided decoder.
+ fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>>;
+}
+
+pub trait ConnectionIdGenerator: ConnectionIdDecoder {
+ /// Generates a connection ID. This can return `None` if the generator
+ /// is exhausted.
+ fn generate_cid(&mut self) -> Option<ConnectionId>;
+ /// Indicates whether the connection IDs are zero-length.
+ /// If this returns true, `generate_cid` must always produce an empty value
+ /// and never `None`.
+ /// If this returns false, `generate_cid` must never produce an empty value,
+ /// though it can return `None`.
+ ///
+ /// You should not need to implement this: if you want zero-length connection IDs,
+ /// use `EmptyConnectionIdGenerator` instead.
+ fn generates_empty_cids(&self) -> bool {
+ false
+ }
+ fn as_decoder(&self) -> &dyn ConnectionIdDecoder;
+}
+
+/// An `EmptyConnectionIdGenerator` generates empty connection IDs.
+#[derive(Default)]
+pub struct EmptyConnectionIdGenerator {}
+
+impl ConnectionIdDecoder for EmptyConnectionIdGenerator {
+ fn decode_cid<'a>(&self, _: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> {
+ Some(ConnectionIdRef::from(&[]))
+ }
+}
+
+impl ConnectionIdGenerator for EmptyConnectionIdGenerator {
+ fn generate_cid(&mut self) -> Option<ConnectionId> {
+ Some(ConnectionId::from(&[]))
+ }
+ fn as_decoder(&self) -> &dyn ConnectionIdDecoder {
+ self
+ }
+ fn generates_empty_cids(&self) -> bool {
+ true
+ }
+}
+
+/// An RandomConnectionIdGenerator produces connection IDs of
+/// a fixed length and random content. No effort is made to
+/// prevent collisions.
+pub struct RandomConnectionIdGenerator {
+ len: usize,
+}
+
+impl RandomConnectionIdGenerator {
+ pub fn new(len: usize) -> Self {
+ Self { len }
+ }
+}
+
+impl ConnectionIdDecoder for RandomConnectionIdGenerator {
+ fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> {
+ dec.decode(self.len).map(ConnectionIdRef::from)
+ }
+}
+
+impl ConnectionIdGenerator for RandomConnectionIdGenerator {
+ fn generate_cid(&mut self) -> Option<ConnectionId> {
+ Some(ConnectionId::from(&random(self.len)))
+ }
+
+ fn as_decoder(&self) -> &dyn ConnectionIdDecoder {
+ self
+ }
+
+ fn generates_empty_cids(&self) -> bool {
+ self.len == 0
+ }
+}
+
+/// A single connection ID, as saved from NEW_CONNECTION_ID.
+/// This is templated so that the connection ID entries from a peer can be
+/// saved with a stateless reset token. Local entries don't need that.
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub struct ConnectionIdEntry<SRT: Clone + PartialEq> {
+ /// The sequence number.
+ seqno: u64,
+ /// The connection ID.
+ cid: ConnectionId,
+ /// The corresponding stateless reset token.
+ srt: SRT,
+}
+
+impl ConnectionIdEntry<[u8; 16]> {
+ /// Create a random stateless reset token so that it is hard to guess the correct
+ /// value and reset the connection.
+ fn random_srt() -> [u8; 16] {
+ <[u8; 16]>::try_from(&random(16)[..]).unwrap()
+ }
+
+ /// Create the first entry, which won't have a stateless reset token.
+ pub fn initial_remote(cid: ConnectionId) -> Self {
+ Self::new(CONNECTION_ID_SEQNO_INITIAL, cid, Self::random_srt())
+ }
+
+ /// Create an empty for when the peer chooses empty connection IDs.
+ /// This uses a special sequence number just because it can.
+ pub fn empty_remote() -> Self {
+ Self::new(
+ CONNECTION_ID_SEQNO_EMPTY,
+ ConnectionId::from(&[]),
+ Self::random_srt(),
+ )
+ }
+
+ fn token_equal(a: &[u8; 16], b: &[u8; 16]) -> bool {
+ // rustc might decide to optimize this and make this non-constant-time
+ // with respect to `t`, but it doesn't appear to currently.
+ let mut c = 0;
+ for (&a, &b) in a.iter().zip(b) {
+ c |= a ^ b;
+ }
+ c == 0
+ }
+
+ /// Determine whether this is a valid stateless reset.
+ pub fn is_stateless_reset(&self, token: &[u8; 16]) -> bool {
+ // A sequence number of 2^62 or more has no corresponding stateless reset token.
+ (self.seqno < (1 << 62)) && Self::token_equal(&self.srt, token)
+ }
+
+ /// Return true if the two contain any equal parts.
+ fn any_part_equal(&self, other: &Self) -> bool {
+ self.seqno == other.seqno || self.cid == other.cid || self.srt == other.srt
+ }
+
+ /// The sequence number of this entry.
+ pub fn sequence_number(&self) -> u64 {
+ self.seqno
+ }
+}
+
+impl ConnectionIdEntry<()> {
+ /// Create an initial entry.
+ pub fn initial_local(cid: ConnectionId) -> Self {
+ Self::new(0, cid, ())
+ }
+}
+
+impl<SRT: Clone + PartialEq> ConnectionIdEntry<SRT> {
+ pub fn new(seqno: u64, cid: ConnectionId, srt: SRT) -> Self {
+ Self { seqno, cid, srt }
+ }
+
+ /// Update the stateless reset token. This panics if the sequence number is non-zero.
+ pub fn set_stateless_reset_token(&mut self, srt: SRT) {
+ assert_eq!(self.seqno, CONNECTION_ID_SEQNO_INITIAL);
+ self.srt = srt;
+ }
+
+ /// Replace the connection ID. This panics if the sequence number is non-zero.
+ pub fn update_cid(&mut self, cid: ConnectionId) {
+ assert_eq!(self.seqno, CONNECTION_ID_SEQNO_INITIAL);
+ self.cid = cid;
+ }
+
+ pub fn connection_id(&self) -> &ConnectionId {
+ &self.cid
+ }
+
+ pub fn reset_token(&self) -> &SRT {
+ &self.srt
+ }
+}
+
+pub type RemoteConnectionIdEntry = ConnectionIdEntry<[u8; 16]>;
+
+/// A collection of connection IDs that are indexed by a sequence number.
+/// Used to store connection IDs that are provided by a peer.
+#[derive(Debug, Default)]
+pub struct ConnectionIdStore<SRT: Clone + PartialEq> {
+ cids: SmallVec<[ConnectionIdEntry<SRT>; 8]>,
+}
+
+impl<SRT: Clone + PartialEq> ConnectionIdStore<SRT> {
+ pub fn retire(&mut self, seqno: u64) {
+ self.cids.retain(|c| c.seqno != seqno);
+ }
+
+ pub fn contains(&self, cid: ConnectionIdRef) -> bool {
+ self.cids.iter().any(|c| c.cid == cid)
+ }
+
+ pub fn next(&mut self) -> Option<ConnectionIdEntry<SRT>> {
+ if self.cids.is_empty() {
+ None
+ } else {
+ Some(self.cids.remove(0))
+ }
+ }
+
+ pub fn len(&self) -> usize {
+ self.cids.len()
+ }
+}
+
+impl ConnectionIdStore<[u8; 16]> {
+ pub fn add_remote(&mut self, entry: ConnectionIdEntry<[u8; 16]>) -> Res<()> {
+ // It's OK if this perfectly matches an existing entry.
+ if self.cids.iter().any(|c| c == &entry) {
+ return Ok(());
+ }
+ // It's not OK if any individual piece matches though.
+ if self.cids.iter().any(|c| c.any_part_equal(&entry)) {
+ qinfo!("ConnectionIdStore found reused part in NEW_CONNECTION_ID");
+ return Err(Error::ProtocolViolation);
+ }
+
+ // Insert in order so that we use them in order where possible.
+ if let Err(idx) = self.cids.binary_search_by_key(&entry.seqno, |e| e.seqno) {
+ self.cids.insert(idx, entry);
+ Ok(())
+ } else {
+ Err(Error::ProtocolViolation)
+ }
+ }
+
+ // Retire connection IDs and return the sequence numbers of those that were retired.
+ pub fn retire_prior_to(&mut self, retire_prior: u64) -> Vec<u64> {
+ let mut retired = Vec::new();
+ self.cids.retain(|e| {
+ if e.seqno < retire_prior {
+ retired.push(e.seqno);
+ false
+ } else {
+ true
+ }
+ });
+ retired
+ }
+}
+
+impl ConnectionIdStore<()> {
+ fn add_local(&mut self, entry: ConnectionIdEntry<()>) {
+ self.cids.push(entry);
+ }
+}
+
+pub struct ConnectionIdDecoderRef<'a> {
+ generator: Ref<'a, dyn ConnectionIdGenerator>,
+}
+
+// Ideally this would be an implementation of `Deref`, but it doesn't
+// seem to be possible to convince the compiler to build anything useful.
+impl<'a: 'b, 'b> ConnectionIdDecoderRef<'a> {
+ pub fn as_ref(&'a self) -> &'b dyn ConnectionIdDecoder {
+ self.generator.as_decoder()
+ }
+}
+
+/// A connection ID manager looks after the generation of connection IDs,
+/// the set of connection IDs that are valid for the connection, and the
+/// generation of `NEW_CONNECTION_ID` frames.
+pub struct ConnectionIdManager {
+ /// The `ConnectionIdGenerator` instance that is used to create connection IDs.
+ generator: Rc<RefCell<dyn ConnectionIdGenerator>>,
+ /// The connection IDs that we will accept.
+ /// This includes any we advertise in `NEW_CONNECTION_ID` that haven't been bound to a path
+ /// yet. During the handshake at the server, it also includes the randomized DCID pick by
+ /// the client.
+ connection_ids: ConnectionIdStore<()>,
+ /// The maximum number of connection IDs this will accept. This is at least 2 and won't
+ /// be more than `LOCAL_ACTIVE_CID_LIMIT`.
+ limit: usize,
+ /// The next sequence number that will be used for sending `NEW_CONNECTION_ID` frames.
+ next_seqno: u64,
+ /// Outstanding, but lost NEW_CONNECTION_ID frames will be stored here.
+ lost_new_connection_id: Vec<ConnectionIdEntry<[u8; 16]>>,
+}
+
+impl ConnectionIdManager {
+ pub fn new(generator: Rc<RefCell<dyn ConnectionIdGenerator>>, initial: ConnectionId) -> Self {
+ let mut connection_ids = ConnectionIdStore::default();
+ connection_ids.add_local(ConnectionIdEntry::initial_local(initial));
+ Self {
+ generator,
+ connection_ids,
+ // A note about initializing the limit to 2.
+ // For a server, the number of connection IDs that are tracked at the point that
+ // it is first possible to send `NEW_CONNECTION_ID` is 2. One is the client-generated
+ // destination connection (stored with a sequence number of `HANDSHAKE_SEQNO`); the
+ // other being the handshake value (seqno 0). As a result, `NEW_CONNECTION_ID`
+ // won't be sent until until after the handshake completes, because this initial
+ // value remains until the connection completes and transport parameters are handled.
+ limit: 2,
+ next_seqno: 1,
+ lost_new_connection_id: Vec::new(),
+ }
+ }
+
+ pub fn generator(&self) -> Rc<RefCell<dyn ConnectionIdGenerator>> {
+ Rc::clone(&self.generator)
+ }
+
+ pub fn decoder(&self) -> ConnectionIdDecoderRef {
+ ConnectionIdDecoderRef {
+ generator: self.generator.deref().borrow(),
+ }
+ }
+
+ /// Generate a connection ID and stateless reset token for a preferred address.
+ pub fn preferred_address_cid(&mut self) -> Res<(ConnectionId, [u8; 16])> {
+ if self.generator.deref().borrow().generates_empty_cids() {
+ return Err(Error::ConnectionIdsExhausted);
+ }
+ if let Some(cid) = self.generator.borrow_mut().generate_cid() {
+ assert_ne!(cid.len(), 0);
+ debug_assert_eq!(self.next_seqno, CONNECTION_ID_SEQNO_PREFERRED);
+ self.connection_ids
+ .add_local(ConnectionIdEntry::new(self.next_seqno, cid.clone(), ()));
+ self.next_seqno += 1;
+
+ let srt = <[u8; 16]>::try_from(&random(16)[..]).unwrap();
+ Ok((cid, srt))
+ } else {
+ Err(Error::ConnectionIdsExhausted)
+ }
+ }
+
+ pub fn is_valid(&self, cid: ConnectionIdRef) -> bool {
+ self.connection_ids.contains(cid)
+ }
+
+ pub fn retire(&mut self, seqno: u64) {
+ // TODO(mt) - consider keeping connection IDs around for a short while.
+
+ self.connection_ids.retire(seqno);
+ self.lost_new_connection_id.retain(|cid| cid.seqno != seqno);
+ }
+
+ /// During the handshake, a server needs to regard the client's choice of destination
+ /// connection ID as valid. This function saves it in the store in a special place.
+ /// Note that this is only done *after* an Initial packet from the client is
+ /// successfully processed.
+ pub fn add_odcid(&mut self, cid: ConnectionId) {
+ let entry = ConnectionIdEntry::new(CONNECTION_ID_SEQNO_ODCID, cid, ());
+ self.connection_ids.add_local(entry);
+ }
+
+ /// Stop treating the original destination connection ID as valid.
+ pub fn remove_odcid(&mut self) {
+ self.connection_ids.retire(CONNECTION_ID_SEQNO_ODCID);
+ }
+
+ pub fn set_limit(&mut self, limit: u64) {
+ debug_assert!(limit >= 2);
+ self.limit = min(
+ LOCAL_ACTIVE_CID_LIMIT,
+ usize::try_from(limit).unwrap_or(LOCAL_ACTIVE_CID_LIMIT),
+ );
+ }
+
+ fn write_entry(
+ &mut self,
+ entry: &ConnectionIdEntry<[u8; 16]>,
+ builder: &mut PacketBuilder,
+ stats: &mut FrameStats,
+ ) -> Res<bool> {
+ let len = 1 + Encoder::varint_len(entry.seqno) + 1 + 1 + entry.cid.len() + 16;
+ if builder.remaining() < len {
+ return Ok(false);
+ }
+
+ builder.encode_varint(FRAME_TYPE_NEW_CONNECTION_ID);
+ builder.encode_varint(entry.seqno);
+ builder.encode_varint(0u64);
+ builder.encode_vec(1, &entry.cid);
+ builder.encode(&entry.srt);
+ stats.new_connection_id += 1;
+ Ok(true)
+ }
+
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) -> Res<()> {
+ if self.generator.deref().borrow().generates_empty_cids() {
+ debug_assert_eq!(self.generator.borrow_mut().generate_cid().unwrap().len(), 0);
+ return Ok(());
+ }
+
+ while let Some(entry) = self.lost_new_connection_id.pop() {
+ if self.write_entry(&entry, builder, stats)? {
+ tokens.push(RecoveryToken::NewConnectionId(entry));
+ } else {
+ // This shouldn't happen often.
+ self.lost_new_connection_id.push(entry);
+ break;
+ }
+ }
+
+ // Keep writing while we have fewer than the limit of active connection IDs
+ // and while there is room for more. This uses the longest connection ID
+ // length to simplify (assuming Retire Prior To is just 1 byte).
+ while self.connection_ids.len() < self.limit && builder.remaining() >= 47 {
+ let maybe_cid = self.generator.borrow_mut().generate_cid();
+ if let Some(cid) = maybe_cid {
+ assert_ne!(cid.len(), 0);
+ // TODO: generate the stateless reset tokens from the connection ID and a key.
+ let srt = <[u8; 16]>::try_from(&random(16)[..]).unwrap();
+
+ let seqno = self.next_seqno;
+ self.next_seqno += 1;
+ self.connection_ids
+ .add_local(ConnectionIdEntry::new(seqno, cid.clone(), ()));
+
+ let entry = ConnectionIdEntry::new(seqno, cid, srt);
+ debug_assert!(self.write_entry(&entry, builder, stats)?);
+ tokens.push(RecoveryToken::NewConnectionId(entry));
+ }
+ }
+ Ok(())
+ }
+
+ pub fn lost(&mut self, entry: &ConnectionIdEntry<[u8; 16]>) {
+ self.lost_new_connection_id.push(entry.clone());
+ }
+
+ pub fn acked(&mut self, entry: &ConnectionIdEntry<[u8; 16]>) {
+ self.lost_new_connection_id
+ .retain(|e| e.seqno != entry.seqno);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use test_fixture::fixture_init;
+
+ use super::*;
+
+ #[test]
+ fn generate_initial_cid() {
+ fixture_init();
+ for _ in 0..100 {
+ let cid = ConnectionId::generate_initial();
+ if !matches!(cid.len(), 8..=MAX_CONNECTION_ID_LEN) {
+ panic!("connection ID {:?}", cid);
+ }
+ }
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/connection/dump.rs b/third_party/rust/neqo-transport/src/connection/dump.rs
new file mode 100644
index 0000000000..77d51c605c
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/dump.rs
@@ -0,0 +1,46 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Enable just this file for logging to just see packets.
+// e.g. "RUST_LOG=neqo_transport::dump neqo-client ..."
+
+use std::fmt::Write;
+
+use neqo_common::{qdebug, Decoder};
+
+use crate::{
+ connection::Connection,
+ frame::Frame,
+ packet::{PacketNumber, PacketType},
+ path::PathRef,
+};
+
+#[allow(clippy::module_name_repetitions)]
+pub fn dump_packet(
+ conn: &Connection,
+ path: &PathRef,
+ dir: &str,
+ pt: PacketType,
+ pn: PacketNumber,
+ payload: &[u8],
+) {
+ if !log::log_enabled!(log::Level::Debug) {
+ return;
+ }
+
+ let mut s = String::from("");
+ let mut d = Decoder::from(payload);
+ while d.remaining() > 0 {
+ let Ok(f) = Frame::decode(&mut d) else {
+ s.push_str(" [broken]...");
+ break;
+ };
+ if let Some(x) = f.dump() {
+ write!(&mut s, "\n {} {}", dir, &x).unwrap();
+ }
+ }
+ qdebug!([conn], "pn={} type={:?} {}{}", pn, pt, path.borrow(), s);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/idle.rs b/third_party/rust/neqo-transport/src/connection/idle.rs
new file mode 100644
index 0000000000..e33f3defb3
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/idle.rs
@@ -0,0 +1,120 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{
+ cmp::{max, min},
+ time::{Duration, Instant},
+};
+
+use neqo_common::qtrace;
+
+use crate::recovery::RecoveryToken;
+
+#[derive(Debug, Clone)]
+/// There's a little bit of different behavior for resetting idle timeout. See
+/// -transport 10.2 ("Idle Timeout").
+enum IdleTimeoutState {
+ Init,
+ PacketReceived(Instant),
+ AckElicitingPacketSent(Instant),
+}
+
+#[derive(Debug, Clone)]
+/// There's a little bit of different behavior for resetting idle timeout. See
+/// -transport 10.2 ("Idle Timeout").
+pub struct IdleTimeout {
+ timeout: Duration,
+ state: IdleTimeoutState,
+ keep_alive_outstanding: bool,
+}
+
+impl IdleTimeout {
+ pub fn new(timeout: Duration) -> Self {
+ Self {
+ timeout,
+ state: IdleTimeoutState::Init,
+ keep_alive_outstanding: false,
+ }
+ }
+}
+
+impl IdleTimeout {
+ pub fn set_peer_timeout(&mut self, peer_timeout: Duration) {
+ self.timeout = min(self.timeout, peer_timeout);
+ }
+
+ pub fn expiry(&self, now: Instant, pto: Duration, keep_alive: bool) -> Instant {
+ let start = match self.state {
+ IdleTimeoutState::Init => now,
+ IdleTimeoutState::PacketReceived(t) | IdleTimeoutState::AckElicitingPacketSent(t) => t,
+ };
+ let delay = if keep_alive && !self.keep_alive_outstanding {
+ // For a keep-alive timer, wait for half the timeout interval, but be sure
+ // not to wait too little or we will send many unnecessary probes.
+ max(self.timeout / 2, pto)
+ } else {
+ max(self.timeout, pto * 3)
+ };
+ qtrace!(
+ "IdleTimeout::expiry@{now:?} pto={pto:?}, ka={keep_alive} => {t:?}",
+ t = start + delay
+ );
+ start + delay
+ }
+
+ pub fn on_packet_sent(&mut self, now: Instant) {
+ // Only reset idle timeout if we've received a packet since the last
+ // time we reset the timeout here.
+ match self.state {
+ IdleTimeoutState::AckElicitingPacketSent(_) => {}
+ IdleTimeoutState::Init | IdleTimeoutState::PacketReceived(_) => {
+ self.state = IdleTimeoutState::AckElicitingPacketSent(now);
+ }
+ }
+ }
+
+ pub fn on_packet_received(&mut self, now: Instant) {
+ // Only update if this doesn't rewind the idle timeout.
+ // We sometimes process packets after caching them, which uses
+ // the time the packet was received. That could be in the past.
+ let update = match self.state {
+ IdleTimeoutState::Init => true,
+ IdleTimeoutState::AckElicitingPacketSent(t) | IdleTimeoutState::PacketReceived(t) => {
+ t <= now
+ }
+ };
+ if update {
+ self.state = IdleTimeoutState::PacketReceived(now);
+ }
+ }
+
+ pub fn expired(&self, now: Instant, pto: Duration) -> bool {
+ now >= self.expiry(now, pto, false)
+ }
+
+ pub fn send_keep_alive(
+ &mut self,
+ now: Instant,
+ pto: Duration,
+ tokens: &mut Vec<RecoveryToken>,
+ ) -> bool {
+ if !self.keep_alive_outstanding && now >= self.expiry(now, pto, true) {
+ self.keep_alive_outstanding = true;
+ tokens.push(RecoveryToken::KeepAlive);
+ true
+ } else {
+ false
+ }
+ }
+
+ pub fn lost_keep_alive(&mut self) {
+ self.keep_alive_outstanding = false;
+ }
+
+ pub fn ack_keep_alive(&mut self) {
+ self.keep_alive_outstanding = false;
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/connection/mod.rs b/third_party/rust/neqo-transport/src/connection/mod.rs
new file mode 100644
index 0000000000..2de388418a
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/mod.rs
@@ -0,0 +1,3241 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// The class implementing a QUIC connection.
+
+use std::{
+ cell::RefCell,
+ cmp::{max, min},
+ convert::TryFrom,
+ fmt::{self, Debug},
+ mem,
+ net::{IpAddr, SocketAddr},
+ ops::RangeInclusive,
+ rc::{Rc, Weak},
+ time::{Duration, Instant},
+};
+
+use neqo_common::{
+ event::Provider as EventProvider, hex, hex_snip_middle, hrtime, qdebug, qerror, qinfo,
+ qlog::NeqoQlog, qtrace, qwarn, Datagram, Decoder, Encoder, Role,
+};
+use neqo_crypto::{
+ agent::CertificateInfo, random, Agent, AntiReplay, AuthenticationStatus, Cipher, Client, Group,
+ HandshakeState, PrivateKey, PublicKey, ResumptionToken, SecretAgentInfo, SecretAgentPreInfo,
+ Server, ZeroRttChecker,
+};
+use smallvec::SmallVec;
+
+use crate::{
+ addr_valid::{AddressValidation, NewTokenState},
+ cid::{
+ ConnectionId, ConnectionIdEntry, ConnectionIdGenerator, ConnectionIdManager,
+ ConnectionIdRef, ConnectionIdStore, LOCAL_ACTIVE_CID_LIMIT,
+ },
+ crypto::{Crypto, CryptoDxState, CryptoSpace},
+ events::{ConnectionEvent, ConnectionEvents, OutgoingDatagramOutcome},
+ frame::{
+ CloseError, Frame, FrameType, FRAME_TYPE_CONNECTION_CLOSE_APPLICATION,
+ FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT,
+ },
+ packet::{DecryptedPacket, PacketBuilder, PacketNumber, PacketType, PublicPacket},
+ path::{Path, PathRef, Paths},
+ qlog,
+ quic_datagrams::{DatagramTracking, QuicDatagrams},
+ recovery::{LossRecovery, RecoveryToken, SendProfile},
+ recv_stream::RecvStreamStats,
+ rtt::GRANULARITY,
+ stats::{Stats, StatsCell},
+ stream_id::StreamType,
+ streams::{SendOrder, Streams},
+ tparams::{
+ self, TransportParameter, TransportParameterId, TransportParameters,
+ TransportParametersHandler,
+ },
+ tracking::{AckTracker, PacketNumberSpace, SentPacket},
+ version::{Version, WireVersion},
+ AppError, ConnectionError, Error, Res, StreamId,
+};
+mod dump;
+mod idle;
+pub mod params;
+mod saved;
+mod state;
+#[cfg(test)]
+pub mod test_internal;
+use dump::dump_packet;
+use idle::IdleTimeout;
+pub use params::ConnectionParameters;
+use params::PreferredAddressConfig;
+#[cfg(test)]
+pub use params::ACK_RATIO_SCALE;
+use saved::SavedDatagrams;
+use state::StateSignaling;
+pub use state::{ClosingFrame, State};
+
+pub use crate::send_stream::{RetransmissionPriority, SendStreamStats, TransmissionPriority};
+
+#[derive(Debug, Default)]
+struct Packet(Vec<u8>);
+
+/// The number of Initial packets that the client will send in response
+/// to receiving an undecryptable packet during the early part of the
+/// handshake. This is a hack, but a useful one.
+const EXTRA_INITIALS: usize = 4;
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum ZeroRttState {
+ Init,
+ Sending,
+ AcceptedClient,
+ AcceptedServer,
+ Rejected,
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+/// Type returned from process() and `process_output()`. Users are required to
+/// call these repeatedly until `Callback` or `None` is returned.
+pub enum Output {
+ /// Connection requires no action.
+ None,
+ /// Connection requires the datagram be sent.
+ Datagram(Datagram),
+ /// Connection requires `process_input()` be called when the `Duration`
+ /// elapses.
+ Callback(Duration),
+}
+
+impl Output {
+ /// Convert into an `Option<Datagram>`.
+ #[must_use]
+ pub fn dgram(self) -> Option<Datagram> {
+ match self {
+ Self::Datagram(dg) => Some(dg),
+ _ => None,
+ }
+ }
+
+ /// Get a reference to the Datagram, if any.
+ pub fn as_dgram_ref(&self) -> Option<&Datagram> {
+ match self {
+ Self::Datagram(dg) => Some(dg),
+ _ => None,
+ }
+ }
+
+ /// Ask how long the caller should wait before calling back.
+ #[must_use]
+ pub fn callback(&self) -> Duration {
+ match self {
+ Self::Callback(t) => *t,
+ _ => Duration::new(0, 0),
+ }
+ }
+}
+
+/// Used by inner functions like Connection::output.
+enum SendOption {
+ /// Yes, please send this datagram.
+ Yes(Datagram),
+ /// Don't send. If this was blocked on the pacer (the arg is true).
+ No(bool),
+}
+
+impl Default for SendOption {
+ fn default() -> Self {
+ Self::No(false)
+ }
+}
+
+/// Used by `Connection::preprocess` to determine what to do
+/// with an packet before attempting to remove protection.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum PreprocessResult {
+ /// End processing and return successfully.
+ End,
+ /// Stop processing this datagram and move on to the next.
+ Next,
+ /// Continue and process this packet.
+ Continue,
+}
+
+/// `AddressValidationInfo` holds information relevant to either
+/// responding to address validation (`NewToken`, `Retry`) or generating
+/// tokens for address validation (`Server`).
+enum AddressValidationInfo {
+ None,
+ // We are a client and have information from `NEW_TOKEN`.
+ NewToken(Vec<u8>),
+ // We are a client and have received a `Retry` packet.
+ Retry {
+ token: Vec<u8>,
+ retry_source_cid: ConnectionId,
+ },
+ // We are a server and can generate tokens.
+ Server(Weak<RefCell<AddressValidation>>),
+}
+
+impl AddressValidationInfo {
+ pub fn token(&self) -> &[u8] {
+ match self {
+ Self::NewToken(token) | Self::Retry { token, .. } => token,
+ _ => &[],
+ }
+ }
+
+ pub fn generate_new_token(
+ &mut self,
+ peer_address: SocketAddr,
+ now: Instant,
+ ) -> Option<Vec<u8>> {
+ match self {
+ Self::Server(ref w) => {
+ if let Some(validation) = w.upgrade() {
+ validation
+ .borrow()
+ .generate_new_token(peer_address, now)
+ .ok()
+ } else {
+ None
+ }
+ }
+ Self::None => None,
+ _ => unreachable!("called a server function on a client"),
+ }
+ }
+}
+
+/// A QUIC Connection
+///
+/// First, create a new connection using `new_client()` or `new_server()`.
+///
+/// For the life of the connection, handle activity in the following manner:
+/// 1. Perform operations using the `stream_*()` methods.
+/// 1. Call `process_input()` when a datagram is received or the timer
+/// expires. Obtain information on connection state changes by checking
+/// `events()`.
+/// 1. Having completed handling current activity, repeatedly call
+/// `process_output()` for packets to send, until it returns `Output::Callback`
+/// or `Output::None`.
+///
+/// After the connection is closed (either by calling `close()` or by the
+/// remote) continue processing until `state()` returns `Closed`.
+pub struct Connection {
+ role: Role,
+ version: Version,
+ state: State,
+ tps: Rc<RefCell<TransportParametersHandler>>,
+ /// What we are doing with 0-RTT.
+ zero_rtt_state: ZeroRttState,
+ /// All of the network paths that we are aware of.
+ paths: Paths,
+ /// This object will generate connection IDs for the connection.
+ cid_manager: ConnectionIdManager,
+ address_validation: AddressValidationInfo,
+ /// The connection IDs that were provided by the peer.
+ connection_ids: ConnectionIdStore<[u8; 16]>,
+
+ /// The source connection ID that this endpoint uses for the handshake.
+ /// Since we need to communicate this to our peer in tparams, setting this
+ /// value is part of constructing the struct.
+ local_initial_source_cid: ConnectionId,
+ /// The source connection ID from the first packet from the other end.
+ /// This is checked against the peer's transport parameters.
+ remote_initial_source_cid: Option<ConnectionId>,
+ /// The destination connection ID from the first packet from the client.
+ /// This is checked by the client against the server's transport parameters.
+ original_destination_cid: Option<ConnectionId>,
+
+ /// We sometimes save a datagram against the possibility that keys will later
+ /// become available. This avoids reporting packets as dropped during the handshake
+ /// when they are either just reordered or we haven't been able to install keys yet.
+ /// In particular, this occurs when asynchronous certificate validation happens.
+ saved_datagrams: SavedDatagrams,
+ /// Some packets were received, but not tracked.
+ received_untracked: bool,
+
+ /// This is responsible for the QuicDatagrams' handling:
+ /// <https://datatracker.ietf.org/doc/html/draft-ietf-quic-datagram>
+ quic_datagrams: QuicDatagrams,
+
+ pub(crate) crypto: Crypto,
+ pub(crate) acks: AckTracker,
+ idle_timeout: IdleTimeout,
+ streams: Streams,
+ state_signaling: StateSignaling,
+ loss_recovery: LossRecovery,
+ events: ConnectionEvents,
+ new_token: NewTokenState,
+ stats: StatsCell,
+ qlog: NeqoQlog,
+ /// A session ticket was received without NEW_TOKEN,
+ /// this is when that turns into an event without NEW_TOKEN.
+ release_resumption_token_timer: Option<Instant>,
+ conn_params: ConnectionParameters,
+ hrtime: hrtime::Handle,
+
+ /// For testing purposes it is sometimes necessary to inject frames that wouldn't
+ /// otherwise be sent, just to see how a connection handles them. Inserting them
+ /// into packets proper mean that the frames follow the entire processing path.
+ #[cfg(test)]
+ pub test_frame_writer: Option<Box<dyn test_internal::FrameWriter>>,
+}
+
+impl Debug for Connection {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{:?} Connection: {:?} {:?}",
+ self.role,
+ self.state,
+ self.paths.primary_fallible()
+ )
+ }
+}
+
+impl Connection {
+ /// A long default for timer resolution, so that we don't tax the
+ /// system too hard when we don't need to.
+ const LOOSE_TIMER_RESOLUTION: Duration = Duration::from_millis(50);
+
+ /// Create a new QUIC connection with Client role.
+ pub fn new_client(
+ server_name: impl Into<String>,
+ protocols: &[impl AsRef<str>],
+ cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>,
+ local_addr: SocketAddr,
+ remote_addr: SocketAddr,
+ conn_params: ConnectionParameters,
+ now: Instant,
+ ) -> Res<Self> {
+ let dcid = ConnectionId::generate_initial();
+ let mut c = Self::new(
+ Role::Client,
+ Agent::from(Client::new(server_name.into(), conn_params.is_greasing())?),
+ cid_generator,
+ protocols,
+ conn_params,
+ )?;
+ c.crypto.states.init(
+ c.conn_params.get_versions().compatible(),
+ Role::Client,
+ &dcid,
+ );
+ c.original_destination_cid = Some(dcid);
+ let path = Path::temporary(
+ local_addr,
+ remote_addr,
+ c.conn_params.get_cc_algorithm(),
+ c.conn_params.pacing_enabled(),
+ NeqoQlog::default(),
+ now,
+ );
+ c.setup_handshake_path(&Rc::new(RefCell::new(path)), now);
+ Ok(c)
+ }
+
+ /// Create a new QUIC connection with Server role.
+ pub fn new_server(
+ certs: &[impl AsRef<str>],
+ protocols: &[impl AsRef<str>],
+ cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>,
+ conn_params: ConnectionParameters,
+ ) -> Res<Self> {
+ Self::new(
+ Role::Server,
+ Agent::from(Server::new(certs)?),
+ cid_generator,
+ protocols,
+ conn_params,
+ )
+ }
+
+ fn new<P: AsRef<str>>(
+ role: Role,
+ agent: Agent,
+ cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>,
+ protocols: &[P],
+ conn_params: ConnectionParameters,
+ ) -> Res<Self> {
+ // Setup the local connection ID.
+ let local_initial_source_cid = cid_generator
+ .borrow_mut()
+ .generate_cid()
+ .ok_or(Error::ConnectionIdsExhausted)?;
+ let mut cid_manager =
+ ConnectionIdManager::new(cid_generator, local_initial_source_cid.clone());
+ let mut tps = conn_params.create_transport_parameter(role, &mut cid_manager)?;
+ tps.local.set_bytes(
+ tparams::INITIAL_SOURCE_CONNECTION_ID,
+ local_initial_source_cid.to_vec(),
+ );
+
+ let tphandler = Rc::new(RefCell::new(tps));
+ let crypto = Crypto::new(
+ conn_params.get_versions().initial(),
+ agent,
+ protocols.iter().map(P::as_ref).map(String::from).collect(),
+ Rc::clone(&tphandler),
+ conn_params.is_fuzzing(),
+ )?;
+
+ let stats = StatsCell::default();
+ let events = ConnectionEvents::default();
+ let quic_datagrams = QuicDatagrams::new(
+ conn_params.get_datagram_size(),
+ conn_params.get_outgoing_datagram_queue(),
+ conn_params.get_incoming_datagram_queue(),
+ events.clone(),
+ );
+
+ let c = Self {
+ role,
+ version: conn_params.get_versions().initial(),
+ state: State::Init,
+ paths: Paths::default(),
+ cid_manager,
+ tps: tphandler.clone(),
+ zero_rtt_state: ZeroRttState::Init,
+ address_validation: AddressValidationInfo::None,
+ local_initial_source_cid,
+ remote_initial_source_cid: None,
+ original_destination_cid: None,
+ saved_datagrams: SavedDatagrams::default(),
+ received_untracked: false,
+ crypto,
+ acks: AckTracker::default(),
+ idle_timeout: IdleTimeout::new(conn_params.get_idle_timeout()),
+ streams: Streams::new(tphandler, role, events.clone()),
+ connection_ids: ConnectionIdStore::default(),
+ state_signaling: StateSignaling::Idle,
+ loss_recovery: LossRecovery::new(stats.clone(), conn_params.get_fast_pto()),
+ events,
+ new_token: NewTokenState::new(role),
+ stats,
+ qlog: NeqoQlog::disabled(),
+ release_resumption_token_timer: None,
+ conn_params,
+ hrtime: hrtime::Time::get(Self::LOOSE_TIMER_RESOLUTION),
+ quic_datagrams,
+ #[cfg(test)]
+ test_frame_writer: None,
+ };
+ c.stats.borrow_mut().init(format!("{c}"));
+ Ok(c)
+ }
+
+ pub fn server_enable_0rtt(
+ &mut self,
+ anti_replay: &AntiReplay,
+ zero_rtt_checker: impl ZeroRttChecker + 'static,
+ ) -> Res<()> {
+ self.crypto
+ .server_enable_0rtt(self.tps.clone(), anti_replay, zero_rtt_checker)
+ }
+
+ pub fn server_enable_ech(
+ &mut self,
+ config: u8,
+ public_name: &str,
+ sk: &PrivateKey,
+ pk: &PublicKey,
+ ) -> Res<()> {
+ self.crypto.server_enable_ech(config, public_name, sk, pk)
+ }
+
+ /// Get the active ECH configuration, which is empty if ECH is disabled.
+ pub fn ech_config(&self) -> &[u8] {
+ self.crypto.ech_config()
+ }
+
+ pub fn client_enable_ech(&mut self, ech_config_list: impl AsRef<[u8]>) -> Res<()> {
+ self.crypto.client_enable_ech(ech_config_list)
+ }
+
+ /// Set or clear the qlog for this connection.
+ pub fn set_qlog(&mut self, qlog: NeqoQlog) {
+ self.loss_recovery.set_qlog(qlog.clone());
+ self.paths.set_qlog(qlog.clone());
+ self.qlog = qlog;
+ }
+
+ /// Get the qlog (if any) for this connection.
+ pub fn qlog_mut(&mut self) -> &mut NeqoQlog {
+ &mut self.qlog
+ }
+
+ /// Get the original destination connection id for this connection. This
+ /// will always be present for Role::Client but not if Role::Server is in
+ /// State::Init.
+ pub fn odcid(&self) -> Option<&ConnectionId> {
+ self.original_destination_cid.as_ref()
+ }
+
+ /// Set a local transport parameter, possibly overriding a default value.
+ /// This only sets transport parameters without dealing with other aspects of
+ /// setting the value.
+ ///
+ /// # Panics
+ ///
+ /// This panics if the transport parameter is known to this crate.
+ pub fn set_local_tparam(&self, tp: TransportParameterId, value: TransportParameter) -> Res<()> {
+ #[cfg(not(test))]
+ {
+ assert!(!tparams::INTERNAL_TRANSPORT_PARAMETERS.contains(&tp));
+ }
+ if *self.state() == State::Init {
+ self.tps.borrow_mut().local.set(tp, value);
+ Ok(())
+ } else {
+ qerror!("Current state: {:?}", self.state());
+ qerror!("Cannot set local tparam when not in an initial connection state.");
+ Err(Error::ConnectionState)
+ }
+ }
+
+ /// `odcid` is their original choice for our CID, which we get from the Retry token.
+ /// `remote_cid` is the value from the Source Connection ID field of an incoming packet: what
+ /// the peer wants us to use now. `retry_cid` is what we asked them to use when we sent the
+ /// Retry.
+ pub(crate) fn set_retry_cids(
+ &mut self,
+ odcid: ConnectionId,
+ remote_cid: ConnectionId,
+ retry_cid: ConnectionId,
+ ) {
+ debug_assert_eq!(self.role, Role::Server);
+ qtrace!(
+ [self],
+ "Retry CIDs: odcid={} remote={} retry={}",
+ odcid,
+ remote_cid,
+ retry_cid
+ );
+ // We advertise "our" choices in transport parameters.
+ let local_tps = &mut self.tps.borrow_mut().local;
+ local_tps.set_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID, odcid.to_vec());
+ local_tps.set_bytes(tparams::RETRY_SOURCE_CONNECTION_ID, retry_cid.to_vec());
+
+ // ...and save their choices for later validation.
+ self.remote_initial_source_cid = Some(remote_cid);
+ }
+
+ fn retry_sent(&self) -> bool {
+ self.tps
+ .borrow()
+ .local
+ .get_bytes(tparams::RETRY_SOURCE_CONNECTION_ID)
+ .is_some()
+ }
+
+ /// Set ALPN preferences. Strings that appear earlier in the list are given
+ /// higher preference.
+ pub fn set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()> {
+ self.crypto.tls.set_alpn(protocols)?;
+ Ok(())
+ }
+
+ /// Enable a set of ciphers.
+ pub fn set_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()> {
+ if self.state != State::Init {
+ qerror!([self], "Cannot enable ciphers in state {:?}", self.state);
+ return Err(Error::ConnectionState);
+ }
+ self.crypto.tls.set_ciphers(ciphers)?;
+ Ok(())
+ }
+
+ /// Enable a set of key exchange groups.
+ pub fn set_groups(&mut self, groups: &[Group]) -> Res<()> {
+ if self.state != State::Init {
+ qerror!([self], "Cannot enable groups in state {:?}", self.state);
+ return Err(Error::ConnectionState);
+ }
+ self.crypto.tls.set_groups(groups)?;
+ Ok(())
+ }
+
+ /// Set the number of additional key shares to send in the client hello.
+ pub fn send_additional_key_shares(&mut self, count: usize) -> Res<()> {
+ if self.state != State::Init {
+ qerror!([self], "Cannot enable groups in state {:?}", self.state);
+ return Err(Error::ConnectionState);
+ }
+ self.crypto.tls.send_additional_key_shares(count)?;
+ Ok(())
+ }
+
+ fn make_resumption_token(&mut self) -> ResumptionToken {
+ debug_assert_eq!(self.role, Role::Client);
+ debug_assert!(self.crypto.has_resumption_token());
+ let rtt = self.paths.primary().borrow().rtt().estimate();
+ self.crypto
+ .create_resumption_token(
+ self.new_token.take_token(),
+ self.tps
+ .borrow()
+ .remote
+ .as_ref()
+ .expect("should have transport parameters"),
+ self.version,
+ u64::try_from(rtt.as_millis()).unwrap_or(0),
+ )
+ .unwrap()
+ }
+
+ /// Get the simplest PTO calculation for all those cases where we need
+ /// a value of this approximate order. Don't use this for loss recovery,
+ /// only use it where a more precise value is not important.
+ fn pto(&self) -> Duration {
+ self.paths
+ .primary()
+ .borrow()
+ .rtt()
+ .pto(PacketNumberSpace::ApplicationData)
+ }
+
+ fn create_resumption_token(&mut self, now: Instant) {
+ if self.role == Role::Server || self.state < State::Connected {
+ return;
+ }
+
+ qtrace!(
+ [self],
+ "Maybe create resumption token: {} {}",
+ self.crypto.has_resumption_token(),
+ self.new_token.has_token()
+ );
+
+ while self.crypto.has_resumption_token() && self.new_token.has_token() {
+ let token = self.make_resumption_token();
+ self.events.client_resumption_token(token);
+ }
+
+ // If we have a resumption ticket check or set a timer.
+ if self.crypto.has_resumption_token() {
+ let arm = if let Some(expiration_time) = self.release_resumption_token_timer {
+ if expiration_time <= now {
+ let token = self.make_resumption_token();
+ self.events.client_resumption_token(token);
+ self.release_resumption_token_timer = None;
+
+ // This means that we release one session ticket every 3 PTOs
+ // if no NEW_TOKEN frame is received.
+ self.crypto.has_resumption_token()
+ } else {
+ false
+ }
+ } else {
+ true
+ };
+
+ if arm {
+ self.release_resumption_token_timer = Some(now + 3 * self.pto());
+ }
+ }
+ }
+
+ /// The correct way to obtain a resumption token is to wait for the
+ /// `ConnectionEvent::ResumptionToken` event. To emit the event we are waiting for a
+ /// resumption token and a `NEW_TOKEN` frame to arrive. Some servers don't send `NEW_TOKEN`
+ /// frames and in this case, we wait for 3xPTO before emitting an event. This is especially a
+ /// problem for short-lived connections, where the connection is closed before any events are
+ /// released. This function retrieves the token, without waiting for a `NEW_TOKEN` frame to
+ /// arrive.
+ ///
+ /// # Panics
+ ///
+ /// If this is called on a server.
+ pub fn take_resumption_token(&mut self, now: Instant) -> Option<ResumptionToken> {
+ assert_eq!(self.role, Role::Client);
+
+ if self.crypto.has_resumption_token() {
+ let token = self.make_resumption_token();
+ if self.crypto.has_resumption_token() {
+ self.release_resumption_token_timer = Some(now + 3 * self.pto());
+ }
+ Some(token)
+ } else {
+ None
+ }
+ }
+
+ /// Enable resumption, using a token previously provided.
+ /// This can only be called once and only on the client.
+ /// After calling the function, it should be possible to attempt 0-RTT
+ /// if the token supports that.
+ pub fn enable_resumption(&mut self, now: Instant, token: impl AsRef<[u8]>) -> Res<()> {
+ if self.state != State::Init {
+ qerror!([self], "set token in state {:?}", self.state);
+ return Err(Error::ConnectionState);
+ }
+ if self.role == Role::Server {
+ return Err(Error::ConnectionState);
+ }
+
+ qinfo!(
+ [self],
+ "resumption token {}",
+ hex_snip_middle(token.as_ref())
+ );
+ let mut dec = Decoder::from(token.as_ref());
+
+ let version =
+ Version::try_from(dec.decode_uint(4).ok_or(Error::InvalidResumptionToken)? as u32)?;
+ qtrace!([self], " version {:?}", version);
+ if !self.conn_params.get_versions().all().contains(&version) {
+ return Err(Error::DisabledVersion);
+ }
+
+ let rtt = Duration::from_millis(dec.decode_varint().ok_or(Error::InvalidResumptionToken)?);
+ qtrace!([self], " RTT {:?}", rtt);
+
+ let tp_slice = dec.decode_vvec().ok_or(Error::InvalidResumptionToken)?;
+ qtrace!([self], " transport parameters {}", hex(tp_slice));
+ let mut dec_tp = Decoder::from(tp_slice);
+ let tp =
+ TransportParameters::decode(&mut dec_tp).map_err(|_| Error::InvalidResumptionToken)?;
+
+ let init_token = dec.decode_vvec().ok_or(Error::InvalidResumptionToken)?;
+ qtrace!([self], " Initial token {}", hex(init_token));
+
+ let tok = dec.decode_remainder();
+ qtrace!([self], " TLS token {}", hex(tok));
+
+ match self.crypto.tls {
+ Agent::Client(ref mut c) => {
+ let res = c.enable_resumption(tok);
+ if let Err(e) = res {
+ self.absorb_error::<Error>(now, Err(Error::from(e)));
+ return Ok(());
+ }
+ }
+ Agent::Server(_) => return Err(Error::WrongRole),
+ }
+
+ self.version = version;
+ self.conn_params.get_versions_mut().set_initial(version);
+ self.tps.borrow_mut().set_version(version);
+ self.tps.borrow_mut().remote_0rtt = Some(tp);
+ if !init_token.is_empty() {
+ self.address_validation = AddressValidationInfo::NewToken(init_token.to_vec());
+ }
+ self.paths.primary().borrow_mut().rtt_mut().set_initial(rtt);
+ self.set_initial_limits();
+ // Start up TLS, which has the effect of setting up all the necessary
+ // state for 0-RTT. This only stages the CRYPTO frames.
+ let res = self.client_start(now);
+ self.absorb_error(now, res);
+ Ok(())
+ }
+
+ pub(crate) fn set_validation(&mut self, validation: Rc<RefCell<AddressValidation>>) {
+ qtrace!([self], "Enabling NEW_TOKEN");
+ assert_eq!(self.role, Role::Server);
+ self.address_validation = AddressValidationInfo::Server(Rc::downgrade(&validation));
+ }
+
+ /// Send a TLS session ticket AND a NEW_TOKEN frame (if possible).
+ pub fn send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<()> {
+ if self.role == Role::Client {
+ return Err(Error::WrongRole);
+ }
+
+ let tps = &self.tps;
+ if let Agent::Server(ref mut s) = self.crypto.tls {
+ let mut enc = Encoder::default();
+ enc.encode_vvec_with(|enc_inner| {
+ tps.borrow().local.encode(enc_inner);
+ });
+ enc.encode(extra);
+ let records = s.send_ticket(now, enc.as_ref())?;
+ qinfo!([self], "send session ticket {}", hex(&enc));
+ self.crypto.buffer_records(records)?;
+ } else {
+ unreachable!();
+ }
+
+ // If we are able, also send a NEW_TOKEN frame.
+ // This should be recording all remote addresses that are valid,
+ // but there are just 0 or 1 in the current implementation.
+ if let Some(path) = self.paths.primary_fallible() {
+ if let Some(token) = self
+ .address_validation
+ .generate_new_token(path.borrow().remote_address(), now)
+ {
+ self.new_token.send_new_token(token);
+ }
+ Ok(())
+ } else {
+ Err(Error::NotConnected)
+ }
+ }
+
+ pub fn tls_info(&self) -> Option<&SecretAgentInfo> {
+ self.crypto.tls.info()
+ }
+
+ pub fn tls_preinfo(&self) -> Res<SecretAgentPreInfo> {
+ Ok(self.crypto.tls.preinfo()?)
+ }
+
+ /// Get the peer's certificate chain and other info.
+ pub fn peer_certificate(&self) -> Option<CertificateInfo> {
+ self.crypto.tls.peer_certificate()
+ }
+
+ /// Call by application when the peer cert has been verified.
+ ///
+ /// This panics if there is no active peer. It's OK to call this
+ /// when authentication isn't needed, that will likely only cause
+ /// the connection to fail. However, if no packets have been
+ /// exchanged, it's not OK.
+ pub fn authenticated(&mut self, status: AuthenticationStatus, now: Instant) {
+ qinfo!([self], "Authenticated {:?}", status);
+ self.crypto.tls.authenticated(status);
+ let res = self.handshake(now, self.version, PacketNumberSpace::Handshake, None);
+ self.absorb_error(now, res);
+ self.process_saved(now);
+ }
+
+ /// Get the role of the connection.
+ pub fn role(&self) -> Role {
+ self.role
+ }
+
+ /// Get the state of the connection.
+ pub fn state(&self) -> &State {
+ &self.state
+ }
+
+ /// The QUIC version in use.
+ pub fn version(&self) -> Version {
+ self.version
+ }
+
+ /// Get the 0-RTT state of the connection.
+ pub fn zero_rtt_state(&self) -> ZeroRttState {
+ self.zero_rtt_state
+ }
+
+ /// Get a snapshot of collected statistics.
+ pub fn stats(&self) -> Stats {
+ let mut v = self.stats.borrow().clone();
+ if let Some(p) = self.paths.primary_fallible() {
+ let p = p.borrow();
+ v.rtt = p.rtt().estimate();
+ v.rttvar = p.rtt().rttvar();
+ }
+ v
+ }
+
+ // This function wraps a call to another function and sets the connection state
+ // properly if that call fails.
+ fn capture_error<T>(
+ &mut self,
+ path: Option<PathRef>,
+ now: Instant,
+ frame_type: FrameType,
+ res: Res<T>,
+ ) -> Res<T> {
+ if let Err(v) = &res {
+ #[cfg(debug_assertions)]
+ let msg = format!("{v:?}");
+ #[cfg(not(debug_assertions))]
+ let msg = "";
+ let error = ConnectionError::Transport(v.clone());
+ match &self.state {
+ State::Closing { error: err, .. }
+ | State::Draining { error: err, .. }
+ | State::Closed(err) => {
+ qwarn!([self], "Closing again after error {:?}", err);
+ }
+ State::Init => {
+ // We have not even sent anything just close the connection without sending any
+ // error. This may happen when client_start fails.
+ self.set_state(State::Closed(error));
+ }
+ State::WaitInitial => {
+ // We don't have any state yet, so don't bother with
+ // the closing state, just send one CONNECTION_CLOSE.
+ if let Some(path) = path.or_else(|| self.paths.primary_fallible()) {
+ self.state_signaling
+ .close(path, error.clone(), frame_type, msg);
+ }
+ self.set_state(State::Closed(error));
+ }
+ _ => {
+ if let Some(path) = path.or_else(|| self.paths.primary_fallible()) {
+ self.state_signaling
+ .close(path, error.clone(), frame_type, msg);
+ if matches!(v, Error::KeysExhausted) {
+ self.set_state(State::Closed(error));
+ } else {
+ self.set_state(State::Closing {
+ error,
+ timeout: self.get_closing_period_time(now),
+ });
+ }
+ } else {
+ self.set_state(State::Closed(error));
+ }
+ }
+ }
+ }
+ res
+ }
+
+ /// For use with process_input(). Errors there can be ignored, but this
+ /// needs to ensure that the state is updated.
+ fn absorb_error<T>(&mut self, now: Instant, res: Res<T>) -> Option<T> {
+ self.capture_error(None, now, 0, res).ok()
+ }
+
+ fn process_timer(&mut self, now: Instant) {
+ match &self.state {
+ // Only the client runs timers while waiting for Initial packets.
+ State::WaitInitial => debug_assert_eq!(self.role, Role::Client),
+ // If Closing or Draining, check if it is time to move to Closed.
+ State::Closing { error, timeout } | State::Draining { error, timeout } => {
+ if *timeout <= now {
+ let st = State::Closed(error.clone());
+ self.set_state(st);
+ qinfo!("Closing timer expired");
+ return;
+ }
+ }
+ State::Closed(_) => {
+ qdebug!("Timer fired while closed");
+ return;
+ }
+ _ => (),
+ }
+
+ let pto = self.pto();
+ if self.idle_timeout.expired(now, pto) {
+ qinfo!([self], "idle timeout expired");
+ self.set_state(State::Closed(ConnectionError::Transport(
+ Error::IdleTimeout,
+ )));
+ return;
+ }
+
+ self.streams.cleanup_closed_streams();
+
+ let res = self.crypto.states.check_key_update(now);
+ self.absorb_error(now, res);
+
+ let lost = self.loss_recovery.timeout(&self.paths.primary(), now);
+ self.handle_lost_packets(&lost);
+ qlog::packets_lost(&mut self.qlog, &lost);
+
+ if self.release_resumption_token_timer.is_some() {
+ self.create_resumption_token(now);
+ }
+
+ if !self.paths.process_timeout(now, pto) {
+ qinfo!([self], "last available path failed");
+ self.absorb_error::<Error>(now, Err(Error::NoAvailablePath));
+ }
+ }
+
+ /// Process new input datagrams on the connection.
+ pub fn process_input(&mut self, d: &Datagram, now: Instant) {
+ self.input(d, now, now);
+ self.process_saved(now);
+ self.streams.cleanup_closed_streams();
+ }
+
+ /// Process new input datagrams on the connection.
+ pub fn process_multiple_input<'a, I>(&mut self, dgrams: I, now: Instant)
+ where
+ I: IntoIterator<Item = &'a Datagram>,
+ I::IntoIter: ExactSizeIterator,
+ {
+ let dgrams = dgrams.into_iter();
+ if dgrams.len() == 0 {
+ return;
+ }
+
+ for d in dgrams {
+ self.input(d, now, now);
+ }
+ self.process_saved(now);
+ self.streams.cleanup_closed_streams();
+ }
+
+ /// Get the time that we next need to be called back, relative to `now`.
+ fn next_delay(&mut self, now: Instant, paced: bool) -> Duration {
+ qtrace!([self], "Get callback delay {:?}", now);
+
+ // Only one timer matters when closing...
+ if let State::Closing { timeout, .. } | State::Draining { timeout, .. } = self.state {
+ self.hrtime.update(Self::LOOSE_TIMER_RESOLUTION);
+ return timeout.duration_since(now);
+ }
+
+ let mut delays = SmallVec::<[_; 6]>::new();
+ if let Some(ack_time) = self.acks.ack_time(now) {
+ qtrace!([self], "Delayed ACK timer {:?}", ack_time);
+ delays.push(ack_time);
+ }
+
+ if let Some(p) = self.paths.primary_fallible() {
+ let path = p.borrow();
+ let rtt = path.rtt();
+ let pto = rtt.pto(PacketNumberSpace::ApplicationData);
+
+ let keep_alive = self.streams.need_keep_alive();
+ let idle_time = self.idle_timeout.expiry(now, pto, keep_alive);
+ qtrace!([self], "Idle/keepalive timer {:?}", idle_time);
+ delays.push(idle_time);
+
+ if let Some(lr_time) = self.loss_recovery.next_timeout(rtt) {
+ qtrace!([self], "Loss recovery timer {:?}", lr_time);
+ delays.push(lr_time);
+ }
+
+ if paced {
+ if let Some(pace_time) = path.sender().next_paced(rtt.estimate()) {
+ qtrace!([self], "Pacing timer {:?}", pace_time);
+ delays.push(pace_time);
+ }
+ }
+
+ if let Some(path_time) = self.paths.next_timeout(pto) {
+ qtrace!([self], "Path probe timer {:?}", path_time);
+ delays.push(path_time);
+ }
+ }
+
+ if let Some(key_update_time) = self.crypto.states.update_time() {
+ qtrace!([self], "Key update timer {:?}", key_update_time);
+ delays.push(key_update_time);
+ }
+
+ // `release_resumption_token_timer` is not considered here, because
+ // it is not important enough to force the application to set a
+ // timeout for it It is expected that other activities will
+ // drive it.
+
+ let earliest = delays.into_iter().min().unwrap();
+ // TODO(agrover, mt) - need to analyze and fix #47
+ // rather than just clamping to zero here.
+ debug_assert!(earliest > now);
+ let delay = earliest.saturating_duration_since(now);
+ qdebug!([self], "delay duration {:?}", delay);
+ self.hrtime.update(delay / 4);
+ delay
+ }
+
+ /// Get output packets, as a result of receiving packets, or actions taken
+ /// by the application.
+ /// Returns datagrams to send, and how long to wait before calling again
+ /// even if no incoming packets.
+ #[must_use = "Output of the process_output function must be handled"]
+ pub fn process_output(&mut self, now: Instant) -> Output {
+ qtrace!([self], "process_output {:?} {:?}", self.state, now);
+
+ match (&self.state, self.role) {
+ (State::Init, Role::Client) => {
+ let res = self.client_start(now);
+ self.absorb_error(now, res);
+ }
+ (State::Init | State::WaitInitial, Role::Server) => {
+ return Output::None;
+ }
+ _ => {
+ self.process_timer(now);
+ }
+ }
+
+ match self.output(now) {
+ SendOption::Yes(dgram) => Output::Datagram(dgram),
+ SendOption::No(paced) => match self.state {
+ State::Init | State::Closed(_) => Output::None,
+ State::Closing { timeout, .. } | State::Draining { timeout, .. } => {
+ Output::Callback(timeout.duration_since(now))
+ }
+ _ => Output::Callback(self.next_delay(now, paced)),
+ },
+ }
+ }
+
+ /// Process input and generate output.
+ #[must_use = "Output of the process function must be handled"]
+ pub fn process(&mut self, dgram: Option<&Datagram>, now: Instant) -> Output {
+ if let Some(d) = dgram {
+ self.input(d, now, now);
+ self.process_saved(now);
+ }
+ self.process_output(now)
+ }
+
+ fn handle_retry(&mut self, packet: &PublicPacket, now: Instant) {
+ qinfo!([self], "received Retry");
+ if matches!(self.address_validation, AddressValidationInfo::Retry { .. }) {
+ self.stats.borrow_mut().pkt_dropped("Extra Retry");
+ return;
+ }
+ if packet.token().is_empty() {
+ self.stats.borrow_mut().pkt_dropped("Retry without a token");
+ return;
+ }
+ if !packet.is_valid_retry(self.original_destination_cid.as_ref().unwrap()) {
+ self.stats
+ .borrow_mut()
+ .pkt_dropped("Retry with bad integrity tag");
+ return;
+ }
+ // At this point, we should only have the connection ID that we generated.
+ // Update to the one that the server prefers.
+ let path = self.paths.primary();
+ path.borrow_mut().set_remote_cid(packet.scid());
+
+ let retry_scid = ConnectionId::from(packet.scid());
+ qinfo!(
+ [self],
+ "Valid Retry received, token={} scid={}",
+ hex(packet.token()),
+ retry_scid
+ );
+
+ let lost_packets = self.loss_recovery.retry(&path, now);
+ self.handle_lost_packets(&lost_packets);
+
+ self.crypto.states.init(
+ self.conn_params.get_versions().compatible(),
+ self.role,
+ &retry_scid,
+ );
+ self.address_validation = AddressValidationInfo::Retry {
+ token: packet.token().to_vec(),
+ retry_source_cid: retry_scid,
+ };
+ }
+
+ fn discard_keys(&mut self, space: PacketNumberSpace, now: Instant) {
+ if self.crypto.discard(space) {
+ qinfo!([self], "Drop packet number space {}", space);
+ let primary = self.paths.primary();
+ self.loss_recovery.discard(&primary, space, now);
+ self.acks.drop_space(space);
+ }
+ }
+
+ fn is_stateless_reset(&self, path: &PathRef, d: &Datagram) -> bool {
+ // If the datagram is too small, don't try.
+ // If the connection is connected, then the reset token will be invalid.
+ if d.len() < 16 || !self.state.connected() {
+ return false;
+ }
+ let token = <&[u8; 16]>::try_from(&d[d.len() - 16..]).unwrap();
+ path.borrow().is_stateless_reset(token)
+ }
+
+ fn check_stateless_reset(
+ &mut self,
+ path: &PathRef,
+ d: &Datagram,
+ first: bool,
+ now: Instant,
+ ) -> Res<()> {
+ if first && self.is_stateless_reset(path, d) {
+ // Failing to process a packet in a datagram might
+ // indicate that there is a stateless reset present.
+ qdebug!([self], "Stateless reset: {}", hex(&d[d.len() - 16..]));
+ self.state_signaling.reset();
+ self.set_state(State::Draining {
+ error: ConnectionError::Transport(Error::StatelessReset),
+ timeout: self.get_closing_period_time(now),
+ });
+ Err(Error::StatelessReset)
+ } else {
+ Ok(())
+ }
+ }
+
+ /// Process any saved datagrams that might be available for processing.
+ fn process_saved(&mut self, now: Instant) {
+ while let Some(cspace) = self.saved_datagrams.available() {
+ qdebug!([self], "process saved for space {:?}", cspace);
+ debug_assert!(self.crypto.states.rx_hp(self.version, cspace).is_some());
+ for saved in self.saved_datagrams.take_saved() {
+ qtrace!([self], "input saved @{:?}: {:?}", saved.t, saved.d);
+ self.input(&saved.d, saved.t, now);
+ }
+ }
+ }
+
+ /// In case a datagram arrives that we can only partially process, save any
+ /// part that we don't have keys for.
+ fn save_datagram(&mut self, cspace: CryptoSpace, d: &Datagram, remaining: usize, now: Instant) {
+ let d = if remaining < d.len() {
+ Datagram::new(
+ d.source(),
+ d.destination(),
+ d.tos(),
+ d.ttl(),
+ &d[d.len() - remaining..],
+ )
+ } else {
+ d.clone()
+ };
+ self.saved_datagrams.save(cspace, d, now);
+ self.stats.borrow_mut().saved_datagrams += 1;
+ }
+
+ /// Perform version negotiation.
+ fn version_negotiation(&mut self, supported: &[WireVersion], now: Instant) -> Res<()> {
+ debug_assert_eq!(self.role, Role::Client);
+
+ if let Some(version) = self.conn_params.get_versions().preferred(supported) {
+ assert_ne!(self.version, version);
+
+ qinfo!([self], "Version negotiation: trying {:?}", version);
+ let local_addr = self.paths.primary().borrow().local_address();
+ let remote_addr = self.paths.primary().borrow().remote_address();
+ let conn_params = self
+ .conn_params
+ .clone()
+ .versions(version, self.conn_params.get_versions().all().to_vec());
+ let mut c = Self::new_client(
+ self.crypto.server_name().unwrap(),
+ self.crypto.protocols(),
+ self.cid_manager.generator(),
+ local_addr,
+ remote_addr,
+ conn_params,
+ now,
+ )?;
+ c.conn_params
+ .get_versions_mut()
+ .set_initial(self.conn_params.get_versions().initial());
+ mem::swap(self, &mut c);
+ qlog::client_version_information_negotiated(
+ &mut self.qlog,
+ self.conn_params.get_versions().all(),
+ supported,
+ version,
+ );
+ Ok(())
+ } else {
+ qinfo!([self], "Version negotiation: failed with {:?}", supported);
+ // This error goes straight to closed.
+ self.set_state(State::Closed(ConnectionError::Transport(
+ Error::VersionNegotiation,
+ )));
+ Err(Error::VersionNegotiation)
+ }
+ }
+
+ /// Perform any processing that we might have to do on packets prior to
+ /// attempting to remove protection.
+ fn preprocess_packet(
+ &mut self,
+ packet: &PublicPacket,
+ path: &PathRef,
+ dcid: Option<&ConnectionId>,
+ now: Instant,
+ ) -> Res<PreprocessResult> {
+ if dcid.map_or(false, |d| d != &packet.dcid()) {
+ self.stats
+ .borrow_mut()
+ .pkt_dropped("Coalesced packet has different DCID");
+ return Ok(PreprocessResult::Next);
+ }
+
+ if (packet.packet_type() == PacketType::Initial
+ || packet.packet_type() == PacketType::Handshake)
+ && self.role == Role::Client
+ && !path.borrow().is_primary()
+ {
+ // If we have received a packet from a different address than we have sent to
+ // we should ignore the packet. In such a case a path will be a newly created
+ // temporary path, not the primary path.
+ return Ok(PreprocessResult::Next);
+ }
+
+ match (packet.packet_type(), &self.state, &self.role) {
+ (PacketType::Initial, State::Init, Role::Server) => {
+ let version = *packet.version().as_ref().unwrap();
+ if !packet.is_valid_initial()
+ || !self.conn_params.get_versions().all().contains(&version)
+ {
+ self.stats.borrow_mut().pkt_dropped("Invalid Initial");
+ return Ok(PreprocessResult::Next);
+ }
+ qinfo!(
+ [self],
+ "Received valid Initial packet with scid {:?} dcid {:?}",
+ packet.scid(),
+ packet.dcid()
+ );
+ // Record the client's selected CID so that it can be accepted until
+ // the client starts using a real connection ID.
+ let dcid = ConnectionId::from(packet.dcid());
+ self.crypto.states.init_server(version, &dcid);
+ self.original_destination_cid = Some(dcid);
+ self.set_state(State::WaitInitial);
+
+ // We need to make sure that we set this transport parameter.
+ // This has to happen prior to processing the packet so that
+ // the TLS handshake has all it needs.
+ if !self.retry_sent() {
+ self.tps.borrow_mut().local.set_bytes(
+ tparams::ORIGINAL_DESTINATION_CONNECTION_ID,
+ packet.dcid().to_vec(),
+ );
+ }
+ }
+ (PacketType::VersionNegotiation, State::WaitInitial, Role::Client) => {
+ if let Ok(versions) = packet.supported_versions() {
+ if versions.is_empty()
+ || versions.contains(&self.version().wire_version())
+ || versions.contains(&0)
+ || &packet.scid() != self.odcid().unwrap()
+ || matches!(self.address_validation, AddressValidationInfo::Retry { .. })
+ {
+ // Ignore VersionNegotiation packets that contain the current version.
+ // Or don't have the right connection ID.
+ // Or are received after a Retry.
+ self.stats.borrow_mut().pkt_dropped("Invalid VN");
+ } else {
+ self.version_negotiation(&versions, now)?;
+ }
+ } else {
+ self.stats.borrow_mut().pkt_dropped("VN with no versions");
+ };
+ return Ok(PreprocessResult::End);
+ }
+ (PacketType::Retry, State::WaitInitial, Role::Client) => {
+ self.handle_retry(packet, now);
+ return Ok(PreprocessResult::Next);
+ }
+ (PacketType::Handshake | PacketType::Short, State::WaitInitial, Role::Client) => {
+ // This packet can't be processed now, but it could be a sign
+ // that Initial packets were lost.
+ // Resend Initial CRYPTO frames immediately a few times just
+ // in case. As we don't have an RTT estimate yet, this helps
+ // when there is a short RTT and losses.
+ if dcid.is_none()
+ && self.cid_manager.is_valid(packet.dcid())
+ && self.stats.borrow().saved_datagrams <= EXTRA_INITIALS
+ {
+ self.crypto.resend_unacked(PacketNumberSpace::Initial);
+ }
+ }
+ (PacketType::VersionNegotiation | PacketType::Retry | PacketType::OtherVersion, ..) => {
+ self.stats
+ .borrow_mut()
+ .pkt_dropped(format!("{:?}", packet.packet_type()));
+ return Ok(PreprocessResult::Next);
+ }
+ _ => {}
+ }
+
+ let res = match self.state {
+ State::Init => {
+ self.stats
+ .borrow_mut()
+ .pkt_dropped("Received while in Init state");
+ PreprocessResult::Next
+ }
+ State::WaitInitial => PreprocessResult::Continue,
+ State::WaitVersion | State::Handshaking | State::Connected | State::Confirmed => {
+ if !self.cid_manager.is_valid(packet.dcid()) {
+ self.stats
+ .borrow_mut()
+ .pkt_dropped(format!("Invalid DCID {:?}", packet.dcid()));
+ PreprocessResult::Next
+ } else {
+ if self.role == Role::Server && packet.packet_type() == PacketType::Handshake {
+ // Server has received a Handshake packet -> discard Initial keys and states
+ self.discard_keys(PacketNumberSpace::Initial, now);
+ }
+ PreprocessResult::Continue
+ }
+ }
+ State::Closing { .. } => {
+ // Don't bother processing the packet. Instead ask to get a
+ // new close frame.
+ self.state_signaling.send_close();
+ PreprocessResult::Next
+ }
+ State::Draining { .. } | State::Closed(..) => {
+ // Do nothing.
+ self.stats
+ .borrow_mut()
+ .pkt_dropped(format!("State {:?}", self.state));
+ PreprocessResult::Next
+ }
+ };
+ Ok(res)
+ }
+
+ /// After a Initial, Handshake, ZeroRtt, or Short packet is successfully processed.
+ fn postprocess_packet(
+ &mut self,
+ path: &PathRef,
+ d: &Datagram,
+ packet: &PublicPacket,
+ migrate: bool,
+ now: Instant,
+ ) {
+ if self.state == State::WaitInitial {
+ self.start_handshake(path, packet, now);
+ }
+
+ if self.state.connected() {
+ self.handle_migration(path, d, migrate, now);
+ } else if self.role != Role::Client
+ && (packet.packet_type() == PacketType::Handshake
+ || (packet.dcid().len() >= 8 && packet.dcid() == self.local_initial_source_cid))
+ {
+ // We only allow one path during setup, so apply handshake
+ // path validation to this path.
+ path.borrow_mut().set_valid(now);
+ }
+ }
+
+ /// Take a datagram as input. This reports an error if the packet was bad.
+ /// This takes two times: when the datagram was received, and the current time.
+ fn input(&mut self, d: &Datagram, received: Instant, now: Instant) {
+ // First determine the path.
+ let path = self.paths.find_path_with_rebinding(
+ d.destination(),
+ d.source(),
+ self.conn_params.get_cc_algorithm(),
+ self.conn_params.pacing_enabled(),
+ now,
+ );
+ path.borrow_mut().add_received(d.len());
+ let res = self.input_path(&path, d, received);
+ self.capture_error(Some(path), now, 0, res).ok();
+ }
+
+ fn input_path(&mut self, path: &PathRef, d: &Datagram, now: Instant) -> Res<()> {
+ let mut slc = &d[..];
+ let mut dcid = None;
+
+ qtrace!([self], "{} input {}", path.borrow(), hex(&**d));
+ let pto = path.borrow().rtt().pto(PacketNumberSpace::ApplicationData);
+
+ // Handle each packet in the datagram.
+ while !slc.is_empty() {
+ self.stats.borrow_mut().packets_rx += 1;
+ let (packet, remainder) =
+ match PublicPacket::decode(slc, self.cid_manager.decoder().as_ref()) {
+ Ok((packet, remainder)) => (packet, remainder),
+ Err(e) => {
+ qinfo!([self], "Garbage packet: {}", e);
+ qtrace!([self], "Garbage packet contents: {}", hex(slc));
+ self.stats.borrow_mut().pkt_dropped("Garbage packet");
+ break;
+ }
+ };
+ match self.preprocess_packet(&packet, path, dcid.as_ref(), now)? {
+ PreprocessResult::Continue => (),
+ PreprocessResult::Next => break,
+ PreprocessResult::End => return Ok(()),
+ }
+
+ qtrace!([self], "Received unverified packet {:?}", packet);
+
+ match packet.decrypt(&mut self.crypto.states, now + pto) {
+ Ok(payload) => {
+ // OK, we have a valid packet.
+ self.idle_timeout.on_packet_received(now);
+ dump_packet(
+ self,
+ path,
+ "-> RX",
+ payload.packet_type(),
+ payload.pn(),
+ &payload[..],
+ );
+
+ qlog::packet_received(&mut self.qlog, &packet, &payload);
+ let space = PacketNumberSpace::from(payload.packet_type());
+ if self.acks.get_mut(space).unwrap().is_duplicate(payload.pn()) {
+ qdebug!([self], "Duplicate packet {}-{}", space, payload.pn());
+ self.stats.borrow_mut().dups_rx += 1;
+ } else {
+ match self.process_packet(path, &payload, now) {
+ Ok(migrate) => self.postprocess_packet(path, d, &packet, migrate, now),
+ Err(e) => {
+ self.ensure_error_path(path, &packet, now);
+ return Err(e);
+ }
+ }
+ }
+ }
+ Err(e) => {
+ match e {
+ Error::KeysPending(cspace) => {
+ // This packet can't be decrypted because we don't have the keys yet.
+ // Don't check this packet for a stateless reset, just return.
+ let remaining = slc.len();
+ self.save_datagram(cspace, d, remaining, now);
+ return Ok(());
+ }
+ Error::KeysExhausted => {
+ // Exhausting read keys is fatal.
+ return Err(e);
+ }
+ Error::KeysDiscarded(cspace) => {
+ // This was a valid-appearing Initial packet: maybe probe with
+ // a Handshake packet to keep the handshake moving.
+ self.received_untracked |=
+ self.role == Role::Client && cspace == CryptoSpace::Initial;
+ }
+ _ => (),
+ }
+ // Decryption failure, or not having keys is not fatal.
+ // If the state isn't available, or we can't decrypt the packet, drop
+ // the rest of the datagram on the floor, but don't generate an error.
+ self.check_stateless_reset(path, d, dcid.is_none(), now)?;
+ self.stats.borrow_mut().pkt_dropped("Decryption failure");
+ qlog::packet_dropped(&mut self.qlog, &packet);
+ }
+ }
+ slc = remainder;
+ dcid = Some(ConnectionId::from(packet.dcid()));
+ }
+ self.check_stateless_reset(path, d, dcid.is_none(), now)?;
+ Ok(())
+ }
+
+ /// Process a packet. Returns true if the packet might initiate migration.
+ fn process_packet(
+ &mut self,
+ path: &PathRef,
+ packet: &DecryptedPacket,
+ now: Instant,
+ ) -> Res<bool> {
+ // TODO(ekr@rtfm.com): Have the server blow away the initial
+ // crypto state if this fails? Otherwise, we will get a panic
+ // on the assert for doesn't exist.
+ // OK, we have a valid packet.
+
+ let mut ack_eliciting = false;
+ let mut probing = true;
+ let mut d = Decoder::from(&packet[..]);
+ let mut consecutive_padding = 0;
+ while d.remaining() > 0 {
+ let mut f = Frame::decode(&mut d)?;
+
+ // Skip padding
+ while f == Frame::Padding && d.remaining() > 0 {
+ consecutive_padding += 1;
+ f = Frame::decode(&mut d)?;
+ }
+ if consecutive_padding > 0 {
+ qdebug!(
+ [self],
+ "PADDING frame repeated {} times",
+ consecutive_padding
+ );
+ consecutive_padding = 0;
+ }
+
+ ack_eliciting |= f.ack_eliciting();
+ probing &= f.path_probing();
+ let t = f.get_type();
+ if let Err(e) = self.input_frame(path, packet.version(), packet.packet_type(), f, now) {
+ self.capture_error(Some(Rc::clone(path)), now, t, Err(e))?;
+ }
+ }
+
+ let largest_received = if let Some(space) = self
+ .acks
+ .get_mut(PacketNumberSpace::from(packet.packet_type()))
+ {
+ space.set_received(now, packet.pn(), ack_eliciting)
+ } else {
+ qdebug!(
+ [self],
+ "processed a {:?} packet without tracking it",
+ packet.packet_type(),
+ );
+ // This was a valid packet that caused the same packet number to be
+ // discarded. This happens when the client discards the Initial packet
+ // number space after receiving the ServerHello. Remember this so
+ // that we guarantee that we send a Handshake packet.
+ self.received_untracked = true;
+ // We don't migrate during the handshake, so return false.
+ false
+ };
+
+ Ok(largest_received && !probing)
+ }
+
+ /// During connection setup, the first path needs to be setup.
+ /// This uses the connection IDs that were provided during the handshake
+ /// to setup that path.
+ #[allow(clippy::or_fun_call)] // Remove when MSRV >= 1.59
+ fn setup_handshake_path(&mut self, path: &PathRef, now: Instant) {
+ self.paths.make_permanent(
+ path,
+ Some(self.local_initial_source_cid.clone()),
+ // Ideally we know what the peer wants us to use for the remote CID.
+ // But we will use our own guess if necessary.
+ ConnectionIdEntry::initial_remote(
+ self.remote_initial_source_cid
+ .as_ref()
+ .or(self.original_destination_cid.as_ref())
+ .unwrap()
+ .clone(),
+ ),
+ );
+ path.borrow_mut().set_valid(now);
+ }
+
+ /// If the path isn't permanent, assign it a connection ID to make it so.
+ fn ensure_permanent(&mut self, path: &PathRef) -> Res<()> {
+ if self.paths.is_temporary(path) {
+ // If there isn't a connection ID to use for this path, the packet
+ // will be processed, but it won't be attributed to a path. That means
+ // no path probes or PATH_RESPONSE. But it's not fatal.
+ if let Some(cid) = self.connection_ids.next() {
+ self.paths.make_permanent(path, None, cid);
+ Ok(())
+ } else if self.paths.primary().borrow().remote_cid().is_empty() {
+ self.paths
+ .make_permanent(path, None, ConnectionIdEntry::empty_remote());
+ Ok(())
+ } else {
+ qtrace!([self], "Unable to make path permanent: {}", path.borrow());
+ Err(Error::InvalidMigration)
+ }
+ } else {
+ Ok(())
+ }
+ }
+
+ /// After an error, a permanent path is needed to send the CONNECTION_CLOSE.
+ /// This attempts to ensure that this exists. As the connection is now
+ /// temporary, there is no reason to do anything special here.
+ fn ensure_error_path(&mut self, path: &PathRef, packet: &PublicPacket, now: Instant) {
+ path.borrow_mut().set_valid(now);
+ if self.paths.is_temporary(path) {
+ // First try to fill in handshake details.
+ if packet.packet_type() == PacketType::Initial {
+ self.remote_initial_source_cid = Some(ConnectionId::from(packet.scid()));
+ self.setup_handshake_path(path, now);
+ } else {
+ // Otherwise try to get a usable connection ID.
+ mem::drop(self.ensure_permanent(path));
+ }
+ }
+ }
+
+ fn start_handshake(&mut self, path: &PathRef, packet: &PublicPacket, now: Instant) {
+ qtrace!([self], "starting handshake");
+ debug_assert_eq!(packet.packet_type(), PacketType::Initial);
+ self.remote_initial_source_cid = Some(ConnectionId::from(packet.scid()));
+
+ let got_version = if self.role == Role::Server {
+ self.cid_manager
+ .add_odcid(self.original_destination_cid.as_ref().unwrap().clone());
+ // Make a path on which to run the handshake.
+ self.setup_handshake_path(path, now);
+
+ self.zero_rtt_state = match self.crypto.enable_0rtt(self.version, self.role) {
+ Ok(true) => {
+ qdebug!([self], "Accepted 0-RTT");
+ ZeroRttState::AcceptedServer
+ }
+ _ => ZeroRttState::Rejected,
+ };
+
+ // The server knows the final version if it has remote transport parameters.
+ self.tps.borrow().remote.is_some()
+ } else {
+ qdebug!([self], "Changing to use Server CID={}", packet.scid());
+ debug_assert!(path.borrow().is_primary());
+ path.borrow_mut().set_remote_cid(packet.scid());
+
+ // The client knows the final version if it processed a CRYPTO frame.
+ self.stats.borrow().frame_rx.crypto > 0
+ };
+ if got_version {
+ self.set_state(State::Handshaking);
+ } else {
+ self.set_state(State::WaitVersion);
+ }
+ }
+
+ /// Migrate to the provided path.
+ /// Either local or remote address (but not both) may be provided as `None` to have
+ /// the address from the current primary path used.
+ /// If `force` is true, then migration is immediate.
+ /// Otherwise, migration occurs after the path is probed successfully.
+ /// Either way, the path is probed and will be abandoned if the probe fails.
+ ///
+ /// # Errors
+ ///
+ /// Fails if this is not a client, not confirmed, or there are not enough connection
+ /// IDs available to use.
+ pub fn migrate(
+ &mut self,
+ local: Option<SocketAddr>,
+ remote: Option<SocketAddr>,
+ force: bool,
+ now: Instant,
+ ) -> Res<()> {
+ if self.role != Role::Client {
+ return Err(Error::InvalidMigration);
+ }
+ if !matches!(self.state(), State::Confirmed) {
+ return Err(Error::InvalidMigration);
+ }
+
+ // Fill in the blanks, using the current primary path.
+ if local.is_none() && remote.is_none() {
+ // Pointless migration is pointless.
+ return Err(Error::InvalidMigration);
+ }
+ let local = local.unwrap_or_else(|| self.paths.primary().borrow().local_address());
+ let remote = remote.unwrap_or_else(|| self.paths.primary().borrow().remote_address());
+
+ if mem::discriminant(&local.ip()) != mem::discriminant(&remote.ip()) {
+ // Can't mix address families.
+ return Err(Error::InvalidMigration);
+ }
+ if local.port() == 0 || remote.ip().is_unspecified() || remote.port() == 0 {
+ // All but the local address need to be specified.
+ return Err(Error::InvalidMigration);
+ }
+ if (local.ip().is_loopback() ^ remote.ip().is_loopback()) && !local.ip().is_unspecified() {
+ // Block attempts to migrate to a path with loopback on only one end, unless the local
+ // address is unspecified.
+ return Err(Error::InvalidMigration);
+ }
+
+ let path = self.paths.find_path(
+ local,
+ remote,
+ self.conn_params.get_cc_algorithm(),
+ self.conn_params.pacing_enabled(),
+ now,
+ );
+ self.ensure_permanent(&path)?;
+ qinfo!(
+ [self],
+ "Migrate to {} probe {}",
+ path.borrow(),
+ if force { "now" } else { "after" }
+ );
+ if self.paths.migrate(&path, force, now) {
+ self.loss_recovery.migrate();
+ }
+ Ok(())
+ }
+
+ fn migrate_to_preferred_address(&mut self, now: Instant) -> Res<()> {
+ let spa = if matches!(
+ self.conn_params.get_preferred_address(),
+ PreferredAddressConfig::Disabled
+ ) {
+ None
+ } else {
+ self.tps.borrow_mut().remote().get_preferred_address()
+ };
+ if let Some((addr, cid)) = spa {
+ // The connection ID isn't special, so just save it.
+ self.connection_ids.add_remote(cid)?;
+
+ // The preferred address doesn't dictate what the local address is, so this
+ // has to use the existing address. So only pay attention to a preferred
+ // address from the same family as is currently in use. More thought will
+ // be needed to work out how to get addresses from a different family.
+ let prev = self.paths.primary().borrow().remote_address();
+ let remote = match prev.ip() {
+ IpAddr::V4(_) => addr.ipv4().map(SocketAddr::V4),
+ IpAddr::V6(_) => addr.ipv6().map(SocketAddr::V6),
+ };
+
+ if let Some(remote) = remote {
+ // Ignore preferred address that move to loopback from non-loopback.
+ // `migrate` doesn't enforce this rule.
+ if !prev.ip().is_loopback() && remote.ip().is_loopback() {
+ qwarn!([self], "Ignoring a move to a loopback address: {}", remote);
+ return Ok(());
+ }
+
+ if self.migrate(None, Some(remote), false, now).is_err() {
+ qwarn!([self], "Ignoring bad preferred address: {}", remote);
+ }
+ } else {
+ qwarn!([self], "Unable to migrate to a different address family");
+ }
+ }
+ Ok(())
+ }
+
+ fn handle_migration(&mut self, path: &PathRef, d: &Datagram, migrate: bool, now: Instant) {
+ if !migrate {
+ return;
+ }
+ if self.role == Role::Client {
+ return;
+ }
+
+ if self.ensure_permanent(path).is_ok() {
+ self.paths.handle_migration(path, d.source(), now);
+ } else {
+ qinfo!(
+ [self],
+ "{} Peer migrated, but no connection ID available",
+ path.borrow()
+ );
+ }
+ }
+
+ fn output(&mut self, now: Instant) -> SendOption {
+ qtrace!([self], "output {:?}", now);
+ let res = match &self.state {
+ State::Init
+ | State::WaitInitial
+ | State::WaitVersion
+ | State::Handshaking
+ | State::Connected
+ | State::Confirmed => {
+ if let Some(path) = self.paths.select_path() {
+ let res = self.output_path(&path, now);
+ self.capture_error(Some(path), now, 0, res)
+ } else {
+ Ok(SendOption::default())
+ }
+ }
+ State::Closing { .. } | State::Draining { .. } | State::Closed(_) => {
+ if let Some(details) = self.state_signaling.close_frame() {
+ let path = Rc::clone(details.path());
+ let res = self.output_close(details);
+ self.capture_error(Some(path), now, 0, res)
+ } else {
+ Ok(SendOption::default())
+ }
+ }
+ };
+ res.unwrap_or_default()
+ }
+
+ fn build_packet_header(
+ path: &Path,
+ cspace: CryptoSpace,
+ encoder: Encoder,
+ tx: &CryptoDxState,
+ address_validation: &AddressValidationInfo,
+ version: Version,
+ grease_quic_bit: bool,
+ ) -> (PacketType, PacketBuilder) {
+ let pt = PacketType::from(cspace);
+ let mut builder = if pt == PacketType::Short {
+ qdebug!("Building Short dcid {}", path.remote_cid());
+ PacketBuilder::short(encoder, tx.key_phase(), path.remote_cid())
+ } else {
+ qdebug!(
+ "Building {:?} dcid {} scid {}",
+ pt,
+ path.remote_cid(),
+ path.local_cid(),
+ );
+
+ PacketBuilder::long(encoder, pt, version, path.remote_cid(), path.local_cid())
+ };
+ if builder.remaining() > 0 {
+ builder.scramble(grease_quic_bit);
+ if pt == PacketType::Initial {
+ builder.initial_token(address_validation.token());
+ }
+ }
+
+ (pt, builder)
+ }
+
+ #[must_use]
+ fn add_packet_number(
+ builder: &mut PacketBuilder,
+ tx: &CryptoDxState,
+ largest_acknowledged: Option<PacketNumber>,
+ ) -> PacketNumber {
+ // Get the packet number and work out how long it is.
+ let pn = tx.next_pn();
+ let unacked_range = if let Some(la) = largest_acknowledged {
+ // Double the range from this to the last acknowledged in this space.
+ (pn - la) << 1
+ } else {
+ pn + 1
+ };
+ // Count how many bytes in this range are non-zero.
+ let pn_len = mem::size_of::<PacketNumber>()
+ - usize::try_from(unacked_range.leading_zeros() / 8).unwrap();
+ // pn_len can't be zero (unacked_range is > 0)
+ // TODO(mt) also use `4*path CWND/path MTU` to set a minimum length.
+ builder.pn(pn, pn_len);
+ pn
+ }
+
+ fn can_grease_quic_bit(&self) -> bool {
+ let tph = self.tps.borrow();
+ if let Some(r) = &tph.remote {
+ r.get_empty(tparams::GREASE_QUIC_BIT)
+ } else if let Some(r) = &tph.remote_0rtt {
+ r.get_empty(tparams::GREASE_QUIC_BIT)
+ } else {
+ false
+ }
+ }
+
+ fn output_close(&mut self, close: ClosingFrame) -> Res<SendOption> {
+ let mut encoder = Encoder::with_capacity(256);
+ let grease_quic_bit = self.can_grease_quic_bit();
+ let version = self.version();
+ for space in PacketNumberSpace::iter() {
+ let Some((cspace, tx)) = self.crypto.states.select_tx_mut(self.version, *space) else {
+ continue;
+ };
+
+ let path = close.path().borrow();
+ let (_, mut builder) = Self::build_packet_header(
+ &path,
+ cspace,
+ encoder,
+ tx,
+ &AddressValidationInfo::None,
+ version,
+ grease_quic_bit,
+ );
+ _ = Self::add_packet_number(
+ &mut builder,
+ tx,
+ self.loss_recovery.largest_acknowledged_pn(*space),
+ );
+ // The builder will set the limit to 0 if there isn't enough space for the header.
+ if builder.is_full() {
+ encoder = builder.abort();
+ break;
+ }
+ builder.set_limit(min(path.amplification_limit(), path.mtu()) - tx.expansion());
+ debug_assert!(builder.limit() <= 2048);
+
+ // ConnectionError::Application is only allowed at 1RTT.
+ let sanitized = if *space == PacketNumberSpace::ApplicationData {
+ None
+ } else {
+ close.sanitize()
+ };
+ sanitized
+ .as_ref()
+ .unwrap_or(&close)
+ .write_frame(&mut builder);
+ encoder = builder.build(tx)?;
+ }
+
+ Ok(SendOption::Yes(close.path().borrow().datagram(encoder)))
+ }
+
+ /// Write the frames that are exchanged in the application data space.
+ /// The order of calls here determines the relative priority of frames.
+ fn write_appdata_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ ) -> Res<()> {
+ let stats = &mut self.stats.borrow_mut();
+ let frame_stats = &mut stats.frame_tx;
+ if self.role == Role::Server {
+ if let Some(t) = self.state_signaling.write_done(builder)? {
+ tokens.push(t);
+ frame_stats.handshake_done += 1;
+ }
+ }
+
+ self.streams
+ .write_frames(TransmissionPriority::Critical, builder, tokens, frame_stats);
+ if builder.is_full() {
+ return Ok(());
+ }
+
+ self.streams.write_frames(
+ TransmissionPriority::Important,
+ builder,
+ tokens,
+ frame_stats,
+ );
+ if builder.is_full() {
+ return Ok(());
+ }
+
+ // NEW_CONNECTION_ID, RETIRE_CONNECTION_ID, and ACK_FREQUENCY.
+ self.cid_manager
+ .write_frames(builder, tokens, frame_stats)?;
+ if builder.is_full() {
+ return Ok(());
+ }
+ self.paths.write_frames(builder, tokens, frame_stats);
+ if builder.is_full() {
+ return Ok(());
+ }
+
+ self.streams
+ .write_frames(TransmissionPriority::High, builder, tokens, frame_stats);
+ if builder.is_full() {
+ return Ok(());
+ }
+
+ self.streams
+ .write_frames(TransmissionPriority::Normal, builder, tokens, frame_stats);
+ if builder.is_full() {
+ return Ok(());
+ }
+
+ // Datagrams are best-effort and unreliable. Let streams starve them for now.
+ self.quic_datagrams.write_frames(builder, tokens, stats);
+ if builder.is_full() {
+ return Ok(());
+ }
+
+ let frame_stats = &mut stats.frame_tx;
+ // CRYPTO here only includes NewSessionTicket, plus NEW_TOKEN.
+ // Both of these are only used for resumption and so can be relatively low priority.
+ self.crypto.write_frame(
+ PacketNumberSpace::ApplicationData,
+ builder,
+ tokens,
+ frame_stats,
+ )?;
+ if builder.is_full() {
+ return Ok(());
+ }
+ self.new_token.write_frames(builder, tokens, frame_stats)?;
+ if builder.is_full() {
+ return Ok(());
+ }
+
+ self.streams
+ .write_frames(TransmissionPriority::Low, builder, tokens, frame_stats);
+
+ #[cfg(test)]
+ {
+ if let Some(w) = &mut self.test_frame_writer {
+ w.write_frames(builder);
+ }
+ }
+
+ Ok(())
+ }
+
+ // Maybe send a probe. Return true if the packet was ack-eliciting.
+ fn maybe_probe(
+ &mut self,
+ path: &PathRef,
+ force_probe: bool,
+ builder: &mut PacketBuilder,
+ ack_end: usize,
+ tokens: &mut Vec<RecoveryToken>,
+ now: Instant,
+ ) -> bool {
+ let untracked = self.received_untracked && !self.state.connected();
+ self.received_untracked = false;
+
+ // Anything written after an ACK already elicits acknowledgment.
+ // If we need to probe and nothing has been written, send a PING.
+ if builder.len() > ack_end {
+ return true;
+ }
+
+ let probe = if untracked && builder.packet_empty() || force_probe {
+ // If we received an untracked packet and we aren't probing already
+ // or the PTO timer fired: probe.
+ true
+ } else {
+ let pto = path.borrow().rtt().pto(PacketNumberSpace::ApplicationData);
+ if !builder.packet_empty() {
+ // The packet only contains an ACK. Check whether we want to
+ // force an ACK with a PING so we can stop tracking packets.
+ self.loss_recovery.should_probe(pto, now)
+ } else if self.streams.need_keep_alive() {
+ // We need to keep the connection alive, including sending
+ // a PING again.
+ self.idle_timeout.send_keep_alive(now, pto, tokens)
+ } else {
+ false
+ }
+ };
+ if probe {
+ // Nothing ack-eliciting and we need to probe; send PING.
+ debug_assert_ne!(builder.remaining(), 0);
+ builder.encode_varint(crate::frame::FRAME_TYPE_PING);
+ let stats = &mut self.stats.borrow_mut().frame_tx;
+ stats.ping += 1;
+ stats.all += 1;
+ }
+ probe
+ }
+
+ /// Write frames to the provided builder. Returns a list of tokens used for
+ /// tracking loss or acknowledgment, whether any frame was ACK eliciting, and
+ /// whether the packet was padded.
+ fn write_frames(
+ &mut self,
+ path: &PathRef,
+ space: PacketNumberSpace,
+ profile: &SendProfile,
+ builder: &mut PacketBuilder,
+ now: Instant,
+ ) -> Res<(Vec<RecoveryToken>, bool, bool)> {
+ let mut tokens = Vec::new();
+ let primary = path.borrow().is_primary();
+ let mut ack_eliciting = false;
+
+ if primary {
+ let stats = &mut self.stats.borrow_mut().frame_tx;
+ self.acks.write_frame(
+ space,
+ now,
+ path.borrow().rtt().estimate(),
+ builder,
+ &mut tokens,
+ stats,
+ );
+ }
+ let ack_end = builder.len();
+
+ // Avoid sending probes until the handshake completes,
+ // but send them even when we don't have space.
+ let full_mtu = profile.limit() == path.borrow().mtu();
+ if space == PacketNumberSpace::ApplicationData && self.state.connected() {
+ // Probes should only be padded if the full MTU is available.
+ // The probing code needs to know so it can track that.
+ if path.borrow_mut().write_frames(
+ builder,
+ &mut self.stats.borrow_mut().frame_tx,
+ full_mtu,
+ now,
+ ) {
+ builder.enable_padding(true);
+ }
+ }
+
+ if profile.ack_only(space) {
+ // If we are CC limited we can only send acks!
+ return Ok((tokens, false, false));
+ }
+
+ if primary {
+ if space == PacketNumberSpace::ApplicationData {
+ self.write_appdata_frames(builder, &mut tokens)?;
+ } else {
+ let stats = &mut self.stats.borrow_mut().frame_tx;
+ self.crypto
+ .write_frame(space, builder, &mut tokens, stats)?;
+ }
+ }
+
+ // Maybe send a probe now, either to probe for losses or to keep the connection live.
+ let force_probe = profile.should_probe(space);
+ ack_eliciting |= self.maybe_probe(path, force_probe, builder, ack_end, &mut tokens, now);
+ // If this is not the primary path, this should be ack-eliciting.
+ debug_assert!(primary || ack_eliciting);
+
+ // Add padding. Only pad 1-RTT packets so that we don't prevent coalescing.
+ // And avoid padding packets that otherwise only contain ACK because adding PADDING
+ // causes those packets to consume congestion window, which is not tracked (yet).
+ // And avoid padding if we don't have a full MTU available.
+ let stats = &mut self.stats.borrow_mut().frame_tx;
+ let padded = if ack_eliciting && full_mtu && builder.pad() {
+ stats.padding += 1;
+ stats.all += 1;
+ true
+ } else {
+ false
+ };
+
+ stats.all += tokens.len();
+ Ok((tokens, ack_eliciting, padded))
+ }
+
+ /// Build a datagram, possibly from multiple packets (for different PN
+ /// spaces) and each containing 1+ frames.
+ fn output_path(&mut self, path: &PathRef, now: Instant) -> Res<SendOption> {
+ let mut initial_sent = None;
+ let mut needs_padding = false;
+ let grease_quic_bit = self.can_grease_quic_bit();
+ let version = self.version();
+
+ // Determine how we are sending packets (PTO, etc..).
+ let mtu = path.borrow().mtu();
+ let profile = self.loss_recovery.send_profile(&path.borrow(), now);
+ qdebug!([self], "output_path send_profile {:?}", profile);
+
+ // Frames for different epochs must go in different packets, but then these
+ // packets can go in a single datagram
+ let mut encoder = Encoder::with_capacity(profile.limit());
+ for space in PacketNumberSpace::iter() {
+ // Ensure we have tx crypto state for this epoch, or skip it.
+ let Some((cspace, tx)) = self.crypto.states.select_tx_mut(self.version, *space) else {
+ continue;
+ };
+
+ let header_start = encoder.len();
+ let (pt, mut builder) = Self::build_packet_header(
+ &path.borrow(),
+ cspace,
+ encoder,
+ tx,
+ &self.address_validation,
+ version,
+ grease_quic_bit,
+ );
+ let pn = Self::add_packet_number(
+ &mut builder,
+ tx,
+ self.loss_recovery.largest_acknowledged_pn(*space),
+ );
+ // The builder will set the limit to 0 if there isn't enough space for the header.
+ if builder.is_full() {
+ encoder = builder.abort();
+ break;
+ }
+
+ // Configure the limits and padding for this packet.
+ let aead_expansion = tx.expansion();
+ builder.set_limit(profile.limit() - aead_expansion);
+ builder.enable_padding(needs_padding);
+ debug_assert!(builder.limit() <= 2048);
+ if builder.is_full() {
+ encoder = builder.abort();
+ break;
+ }
+
+ // Add frames to the packet.
+ let payload_start = builder.len();
+ let (tokens, ack_eliciting, padded) =
+ self.write_frames(path, *space, &profile, &mut builder, now)?;
+ if builder.packet_empty() {
+ // Nothing to include in this packet.
+ encoder = builder.abort();
+ continue;
+ }
+
+ dump_packet(
+ self,
+ path,
+ "TX ->",
+ pt,
+ pn,
+ &builder.as_ref()[payload_start..],
+ );
+ qlog::packet_sent(
+ &mut self.qlog,
+ pt,
+ pn,
+ builder.len() - header_start + aead_expansion,
+ &builder.as_ref()[payload_start..],
+ );
+
+ self.stats.borrow_mut().packets_tx += 1;
+ let tx = self.crypto.states.tx_mut(self.version, cspace).unwrap();
+ encoder = builder.build(tx)?;
+ debug_assert!(encoder.len() <= mtu);
+ self.crypto.states.auto_update()?;
+
+ if ack_eliciting {
+ self.idle_timeout.on_packet_sent(now);
+ }
+ let sent = SentPacket::new(
+ pt,
+ pn,
+ now,
+ ack_eliciting,
+ tokens,
+ encoder.len() - header_start,
+ );
+ if padded {
+ needs_padding = false;
+ self.loss_recovery.on_packet_sent(path, sent);
+ } else if pt == PacketType::Initial && (self.role == Role::Client || ack_eliciting) {
+ // Packets containing Initial packets might need padding, and we want to
+ // track that padding along with the Initial packet. So defer tracking.
+ initial_sent = Some(sent);
+ needs_padding = true;
+ } else {
+ if pt == PacketType::Handshake && self.role == Role::Client {
+ needs_padding = false;
+ }
+ self.loss_recovery.on_packet_sent(path, sent);
+ }
+
+ if *space == PacketNumberSpace::Handshake
+ && self.role == Role::Server
+ && self.state == State::Confirmed
+ {
+ // We could discard handshake keys in set_state,
+ // but wait until after sending an ACK.
+ self.discard_keys(PacketNumberSpace::Handshake, now);
+ }
+ }
+
+ if encoder.is_empty() {
+ qinfo!("TX blocked, profile={:?} ", profile);
+ Ok(SendOption::No(profile.paced()))
+ } else {
+ // Perform additional padding for Initial packets as necessary.
+ let mut packets: Vec<u8> = encoder.into();
+ if let Some(mut initial) = initial_sent.take() {
+ if needs_padding {
+ qdebug!(
+ [self],
+ "pad Initial from {} to path MTU {}",
+ packets.len(),
+ mtu
+ );
+ initial.size += mtu - packets.len();
+ packets.resize(mtu, 0);
+ }
+ self.loss_recovery.on_packet_sent(path, initial);
+ }
+ path.borrow_mut().add_sent(packets.len());
+ Ok(SendOption::Yes(path.borrow().datagram(packets)))
+ }
+ }
+
+ pub fn initiate_key_update(&mut self) -> Res<()> {
+ if self.state == State::Confirmed {
+ let la = self
+ .loss_recovery
+ .largest_acknowledged_pn(PacketNumberSpace::ApplicationData);
+ qinfo!([self], "Initiating key update");
+ self.crypto.states.initiate_key_update(la)
+ } else {
+ Err(Error::KeyUpdateBlocked)
+ }
+ }
+
+ #[cfg(test)]
+ pub fn get_epochs(&self) -> (Option<usize>, Option<usize>) {
+ self.crypto.states.get_epochs()
+ }
+
+ fn client_start(&mut self, now: Instant) -> Res<()> {
+ qinfo!([self], "client_start");
+ debug_assert_eq!(self.role, Role::Client);
+ qlog::client_connection_started(&mut self.qlog, &self.paths.primary());
+ qlog::client_version_information_initiated(&mut self.qlog, self.conn_params.get_versions());
+
+ self.handshake(now, self.version, PacketNumberSpace::Initial, None)?;
+ self.set_state(State::WaitInitial);
+ self.zero_rtt_state = if self.crypto.enable_0rtt(self.version, self.role)? {
+ qdebug!([self], "Enabled 0-RTT");
+ ZeroRttState::Sending
+ } else {
+ ZeroRttState::Init
+ };
+ Ok(())
+ }
+
+ fn get_closing_period_time(&self, now: Instant) -> Instant {
+ // Spec says close time should be at least PTO times 3.
+ now + (self.pto() * 3)
+ }
+
+ /// Close the connection.
+ pub fn close(&mut self, now: Instant, app_error: AppError, msg: impl AsRef<str>) {
+ let error = ConnectionError::Application(app_error);
+ let timeout = self.get_closing_period_time(now);
+ if let Some(path) = self.paths.primary_fallible() {
+ self.state_signaling.close(path, error.clone(), 0, msg);
+ self.set_state(State::Closing { error, timeout });
+ } else {
+ self.set_state(State::Closed(error));
+ }
+ }
+
+ fn set_initial_limits(&mut self) {
+ self.streams.set_initial_limits();
+ let peer_timeout = self
+ .tps
+ .borrow()
+ .remote()
+ .get_integer(tparams::IDLE_TIMEOUT);
+ if peer_timeout > 0 {
+ self.idle_timeout
+ .set_peer_timeout(Duration::from_millis(peer_timeout));
+ }
+
+ self.quic_datagrams.set_remote_datagram_size(
+ self.tps
+ .borrow()
+ .remote()
+ .get_integer(tparams::MAX_DATAGRAM_FRAME_SIZE),
+ );
+ }
+
+ pub fn is_stream_id_allowed(&self, stream_id: StreamId) -> bool {
+ self.streams.is_stream_id_allowed(stream_id)
+ }
+
+ /// Process the final set of transport parameters.
+ fn process_tps(&mut self) -> Res<()> {
+ self.validate_cids()?;
+ self.validate_versions()?;
+ {
+ let tps = self.tps.borrow();
+ let remote = tps.remote.as_ref().unwrap();
+
+ // If the peer provided a preferred address, then we have to be a client
+ // and they have to be using a non-empty connection ID.
+ if remote.get_preferred_address().is_some()
+ && (self.role == Role::Server
+ || self.remote_initial_source_cid.as_ref().unwrap().is_empty())
+ {
+ return Err(Error::TransportParameterError);
+ }
+
+ let reset_token = if let Some(token) = remote.get_bytes(tparams::STATELESS_RESET_TOKEN)
+ {
+ <[u8; 16]>::try_from(token).unwrap()
+ } else {
+ // The other side didn't provide a stateless reset token.
+ // That's OK, they can try guessing this.
+ <[u8; 16]>::try_from(&random(16)[..]).unwrap()
+ };
+ self.paths
+ .primary()
+ .borrow_mut()
+ .set_reset_token(reset_token);
+
+ let max_ad = Duration::from_millis(remote.get_integer(tparams::MAX_ACK_DELAY));
+ let min_ad = if remote.has_value(tparams::MIN_ACK_DELAY) {
+ let min_ad = Duration::from_micros(remote.get_integer(tparams::MIN_ACK_DELAY));
+ if min_ad > max_ad {
+ return Err(Error::TransportParameterError);
+ }
+ Some(min_ad)
+ } else {
+ None
+ };
+ self.paths.primary().borrow_mut().set_ack_delay(
+ max_ad,
+ min_ad,
+ self.conn_params.get_ack_ratio(),
+ );
+
+ let max_active_cids = remote.get_integer(tparams::ACTIVE_CONNECTION_ID_LIMIT);
+ self.cid_manager.set_limit(max_active_cids);
+ }
+ self.set_initial_limits();
+ qlog::connection_tparams_set(&mut self.qlog, &self.tps.borrow());
+ Ok(())
+ }
+
+ fn validate_cids(&mut self) -> Res<()> {
+ let tph = self.tps.borrow();
+ let remote_tps = tph.remote.as_ref().unwrap();
+
+ let tp = remote_tps.get_bytes(tparams::INITIAL_SOURCE_CONNECTION_ID);
+ if self
+ .remote_initial_source_cid
+ .as_ref()
+ .map(ConnectionId::as_cid_ref)
+ != tp.map(ConnectionIdRef::from)
+ {
+ qwarn!(
+ [self],
+ "ISCID test failed: self cid {:?} != tp cid {:?}",
+ self.remote_initial_source_cid,
+ tp.map(hex),
+ );
+ return Err(Error::ProtocolViolation);
+ }
+
+ if self.role == Role::Client {
+ let tp = remote_tps.get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID);
+ if self
+ .original_destination_cid
+ .as_ref()
+ .map(ConnectionId::as_cid_ref)
+ != tp.map(ConnectionIdRef::from)
+ {
+ qwarn!(
+ [self],
+ "ODCID test failed: self cid {:?} != tp cid {:?}",
+ self.original_destination_cid,
+ tp.map(hex),
+ );
+ return Err(Error::ProtocolViolation);
+ }
+
+ let tp = remote_tps.get_bytes(tparams::RETRY_SOURCE_CONNECTION_ID);
+ let expected = if let AddressValidationInfo::Retry {
+ retry_source_cid, ..
+ } = &self.address_validation
+ {
+ Some(retry_source_cid.as_cid_ref())
+ } else {
+ None
+ };
+ if expected != tp.map(ConnectionIdRef::from) {
+ qwarn!(
+ [self],
+ "RSCID test failed. self cid {:?} != tp cid {:?}",
+ expected,
+ tp.map(hex),
+ );
+ return Err(Error::ProtocolViolation);
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Validate the `version_negotiation` transport parameter from the peer.
+ fn validate_versions(&mut self) -> Res<()> {
+ let tph = self.tps.borrow();
+ let remote_tps = tph.remote.as_ref().unwrap();
+ // `current` and `other` are the value from the peer's transport parameters.
+ // We're checking that these match our expectations.
+ if let Some((current, other)) = remote_tps.get_versions() {
+ qtrace!(
+ [self],
+ "validate_versions: current={:x} chosen={:x} other={:x?}",
+ self.version.wire_version(),
+ current,
+ other,
+ );
+ if self.role == Role::Server {
+ // 1. A server acts on transport parameters, with validation
+ // of `current` happening in the transport parameter handler.
+ // All we need to do is confirm that the transport parameter
+ // was provided.
+ Ok(())
+ } else if self.version().wire_version() != current {
+ qinfo!([self], "validate_versions: current version mismatch");
+ Err(Error::VersionNegotiation)
+ } else if self
+ .conn_params
+ .get_versions()
+ .initial()
+ .is_compatible(self.version)
+ {
+ // 2. The current version is compatible with what we attempted.
+ // That's a compatible upgrade and that's OK.
+ Ok(())
+ } else {
+ // 3. The initial version we attempted isn't compatible. Check that
+ // the one we would have chosen is compatible with this one.
+ let mut all_versions = other.to_owned();
+ all_versions.push(current);
+ if self
+ .conn_params
+ .get_versions()
+ .preferred(&all_versions)
+ .ok_or(Error::VersionNegotiation)?
+ .is_compatible(self.version)
+ {
+ Ok(())
+ } else {
+ qinfo!([self], "validate_versions: failed");
+ Err(Error::VersionNegotiation)
+ }
+ }
+ } else if self.version != Version::Version1 && !self.version.is_draft() {
+ qinfo!([self], "validate_versions: missing extension");
+ Err(Error::VersionNegotiation)
+ } else {
+ Ok(())
+ }
+ }
+
+ fn confirm_version(&mut self, v: Version) {
+ if self.version != v {
+ qinfo!([self], "Compatible upgrade {:?} ==> {:?}", self.version, v);
+ }
+ self.crypto.confirm_version(v);
+ self.version = v;
+ }
+
+ fn compatible_upgrade(&mut self, packet_version: Version) {
+ if !matches!(self.state, State::WaitInitial | State::WaitVersion) {
+ return;
+ }
+
+ if self.role == Role::Client {
+ self.confirm_version(packet_version);
+ } else if self.tps.borrow().remote.is_some() {
+ let version = self.tps.borrow().version();
+ let dcid = self.original_destination_cid.as_ref().unwrap();
+ self.crypto.states.init_server(version, dcid);
+ self.confirm_version(version);
+ }
+ }
+
+ fn handshake(
+ &mut self,
+ now: Instant,
+ packet_version: Version,
+ space: PacketNumberSpace,
+ data: Option<&[u8]>,
+ ) -> Res<()> {
+ qtrace!([self], "Handshake space={} data={:0x?}", space, data);
+
+ let try_update = data.is_some();
+ match self.crypto.handshake(now, space, data)? {
+ HandshakeState::Authenticated(_) | HandshakeState::InProgress => (),
+ HandshakeState::AuthenticationPending => self.events.authentication_needed(),
+ HandshakeState::EchFallbackAuthenticationPending(public_name) => self
+ .events
+ .ech_fallback_authentication_needed(public_name.clone()),
+ HandshakeState::Complete(_) => {
+ if !self.state.connected() {
+ self.set_connected(now)?;
+ }
+ }
+ _ => {
+ unreachable!("Crypto state should not be new or failed after successful handshake")
+ }
+ }
+
+ // There is a chance that this could be called less often, but getting the
+ // conditions right is a little tricky, so call whenever CRYPTO data is used.
+ if try_update {
+ self.compatible_upgrade(packet_version);
+ // We have transport parameters, it's go time.
+ if self.tps.borrow().remote.is_some() {
+ self.set_initial_limits();
+ }
+ if self.crypto.install_keys(self.role)? {
+ if self.role == Role::Client {
+ // We won't acknowledge Initial packets as a result of this, but the
+ // server can rely on implicit acknowledgment.
+ self.discard_keys(PacketNumberSpace::Initial, now);
+ }
+ self.saved_datagrams.make_available(CryptoSpace::Handshake);
+ }
+ }
+
+ Ok(())
+ }
+
+ fn input_frame(
+ &mut self,
+ path: &PathRef,
+ packet_version: Version,
+ packet_type: PacketType,
+ frame: Frame,
+ now: Instant,
+ ) -> Res<()> {
+ if !frame.is_allowed(packet_type) {
+ qinfo!("frame not allowed: {:?} {:?}", frame, packet_type);
+ return Err(Error::ProtocolViolation);
+ }
+ self.stats.borrow_mut().frame_rx.all += 1;
+ let space = PacketNumberSpace::from(packet_type);
+ if frame.is_stream() {
+ return self
+ .streams
+ .input_frame(frame, &mut self.stats.borrow_mut().frame_rx);
+ }
+ match frame {
+ Frame::Padding => {
+ // Note: This counts contiguous padding as a single frame.
+ self.stats.borrow_mut().frame_rx.padding += 1;
+ }
+ Frame::Ping => {
+ // If we get a PING and there are outstanding CRYPTO frames,
+ // prepare to resend them.
+ self.stats.borrow_mut().frame_rx.ping += 1;
+ self.crypto.resend_unacked(space);
+ if space == PacketNumberSpace::ApplicationData {
+ // Send an ACK immediately if we might not otherwise do so.
+ self.acks.immediate_ack(now);
+ }
+ }
+ Frame::Ack {
+ largest_acknowledged,
+ ack_delay,
+ first_ack_range,
+ ack_ranges,
+ } => {
+ let ranges =
+ Frame::decode_ack_frame(largest_acknowledged, first_ack_range, &ack_ranges)?;
+ self.handle_ack(space, largest_acknowledged, ranges, ack_delay, now);
+ }
+ Frame::Crypto { offset, data } => {
+ qtrace!(
+ [self],
+ "Crypto frame on space={} offset={}, data={:0x?}",
+ space,
+ offset,
+ &data
+ );
+ self.stats.borrow_mut().frame_rx.crypto += 1;
+ self.crypto.streams.inbound_frame(space, offset, data)?;
+ if self.crypto.streams.data_ready(space) {
+ let mut buf = Vec::new();
+ let read = self.crypto.streams.read_to_end(space, &mut buf);
+ qdebug!("Read {} bytes", read);
+ self.handshake(now, packet_version, space, Some(&buf))?;
+ self.create_resumption_token(now);
+ } else {
+ // If we get a useless CRYPTO frame send outstanding CRYPTO frames again.
+ self.crypto.resend_unacked(space);
+ }
+ }
+ Frame::NewToken { token } => {
+ self.stats.borrow_mut().frame_rx.new_token += 1;
+ self.new_token.save_token(token.to_vec());
+ self.create_resumption_token(now);
+ }
+ Frame::NewConnectionId {
+ sequence_number,
+ connection_id,
+ stateless_reset_token,
+ retire_prior,
+ } => {
+ self.stats.borrow_mut().frame_rx.new_connection_id += 1;
+ self.connection_ids.add_remote(ConnectionIdEntry::new(
+ sequence_number,
+ ConnectionId::from(connection_id),
+ stateless_reset_token.to_owned(),
+ ))?;
+ self.paths
+ .retire_cids(retire_prior, &mut self.connection_ids);
+ if self.connection_ids.len() >= LOCAL_ACTIVE_CID_LIMIT {
+ qinfo!([self], "received too many connection IDs");
+ return Err(Error::ConnectionIdLimitExceeded);
+ }
+ }
+ Frame::RetireConnectionId { sequence_number } => {
+ self.stats.borrow_mut().frame_rx.retire_connection_id += 1;
+ self.cid_manager.retire(sequence_number);
+ }
+ Frame::PathChallenge { data } => {
+ self.stats.borrow_mut().frame_rx.path_challenge += 1;
+ // If we were challenged, try to make the path permanent.
+ // Report an error if we don't have enough connection IDs.
+ self.ensure_permanent(path)?;
+ path.borrow_mut().challenged(data);
+ }
+ Frame::PathResponse { data } => {
+ self.stats.borrow_mut().frame_rx.path_response += 1;
+ if self.paths.path_response(data, now) {
+ // This PATH_RESPONSE enabled migration; tell loss recovery.
+ self.loss_recovery.migrate();
+ }
+ }
+ Frame::ConnectionClose {
+ error_code,
+ frame_type,
+ reason_phrase,
+ } => {
+ self.stats.borrow_mut().frame_rx.connection_close += 1;
+ let reason_phrase = String::from_utf8_lossy(&reason_phrase);
+ qinfo!(
+ [self],
+ "ConnectionClose received. Error code: {:?} frame type {:x} reason {}",
+ error_code,
+ frame_type,
+ reason_phrase
+ );
+ let (detail, frame_type) = if let CloseError::Application(_) = error_code {
+ // Use a transport error here because we want to send
+ // NO_ERROR in this case.
+ (
+ Error::PeerApplicationError(error_code.code()),
+ FRAME_TYPE_CONNECTION_CLOSE_APPLICATION,
+ )
+ } else {
+ (
+ Error::PeerError(error_code.code()),
+ FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT,
+ )
+ };
+ let error = ConnectionError::Transport(detail);
+ self.state_signaling
+ .drain(Rc::clone(path), error.clone(), frame_type, "");
+ self.set_state(State::Draining {
+ error,
+ timeout: self.get_closing_period_time(now),
+ });
+ }
+ Frame::HandshakeDone => {
+ self.stats.borrow_mut().frame_rx.handshake_done += 1;
+ if self.role == Role::Server || !self.state.connected() {
+ return Err(Error::ProtocolViolation);
+ }
+ self.set_state(State::Confirmed);
+ self.discard_keys(PacketNumberSpace::Handshake, now);
+ self.migrate_to_preferred_address(now)?;
+ }
+ Frame::AckFrequency {
+ seqno,
+ tolerance,
+ delay,
+ ignore_order,
+ } => {
+ self.stats.borrow_mut().frame_rx.ack_frequency += 1;
+ let delay = Duration::from_micros(delay);
+ if delay < GRANULARITY {
+ return Err(Error::ProtocolViolation);
+ }
+ self.acks
+ .ack_freq(seqno, tolerance - 1, delay, ignore_order);
+ }
+ Frame::Datagram { data, .. } => {
+ self.stats.borrow_mut().frame_rx.datagram += 1;
+ self.quic_datagrams
+ .handle_datagram(data, &mut self.stats.borrow_mut())?;
+ }
+ _ => unreachable!("All other frames are for streams"),
+ };
+
+ Ok(())
+ }
+
+ /// Given a set of `SentPacket` instances, ensure that the source of the packet
+ /// is told that they are lost. This gives the frame generation code a chance
+ /// to retransmit the frame as needed.
+ fn handle_lost_packets(&mut self, lost_packets: &[SentPacket]) {
+ for lost in lost_packets {
+ for token in &lost.tokens {
+ qdebug!([self], "Lost: {:?}", token);
+ match token {
+ RecoveryToken::Ack(_) => {}
+ RecoveryToken::Crypto(ct) => self.crypto.lost(ct),
+ RecoveryToken::HandshakeDone => self.state_signaling.handshake_done(),
+ RecoveryToken::NewToken(seqno) => self.new_token.lost(*seqno),
+ RecoveryToken::NewConnectionId(ncid) => self.cid_manager.lost(ncid),
+ RecoveryToken::RetireConnectionId(seqno) => self.paths.lost_retire_cid(*seqno),
+ RecoveryToken::AckFrequency(rate) => self.paths.lost_ack_frequency(rate),
+ RecoveryToken::KeepAlive => self.idle_timeout.lost_keep_alive(),
+ RecoveryToken::Stream(stream_token) => self.streams.lost(stream_token),
+ RecoveryToken::Datagram(dgram_tracker) => {
+ self.events
+ .datagram_outcome(dgram_tracker, OutgoingDatagramOutcome::Lost);
+ self.stats.borrow_mut().datagram_tx.lost += 1;
+ }
+ }
+ }
+ }
+ }
+
+ fn decode_ack_delay(&self, v: u64) -> Duration {
+ // If we have remote transport parameters, use them.
+ // Otherwise, ack delay should be zero (because it's the handshake).
+ if let Some(r) = self.tps.borrow().remote.as_ref() {
+ let exponent = u32::try_from(r.get_integer(tparams::ACK_DELAY_EXPONENT)).unwrap();
+ Duration::from_micros(v.checked_shl(exponent).unwrap_or(u64::MAX))
+ } else {
+ Duration::new(0, 0)
+ }
+ }
+
+ fn handle_ack<R>(
+ &mut self,
+ space: PacketNumberSpace,
+ largest_acknowledged: u64,
+ ack_ranges: R,
+ ack_delay: u64,
+ now: Instant,
+ ) where
+ R: IntoIterator<Item = RangeInclusive<u64>> + Debug,
+ R::IntoIter: ExactSizeIterator,
+ {
+ qinfo!([self], "Rx ACK space={}, ranges={:?}", space, ack_ranges);
+
+ let (acked_packets, lost_packets) = self.loss_recovery.on_ack_received(
+ &self.paths.primary(),
+ space,
+ largest_acknowledged,
+ ack_ranges,
+ self.decode_ack_delay(ack_delay),
+ now,
+ );
+ for acked in acked_packets {
+ for token in &acked.tokens {
+ match token {
+ RecoveryToken::Stream(stream_token) => self.streams.acked(stream_token),
+ RecoveryToken::Ack(at) => self.acks.acked(at),
+ RecoveryToken::Crypto(ct) => self.crypto.acked(ct),
+ RecoveryToken::NewToken(seqno) => self.new_token.acked(*seqno),
+ RecoveryToken::NewConnectionId(entry) => self.cid_manager.acked(entry),
+ RecoveryToken::RetireConnectionId(seqno) => self.paths.acked_retire_cid(*seqno),
+ RecoveryToken::AckFrequency(rate) => self.paths.acked_ack_frequency(rate),
+ RecoveryToken::KeepAlive => self.idle_timeout.ack_keep_alive(),
+ RecoveryToken::Datagram(dgram_tracker) => self
+ .events
+ .datagram_outcome(dgram_tracker, OutgoingDatagramOutcome::Acked),
+ // We only worry when these are lost
+ RecoveryToken::HandshakeDone => (),
+ }
+ }
+ }
+ self.handle_lost_packets(&lost_packets);
+ qlog::packets_lost(&mut self.qlog, &lost_packets);
+ let stats = &mut self.stats.borrow_mut().frame_rx;
+ stats.ack += 1;
+ stats.largest_acknowledged = max(stats.largest_acknowledged, largest_acknowledged);
+ }
+
+ /// When the server rejects 0-RTT we need to drop a bunch of stuff.
+ fn client_0rtt_rejected(&mut self, now: Instant) {
+ if !matches!(self.zero_rtt_state, ZeroRttState::Sending) {
+ return;
+ }
+ qdebug!([self], "0-RTT rejected");
+
+ // Tell 0-RTT packets that they were "lost".
+ let dropped = self.loss_recovery.drop_0rtt(&self.paths.primary(), now);
+ self.handle_lost_packets(&dropped);
+
+ self.streams.zero_rtt_rejected();
+
+ self.crypto.states.discard_0rtt_keys();
+ self.events.client_0rtt_rejected();
+ }
+
+ fn set_connected(&mut self, now: Instant) -> Res<()> {
+ qinfo!([self], "TLS connection complete");
+ if self.crypto.tls.info().map(SecretAgentInfo::alpn).is_none() {
+ qwarn!([self], "No ALPN. Closing connection.");
+ // 120 = no_application_protocol
+ return Err(Error::CryptoAlert(120));
+ }
+ if self.role == Role::Server {
+ // Remove the randomized client CID from the list of acceptable CIDs.
+ self.cid_manager.remove_odcid();
+ // Mark the path as validated, if it isn't already.
+ let path = self.paths.primary();
+ path.borrow_mut().set_valid(now);
+ // Generate a qlog event that the server connection started.
+ qlog::server_connection_started(&mut self.qlog, &path);
+ } else {
+ self.zero_rtt_state = if self.crypto.tls.info().unwrap().early_data_accepted() {
+ ZeroRttState::AcceptedClient
+ } else {
+ self.client_0rtt_rejected(now);
+ ZeroRttState::Rejected
+ };
+ }
+
+ // Setting application keys has to occur after 0-RTT rejection.
+ let pto = self.pto();
+ self.crypto
+ .install_application_keys(self.version, now + pto)?;
+ self.process_tps()?;
+ self.set_state(State::Connected);
+ self.create_resumption_token(now);
+ self.saved_datagrams
+ .make_available(CryptoSpace::ApplicationData);
+ self.stats.borrow_mut().resumed = self.crypto.tls.info().unwrap().resumed();
+ if self.role == Role::Server {
+ self.state_signaling.handshake_done();
+ self.set_state(State::Confirmed);
+ }
+ qinfo!([self], "Connection established");
+ Ok(())
+ }
+
+ fn set_state(&mut self, state: State) {
+ if state > self.state {
+ qinfo!([self], "State change from {:?} -> {:?}", self.state, state);
+ self.state = state.clone();
+ if self.state.closed() {
+ self.streams.clear_streams();
+ }
+ self.events.connection_state_change(state);
+ qlog::connection_state_updated(&mut self.qlog, &self.state);
+ } else if mem::discriminant(&state) != mem::discriminant(&self.state) {
+ // Only tolerate a regression in state if the new state is closing
+ // and the connection is already closed.
+ debug_assert!(matches!(
+ state,
+ State::Closing { .. } | State::Draining { .. }
+ ));
+ debug_assert!(self.state.closed());
+ }
+ }
+
+ /// Create a stream.
+ /// Returns new stream id
+ ///
+ /// # Errors
+ ///
+ /// `ConnectionState` if the connecton stat does not allow to create streams.
+ /// `StreamLimitError` if we are limiied by server's stream concurence.
+ pub fn stream_create(&mut self, st: StreamType) -> Res<StreamId> {
+ // Can't make streams while closing, otherwise rely on the stream limits.
+ match self.state {
+ State::Closing { .. } | State::Draining { .. } | State::Closed { .. } => {
+ return Err(Error::ConnectionState);
+ }
+ State::WaitInitial | State::Handshaking => {
+ if self.role == Role::Client && self.zero_rtt_state != ZeroRttState::Sending {
+ return Err(Error::ConnectionState);
+ }
+ }
+ // In all other states, trust that the stream limits are correct.
+ _ => (),
+ }
+
+ self.streams.stream_create(st)
+ }
+
+ /// Set the priority of a stream.
+ ///
+ /// # Errors
+ ///
+ /// `InvalidStreamId` the stream does not exist.
+ pub fn stream_priority(
+ &mut self,
+ stream_id: StreamId,
+ transmission: TransmissionPriority,
+ retransmission: RetransmissionPriority,
+ ) -> Res<()> {
+ self.streams
+ .get_send_stream_mut(stream_id)?
+ .set_priority(transmission, retransmission);
+ Ok(())
+ }
+
+ /// Set the SendOrder of a stream. Re-enqueues to keep the ordering correct
+ ///
+ /// # Errors
+ ///
+ /// Returns InvalidStreamId if the stream id doesn't exist
+ pub fn stream_sendorder(
+ &mut self,
+ stream_id: StreamId,
+ sendorder: Option<SendOrder>,
+ ) -> Res<()> {
+ self.streams.set_sendorder(stream_id, sendorder)
+ }
+
+ /// Set the Fairness of a stream
+ ///
+ /// # Errors
+ ///
+ /// Returns InvalidStreamId if the stream id doesn't exist
+ pub fn stream_fairness(&mut self, stream_id: StreamId, fairness: bool) -> Res<()> {
+ self.streams.set_fairness(stream_id, fairness)
+ }
+
+ pub fn send_stream_stats(&self, stream_id: StreamId) -> Res<SendStreamStats> {
+ self.streams.get_send_stream(stream_id).map(|s| s.stats())
+ }
+
+ pub fn recv_stream_stats(&mut self, stream_id: StreamId) -> Res<RecvStreamStats> {
+ let stream = self.streams.get_recv_stream_mut(stream_id)?;
+
+ Ok(stream.stats())
+ }
+
+ /// Send data on a stream.
+ /// Returns how many bytes were successfully sent. Could be less
+ /// than total, based on receiver credit space available, etc.
+ ///
+ /// # Errors
+ ///
+ /// `InvalidStreamId` the stream does not exist,
+ /// `InvalidInput` if length of `data` is zero,
+ /// `FinalSizeError` if the stream has already been closed.
+ pub fn stream_send(&mut self, stream_id: StreamId, data: &[u8]) -> Res<usize> {
+ self.streams.get_send_stream_mut(stream_id)?.send(data)
+ }
+
+ /// Send all data or nothing on a stream. May cause DATA_BLOCKED or
+ /// STREAM_DATA_BLOCKED frames to be sent.
+ /// Returns true if data was successfully sent, otherwise false.
+ ///
+ /// # Errors
+ ///
+ /// `InvalidStreamId` the stream does not exist,
+ /// `InvalidInput` if length of `data` is zero,
+ /// `FinalSizeError` if the stream has already been closed.
+ pub fn stream_send_atomic(&mut self, stream_id: StreamId, data: &[u8]) -> Res<bool> {
+ let val = self
+ .streams
+ .get_send_stream_mut(stream_id)?
+ .send_atomic(data);
+ if let Ok(val) = val {
+ debug_assert!(
+ val == 0 || val == data.len(),
+ "Unexpected value {} when trying to send {} bytes atomically",
+ val,
+ data.len()
+ );
+ }
+ val.map(|v| v == data.len())
+ }
+
+ /// Bytes that stream_send() is guaranteed to accept for sending.
+ /// i.e. that will not be blocked by flow credits or send buffer max
+ /// capacity.
+ pub fn stream_avail_send_space(&self, stream_id: StreamId) -> Res<usize> {
+ Ok(self.streams.get_send_stream(stream_id)?.avail())
+ }
+
+ /// Close the stream. Enqueued data will be sent.
+ pub fn stream_close_send(&mut self, stream_id: StreamId) -> Res<()> {
+ self.streams.get_send_stream_mut(stream_id)?.close();
+ Ok(())
+ }
+
+ /// Abandon transmission of in-flight and future stream data.
+ pub fn stream_reset_send(&mut self, stream_id: StreamId, err: AppError) -> Res<()> {
+ self.streams.get_send_stream_mut(stream_id)?.reset(err);
+ Ok(())
+ }
+
+ /// Read buffered data from stream. bool says whether read bytes includes
+ /// the final data on stream.
+ ///
+ /// # Errors
+ ///
+ /// `InvalidStreamId` if the stream does not exist.
+ /// `NoMoreData` if data and fin bit were previously read by the application.
+ pub fn stream_recv(&mut self, stream_id: StreamId, data: &mut [u8]) -> Res<(usize, bool)> {
+ let stream = self.streams.get_recv_stream_mut(stream_id)?;
+
+ let rb = stream.read(data)?;
+ Ok(rb)
+ }
+
+ /// Application is no longer interested in this stream.
+ pub fn stream_stop_sending(&mut self, stream_id: StreamId, err: AppError) -> Res<()> {
+ let stream = self.streams.get_recv_stream_mut(stream_id)?;
+
+ stream.stop_sending(err);
+ Ok(())
+ }
+
+ /// Increases `max_stream_data` for a `stream_id`.
+ ///
+ /// # Errors
+ ///
+ /// Returns `InvalidStreamId` if a stream does not exist or the receiving
+ /// side is closed.
+ pub fn set_stream_max_data(&mut self, stream_id: StreamId, max_data: u64) -> Res<()> {
+ let stream = self.streams.get_recv_stream_mut(stream_id)?;
+
+ stream.set_stream_max_data(max_data);
+ Ok(())
+ }
+
+ /// Mark a receive stream as being important enough to keep the connection alive
+ /// (if `keep` is `true`) or no longer important (if `keep` is `false`). If any
+ /// stream is marked this way, PING frames will be used to keep the connection
+ /// alive, even when there is no activity.
+ ///
+ /// # Errors
+ ///
+ /// Returns `InvalidStreamId` if a stream does not exist or the receiving
+ /// side is closed.
+ pub fn stream_keep_alive(&mut self, stream_id: StreamId, keep: bool) -> Res<()> {
+ self.streams.keep_alive(stream_id, keep)
+ }
+
+ pub fn remote_datagram_size(&self) -> u64 {
+ self.quic_datagrams.remote_datagram_size()
+ }
+
+ /// Returns the current max size of a datagram that can fit into a packet.
+ /// The value will change over time depending on the encoded size of the
+ /// packet number, ack frames, etc.
+ ///
+ /// # Error
+ ///
+ /// The function returns `NotAvailable` if datagrams are not enabled.
+ pub fn max_datagram_size(&self) -> Res<u64> {
+ let max_dgram_size = self.quic_datagrams.remote_datagram_size();
+ if max_dgram_size == 0 {
+ return Err(Error::NotAvailable);
+ }
+ let version = self.version();
+ let Some((cspace, tx)) = self
+ .crypto
+ .states
+ .select_tx(self.version, PacketNumberSpace::ApplicationData)
+ else {
+ return Err(Error::NotAvailable);
+ };
+ let path = self.paths.primary_fallible().ok_or(Error::NotAvailable)?;
+ let mtu = path.borrow().mtu();
+ let encoder = Encoder::with_capacity(mtu);
+
+ let (_, mut builder) = Self::build_packet_header(
+ &path.borrow(),
+ cspace,
+ encoder,
+ tx,
+ &self.address_validation,
+ version,
+ false,
+ );
+ _ = Self::add_packet_number(
+ &mut builder,
+ tx,
+ self.loss_recovery
+ .largest_acknowledged_pn(PacketNumberSpace::ApplicationData),
+ );
+
+ let data_len_possible =
+ u64::try_from(mtu.saturating_sub(tx.expansion() + builder.len() + 1)).unwrap();
+ Ok(min(data_len_possible, max_dgram_size))
+ }
+
+ /// Queue a datagram for sending.
+ ///
+ /// # Error
+ ///
+ /// The function returns `TooMuchData` if the supply buffer is bigger than
+ /// the allowed remote datagram size. The funcion does not check if the
+ /// datagram can fit into a packet (i.e. MTU limit). This is checked during
+ /// creation of an actual packet and the datagram will be dropped if it does
+ /// not fit into the packet. The app is encourage to use `max_datagram_size`
+ /// to check the estimated max datagram size and to use smaller datagrams.
+ /// `max_datagram_size` is just a current estimate and will change over
+ /// time depending on the encoded size of the packet number, ack frames, etc.
+
+ pub fn send_datagram(&mut self, buf: &[u8], id: impl Into<DatagramTracking>) -> Res<()> {
+ self.quic_datagrams
+ .add_datagram(buf, id.into(), &mut self.stats.borrow_mut())
+ }
+}
+
+impl EventProvider for Connection {
+ type Event = ConnectionEvent;
+
+ /// Return true if there are outstanding events.
+ fn has_events(&self) -> bool {
+ self.events.has_events()
+ }
+
+ /// Get events that indicate state changes on the connection. This method
+ /// correctly handles cases where handling one event can obsolete
+ /// previously-queued events, or cause new events to be generated.
+ fn next_event(&mut self) -> Option<Self::Event> {
+ self.events.next_event()
+ }
+}
+
+impl ::std::fmt::Display for Connection {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "{:?} ", self.role)?;
+ if let Some(cid) = self.odcid() {
+ std::fmt::Display::fmt(&cid, f)
+ } else {
+ write!(f, "...")
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests;
diff --git a/third_party/rust/neqo-transport/src/connection/params.rs b/third_party/rust/neqo-transport/src/connection/params.rs
new file mode 100644
index 0000000000..48aba4303b
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/params.rs
@@ -0,0 +1,392 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{cmp::max, convert::TryFrom, time::Duration};
+
+pub use crate::recovery::FAST_PTO_SCALE;
+use crate::{
+ connection::{ConnectionIdManager, Role, LOCAL_ACTIVE_CID_LIMIT},
+ recv_stream::RECV_BUFFER_SIZE,
+ rtt::GRANULARITY,
+ stream_id::StreamType,
+ tparams::{self, PreferredAddress, TransportParameter, TransportParametersHandler},
+ tracking::DEFAULT_ACK_DELAY,
+ version::{Version, VersionConfig},
+ CongestionControlAlgorithm, Res,
+};
+
+const LOCAL_MAX_DATA: u64 = 0x3FFF_FFFF_FFFF_FFFF; // 2^62-1
+const LOCAL_STREAM_LIMIT_BIDI: u64 = 16;
+const LOCAL_STREAM_LIMIT_UNI: u64 = 16;
+/// See `ConnectionParameters.ack_ratio` for a discussion of this value.
+pub const ACK_RATIO_SCALE: u8 = 10;
+/// By default, aim to have the peer acknowledge 4 times per round trip time.
+/// See `ConnectionParameters.ack_ratio` for more.
+const DEFAULT_ACK_RATIO: u8 = 4 * ACK_RATIO_SCALE;
+/// The local value for the idle timeout period.
+const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
+const MAX_QUEUED_DATAGRAMS_DEFAULT: usize = 10;
+
+/// What to do with preferred addresses.
+#[derive(Debug, Clone)]
+pub enum PreferredAddressConfig {
+ /// Disabled, whether for client or server.
+ Disabled,
+ /// Enabled at a client, disabled at a server.
+ Default,
+ /// Enabled at both client and server.
+ Address(PreferredAddress),
+}
+
+/// ConnectionParameters use for setting intitial value for QUIC parameters.
+/// This collects configuration like initial limits, protocol version, and
+/// congestion control algorithm.
+#[derive(Debug, Clone)]
+pub struct ConnectionParameters {
+ versions: VersionConfig,
+ cc_algorithm: CongestionControlAlgorithm,
+ /// Initial connection-level flow control limit.
+ max_data: u64,
+ /// Initial flow control limit for receiving data on bidirectional streams that the peer
+ /// creates.
+ max_stream_data_bidi_remote: u64,
+ /// Initial flow control limit for receiving data on bidirectional streams that this endpoint
+ /// creates.
+ max_stream_data_bidi_local: u64,
+ /// Initial flow control limit for receiving data on unidirectional streams that the peer
+ /// creates.
+ max_stream_data_uni: u64,
+ /// Initial limit on bidirectional streams that the peer creates.
+ max_streams_bidi: u64,
+ /// Initial limit on unidirectional streams that this endpoint creates.
+ max_streams_uni: u64,
+ /// The ACK ratio determines how many acknowledgements we will request as a
+ /// fraction of both the current congestion window (expressed in packets) and
+ /// as a fraction of the current round trip time. This value is scaled by
+ /// `ACK_RATIO_SCALE`; that is, if the goal is to have at least five
+ /// acknowledgments every round trip, set the value to `5 * ACK_RATIO_SCALE`.
+ /// Values less than `ACK_RATIO_SCALE` are clamped to `ACK_RATIO_SCALE`.
+ ack_ratio: u8,
+ /// The duration of the idle timeout for the connection.
+ idle_timeout: Duration,
+ preferred_address: PreferredAddressConfig,
+ datagram_size: u64,
+ outgoing_datagram_queue: usize,
+ incoming_datagram_queue: usize,
+ fast_pto: u8,
+ fuzzing: bool,
+ grease: bool,
+ pacing: bool,
+}
+
+impl Default for ConnectionParameters {
+ fn default() -> Self {
+ Self {
+ versions: VersionConfig::default(),
+ cc_algorithm: CongestionControlAlgorithm::NewReno,
+ max_data: LOCAL_MAX_DATA,
+ max_stream_data_bidi_remote: u64::try_from(RECV_BUFFER_SIZE).unwrap(),
+ max_stream_data_bidi_local: u64::try_from(RECV_BUFFER_SIZE).unwrap(),
+ max_stream_data_uni: u64::try_from(RECV_BUFFER_SIZE).unwrap(),
+ max_streams_bidi: LOCAL_STREAM_LIMIT_BIDI,
+ max_streams_uni: LOCAL_STREAM_LIMIT_UNI,
+ ack_ratio: DEFAULT_ACK_RATIO,
+ idle_timeout: DEFAULT_IDLE_TIMEOUT,
+ preferred_address: PreferredAddressConfig::Default,
+ datagram_size: 0,
+ outgoing_datagram_queue: MAX_QUEUED_DATAGRAMS_DEFAULT,
+ incoming_datagram_queue: MAX_QUEUED_DATAGRAMS_DEFAULT,
+ fast_pto: FAST_PTO_SCALE,
+ fuzzing: false,
+ grease: true,
+ pacing: true,
+ }
+ }
+}
+
+impl ConnectionParameters {
+ pub fn get_versions(&self) -> &VersionConfig {
+ &self.versions
+ }
+
+ pub(crate) fn get_versions_mut(&mut self) -> &mut VersionConfig {
+ &mut self.versions
+ }
+
+ /// Describe the initial version that should be attempted and all the
+ /// versions that should be enabled. This list should contain the initial
+ /// version and be in order of preference, with more preferred versions
+ /// before less preferred.
+ pub fn versions(mut self, initial: Version, all: Vec<Version>) -> Self {
+ self.versions = VersionConfig::new(initial, all);
+ self
+ }
+
+ pub fn get_cc_algorithm(&self) -> CongestionControlAlgorithm {
+ self.cc_algorithm
+ }
+
+ pub fn cc_algorithm(mut self, v: CongestionControlAlgorithm) -> Self {
+ self.cc_algorithm = v;
+ self
+ }
+
+ pub fn get_max_data(&self) -> u64 {
+ self.max_data
+ }
+
+ pub fn max_data(mut self, v: u64) -> Self {
+ self.max_data = v;
+ self
+ }
+
+ pub fn get_max_streams(&self, stream_type: StreamType) -> u64 {
+ match stream_type {
+ StreamType::BiDi => self.max_streams_bidi,
+ StreamType::UniDi => self.max_streams_uni,
+ }
+ }
+
+ /// # Panics
+ ///
+ /// If v > 2^60 (the maximum allowed by the protocol).
+ pub fn max_streams(mut self, stream_type: StreamType, v: u64) -> Self {
+ assert!(v <= (1 << 60), "max_streams is too large");
+ match stream_type {
+ StreamType::BiDi => {
+ self.max_streams_bidi = v;
+ }
+ StreamType::UniDi => {
+ self.max_streams_uni = v;
+ }
+ }
+ self
+ }
+
+ /// Get the maximum stream data that we will accept on different types of streams.
+ ///
+ /// # Panics
+ ///
+ /// If `StreamType::UniDi` and `false` are passed as that is not a valid combination.
+ pub fn get_max_stream_data(&self, stream_type: StreamType, remote: bool) -> u64 {
+ match (stream_type, remote) {
+ (StreamType::BiDi, false) => self.max_stream_data_bidi_local,
+ (StreamType::BiDi, true) => self.max_stream_data_bidi_remote,
+ (StreamType::UniDi, false) => {
+ panic!("Can't get receive limit on a stream that can only be sent.")
+ }
+ (StreamType::UniDi, true) => self.max_stream_data_uni,
+ }
+ }
+
+ /// Set the maximum stream data that we will accept on different types of streams.
+ ///
+ /// # Panics
+ ///
+ /// If `StreamType::UniDi` and `false` are passed as that is not a valid combination
+ /// or if v >= 62 (the maximum allowed by the protocol).
+ pub fn max_stream_data(mut self, stream_type: StreamType, remote: bool, v: u64) -> Self {
+ assert!(v < (1 << 62), "max stream data is too large");
+ match (stream_type, remote) {
+ (StreamType::BiDi, false) => {
+ self.max_stream_data_bidi_local = v;
+ }
+ (StreamType::BiDi, true) => {
+ self.max_stream_data_bidi_remote = v;
+ }
+ (StreamType::UniDi, false) => {
+ panic!("Can't set receive limit on a stream that can only be sent.")
+ }
+ (StreamType::UniDi, true) => {
+ self.max_stream_data_uni = v;
+ }
+ }
+ self
+ }
+
+ /// Set a preferred address (which only has an effect for a server).
+ pub fn preferred_address(mut self, preferred: PreferredAddress) -> Self {
+ self.preferred_address = PreferredAddressConfig::Address(preferred);
+ self
+ }
+
+ /// Disable the use of preferred addresses.
+ pub fn disable_preferred_address(mut self) -> Self {
+ self.preferred_address = PreferredAddressConfig::Disabled;
+ self
+ }
+
+ pub fn get_preferred_address(&self) -> &PreferredAddressConfig {
+ &self.preferred_address
+ }
+
+ pub fn ack_ratio(mut self, ack_ratio: u8) -> Self {
+ self.ack_ratio = ack_ratio;
+ self
+ }
+
+ pub fn get_ack_ratio(&self) -> u8 {
+ self.ack_ratio
+ }
+
+ /// # Panics
+ ///
+ /// If `timeout` is 2^62 milliseconds or more.
+ pub fn idle_timeout(mut self, timeout: Duration) -> Self {
+ assert!(timeout.as_millis() < (1 << 62), "idle timeout is too long");
+ self.idle_timeout = timeout;
+ self
+ }
+
+ pub fn get_idle_timeout(&self) -> Duration {
+ self.idle_timeout
+ }
+
+ pub fn get_datagram_size(&self) -> u64 {
+ self.datagram_size
+ }
+
+ pub fn datagram_size(mut self, v: u64) -> Self {
+ self.datagram_size = v;
+ self
+ }
+
+ pub fn get_outgoing_datagram_queue(&self) -> usize {
+ self.outgoing_datagram_queue
+ }
+
+ pub fn outgoing_datagram_queue(mut self, v: usize) -> Self {
+ // The max queue length must be at least 1.
+ self.outgoing_datagram_queue = max(v, 1);
+ self
+ }
+
+ pub fn get_incoming_datagram_queue(&self) -> usize {
+ self.incoming_datagram_queue
+ }
+
+ pub fn incoming_datagram_queue(mut self, v: usize) -> Self {
+ // The max queue length must be at least 1.
+ self.incoming_datagram_queue = max(v, 1);
+ self
+ }
+
+ pub fn get_fast_pto(&self) -> u8 {
+ self.fast_pto
+ }
+
+ /// Scale the PTO timer. A value of `FAST_PTO_SCALE` follows the spec, a smaller
+ /// value does not, but produces more probes with the intent of ensuring lower
+ /// latency in the event of tail loss. A value of `FAST_PTO_SCALE/4` is quite
+ /// aggressive. Smaller values (other than zero) are not rejected, but could be
+ /// very wasteful. Values greater than `FAST_PTO_SCALE` delay probes and could
+ /// reduce performance. It should not be possible to increase the PTO timer by
+ /// too much based on the range of valid values, but a maximum value of 255 will
+ /// result in very poor performance.
+ /// Scaling PTO this way does not affect when persistent congestion is declared,
+ /// but may change how many retransmissions are sent before declaring persistent
+ /// congestion.
+ ///
+ /// # Panics
+ ///
+ /// A value of 0 is invalid and will cause a panic.
+ pub fn fast_pto(mut self, scale: u8) -> Self {
+ assert_ne!(scale, 0);
+ self.fast_pto = scale;
+ self
+ }
+
+ pub fn is_fuzzing(&self) -> bool {
+ self.fuzzing
+ }
+
+ pub fn fuzzing(mut self, enable: bool) -> Self {
+ self.fuzzing = enable;
+ self
+ }
+
+ pub fn is_greasing(&self) -> bool {
+ self.grease
+ }
+
+ pub fn grease(mut self, grease: bool) -> Self {
+ self.grease = grease;
+ self
+ }
+
+ pub fn pacing_enabled(&self) -> bool {
+ self.pacing
+ }
+
+ pub fn pacing(mut self, pacing: bool) -> Self {
+ self.pacing = pacing;
+ self
+ }
+
+ pub fn create_transport_parameter(
+ &self,
+ role: Role,
+ cid_manager: &mut ConnectionIdManager,
+ ) -> Res<TransportParametersHandler> {
+ let mut tps = TransportParametersHandler::new(role, self.versions.clone());
+ // default parameters
+ tps.local.set_integer(
+ tparams::ACTIVE_CONNECTION_ID_LIMIT,
+ u64::try_from(LOCAL_ACTIVE_CID_LIMIT).unwrap(),
+ );
+ tps.local.set_empty(tparams::DISABLE_MIGRATION);
+ tps.local.set_empty(tparams::GREASE_QUIC_BIT);
+ tps.local.set_integer(
+ tparams::MAX_ACK_DELAY,
+ u64::try_from(DEFAULT_ACK_DELAY.as_millis()).unwrap(),
+ );
+ tps.local.set_integer(
+ tparams::MIN_ACK_DELAY,
+ u64::try_from(GRANULARITY.as_micros()).unwrap(),
+ );
+
+ // set configurable parameters
+ tps.local
+ .set_integer(tparams::INITIAL_MAX_DATA, self.max_data);
+ tps.local.set_integer(
+ tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL,
+ self.max_stream_data_bidi_local,
+ );
+ tps.local.set_integer(
+ tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE,
+ self.max_stream_data_bidi_remote,
+ );
+ tps.local.set_integer(
+ tparams::INITIAL_MAX_STREAM_DATA_UNI,
+ self.max_stream_data_uni,
+ );
+ tps.local
+ .set_integer(tparams::INITIAL_MAX_STREAMS_BIDI, self.max_streams_bidi);
+ tps.local
+ .set_integer(tparams::INITIAL_MAX_STREAMS_UNI, self.max_streams_uni);
+ tps.local.set_integer(
+ tparams::IDLE_TIMEOUT,
+ u64::try_from(self.idle_timeout.as_millis()).unwrap_or(0),
+ );
+ if let PreferredAddressConfig::Address(preferred) = &self.preferred_address {
+ if role == Role::Server {
+ let (cid, srt) = cid_manager.preferred_address_cid()?;
+ tps.local.set(
+ tparams::PREFERRED_ADDRESS,
+ TransportParameter::PreferredAddress {
+ v4: preferred.ipv4(),
+ v6: preferred.ipv6(),
+ cid,
+ srt,
+ },
+ );
+ }
+ }
+ tps.local
+ .set_integer(tparams::MAX_DATAGRAM_FRAME_SIZE, self.datagram_size);
+ Ok(tps)
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/connection/saved.rs b/third_party/rust/neqo-transport/src/connection/saved.rs
new file mode 100644
index 0000000000..f5616c732a
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/saved.rs
@@ -0,0 +1,68 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{mem, time::Instant};
+
+use neqo_common::{qdebug, qinfo, Datagram};
+
+use crate::crypto::CryptoSpace;
+
+/// The number of datagrams that are saved during the handshake when
+/// keys to decrypt them are not yet available.
+const MAX_SAVED_DATAGRAMS: usize = 4;
+
+pub struct SavedDatagram {
+ /// The datagram.
+ pub d: Datagram,
+ /// The time that the datagram was received.
+ pub t: Instant,
+}
+
+#[derive(Default)]
+pub struct SavedDatagrams {
+ handshake: Vec<SavedDatagram>,
+ application_data: Vec<SavedDatagram>,
+ available: Option<CryptoSpace>,
+}
+
+impl SavedDatagrams {
+ fn store(&mut self, cspace: CryptoSpace) -> &mut Vec<SavedDatagram> {
+ match cspace {
+ CryptoSpace::Handshake => &mut self.handshake,
+ CryptoSpace::ApplicationData => &mut self.application_data,
+ _ => panic!("unexpected space"),
+ }
+ }
+
+ pub fn save(&mut self, cspace: CryptoSpace, d: Datagram, t: Instant) {
+ let store = self.store(cspace);
+
+ if store.len() < MAX_SAVED_DATAGRAMS {
+ qdebug!("saving datagram of {} bytes", d.len());
+ store.push(SavedDatagram { d, t });
+ } else {
+ qinfo!("not saving datagram of {} bytes", d.len());
+ }
+ }
+
+ pub fn make_available(&mut self, cspace: CryptoSpace) {
+ debug_assert_ne!(cspace, CryptoSpace::ZeroRtt);
+ debug_assert_ne!(cspace, CryptoSpace::Initial);
+ if !self.store(cspace).is_empty() {
+ self.available = Some(cspace);
+ }
+ }
+
+ pub fn available(&self) -> Option<CryptoSpace> {
+ self.available
+ }
+
+ pub fn take_saved(&mut self) -> Vec<SavedDatagram> {
+ self.available
+ .take()
+ .map_or_else(Vec::new, |cspace| mem::take(self.store(cspace)))
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/connection/state.rs b/third_party/rust/neqo-transport/src/connection/state.rs
new file mode 100644
index 0000000000..9afb42174f
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/state.rs
@@ -0,0 +1,281 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{
+ cmp::{min, Ordering},
+ mem,
+ rc::Rc,
+ time::Instant,
+};
+
+use neqo_common::Encoder;
+
+use crate::{
+ frame::{
+ FrameType, FRAME_TYPE_CONNECTION_CLOSE_APPLICATION, FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT,
+ FRAME_TYPE_HANDSHAKE_DONE,
+ },
+ packet::PacketBuilder,
+ path::PathRef,
+ recovery::RecoveryToken,
+ ConnectionError, Error, Res,
+};
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+/// The state of the Connection.
+pub enum State {
+ /// A newly created connection.
+ Init,
+ /// Waiting for the first Initial packet.
+ WaitInitial,
+ /// Waiting to confirm which version was selected.
+ /// For a client, this is confirmed when a CRYPTO frame is received;
+ /// the version of the packet determines the version.
+ /// For a server, this is confirmed when transport parameters are
+ /// received and processed.
+ WaitVersion,
+ /// Exchanging Handshake packets.
+ Handshaking,
+ Connected,
+ Confirmed,
+ Closing {
+ error: ConnectionError,
+ timeout: Instant,
+ },
+ Draining {
+ error: ConnectionError,
+ timeout: Instant,
+ },
+ Closed(ConnectionError),
+}
+
+impl State {
+ #[must_use]
+ pub fn connected(&self) -> bool {
+ matches!(self, Self::Connected | Self::Confirmed)
+ }
+
+ #[must_use]
+ pub fn closed(&self) -> bool {
+ matches!(
+ self,
+ Self::Closing { .. } | Self::Draining { .. } | Self::Closed(_)
+ )
+ }
+
+ pub fn error(&self) -> Option<&ConnectionError> {
+ if let Self::Closing { error, .. } | Self::Draining { error, .. } | Self::Closed(error) =
+ self
+ {
+ Some(error)
+ } else {
+ None
+ }
+ }
+}
+
+// Implement `PartialOrd` so that we can enforce monotonic state progression.
+impl PartialOrd for State {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl Ord for State {
+ fn cmp(&self, other: &Self) -> Ordering {
+ if mem::discriminant(self) == mem::discriminant(other) {
+ return Ordering::Equal;
+ }
+ #[allow(clippy::match_same_arms)] // Lint bug: rust-lang/rust-clippy#860
+ match (self, other) {
+ (Self::Init, _) => Ordering::Less,
+ (_, Self::Init) => Ordering::Greater,
+ (Self::WaitInitial, _) => Ordering::Less,
+ (_, Self::WaitInitial) => Ordering::Greater,
+ (Self::WaitVersion, _) => Ordering::Less,
+ (_, Self::WaitVersion) => Ordering::Greater,
+ (Self::Handshaking, _) => Ordering::Less,
+ (_, Self::Handshaking) => Ordering::Greater,
+ (Self::Connected, _) => Ordering::Less,
+ (_, Self::Connected) => Ordering::Greater,
+ (Self::Confirmed, _) => Ordering::Less,
+ (_, Self::Confirmed) => Ordering::Greater,
+ (Self::Closing { .. }, _) => Ordering::Less,
+ (_, Self::Closing { .. }) => Ordering::Greater,
+ (Self::Draining { .. }, _) => Ordering::Less,
+ (_, Self::Draining { .. }) => Ordering::Greater,
+ (Self::Closed(_), _) => unreachable!(),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct ClosingFrame {
+ path: PathRef,
+ error: ConnectionError,
+ frame_type: FrameType,
+ reason_phrase: Vec<u8>,
+}
+
+impl ClosingFrame {
+ fn new(
+ path: PathRef,
+ error: ConnectionError,
+ frame_type: FrameType,
+ message: impl AsRef<str>,
+ ) -> Self {
+ let reason_phrase = message.as_ref().as_bytes().to_vec();
+ Self {
+ path,
+ error,
+ frame_type,
+ reason_phrase,
+ }
+ }
+
+ pub fn path(&self) -> &PathRef {
+ &self.path
+ }
+
+ pub fn sanitize(&self) -> Option<Self> {
+ if let ConnectionError::Application(_) = self.error {
+ // The default CONNECTION_CLOSE frame that is sent when an application
+ // error code needs to be sent in an Initial or Handshake packet.
+ Some(Self {
+ path: Rc::clone(&self.path),
+ error: ConnectionError::Transport(Error::ApplicationError),
+ frame_type: 0,
+ reason_phrase: Vec::new(),
+ })
+ } else {
+ None
+ }
+ }
+
+ pub fn write_frame(&self, builder: &mut PacketBuilder) {
+ // Allow 8 bytes for the reason phrase to ensure that if it needs to be
+ // truncated there is still at least a few bytes of the value.
+ if builder.remaining() < 1 + 8 + 8 + 2 + 8 {
+ return;
+ }
+ match &self.error {
+ ConnectionError::Transport(e) => {
+ builder.encode_varint(FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT);
+ builder.encode_varint(e.code());
+ builder.encode_varint(self.frame_type);
+ }
+ ConnectionError::Application(code) => {
+ builder.encode_varint(FRAME_TYPE_CONNECTION_CLOSE_APPLICATION);
+ builder.encode_varint(*code);
+ }
+ }
+ // Truncate the reason phrase if it doesn't fit. As we send this frame in
+ // multiple packet number spaces, limit the overall size to 256.
+ let available = min(256, builder.remaining());
+ let reason = if available < Encoder::vvec_len(self.reason_phrase.len()) {
+ &self.reason_phrase[..available - 2]
+ } else {
+ &self.reason_phrase
+ };
+ builder.encode_vvec(reason);
+ }
+}
+
+/// `StateSignaling` manages whether we need to send HANDSHAKE_DONE and CONNECTION_CLOSE.
+/// Valid state transitions are:
+/// * Idle -> HandshakeDone: at the server when the handshake completes
+/// * HandshakeDone -> Idle: when a HANDSHAKE_DONE frame is sent
+/// * Idle/HandshakeDone -> Closing/Draining: when closing or draining
+/// * Closing/Draining -> CloseSent: after sending CONNECTION_CLOSE
+/// * CloseSent -> Closing: any time a new CONNECTION_CLOSE is needed
+/// * -> Reset: from any state in case of a stateless reset
+#[derive(Debug, Clone)]
+pub enum StateSignaling {
+ Idle,
+ HandshakeDone,
+ /// These states save the frame that needs to be sent.
+ Closing(ClosingFrame),
+ Draining(ClosingFrame),
+ /// This state saves the frame that might need to be sent again.
+ /// If it is `None`, then we are draining and don't send.
+ CloseSent(Option<ClosingFrame>),
+ Reset,
+}
+
+impl StateSignaling {
+ pub fn handshake_done(&mut self) {
+ if !matches!(self, Self::Idle) {
+ debug_assert!(false, "StateSignaling must be in Idle state.");
+ return;
+ }
+ *self = Self::HandshakeDone;
+ }
+
+ pub fn write_done(&mut self, builder: &mut PacketBuilder) -> Res<Option<RecoveryToken>> {
+ if matches!(self, Self::HandshakeDone) && builder.remaining() >= 1 {
+ *self = Self::Idle;
+ builder.encode_varint(FRAME_TYPE_HANDSHAKE_DONE);
+ Ok(Some(RecoveryToken::HandshakeDone))
+ } else {
+ Ok(None)
+ }
+ }
+
+ pub fn close(
+ &mut self,
+ path: PathRef,
+ error: ConnectionError,
+ frame_type: FrameType,
+ message: impl AsRef<str>,
+ ) {
+ if !matches!(self, Self::Reset) {
+ *self = Self::Closing(ClosingFrame::new(path, error, frame_type, message));
+ }
+ }
+
+ pub fn drain(
+ &mut self,
+ path: PathRef,
+ error: ConnectionError,
+ frame_type: FrameType,
+ message: impl AsRef<str>,
+ ) {
+ if !matches!(self, Self::Reset) {
+ *self = Self::Draining(ClosingFrame::new(path, error, frame_type, message));
+ }
+ }
+
+ /// If a close is pending, take a frame.
+ pub fn close_frame(&mut self) -> Option<ClosingFrame> {
+ match self {
+ Self::Closing(frame) => {
+ // When we are closing, we might need to send the close frame again.
+ let res = Some(frame.clone());
+ *self = Self::CloseSent(Some(frame.clone()));
+ res
+ }
+ Self::Draining(frame) => {
+ // When we are draining, just send once.
+ let res = Some(frame.clone());
+ *self = Self::CloseSent(None);
+ res
+ }
+ _ => None,
+ }
+ }
+
+ /// If a close can be sent again, prepare to send it again.
+ pub fn send_close(&mut self) {
+ if let Self::CloseSent(Some(frame)) = self {
+ *self = Self::Closing(frame.clone());
+ }
+ }
+
+ /// We just got a stateless reset. Terminate.
+ pub fn reset(&mut self) {
+ *self = Self::Reset;
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/connection/test_internal.rs b/third_party/rust/neqo-transport/src/connection/test_internal.rs
new file mode 100644
index 0000000000..353c38e526
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/test_internal.rs
@@ -0,0 +1,13 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Some access to internal connection stuff for testing purposes.
+
+use crate::packet::PacketBuilder;
+
+pub trait FrameWriter {
+ fn write_frames(&mut self, builder: &mut PacketBuilder);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/ackrate.rs b/third_party/rust/neqo-transport/src/connection/tests/ackrate.rs
new file mode 100644
index 0000000000..1b83d42acd
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/ackrate.rs
@@ -0,0 +1,194 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{mem, time::Duration};
+
+use test_fixture::{addr_v4, assertions};
+
+use super::{
+ super::{ConnectionParameters, ACK_RATIO_SCALE},
+ ack_bytes, connect_rtt_idle, default_client, default_server, fill_cwnd, increase_cwnd,
+ induce_persistent_congestion, new_client, new_server, send_something, DEFAULT_RTT,
+};
+use crate::stream_id::StreamType;
+
+/// With the default RTT here (100ms) and default ratio (4), endpoints won't send
+/// `ACK_FREQUENCY` as the ACK delay isn't different enough from the default.
+#[test]
+fn ack_rate_default() {
+ let mut client = default_client();
+ let mut server = default_server();
+ _ = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ assert_eq!(client.stats().frame_tx.ack_frequency, 0);
+ assert_eq!(server.stats().frame_tx.ack_frequency, 0);
+}
+
+/// When the congestion window increases, the rate doesn't change.
+#[test]
+fn ack_rate_slow_start() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Increase the congestion window a few times.
+ let stream = client.stream_create(StreamType::UniDi).unwrap();
+ let now = increase_cwnd(&mut client, &mut server, stream, now);
+ let now = increase_cwnd(&mut client, &mut server, stream, now);
+ _ = increase_cwnd(&mut client, &mut server, stream, now);
+
+ // The client should not have sent an ACK_FREQUENCY frame, even
+ // though the value would have updated.
+ assert_eq!(client.stats().frame_tx.ack_frequency, 0);
+ assert_eq!(server.stats().frame_rx.ack_frequency, 0);
+}
+
+/// When the congestion window decreases, a frame is sent.
+#[test]
+fn ack_rate_exit_slow_start() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Increase the congestion window a few times, enough that after a loss,
+ // there are enough packets in the window to increase the packet
+ // count in ACK_FREQUENCY frames.
+ let stream = client.stream_create(StreamType::UniDi).unwrap();
+ let now = increase_cwnd(&mut client, &mut server, stream, now);
+ let now = increase_cwnd(&mut client, &mut server, stream, now);
+
+ // Now fill the congestion window and drop the first packet.
+ let (mut pkts, mut now) = fill_cwnd(&mut client, stream, now);
+ pkts.remove(0);
+
+ // After acknowledging the other packets the client will notice the loss.
+ now += DEFAULT_RTT / 2;
+ let ack = ack_bytes(&mut server, stream, pkts, now);
+
+ // Receiving the ACK will cause the client to reduce its congestion window
+ // and to send ACK_FREQUENCY.
+ now += DEFAULT_RTT / 2;
+ assert_eq!(client.stats().frame_tx.ack_frequency, 0);
+ let af = client.process(Some(&ack), now).dgram();
+ assert!(af.is_some());
+ assert_eq!(client.stats().frame_tx.ack_frequency, 1);
+}
+
+/// When the congestion window collapses, `ACK_FREQUENCY` is updated.
+#[test]
+fn ack_rate_persistent_congestion() {
+ // Use a configuration that results in the value being set after exiting
+ // the handshake.
+ const RTT: Duration = Duration::from_millis(3);
+ let mut client = new_client(ConnectionParameters::default().ack_ratio(ACK_RATIO_SCALE));
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, RTT);
+
+ // The client should have sent a frame.
+ assert_eq!(client.stats().frame_tx.ack_frequency, 1);
+
+ // Now crash the congestion window.
+ let stream = client.stream_create(StreamType::UniDi).unwrap();
+ let (dgrams, mut now) = fill_cwnd(&mut client, stream, now);
+ now += RTT / 2;
+ mem::drop(ack_bytes(&mut server, stream, dgrams, now));
+
+ let now = induce_persistent_congestion(&mut client, &mut server, stream, now);
+
+ // The client sends a second ACK_FREQUENCY frame with an increased rate.
+ let af = client.process_output(now).dgram();
+ assert!(af.is_some());
+ assert_eq!(client.stats().frame_tx.ack_frequency, 2);
+}
+
+/// Validate that the configuration works for the client.
+#[test]
+fn ack_rate_client_one_rtt() {
+ // This has to be chosen so that the resulting ACK delay is between 1ms and 50ms.
+ // We also have to avoid values between 20..30ms (approximately). The default
+ // maximum ACK delay is 25ms and an ACK_FREQUENCY frame won't be sent when the
+ // change to the maximum ACK delay is too small.
+ const RTT: Duration = Duration::from_millis(3);
+ let mut client = new_client(ConnectionParameters::default().ack_ratio(ACK_RATIO_SCALE));
+ let mut server = default_server();
+ let mut now = connect_rtt_idle(&mut client, &mut server, RTT);
+
+ // A single packet from the client will cause the server to engage its delayed
+ // acknowledgment timer, which should now be equal to RTT.
+ // The first packet will elicit an immediate ACK however, so do this twice.
+ let d = send_something(&mut client, now);
+ now += RTT / 2;
+ let ack = server.process(Some(&d), now).dgram();
+ assert!(ack.is_some());
+ let d = send_something(&mut client, now);
+ now += RTT / 2;
+ let delay = server.process(Some(&d), now).callback();
+ assert_eq!(delay, RTT);
+
+ assert_eq!(client.stats().frame_tx.ack_frequency, 1);
+}
+
+/// Validate that the configuration works for the server.
+#[test]
+fn ack_rate_server_half_rtt() {
+ const RTT: Duration = Duration::from_millis(10);
+ let mut client = default_client();
+ let mut server = new_server(ConnectionParameters::default().ack_ratio(ACK_RATIO_SCALE * 2));
+ let mut now = connect_rtt_idle(&mut client, &mut server, RTT);
+
+ // The server now sends something.
+ let d = send_something(&mut server, now);
+ now += RTT / 2;
+ // The client now will acknowledge immediately because it has been more than
+ // an RTT since it last sent an acknowledgment.
+ let ack = client.process(Some(&d), now);
+ assert!(ack.as_dgram_ref().is_some());
+ let d = send_something(&mut server, now);
+ now += RTT / 2;
+ let delay = client.process(Some(&d), now).callback();
+ assert_eq!(delay, RTT / 2);
+
+ assert_eq!(server.stats().frame_tx.ack_frequency, 1);
+}
+
+/// ACK delay calculations are path-specific,
+/// so check that they can be sent on new paths.
+#[test]
+fn migrate_ack_delay() {
+ // Have the client send ACK_FREQUENCY frames at a normal-ish rate.
+ let mut client = new_client(ConnectionParameters::default().ack_ratio(ACK_RATIO_SCALE));
+ let mut server = default_server();
+ let mut now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ client
+ .migrate(Some(addr_v4()), Some(addr_v4()), true, now)
+ .unwrap();
+
+ let client1 = send_something(&mut client, now);
+ assertions::assert_v4_path(&client1, true); // Contains PATH_CHALLENGE.
+ let client2 = send_something(&mut client, now);
+ assertions::assert_v4_path(&client2, false); // Doesn't. Is dropped.
+ now += DEFAULT_RTT / 2;
+ server.process_input(&client1, now);
+
+ let stream = client.stream_create(StreamType::UniDi).unwrap();
+ let now = increase_cwnd(&mut client, &mut server, stream, now);
+ let now = increase_cwnd(&mut client, &mut server, stream, now);
+ let now = increase_cwnd(&mut client, &mut server, stream, now);
+
+ // Now lose a packet and force the client to update
+ let (mut pkts, mut now) = fill_cwnd(&mut client, stream, now);
+ pkts.remove(0);
+ now += DEFAULT_RTT / 2;
+ let ack = ack_bytes(&mut server, stream, pkts, now);
+
+ // After noticing this new loss, the client sends ACK_FREQUENCY.
+ // It has sent a few before (as we dropped `client2`), so ignore those.
+ let ad_before = client.stats().frame_tx.ack_frequency;
+ let af = client.process(Some(&ack), now).dgram();
+ assert!(af.is_some());
+ assert_eq!(client.stats().frame_tx.ack_frequency, ad_before + 1);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/cc.rs b/third_party/rust/neqo-transport/src/connection/tests/cc.rs
new file mode 100644
index 0000000000..b3467ea67c
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/cc.rs
@@ -0,0 +1,429 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{convert::TryFrom, mem, time::Duration};
+
+use neqo_common::{qdebug, qinfo, Datagram};
+
+use super::{
+ super::Output, ack_bytes, assert_full_cwnd, connect_rtt_idle, cwnd, cwnd_avail, cwnd_packets,
+ default_client, default_server, fill_cwnd, induce_persistent_congestion, send_something,
+ CLIENT_HANDSHAKE_1RTT_PACKETS, DEFAULT_RTT, POST_HANDSHAKE_CWND,
+};
+use crate::{
+ cc::MAX_DATAGRAM_SIZE,
+ packet::PacketNumber,
+ recovery::{ACK_ONLY_SIZE_LIMIT, PACKET_THRESHOLD},
+ sender::PACING_BURST_SIZE,
+ stream_id::StreamType,
+ tracking::DEFAULT_ACK_PACKET_TOLERANCE,
+};
+
+#[test]
+/// Verify initial CWND is honored.
+fn cc_slow_start() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Try to send a lot of data
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ let (c_tx_dgrams, _) = fill_cwnd(&mut client, stream_id, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+ assert!(cwnd_avail(&client) < ACK_ONLY_SIZE_LIMIT);
+}
+
+#[test]
+/// Verify that CC moves to cong avoidance when a packet is marked lost.
+fn cc_slow_start_to_cong_avoidance_recovery_period() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Create stream 0
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ assert_eq!(stream_id, 0);
+
+ // Buffer up lot of data and generate packets
+ let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream_id, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+ // Predict the packet number of the last packet sent.
+ // We have already sent packets in `connect_rtt_idle`,
+ // so include a fudge factor.
+ let flight1_largest =
+ PacketNumber::try_from(c_tx_dgrams.len() + CLIENT_HANDSHAKE_1RTT_PACKETS).unwrap();
+
+ // Server: Receive and generate ack
+ now += DEFAULT_RTT / 2;
+ let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now);
+ assert_eq!(
+ server.stats().frame_tx.largest_acknowledged,
+ flight1_largest
+ );
+
+ // Client: Process ack
+ now += DEFAULT_RTT / 2;
+ client.process_input(&s_ack, now);
+ assert_eq!(
+ client.stats().frame_rx.largest_acknowledged,
+ flight1_largest
+ );
+
+ // Client: send more
+ let (mut c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream_id, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND * 2);
+ let flight2_largest = flight1_largest + u64::try_from(c_tx_dgrams.len()).unwrap();
+
+ // Server: Receive and generate ack again, but drop first packet
+ now += DEFAULT_RTT / 2;
+ c_tx_dgrams.remove(0);
+ let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now);
+ assert_eq!(
+ server.stats().frame_tx.largest_acknowledged,
+ flight2_largest
+ );
+
+ // Client: Process ack
+ now += DEFAULT_RTT / 2;
+ client.process_input(&s_ack, now);
+ assert_eq!(
+ client.stats().frame_rx.largest_acknowledged,
+ flight2_largest
+ );
+}
+
+#[test]
+/// Verify that CC stays in recovery period when packet sent before start of
+/// recovery period is acked.
+fn cc_cong_avoidance_recovery_period_unchanged() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Create stream 0
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ assert_eq!(stream_id, 0);
+
+ // Buffer up lot of data and generate packets
+ let (mut c_tx_dgrams, now) = fill_cwnd(&mut client, stream_id, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+
+ // Drop 0th packet. When acked, this should put client into CARP.
+ c_tx_dgrams.remove(0);
+
+ let c_tx_dgrams2 = c_tx_dgrams.split_off(5);
+
+ // Server: Receive and generate ack
+ let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now);
+ client.process_input(&s_ack, now);
+
+ let cwnd1 = cwnd(&client);
+
+ // Generate ACK for more received packets
+ let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams2, now);
+
+ // ACK more packets but they were sent before end of recovery period
+ client.process_input(&s_ack, now);
+
+ // cwnd should not have changed since ACKed packets were sent before
+ // recovery period expired
+ let cwnd2 = cwnd(&client);
+ assert_eq!(cwnd1, cwnd2);
+}
+
+#[test]
+/// Ensure that a single packet is sent after entering recovery, even
+/// when that exceeds the available congestion window.
+fn single_packet_on_recovery() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Drop a few packets, up to the reordering threshold.
+ for _ in 0..PACKET_THRESHOLD {
+ let _dropped = send_something(&mut client, now);
+ }
+ let delivered = send_something(&mut client, now);
+
+ // Now fill the congestion window.
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ assert_eq!(stream_id, 0);
+ let (_, now) = fill_cwnd(&mut client, stream_id, now);
+ assert!(cwnd_avail(&client) < ACK_ONLY_SIZE_LIMIT);
+
+ // Acknowledge just one packet and cause one packet to be declared lost.
+ // The length is the amount of credit the client should have.
+ let ack = server.process(Some(&delivered), now).dgram();
+ assert!(ack.is_some());
+
+ // The client should see the loss and enter recovery.
+ // As there are many outstanding packets, there should be no available cwnd.
+ client.process_input(&ack.unwrap(), now);
+ assert_eq!(cwnd_avail(&client), 0);
+
+ // The client should send one packet, ignoring the cwnd.
+ let dgram = client.process_output(now).dgram();
+ assert!(dgram.is_some());
+}
+
+#[test]
+/// Verify that CC moves out of recovery period when packet sent after start
+/// of recovery period is acked.
+fn cc_cong_avoidance_recovery_period_to_cong_avoidance() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Create stream 0
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ assert_eq!(stream_id, 0);
+
+ // Buffer up lot of data and generate packets
+ let (mut c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream_id, now);
+
+ // Drop 0th packet. When acked, this should put client into CARP.
+ c_tx_dgrams.remove(0);
+
+ // Server: Receive and generate ack
+ now += DEFAULT_RTT / 2;
+ let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now);
+
+ // Client: Process ack
+ now += DEFAULT_RTT / 2;
+ client.process_input(&s_ack, now);
+
+ // Should be in CARP now.
+ now += DEFAULT_RTT / 2;
+ qinfo!("moving to congestion avoidance {}", cwnd(&client));
+
+ // Now make sure that we increase congestion window according to the
+ // accurate byte counting version of congestion avoidance.
+ // Check over several increases to be sure.
+ let mut expected_cwnd = cwnd(&client);
+ // Fill cwnd.
+ let (mut c_tx_dgrams, next_now) = fill_cwnd(&mut client, stream_id, now);
+ now = next_now;
+ for i in 0..5 {
+ qinfo!("iteration {}", i);
+
+ let c_tx_size: usize = c_tx_dgrams.iter().map(|d| d.len()).sum();
+ qinfo!(
+ "client sending {} bytes into cwnd of {}",
+ c_tx_size,
+ cwnd(&client)
+ );
+ assert_eq!(c_tx_size, expected_cwnd);
+
+ // As acks arrive we will continue filling cwnd and save all packets
+ // from this cycle will be stored in next_c_tx_dgrams.
+ let mut next_c_tx_dgrams: Vec<Datagram> = Vec::new();
+
+ // Until we process all the packets, the congestion window remains the same.
+ // Note that we need the client to process ACK frames in stages, so split the
+ // datagrams into two, ensuring that we allow for an ACK for each batch.
+ let most = c_tx_dgrams.len() - usize::try_from(DEFAULT_ACK_PACKET_TOLERANCE).unwrap() - 1;
+ let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams.drain(..most), now);
+ assert_eq!(cwnd(&client), expected_cwnd);
+ client.process_input(&s_ack, now);
+ // make sure to fill cwnd again.
+ let (mut new_pkts, next_now) = fill_cwnd(&mut client, stream_id, now);
+ now = next_now;
+ next_c_tx_dgrams.append(&mut new_pkts);
+
+ let s_ack = ack_bytes(&mut server, stream_id, c_tx_dgrams, now);
+ assert_eq!(cwnd(&client), expected_cwnd);
+ client.process_input(&s_ack, now);
+ // make sure to fill cwnd again.
+ let (mut new_pkts, next_now) = fill_cwnd(&mut client, stream_id, now);
+ now = next_now;
+ next_c_tx_dgrams.append(&mut new_pkts);
+
+ expected_cwnd += MAX_DATAGRAM_SIZE;
+ assert_eq!(cwnd(&client), expected_cwnd);
+ c_tx_dgrams = next_c_tx_dgrams;
+ }
+}
+
+#[test]
+/// Verify transition to persistent congestion state if conditions are met.
+fn cc_slow_start_to_persistent_congestion_no_acks() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ let stream = client.stream_create(StreamType::BiDi).unwrap();
+
+ // Buffer up lot of data and generate packets
+ let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+
+ // Server: Receive and generate ack
+ now += DEFAULT_RTT / 2;
+ mem::drop(ack_bytes(&mut server, stream, c_tx_dgrams, now));
+
+ // ACK lost.
+ induce_persistent_congestion(&mut client, &mut server, stream, now);
+}
+
+#[test]
+/// Verify transition to persistent congestion state if conditions are met.
+fn cc_slow_start_to_persistent_congestion_some_acks() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Create stream 0
+ let stream = client.stream_create(StreamType::BiDi).unwrap();
+
+ // Buffer up lot of data and generate packets
+ let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+
+ // Server: Receive and generate ack
+ now += Duration::from_millis(100);
+ let s_ack = ack_bytes(&mut server, stream, c_tx_dgrams, now);
+
+ now += Duration::from_millis(100);
+ client.process_input(&s_ack, now);
+
+ // send bytes that will be lost
+ let (_, next_now) = fill_cwnd(&mut client, stream, now);
+ now = next_now + Duration::from_millis(100);
+
+ induce_persistent_congestion(&mut client, &mut server, stream, now);
+}
+
+#[test]
+/// Verify persistent congestion moves to slow start after recovery period
+/// ends.
+fn cc_persistent_congestion_to_slow_start() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Create stream 0
+ let stream = client.stream_create(StreamType::BiDi).unwrap();
+
+ // Buffer up lot of data and generate packets
+ let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, stream, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+
+ // Server: Receive and generate ack
+ now += Duration::from_millis(10);
+ mem::drop(ack_bytes(&mut server, stream, c_tx_dgrams, now));
+
+ // ACK lost.
+
+ now = induce_persistent_congestion(&mut client, &mut server, stream, now);
+
+ // New part of test starts here
+
+ now += Duration::from_millis(10);
+
+ // Send packets from after start of CARP
+ let (c_tx_dgrams, next_now) = fill_cwnd(&mut client, stream, now);
+ assert_eq!(c_tx_dgrams.len(), 2);
+
+ // Server: Receive and generate ack
+ now = next_now + Duration::from_millis(100);
+ let s_ack = ack_bytes(&mut server, stream, c_tx_dgrams, now);
+
+ // No longer in CARP. (pkts acked from after start of CARP)
+ // Should be in slow start now.
+ client.process_input(&s_ack, now);
+
+ // ACKing 2 packets should let client send 4.
+ let (c_tx_dgrams, _) = fill_cwnd(&mut client, stream, now);
+ assert_eq!(c_tx_dgrams.len(), 4);
+}
+
+#[test]
+fn ack_are_not_cc() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Create a stream
+ let stream = client.stream_create(StreamType::BiDi).unwrap();
+ assert_eq!(stream, 0);
+
+ // Buffer up lot of data and generate packets, so that cc window is filled.
+ let (c_tx_dgrams, now) = fill_cwnd(&mut client, stream, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+
+ // The server hasn't received any of these packets yet, the server
+ // won't ACK, but if it sends an ack-eliciting packet instead.
+ qdebug!([server], "Sending ack-eliciting");
+ let other_stream = server.stream_create(StreamType::BiDi).unwrap();
+ assert_eq!(other_stream, 1);
+ server.stream_send(other_stream, b"dropped").unwrap();
+ let dropped_packet = server.process(None, now).dgram();
+ assert!(dropped_packet.is_some()); // Now drop this one.
+
+ // Now the server sends a packet that will force an ACK,
+ // because the client will detect a gap.
+ server.stream_send(other_stream, b"sent").unwrap();
+ let ack_eliciting_packet = server.process(None, now).dgram();
+ assert!(ack_eliciting_packet.is_some());
+
+ // The client can ack the server packet even if cc windows is full.
+ qdebug!([client], "Process ack-eliciting");
+ let ack_pkt = client.process(ack_eliciting_packet.as_ref(), now).dgram();
+ assert!(ack_pkt.is_some());
+ qdebug!([server], "Handle ACK");
+ let prev_ack_count = server.stats().frame_rx.ack;
+ server.process_input(&ack_pkt.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.ack, prev_ack_count + 1);
+}
+
+#[test]
+fn pace() {
+ const DATA: &[u8] = &[0xcc; 4_096];
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Now fill up the pipe and watch it trickle out.
+ let stream = client.stream_create(StreamType::BiDi).unwrap();
+ loop {
+ let written = client.stream_send(stream, DATA).unwrap();
+ if written < DATA.len() {
+ break;
+ }
+ }
+ let mut count = 0;
+ // We should get a burst at first.
+ // The first packet is not subject to pacing as there are no bytes in flight.
+ // After that we allow the burst to continue up to a number of packets (2).
+ for _ in 0..=PACING_BURST_SIZE {
+ let dgram = client.process_output(now).dgram();
+ assert!(dgram.is_some());
+ count += 1;
+ }
+ let gap = client.process_output(now).callback();
+ assert_ne!(gap, Duration::new(0, 0));
+ for _ in (1 + PACING_BURST_SIZE)..cwnd_packets(POST_HANDSHAKE_CWND) {
+ match client.process_output(now) {
+ Output::Callback(t) => assert_eq!(t, gap),
+ Output::Datagram(_) => {
+ // The last packet might not be paced.
+ count += 1;
+ break;
+ }
+ Output::None => panic!(),
+ }
+ now += gap;
+ let dgram = client.process_output(now).dgram();
+ assert!(dgram.is_some());
+ count += 1;
+ }
+ let dgram = client.process_output(now).dgram();
+ assert!(dgram.is_none());
+ assert_eq!(count, cwnd_packets(POST_HANDSHAKE_CWND));
+ let fin = client.process_output(now).callback();
+ assert_ne!(fin, Duration::new(0, 0));
+ assert_ne!(fin, gap);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/close.rs b/third_party/rust/neqo-transport/src/connection/tests/close.rs
new file mode 100644
index 0000000000..f45e77e549
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/close.rs
@@ -0,0 +1,210 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::time::Duration;
+
+use test_fixture::{self, datagram, now};
+
+use super::{
+ super::{Connection, Output, State},
+ connect, connect_force_idle, default_client, default_server, send_something,
+};
+use crate::{
+ tparams::{self, TransportParameter},
+ AppError, ConnectionError, Error, ERROR_APPLICATION_CLOSE,
+};
+
+fn assert_draining(c: &Connection, expected: &Error) {
+ assert!(c.state().closed());
+ if let State::Draining {
+ error: ConnectionError::Transport(error),
+ ..
+ } = c.state()
+ {
+ assert_eq!(error, expected);
+ } else {
+ panic!();
+ }
+}
+
+#[test]
+fn connection_close() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let now = now();
+
+ client.close(now, 42, "");
+
+ let out = client.process(None, now);
+
+ server.process_input(&out.dgram().unwrap(), now);
+ assert_draining(&server, &Error::PeerApplicationError(42));
+}
+
+#[test]
+fn connection_close_with_long_reason_string() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let now = now();
+ // Create a long string and use it as the close reason.
+ let long_reason = String::from_utf8([0x61; 2048].to_vec()).unwrap();
+ client.close(now, 42, long_reason);
+
+ let out = client.process(None, now);
+
+ server.process_input(&out.dgram().unwrap(), now);
+ assert_draining(&server, &Error::PeerApplicationError(42));
+}
+
+// During the handshake, an application close should be sanitized.
+#[test]
+fn early_application_close() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ // One flight each.
+ let dgram = client.process(None, now()).dgram();
+ assert!(dgram.is_some());
+ let dgram = server.process(dgram.as_ref(), now()).dgram();
+ assert!(dgram.is_some());
+
+ server.close(now(), 77, String::new());
+ assert!(server.state().closed());
+ let dgram = server.process(None, now()).dgram();
+ assert!(dgram.is_some());
+
+ client.process_input(&dgram.unwrap(), now());
+ assert_draining(&client, &Error::PeerError(ERROR_APPLICATION_CLOSE));
+}
+
+#[test]
+fn bad_tls_version() {
+ let mut client = default_client();
+ // Do a bad, bad thing.
+ client
+ .crypto
+ .tls
+ .set_option(neqo_crypto::Opt::Tls13CompatMode, true)
+ .unwrap();
+ let mut server = default_server();
+
+ let dgram = client.process(None, now()).dgram();
+ assert!(dgram.is_some());
+ let dgram = server.process(dgram.as_ref(), now()).dgram();
+ assert_eq!(
+ *server.state(),
+ State::Closed(ConnectionError::Transport(Error::ProtocolViolation))
+ );
+ assert!(dgram.is_some());
+ client.process_input(&dgram.unwrap(), now());
+ assert_draining(&client, &Error::PeerError(Error::ProtocolViolation.code()));
+}
+
+/// Test the interaction between the loss recovery timer
+/// and the closing timer.
+#[test]
+fn closing_timers_interation() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let mut now = now();
+
+ // We're going to induce time-based loss recovery so that timer is set.
+ let _p1 = send_something(&mut client, now);
+ let p2 = send_something(&mut client, now);
+ let ack = server.process(Some(&p2), now).dgram();
+ assert!(ack.is_some()); // This is an ACK.
+
+ // After processing the ACK, we should be on the loss recovery timer.
+ let cb = client.process(ack.as_ref(), now).callback();
+ assert_ne!(cb, Duration::from_secs(0));
+ now += cb;
+
+ // Rather than let the timer pop, close the connection.
+ client.close(now, 0, "");
+ let client_close = client.process(None, now).dgram();
+ assert!(client_close.is_some());
+ // This should now report the end of the closing period, not a
+ // zero-duration wait driven by the (now defunct) loss recovery timer.
+ let client_close_timer = client.process(None, now).callback();
+ assert_ne!(client_close_timer, Duration::from_secs(0));
+}
+
+#[test]
+fn closing_and_draining() {
+ const APP_ERROR: AppError = 7;
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // Save a packet from the client for later.
+ let p1 = send_something(&mut client, now());
+
+ // Close the connection.
+ client.close(now(), APP_ERROR, "");
+ let client_close = client.process(None, now()).dgram();
+ assert!(client_close.is_some());
+ let client_close_timer = client.process(None, now()).callback();
+ assert_ne!(client_close_timer, Duration::from_secs(0));
+
+ // The client will spit out the same packet in response to anything it receives.
+ let p3 = send_something(&mut server, now());
+ let client_close2 = client.process(Some(&p3), now()).dgram();
+ assert_eq!(
+ client_close.as_ref().unwrap().len(),
+ client_close2.as_ref().unwrap().len()
+ );
+
+ // After this time, the client should transition to closed.
+ let end = client.process(None, now() + client_close_timer);
+ assert_eq!(end, Output::None);
+ assert_eq!(
+ *client.state(),
+ State::Closed(ConnectionError::Application(APP_ERROR))
+ );
+
+ // When the server receives the close, it too should generate CONNECTION_CLOSE.
+ let server_close = server.process(client_close.as_ref(), now()).dgram();
+ assert!(server.state().closed());
+ assert!(server_close.is_some());
+ // .. but it ignores any further close packets.
+ let server_close_timer = server.process(client_close2.as_ref(), now()).callback();
+ assert_ne!(server_close_timer, Duration::from_secs(0));
+ // Even a legitimate packet without a close in it.
+ let server_close_timer2 = server.process(Some(&p1), now()).callback();
+ assert_eq!(server_close_timer, server_close_timer2);
+
+ let end = server.process(None, now() + server_close_timer);
+ assert_eq!(end, Output::None);
+ assert_eq!(
+ *server.state(),
+ State::Closed(ConnectionError::Transport(Error::PeerApplicationError(
+ APP_ERROR
+ )))
+ );
+}
+
+/// Test that a client can handle a stateless reset correctly.
+#[test]
+fn stateless_reset_client() {
+ let mut client = default_client();
+ let mut server = default_server();
+ server
+ .set_local_tparam(
+ tparams::STATELESS_RESET_TOKEN,
+ TransportParameter::Bytes(vec![77; 16]),
+ )
+ .unwrap();
+ connect_force_idle(&mut client, &mut server);
+
+ client.process_input(&datagram(vec![77; 21]), now());
+ assert_draining(&client, &Error::StatelessReset);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/datagram.rs b/third_party/rust/neqo-transport/src/connection/tests/datagram.rs
new file mode 100644
index 0000000000..5b7b8dc0b4
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/datagram.rs
@@ -0,0 +1,620 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{cell::RefCell, convert::TryFrom, rc::Rc};
+
+use neqo_common::event::Provider;
+use test_fixture::now;
+
+use super::{
+ assert_error, connect_force_idle, default_client, default_server, new_client, new_server,
+ AT_LEAST_PTO,
+};
+use crate::{
+ events::{ConnectionEvent, OutgoingDatagramOutcome},
+ frame::FRAME_TYPE_DATAGRAM,
+ packet::PacketBuilder,
+ quic_datagrams::MAX_QUIC_DATAGRAM,
+ send_stream::{RetransmissionPriority, TransmissionPriority},
+ Connection, ConnectionError, ConnectionParameters, Error, StreamType,
+};
+
+const DATAGRAM_LEN_MTU: u64 = 1310;
+const DATA_MTU: &[u8] = &[1; 1310];
+const DATA_BIGGER_THAN_MTU: &[u8] = &[0; 2620];
+const DATAGRAM_LEN_SMALLER_THAN_MTU: u64 = 1200;
+const DATA_SMALLER_THAN_MTU: &[u8] = &[0; 1200];
+const DATA_SMALLER_THAN_MTU_2: &[u8] = &[0; 600];
+const OUTGOING_QUEUE: usize = 2;
+
+struct InsertDatagram<'a> {
+ data: &'a [u8],
+}
+
+impl crate::connection::test_internal::FrameWriter for InsertDatagram<'_> {
+ fn write_frames(&mut self, builder: &mut PacketBuilder) {
+ builder.encode_varint(FRAME_TYPE_DATAGRAM);
+ builder.encode(self.data);
+ }
+}
+
+#[test]
+fn datagram_disabled_both() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ assert_eq!(client.max_datagram_size(), Err(Error::NotAvailable));
+ assert_eq!(server.max_datagram_size(), Err(Error::NotAvailable));
+ assert_eq!(
+ client.send_datagram(DATA_SMALLER_THAN_MTU, None),
+ Err(Error::TooMuchData)
+ );
+ assert_eq!(server.stats().frame_tx.datagram, 0);
+ assert_eq!(
+ server.send_datagram(DATA_SMALLER_THAN_MTU, None),
+ Err(Error::TooMuchData)
+ );
+ assert_eq!(server.stats().frame_tx.datagram, 0);
+}
+
+#[test]
+fn datagram_enabled_on_client() {
+ let mut client =
+ new_client(ConnectionParameters::default().datagram_size(DATAGRAM_LEN_SMALLER_THAN_MTU));
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ assert_eq!(client.max_datagram_size(), Err(Error::NotAvailable));
+ assert_eq!(
+ server.max_datagram_size(),
+ Ok(DATAGRAM_LEN_SMALLER_THAN_MTU)
+ );
+ assert_eq!(
+ client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)),
+ Err(Error::TooMuchData)
+ );
+ let dgram_sent = server.stats().frame_tx.datagram;
+ assert_eq!(server.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(()));
+ let out = server.process_output(now()).dgram().unwrap();
+ assert_eq!(server.stats().frame_tx.datagram, dgram_sent + 1);
+
+ client.process_input(&out, now());
+ assert!(matches!(
+ client.next_event().unwrap(),
+ ConnectionEvent::Datagram(data) if data == DATA_SMALLER_THAN_MTU
+ ));
+}
+
+#[test]
+fn datagram_enabled_on_server() {
+ let mut client = default_client();
+ let mut server =
+ new_server(ConnectionParameters::default().datagram_size(DATAGRAM_LEN_SMALLER_THAN_MTU));
+ connect_force_idle(&mut client, &mut server);
+
+ assert_eq!(
+ client.max_datagram_size(),
+ Ok(DATAGRAM_LEN_SMALLER_THAN_MTU)
+ );
+ assert_eq!(server.max_datagram_size(), Err(Error::NotAvailable));
+ assert_eq!(
+ server.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)),
+ Err(Error::TooMuchData)
+ );
+ let dgram_sent = client.stats().frame_tx.datagram;
+ assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(()));
+ let out = client.process_output(now()).dgram().unwrap();
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1);
+
+ server.process_input(&out, now());
+ assert!(matches!(
+ server.next_event().unwrap(),
+ ConnectionEvent::Datagram(data) if data == DATA_SMALLER_THAN_MTU
+ ));
+}
+
+fn connect_datagram() -> (Connection, Connection) {
+ let mut client = new_client(
+ ConnectionParameters::default()
+ .datagram_size(MAX_QUIC_DATAGRAM)
+ .outgoing_datagram_queue(OUTGOING_QUEUE),
+ );
+ let mut server = new_server(ConnectionParameters::default().datagram_size(MAX_QUIC_DATAGRAM));
+ connect_force_idle(&mut client, &mut server);
+ (client, server)
+}
+
+#[test]
+fn mtu_limit() {
+ let (client, server) = connect_datagram();
+
+ assert_eq!(client.max_datagram_size(), Ok(DATAGRAM_LEN_MTU));
+ assert_eq!(server.max_datagram_size(), Ok(DATAGRAM_LEN_MTU));
+}
+
+#[test]
+fn limit_data_size() {
+ let (mut client, mut server) = connect_datagram();
+
+ assert!(u64::try_from(DATA_BIGGER_THAN_MTU.len()).unwrap() > DATAGRAM_LEN_MTU);
+ // Datagram can be queued because they are smaller than allowed by the peer,
+ // but they cannot be sent.
+ assert_eq!(server.send_datagram(DATA_BIGGER_THAN_MTU, Some(1)), Ok(()));
+
+ let dgram_dropped_s = server.stats().datagram_tx.dropped_too_big;
+ let dgram_sent_s = server.stats().frame_tx.datagram;
+ assert!(server.process_output(now()).dgram().is_none());
+ assert_eq!(
+ server.stats().datagram_tx.dropped_too_big,
+ dgram_dropped_s + 1
+ );
+ assert_eq!(server.stats().frame_tx.datagram, dgram_sent_s);
+ assert!(matches!(
+ server.next_event().unwrap(),
+ ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::DroppedTooBig
+ ));
+
+ // The same test for the client side.
+ assert_eq!(client.send_datagram(DATA_BIGGER_THAN_MTU, Some(1)), Ok(()));
+ let dgram_sent_c = client.stats().frame_tx.datagram;
+ assert!(client.process_output(now()).dgram().is_none());
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent_c);
+ assert!(matches!(
+ client.next_event().unwrap(),
+ ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::DroppedTooBig
+ ));
+}
+
+#[test]
+fn after_dgram_dropped_continue_writing_frames() {
+ let (mut client, _) = connect_datagram();
+
+ assert!(u64::try_from(DATA_BIGGER_THAN_MTU.len()).unwrap() > DATAGRAM_LEN_MTU);
+ // Datagram can be queued because they are smaller than allowed by the peer,
+ // but they cannot be sent.
+ assert_eq!(client.send_datagram(DATA_BIGGER_THAN_MTU, Some(1)), Ok(()));
+ assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(2)), Ok(()));
+
+ let datagram_dropped = |e| {
+ matches!(
+ e,
+ ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::DroppedTooBig)
+ };
+
+ let dgram_dropped_c = client.stats().datagram_tx.dropped_too_big;
+ let dgram_sent_c = client.stats().frame_tx.datagram;
+
+ assert!(client.process_output(now()).dgram().is_some());
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent_c + 1);
+ assert_eq!(
+ client.stats().datagram_tx.dropped_too_big,
+ dgram_dropped_c + 1
+ );
+ assert!(client.events().any(datagram_dropped));
+}
+
+#[test]
+fn datagram_acked() {
+ let (mut client, mut server) = connect_datagram();
+
+ let dgram_sent = client.stats().frame_tx.datagram;
+ assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(()));
+ let out = client.process_output(now()).dgram();
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1);
+
+ let dgram_received = server.stats().frame_rx.datagram;
+ server.process_input(&out.unwrap(), now());
+ assert_eq!(server.stats().frame_rx.datagram, dgram_received + 1);
+ let now = now() + AT_LEAST_PTO;
+ // Ack should be sent
+ let ack_sent = server.stats().frame_tx.ack;
+ let out = server.process_output(now).dgram();
+ assert_eq!(server.stats().frame_tx.ack, ack_sent + 1);
+
+ assert!(matches!(
+ server.next_event().unwrap(),
+ ConnectionEvent::Datagram(data) if data == DATA_SMALLER_THAN_MTU
+ ));
+
+ client.process_input(&out.unwrap(), now);
+ assert!(matches!(
+ client.next_event().unwrap(),
+ ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::Acked
+ ));
+}
+
+fn send_packet_and_get_server_event(
+ client: &mut Connection,
+ server: &mut Connection,
+) -> ConnectionEvent {
+ let out = client.process_output(now()).dgram();
+ server.process_input(&out.unwrap(), now());
+ let mut events: Vec<_> = server
+ .events()
+ .filter_map(|evt| match evt {
+ ConnectionEvent::RecvStreamReadable { .. } | ConnectionEvent::Datagram { .. } => {
+ Some(evt)
+ }
+ _ => None,
+ })
+ .collect();
+ // We should only get one event - either RecvStreamReadable or Datagram.
+ assert_eq!(events.len(), 1);
+ events.remove(0)
+}
+
+/// Write a datagram that is big enough to fill a packet, but then see that
+/// normal priority stream data is sent first.
+#[test]
+fn datagram_after_stream_data() {
+ let (mut client, mut server) = connect_datagram();
+
+ // Write a datagram first.
+ let dgram_sent = client.stats().frame_tx.datagram;
+ assert_eq!(client.send_datagram(DATA_MTU, Some(1)), Ok(()));
+
+ // Create a stream with normal priority and send some data.
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_send(stream_id, &[6; 1200]).unwrap();
+
+ assert!(
+ matches!(send_packet_and_get_server_event(&mut client, &mut server), ConnectionEvent::RecvStreamReadable { stream_id: s } if s == stream_id)
+ );
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent);
+
+ if let ConnectionEvent::Datagram(data) =
+ &send_packet_and_get_server_event(&mut client, &mut server)
+ {
+ assert_eq!(data, DATA_MTU);
+ } else {
+ panic!();
+ }
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1);
+}
+
+#[test]
+fn datagram_before_stream_data() {
+ let (mut client, mut server) = connect_datagram();
+
+ // Create a stream with low priority and send some data before datagram.
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client
+ .stream_priority(
+ stream_id,
+ TransmissionPriority::Low,
+ RetransmissionPriority::default(),
+ )
+ .unwrap();
+ client.stream_send(stream_id, &[6; 1200]).unwrap();
+
+ // Write a datagram.
+ let dgram_sent = client.stats().frame_tx.datagram;
+ assert_eq!(client.send_datagram(DATA_MTU, Some(1)), Ok(()));
+
+ if let ConnectionEvent::Datagram(data) =
+ &send_packet_and_get_server_event(&mut client, &mut server)
+ {
+ assert_eq!(data, DATA_MTU);
+ } else {
+ panic!();
+ }
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1);
+
+ assert!(
+ matches!(send_packet_and_get_server_event(&mut client, &mut server), ConnectionEvent::RecvStreamReadable { stream_id: s } if s == stream_id)
+ );
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1);
+}
+
+#[test]
+fn datagram_lost() {
+ let (mut client, _) = connect_datagram();
+
+ let dgram_sent = client.stats().frame_tx.datagram;
+ assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(()));
+ let _out = client.process_output(now()).dgram(); // This packet will be lost.
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1);
+
+ // Wait for PTO
+ let now = now() + AT_LEAST_PTO;
+ let dgram_sent2 = client.stats().frame_tx.datagram;
+ let pings_sent = client.stats().frame_tx.ping;
+ let dgram_lost = client.stats().datagram_tx.lost;
+ let out = client.process_output(now).dgram();
+ assert!(out.is_some()); // PING probing
+ // Datagram is not sent again.
+ assert_eq!(client.stats().frame_tx.ping, pings_sent + 1);
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent2);
+ assert_eq!(client.stats().datagram_tx.lost, dgram_lost + 1);
+
+ assert!(matches!(
+ client.next_event().unwrap(),
+ ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::Lost
+ ));
+}
+
+#[test]
+fn datagram_sent_once() {
+ let (mut client, _) = connect_datagram();
+
+ let dgram_sent = client.stats().frame_tx.datagram;
+ assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(()));
+ let _out = client.process_output(now()).dgram();
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1);
+
+ // Call process_output again should not send any new Datagram.
+ assert!(client.process_output(now()).dgram().is_none());
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1);
+}
+
+#[test]
+fn dgram_no_allowed() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+ server.test_frame_writer = Some(Box::new(InsertDatagram { data: DATA_MTU }));
+ let out = server.process_output(now()).dgram().unwrap();
+ server.test_frame_writer = None;
+
+ client.process_input(&out, now());
+
+ assert_error(
+ &client,
+ &ConnectionError::Transport(Error::ProtocolViolation),
+ );
+}
+
+#[test]
+#[allow(clippy::assertions_on_constants)] // this is a static assert, thanks
+fn dgram_too_big() {
+ let mut client =
+ new_client(ConnectionParameters::default().datagram_size(DATAGRAM_LEN_SMALLER_THAN_MTU));
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ assert!(DATAGRAM_LEN_MTU > DATAGRAM_LEN_SMALLER_THAN_MTU);
+ server.test_frame_writer = Some(Box::new(InsertDatagram { data: DATA_MTU }));
+ let out = server.process_output(now()).dgram().unwrap();
+ server.test_frame_writer = None;
+
+ client.process_input(&out, now());
+
+ assert_error(
+ &client,
+ &ConnectionError::Transport(Error::ProtocolViolation),
+ );
+}
+
+#[test]
+fn outgoing_datagram_queue_full() {
+ let (mut client, mut server) = connect_datagram();
+
+ let dgram_sent = client.stats().frame_tx.datagram;
+ assert_eq!(client.send_datagram(DATA_SMALLER_THAN_MTU, Some(1)), Ok(()));
+ assert_eq!(
+ client.send_datagram(DATA_SMALLER_THAN_MTU_2, Some(2)),
+ Ok(())
+ );
+
+ // The outgoing datagram queue limit is 2, therefore the datagram with id 1
+ // will be dropped after adding one more datagram.
+ let dgram_dropped = client.stats().datagram_tx.dropped_queue_full;
+ assert_eq!(client.send_datagram(DATA_MTU, Some(3)), Ok(()));
+ assert!(matches!(
+ client.next_event().unwrap(),
+ ConnectionEvent::OutgoingDatagramOutcome { id, outcome } if id == 1 && outcome == OutgoingDatagramOutcome::DroppedQueueFull
+ ));
+ assert_eq!(
+ client.stats().datagram_tx.dropped_queue_full,
+ dgram_dropped + 1
+ );
+
+ // Send DATA_SMALLER_THAN_MTU_2 datagram
+ let out = client.process_output(now()).dgram();
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 1);
+ server.process_input(&out.unwrap(), now());
+ assert!(matches!(
+ server.next_event().unwrap(),
+ ConnectionEvent::Datagram(data) if data == DATA_SMALLER_THAN_MTU_2
+ ));
+
+ // Send DATA_SMALLER_THAN_MTU_2 datagram
+ let dgram_sent2 = client.stats().frame_tx.datagram;
+ let out = client.process_output(now()).dgram();
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent2 + 1);
+ server.process_input(&out.unwrap(), now());
+ assert!(matches!(
+ server.next_event().unwrap(),
+ ConnectionEvent::Datagram(data) if data == DATA_MTU
+ ));
+}
+
+fn send_datagram(sender: &mut Connection, receiver: &mut Connection, data: &[u8]) {
+ let dgram_sent = sender.stats().frame_tx.datagram;
+ assert_eq!(sender.send_datagram(data, Some(1)), Ok(()));
+ let out = sender.process_output(now()).dgram().unwrap();
+ assert_eq!(sender.stats().frame_tx.datagram, dgram_sent + 1);
+
+ let dgram_received = receiver.stats().frame_rx.datagram;
+ receiver.process_input(&out, now());
+ assert_eq!(receiver.stats().frame_rx.datagram, dgram_received + 1);
+}
+
+#[test]
+fn multiple_datagram_events() {
+ const DATA_SIZE: usize = 1200;
+ const MAX_QUEUE: usize = 3;
+ const FIRST_DATAGRAM: &[u8] = &[0; DATA_SIZE];
+ const SECOND_DATAGRAM: &[u8] = &[1; DATA_SIZE];
+ const THIRD_DATAGRAM: &[u8] = &[2; DATA_SIZE];
+ const FOURTH_DATAGRAM: &[u8] = &[3; DATA_SIZE];
+
+ let mut client = new_client(
+ ConnectionParameters::default()
+ .datagram_size(u64::try_from(DATA_SIZE).unwrap())
+ .incoming_datagram_queue(MAX_QUEUE),
+ );
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ send_datagram(&mut server, &mut client, FIRST_DATAGRAM);
+ send_datagram(&mut server, &mut client, SECOND_DATAGRAM);
+ send_datagram(&mut server, &mut client, THIRD_DATAGRAM);
+
+ let mut datagrams = client.events().filter_map(|evt| {
+ if let ConnectionEvent::Datagram(d) = evt {
+ Some(d)
+ } else {
+ None
+ }
+ });
+ assert_eq!(datagrams.next().unwrap(), FIRST_DATAGRAM);
+ assert_eq!(datagrams.next().unwrap(), SECOND_DATAGRAM);
+ assert_eq!(datagrams.next().unwrap(), THIRD_DATAGRAM);
+ assert!(datagrams.next().is_none());
+
+ // New events can be queued.
+ send_datagram(&mut server, &mut client, FOURTH_DATAGRAM);
+ let mut datagrams = client.events().filter_map(|evt| {
+ if let ConnectionEvent::Datagram(d) = evt {
+ Some(d)
+ } else {
+ None
+ }
+ });
+ assert_eq!(datagrams.next().unwrap(), FOURTH_DATAGRAM);
+ assert!(datagrams.next().is_none());
+}
+
+#[test]
+fn too_many_datagram_events() {
+ const DATA_SIZE: usize = 1200;
+ const MAX_QUEUE: usize = 2;
+ const FIRST_DATAGRAM: &[u8] = &[0; DATA_SIZE];
+ const SECOND_DATAGRAM: &[u8] = &[1; DATA_SIZE];
+ const THIRD_DATAGRAM: &[u8] = &[2; DATA_SIZE];
+ const FOURTH_DATAGRAM: &[u8] = &[3; DATA_SIZE];
+
+ let mut client = new_client(
+ ConnectionParameters::default()
+ .datagram_size(u64::try_from(DATA_SIZE).unwrap())
+ .incoming_datagram_queue(MAX_QUEUE),
+ );
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ send_datagram(&mut server, &mut client, FIRST_DATAGRAM);
+ send_datagram(&mut server, &mut client, SECOND_DATAGRAM);
+ send_datagram(&mut server, &mut client, THIRD_DATAGRAM);
+
+ // Datagram with FIRST_DATAGRAM data will be dropped.
+ assert!(matches!(
+ client.next_event().unwrap(),
+ ConnectionEvent::Datagram(data) if data == SECOND_DATAGRAM
+ ));
+ assert!(matches!(
+ client.next_event().unwrap(),
+ ConnectionEvent::IncomingDatagramDropped
+ ));
+ assert!(matches!(
+ client.next_event().unwrap(),
+ ConnectionEvent::Datagram(data) if data == THIRD_DATAGRAM
+ ));
+ assert!(client.next_event().is_none());
+ assert_eq!(client.stats().incoming_datagram_dropped, 1);
+
+ // New events can be queued.
+ send_datagram(&mut server, &mut client, FOURTH_DATAGRAM);
+ assert!(matches!(
+ client.next_event().unwrap(),
+ ConnectionEvent::Datagram(data) if data == FOURTH_DATAGRAM
+ ));
+ assert!(client.next_event().is_none());
+ assert_eq!(client.stats().incoming_datagram_dropped, 1);
+}
+
+#[test]
+fn multiple_quic_datagrams_in_one_packet() {
+ let (mut client, mut server) = connect_datagram();
+
+ let dgram_sent = client.stats().frame_tx.datagram;
+ // Enqueue 2 datagrams that can fit in a single packet.
+ assert_eq!(
+ client.send_datagram(DATA_SMALLER_THAN_MTU_2, Some(1)),
+ Ok(())
+ );
+ assert_eq!(
+ client.send_datagram(DATA_SMALLER_THAN_MTU_2, Some(2)),
+ Ok(())
+ );
+
+ let out = client.process_output(now()).dgram();
+ assert_eq!(client.stats().frame_tx.datagram, dgram_sent + 2);
+ server.process_input(&out.unwrap(), now());
+ let datagram = |e: &_| matches!(e, ConnectionEvent::Datagram(..));
+ assert_eq!(server.events().filter(datagram).count(), 2);
+}
+
+/// Datagrams that are close to the capacity of the packet need special
+/// handling. They need to use the packet-filling frame type and
+/// they cannot allow other frames to follow.
+#[test]
+fn datagram_fill() {
+ struct PanickingFrameWriter {}
+ impl crate::connection::test_internal::FrameWriter for PanickingFrameWriter {
+ fn write_frames(&mut self, builder: &mut PacketBuilder) {
+ panic!(
+ "builder invoked with {} bytes remaining",
+ builder.remaining()
+ );
+ }
+ }
+ struct TrackingFrameWriter {
+ called: Rc<RefCell<bool>>,
+ }
+ impl crate::connection::test_internal::FrameWriter for TrackingFrameWriter {
+ fn write_frames(&mut self, builder: &mut PacketBuilder) {
+ assert_eq!(builder.remaining(), 2);
+ *self.called.borrow_mut() = true;
+ }
+ }
+
+ let (mut client, mut server) = connect_datagram();
+
+ // Work out how much space we have for a datagram.
+ let space = {
+ let p = client.paths.primary();
+ let path = p.borrow();
+ // Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number,
+ // 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD.
+ path.mtu() - path.remote_cid().len() - 19
+ };
+ assert!(space >= 64); // Unlikely, but this test depends on the datagram being this large.
+
+ // This should not be called.
+ client.test_frame_writer = Some(Box::new(PanickingFrameWriter {}));
+
+ let buf = vec![9; space];
+ // This will completely fill available space.
+ send_datagram(&mut client, &mut server, &buf);
+ // This will leave 1 byte free, but more frames won't be added in this space.
+ send_datagram(&mut client, &mut server, &buf[..buf.len() - 1]);
+ // This will leave 2 bytes free, which is enough space for a length field,
+ // but not enough space for another frame after that.
+ send_datagram(&mut client, &mut server, &buf[..buf.len() - 2]);
+ // Three bytes free will be space enough for a length frame, but not enough
+ // space left over for another frame (we need 2 bytes).
+ send_datagram(&mut client, &mut server, &buf[..buf.len() - 3]);
+
+ // Four bytes free is enough space for another frame.
+ let called = Rc::new(RefCell::new(false));
+ client.test_frame_writer = Some(Box::new(TrackingFrameWriter {
+ called: Rc::clone(&called),
+ }));
+ send_datagram(&mut client, &mut server, &buf[..buf.len() - 4]);
+ assert!(*called.borrow());
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/fuzzing.rs b/third_party/rust/neqo-transport/src/connection/tests/fuzzing.rs
new file mode 100644
index 0000000000..5425e1a16e
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/fuzzing.rs
@@ -0,0 +1,44 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![cfg_attr(feature = "deny-warnings", deny(warnings))]
+#![warn(clippy::pedantic)]
+#![cfg(feature = "fuzzing")]
+
+use neqo_crypto::FIXED_TAG_FUZZING;
+use test_fixture::now;
+
+use super::{connect_force_idle, default_client, default_server};
+use crate::StreamType;
+
+#[test]
+fn no_encryption() {
+ const DATA_CLIENT: &[u8] = &[2; 40];
+ const DATA_SERVER: &[u8] = &[3; 50];
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+
+ client.stream_send(stream_id, DATA_CLIENT).unwrap();
+ let client_pkt = client.process_output(now()).dgram().unwrap();
+ assert!(client_pkt[..client_pkt.len() - FIXED_TAG_FUZZING.len()].ends_with(DATA_CLIENT));
+
+ server.process_input(&client_pkt, now());
+ let mut buf = vec![0; 100];
+ let (len, _) = server.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(len, DATA_CLIENT.len());
+ assert_eq!(&buf[..len], DATA_CLIENT);
+ server.stream_send(stream_id, DATA_SERVER).unwrap();
+ let server_pkt = server.process_output(now()).dgram().unwrap();
+ assert!(server_pkt[..server_pkt.len() - FIXED_TAG_FUZZING.len()].ends_with(DATA_SERVER));
+
+ client.process_input(&server_pkt, now());
+ let (len, _) = client.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(len, DATA_SERVER.len());
+ assert_eq!(&buf[..len], DATA_SERVER);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/handshake.rs b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs
new file mode 100644
index 0000000000..93385ac1bc
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs
@@ -0,0 +1,1137 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{
+ cell::RefCell,
+ convert::TryFrom,
+ mem,
+ net::{IpAddr, Ipv6Addr, SocketAddr},
+ rc::Rc,
+ time::Duration,
+};
+
+use neqo_common::{event::Provider, qdebug, Datagram};
+use neqo_crypto::{
+ constants::TLS_CHACHA20_POLY1305_SHA256, generate_ech_keys, AuthenticationStatus,
+};
+use test_fixture::{
+ self, addr, assertions, assertions::assert_coalesced_0rtt, datagram, fixture_init, now,
+ split_datagram,
+};
+
+use super::{
+ super::{Connection, Output, State},
+ assert_error, connect, connect_force_idle, connect_with_rtt, default_client, default_server,
+ get_tokens, handshake, maybe_authenticate, resumed_server, send_something,
+ CountingConnectionIdGenerator, AT_LEAST_PTO, DEFAULT_RTT, DEFAULT_STREAM_DATA,
+};
+use crate::{
+ connection::AddressValidation,
+ events::ConnectionEvent,
+ path::PATH_MTU_V6,
+ server::ValidateAddress,
+ tparams::{TransportParameter, MIN_ACK_DELAY},
+ tracking::DEFAULT_ACK_DELAY,
+ ConnectionError, ConnectionParameters, EmptyConnectionIdGenerator, Error, StreamType, Version,
+};
+
+const ECH_CONFIG_ID: u8 = 7;
+const ECH_PUBLIC_NAME: &str = "public.example";
+
+#[test]
+fn full_handshake() {
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6);
+
+ qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN");
+ let mut server = default_server();
+ let out = server.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6);
+
+ qdebug!("---- client: cert verification");
+ let out = client.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_some());
+
+ let out = server.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_none());
+
+ assert!(maybe_authenticate(&mut client));
+
+ qdebug!("---- client: SH..FIN -> FIN");
+ let out = client.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(*client.state(), State::Connected);
+
+ qdebug!("---- server: FIN -> ACKS");
+ let out = server.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(*server.state(), State::Confirmed);
+
+ qdebug!("---- client: ACKS -> 0");
+ let out = client.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_none());
+ assert_eq!(*client.state(), State::Confirmed);
+}
+
+#[test]
+fn handshake_failed_authentication() {
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+
+ qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN");
+ let mut server = default_server();
+ let out = server.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_some());
+
+ qdebug!("---- client: cert verification");
+ let out = client.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_some());
+
+ let out = server.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_none());
+
+ let authentication_needed = |e| matches!(e, ConnectionEvent::AuthenticationNeeded);
+ assert!(client.events().any(authentication_needed));
+ qdebug!("---- client: Alert(certificate_revoked)");
+ client.authenticated(AuthenticationStatus::CertRevoked, now());
+
+ qdebug!("---- client: -> Alert(certificate_revoked)");
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+
+ qdebug!("---- server: Alert(certificate_revoked)");
+ let out = server.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_error(&client, &ConnectionError::Transport(Error::CryptoAlert(44)));
+ assert_error(&server, &ConnectionError::Transport(Error::PeerError(300)));
+}
+
+#[test]
+fn no_alpn() {
+ fixture_init();
+ let mut client = Connection::new_client(
+ "example.com",
+ &["bad-alpn"],
+ Rc::new(RefCell::new(CountingConnectionIdGenerator::default())),
+ addr(),
+ addr(),
+ ConnectionParameters::default(),
+ now(),
+ )
+ .unwrap();
+ let mut server = default_server();
+
+ handshake(&mut client, &mut server, now(), Duration::new(0, 0));
+ // TODO (mt): errors are immediate, which means that we never send CONNECTION_CLOSE
+ // and the client never sees the server's rejection of its handshake.
+ // assert_error(&client, ConnectionError::Transport(Error::CryptoAlert(120)));
+ assert_error(
+ &server,
+ &ConnectionError::Transport(Error::CryptoAlert(120)),
+ );
+}
+
+#[test]
+fn dup_server_flight1() {
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6);
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+
+ qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN");
+ let mut server = default_server();
+ let out_to_rep = server.process(out.as_dgram_ref(), now());
+ assert!(out_to_rep.as_dgram_ref().is_some());
+ qdebug!("Output={:0x?}", out_to_rep.as_dgram_ref());
+
+ qdebug!("---- client: cert verification");
+ let out = client.process(Some(out_to_rep.as_dgram_ref().unwrap()), now());
+ assert!(out.as_dgram_ref().is_some());
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+
+ let out = server.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_none());
+
+ assert!(maybe_authenticate(&mut client));
+
+ qdebug!("---- client: SH..FIN -> FIN");
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+
+ assert_eq!(3, client.stats().packets_rx);
+ assert_eq!(0, client.stats().dups_rx);
+ assert_eq!(1, client.stats().dropped_rx);
+
+ qdebug!("---- Dup, ignored");
+ let out = client.process(out_to_rep.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_none());
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+
+ // Four packets total received, 1 of them is a dup and one has been dropped because Initial keys
+ // are dropped. Add 2 counts of the padding that the server adds to Initial packets.
+ assert_eq!(6, client.stats().packets_rx);
+ assert_eq!(1, client.stats().dups_rx);
+ assert_eq!(3, client.stats().dropped_rx);
+}
+
+// Test that we split crypto data if they cannot fit into one packet.
+// To test this we will use a long server certificate.
+#[test]
+fn crypto_frame_split() {
+ let mut client = default_client();
+
+ let mut server = Connection::new_server(
+ test_fixture::LONG_CERT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(CountingConnectionIdGenerator::default())),
+ ConnectionParameters::default(),
+ )
+ .expect("create a server");
+
+ let client1 = client.process(None, now());
+ assert!(client1.as_dgram_ref().is_some());
+
+ // The entire server flight doesn't fit in a single packet because the
+ // certificate is large, therefore the server will produce 2 packets.
+ let server1 = server.process(client1.as_dgram_ref(), now());
+ assert!(server1.as_dgram_ref().is_some());
+ let server2 = server.process(None, now());
+ assert!(server2.as_dgram_ref().is_some());
+
+ let client2 = client.process(server1.as_dgram_ref(), now());
+ // This is an ack.
+ assert!(client2.as_dgram_ref().is_some());
+ // The client might have the certificate now, so we can't guarantee that
+ // this will work.
+ let auth1 = maybe_authenticate(&mut client);
+ assert_eq!(*client.state(), State::Handshaking);
+
+ // let server process the ack for the first packet.
+ let server3 = server.process(client2.as_dgram_ref(), now());
+ assert!(server3.as_dgram_ref().is_none());
+
+ // Consume the second packet from the server.
+ let client3 = client.process(server2.as_dgram_ref(), now());
+
+ // Check authentication.
+ let auth2 = maybe_authenticate(&mut client);
+ assert!(auth1 ^ auth2);
+ // Now client has all data to finish handshake.
+ assert_eq!(*client.state(), State::Connected);
+
+ let client4 = client.process(server3.as_dgram_ref(), now());
+ // One of these will contain data depending on whether Authentication was completed
+ // after the first or second server packet.
+ assert!(client3.as_dgram_ref().is_some() ^ client4.as_dgram_ref().is_some());
+
+ mem::drop(server.process(client3.as_dgram_ref(), now()));
+ mem::drop(server.process(client4.as_dgram_ref(), now()));
+
+ assert_eq!(*client.state(), State::Connected);
+ assert_eq!(*server.state(), State::Confirmed);
+}
+
+/// Run a single ChaCha20-Poly1305 test and get a PTO.
+#[test]
+fn chacha20poly1305() {
+ let mut server = default_server();
+ let mut client = Connection::new_client(
+ test_fixture::DEFAULT_SERVER_NAME,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(EmptyConnectionIdGenerator::default())),
+ addr(),
+ addr(),
+ ConnectionParameters::default(),
+ now(),
+ )
+ .expect("create a default client");
+ client.set_ciphers(&[TLS_CHACHA20_POLY1305_SHA256]).unwrap();
+ connect_force_idle(&mut client, &mut server);
+}
+
+/// Test that a server can send 0.5 RTT application data.
+#[test]
+fn send_05rtt() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ let c1 = client.process(None, now()).dgram();
+ assert!(c1.is_some());
+ let s1 = server.process(c1.as_ref(), now()).dgram().unwrap();
+ assert_eq!(s1.len(), PATH_MTU_V6);
+
+ // The server should accept writes at this point.
+ let s2 = send_something(&mut server, now());
+
+ // Complete the handshake at the client.
+ client.process_input(&s1, now());
+ maybe_authenticate(&mut client);
+ assert_eq!(*client.state(), State::Connected);
+
+ // The client should receive the 0.5-RTT data now.
+ client.process_input(&s2, now());
+ let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1];
+ let stream_id = client
+ .events()
+ .find_map(|e| {
+ if let ConnectionEvent::RecvStreamReadable { stream_id } = e {
+ Some(stream_id)
+ } else {
+ None
+ }
+ })
+ .unwrap();
+ let (l, ended) = client.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(&buf[..l], DEFAULT_STREAM_DATA);
+ assert!(ended);
+}
+
+/// Test that a client buffers 0.5-RTT data when it arrives early.
+#[test]
+fn reorder_05rtt() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ let c1 = client.process(None, now()).dgram();
+ assert!(c1.is_some());
+ let s1 = server.process(c1.as_ref(), now()).dgram().unwrap();
+
+ // The server should accept writes at this point.
+ let s2 = send_something(&mut server, now());
+
+ // We can't use the standard facility to complete the handshake, so
+ // drive it as aggressively as possible.
+ client.process_input(&s2, now());
+ assert_eq!(client.stats().saved_datagrams, 1);
+
+ // After processing the first packet, the client should go back and
+ // process the 0.5-RTT packet data, which should make data available.
+ client.process_input(&s1, now());
+ // We can't use `maybe_authenticate` here as that consumes events.
+ client.authenticated(AuthenticationStatus::Ok, now());
+ assert_eq!(*client.state(), State::Connected);
+
+ let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1];
+ let stream_id = client
+ .events()
+ .find_map(|e| {
+ if let ConnectionEvent::RecvStreamReadable { stream_id } = e {
+ Some(stream_id)
+ } else {
+ None
+ }
+ })
+ .unwrap();
+ let (l, ended) = client.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(&buf[..l], DEFAULT_STREAM_DATA);
+ assert!(ended);
+}
+
+#[test]
+fn reorder_05rtt_with_0rtt() {
+ const RTT: Duration = Duration::from_millis(100);
+
+ let mut client = default_client();
+ let mut server = default_server();
+ let validation = AddressValidation::new(now(), ValidateAddress::NoToken).unwrap();
+ let validation = Rc::new(RefCell::new(validation));
+ server.set_validation(Rc::clone(&validation));
+ let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT);
+
+ // Include RTT in sending the ticket or the ticket age reported by the
+ // client is wrong, which causes the server to reject 0-RTT.
+ now += RTT / 2;
+ server.send_ticket(now, &[]).unwrap();
+ let ticket = server.process_output(now).dgram().unwrap();
+ now += RTT / 2;
+ client.process_input(&ticket, now);
+
+ let token = get_tokens(&mut client).pop().unwrap();
+ let mut client = default_client();
+ client.enable_resumption(now, token).unwrap();
+ let mut server = resumed_server(&client);
+
+ // Send ClientHello and some 0-RTT.
+ let c1 = send_something(&mut client, now);
+ assertions::assert_coalesced_0rtt(&c1[..]);
+ // Drop the 0-RTT from the coalesced datagram, so that the server
+ // acknowledges the next 0-RTT packet.
+ let (c1, _) = split_datagram(&c1);
+ let c2 = send_something(&mut client, now);
+
+ // Handle the first packet and send 0.5-RTT in response. Drop the response.
+ now += RTT / 2;
+ mem::drop(server.process(Some(&c1), now).dgram().unwrap());
+ // The gap in 0-RTT will result in this 0.5 RTT containing an ACK.
+ server.process_input(&c2, now);
+ let s2 = send_something(&mut server, now);
+
+ // Save the 0.5 RTT.
+ now += RTT / 2;
+ client.process_input(&s2, now);
+ assert_eq!(client.stats().saved_datagrams, 1);
+
+ // Now PTO at the client and cause the server to re-send handshake packets.
+ now += AT_LEAST_PTO;
+ let c3 = client.process(None, now).dgram();
+ assert_coalesced_0rtt(c3.as_ref().unwrap());
+
+ now += RTT / 2;
+ let s3 = server.process(c3.as_ref(), now).dgram().unwrap();
+
+ // The client should be able to process the 0.5 RTT now.
+ // This should contain an ACK, so we are processing an ACK from the past.
+ now += RTT / 2;
+ client.process_input(&s3, now);
+ maybe_authenticate(&mut client);
+ let c4 = client.process(None, now).dgram();
+ assert_eq!(*client.state(), State::Connected);
+ assert_eq!(client.paths.rtt(), RTT);
+
+ now += RTT / 2;
+ server.process_input(&c4.unwrap(), now);
+ assert_eq!(*server.state(), State::Confirmed);
+ // Don't check server RTT as it will be massively inflated by a
+ // poor initial estimate received when the server dropped the
+ // Initial packet number space.
+}
+
+/// Test that a server that coalesces 0.5 RTT with handshake packets
+/// doesn't cause the client to drop application data.
+#[test]
+fn coalesce_05rtt() {
+ const RTT: Duration = Duration::from_millis(100);
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+
+ // The first exchange doesn't offer a chance for the server to send.
+ // So drop the server flight and wait for the PTO.
+ let c1 = client.process(None, now).dgram();
+ assert!(c1.is_some());
+ now += RTT / 2;
+ let s1 = server.process(c1.as_ref(), now).dgram();
+ assert!(s1.is_some());
+
+ // Drop the server flight. Then send some data.
+ let stream_id = server.stream_create(StreamType::UniDi).unwrap();
+ assert!(server.stream_send(stream_id, DEFAULT_STREAM_DATA).is_ok());
+ assert!(server.stream_close_send(stream_id).is_ok());
+
+ // Now after a PTO the client can send another packet.
+ // The server should then send its entire flight again,
+ // including the application data, which it sends in a 1-RTT packet.
+ now += AT_LEAST_PTO;
+ let c2 = client.process(None, now).dgram();
+ assert!(c2.is_some());
+ now += RTT / 2;
+ let s2 = server.process(c2.as_ref(), now).dgram();
+ // Even though there is a 1-RTT packet at the end of the datagram, the
+ // flight should be padded to full size.
+ assert_eq!(s2.as_ref().unwrap().len(), PATH_MTU_V6);
+
+ // The client should process the datagram. It can't process the 1-RTT
+ // packet until authentication completes though. So it saves it.
+ now += RTT / 2;
+ assert_eq!(client.stats().dropped_rx, 0);
+ mem::drop(client.process(s2.as_ref(), now).dgram());
+ // This packet will contain an ACK, but we can ignore it.
+ assert_eq!(client.stats().dropped_rx, 0);
+ assert_eq!(client.stats().packets_rx, 3);
+ assert_eq!(client.stats().saved_datagrams, 1);
+
+ // After (successful) authentication, the packet is processed.
+ maybe_authenticate(&mut client);
+ let c3 = client.process(None, now).dgram();
+ assert!(c3.is_some());
+ assert_eq!(client.stats().dropped_rx, 0); // No Initial padding.
+ assert_eq!(client.stats().packets_rx, 4);
+ assert_eq!(client.stats().saved_datagrams, 1);
+ assert_eq!(client.stats().frame_rx.padding, 1); // Padding uses frames.
+
+ // Allow the handshake to complete.
+ now += RTT / 2;
+ let s3 = server.process(c3.as_ref(), now).dgram();
+ assert!(s3.is_some());
+ assert_eq!(*server.state(), State::Confirmed);
+ now += RTT / 2;
+ mem::drop(client.process(s3.as_ref(), now).dgram());
+ assert_eq!(*client.state(), State::Confirmed);
+
+ assert_eq!(client.stats().dropped_rx, 0); // No dropped packets.
+}
+
+#[test]
+fn reorder_handshake() {
+ const RTT: Duration = Duration::from_millis(100);
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+
+ let c1 = client.process(None, now).dgram();
+ assert!(c1.is_some());
+
+ now += RTT / 2;
+ let s1 = server.process(c1.as_ref(), now).dgram();
+ assert!(s1.is_some());
+
+ // Drop the Initial packet from this.
+ let (_, s_hs) = split_datagram(&s1.unwrap());
+ assert!(s_hs.is_some());
+
+ // Pass just the handshake packet in and the client can't handle it yet.
+ // It can only send another Initial packet.
+ now += RTT / 2;
+ let dgram = client.process(s_hs.as_ref(), now).dgram();
+ assertions::assert_initial(dgram.as_ref().unwrap(), false);
+ assert_eq!(client.stats().saved_datagrams, 1);
+ assert_eq!(client.stats().packets_rx, 1);
+
+ // Get the server to try again.
+ // Though we currently allow the server to arm its PTO timer, use
+ // a second client Initial packet to cause it to send again.
+ now += AT_LEAST_PTO;
+ let c2 = client.process(None, now).dgram();
+ now += RTT / 2;
+ let s2 = server.process(c2.as_ref(), now).dgram();
+ assert!(s2.is_some());
+
+ let (s_init, s_hs) = split_datagram(&s2.unwrap());
+ assert!(s_hs.is_some());
+
+ // Processing the Handshake packet first should save it.
+ now += RTT / 2;
+ client.process_input(&s_hs.unwrap(), now);
+ assert_eq!(client.stats().saved_datagrams, 2);
+ assert_eq!(client.stats().packets_rx, 2);
+
+ client.process_input(&s_init, now);
+ // Each saved packet should now be "received" again.
+ assert_eq!(client.stats().packets_rx, 7);
+ maybe_authenticate(&mut client);
+ let c3 = client.process(None, now).dgram();
+ assert!(c3.is_some());
+
+ // Note that though packets were saved and processed very late,
+ // they don't cause the RTT to change.
+ now += RTT / 2;
+ let s3 = server.process(c3.as_ref(), now).dgram();
+ assert_eq!(*server.state(), State::Confirmed);
+ // Don't check server RTT estimate as it will be inflated due to
+ // it making a guess based on retransmissions when it dropped
+ // the Initial packet number space.
+
+ now += RTT / 2;
+ client.process_input(&s3.unwrap(), now);
+ assert_eq!(*client.state(), State::Confirmed);
+ assert_eq!(client.paths.rtt(), RTT);
+}
+
+#[test]
+fn reorder_1rtt() {
+ const RTT: Duration = Duration::from_millis(100);
+ const PACKETS: usize = 4; // Many, but not enough to overflow cwnd.
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+
+ let c1 = client.process(None, now).dgram();
+ assert!(c1.is_some());
+
+ now += RTT / 2;
+ let s1 = server.process(c1.as_ref(), now).dgram();
+ assert!(s1.is_some());
+
+ now += RTT / 2;
+ client.process_input(&s1.unwrap(), now);
+ maybe_authenticate(&mut client);
+ let c2 = client.process(None, now).dgram();
+ assert!(c2.is_some());
+
+ // Now get a bunch of packets from the client.
+ // Give them to the server before giving it `c2`.
+ for _ in 0..PACKETS {
+ let d = send_something(&mut client, now);
+ server.process_input(&d, now + RTT / 2);
+ }
+ // The server has now received those packets, and saved them.
+ // The two extra received are Initial + the junk we use for padding.
+ assert_eq!(server.stats().packets_rx, PACKETS + 2);
+ assert_eq!(server.stats().saved_datagrams, PACKETS);
+ assert_eq!(server.stats().dropped_rx, 1);
+
+ now += RTT / 2;
+ let s2 = server.process(c2.as_ref(), now).dgram();
+ // The server has now received those packets, and saved them.
+ // The two additional are a Handshake and a 1-RTT (w/ NEW_CONNECTION_ID).
+ assert_eq!(server.stats().packets_rx, PACKETS * 2 + 4);
+ assert_eq!(server.stats().saved_datagrams, PACKETS);
+ assert_eq!(server.stats().dropped_rx, 1);
+ assert_eq!(*server.state(), State::Confirmed);
+ assert_eq!(server.paths.rtt(), RTT);
+
+ now += RTT / 2;
+ client.process_input(&s2.unwrap(), now);
+ assert_eq!(client.paths.rtt(), RTT);
+
+ // All the stream data that was sent should now be available.
+ let streams = server
+ .events()
+ .filter_map(|e| {
+ if let ConnectionEvent::RecvStreamReadable { stream_id } = e {
+ Some(stream_id)
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<_>>();
+ assert_eq!(streams.len(), PACKETS);
+ for stream_id in streams {
+ let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1];
+ let (recvd, fin) = server.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(recvd, DEFAULT_STREAM_DATA.len());
+ assert!(fin);
+ }
+}
+
+#[cfg(not(feature = "fuzzing"))]
+#[test]
+fn corrupted_initial() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let d = client.process(None, now()).dgram().unwrap();
+ let mut corrupted = Vec::from(&d[..]);
+ // Find the last non-zero value and corrupt that.
+ let (idx, _) = corrupted
+ .iter()
+ .enumerate()
+ .rev()
+ .find(|(_, &v)| v != 0)
+ .unwrap();
+ corrupted[idx] ^= 0x76;
+ let dgram = Datagram::new(d.source(), d.destination(), d.tos(), d.ttl(), corrupted);
+ server.process_input(&dgram, now());
+ // The server should have received two packets,
+ // the first should be dropped, the second saved.
+ assert_eq!(server.stats().packets_rx, 2);
+ assert_eq!(server.stats().dropped_rx, 2);
+ assert_eq!(server.stats().saved_datagrams, 0);
+}
+
+#[test]
+// Absent path PTU discovery, max v6 packet size should be PATH_MTU_V6.
+fn verify_pkt_honors_mtu() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let now = now();
+
+ let res = client.process(None, now);
+ let idle_timeout = ConnectionParameters::default().get_idle_timeout();
+ assert_eq!(res, Output::Callback(idle_timeout));
+
+ // Try to send a large stream and verify first packet is correctly sized
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(client.stream_send(stream_id, &[0xbb; 2000]).unwrap(), 2000);
+ let pkt0 = client.process(None, now);
+ assert!(matches!(pkt0, Output::Datagram(_)));
+ assert_eq!(pkt0.as_dgram_ref().unwrap().len(), PATH_MTU_V6);
+}
+
+#[test]
+fn extra_initial_hs() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+
+ let c_init = client.process(None, now).dgram();
+ assert!(c_init.is_some());
+ now += DEFAULT_RTT / 2;
+ let s_init = server.process(c_init.as_ref(), now).dgram();
+ assert!(s_init.is_some());
+ now += DEFAULT_RTT / 2;
+
+ // Drop the Initial packet, keep only the Handshake.
+ let (_, undecryptable) = split_datagram(&s_init.unwrap());
+ assert!(undecryptable.is_some());
+
+ // Feed the same undecryptable packet into the client a few times.
+ // Do that EXTRA_INITIALS times and each time the client will emit
+ // another Initial packet.
+ for _ in 0..=super::super::EXTRA_INITIALS {
+ let c_init = client.process(undecryptable.as_ref(), now).dgram();
+ assertions::assert_initial(c_init.as_ref().unwrap(), false);
+ now += DEFAULT_RTT / 10;
+ }
+
+ // After EXTRA_INITIALS, the client stops sending Initial packets.
+ let nothing = client.process(undecryptable.as_ref(), now).dgram();
+ assert!(nothing.is_none());
+
+ // Until PTO, where another Initial can be used to complete the handshake.
+ now += AT_LEAST_PTO;
+ let c_init = client.process(None, now).dgram();
+ assertions::assert_initial(c_init.as_ref().unwrap(), false);
+ now += DEFAULT_RTT / 2;
+ let s_init = server.process(c_init.as_ref(), now).dgram();
+ now += DEFAULT_RTT / 2;
+ client.process_input(&s_init.unwrap(), now);
+ maybe_authenticate(&mut client);
+ let c_fin = client.process_output(now).dgram();
+ assert_eq!(*client.state(), State::Connected);
+ now += DEFAULT_RTT / 2;
+ server.process_input(&c_fin.unwrap(), now);
+ assert_eq!(*server.state(), State::Confirmed);
+}
+
+#[test]
+fn extra_initial_invalid_cid() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+
+ let c_init = client.process(None, now).dgram();
+ assert!(c_init.is_some());
+ now += DEFAULT_RTT / 2;
+ let s_init = server.process(c_init.as_ref(), now).dgram();
+ assert!(s_init.is_some());
+ now += DEFAULT_RTT / 2;
+
+ // If the client receives a packet that contains the wrong connection
+ // ID, it won't send another Initial.
+ let (_, hs) = split_datagram(&s_init.unwrap());
+ let hs = hs.unwrap();
+ let mut copy = hs.to_vec();
+ assert_ne!(copy[5], 0); // The DCID should be non-zero length.
+ copy[6] ^= 0xc4;
+ let dgram_copy = Datagram::new(hs.destination(), hs.source(), hs.tos(), hs.ttl(), copy);
+ let nothing = client.process(Some(&dgram_copy), now).dgram();
+ assert!(nothing.is_none());
+}
+
+#[test]
+fn connect_one_version() {
+ fn connect_v(version: Version) {
+ fixture_init();
+ let mut client = Connection::new_client(
+ test_fixture::DEFAULT_SERVER_NAME,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(CountingConnectionIdGenerator::default())),
+ addr(),
+ addr(),
+ ConnectionParameters::default().versions(version, vec![version]),
+ now(),
+ )
+ .unwrap();
+ let mut server = Connection::new_server(
+ test_fixture::DEFAULT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(CountingConnectionIdGenerator::default())),
+ ConnectionParameters::default().versions(version, vec![version]),
+ )
+ .unwrap();
+ connect_force_idle(&mut client, &mut server);
+ assert_eq!(client.version(), version);
+ assert_eq!(server.version(), version);
+ }
+
+ for v in Version::all() {
+ println!("Connecting with {v:?}");
+ connect_v(v);
+ }
+}
+
+#[test]
+fn anti_amplification() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+
+ // With a gigantic transport parameter, the server is unable to complete
+ // the handshake within the amplification limit.
+ let very_big = TransportParameter::Bytes(vec![0; PATH_MTU_V6 * 3]);
+ server.set_local_tparam(0xce16, very_big).unwrap();
+
+ let c_init = client.process_output(now).dgram();
+ now += DEFAULT_RTT / 2;
+ let s_init1 = server.process(c_init.as_ref(), now).dgram().unwrap();
+ assert_eq!(s_init1.len(), PATH_MTU_V6);
+ let s_init2 = server.process_output(now).dgram().unwrap();
+ assert_eq!(s_init2.len(), PATH_MTU_V6);
+
+ // Skip the gap for pacing here.
+ let s_pacing = server.process_output(now).callback();
+ assert_ne!(s_pacing, Duration::new(0, 0));
+ now += s_pacing;
+
+ let s_init3 = server.process_output(now).dgram().unwrap();
+ assert_eq!(s_init3.len(), PATH_MTU_V6);
+ let cb = server.process_output(now).callback();
+ assert_ne!(cb, Duration::new(0, 0));
+
+ now += DEFAULT_RTT / 2;
+ client.process_input(&s_init1, now);
+ client.process_input(&s_init2, now);
+ let ack_count = client.stats().frame_tx.ack;
+ let frame_count = client.stats().frame_tx.all;
+ let ack = client.process(Some(&s_init3), now).dgram().unwrap();
+ assert!(!maybe_authenticate(&mut client)); // No need yet.
+
+ // The client sends a padded datagram, with just ACK for Handshake.
+ assert_eq!(client.stats().frame_tx.ack, ack_count + 1);
+ assert_eq!(client.stats().frame_tx.all, frame_count + 1);
+ assert_ne!(ack.len(), PATH_MTU_V6); // Not padded (it includes Handshake).
+
+ now += DEFAULT_RTT / 2;
+ let remainder = server.process(Some(&ack), now).dgram();
+
+ now += DEFAULT_RTT / 2;
+ client.process_input(&remainder.unwrap(), now);
+ assert!(maybe_authenticate(&mut client)); // OK, we have all of it.
+ let fin = client.process_output(now).dgram();
+ assert_eq!(*client.state(), State::Connected);
+
+ now += DEFAULT_RTT / 2;
+ server.process_input(&fin.unwrap(), now);
+ assert_eq!(*server.state(), State::Confirmed);
+}
+
+#[cfg(not(feature = "fuzzing"))]
+#[test]
+fn garbage_initial() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ let dgram = client.process_output(now()).dgram().unwrap();
+ let (initial, rest) = split_datagram(&dgram);
+ let mut corrupted = Vec::from(&initial[..initial.len() - 1]);
+ corrupted.push(initial[initial.len() - 1] ^ 0xb7);
+ corrupted.extend_from_slice(rest.as_ref().map_or(&[], |r| &r[..]));
+ let garbage = datagram(corrupted);
+ assert_eq!(Output::None, server.process(Some(&garbage), now()));
+}
+
+#[test]
+fn drop_initial_packet_from_wrong_address() {
+ let mut client = default_client();
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+
+ let mut server = default_server();
+ let out = server.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_some());
+
+ let p = out.dgram().unwrap();
+ let dgram = Datagram::new(
+ SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2)), 443),
+ p.destination(),
+ p.tos(),
+ p.ttl(),
+ &p[..],
+ );
+
+ let out = client.process(Some(&dgram), now());
+ assert!(out.as_dgram_ref().is_none());
+}
+
+#[test]
+fn drop_handshake_packet_from_wrong_address() {
+ let mut client = default_client();
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+
+ let mut server = default_server();
+ let out = server.process(out.as_dgram_ref(), now());
+ assert!(out.as_dgram_ref().is_some());
+
+ let (s_in, s_hs) = split_datagram(&out.dgram().unwrap());
+
+ // Pass the initial packet.
+ mem::drop(client.process(Some(&s_in), now()).dgram());
+
+ let p = s_hs.unwrap();
+ let dgram = Datagram::new(
+ SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2)), 443),
+ p.destination(),
+ p.tos(),
+ p.ttl(),
+ &p[..],
+ );
+
+ let out = client.process(Some(&dgram), now());
+ assert!(out.as_dgram_ref().is_none());
+}
+
+#[test]
+fn ech() {
+ let mut server = default_server();
+ let (sk, pk) = generate_ech_keys().unwrap();
+ server
+ .server_enable_ech(ECH_CONFIG_ID, ECH_PUBLIC_NAME, &sk, &pk)
+ .unwrap();
+
+ let mut client = default_client();
+ client.client_enable_ech(server.ech_config()).unwrap();
+
+ connect(&mut client, &mut server);
+
+ assert!(client.tls_info().unwrap().ech_accepted());
+ assert!(server.tls_info().unwrap().ech_accepted());
+ assert!(client.tls_preinfo().unwrap().ech_accepted().unwrap());
+ assert!(server.tls_preinfo().unwrap().ech_accepted().unwrap());
+}
+
+fn damaged_ech_config(config: &[u8]) -> Vec<u8> {
+ let mut cfg = Vec::from(config);
+ // Ensure that the version and config_id is correct.
+ assert_eq!(cfg[2], 0xfe);
+ assert_eq!(cfg[3], 0x0d);
+ assert_eq!(cfg[6], ECH_CONFIG_ID);
+ // Change the config_id so that the server doesn't recognize it.
+ cfg[6] ^= 0x94;
+ cfg
+}
+
+#[test]
+fn ech_retry() {
+ fixture_init();
+ let mut server = default_server();
+ let (sk, pk) = generate_ech_keys().unwrap();
+ server
+ .server_enable_ech(ECH_CONFIG_ID, ECH_PUBLIC_NAME, &sk, &pk)
+ .unwrap();
+
+ let mut client = default_client();
+ client
+ .client_enable_ech(&damaged_ech_config(server.ech_config()))
+ .unwrap();
+
+ let dgram = client.process_output(now()).dgram();
+ let dgram = server.process(dgram.as_ref(), now()).dgram();
+ client.process_input(&dgram.unwrap(), now());
+ let auth_event = ConnectionEvent::EchFallbackAuthenticationNeeded {
+ public_name: String::from(ECH_PUBLIC_NAME),
+ };
+ assert!(client.events().any(|e| e == auth_event));
+ client.authenticated(AuthenticationStatus::Ok, now());
+ assert!(client.state().error().is_some());
+
+ // Tell the server about the error.
+ let dgram = client.process_output(now()).dgram();
+ server.process_input(&dgram.unwrap(), now());
+ assert_eq!(
+ server.state().error(),
+ Some(&ConnectionError::Transport(Error::PeerError(0x100 + 121)))
+ );
+
+ let Some(ConnectionError::Transport(Error::EchRetry(updated_config))) = client.state().error()
+ else {
+ panic!(
+ "Client state should be failed with EchRetry, is {:?}",
+ client.state()
+ );
+ };
+
+ let mut server = default_server();
+ server
+ .server_enable_ech(ECH_CONFIG_ID, ECH_PUBLIC_NAME, &sk, &pk)
+ .unwrap();
+ let mut client = default_client();
+ client.client_enable_ech(updated_config).unwrap();
+
+ connect(&mut client, &mut server);
+
+ assert!(client.tls_info().unwrap().ech_accepted());
+ assert!(server.tls_info().unwrap().ech_accepted());
+ assert!(client.tls_preinfo().unwrap().ech_accepted().unwrap());
+ assert!(server.tls_preinfo().unwrap().ech_accepted().unwrap());
+}
+
+#[test]
+fn ech_retry_fallback_rejected() {
+ fixture_init();
+ let mut server = default_server();
+ let (sk, pk) = generate_ech_keys().unwrap();
+ server
+ .server_enable_ech(ECH_CONFIG_ID, ECH_PUBLIC_NAME, &sk, &pk)
+ .unwrap();
+
+ let mut client = default_client();
+ client
+ .client_enable_ech(&damaged_ech_config(server.ech_config()))
+ .unwrap();
+
+ let dgram = client.process_output(now()).dgram();
+ let dgram = server.process(dgram.as_ref(), now()).dgram();
+ client.process_input(&dgram.unwrap(), now());
+ let auth_event = ConnectionEvent::EchFallbackAuthenticationNeeded {
+ public_name: String::from(ECH_PUBLIC_NAME),
+ };
+ assert!(client.events().any(|e| e == auth_event));
+ client.authenticated(AuthenticationStatus::PolicyRejection, now());
+ assert!(client.state().error().is_some());
+
+ if let Some(ConnectionError::Transport(Error::EchRetry(_))) = client.state().error() {
+ panic!("Client should not get EchRetry error");
+ }
+
+ // Pass the error on.
+ let dgram = client.process_output(now()).dgram();
+ server.process_input(&dgram.unwrap(), now());
+ assert_eq!(
+ server.state().error(),
+ Some(&ConnectionError::Transport(Error::PeerError(298)))
+ ); // A bad_certificate alert.
+}
+
+#[test]
+fn bad_min_ack_delay() {
+ const EXPECTED_ERROR: ConnectionError =
+ ConnectionError::Transport(Error::TransportParameterError);
+ let mut server = default_server();
+ let max_ad = u64::try_from(DEFAULT_ACK_DELAY.as_micros()).unwrap();
+ server
+ .set_local_tparam(MIN_ACK_DELAY, TransportParameter::Integer(max_ad + 1))
+ .unwrap();
+ let mut client = default_client();
+
+ let dgram = client.process_output(now()).dgram();
+ let dgram = server.process(dgram.as_ref(), now()).dgram();
+ client.process_input(&dgram.unwrap(), now());
+ client.authenticated(AuthenticationStatus::Ok, now());
+ assert_eq!(client.state().error(), Some(&EXPECTED_ERROR));
+ let dgram = client.process_output(now()).dgram();
+
+ server.process_input(&dgram.unwrap(), now());
+ assert_eq!(
+ server.state().error(),
+ Some(&ConnectionError::Transport(Error::PeerError(
+ Error::TransportParameterError.code()
+ )))
+ );
+}
+
+/// Ensure that the client probes correctly if it only receives Initial packets
+/// from the server.
+#[test]
+fn only_server_initial() {
+ let mut server = default_server();
+ let mut client = default_client();
+ let mut now = now();
+
+ let client_dgram = client.process_output(now).dgram();
+
+ // Now fetch two flights of messages from the server.
+ let server_dgram1 = server.process(client_dgram.as_ref(), now).dgram();
+ let server_dgram2 = server.process_output(now + AT_LEAST_PTO).dgram();
+
+ // Only pass on the Initial from the first. We should get a Handshake in return.
+ let (initial, handshake) = split_datagram(&server_dgram1.unwrap());
+ assert!(handshake.is_some());
+
+ // The client will not acknowledge the Initial as it discards keys.
+ // It sends a Handshake probe instead, containing just a PING frame.
+ assert_eq!(client.stats().frame_tx.ping, 0);
+ let probe = client.process(Some(&initial), now).dgram();
+ assertions::assert_handshake(&probe.unwrap());
+ assert_eq!(client.stats().dropped_rx, 0);
+ assert_eq!(client.stats().frame_tx.ping, 1);
+
+ let (initial, handshake) = split_datagram(&server_dgram2.unwrap());
+ assert!(handshake.is_some());
+
+ // The same happens after a PTO, even though the client will discard the Initial packet.
+ now += AT_LEAST_PTO;
+ assert_eq!(client.stats().frame_tx.ping, 1);
+ let discarded = client.stats().dropped_rx;
+ let probe = client.process(Some(&initial), now).dgram();
+ assertions::assert_handshake(&probe.unwrap());
+ assert_eq!(client.stats().frame_tx.ping, 2);
+ assert_eq!(client.stats().dropped_rx, discarded + 1);
+
+ // Pass the Handshake packet and complete the handshake.
+ client.process_input(&handshake.unwrap(), now);
+ maybe_authenticate(&mut client);
+ let dgram = client.process_output(now).dgram();
+ let dgram = server.process(dgram.as_ref(), now).dgram();
+ client.process_input(&dgram.unwrap(), now);
+
+ assert_eq!(*client.state(), State::Confirmed);
+ assert_eq!(*server.state(), State::Confirmed);
+}
+
+// Collect a few spare Initial packets as the handshake is exchanged.
+// Later, replay those packets to see if they result in additional probes; they should not.
+#[test]
+fn no_extra_probes_after_confirmed() {
+ let mut server = default_server();
+ let mut client = default_client();
+ let mut now = now();
+
+ // First, collect a client Initial.
+ let spare_initial = client.process_output(now).dgram();
+ assert!(spare_initial.is_some());
+
+ // Collect ANOTHER client Initial.
+ now += AT_LEAST_PTO;
+ let dgram = client.process_output(now).dgram();
+ let (replay_initial, _) = split_datagram(dgram.as_ref().unwrap());
+
+ // Finally, run the handshake.
+ now += AT_LEAST_PTO * 2;
+ let dgram = client.process_output(now).dgram();
+ let dgram = server.process(dgram.as_ref(), now).dgram();
+
+ // The server should have dropped the Initial keys now, so passing in the Initial
+ // should elicit a retransmit rather than having it completely ignored.
+ let spare_handshake = server.process(Some(&replay_initial), now).dgram();
+ assert!(spare_handshake.is_some());
+
+ client.process_input(&dgram.unwrap(), now);
+ maybe_authenticate(&mut client);
+ let dgram = client.process_output(now).dgram();
+ let dgram = server.process(dgram.as_ref(), now).dgram();
+ client.process_input(&dgram.unwrap(), now);
+
+ assert_eq!(*client.state(), State::Confirmed);
+ assert_eq!(*server.state(), State::Confirmed);
+
+ let probe = server.process(spare_initial.as_ref(), now).dgram();
+ assert!(probe.is_none());
+ let probe = client.process(spare_handshake.as_ref(), now).dgram();
+ assert!(probe.is_none());
+}
+
+#[test]
+fn implicit_rtt_server() {
+ const RTT: Duration = Duration::from_secs(2);
+ let mut server = default_server();
+ let mut client = default_client();
+ let mut now = now();
+
+ let dgram = client.process_output(now).dgram();
+ now += RTT / 2;
+ let dgram = server.process(dgram.as_ref(), now).dgram();
+ now += RTT / 2;
+ let dgram = client.process(dgram.as_ref(), now).dgram();
+ assertions::assert_handshake(dgram.as_ref().unwrap());
+ now += RTT / 2;
+ server.process_input(&dgram.unwrap(), now);
+
+ // The server doesn't receive any acknowledgments, but it can infer
+ // an RTT estimate from having discarded the Initial packet number space.
+ assert_eq!(server.stats().rtt, RTT);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/idle.rs b/third_party/rust/neqo-transport/src/connection/tests/idle.rs
new file mode 100644
index 0000000000..c33726917a
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/idle.rs
@@ -0,0 +1,752 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{
+ mem,
+ time::{Duration, Instant},
+};
+
+use neqo_common::{qtrace, Encoder};
+use test_fixture::{self, now, split_datagram};
+
+use super::{
+ super::{Connection, ConnectionParameters, IdleTimeout, Output, State},
+ connect, connect_force_idle, connect_rtt_idle, connect_with_rtt, default_client,
+ default_server, maybe_authenticate, new_client, new_server, send_and_receive, send_something,
+ AT_LEAST_PTO, DEFAULT_STREAM_DATA,
+};
+use crate::{
+ packet::PacketBuilder,
+ stats::FrameStats,
+ stream_id::{StreamId, StreamType},
+ tparams::{self, TransportParameter},
+ tracking::PacketNumberSpace,
+};
+
+fn default_timeout() -> Duration {
+ ConnectionParameters::default().get_idle_timeout()
+}
+
+fn test_idle_timeout(client: &mut Connection, server: &mut Connection, timeout: Duration) {
+ assert!(timeout > Duration::from_secs(1));
+ connect_force_idle(client, server);
+
+ let now = now();
+
+ let res = client.process(None, now);
+ assert_eq!(res, Output::Callback(timeout));
+
+ // Still connected after timeout-1 seconds. Idle timer not reset
+ mem::drop(client.process(
+ None,
+ now + timeout.checked_sub(Duration::from_secs(1)).unwrap(),
+ ));
+ assert!(matches!(client.state(), State::Confirmed));
+
+ mem::drop(client.process(None, now + timeout));
+
+ // Not connected after timeout.
+ assert!(matches!(client.state(), State::Closed(_)));
+}
+
+#[test]
+fn idle_timeout() {
+ let mut client = default_client();
+ let mut server = default_server();
+ test_idle_timeout(&mut client, &mut server, default_timeout());
+}
+
+#[test]
+fn idle_timeout_custom_client() {
+ const IDLE_TIMEOUT: Duration = Duration::from_secs(5);
+ let mut client = new_client(ConnectionParameters::default().idle_timeout(IDLE_TIMEOUT));
+ let mut server = default_server();
+ test_idle_timeout(&mut client, &mut server, IDLE_TIMEOUT);
+}
+
+#[test]
+fn idle_timeout_custom_server() {
+ const IDLE_TIMEOUT: Duration = Duration::from_secs(5);
+ let mut client = default_client();
+ let mut server = new_server(ConnectionParameters::default().idle_timeout(IDLE_TIMEOUT));
+ test_idle_timeout(&mut client, &mut server, IDLE_TIMEOUT);
+}
+
+#[test]
+fn idle_timeout_custom_both() {
+ const LOWER_TIMEOUT: Duration = Duration::from_secs(5);
+ const HIGHER_TIMEOUT: Duration = Duration::from_secs(10);
+ let mut client = new_client(ConnectionParameters::default().idle_timeout(HIGHER_TIMEOUT));
+ let mut server = new_server(ConnectionParameters::default().idle_timeout(LOWER_TIMEOUT));
+ test_idle_timeout(&mut client, &mut server, LOWER_TIMEOUT);
+}
+
+#[test]
+fn asymmetric_idle_timeout() {
+ const LOWER_TIMEOUT_MS: u64 = 1000;
+ const LOWER_TIMEOUT: Duration = Duration::from_millis(LOWER_TIMEOUT_MS);
+ // Sanity check the constant.
+ assert!(LOWER_TIMEOUT < default_timeout());
+
+ let mut client = default_client();
+ let mut server = default_server();
+
+ // Overwrite the default at the server.
+ server
+ .tps
+ .borrow_mut()
+ .local
+ .set_integer(tparams::IDLE_TIMEOUT, LOWER_TIMEOUT_MS);
+ server.idle_timeout = IdleTimeout::new(LOWER_TIMEOUT);
+
+ // Now connect and force idleness manually.
+ // We do that by following what `force_idle` does and have each endpoint
+ // send two packets, which are delivered out of order. See `force_idle`.
+ connect(&mut client, &mut server);
+ let c1 = send_something(&mut client, now());
+ let c2 = send_something(&mut client, now());
+ server.process_input(&c2, now());
+ server.process_input(&c1, now());
+ let s1 = send_something(&mut server, now());
+ let s2 = send_something(&mut server, now());
+ client.process_input(&s2, now());
+ let ack = client.process(Some(&s1), now()).dgram();
+ assert!(ack.is_some());
+ // Now both should have received ACK frames so should be idle.
+ assert_eq!(
+ server.process(ack.as_ref(), now()),
+ Output::Callback(LOWER_TIMEOUT)
+ );
+ assert_eq!(client.process(None, now()), Output::Callback(LOWER_TIMEOUT));
+}
+
+#[test]
+fn tiny_idle_timeout() {
+ const RTT: Duration = Duration::from_millis(500);
+ const LOWER_TIMEOUT_MS: u64 = 100;
+ const LOWER_TIMEOUT: Duration = Duration::from_millis(LOWER_TIMEOUT_MS);
+ // We won't respect a value that is lower than 3*PTO, sanity check.
+ assert!(LOWER_TIMEOUT < 3 * RTT);
+
+ let mut client = default_client();
+ let mut server = default_server();
+
+ // Overwrite the default at the server.
+ server
+ .set_local_tparam(
+ tparams::IDLE_TIMEOUT,
+ TransportParameter::Integer(LOWER_TIMEOUT_MS),
+ )
+ .unwrap();
+ server.idle_timeout = IdleTimeout::new(LOWER_TIMEOUT);
+
+ // Now connect with an RTT and force idleness manually.
+ let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT);
+ let c1 = send_something(&mut client, now);
+ let c2 = send_something(&mut client, now);
+ now += RTT / 2;
+ server.process_input(&c2, now);
+ server.process_input(&c1, now);
+ let s1 = send_something(&mut server, now);
+ let s2 = send_something(&mut server, now);
+ now += RTT / 2;
+ client.process_input(&s2, now);
+ let ack = client.process(Some(&s1), now).dgram();
+ assert!(ack.is_some());
+
+ // The client should be idle now, but with a different timer.
+ if let Output::Callback(t) = client.process(None, now) {
+ assert!(t > LOWER_TIMEOUT);
+ } else {
+ panic!("Client not idle");
+ }
+
+ // The server should go idle after the ACK, but again with a larger timeout.
+ now += RTT / 2;
+ if let Output::Callback(t) = client.process(ack.as_ref(), now) {
+ assert!(t > LOWER_TIMEOUT);
+ } else {
+ panic!("Client not idle");
+ }
+}
+
+#[test]
+fn idle_send_packet1() {
+ const DELTA: Duration = Duration::from_millis(10);
+
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+ connect_force_idle(&mut client, &mut server);
+
+ let timeout = client.process(None, now).callback();
+ assert_eq!(timeout, default_timeout());
+
+ now += Duration::from_secs(10);
+ let dgram = send_and_receive(&mut client, &mut server, now);
+ assert!(dgram.is_some()); // the server will want to ACK, we can drop that.
+
+ // Still connected after 39 seconds because idle timer reset by the
+ // outgoing packet.
+ now += default_timeout() - DELTA;
+ let dgram = client.process(None, now).dgram();
+ assert!(dgram.is_some()); // PTO
+ assert!(client.state().connected());
+
+ // Not connected after 40 seconds.
+ now += DELTA;
+ let out = client.process(None, now);
+ assert!(matches!(out, Output::None));
+ assert!(client.state().closed());
+}
+
+#[test]
+fn idle_send_packet2() {
+ const GAP: Duration = Duration::from_secs(10);
+ const DELTA: Duration = Duration::from_millis(10);
+
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let mut now = now();
+
+ let timeout = client.process(None, now).callback();
+ assert_eq!(timeout, default_timeout());
+
+ // First transmission at t=GAP.
+ now += GAP;
+ mem::drop(send_something(&mut client, now));
+
+ // Second transmission at t=2*GAP.
+ mem::drop(send_something(&mut client, now + GAP));
+ assert!((GAP * 2 + DELTA) < default_timeout());
+
+ // Still connected just before GAP + default_timeout().
+ now += default_timeout() - DELTA;
+ let dgram = client.process(None, now).dgram();
+ assert!(dgram.is_some()); // PTO
+ assert!(matches!(client.state(), State::Confirmed));
+
+ // Not connected after 40 seconds because timer not reset by second
+ // outgoing packet
+ now += DELTA;
+ let out = client.process(None, now);
+ assert!(matches!(out, Output::None));
+ assert!(matches!(client.state(), State::Closed(_)));
+}
+
+#[test]
+fn idle_recv_packet() {
+ const FUDGE: Duration = Duration::from_millis(10);
+
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let mut now = now();
+
+ let res = client.process(None, now);
+ assert_eq!(res, Output::Callback(default_timeout()));
+
+ let stream = client.stream_create(StreamType::BiDi).unwrap();
+ assert_eq!(stream, 0);
+ assert_eq!(client.stream_send(stream, b"hello").unwrap(), 5);
+
+ // Respond with another packet.
+ // Note that it is important that this not result in the RTT increasing above 0.
+ // Otherwise, the eventual timeout will be extended (and we're not testing that).
+ now += Duration::from_secs(10);
+ let out = client.process(None, now);
+ server.process_input(&out.dgram().unwrap(), now);
+ assert_eq!(server.stream_send(stream, b"world").unwrap(), 5);
+ let out = server.process_output(now);
+ assert_ne!(out.as_dgram_ref(), None);
+ mem::drop(client.process(out.as_dgram_ref(), now));
+ assert!(matches!(client.state(), State::Confirmed));
+
+ // Add a little less than the idle timeout and we're still connected.
+ now += default_timeout() - FUDGE;
+ mem::drop(client.process(None, now));
+ assert!(matches!(client.state(), State::Confirmed));
+
+ now += FUDGE;
+ mem::drop(client.process(None, now));
+
+ assert!(matches!(client.state(), State::Closed(_)));
+}
+
+/// Caching packets should not cause the connection to become idle.
+/// This requires a few tricks to keep the connection from going
+/// idle while preventing any progress on the handshake.
+#[test]
+fn idle_caching() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let start = now();
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+
+ // Perform the first round trip, but drop the Initial from the server.
+ // The client then caches the Handshake packet.
+ let dgram = client.process_output(start).dgram();
+ let dgram = server.process(dgram.as_ref(), start).dgram();
+ let (_, handshake) = split_datagram(&dgram.unwrap());
+ client.process_input(&handshake.unwrap(), start);
+
+ // Perform an exchange and keep the connection alive.
+ // Only allow a packet containing a PING to pass.
+ let middle = start + AT_LEAST_PTO;
+ mem::drop(client.process_output(middle));
+ let dgram = client.process_output(middle).dgram();
+
+ // Get the server to send its first probe and throw that away.
+ mem::drop(server.process_output(middle).dgram());
+ // Now let the server process the client PING. This causes the server
+ // to send CRYPTO frames again, so manually extract and discard those.
+ let ping_before_s = server.stats().frame_rx.ping;
+ server.process_input(&dgram.unwrap(), middle);
+ assert_eq!(server.stats().frame_rx.ping, ping_before_s + 1);
+ let mut tokens = Vec::new();
+ server
+ .crypto
+ .streams
+ .write_frame(
+ PacketNumberSpace::Initial,
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ )
+ .unwrap();
+ assert_eq!(tokens.len(), 1);
+ tokens.clear();
+ server
+ .crypto
+ .streams
+ .write_frame(
+ PacketNumberSpace::Initial,
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ )
+ .unwrap();
+ assert!(tokens.is_empty());
+ let dgram = server.process_output(middle).dgram();
+
+ // Now only allow the Initial packet from the server through;
+ // it shouldn't contain a CRYPTO frame.
+ let (initial, _) = split_datagram(&dgram.unwrap());
+ let ping_before_c = client.stats().frame_rx.ping;
+ let ack_before = client.stats().frame_rx.ack;
+ client.process_input(&initial, middle);
+ assert_eq!(client.stats().frame_rx.ping, ping_before_c + 1);
+ assert_eq!(client.stats().frame_rx.ack, ack_before + 1);
+
+ let end = start + default_timeout() + (AT_LEAST_PTO / 2);
+ // Now let the server Initial through, with the CRYPTO frame.
+ let dgram = server.process_output(end).dgram();
+ let (initial, _) = split_datagram(&dgram.unwrap());
+ neqo_common::qwarn!("client ingests initial, finally");
+ mem::drop(client.process(Some(&initial), end));
+ maybe_authenticate(&mut client);
+ let dgram = client.process_output(end).dgram();
+ let dgram = server.process(dgram.as_ref(), end).dgram();
+ client.process_input(&dgram.unwrap(), end);
+ assert_eq!(*client.state(), State::Confirmed);
+ assert_eq!(*server.state(), State::Confirmed);
+}
+
+/// This function opens a bidirectional stream and leaves both endpoints
+/// idle, with the stream left open.
+/// The stream ID of that stream is returned (along with the new time).
+fn create_stream_idle_rtt(
+ initiator: &mut Connection,
+ responder: &mut Connection,
+ mut now: Instant,
+ rtt: Duration,
+) -> (Instant, StreamId) {
+ let check_idle = |endpoint: &mut Connection, now: Instant| {
+ let delay = endpoint.process_output(now).callback();
+ qtrace!([endpoint], "idle timeout {:?}", delay);
+ if rtt < default_timeout() / 4 {
+ assert_eq!(default_timeout(), delay);
+ } else {
+ assert!(delay > default_timeout());
+ }
+ };
+
+ // Exchange a message each way on a stream.
+ let stream = initiator.stream_create(StreamType::BiDi).unwrap();
+ _ = initiator.stream_send(stream, DEFAULT_STREAM_DATA).unwrap();
+ let req = initiator.process_output(now).dgram();
+ now += rtt / 2;
+ responder.process_input(&req.unwrap(), now);
+
+ // Reordering two packets from the responder forces the initiator to be idle.
+ _ = responder.stream_send(stream, DEFAULT_STREAM_DATA).unwrap();
+ let resp1 = responder.process_output(now).dgram();
+ _ = responder.stream_send(stream, DEFAULT_STREAM_DATA).unwrap();
+ let resp2 = responder.process_output(now).dgram();
+
+ now += rtt / 2;
+ initiator.process_input(&resp2.unwrap(), now);
+ initiator.process_input(&resp1.unwrap(), now);
+ let ack = initiator.process_output(now).dgram();
+ assert!(ack.is_some());
+ check_idle(initiator, now);
+
+ // Receiving the ACK should return the responder to idle too.
+ now += rtt / 2;
+ responder.process_input(&ack.unwrap(), now);
+ check_idle(responder, now);
+
+ (now, stream)
+}
+
+fn create_stream_idle(initiator: &mut Connection, responder: &mut Connection) -> StreamId {
+ let (_, stream) = create_stream_idle_rtt(initiator, responder, now(), Duration::new(0, 0));
+ stream
+}
+
+fn assert_idle(endpoint: &mut Connection, now: Instant, expected: Duration) {
+ assert_eq!(endpoint.process_output(now).callback(), expected);
+}
+
+/// The creator of a stream marks it as important enough to use a keep-alive.
+#[test]
+fn keep_alive_initiator() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let stream = create_stream_idle(&mut server, &mut client);
+ let mut now = now();
+
+ // Marking the stream for keep-alive changes the idle timeout.
+ server.stream_keep_alive(stream, true).unwrap();
+ assert_idle(&mut server, now, default_timeout() / 2);
+
+ // Wait that long and the server should send a PING frame.
+ now += default_timeout() / 2;
+ let pings_before = server.stats().frame_tx.ping;
+ let ping = server.process_output(now).dgram();
+ assert!(ping.is_some());
+ assert_eq!(server.stats().frame_tx.ping, pings_before + 1);
+
+ // Exchange ack for the PING.
+ let out = client.process(ping.as_ref(), now).dgram();
+ let out = server.process(out.as_ref(), now).dgram();
+ assert!(client.process(out.as_ref(), now).dgram().is_none());
+
+ // Check that there will be next keep-alive ping after default_timeout() / 2.
+ assert_idle(&mut server, now, default_timeout() / 2);
+ now += default_timeout() / 2;
+ let pings_before2 = server.stats().frame_tx.ping;
+ let ping = server.process_output(now).dgram();
+ assert!(ping.is_some());
+ assert_eq!(server.stats().frame_tx.ping, pings_before2 + 1);
+}
+
+/// Test a keep-alive ping is retransmitted if lost.
+#[test]
+fn keep_alive_lost() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let stream = create_stream_idle(&mut server, &mut client);
+ let mut now = now();
+
+ // Marking the stream for keep-alive changes the idle timeout.
+ server.stream_keep_alive(stream, true).unwrap();
+ assert_idle(&mut server, now, default_timeout() / 2);
+
+ // Wait that long and the server should send a PING frame.
+ now += default_timeout() / 2;
+ let pings_before = server.stats().frame_tx.ping;
+ let ping = server.process_output(now).dgram();
+ assert!(ping.is_some());
+ assert_eq!(server.stats().frame_tx.ping, pings_before + 1);
+
+ // Wait for ping to be marked lost.
+ assert!(server.process_output(now).callback() < AT_LEAST_PTO);
+ now += AT_LEAST_PTO;
+ let pings_before2 = server.stats().frame_tx.ping;
+ let ping = server.process_output(now).dgram();
+ assert!(ping.is_some());
+ assert_eq!(server.stats().frame_tx.ping, pings_before2 + 1);
+
+ // Exchange ack for the PING.
+ let out = client.process(ping.as_ref(), now).dgram();
+
+ now += Duration::from_millis(20);
+ let out = server.process(out.as_ref(), now).dgram();
+
+ assert!(client.process(out.as_ref(), now).dgram().is_none());
+
+ // TODO: if we run server.process with current value of now, the server will
+ // return some small timeout for the recovry although it does not have
+ // any outstanding data. Therefore we call it after AT_LEAST_PTO.
+ now += AT_LEAST_PTO;
+ assert_idle(&mut server, now, default_timeout() / 2 - AT_LEAST_PTO);
+}
+
+/// The other peer can also keep it alive.
+#[test]
+fn keep_alive_responder() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let stream = create_stream_idle(&mut server, &mut client);
+ let mut now = now();
+
+ // Marking the stream for keep-alive changes the idle timeout.
+ client.stream_keep_alive(stream, true).unwrap();
+ assert_idle(&mut client, now, default_timeout() / 2);
+
+ // Wait that long and the client should send a PING frame.
+ now += default_timeout() / 2;
+ let pings_before = client.stats().frame_tx.ping;
+ let ping = client.process_output(now).dgram();
+ assert!(ping.is_some());
+ assert_eq!(client.stats().frame_tx.ping, pings_before + 1);
+}
+
+/// Unmark a stream as being keep-alive.
+#[test]
+fn keep_alive_unmark() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let stream = create_stream_idle(&mut client, &mut server);
+
+ client.stream_keep_alive(stream, true).unwrap();
+ assert_idle(&mut client, now(), default_timeout() / 2);
+
+ client.stream_keep_alive(stream, false).unwrap();
+ assert_idle(&mut client, now(), default_timeout());
+}
+
+/// The sender has something to send. Make it send it
+/// and cause the receiver to become idle by sending something
+/// else, reordering the packets, and consuming the ACK.
+/// Note that the sender might not be idle if the thing that it
+/// sends results in something in addition to an ACK.
+fn transfer_force_idle(sender: &mut Connection, receiver: &mut Connection) {
+ let dgram = sender.process_output(now()).dgram();
+ let chaff = send_something(sender, now());
+ receiver.process_input(&chaff, now());
+ receiver.process_input(&dgram.unwrap(), now());
+ let ack = receiver.process_output(now()).dgram();
+ sender.process_input(&ack.unwrap(), now());
+}
+
+/// Receiving the end of the stream stops keep-alives for that stream.
+/// Even if that data hasn't been read.
+#[test]
+fn keep_alive_close() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let stream = create_stream_idle(&mut client, &mut server);
+
+ client.stream_keep_alive(stream, true).unwrap();
+ assert_idle(&mut client, now(), default_timeout() / 2);
+
+ client.stream_close_send(stream).unwrap();
+ transfer_force_idle(&mut client, &mut server);
+ assert_idle(&mut client, now(), default_timeout() / 2);
+
+ server.stream_close_send(stream).unwrap();
+ transfer_force_idle(&mut server, &mut client);
+ assert_idle(&mut client, now(), default_timeout());
+}
+
+/// Receiving `RESET_STREAM` stops keep-alives for that stream, but only once
+/// the sending side is also closed.
+#[test]
+fn keep_alive_reset() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let stream = create_stream_idle(&mut client, &mut server);
+
+ client.stream_keep_alive(stream, true).unwrap();
+ assert_idle(&mut client, now(), default_timeout() / 2);
+
+ client.stream_close_send(stream).unwrap();
+ transfer_force_idle(&mut client, &mut server);
+ assert_idle(&mut client, now(), default_timeout() / 2);
+
+ server.stream_reset_send(stream, 0).unwrap();
+ transfer_force_idle(&mut server, &mut client);
+ assert_idle(&mut client, now(), default_timeout());
+
+ // The client will fade away from here.
+ let t = now() + (default_timeout() / 2);
+ assert_eq!(client.process_output(t).callback(), default_timeout() / 2);
+ let t = now() + default_timeout();
+ assert_eq!(client.process_output(t), Output::None);
+}
+
+/// Stopping sending also cancels the keep-alive.
+#[test]
+fn keep_alive_stop_sending() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let stream = create_stream_idle(&mut client, &mut server);
+
+ client.stream_keep_alive(stream, true).unwrap();
+ assert_idle(&mut client, now(), default_timeout() / 2);
+
+ client.stream_close_send(stream).unwrap();
+ client.stream_stop_sending(stream, 0).unwrap();
+ transfer_force_idle(&mut client, &mut server);
+ // The server will have sent RESET_STREAM, which the client will
+ // want to acknowledge, so force that out.
+ let junk = send_something(&mut server, now());
+ let ack = client.process(Some(&junk), now()).dgram();
+ assert!(ack.is_some());
+
+ // Now the client should be idle.
+ assert_idle(&mut client, now(), default_timeout());
+}
+
+/// Multiple active streams are tracked properly.
+#[test]
+fn keep_alive_multiple_stop() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let stream = create_stream_idle(&mut client, &mut server);
+
+ client.stream_keep_alive(stream, true).unwrap();
+ assert_idle(&mut client, now(), default_timeout() / 2);
+
+ let other = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_keep_alive(other, true).unwrap();
+ assert_idle(&mut client, now(), default_timeout() / 2);
+
+ client.stream_keep_alive(stream, false).unwrap();
+ assert_idle(&mut client, now(), default_timeout() / 2);
+
+ client.stream_keep_alive(other, false).unwrap();
+ assert_idle(&mut client, now(), default_timeout());
+}
+
+/// If the RTT is too long relative to the idle timeout, the keep-alive is large too.
+#[test]
+fn keep_alive_large_rtt() {
+ let mut client = default_client();
+ let mut server = default_server();
+ // Use an RTT that is large enough to cause the PTO timer to exceed half
+ // the idle timeout.
+ let rtt = default_timeout() * 3 / 4;
+ let now = connect_with_rtt(&mut client, &mut server, now(), rtt);
+ let (now, stream) = create_stream_idle_rtt(&mut server, &mut client, now, rtt);
+
+ // Calculating PTO here is tricky as RTTvar has eroded after multiple round trips.
+ // Just check that the delay is larger than the baseline and the RTT.
+ for endpoint in &mut [client, server] {
+ endpoint.stream_keep_alive(stream, true).unwrap();
+ let delay = endpoint.process_output(now).callback();
+ qtrace!([endpoint], "new delay {:?}", delay);
+ assert!(delay > default_timeout() / 2);
+ assert!(delay > rtt);
+ }
+}
+
+/// Only the recipient of a unidirectional stream can keep it alive.
+#[test]
+fn keep_alive_uni() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let stream = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_keep_alive(stream, true).unwrap_err();
+ _ = client.stream_send(stream, DEFAULT_STREAM_DATA).unwrap();
+ let dgram = client.process_output(now()).dgram();
+
+ server.process_input(&dgram.unwrap(), now());
+ server.stream_keep_alive(stream, true).unwrap();
+}
+
+/// Test a keep-alive ping is send if there are outstading ack-eliciting packets and that
+/// the connection is closed after the idle timeout passes.
+#[test]
+fn keep_alive_with_ack_eliciting_packet_lost() {
+ const RTT: Duration = Duration::from_millis(500); // PTO will be ~1.1125s
+
+ // The idle time out will be set to ~ 5 * PTO. (IDLE_TIMEOUT/2 > pto and IDLE_TIMEOUT/2 < pto
+ // + 2pto) After handshake all packets will be lost. The following steps will happen after
+ // the handshake:
+ // - data will be sent on a stream that is marked for keep-alive, (at start time)
+ // - PTO timer will trigger first, and the data will be retransmited toghether with a PING, (at
+ // the start time + pto)
+ // - keep-alive timer will trigger and a keep-alive PING will be sent, (at the start time +
+ // IDLE_TIMEOUT / 2)
+ // - PTO timer will trigger again. (at the start time + pto + 2*pto)
+ // - Idle time out will trigger (at the timeout + IDLE_TIMEOUT)
+ const IDLE_TIMEOUT: Duration = Duration::from_millis(6000);
+
+ let mut client = new_client(ConnectionParameters::default().idle_timeout(IDLE_TIMEOUT));
+ let mut server = default_server();
+ let mut now = connect_rtt_idle(&mut client, &mut server, RTT);
+ // connect_rtt_idle increase now by RTT / 2;
+ now -= RTT / 2;
+ assert_idle(&mut client, now, IDLE_TIMEOUT);
+
+ // Create a stream.
+ let stream = client.stream_create(StreamType::BiDi).unwrap();
+ // Marking the stream for keep-alive changes the idle timeout.
+ client.stream_keep_alive(stream, true).unwrap();
+ assert_idle(&mut client, now, IDLE_TIMEOUT / 2);
+
+ // Send data on the stream that will be lost.
+ _ = client.stream_send(stream, DEFAULT_STREAM_DATA).unwrap();
+ let _lost_packet = client.process_output(now).dgram();
+
+ let pto = client.process_output(now).callback();
+ // Wait for packet to be marked lost.
+ assert!(pto < IDLE_TIMEOUT / 2);
+ now += pto;
+ let retransmit = client.process_output(now).dgram();
+ assert!(retransmit.is_some());
+ let retransmit = client.process_output(now).dgram();
+ assert!(retransmit.is_some());
+
+ // The next callback should be for an idle PING.
+ assert_eq!(
+ client.process_output(now).callback(),
+ IDLE_TIMEOUT / 2 - pto
+ );
+
+ // Wait that long and the client should send a PING frame.
+ now += IDLE_TIMEOUT / 2 - pto;
+ let pings_before = client.stats().frame_tx.ping;
+ let ping = client.process_output(now).dgram();
+ assert!(ping.is_some());
+ assert_eq!(client.stats().frame_tx.ping, pings_before + 1);
+
+ // The next callback is for a PTO, the PTO timer is 2 * pto now.
+ assert_eq!(client.process_output(now).callback(), pto * 2);
+ now += pto * 2;
+ // Now we will retransmit stream data.
+ let retransmit = client.process_output(now).dgram();
+ assert!(retransmit.is_some());
+ let retransmit = client.process_output(now).dgram();
+ assert!(retransmit.is_some());
+
+ // The next callback will be an idle timeout.
+ assert_eq!(
+ client.process_output(now).callback(),
+ IDLE_TIMEOUT / 2 - 2 * pto
+ );
+
+ now += IDLE_TIMEOUT / 2 - 2 * pto;
+ let out = client.process_output(now);
+ assert!(matches!(out, Output::None));
+ assert!(matches!(client.state(), State::Closed(_)));
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/keys.rs b/third_party/rust/neqo-transport/src/connection/tests/keys.rs
new file mode 100644
index 0000000000..c247bba670
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/keys.rs
@@ -0,0 +1,346 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::mem;
+
+use neqo_common::{qdebug, Datagram};
+use test_fixture::{self, now};
+
+use super::{
+ super::{
+ super::{ConnectionError, ERROR_AEAD_LIMIT_REACHED},
+ Connection, ConnectionParameters, Error, Output, State, StreamType,
+ },
+ connect, connect_force_idle, default_client, default_server, maybe_authenticate,
+ send_and_receive, send_something, AT_LEAST_PTO,
+};
+use crate::{
+ crypto::{OVERWRITE_INVOCATIONS, UPDATE_WRITE_KEYS_AT},
+ packet::PacketNumber,
+ path::PATH_MTU_V6,
+};
+
+fn check_discarded(
+ peer: &mut Connection,
+ pkt: &Datagram,
+ response: bool,
+ dropped: usize,
+ dups: usize,
+) {
+ // Make sure to flush any saved datagrams before doing this.
+ mem::drop(peer.process_output(now()));
+
+ let before = peer.stats();
+ let out = peer.process(Some(pkt), now());
+ assert_eq!(out.as_dgram_ref().is_some(), response);
+ let after = peer.stats();
+ assert_eq!(dropped, after.dropped_rx - before.dropped_rx);
+ assert_eq!(dups, after.dups_rx - before.dups_rx);
+}
+
+fn assert_update_blocked(c: &mut Connection) {
+ assert_eq!(
+ c.initiate_key_update().unwrap_err(),
+ Error::KeyUpdateBlocked
+ );
+}
+
+fn overwrite_invocations(n: PacketNumber) {
+ OVERWRITE_INVOCATIONS.with(|v| {
+ *v.borrow_mut() = Some(n);
+ });
+}
+
+#[test]
+fn discarded_initial_keys() {
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let init_pkt_c = client.process(None, now()).dgram();
+ assert!(init_pkt_c.is_some());
+ assert_eq!(init_pkt_c.as_ref().unwrap().len(), PATH_MTU_V6);
+
+ qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN");
+ let mut server = default_server();
+ let init_pkt_s = server.process(init_pkt_c.as_ref(), now()).dgram();
+ assert!(init_pkt_s.is_some());
+
+ qdebug!("---- client: cert verification");
+ let out = client.process(init_pkt_s.as_ref(), now()).dgram();
+ assert!(out.is_some());
+
+ // The client has received a handshake packet. It will remove the Initial keys.
+ // We will check this by processing init_pkt_s a second time.
+ // The initial packet should be dropped. The packet contains a Handshake packet as well, which
+ // will be marked as dup. And it will contain padding, which will be "dropped".
+ // The client will generate a Handshake packet here to avoid stalling.
+ check_discarded(&mut client, &init_pkt_s.unwrap(), true, 2, 1);
+
+ assert!(maybe_authenticate(&mut client));
+
+ // The server has not removed the Initial keys yet, because it has not yet received a Handshake
+ // packet from the client.
+ // We will check this by processing init_pkt_c a second time.
+ // The dropped packet is padding. The Initial packet has been mark dup.
+ check_discarded(&mut server, &init_pkt_c.clone().unwrap(), false, 1, 1);
+
+ qdebug!("---- client: SH..FIN -> FIN");
+ let out = client.process(None, now()).dgram();
+ assert!(out.is_some());
+
+ // The server will process the first Handshake packet.
+ // After this the Initial keys will be dropped.
+ let out = server.process(out.as_ref(), now()).dgram();
+ assert!(out.is_some());
+
+ // Check that the Initial keys are dropped at the server
+ // We will check this by processing init_pkt_c a third time.
+ // The Initial packet has been dropped and padding that follows it.
+ // There is no dups, everything has been dropped.
+ check_discarded(&mut server, &init_pkt_c.unwrap(), false, 1, 0);
+}
+
+#[test]
+fn key_update_client() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+ let mut now = now();
+
+ assert_eq!(client.get_epochs(), (Some(3), Some(3))); // (write, read)
+ assert_eq!(server.get_epochs(), (Some(3), Some(3)));
+
+ assert!(client.initiate_key_update().is_ok());
+ assert_update_blocked(&mut client);
+
+ // Initiating an update should only increase the write epoch.
+ let idle_timeout = ConnectionParameters::default().get_idle_timeout();
+ assert_eq!(Output::Callback(idle_timeout), client.process(None, now));
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+
+ // Send something to propagate the update.
+ // Note that the server will acknowledge immediately when RTT is zero.
+ assert!(send_and_receive(&mut client, &mut server, now).is_some());
+
+ // The server should now be waiting to discharge read keys.
+ assert_eq!(server.get_epochs(), (Some(4), Some(3)));
+ let res = server.process(None, now);
+ if let Output::Callback(t) = res {
+ assert!(t < idle_timeout);
+ } else {
+ panic!("server should now be waiting to clear keys");
+ }
+
+ // Without having had time to purge old keys, more updates are blocked.
+ // The spec would permits it at this point, but we are more conservative.
+ assert_update_blocked(&mut client);
+ // The server can't update until it receives an ACK for a packet.
+ assert_update_blocked(&mut server);
+
+ // Waiting now for at least a PTO should cause the server to drop old keys.
+ // But at this point the client hasn't received a key update from the server.
+ // It will be stuck with old keys.
+ now += AT_LEAST_PTO;
+ let dgram = client.process(None, now).dgram();
+ assert!(dgram.is_some()); // Drop this packet.
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+ mem::drop(server.process(None, now));
+ assert_eq!(server.get_epochs(), (Some(4), Some(4)));
+
+ // Even though the server has updated, it hasn't received an ACK yet.
+ assert_update_blocked(&mut server);
+
+ // Now get an ACK from the server.
+ // The previous PTO packet (see above) was dropped, so we should get an ACK here.
+ let dgram = send_and_receive(&mut client, &mut server, now);
+ assert!(dgram.is_some());
+ let res = client.process(dgram.as_ref(), now);
+ // This is the first packet that the client has received from the server
+ // with new keys, so its read timer just started.
+ if let Output::Callback(t) = res {
+ assert!(t < ConnectionParameters::default().get_idle_timeout());
+ } else {
+ panic!("client should now be waiting to clear keys");
+ }
+
+ assert_update_blocked(&mut client);
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+ // The server can't update until it gets something from the client.
+ assert_update_blocked(&mut server);
+
+ now += AT_LEAST_PTO;
+ mem::drop(client.process(None, now));
+ assert_eq!(client.get_epochs(), (Some(4), Some(4)));
+}
+
+#[test]
+fn key_update_consecutive() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let now = now();
+
+ assert!(server.initiate_key_update().is_ok());
+ assert_eq!(server.get_epochs(), (Some(4), Some(3)));
+
+ // Server sends something.
+ // Send twice and drop the first to induce an ACK from the client.
+ mem::drop(send_something(&mut server, now)); // Drop this.
+
+ // Another packet from the server will cause the client to ACK and update keys.
+ let dgram = send_and_receive(&mut server, &mut client, now);
+ assert!(dgram.is_some());
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+
+ // Have the server process the ACK.
+ if let Output::Callback(_) = server.process(dgram.as_ref(), now) {
+ assert_eq!(server.get_epochs(), (Some(4), Some(3)));
+ // Now move the server temporarily into the future so that it
+ // rotates the keys. The client stays in the present.
+ mem::drop(server.process(None, now + AT_LEAST_PTO));
+ assert_eq!(server.get_epochs(), (Some(4), Some(4)));
+ } else {
+ panic!("server should have a timer set");
+ }
+
+ // Now update keys on the server again.
+ assert!(server.initiate_key_update().is_ok());
+ assert_eq!(server.get_epochs(), (Some(5), Some(4)));
+
+ let dgram = send_something(&mut server, now + AT_LEAST_PTO);
+
+ // However, as the server didn't wait long enough to update again, the
+ // client hasn't rotated its keys, so the packet gets dropped.
+ check_discarded(&mut client, &dgram, false, 1, 0);
+}
+
+// Key updates can't be initiated too early.
+#[test]
+fn key_update_before_confirmed() {
+ let mut client = default_client();
+ assert_update_blocked(&mut client);
+ let mut server = default_server();
+ assert_update_blocked(&mut server);
+
+ // Client Initial
+ let dgram = client.process(None, now()).dgram();
+ assert!(dgram.is_some());
+ assert_update_blocked(&mut client);
+
+ // Server Initial + Handshake
+ let dgram = server.process(dgram.as_ref(), now()).dgram();
+ assert!(dgram.is_some());
+ assert_update_blocked(&mut server);
+
+ // Client Handshake
+ client.process_input(&dgram.unwrap(), now());
+ assert_update_blocked(&mut client);
+
+ assert!(maybe_authenticate(&mut client));
+ assert_update_blocked(&mut client);
+
+ let dgram = client.process(None, now()).dgram();
+ assert!(dgram.is_some());
+ assert_update_blocked(&mut client);
+
+ // Server HANDSHAKE_DONE
+ let dgram = server.process(dgram.as_ref(), now()).dgram();
+ assert!(dgram.is_some());
+ assert!(server.initiate_key_update().is_ok());
+
+ // Client receives HANDSHAKE_DONE
+ let dgram = client.process(dgram.as_ref(), now()).dgram();
+ assert!(dgram.is_none());
+ assert!(client.initiate_key_update().is_ok());
+}
+
+#[test]
+fn exhaust_write_keys() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ overwrite_invocations(0);
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert!(client.stream_send(stream_id, b"explode!").is_ok());
+ let dgram = client.process_output(now()).dgram();
+ assert!(dgram.is_none());
+ assert!(matches!(
+ client.state(),
+ State::Closed(ConnectionError::Transport(Error::KeysExhausted))
+ ));
+}
+
+#[test]
+fn exhaust_read_keys() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let dgram = send_something(&mut client, now());
+
+ overwrite_invocations(0);
+ let dgram = server.process(Some(&dgram), now()).dgram();
+ assert!(matches!(
+ server.state(),
+ State::Closed(ConnectionError::Transport(Error::KeysExhausted))
+ ));
+
+ client.process_input(&dgram.unwrap(), now());
+ assert!(matches!(
+ client.state(),
+ State::Draining {
+ error: ConnectionError::Transport(Error::PeerError(ERROR_AEAD_LIMIT_REACHED)),
+ ..
+ }
+ ));
+}
+
+#[test]
+fn automatic_update_write_keys() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ overwrite_invocations(UPDATE_WRITE_KEYS_AT);
+ mem::drop(send_something(&mut client, now()));
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+}
+
+#[test]
+fn automatic_update_write_keys_later() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ overwrite_invocations(UPDATE_WRITE_KEYS_AT + 2);
+ // No update after the first.
+ mem::drop(send_something(&mut client, now()));
+ assert_eq!(client.get_epochs(), (Some(3), Some(3)));
+ // The second will update though.
+ mem::drop(send_something(&mut client, now()));
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+}
+
+#[test]
+fn automatic_update_write_keys_blocked() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ // An outstanding key update will block the automatic update.
+ client.initiate_key_update().unwrap();
+
+ overwrite_invocations(UPDATE_WRITE_KEYS_AT);
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert!(client.stream_send(stream_id, b"explode!").is_ok());
+ let dgram = client.process_output(now()).dgram();
+ // Not being able to update is fatal.
+ assert!(dgram.is_none());
+ assert!(matches!(
+ client.state(),
+ State::Closed(ConnectionError::Transport(Error::KeysExhausted))
+ ));
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/migration.rs b/third_party/rust/neqo-transport/src/connection/tests/migration.rs
new file mode 100644
index 0000000000..8307a7dd84
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/migration.rs
@@ -0,0 +1,953 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{
+ cell::RefCell,
+ net::{IpAddr, Ipv6Addr, SocketAddr},
+ rc::Rc,
+ time::{Duration, Instant},
+};
+
+use neqo_common::{Datagram, Decoder};
+use test_fixture::{
+ self, addr, addr_v4,
+ assertions::{assert_v4_path, assert_v6_path},
+ fixture_init, new_neqo_qlog, now,
+};
+
+use super::{
+ super::{Connection, Output, State, StreamType},
+ connect_fail, connect_force_idle, connect_rtt_idle, default_client, default_server,
+ maybe_authenticate, new_client, new_server, send_something, CountingConnectionIdGenerator,
+};
+use crate::{
+ cid::LOCAL_ACTIVE_CID_LIMIT,
+ connection::tests::send_something_paced,
+ frame::FRAME_TYPE_NEW_CONNECTION_ID,
+ packet::PacketBuilder,
+ path::{PATH_MTU_V4, PATH_MTU_V6},
+ tparams::{self, PreferredAddress, TransportParameter},
+ ConnectionError, ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef,
+ ConnectionParameters, EmptyConnectionIdGenerator, Error,
+};
+
+/// This should be a valid-seeming transport parameter.
+/// And it should have different values to `addr` and `addr_v4`.
+const SAMPLE_PREFERRED_ADDRESS: &[u8] = &[
+ 0xc0, 0x00, 0x02, 0x02, 0x01, 0xbb, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0xbb, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, 0x03, 0x03,
+ 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
+];
+
+// These tests generally use two paths:
+// The connection is established on a path with the same IPv6 address on both ends.
+// Migrations move to a path with the same IPv4 address on both ends.
+// This simplifies validation as the same assertions can be used for client and server.
+// The risk is that there is a place where source/destination local/remote is inverted.
+
+fn loopback() -> SocketAddr {
+ SocketAddr::new(IpAddr::V6(Ipv6Addr::from(1)), 443)
+}
+
+fn change_path(d: &Datagram, a: SocketAddr) -> Datagram {
+ Datagram::new(a, a, d.tos(), d.ttl(), &d[..])
+}
+
+fn new_port(a: SocketAddr) -> SocketAddr {
+ let (port, _) = a.port().overflowing_add(410);
+ SocketAddr::new(a.ip(), port)
+}
+
+fn change_source_port(d: &Datagram) -> Datagram {
+ Datagram::new(
+ new_port(d.source()),
+ d.destination(),
+ d.tos(),
+ d.ttl(),
+ &d[..],
+ )
+}
+
+/// As these tests use a new path, that path often has a non-zero RTT.
+/// Pacing can be a problem when testing that path. This skips time forward.
+fn skip_pacing(c: &mut Connection, now: Instant) -> Instant {
+ let pacing = c.process_output(now).callback();
+ assert_ne!(pacing, Duration::new(0, 0));
+ now + pacing
+}
+
+#[test]
+fn rebinding_port() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let dgram = send_something(&mut client, now());
+ let dgram = change_source_port(&dgram);
+
+ server.process_input(&dgram, now());
+ // Have the server send something so that it generates a packet.
+ let stream_id = server.stream_create(StreamType::UniDi).unwrap();
+ server.stream_close_send(stream_id).unwrap();
+ let dgram = server.process_output(now()).dgram();
+ let dgram = dgram.unwrap();
+ assert_eq!(dgram.source(), addr());
+ assert_eq!(dgram.destination(), new_port(addr()));
+}
+
+/// This simulates an attack where a valid packet is forwarded on
+/// a different path. This shows how both paths are probed and the
+/// server eventually returns to the original path.
+#[test]
+fn path_forwarding_attack() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+ let mut now = now();
+
+ let dgram = send_something(&mut client, now);
+ let dgram = change_path(&dgram, addr_v4());
+ server.process_input(&dgram, now);
+
+ // The server now probes the new (primary) path.
+ let new_probe = server.process_output(now).dgram().unwrap();
+ assert_eq!(server.stats().frame_tx.path_challenge, 1);
+ assert_v4_path(&new_probe, false); // Can't be padded.
+
+ // The server also probes the old path.
+ let old_probe = server.process_output(now).dgram().unwrap();
+ assert_eq!(server.stats().frame_tx.path_challenge, 2);
+ assert_v6_path(&old_probe, true);
+
+ // New data from the server is sent on the new path, but that is
+ // now constrained by the amplification limit.
+ let stream_id = server.stream_create(StreamType::UniDi).unwrap();
+ server.stream_close_send(stream_id).unwrap();
+ assert!(server.process_output(now).dgram().is_none());
+
+ // The client should respond to the challenge on the new path.
+ // The server couldn't pad, so the client is also amplification limited.
+ let new_resp = client.process(Some(&new_probe), now).dgram().unwrap();
+ assert_eq!(client.stats().frame_rx.path_challenge, 1);
+ assert_eq!(client.stats().frame_tx.path_challenge, 1);
+ assert_eq!(client.stats().frame_tx.path_response, 1);
+ assert_v4_path(&new_resp, false);
+
+ // The client also responds to probes on the old path.
+ let old_resp = client.process(Some(&old_probe), now).dgram().unwrap();
+ assert_eq!(client.stats().frame_rx.path_challenge, 2);
+ assert_eq!(client.stats().frame_tx.path_challenge, 1);
+ assert_eq!(client.stats().frame_tx.path_response, 2);
+ assert_v6_path(&old_resp, true);
+
+ // But the client still sends data on the old path.
+ let client_data1 = send_something(&mut client, now);
+ assert_v6_path(&client_data1, false); // Just data.
+
+ // Receiving the PATH_RESPONSE from the client opens the amplification
+ // limit enough for the server to respond.
+ // This is padded because it includes PATH_CHALLENGE.
+ let server_data1 = server.process(Some(&new_resp), now).dgram().unwrap();
+ assert_v4_path(&server_data1, true);
+ assert_eq!(server.stats().frame_tx.path_challenge, 3);
+
+ // The client responds to this probe on the new path.
+ client.process_input(&server_data1, now);
+ let stream_before = client.stats().frame_tx.stream;
+ let padded_resp = send_something(&mut client, now);
+ assert_eq!(stream_before, client.stats().frame_tx.stream);
+ assert_v4_path(&padded_resp, true); // This is padded!
+
+ // But new data from the client stays on the old path.
+ let client_data2 = client.process_output(now).dgram().unwrap();
+ assert_v6_path(&client_data2, false);
+
+ // The server keeps sending on the new path.
+ now = skip_pacing(&mut server, now);
+ let server_data2 = send_something(&mut server, now);
+ assert_v4_path(&server_data2, false);
+
+ // Until new data is received from the client on the old path.
+ server.process_input(&client_data2, now);
+ // The server sends a probe on the "old" path.
+ let server_data3 = send_something(&mut server, now);
+ assert_v4_path(&server_data3, true);
+ // But switches data transmission to the "new" path.
+ let server_data4 = server.process_output(now).dgram().unwrap();
+ assert_v6_path(&server_data4, false);
+}
+
+#[test]
+fn migrate_immediate() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+ let now = now();
+
+ client
+ .migrate(Some(addr_v4()), Some(addr_v4()), true, now)
+ .unwrap();
+
+ let client1 = send_something(&mut client, now);
+ assert_v4_path(&client1, true); // Contains PATH_CHALLENGE.
+ let client2 = send_something(&mut client, now);
+ assert_v4_path(&client2, false); // Doesn't.
+
+ let server_delayed = send_something(&mut server, now);
+
+ // The server accepts the first packet and migrates (but probes).
+ let server1 = server.process(Some(&client1), now).dgram().unwrap();
+ assert_v4_path(&server1, true);
+ let server2 = server.process_output(now).dgram().unwrap();
+ assert_v6_path(&server2, true);
+
+ // The second packet has no real effect, it just elicits an ACK.
+ let all_before = server.stats().frame_tx.all;
+ let ack_before = server.stats().frame_tx.ack;
+ let server3 = server.process(Some(&client2), now).dgram();
+ assert!(server3.is_some());
+ assert_eq!(server.stats().frame_tx.all, all_before + 1);
+ assert_eq!(server.stats().frame_tx.ack, ack_before + 1);
+
+ // Receiving a packet sent by the server before migration doesn't change path.
+ client.process_input(&server_delayed, now);
+ // The client has sent two unpaced packets and this new path has no RTT estimate
+ // so this might be paced.
+ let (client3, _t) = send_something_paced(&mut client, now, true);
+ assert_v4_path(&client3, false);
+}
+
+/// RTT estimates for paths should be preserved across migrations.
+#[test]
+fn migrate_rtt() {
+ const RTT: Duration = Duration::from_millis(20);
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, RTT);
+
+ client
+ .migrate(Some(addr_v4()), Some(addr_v4()), true, now)
+ .unwrap();
+ // The RTT might be increased for the new path, so allow a little flexibility.
+ let rtt = client.paths.rtt();
+ assert!(rtt > RTT);
+ assert!(rtt < RTT * 2);
+}
+
+#[test]
+fn migrate_immediate_fail() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+ let mut now = now();
+
+ client
+ .migrate(Some(addr_v4()), Some(addr_v4()), true, now)
+ .unwrap();
+
+ let probe = client.process_output(now).dgram().unwrap();
+ assert_v4_path(&probe, true); // Contains PATH_CHALLENGE.
+
+ for _ in 0..2 {
+ let cb = client.process_output(now).callback();
+ assert_ne!(cb, Duration::new(0, 0));
+ now += cb;
+
+ let before = client.stats().frame_tx;
+ let probe = client.process_output(now).dgram().unwrap();
+ assert_v4_path(&probe, true); // Contains PATH_CHALLENGE.
+ let after = client.stats().frame_tx;
+ assert_eq!(after.path_challenge, before.path_challenge + 1);
+ assert_eq!(after.padding, before.padding + 1);
+ assert_eq!(after.all, before.all + 2);
+
+ // This might be a PTO, which will result in sending a probe.
+ if let Some(probe) = client.process_output(now).dgram() {
+ assert_v4_path(&probe, false); // Contains PATH_CHALLENGE.
+ let after = client.stats().frame_tx;
+ assert_eq!(after.ping, before.ping + 1);
+ assert_eq!(after.all, before.all + 3);
+ }
+ }
+
+ let pto = client.process_output(now).callback();
+ assert_ne!(pto, Duration::new(0, 0));
+ now += pto;
+
+ // The client should fall back to the original path and retire the connection ID.
+ let fallback = client.process_output(now).dgram();
+ assert_v6_path(&fallback.unwrap(), false);
+ assert_eq!(client.stats().frame_tx.retire_connection_id, 1);
+}
+
+/// Migrating to the same path shouldn't do anything special,
+/// except that the path is probed.
+#[test]
+fn migrate_same() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+ let now = now();
+
+ client
+ .migrate(Some(addr()), Some(addr()), true, now)
+ .unwrap();
+
+ let probe = client.process_output(now).dgram().unwrap();
+ assert_v6_path(&probe, true); // Contains PATH_CHALLENGE.
+ assert_eq!(client.stats().frame_tx.path_challenge, 1);
+
+ let resp = server.process(Some(&probe), now).dgram().unwrap();
+ assert_v6_path(&resp, true);
+ assert_eq!(server.stats().frame_tx.path_response, 1);
+ assert_eq!(server.stats().frame_tx.path_challenge, 0);
+
+ // Everything continues happily.
+ client.process_input(&resp, now);
+ let contd = send_something(&mut client, now);
+ assert_v6_path(&contd, false);
+}
+
+/// Migrating to the same path, if it fails, causes the connection to fail.
+#[test]
+fn migrate_same_fail() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+ let mut now = now();
+
+ client
+ .migrate(Some(addr()), Some(addr()), true, now)
+ .unwrap();
+
+ let probe = client.process_output(now).dgram().unwrap();
+ assert_v6_path(&probe, true); // Contains PATH_CHALLENGE.
+
+ for _ in 0..2 {
+ let cb = client.process_output(now).callback();
+ assert_ne!(cb, Duration::new(0, 0));
+ now += cb;
+
+ let before = client.stats().frame_tx;
+ let probe = client.process_output(now).dgram().unwrap();
+ assert_v6_path(&probe, true); // Contains PATH_CHALLENGE.
+ let after = client.stats().frame_tx;
+ assert_eq!(after.path_challenge, before.path_challenge + 1);
+ assert_eq!(after.padding, before.padding + 1);
+ assert_eq!(after.all, before.all + 2);
+
+ // This might be a PTO, which will result in sending a probe.
+ if let Some(probe) = client.process_output(now).dgram() {
+ assert_v6_path(&probe, false); // Contains PATH_CHALLENGE.
+ let after = client.stats().frame_tx;
+ assert_eq!(after.ping, before.ping + 1);
+ assert_eq!(after.all, before.all + 3);
+ }
+ }
+
+ let pto = client.process_output(now).callback();
+ assert_ne!(pto, Duration::new(0, 0));
+ now += pto;
+
+ // The client should mark this path as failed and close immediately.
+ let res = client.process_output(now);
+ assert!(matches!(res, Output::None));
+ assert!(matches!(
+ client.state(),
+ State::Closed(ConnectionError::Transport(Error::NoAvailablePath))
+ ));
+}
+
+/// This gets the connection ID from a datagram using the default
+/// connection ID generator/decoder.
+fn get_cid(d: &Datagram) -> ConnectionIdRef {
+ let gen = CountingConnectionIdGenerator::default();
+ assert_eq!(d[0] & 0x80, 0); // Only support short packets for now.
+ gen.decode_cid(&mut Decoder::from(&d[1..])).unwrap()
+}
+
+fn migration(mut client: Connection) {
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+ let now = now();
+
+ client
+ .migrate(Some(addr_v4()), Some(addr_v4()), false, now)
+ .unwrap();
+
+ let probe = client.process_output(now).dgram().unwrap();
+ assert_v4_path(&probe, true); // Contains PATH_CHALLENGE.
+ assert_eq!(client.stats().frame_tx.path_challenge, 1);
+ let probe_cid = ConnectionId::from(get_cid(&probe));
+
+ let resp = server.process(Some(&probe), now).dgram().unwrap();
+ assert_v4_path(&resp, true);
+ assert_eq!(server.stats().frame_tx.path_response, 1);
+ assert_eq!(server.stats().frame_tx.path_challenge, 1);
+
+ // Data continues to be exchanged on the new path.
+ let client_data = send_something(&mut client, now);
+ assert_ne!(get_cid(&client_data), probe_cid);
+ assert_v6_path(&client_data, false);
+ server.process_input(&client_data, now);
+ let server_data = send_something(&mut server, now);
+ assert_v6_path(&server_data, false);
+
+ // Once the client receives the probe response, it migrates to the new path.
+ client.process_input(&resp, now);
+ assert_eq!(client.stats().frame_rx.path_challenge, 1);
+ let migrate_client = send_something(&mut client, now);
+ assert_v4_path(&migrate_client, true); // Responds to server probe.
+
+ // The server now sees the migration and will switch over.
+ // However, it will probe the old path again, even though it has just
+ // received a response to its last probe, because it needs to verify
+ // that the migration is genuine.
+ server.process_input(&migrate_client, now);
+ let stream_before = server.stats().frame_tx.stream;
+ let probe_old_server = send_something(&mut server, now);
+ // This is just the double-check probe; no STREAM frames.
+ assert_v6_path(&probe_old_server, true);
+ assert_eq!(server.stats().frame_tx.path_challenge, 2);
+ assert_eq!(server.stats().frame_tx.stream, stream_before);
+
+ // The server then sends data on the new path.
+ let migrate_server = server.process_output(now).dgram().unwrap();
+ assert_v4_path(&migrate_server, false);
+ assert_eq!(server.stats().frame_tx.path_challenge, 2);
+ assert_eq!(server.stats().frame_tx.stream, stream_before + 1);
+
+ // The client receives these checks and responds to the probe, but uses the new path.
+ client.process_input(&migrate_server, now);
+ client.process_input(&probe_old_server, now);
+ let old_probe_resp = send_something(&mut client, now);
+ assert_v6_path(&old_probe_resp, true);
+ let client_confirmation = client.process_output(now).dgram().unwrap();
+ assert_v4_path(&client_confirmation, false);
+
+ // The server has now sent 2 packets, so it is blocked on the pacer. Wait.
+ let server_pacing = server.process_output(now).callback();
+ assert_ne!(server_pacing, Duration::new(0, 0));
+ // ... then confirm that the server sends on the new path still.
+ let server_confirmation = send_something(&mut server, now + server_pacing);
+ assert_v4_path(&server_confirmation, false);
+}
+
+#[test]
+fn migration_graceful() {
+ migration(default_client());
+}
+
+/// A client should be able to migrate when it has a zero-length connection ID.
+#[test]
+fn migration_client_empty_cid() {
+ fixture_init();
+ let client = Connection::new_client(
+ test_fixture::DEFAULT_SERVER_NAME,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(EmptyConnectionIdGenerator::default())),
+ addr(),
+ addr(),
+ ConnectionParameters::default(),
+ now(),
+ )
+ .unwrap();
+ migration(client);
+}
+
+/// Drive the handshake in the most expeditious fashion.
+/// Returns the packet containing `HANDSHAKE_DONE` from the server.
+fn fast_handshake(client: &mut Connection, server: &mut Connection) -> Option<Datagram> {
+ let dgram = client.process_output(now()).dgram();
+ let dgram = server.process(dgram.as_ref(), now()).dgram();
+ client.process_input(&dgram.unwrap(), now());
+ assert!(maybe_authenticate(client));
+ let dgram = client.process_output(now()).dgram();
+ server.process(dgram.as_ref(), now()).dgram()
+}
+
+fn preferred_address(hs_client: SocketAddr, hs_server: SocketAddr, preferred: SocketAddr) {
+ let mtu = match hs_client.ip() {
+ IpAddr::V4(_) => PATH_MTU_V4,
+ IpAddr::V6(_) => PATH_MTU_V6,
+ };
+ let assert_orig_path = |d: &Datagram, full_mtu: bool| {
+ assert_eq!(
+ d.destination(),
+ if d.source() == hs_client {
+ hs_server
+ } else if d.source() == hs_server {
+ hs_client
+ } else {
+ panic!();
+ }
+ );
+ if full_mtu {
+ assert_eq!(d.len(), mtu);
+ }
+ };
+ let assert_toward_spa = |d: &Datagram, full_mtu: bool| {
+ assert_eq!(d.destination(), preferred);
+ assert_eq!(d.source(), hs_client);
+ if full_mtu {
+ assert_eq!(d.len(), mtu);
+ }
+ };
+ let assert_from_spa = |d: &Datagram, full_mtu: bool| {
+ assert_eq!(d.destination(), hs_client);
+ assert_eq!(d.source(), preferred);
+ if full_mtu {
+ assert_eq!(d.len(), mtu);
+ }
+ };
+
+ fixture_init();
+ let (log, _contents) = new_neqo_qlog();
+ let mut client = Connection::new_client(
+ test_fixture::DEFAULT_SERVER_NAME,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(EmptyConnectionIdGenerator::default())),
+ hs_client,
+ hs_server,
+ ConnectionParameters::default(),
+ now(),
+ )
+ .unwrap();
+ client.set_qlog(log);
+ let spa = match preferred {
+ SocketAddr::V6(v6) => PreferredAddress::new(None, Some(v6)),
+ SocketAddr::V4(v4) => PreferredAddress::new(Some(v4), None),
+ };
+ let mut server = new_server(ConnectionParameters::default().preferred_address(spa));
+
+ let dgram = fast_handshake(&mut client, &mut server);
+
+ // The client is about to process HANDSHAKE_DONE.
+ // It should start probing toward the server's preferred address.
+ let probe = client.process(dgram.as_ref(), now()).dgram().unwrap();
+ assert_toward_spa(&probe, true);
+ assert_eq!(client.stats().frame_tx.path_challenge, 1);
+ assert_ne!(client.process_output(now()).callback(), Duration::new(0, 0));
+
+ // Data continues on the main path for the client.
+ let data = send_something(&mut client, now());
+ assert_orig_path(&data, false);
+
+ // The server responds to the probe.
+ let resp = server.process(Some(&probe), now()).dgram().unwrap();
+ assert_from_spa(&resp, true);
+ assert_eq!(server.stats().frame_tx.path_challenge, 1);
+ assert_eq!(server.stats().frame_tx.path_response, 1);
+
+ // Data continues on the main path for the server.
+ server.process_input(&data, now());
+ let data = send_something(&mut server, now());
+ assert_orig_path(&data, false);
+
+ // Client gets the probe response back and it migrates.
+ client.process_input(&resp, now());
+ client.process_input(&data, now());
+ let data = send_something(&mut client, now());
+ assert_toward_spa(&data, true);
+ assert_eq!(client.stats().frame_tx.stream, 2);
+ assert_eq!(client.stats().frame_tx.path_response, 1);
+
+ // The server sees the migration and probes the old path.
+ let probe = server.process(Some(&data), now()).dgram().unwrap();
+ assert_orig_path(&probe, true);
+ assert_eq!(server.stats().frame_tx.path_challenge, 2);
+
+ // But data now goes on the new path.
+ let data = send_something(&mut server, now());
+ assert_from_spa(&data, false);
+}
+
+/// Migration works for a new port number.
+#[test]
+fn preferred_address_new_port() {
+ let a = addr();
+ preferred_address(a, a, new_port(a));
+}
+
+/// Migration works for a new address too.
+#[test]
+fn preferred_address_new_address() {
+ let mut preferred = addr();
+ preferred.set_ip(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2)));
+ preferred_address(addr(), addr(), preferred);
+}
+
+/// Migration works for IPv4 addresses.
+#[test]
+fn preferred_address_new_port_v4() {
+ let a = addr_v4();
+ preferred_address(a, a, new_port(a));
+}
+
+/// Migrating to a loopback address is OK if we started there.
+#[test]
+fn preferred_address_loopback() {
+ let a = loopback();
+ preferred_address(a, a, new_port(a));
+}
+
+fn expect_no_migration(client: &mut Connection, server: &mut Connection) {
+ let dgram = fast_handshake(client, server);
+
+ // The client won't probe now, though it could; it remains idle.
+ let out = client.process(dgram.as_ref(), now());
+ assert_ne!(out.callback(), Duration::new(0, 0));
+
+ // Data continues on the main path for the client.
+ let data = send_something(client, now());
+ assert_v6_path(&data, false);
+ assert_eq!(client.stats().frame_tx.path_challenge, 0);
+}
+
+fn preferred_address_ignored(spa: PreferredAddress) {
+ let mut client = default_client();
+ let mut server = new_server(ConnectionParameters::default().preferred_address(spa));
+
+ expect_no_migration(&mut client, &mut server);
+}
+
+/// Using a loopback address in the preferred address is ignored.
+#[test]
+fn preferred_address_ignore_loopback() {
+ preferred_address_ignored(PreferredAddress::new_any(None, Some(loopback())));
+}
+
+/// A preferred address in the wrong address family is ignored.
+#[test]
+fn preferred_address_ignore_different_family() {
+ preferred_address_ignored(PreferredAddress::new_any(Some(addr_v4()), None));
+}
+
+/// Disabling preferred addresses at the client means that it ignores a perfectly
+/// good preferred address.
+#[test]
+fn preferred_address_disabled_client() {
+ let mut client = new_client(ConnectionParameters::default().disable_preferred_address());
+ let mut preferred = addr();
+ preferred.set_ip(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2)));
+ let spa = PreferredAddress::new_any(None, Some(preferred));
+ let mut server = new_server(ConnectionParameters::default().preferred_address(spa));
+
+ expect_no_migration(&mut client, &mut server);
+}
+
+#[test]
+fn preferred_address_empty_cid() {
+ fixture_init();
+
+ let spa = PreferredAddress::new_any(None, Some(new_port(addr())));
+ let res = Connection::new_server(
+ test_fixture::DEFAULT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(EmptyConnectionIdGenerator::default())),
+ ConnectionParameters::default().preferred_address(spa),
+ );
+ assert_eq!(res.unwrap_err(), Error::ConnectionIdsExhausted);
+}
+
+/// A server cannot include a preferred address if it chooses an empty connection ID.
+#[test]
+fn preferred_address_server_empty_cid() {
+ let mut client = default_client();
+ let mut server = Connection::new_server(
+ test_fixture::DEFAULT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(EmptyConnectionIdGenerator::default())),
+ ConnectionParameters::default(),
+ )
+ .unwrap();
+
+ server
+ .set_local_tparam(
+ tparams::PREFERRED_ADDRESS,
+ TransportParameter::Bytes(SAMPLE_PREFERRED_ADDRESS.to_vec()),
+ )
+ .unwrap();
+
+ connect_fail(
+ &mut client,
+ &mut server,
+ Error::TransportParameterError,
+ Error::PeerError(Error::TransportParameterError.code()),
+ );
+}
+
+/// A client shouldn't send a preferred address transport parameter.
+#[test]
+fn preferred_address_client() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ client
+ .set_local_tparam(
+ tparams::PREFERRED_ADDRESS,
+ TransportParameter::Bytes(SAMPLE_PREFERRED_ADDRESS.to_vec()),
+ )
+ .unwrap();
+
+ connect_fail(
+ &mut client,
+ &mut server,
+ Error::PeerError(Error::TransportParameterError.code()),
+ Error::TransportParameterError,
+ );
+}
+
+/// Test that migration isn't permitted if the connection isn't in the right state.
+#[test]
+fn migration_invalid_state() {
+ let mut client = default_client();
+ assert!(client
+ .migrate(Some(addr()), Some(addr()), false, now())
+ .is_err());
+
+ let mut server = default_server();
+ assert!(server
+ .migrate(Some(addr()), Some(addr()), false, now())
+ .is_err());
+ connect_force_idle(&mut client, &mut server);
+
+ assert!(server
+ .migrate(Some(addr()), Some(addr()), false, now())
+ .is_err());
+
+ client.close(now(), 0, "closing");
+ assert!(client
+ .migrate(Some(addr()), Some(addr()), false, now())
+ .is_err());
+ let close = client.process(None, now()).dgram();
+
+ let dgram = server.process(close.as_ref(), now()).dgram();
+ assert!(server
+ .migrate(Some(addr()), Some(addr()), false, now())
+ .is_err());
+
+ client.process_input(&dgram.unwrap(), now());
+ assert!(client
+ .migrate(Some(addr()), Some(addr()), false, now())
+ .is_err());
+}
+
+#[test]
+fn migration_invalid_address() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let mut cant_migrate = |local, remote| {
+ assert_eq!(
+ client.migrate(local, remote, true, now()).unwrap_err(),
+ Error::InvalidMigration
+ );
+ };
+
+ // Providing neither address is pointless and therefore an error.
+ cant_migrate(None, None);
+
+ // Providing a zero port number isn't valid.
+ let mut zero_port = addr();
+ zero_port.set_port(0);
+ cant_migrate(None, Some(zero_port));
+ cant_migrate(Some(zero_port), None);
+
+ // An unspecified remote address is bad.
+ let mut remote_unspecified = addr();
+ remote_unspecified.set_ip(IpAddr::V6(Ipv6Addr::from(0)));
+ cant_migrate(None, Some(remote_unspecified));
+
+ // Mixed address families is bad.
+ cant_migrate(Some(addr()), Some(addr_v4()));
+ cant_migrate(Some(addr_v4()), Some(addr()));
+
+ // Loopback to non-loopback is bad.
+ cant_migrate(Some(addr()), Some(loopback()));
+ cant_migrate(Some(loopback()), Some(addr()));
+ assert_eq!(
+ client
+ .migrate(Some(addr()), Some(loopback()), true, now())
+ .unwrap_err(),
+ Error::InvalidMigration
+ );
+ assert_eq!(
+ client
+ .migrate(Some(loopback()), Some(addr()), true, now())
+ .unwrap_err(),
+ Error::InvalidMigration
+ );
+}
+
+/// This inserts a frame into packets that provides a single new
+/// connection ID and retires all others.
+struct RetireAll {
+ cid_gen: Rc<RefCell<dyn ConnectionIdGenerator>>,
+}
+
+impl crate::connection::test_internal::FrameWriter for RetireAll {
+ fn write_frames(&mut self, builder: &mut PacketBuilder) {
+ // Use a sequence number that is large enough that all existing values
+ // will be lower (so they get retired). As the code doesn't care about
+ // gaps in sequence numbers, this is safe, even though the gap might
+ // hint that there are more outstanding connection IDs that are allowed.
+ const SEQNO: u64 = 100;
+ let cid = self.cid_gen.borrow_mut().generate_cid().unwrap();
+ builder
+ .encode_varint(FRAME_TYPE_NEW_CONNECTION_ID)
+ .encode_varint(SEQNO)
+ .encode_varint(SEQNO) // Retire Prior To
+ .encode_vec(1, &cid)
+ .encode(&[0x7f; 16]);
+ }
+}
+
+/// Test that forcing retirement of connection IDs forces retirement of all active
+/// connection IDs and the use of of newer one.
+#[test]
+fn retire_all() {
+ let mut client = default_client();
+ let cid_gen: Rc<RefCell<dyn ConnectionIdGenerator>> =
+ Rc::new(RefCell::new(CountingConnectionIdGenerator::default()));
+ let mut server = Connection::new_server(
+ test_fixture::DEFAULT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::clone(&cid_gen),
+ ConnectionParameters::default(),
+ )
+ .unwrap();
+ connect_force_idle(&mut client, &mut server);
+
+ let original_cid = ConnectionId::from(get_cid(&send_something(&mut client, now())));
+
+ server.test_frame_writer = Some(Box::new(RetireAll { cid_gen }));
+ let ncid = send_something(&mut server, now());
+ server.test_frame_writer = None;
+
+ let new_cid_before = client.stats().frame_rx.new_connection_id;
+ let retire_cid_before = client.stats().frame_tx.retire_connection_id;
+ client.process_input(&ncid, now());
+ let retire = send_something(&mut client, now());
+ assert_eq!(
+ client.stats().frame_rx.new_connection_id,
+ new_cid_before + 1
+ );
+ assert_eq!(
+ client.stats().frame_tx.retire_connection_id,
+ retire_cid_before + LOCAL_ACTIVE_CID_LIMIT
+ );
+
+ assert_ne!(get_cid(&retire), original_cid);
+}
+
+/// During a graceful migration, if the probed path can't get a new connection ID due
+/// to being forced to retire the one it is using, the migration will fail.
+#[test]
+fn retire_prior_to_migration_failure() {
+ let mut client = default_client();
+ let cid_gen: Rc<RefCell<dyn ConnectionIdGenerator>> =
+ Rc::new(RefCell::new(CountingConnectionIdGenerator::default()));
+ let mut server = Connection::new_server(
+ test_fixture::DEFAULT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::clone(&cid_gen),
+ ConnectionParameters::default(),
+ )
+ .unwrap();
+ connect_force_idle(&mut client, &mut server);
+
+ let original_cid = ConnectionId::from(get_cid(&send_something(&mut client, now())));
+
+ client
+ .migrate(Some(addr_v4()), Some(addr_v4()), false, now())
+ .unwrap();
+
+ // The client now probes the new path.
+ let probe = client.process_output(now()).dgram().unwrap();
+ assert_v4_path(&probe, true);
+ assert_eq!(client.stats().frame_tx.path_challenge, 1);
+ let probe_cid = ConnectionId::from(get_cid(&probe));
+ assert_ne!(original_cid, probe_cid);
+
+ // Have the server receive the probe, but separately have it decide to
+ // retire all of the available connection IDs.
+ server.test_frame_writer = Some(Box::new(RetireAll { cid_gen }));
+ let retire_all = send_something(&mut server, now());
+ server.test_frame_writer = None;
+
+ let resp = server.process(Some(&probe), now()).dgram().unwrap();
+ assert_v4_path(&resp, true);
+ assert_eq!(server.stats().frame_tx.path_response, 1);
+ assert_eq!(server.stats().frame_tx.path_challenge, 1);
+
+ // Have the client receive the NEW_CONNECTION_ID with Retire Prior To.
+ client.process_input(&retire_all, now());
+ // This packet contains the probe response, which should be fine, but it
+ // also includes PATH_CHALLENGE for the new path, and the client can't
+ // respond without a connection ID. We treat this as a connection error.
+ client.process_input(&resp, now());
+ assert!(matches!(
+ client.state(),
+ State::Closing {
+ error: ConnectionError::Transport(Error::InvalidMigration),
+ ..
+ }
+ ));
+}
+
+/// The timing of when frames arrive can mean that the migration path can
+/// get the last available connection ID.
+#[test]
+fn retire_prior_to_migration_success() {
+ let mut client = default_client();
+ let cid_gen: Rc<RefCell<dyn ConnectionIdGenerator>> =
+ Rc::new(RefCell::new(CountingConnectionIdGenerator::default()));
+ let mut server = Connection::new_server(
+ test_fixture::DEFAULT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::clone(&cid_gen),
+ ConnectionParameters::default(),
+ )
+ .unwrap();
+ connect_force_idle(&mut client, &mut server);
+
+ let original_cid = ConnectionId::from(get_cid(&send_something(&mut client, now())));
+
+ client
+ .migrate(Some(addr_v4()), Some(addr_v4()), false, now())
+ .unwrap();
+
+ // The client now probes the new path.
+ let probe = client.process_output(now()).dgram().unwrap();
+ assert_v4_path(&probe, true);
+ assert_eq!(client.stats().frame_tx.path_challenge, 1);
+ let probe_cid = ConnectionId::from(get_cid(&probe));
+ assert_ne!(original_cid, probe_cid);
+
+ // Have the server receive the probe, but separately have it decide to
+ // retire all of the available connection IDs.
+ server.test_frame_writer = Some(Box::new(RetireAll { cid_gen }));
+ let retire_all = send_something(&mut server, now());
+ server.test_frame_writer = None;
+
+ let resp = server.process(Some(&probe), now()).dgram().unwrap();
+ assert_v4_path(&resp, true);
+ assert_eq!(server.stats().frame_tx.path_response, 1);
+ assert_eq!(server.stats().frame_tx.path_challenge, 1);
+
+ // Have the client receive the NEW_CONNECTION_ID with Retire Prior To second.
+ // As this occurs in a very specific order, migration succeeds.
+ client.process_input(&resp, now());
+ client.process_input(&retire_all, now());
+
+ // Migration succeeds and the new path gets the last connection ID.
+ let dgram = send_something(&mut client, now());
+ assert_v4_path(&dgram, false);
+ assert_ne!(get_cid(&dgram), original_cid);
+ assert_ne!(get_cid(&dgram), probe_cid);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/mod.rs b/third_party/rust/neqo-transport/src/connection/tests/mod.rs
new file mode 100644
index 0000000000..8a999f4048
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/mod.rs
@@ -0,0 +1,614 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![deny(clippy::pedantic)]
+
+use std::{
+ cell::RefCell,
+ cmp::min,
+ convert::TryFrom,
+ mem,
+ rc::Rc,
+ time::{Duration, Instant},
+};
+
+use enum_map::enum_map;
+use neqo_common::{event::Provider, qdebug, qtrace, Datagram, Decoder, Role};
+use neqo_crypto::{random, AllowZeroRtt, AuthenticationStatus, ResumptionToken};
+use test_fixture::{self, addr, fixture_init, new_neqo_qlog, now};
+
+use super::{Connection, ConnectionError, ConnectionId, Output, State};
+use crate::{
+ addr_valid::{AddressValidation, ValidateAddress},
+ cc::{CWND_INITIAL_PKTS, CWND_MIN},
+ cid::ConnectionIdRef,
+ events::ConnectionEvent,
+ frame::FRAME_TYPE_PING,
+ packet::PacketBuilder,
+ path::PATH_MTU_V6,
+ recovery::ACK_ONLY_SIZE_LIMIT,
+ stats::{FrameStats, Stats, MAX_PTO_COUNTS},
+ ConnectionIdDecoder, ConnectionIdGenerator, ConnectionParameters, Error, StreamId, StreamType,
+ Version,
+};
+
+// All the tests.
+mod ackrate;
+mod cc;
+mod close;
+mod datagram;
+mod fuzzing;
+mod handshake;
+mod idle;
+mod keys;
+mod migration;
+mod priority;
+mod recovery;
+mod resumption;
+mod stream;
+mod vn;
+mod zerortt;
+
+const DEFAULT_RTT: Duration = Duration::from_millis(100);
+const AT_LEAST_PTO: Duration = Duration::from_secs(1);
+const DEFAULT_STREAM_DATA: &[u8] = b"message";
+/// The number of 1-RTT packets sent in `force_idle` by a client.
+const CLIENT_HANDSHAKE_1RTT_PACKETS: usize = 1;
+
+/// WARNING! In this module, this version of the generator needs to be used.
+/// This copies the implementation from
+/// `test_fixture::CountingConnectionIdGenerator`, but it uses the different
+/// types that are exposed to this module. See also `default_client`.
+///
+/// This version doesn't randomize the length; as the congestion control tests
+/// count the amount of data sent precisely.
+#[derive(Debug, Default)]
+pub struct CountingConnectionIdGenerator {
+ counter: u32,
+}
+
+impl ConnectionIdDecoder for CountingConnectionIdGenerator {
+ fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> {
+ let len = usize::from(dec.peek_byte().unwrap());
+ dec.decode(len).map(ConnectionIdRef::from)
+ }
+}
+
+impl ConnectionIdGenerator for CountingConnectionIdGenerator {
+ fn generate_cid(&mut self) -> Option<ConnectionId> {
+ let mut r = random(20);
+ r[0] = 8;
+ r[1] = u8::try_from(self.counter >> 24).unwrap();
+ r[2] = u8::try_from((self.counter >> 16) & 0xff).unwrap();
+ r[3] = u8::try_from((self.counter >> 8) & 0xff).unwrap();
+ r[4] = u8::try_from(self.counter & 0xff).unwrap();
+ self.counter += 1;
+ Some(ConnectionId::from(&r[..8]))
+ }
+
+ fn as_decoder(&self) -> &dyn ConnectionIdDecoder {
+ self
+ }
+}
+
+// This is fabulous: because test_fixture uses the public API for Connection,
+// it gets a different type to the ones that are referenced via super::super::*.
+// Thus, this code can't use default_client() and default_server() from
+// test_fixture because they produce different - and incompatible - types.
+//
+// These are a direct copy of those functions.
+pub fn new_client(params: ConnectionParameters) -> Connection {
+ fixture_init();
+ let (log, _contents) = new_neqo_qlog();
+ let mut client = Connection::new_client(
+ test_fixture::DEFAULT_SERVER_NAME,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(CountingConnectionIdGenerator::default())),
+ addr(),
+ addr(),
+ params,
+ now(),
+ )
+ .expect("create a default client");
+ client.set_qlog(log);
+ client
+}
+
+pub fn default_client() -> Connection {
+ new_client(ConnectionParameters::default())
+}
+
+pub fn new_server(params: ConnectionParameters) -> Connection {
+ fixture_init();
+ let (log, _contents) = new_neqo_qlog();
+ let mut c = Connection::new_server(
+ test_fixture::DEFAULT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(CountingConnectionIdGenerator::default())),
+ params,
+ )
+ .expect("create a default server");
+ c.set_qlog(log);
+ c.server_enable_0rtt(&test_fixture::anti_replay(), AllowZeroRtt {})
+ .expect("enable 0-RTT");
+ c
+}
+pub fn default_server() -> Connection {
+ new_server(ConnectionParameters::default())
+}
+pub fn resumed_server(client: &Connection) -> Connection {
+ new_server(ConnectionParameters::default().versions(client.version(), Version::all()))
+}
+
+/// If state is `AuthenticationNeeded` call `authenticated()`. This function will
+/// consume all outstanding events on the connection.
+pub fn maybe_authenticate(conn: &mut Connection) -> bool {
+ let authentication_needed = |e| matches!(e, ConnectionEvent::AuthenticationNeeded);
+ if conn.events().any(authentication_needed) {
+ conn.authenticated(AuthenticationStatus::Ok, now());
+ return true;
+ }
+ false
+}
+
+/// Compute the RTT variance after `n` ACKs or other RTT updates.
+pub fn rttvar_after_n_updates(n: usize, rtt: Duration) -> Duration {
+ assert!(n > 0);
+ let mut rttvar = rtt / 2;
+ for _ in 1..n {
+ rttvar = rttvar * 3 / 4;
+ }
+ rttvar
+}
+
+/// This inserts a PING frame into packets.
+struct PingWriter {}
+
+impl crate::connection::test_internal::FrameWriter for PingWriter {
+ fn write_frames(&mut self, builder: &mut PacketBuilder) {
+ builder.encode_varint(FRAME_TYPE_PING);
+ }
+}
+
+/// Drive the handshake between the client and server.
+fn handshake(
+ client: &mut Connection,
+ server: &mut Connection,
+ now: Instant,
+ rtt: Duration,
+) -> Instant {
+ let mut a = client;
+ let mut b = server;
+ let mut now = now;
+
+ let mut input = None;
+ let is_done = |c: &mut Connection| {
+ matches!(
+ c.state(),
+ State::Confirmed | State::Closing { .. } | State::Closed(..)
+ )
+ };
+
+ let mut did_ping = enum_map! {_ => false};
+ while !is_done(a) {
+ _ = maybe_authenticate(a);
+ let had_input = input.is_some();
+ // Insert a PING frame into the first application data packet an endpoint sends,
+ // in order to force the peer to ACK it. For the server, this is depending on the
+ // client's connection state, which is accessible during the tests.
+ //
+ // We're doing this to prevent packet loss from delaying ACKs, which would cause
+ // cwnd to shrink, and also to prevent the delayed ACK timer from being armed after
+ // the handshake, which is not something the tests are written to account for.
+ let should_ping = !did_ping[a.role()]
+ && (a.role() == Role::Client && *a.state() == State::Connected
+ || (a.role() == Role::Server && *b.state() == State::Connected));
+ if should_ping {
+ a.test_frame_writer = Some(Box::new(PingWriter {}));
+ }
+ let output = a.process(input.as_ref(), now).dgram();
+ if should_ping {
+ a.test_frame_writer = None;
+ did_ping[a.role()] = true;
+ }
+ assert!(had_input || output.is_some());
+ input = output;
+ qtrace!("handshake: t += {:?}", rtt / 2);
+ now += rtt / 2;
+ mem::swap(&mut a, &mut b);
+ }
+ if let Some(d) = input {
+ a.process_input(&d, now);
+ }
+ now
+}
+
+fn connect_fail(
+ client: &mut Connection,
+ server: &mut Connection,
+ client_error: Error,
+ server_error: Error,
+) {
+ handshake(client, server, now(), Duration::new(0, 0));
+ assert_error(client, &ConnectionError::Transport(client_error));
+ assert_error(server, &ConnectionError::Transport(server_error));
+}
+
+fn connect_with_rtt(
+ client: &mut Connection,
+ server: &mut Connection,
+ now: Instant,
+ rtt: Duration,
+) -> Instant {
+ fn check_rtt(stats: &Stats, rtt: Duration) {
+ assert_eq!(stats.rtt, rtt);
+ // Validate that rttvar has been computed correctly based on the number of RTT updates.
+ let n = stats.frame_rx.ack + usize::from(stats.rtt_init_guess);
+ assert_eq!(stats.rttvar, rttvar_after_n_updates(n, rtt));
+ }
+ let now = handshake(client, server, now, rtt);
+ assert_eq!(*client.state(), State::Confirmed);
+ assert_eq!(*server.state(), State::Confirmed);
+
+ check_rtt(&client.stats(), rtt);
+ check_rtt(&server.stats(), rtt);
+ now
+}
+
+fn connect(client: &mut Connection, server: &mut Connection) {
+ connect_with_rtt(client, server, now(), Duration::new(0, 0));
+}
+
+fn assert_error(c: &Connection, expected: &ConnectionError) {
+ match c.state() {
+ State::Closing { error, .. } | State::Draining { error, .. } | State::Closed(error) => {
+ assert_eq!(*error, *expected, "{c} error mismatch");
+ }
+ _ => panic!("bad state {:?}", c.state()),
+ }
+}
+
+fn exchange_ticket(
+ client: &mut Connection,
+ server: &mut Connection,
+ now: Instant,
+) -> ResumptionToken {
+ let validation = AddressValidation::new(now, ValidateAddress::NoToken).unwrap();
+ let validation = Rc::new(RefCell::new(validation));
+ server.set_validation(Rc::clone(&validation));
+ server.send_ticket(now, &[]).expect("can send ticket");
+ let ticket = server.process_output(now).dgram();
+ assert!(ticket.is_some());
+ client.process_input(&ticket.unwrap(), now);
+ assert_eq!(*client.state(), State::Confirmed);
+ get_tokens(client).pop().expect("should have token")
+}
+
+/// The `handshake` method inserts PING frames into the first application data packets,
+/// which forces each peer to ACK them. As a side effect, that causes both sides of the
+/// connection to be idle aftwerwards. This method simply verifies that this is the case.
+fn assert_idle(client: &mut Connection, server: &mut Connection, rtt: Duration, now: Instant) {
+ let idle_timeout = min(
+ client.conn_params.get_idle_timeout(),
+ server.conn_params.get_idle_timeout(),
+ );
+ // Client started its idle period half an RTT before now.
+ assert_eq!(
+ client.process_output(now),
+ Output::Callback(idle_timeout - rtt / 2)
+ );
+ assert_eq!(server.process_output(now), Output::Callback(idle_timeout));
+}
+
+/// Connect with an RTT and then force both peers to be idle.
+fn connect_rtt_idle(client: &mut Connection, server: &mut Connection, rtt: Duration) -> Instant {
+ let now = connect_with_rtt(client, server, now(), rtt);
+ assert_idle(client, server, rtt, now);
+ // Drain events from both as well.
+ _ = client.events().count();
+ _ = server.events().count();
+ qtrace!("----- connected and idle with RTT {:?}", rtt);
+ now
+}
+
+fn connect_force_idle(client: &mut Connection, server: &mut Connection) {
+ connect_rtt_idle(client, server, Duration::new(0, 0));
+}
+
+fn fill_stream(c: &mut Connection, stream: StreamId) {
+ const BLOCK_SIZE: usize = 4_096;
+ loop {
+ let bytes_sent = c.stream_send(stream, &[0x42; BLOCK_SIZE]).unwrap();
+ qtrace!("fill_cwnd wrote {} bytes", bytes_sent);
+ if bytes_sent < BLOCK_SIZE {
+ break;
+ }
+ }
+}
+
+/// This fills the congestion window from a single source.
+/// As the pacer will interfere with this, this moves time forward
+/// as `Output::Callback` is received. Because it is hard to tell
+/// from the return value whether a timeout is an ACK delay, PTO, or
+/// pacing, this looks at the congestion window to tell when to stop.
+/// Returns a list of datagrams and the new time.
+fn fill_cwnd(c: &mut Connection, stream: StreamId, mut now: Instant) -> (Vec<Datagram>, Instant) {
+ // Train wreck function to get the remaining congestion window on the primary path.
+ fn cwnd(c: &Connection) -> usize {
+ c.paths.primary().borrow().sender().cwnd_avail()
+ }
+
+ qtrace!("fill_cwnd starting cwnd: {}", cwnd(c));
+ fill_stream(c, stream);
+
+ let mut total_dgrams = Vec::new();
+ loop {
+ let pkt = c.process_output(now);
+ qtrace!("fill_cwnd cwnd remaining={}, output: {:?}", cwnd(c), pkt);
+ match pkt {
+ Output::Datagram(dgram) => {
+ total_dgrams.push(dgram);
+ }
+ Output::Callback(t) => {
+ if cwnd(c) < ACK_ONLY_SIZE_LIMIT {
+ break;
+ }
+ now += t;
+ }
+ Output::None => panic!(),
+ }
+ }
+
+ qtrace!(
+ "fill_cwnd sent {} bytes",
+ total_dgrams.iter().map(|d| d.len()).sum::<usize>()
+ );
+ (total_dgrams, now)
+}
+
+/// This function is like the combination of `fill_cwnd` and `ack_bytes`.
+/// However, it acknowledges everything inline and preserves an RTT of `DEFAULT_RTT`.
+fn increase_cwnd(
+ sender: &mut Connection,
+ receiver: &mut Connection,
+ stream: StreamId,
+ mut now: Instant,
+) -> Instant {
+ fill_stream(sender, stream);
+ loop {
+ let pkt = sender.process_output(now);
+ match pkt {
+ Output::Datagram(dgram) => {
+ receiver.process_input(&dgram, now + DEFAULT_RTT / 2);
+ }
+ Output::Callback(t) => {
+ if t < DEFAULT_RTT {
+ now += t;
+ } else {
+ break; // We're on PTO now.
+ }
+ }
+ Output::None => panic!(),
+ }
+ }
+
+ // Now acknowledge all those packets at once.
+ now += DEFAULT_RTT / 2;
+ let ack = receiver.process_output(now).dgram();
+ now += DEFAULT_RTT / 2;
+ sender.process_input(&ack.unwrap(), now);
+ now
+}
+
+/// Receive multiple packets and generate an ack-only packet.
+///
+/// # Panics
+///
+/// The caller is responsible for ensuring that `dest` has received
+/// enough data that it wants to generate an ACK. This panics if
+/// no ACK frame is generated.
+fn ack_bytes<D>(dest: &mut Connection, stream: StreamId, in_dgrams: D, now: Instant) -> Datagram
+where
+ D: IntoIterator<Item = Datagram>,
+ D::IntoIter: ExactSizeIterator,
+{
+ let mut srv_buf = [0; 4_096];
+
+ let in_dgrams = in_dgrams.into_iter();
+ qdebug!([dest], "ack_bytes {} datagrams", in_dgrams.len());
+ for dgram in in_dgrams {
+ dest.process_input(&dgram, now);
+ }
+
+ loop {
+ let (bytes_read, _fin) = dest.stream_recv(stream, &mut srv_buf).unwrap();
+ qtrace!([dest], "ack_bytes read {} bytes", bytes_read);
+ if bytes_read == 0 {
+ break;
+ }
+ }
+
+ dest.process_output(now).dgram().unwrap()
+}
+
+// Get the current congestion window for the connection.
+fn cwnd(c: &Connection) -> usize {
+ c.paths.primary().borrow().sender().cwnd()
+}
+fn cwnd_avail(c: &Connection) -> usize {
+ c.paths.primary().borrow().sender().cwnd_avail()
+}
+
+fn induce_persistent_congestion(
+ client: &mut Connection,
+ server: &mut Connection,
+ stream: StreamId,
+ mut now: Instant,
+) -> Instant {
+ // Note: wait some arbitrary time that should be longer than pto
+ // timer. This is rather brittle.
+ qtrace!([client], "induce_persistent_congestion");
+ now += AT_LEAST_PTO;
+
+ let mut pto_counts = [0; MAX_PTO_COUNTS];
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ qtrace!([client], "first PTO");
+ let (c_tx_dgrams, next_now) = fill_cwnd(client, stream, now);
+ now = next_now;
+ assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets
+
+ pto_counts[0] = 1;
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ qtrace!([client], "second PTO");
+ now += AT_LEAST_PTO * 2;
+ let (c_tx_dgrams, next_now) = fill_cwnd(client, stream, now);
+ now = next_now;
+ assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets
+
+ pto_counts[0] = 0;
+ pto_counts[1] = 1;
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ qtrace!([client], "third PTO");
+ now += AT_LEAST_PTO * 4;
+ let (c_tx_dgrams, next_now) = fill_cwnd(client, stream, now);
+ now = next_now;
+ assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets
+
+ pto_counts[1] = 0;
+ pto_counts[2] = 1;
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ // An ACK for the third PTO causes persistent congestion.
+ let s_ack = ack_bytes(server, stream, c_tx_dgrams, now);
+ client.process_input(&s_ack, now);
+ assert_eq!(cwnd(client), CWND_MIN);
+ now
+}
+
+/// This magic number is the size of the client's CWND after the handshake completes.
+/// This is the same as the initial congestion window, because during the handshake
+/// the cc is app limited and cwnd is not increased.
+///
+/// As we change how we build packets, or even as NSS changes,
+/// this number might be different. The tests that depend on this
+/// value could fail as a result of variations, so it's OK to just
+/// change this value, but it is good to first understand where the
+/// change came from.
+const POST_HANDSHAKE_CWND: usize = PATH_MTU_V6 * CWND_INITIAL_PKTS;
+
+/// Determine the number of packets required to fill the CWND.
+const fn cwnd_packets(data: usize) -> usize {
+ // Add one if the last chunk is >= ACK_ONLY_SIZE_LIMIT.
+ (data + PATH_MTU_V6 - ACK_ONLY_SIZE_LIMIT) / PATH_MTU_V6
+}
+
+/// Determine the size of the last packet.
+/// The minimal size of a packet is `ACK_ONLY_SIZE_LIMIT`.
+fn last_packet(cwnd: usize) -> usize {
+ if (cwnd % PATH_MTU_V6) > ACK_ONLY_SIZE_LIMIT {
+ cwnd % PATH_MTU_V6
+ } else {
+ PATH_MTU_V6
+ }
+}
+
+/// Assert that the set of packets fill the CWND.
+fn assert_full_cwnd(packets: &[Datagram], cwnd: usize) {
+ assert_eq!(packets.len(), cwnd_packets(cwnd));
+ let (last, rest) = packets.split_last().unwrap();
+ assert!(rest.iter().all(|d| d.len() == PATH_MTU_V6));
+ assert_eq!(last.len(), last_packet(cwnd));
+}
+
+/// Send something on a stream from `sender` to `receiver`, maybe allowing for pacing.
+/// Return the resulting datagram and the new time.
+#[must_use]
+fn send_something_paced(
+ sender: &mut Connection,
+ mut now: Instant,
+ allow_pacing: bool,
+) -> (Datagram, Instant) {
+ let stream_id = sender.stream_create(StreamType::UniDi).unwrap();
+ assert!(sender.stream_send(stream_id, DEFAULT_STREAM_DATA).is_ok());
+ assert!(sender.stream_close_send(stream_id).is_ok());
+ qdebug!([sender], "send_something on {}", stream_id);
+ let dgram = match sender.process_output(now) {
+ Output::Callback(t) => {
+ assert!(allow_pacing, "send_something: unexpected delay");
+ now += t;
+ sender
+ .process_output(now)
+ .dgram()
+ .expect("send_something: should have something to send")
+ }
+ Output::Datagram(d) => d,
+ Output::None => panic!("send_something: got Output::None"),
+ };
+ (dgram, now)
+}
+
+/// Send something on a stream from `sender` to `receiver`.
+/// Return the resulting datagram.
+fn send_something(sender: &mut Connection, now: Instant) -> Datagram {
+ send_something_paced(sender, now, false).0
+}
+
+/// Send something on a stream from `sender` to `receiver`.
+/// Return any ACK that might result.
+fn send_and_receive(
+ sender: &mut Connection,
+ receiver: &mut Connection,
+ now: Instant,
+) -> Option<Datagram> {
+ let dgram = send_something(sender, now);
+ receiver.process(Some(&dgram), now).dgram()
+}
+
+fn get_tokens(client: &mut Connection) -> Vec<ResumptionToken> {
+ client
+ .events()
+ .filter_map(|e| {
+ if let ConnectionEvent::ResumptionToken(token) = e {
+ Some(token)
+ } else {
+ None
+ }
+ })
+ .collect()
+}
+
+fn assert_default_stats(stats: &Stats) {
+ assert_eq!(stats.packets_rx, 0);
+ assert_eq!(stats.packets_tx, 0);
+ let dflt_frames = FrameStats::default();
+ assert_eq!(stats.frame_rx, dflt_frames);
+ assert_eq!(stats.frame_tx, dflt_frames);
+}
+
+#[test]
+fn create_client() {
+ let client = default_client();
+ assert_eq!(client.role(), Role::Client);
+ assert!(matches!(client.state(), State::Init));
+ let stats = client.stats();
+ assert_default_stats(&stats);
+ assert_eq!(stats.rtt, crate::rtt::INITIAL_RTT);
+ assert_eq!(stats.rttvar, crate::rtt::INITIAL_RTT / 2);
+}
+
+#[test]
+fn create_server() {
+ let server = default_server();
+ assert_eq!(server.role(), Role::Server);
+ assert!(matches!(server.state(), State::Init));
+ let stats = server.stats();
+ assert_default_stats(&stats);
+ // Server won't have a default path, so no RTT.
+ assert_eq!(stats.rtt, Duration::from_secs(0));
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/priority.rs b/third_party/rust/neqo-transport/src/connection/tests/priority.rs
new file mode 100644
index 0000000000..1f86aa22e5
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/priority.rs
@@ -0,0 +1,404 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{cell::RefCell, mem, rc::Rc};
+
+use neqo_common::event::Provider;
+use test_fixture::{self, now};
+
+use super::{
+ super::{Connection, Error, Output},
+ connect, default_client, default_server, fill_cwnd, maybe_authenticate,
+};
+use crate::{
+ addr_valid::{AddressValidation, ValidateAddress},
+ send_stream::{RetransmissionPriority, TransmissionPriority},
+ ConnectionEvent, StreamId, StreamType,
+};
+
+const BLOCK_SIZE: usize = 4_096;
+
+fn fill_stream(c: &mut Connection, id: StreamId) {
+ loop {
+ if c.stream_send(id, &[0x42; BLOCK_SIZE]).unwrap() < BLOCK_SIZE {
+ return;
+ }
+ }
+}
+
+/// A receive stream cannot be prioritized (yet).
+#[test]
+fn receive_stream() {
+ const MESSAGE: &[u8] = b"hello";
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let id = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(MESSAGE.len(), client.stream_send(id, MESSAGE).unwrap());
+ let dgram = client.process_output(now()).dgram();
+
+ server.process_input(&dgram.unwrap(), now());
+ assert_eq!(
+ server
+ .stream_priority(
+ id,
+ TransmissionPriority::default(),
+ RetransmissionPriority::default()
+ )
+ .unwrap_err(),
+ Error::InvalidStreamId,
+ "Priority doesn't apply to inbound unidirectional streams"
+ );
+
+ // But the stream does exist and can be read.
+ let mut buf = [0; 10];
+ let (len, end) = server.stream_recv(id, &mut buf).unwrap();
+ assert_eq!(MESSAGE, &buf[..len]);
+ assert!(!end);
+}
+
+/// Higher priority streams get sent ahead of lower ones, even when
+/// the higher priority stream is written to later.
+#[test]
+fn relative() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // id_normal is created first, but it is lower priority.
+ let id_normal = client.stream_create(StreamType::UniDi).unwrap();
+ fill_stream(&mut client, id_normal);
+ let high = client.stream_create(StreamType::UniDi).unwrap();
+ fill_stream(&mut client, high);
+ client
+ .stream_priority(
+ high,
+ TransmissionPriority::High,
+ RetransmissionPriority::default(),
+ )
+ .unwrap();
+
+ let dgram = client.process_output(now()).dgram();
+ server.process_input(&dgram.unwrap(), now());
+
+ // The "id_normal" stream will get a `NewStream` event, but no data.
+ for e in server.events() {
+ if let ConnectionEvent::RecvStreamReadable { stream_id } = e {
+ assert_ne!(stream_id, id_normal);
+ }
+ }
+}
+
+/// Check that changing priority has effect on the next packet that is sent.
+#[test]
+fn reprioritize() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // id_normal is created first, but it is lower priority.
+ let id_normal = client.stream_create(StreamType::UniDi).unwrap();
+ fill_stream(&mut client, id_normal);
+ let id_high = client.stream_create(StreamType::UniDi).unwrap();
+ fill_stream(&mut client, id_high);
+ client
+ .stream_priority(
+ id_high,
+ TransmissionPriority::High,
+ RetransmissionPriority::default(),
+ )
+ .unwrap();
+
+ let dgram = client.process_output(now()).dgram();
+ server.process_input(&dgram.unwrap(), now());
+
+ // The "id_normal" stream will get a `NewStream` event, but no data.
+ for e in server.events() {
+ if let ConnectionEvent::RecvStreamReadable { stream_id } = e {
+ assert_ne!(stream_id, id_normal);
+ }
+ }
+
+ // When the high priority stream drops in priority, the streams are equal
+ // priority and so their stream ID determines what is sent.
+ client
+ .stream_priority(
+ id_high,
+ TransmissionPriority::Normal,
+ RetransmissionPriority::default(),
+ )
+ .unwrap();
+ let dgram = client.process_output(now()).dgram();
+ server.process_input(&dgram.unwrap(), now());
+
+ for e in server.events() {
+ if let ConnectionEvent::RecvStreamReadable { stream_id } = e {
+ assert_ne!(stream_id, id_high);
+ }
+ }
+}
+
+/// Retransmission can be prioritized differently (usually higher).
+#[test]
+fn repairing_loss() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let mut now = now();
+
+ // Send a few packets at low priority, lose one.
+ let id_low = client.stream_create(StreamType::UniDi).unwrap();
+ fill_stream(&mut client, id_low);
+ client
+ .stream_priority(
+ id_low,
+ TransmissionPriority::Low,
+ RetransmissionPriority::Higher,
+ )
+ .unwrap();
+
+ let _lost = client.process_output(now).dgram();
+ for _ in 0..5 {
+ match client.process_output(now) {
+ Output::Datagram(d) => server.process_input(&d, now),
+ Output::Callback(delay) => now += delay,
+ Output::None => unreachable!(),
+ }
+ }
+
+ // Generate an ACK. The first packet is now considered lost.
+ let ack = server.process_output(now).dgram();
+ _ = server.events().count(); // Drain events.
+
+ let id_normal = client.stream_create(StreamType::UniDi).unwrap();
+ fill_stream(&mut client, id_normal);
+
+ let dgram = client.process(ack.as_ref(), now).dgram();
+ assert_eq!(client.stats().lost, 1); // Client should have noticed the loss.
+ server.process_input(&dgram.unwrap(), now);
+
+ // Only the low priority stream has data as the retransmission of the data from
+ // the lost packet is now more important than new data from the high priority stream.
+ for e in server.events() {
+ println!("Event: {e:?}");
+ if let ConnectionEvent::RecvStreamReadable { stream_id } = e {
+ assert_eq!(stream_id, id_low);
+ }
+ }
+
+ // However, only the retransmission is prioritized.
+ // Though this might contain some retransmitted data, as other frames might push
+ // the retransmitted data into a second packet, it will also contain data from the
+ // normal priority stream.
+ let dgram = client.process_output(now).dgram();
+ server.process_input(&dgram.unwrap(), now);
+ assert!(server.events().any(
+ |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id } if stream_id == id_normal),
+ ));
+}
+
+#[test]
+fn critical() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = now();
+
+ // Rather than connect, send stream data in 0.5-RTT.
+ // That allows this to test that critical streams pre-empt most frame types.
+ let dgram = client.process_output(now).dgram();
+ let dgram = server.process(dgram.as_ref(), now).dgram();
+ client.process_input(&dgram.unwrap(), now);
+ maybe_authenticate(&mut client);
+
+ let id = server.stream_create(StreamType::UniDi).unwrap();
+ server
+ .stream_priority(
+ id,
+ TransmissionPriority::Critical,
+ RetransmissionPriority::default(),
+ )
+ .unwrap();
+
+ // Can't use fill_cwnd here because the server is blocked on the amplification
+ // limit, so it can't fill the congestion window.
+ while server.stream_create(StreamType::UniDi).is_ok() {}
+
+ fill_stream(&mut server, id);
+ let stats_before = server.stats().frame_tx;
+ let dgram = server.process_output(now).dgram();
+ let stats_after = server.stats().frame_tx;
+ assert_eq!(stats_after.crypto, stats_before.crypto);
+ assert_eq!(stats_after.streams_blocked, 0);
+ assert_eq!(stats_after.new_connection_id, 0);
+ assert_eq!(stats_after.new_token, 0);
+ assert_eq!(stats_after.handshake_done, 0);
+
+ // Complete the handshake.
+ let dgram = client.process(dgram.as_ref(), now).dgram();
+ server.process_input(&dgram.unwrap(), now);
+
+ // Critical beats everything but HANDSHAKE_DONE.
+ let stats_before = server.stats().frame_tx;
+ mem::drop(fill_cwnd(&mut server, id, now));
+ let stats_after = server.stats().frame_tx;
+ assert_eq!(stats_after.crypto, stats_before.crypto);
+ assert_eq!(stats_after.streams_blocked, 0);
+ assert_eq!(stats_after.new_connection_id, 0);
+ assert_eq!(stats_after.new_token, 0);
+ assert_eq!(stats_after.handshake_done, 1);
+}
+
+#[test]
+fn important() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = now();
+
+ // Rather than connect, send stream data in 0.5-RTT.
+ // That allows this to test that important streams pre-empt most frame types.
+ let dgram = client.process_output(now).dgram();
+ let dgram = server.process(dgram.as_ref(), now).dgram();
+ client.process_input(&dgram.unwrap(), now);
+ maybe_authenticate(&mut client);
+
+ let id = server.stream_create(StreamType::UniDi).unwrap();
+ server
+ .stream_priority(
+ id,
+ TransmissionPriority::Important,
+ RetransmissionPriority::default(),
+ )
+ .unwrap();
+ fill_stream(&mut server, id);
+
+ // Important beats everything but flow control.
+ // Make enough streams to get a STREAMS_BLOCKED frame out.
+ while server.stream_create(StreamType::UniDi).is_ok() {}
+
+ let stats_before = server.stats().frame_tx;
+ let dgram = server.process_output(now).dgram();
+ let stats_after = server.stats().frame_tx;
+ assert_eq!(stats_after.crypto, stats_before.crypto);
+ assert_eq!(stats_after.streams_blocked, 1);
+ assert_eq!(stats_after.new_connection_id, 0);
+ assert_eq!(stats_after.new_token, 0);
+ assert_eq!(stats_after.handshake_done, 0);
+ assert_eq!(stats_after.stream, stats_before.stream + 1);
+
+ // Complete the handshake.
+ let dgram = client.process(dgram.as_ref(), now).dgram();
+ server.process_input(&dgram.unwrap(), now);
+
+ // Important beats everything but flow control.
+ let stats_before = server.stats().frame_tx;
+ mem::drop(fill_cwnd(&mut server, id, now));
+ let stats_after = server.stats().frame_tx;
+ assert_eq!(stats_after.crypto, stats_before.crypto);
+ assert_eq!(stats_after.streams_blocked, 1);
+ assert_eq!(stats_after.new_connection_id, 0);
+ assert_eq!(stats_after.new_token, 0);
+ assert_eq!(stats_after.handshake_done, 1);
+ assert!(stats_after.stream > stats_before.stream);
+}
+
+#[test]
+fn high_normal() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = now();
+
+ // Rather than connect, send stream data in 0.5-RTT.
+ // That allows this to test that important streams pre-empt most frame types.
+ let dgram = client.process_output(now).dgram();
+ let dgram = server.process(dgram.as_ref(), now).dgram();
+ client.process_input(&dgram.unwrap(), now);
+ maybe_authenticate(&mut client);
+
+ let id = server.stream_create(StreamType::UniDi).unwrap();
+ server
+ .stream_priority(
+ id,
+ TransmissionPriority::High,
+ RetransmissionPriority::default(),
+ )
+ .unwrap();
+ fill_stream(&mut server, id);
+
+ // Important beats everything but flow control.
+ // Make enough streams to get a STREAMS_BLOCKED frame out.
+ while server.stream_create(StreamType::UniDi).is_ok() {}
+
+ let stats_before = server.stats().frame_tx;
+ let dgram = server.process_output(now).dgram();
+ let stats_after = server.stats().frame_tx;
+ assert_eq!(stats_after.crypto, stats_before.crypto);
+ assert_eq!(stats_after.streams_blocked, 1);
+ assert_eq!(stats_after.new_connection_id, 0);
+ assert_eq!(stats_after.new_token, 0);
+ assert_eq!(stats_after.handshake_done, 0);
+ assert_eq!(stats_after.stream, stats_before.stream + 1);
+
+ // Complete the handshake.
+ let dgram = client.process(dgram.as_ref(), now).dgram();
+ server.process_input(&dgram.unwrap(), now);
+
+ // High or Normal doesn't beat NEW_CONNECTION_ID,
+ // but they beat CRYPTO/NEW_TOKEN.
+ let stats_before = server.stats().frame_tx;
+ server.send_ticket(now, &[]).unwrap();
+ mem::drop(fill_cwnd(&mut server, id, now));
+ let stats_after = server.stats().frame_tx;
+ assert_eq!(stats_after.crypto, stats_before.crypto);
+ assert_eq!(stats_after.streams_blocked, 1);
+ assert_ne!(stats_after.new_connection_id, 0); // Note: > 0
+ assert_eq!(stats_after.new_token, 0);
+ assert_eq!(stats_after.handshake_done, 1);
+ assert!(stats_after.stream > stats_before.stream);
+}
+
+#[test]
+fn low() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = now();
+ // Use address validation; note that we need to hold a strong reference
+ // as the server will only hold a weak reference.
+ let validation = Rc::new(RefCell::new(
+ AddressValidation::new(now, ValidateAddress::Never).unwrap(),
+ ));
+ server.set_validation(Rc::clone(&validation));
+ connect(&mut client, &mut server);
+
+ let id = server.stream_create(StreamType::UniDi).unwrap();
+ server
+ .stream_priority(
+ id,
+ TransmissionPriority::Low,
+ RetransmissionPriority::default(),
+ )
+ .unwrap();
+ fill_stream(&mut server, id);
+
+ // Send a session ticket and make it big enough to require a whole packet.
+ // The resulting CRYPTO frame beats out the stream data.
+ let stats_before = server.stats().frame_tx;
+ server.send_ticket(now, &[0; 2048]).unwrap();
+ mem::drop(server.process_output(now));
+ let stats_after = server.stats().frame_tx;
+ assert_eq!(stats_after.crypto, stats_before.crypto + 1);
+ assert_eq!(stats_after.stream, stats_before.stream);
+
+ // The above can't test if NEW_TOKEN wins because once that fits in a packet,
+ // it is very hard to ensure that the STREAM frame won't also fit.
+ // However, we can ensure that the next packet doesn't consist of just STREAM.
+ let stats_before = server.stats().frame_tx;
+ mem::drop(server.process_output(now));
+ let stats_after = server.stats().frame_tx;
+ assert_eq!(stats_after.crypto, stats_before.crypto + 1);
+ assert_eq!(stats_after.new_token, 1);
+ assert_eq!(stats_after.stream, stats_before.stream + 1);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/recovery.rs b/third_party/rust/neqo-transport/src/connection/tests/recovery.rs
new file mode 100644
index 0000000000..0f12d03107
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/recovery.rs
@@ -0,0 +1,804 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{
+ mem,
+ time::{Duration, Instant},
+};
+
+use neqo_common::qdebug;
+use neqo_crypto::AuthenticationStatus;
+use test_fixture::{
+ assertions::{assert_handshake, assert_initial},
+ now, split_datagram,
+};
+
+use super::{
+ super::{Connection, ConnectionParameters, Output, State},
+ assert_full_cwnd, connect, connect_force_idle, connect_rtt_idle, connect_with_rtt, cwnd,
+ default_client, default_server, fill_cwnd, maybe_authenticate, new_client, send_and_receive,
+ send_something, AT_LEAST_PTO, DEFAULT_RTT, DEFAULT_STREAM_DATA, POST_HANDSHAKE_CWND,
+};
+use crate::{
+ cc::CWND_MIN,
+ path::PATH_MTU_V6,
+ recovery::{
+ FAST_PTO_SCALE, MAX_OUTSTANDING_UNACK, MAX_PTO_PACKET_COUNT, MIN_OUTSTANDING_UNACK,
+ },
+ rtt::GRANULARITY,
+ stats::MAX_PTO_COUNTS,
+ tparams::TransportParameter,
+ tracking::DEFAULT_ACK_DELAY,
+ StreamType,
+};
+
+#[test]
+fn pto_works_basic() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let mut now = now();
+
+ let res = client.process(None, now);
+ let idle_timeout = ConnectionParameters::default().get_idle_timeout();
+ assert_eq!(res, Output::Callback(idle_timeout));
+
+ // Send data on two streams
+ let stream1 = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(client.stream_send(stream1, b"hello").unwrap(), 5);
+ assert_eq!(client.stream_send(stream1, b" world!").unwrap(), 7);
+
+ let stream2 = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(client.stream_send(stream2, b"there!").unwrap(), 6);
+
+ // Send a packet after some time.
+ now += Duration::from_secs(10);
+ let out = client.process(None, now);
+ assert!(out.dgram().is_some());
+
+ // Nothing to do, should return callback
+ let out = client.process(None, now);
+ assert!(matches!(out, Output::Callback(_)));
+
+ // One second later, it should want to send PTO packet
+ now += AT_LEAST_PTO;
+ let out = client.process(None, now);
+
+ let stream_before = server.stats().frame_rx.stream;
+ server.process_input(&out.dgram().unwrap(), now);
+ assert_eq!(server.stats().frame_rx.stream, stream_before + 2);
+}
+
+#[test]
+fn pto_works_full_cwnd() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Send lots of data.
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ let (dgrams, now) = fill_cwnd(&mut client, stream_id, now);
+ assert_full_cwnd(&dgrams, POST_HANDSHAKE_CWND);
+
+ // Fill the CWND after waiting for a PTO.
+ let (dgrams, now) = fill_cwnd(&mut client, stream_id, now + AT_LEAST_PTO);
+ // Two packets in the PTO.
+ // The first should be full sized; the second might be small.
+ assert_eq!(dgrams.len(), 2);
+ assert_eq!(dgrams[0].len(), PATH_MTU_V6);
+
+ // Both datagrams contain one or more STREAM frames.
+ for d in dgrams {
+ let stream_before = server.stats().frame_rx.stream;
+ server.process_input(&d, now);
+ assert!(server.stats().frame_rx.stream > stream_before);
+ }
+}
+
+#[test]
+fn pto_works_ping() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+ let mut now = now() + Duration::from_secs(10);
+
+ // Send a few packets from the client.
+ let pkt0 = send_something(&mut client, now);
+ let pkt1 = send_something(&mut client, now);
+ let pkt2 = send_something(&mut client, now);
+ let pkt3 = send_something(&mut client, now);
+
+ // Nothing to do, should return callback
+ let cb = client.process(None, now).callback();
+ // The PTO timer is calculated with:
+ // RTT + max(rttvar * 4, GRANULARITY) + max_ack_delay
+ // With zero RTT and rttvar, max_ack_delay is minimum too (GRANULARITY)
+ assert_eq!(cb, GRANULARITY * 2);
+
+ // Process these by server, skipping pkt0
+ let srv0 = server.process(Some(&pkt1), now).dgram();
+ assert!(srv0.is_some()); // ooo, ack client pkt1
+
+ now += Duration::from_millis(20);
+
+ // process pkt2 (immediate ack because last ack was more than an RTT ago; RTT=0)
+ let srv1 = server.process(Some(&pkt2), now).dgram();
+ assert!(srv1.is_some()); // this is now dropped
+
+ now += Duration::from_millis(20);
+ // process pkt3 (acked for same reason)
+ let srv2 = server.process(Some(&pkt3), now).dgram();
+ // ack client pkt 2 & 3
+ assert!(srv2.is_some());
+
+ // client processes ack
+ let pkt4 = client.process(srv2.as_ref(), now).dgram();
+ // client resends data from pkt0
+ assert!(pkt4.is_some());
+
+ // server sees ooo pkt0 and generates immediate ack
+ let srv3 = server.process(Some(&pkt0), now).dgram();
+ assert!(srv3.is_some());
+
+ // Accept the acknowledgment.
+ let pkt5 = client.process(srv3.as_ref(), now).dgram();
+ assert!(pkt5.is_none());
+
+ now += Duration::from_millis(70);
+ // PTO expires. No unacked data. Only send PING.
+ let client_pings = client.stats().frame_tx.ping;
+ let pkt6 = client.process(None, now).dgram();
+ assert_eq!(client.stats().frame_tx.ping, client_pings + 1);
+
+ let server_pings = server.stats().frame_rx.ping;
+ server.process_input(&pkt6.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.ping, server_pings + 1);
+}
+
+#[test]
+fn pto_initial() {
+ const INITIAL_PTO: Duration = Duration::from_millis(300);
+ let mut now = now();
+
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let pkt1 = client.process(None, now).dgram();
+ assert!(pkt1.is_some());
+ assert_eq!(pkt1.clone().unwrap().len(), PATH_MTU_V6);
+
+ let delay = client.process(None, now).callback();
+ assert_eq!(delay, INITIAL_PTO);
+
+ // Resend initial after PTO.
+ now += delay;
+ let pkt2 = client.process(None, now).dgram();
+ assert!(pkt2.is_some());
+ assert_eq!(pkt2.unwrap().len(), PATH_MTU_V6);
+
+ let delay = client.process(None, now).callback();
+ // PTO has doubled.
+ assert_eq!(delay, INITIAL_PTO * 2);
+
+ // Server process the first initial pkt.
+ let mut server = default_server();
+ let out = server.process(pkt1.as_ref(), now).dgram();
+ assert!(out.is_some());
+
+ // Client receives ack for the first initial packet as well a Handshake packet.
+ // After the handshake packet the initial keys and the crypto stream for the initial
+ // packet number space will be discarded.
+ // Here only an ack for the Handshake packet will be sent.
+ let out = client.process(out.as_ref(), now).dgram();
+ assert!(out.is_some());
+
+ // We do not have PTO for the resent initial packet any more, but
+ // the Handshake PTO timer should be armed. As the RTT is apparently
+ // the same as the initial PTO value, and there is only one sample,
+ // the PTO will be 3x the INITIAL PTO.
+ let delay = client.process(None, now).callback();
+ assert_eq!(delay, INITIAL_PTO * 3);
+}
+
+/// A complete handshake that involves a PTO in the Handshake space.
+#[test]
+fn pto_handshake_complete() {
+ const HALF_RTT: Duration = Duration::from_millis(10);
+
+ let mut now = now();
+ // start handshake
+ let mut client = default_client();
+ let mut server = default_server();
+
+ let pkt = client.process(None, now).dgram();
+ assert_initial(pkt.as_ref().unwrap(), false);
+ let cb = client.process(None, now).callback();
+ assert_eq!(cb, Duration::from_millis(300));
+
+ now += HALF_RTT;
+ let pkt = server.process(pkt.as_ref(), now).dgram();
+ assert_initial(pkt.as_ref().unwrap(), false);
+
+ now += HALF_RTT;
+ let pkt = client.process(pkt.as_ref(), now).dgram();
+ assert_handshake(pkt.as_ref().unwrap());
+
+ let cb = client.process(None, now).callback();
+ // The client now has a single RTT estimate (20ms), so
+ // the handshake PTO is set based on that.
+ assert_eq!(cb, HALF_RTT * 6);
+
+ now += HALF_RTT;
+ let pkt = server.process(pkt.as_ref(), now).dgram();
+ assert!(pkt.is_none());
+
+ now += HALF_RTT;
+ client.authenticated(AuthenticationStatus::Ok, now);
+
+ qdebug!("---- client: SH..FIN -> FIN");
+ let pkt1 = client.process(None, now).dgram();
+ assert_handshake(pkt1.as_ref().unwrap());
+ assert_eq!(*client.state(), State::Connected);
+
+ let cb = client.process(None, now).callback();
+ assert_eq!(cb, HALF_RTT * 6);
+
+ let mut pto_counts = [0; MAX_PTO_COUNTS];
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ // Wait for PTO to expire and resend a handshake packet.
+ // Wait long enough that the 1-RTT PTO also fires.
+ qdebug!("---- client: PTO");
+ now += HALF_RTT * 6;
+ let pkt2 = client.process(None, now).dgram();
+ assert_handshake(pkt2.as_ref().unwrap());
+
+ pto_counts[0] = 1;
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ // Get a second PTO packet.
+ // Add some application data to this datagram, then split the 1-RTT off.
+ // We'll use that packet to force the server to acknowledge 1-RTT.
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_close_send(stream_id).unwrap();
+ let pkt3 = client.process(None, now).dgram();
+ assert_handshake(pkt3.as_ref().unwrap());
+ let (pkt3_hs, pkt3_1rtt) = split_datagram(&pkt3.unwrap());
+ assert_handshake(&pkt3_hs);
+ assert!(pkt3_1rtt.is_some());
+
+ // PTO has been doubled.
+ let cb = client.process(None, now).callback();
+ assert_eq!(cb, HALF_RTT * 12);
+
+ // We still have only a single PTO
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ qdebug!("---- server: receive FIN and send ACK");
+ now += HALF_RTT;
+ // Now let the server have pkt1 and expect an immediate Handshake ACK.
+ // The output will be a Handshake packet with ACK and 1-RTT packet with
+ // HANDSHAKE_DONE and (because of pkt3_1rtt) an ACK.
+ // This should remove the 1-RTT PTO from messing this test up.
+ let server_acks = server.stats().frame_tx.ack;
+ let server_done = server.stats().frame_tx.handshake_done;
+ server.process_input(&pkt3_1rtt.unwrap(), now);
+ let ack = server.process(pkt1.as_ref(), now).dgram();
+ assert!(ack.is_some());
+ assert_eq!(server.stats().frame_tx.ack, server_acks + 2);
+ assert_eq!(server.stats().frame_tx.handshake_done, server_done + 1);
+
+ // Check that the other packets (pkt2, pkt3) are Handshake packets.
+ // The server discarded the Handshake keys already, therefore they are dropped.
+ // Note that these don't include 1-RTT packets, because 1-RTT isn't send on PTO.
+ let (pkt2_hs, pkt2_1rtt) = split_datagram(&pkt2.unwrap());
+ assert_handshake(&pkt2_hs);
+ assert!(pkt2_1rtt.is_some());
+ let dropped_before1 = server.stats().dropped_rx;
+ let server_frames = server.stats().frame_rx.all;
+ server.process_input(&pkt2_hs, now);
+ assert_eq!(1, server.stats().dropped_rx - dropped_before1);
+ assert_eq!(server.stats().frame_rx.all, server_frames);
+
+ server.process_input(&pkt2_1rtt.unwrap(), now);
+ let server_frames2 = server.stats().frame_rx.all;
+ let dropped_before2 = server.stats().dropped_rx;
+ server.process_input(&pkt3_hs, now);
+ assert_eq!(1, server.stats().dropped_rx - dropped_before2);
+ assert_eq!(server.stats().frame_rx.all, server_frames2);
+
+ now += HALF_RTT;
+
+ // Let the client receive the ACK.
+ // It should now be wait to acknowledge the HANDSHAKE_DONE.
+ let cb = client.process(ack.as_ref(), now).callback();
+ // The default ack delay is the RTT divided by the default ACK ratio of 4.
+ let expected_ack_delay = HALF_RTT * 2 / 4;
+ assert_eq!(cb, expected_ack_delay);
+
+ // Let the ACK delay timer expire.
+ now += cb;
+ let out = client.process(None, now).dgram();
+ assert!(out.is_some());
+}
+
+/// Test that PTO in the Handshake space contains the right frames.
+#[test]
+fn pto_handshake_frames() {
+ let mut now = now();
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let pkt = client.process(None, now);
+
+ now += Duration::from_millis(10);
+ qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN");
+ let mut server = default_server();
+ let pkt = server.process(pkt.as_dgram_ref(), now);
+
+ now += Duration::from_millis(10);
+ qdebug!("---- client: cert verification");
+ let pkt = client.process(pkt.as_dgram_ref(), now);
+
+ now += Duration::from_millis(10);
+ mem::drop(server.process(pkt.as_dgram_ref(), now));
+
+ now += Duration::from_millis(10);
+ client.authenticated(AuthenticationStatus::Ok, now);
+
+ let stream = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(stream, 2);
+ assert_eq!(client.stream_send(stream, b"zero").unwrap(), 4);
+ qdebug!("---- client: SH..FIN -> FIN and 1RTT packet");
+ let pkt1 = client.process(None, now).dgram();
+ assert!(pkt1.is_some());
+
+ // Get PTO timer.
+ let out = client.process(None, now);
+ assert_eq!(out, Output::Callback(Duration::from_millis(60)));
+
+ // Wait for PTO to expire and resend a handshake packet.
+ now += Duration::from_millis(60);
+ let pkt2 = client.process(None, now).dgram();
+ assert!(pkt2.is_some());
+
+ now += Duration::from_millis(10);
+ let crypto_before = server.stats().frame_rx.crypto;
+ server.process_input(&pkt2.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.crypto, crypto_before + 1);
+}
+
+/// In the case that the Handshake takes too many packets, the server might
+/// be stalled on the anti-amplification limit. If a Handshake ACK from the
+/// client is lost, the client has to keep the PTO timer armed or the server
+/// might be unable to send anything, causing a deadlock.
+#[test]
+fn handshake_ack_pto() {
+ const RTT: Duration = Duration::from_millis(10);
+ let mut now = now();
+ let mut client = default_client();
+ let mut server = default_server();
+ // This is a greasing transport parameter, and large enough that the
+ // server needs to send two Handshake packets.
+ let big = TransportParameter::Bytes(vec![0; PATH_MTU_V6]);
+ server.set_local_tparam(0xce16, big).unwrap();
+
+ let c1 = client.process(None, now).dgram();
+
+ now += RTT / 2;
+ let s1 = server.process(c1.as_ref(), now).dgram();
+ assert!(s1.is_some());
+ let s2 = server.process(None, now).dgram();
+ assert!(s1.is_some());
+
+ // Now let the client have the Initial, but drop the first coalesced Handshake packet.
+ now += RTT / 2;
+ let (initial, _) = split_datagram(&s1.unwrap());
+ client.process_input(&initial, now);
+ let c2 = client.process(s2.as_ref(), now).dgram();
+ assert!(c2.is_some()); // This is an ACK. Drop it.
+ let delay = client.process(None, now).callback();
+ assert_eq!(delay, RTT * 3);
+
+ let mut pto_counts = [0; MAX_PTO_COUNTS];
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ // Wait for the PTO and ensure that the client generates a packet.
+ now += delay;
+ let c3 = client.process(None, now).dgram();
+ assert!(c3.is_some());
+
+ now += RTT / 2;
+ let ping_before = server.stats().frame_rx.ping;
+ server.process_input(&c3.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.ping, ping_before + 1);
+
+ pto_counts[0] = 1;
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ // Now complete the handshake as cheaply as possible.
+ let dgram = server.process(None, now).dgram();
+ client.process_input(&dgram.unwrap(), now);
+ maybe_authenticate(&mut client);
+ let dgram = client.process(None, now).dgram();
+ assert_eq!(*client.state(), State::Connected);
+ let dgram = server.process(dgram.as_ref(), now).dgram();
+ assert_eq!(*server.state(), State::Confirmed);
+ client.process_input(&dgram.unwrap(), now);
+ assert_eq!(*client.state(), State::Confirmed);
+
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+}
+
+#[test]
+fn loss_recovery_crash() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let now = now();
+
+ // The server sends something, but we will drop this.
+ mem::drop(send_something(&mut server, now));
+
+ // Then send something again, but let it through.
+ let ack = send_and_receive(&mut server, &mut client, now);
+ assert!(ack.is_some());
+
+ // Have the server process the ACK.
+ let cb = server.process(ack.as_ref(), now).callback();
+ assert!(cb > Duration::from_secs(0));
+
+ // Now we leap into the future. The server should regard the first
+ // packet as lost based on time alone.
+ let dgram = server.process(None, now + AT_LEAST_PTO).dgram();
+ assert!(dgram.is_some());
+
+ // This crashes.
+ mem::drop(send_something(&mut server, now + AT_LEAST_PTO));
+}
+
+// If we receive packets after the PTO timer has fired, we won't clear
+// the PTO state, but we might need to acknowledge those packets.
+// This shouldn't happen, but we found that some implementations do this.
+#[test]
+fn ack_after_pto() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let mut now = now();
+
+ // The client sends and is forced into a PTO.
+ mem::drop(send_something(&mut client, now));
+
+ // Jump forward to the PTO and drain the PTO packets.
+ now += AT_LEAST_PTO;
+ // We can use MAX_PTO_PACKET_COUNT, because we know the handshake is over.
+ for _ in 0..MAX_PTO_PACKET_COUNT {
+ let dgram = client.process(None, now).dgram();
+ assert!(dgram.is_some());
+ }
+ assert!(client.process(None, now).dgram().is_none());
+
+ // The server now needs to send something that will cause the
+ // client to want to acknowledge it. A little out of order
+ // delivery is just the thing.
+ // Note: The server can't ACK anything here, but none of what
+ // the client has sent so far has been transferred.
+ mem::drop(send_something(&mut server, now));
+ let dgram = send_something(&mut server, now);
+
+ // The client is now after a PTO, but if it receives something
+ // that demands acknowledgment, it will send just the ACK.
+ let ack = client.process(Some(&dgram), now).dgram();
+ assert!(ack.is_some());
+
+ // Make sure that the packet only contained an ACK frame.
+ let all_frames_before = server.stats().frame_rx.all;
+ let ack_before = server.stats().frame_rx.ack;
+ server.process_input(&ack.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.all, all_frames_before + 1);
+ assert_eq!(server.stats().frame_rx.ack, ack_before + 1);
+}
+
+/// When we declare a packet as lost, we keep it around for a while for another loss period.
+/// Those packets should not affect how we report the loss recovery timer.
+/// As the loss recovery timer based on RTT we use that to drive the state.
+#[test]
+fn lost_but_kept_and_lr_timer() {
+ const RTT: Duration = Duration::from_secs(1);
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT);
+
+ // Two packets (p1, p2) are sent at around t=0. The first is lost.
+ let _p1 = send_something(&mut client, now);
+ let p2 = send_something(&mut client, now);
+
+ // At t=RTT/2 the server receives the packet and ACKs it.
+ now += RTT / 2;
+ let ack = server.process(Some(&p2), now).dgram();
+ assert!(ack.is_some());
+ // The client also sends another two packets (p3, p4), again losing the first.
+ let _p3 = send_something(&mut client, now);
+ let p4 = send_something(&mut client, now);
+
+ // At t=RTT the client receives the ACK and goes into timed loss recovery.
+ // The client doesn't call p1 lost at this stage, but it will soon.
+ now += RTT / 2;
+ let res = client.process(ack.as_ref(), now);
+ // The client should be on a loss recovery timer as p1 is missing.
+ let lr_timer = res.callback();
+ // Loss recovery timer should be RTT/8, but only check for 0 or >=RTT/2.
+ assert_ne!(lr_timer, Duration::from_secs(0));
+ assert!(lr_timer < (RTT / 2));
+ // The server also receives and acknowledges p4, again sending an ACK.
+ let ack = server.process(Some(&p4), now).dgram();
+ assert!(ack.is_some());
+
+ // At t=RTT*3/2 the client should declare p1 to be lost.
+ now += RTT / 2;
+ // So the client will send the data from p1 again.
+ let res = client.process(None, now);
+ assert!(res.dgram().is_some());
+ // When the client processes the ACK, it should engage the
+ // loss recovery timer for p3, not p1 (even though it still tracks p1).
+ let res = client.process(ack.as_ref(), now);
+ let lr_timer2 = res.callback();
+ assert_eq!(lr_timer, lr_timer2);
+}
+
+/// We should not be setting the loss recovery timer based on packets
+/// that are sent prior to the largest acknowledged.
+/// Testing this requires that we construct a case where one packet
+/// number space causes the loss recovery timer to be engaged. At the same time,
+/// there is a packet in another space that hasn't been acknowledged AND
+/// that packet number space has not received acknowledgments for later packets.
+#[test]
+fn loss_time_past_largest_acked() {
+ const RTT: Duration = Duration::from_secs(10);
+ const INCR: Duration = Duration::from_millis(1);
+ let mut client = default_client();
+ let mut server = default_server();
+
+ let mut now = now();
+
+ // Start the handshake.
+ let c_in = client.process(None, now).dgram();
+ now += RTT / 2;
+ let s_hs1 = server.process(c_in.as_ref(), now).dgram();
+
+ // Get some spare server handshake packets for the client to ACK.
+ // This involves a time machine, so be a little cautious.
+ // This test uses an RTT of 10s, but our server starts
+ // with a much lower RTT estimate, so the PTO at this point should
+ // be much smaller than an RTT and so the server shouldn't see
+ // time go backwards.
+ let s_pto = server.process(None, now).callback();
+ assert_ne!(s_pto, Duration::from_secs(0));
+ assert!(s_pto < RTT);
+ let s_hs2 = server.process(None, now + s_pto).dgram();
+ assert!(s_hs2.is_some());
+ let s_hs3 = server.process(None, now + s_pto).dgram();
+ assert!(s_hs3.is_some());
+
+ // Get some Handshake packets from the client.
+ // We need one to be left unacknowledged before one that is acknowledged.
+ // So that the client engages the loss recovery timer.
+ // This is complicated by the fact that it is hard to cause the client
+ // to generate an ack-eliciting packet. For that, we use the Finished message.
+ // Reordering delivery ensures that the later packet is also acknowledged.
+ now += RTT / 2;
+ let c_hs1 = client.process(s_hs1.as_ref(), now).dgram();
+ assert!(c_hs1.is_some()); // This comes first, so it's useless.
+ maybe_authenticate(&mut client);
+ let c_hs2 = client.process(None, now).dgram();
+ assert!(c_hs2.is_some()); // This one will elicit an ACK.
+
+ // The we need the outstanding packet to be sent after the
+ // application data packet, so space these out a tiny bit.
+ let _p1 = send_something(&mut client, now + INCR);
+ let c_hs3 = client.process(s_hs2.as_ref(), now + (INCR * 2)).dgram();
+ assert!(c_hs3.is_some()); // This will be left outstanding.
+ let c_hs4 = client.process(s_hs3.as_ref(), now + (INCR * 3)).dgram();
+ assert!(c_hs4.is_some()); // This will be acknowledged.
+
+ // Process c_hs2 and c_hs4, but skip c_hs3.
+ // Then get an ACK for the client.
+ now += RTT / 2;
+ // Deliver c_hs4 first, but don't generate a packet.
+ server.process_input(&c_hs4.unwrap(), now);
+ let s_ack = server.process(c_hs2.as_ref(), now).dgram();
+ assert!(s_ack.is_some());
+ // This includes an ACK, but it also includes HANDSHAKE_DONE,
+ // which we need to remove because that will cause the Handshake loss
+ // recovery state to be dropped.
+ let (s_hs_ack, _s_ap_ack) = split_datagram(&s_ack.unwrap());
+
+ // Now the client should start its loss recovery timer based on the ACK.
+ now += RTT / 2;
+ let c_ack = client.process(Some(&s_hs_ack), now).dgram();
+ assert!(c_ack.is_none());
+ // The client should now have the loss recovery timer active.
+ let lr_time = client.process(None, now).callback();
+ assert_ne!(lr_time, Duration::from_secs(0));
+ assert!(lr_time < (RTT / 2));
+}
+
+/// `sender` sends a little, `receiver` acknowledges it.
+/// Repeat until `count` acknowledgements are sent.
+/// Returns the last packet containing acknowledgements, if any.
+fn trickle(sender: &mut Connection, receiver: &mut Connection, mut count: usize, now: Instant) {
+ let id = sender.stream_create(StreamType::UniDi).unwrap();
+ let mut maybe_ack = None;
+ while count > 0 {
+ qdebug!("trickle: remaining={}", count);
+ assert_eq!(sender.stream_send(id, &[9]).unwrap(), 1);
+ let dgram = sender.process(maybe_ack.as_ref(), now).dgram();
+
+ maybe_ack = receiver.process(dgram.as_ref(), now).dgram();
+ count -= usize::from(maybe_ack.is_some());
+ }
+ sender.process_input(&maybe_ack.unwrap(), now);
+}
+
+/// Ensure that a PING frame is sent with ACK sometimes.
+/// `fast` allows testing of when `MAX_OUTSTANDING_UNACK` packets are
+/// outstanding (`fast` is `true`) within 1 PTO and when only
+/// `MIN_OUTSTANDING_UNACK` packets arrive after 2 PTOs (`fast` is `false`).
+fn ping_with_ack(fast: bool) {
+ let mut sender = default_client();
+ let mut receiver = default_server();
+ let mut now = now();
+ connect_force_idle(&mut sender, &mut receiver);
+ let sender_acks_before = sender.stats().frame_tx.ack;
+ let receiver_acks_before = receiver.stats().frame_tx.ack;
+ let count = if fast {
+ MAX_OUTSTANDING_UNACK
+ } else {
+ MIN_OUTSTANDING_UNACK
+ };
+ trickle(&mut sender, &mut receiver, count, now);
+ assert_eq!(sender.stats().frame_tx.ack, sender_acks_before);
+ assert_eq!(receiver.stats().frame_tx.ack, receiver_acks_before + count);
+ assert_eq!(receiver.stats().frame_tx.ping, 0);
+
+ if !fast {
+ // Wait at least one PTO, from the reciever's perspective.
+ // A receiver that hasn't received MAX_OUTSTANDING_UNACK won't send PING.
+ now += receiver.pto() + Duration::from_micros(1);
+ trickle(&mut sender, &mut receiver, 1, now);
+ assert_eq!(receiver.stats().frame_tx.ping, 0);
+ }
+
+ // After a second PTO (or the first if fast), new acknowledgements come
+ // with a PING frame and cause an ACK to be sent by the sender.
+ now += receiver.pto() + Duration::from_micros(1);
+ trickle(&mut sender, &mut receiver, 1, now);
+ assert_eq!(receiver.stats().frame_tx.ping, 1);
+ if let Output::Callback(t) = sender.process_output(now) {
+ assert_eq!(t, DEFAULT_ACK_DELAY);
+ assert!(sender.process_output(now + t).dgram().is_some());
+ }
+ assert_eq!(sender.stats().frame_tx.ack, sender_acks_before + 1);
+}
+
+#[test]
+fn ping_with_ack_fast() {
+ ping_with_ack(true);
+}
+
+#[test]
+fn ping_with_ack_slow() {
+ ping_with_ack(false);
+}
+
+#[test]
+fn ping_with_ack_min() {
+ const COUNT: usize = MIN_OUTSTANDING_UNACK - 2;
+ let mut sender = default_client();
+ let mut receiver = default_server();
+ let mut now = now();
+ connect_force_idle(&mut sender, &mut receiver);
+ let sender_acks_before = sender.stats().frame_tx.ack;
+ let receiver_acks_before = receiver.stats().frame_tx.ack;
+ trickle(&mut sender, &mut receiver, COUNT, now);
+ assert_eq!(sender.stats().frame_tx.ack, sender_acks_before);
+ assert_eq!(receiver.stats().frame_tx.ack, receiver_acks_before + COUNT);
+ assert_eq!(receiver.stats().frame_tx.ping, 0);
+
+ // After 3 PTO, no PING because there are too few outstanding packets.
+ now += receiver.pto() * 3 + Duration::from_micros(1);
+ trickle(&mut sender, &mut receiver, 1, now);
+ assert_eq!(receiver.stats().frame_tx.ping, 0);
+}
+
+/// This calculates the PTO timer immediately after connection establishment.
+/// It depends on there only being 2 RTT samples in the handshake.
+fn expected_pto(rtt: Duration) -> Duration {
+ // PTO calculation is rtt + 4rttvar + ack delay.
+ // rttvar should be (rtt + 4 * (rtt / 2) * (3/4)^n + 25ms)/2
+ // where n is the number of round trips
+ // This uses a 25ms ack delay as the ACK delay extension
+ // is negotiated and no ACK_DELAY frame has been received.
+ rtt + rtt * 9 / 8 + Duration::from_millis(25)
+}
+
+#[test]
+fn fast_pto() {
+ let mut client = new_client(ConnectionParameters::default().fast_pto(FAST_PTO_SCALE / 2));
+ let mut server = default_server();
+ let mut now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ let res = client.process(None, now);
+ let idle_timeout = ConnectionParameters::default().get_idle_timeout() - (DEFAULT_RTT / 2);
+ assert_eq!(res, Output::Callback(idle_timeout));
+
+ // Send data on two streams
+ let stream = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(
+ client.stream_send(stream, DEFAULT_STREAM_DATA).unwrap(),
+ DEFAULT_STREAM_DATA.len()
+ );
+
+ // Send a packet after some time.
+ now += idle_timeout / 2;
+ let dgram = client.process_output(now).dgram();
+ assert!(dgram.is_some());
+
+ // Nothing to do, should return a callback.
+ let cb = client.process_output(now).callback();
+ assert_eq!(expected_pto(DEFAULT_RTT) / 2, cb);
+
+ // Once the PTO timer expires, a PTO packet should be sent should want to send PTO packet.
+ now += cb;
+ let dgram = client.process(None, now).dgram();
+
+ let stream_before = server.stats().frame_rx.stream;
+ server.process_input(&dgram.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.stream, stream_before + 1);
+}
+
+/// Even if the PTO timer is slowed right down, persistent congestion is declared
+/// based on the "true" value of the timer.
+#[test]
+fn fast_pto_persistent_congestion() {
+ let mut client = new_client(ConnectionParameters::default().fast_pto(FAST_PTO_SCALE * 2));
+ let mut server = default_server();
+ let mut now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ let res = client.process(None, now);
+ let idle_timeout = ConnectionParameters::default().get_idle_timeout() - (DEFAULT_RTT / 2);
+ assert_eq!(res, Output::Callback(idle_timeout));
+
+ // Send packets spaced by the PTO timer. And lose them.
+ // Note: This timing is a tiny bit higher than the client will use
+ // to determine persistent congestion. The ACK below adds another RTT
+ // estimate, which will reduce rttvar by 3/4, so persistent congestion
+ // will occur at `rtt + rtt*27/32 + 25ms`.
+ // That is OK as we're still showing that this interval is less than
+ // six times the PTO, which is what would be used if the scaling
+ // applied to the PTO used to determine persistent congestion.
+ let pc_interval = expected_pto(DEFAULT_RTT) * 3;
+ println!("pc_interval {pc_interval:?}");
+ let _drop1 = send_something(&mut client, now);
+
+ // Check that the PTO matches expectations.
+ let cb = client.process_output(now).callback();
+ assert_eq!(expected_pto(DEFAULT_RTT) * 2, cb);
+
+ now += pc_interval;
+ let _drop2 = send_something(&mut client, now);
+ let _drop3 = send_something(&mut client, now);
+ let _drop4 = send_something(&mut client, now);
+ let dgram = send_something(&mut client, now);
+
+ // Now acknowledge the tail packet and enter persistent congestion.
+ now += DEFAULT_RTT / 2;
+ let ack = server.process(Some(&dgram), now).dgram();
+ now += DEFAULT_RTT / 2;
+ client.process_input(&ack.unwrap(), now);
+ assert_eq!(cwnd(&client), CWND_MIN);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/resumption.rs b/third_party/rust/neqo-transport/src/connection/tests/resumption.rs
new file mode 100644
index 0000000000..a8c45a9f06
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/resumption.rs
@@ -0,0 +1,246 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{cell::RefCell, mem, rc::Rc, time::Duration};
+
+use test_fixture::{self, assertions, now};
+
+use super::{
+ connect, connect_with_rtt, default_client, default_server, exchange_ticket, get_tokens,
+ new_client, resumed_server, send_something, AT_LEAST_PTO,
+};
+use crate::{
+ addr_valid::{AddressValidation, ValidateAddress},
+ ConnectionParameters, Error, Version,
+};
+
+#[test]
+fn resume() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let token = exchange_ticket(&mut client, &mut server, now());
+ let mut client = default_client();
+ client
+ .enable_resumption(now(), token)
+ .expect("should set token");
+ let mut server = resumed_server(&client);
+ connect(&mut client, &mut server);
+ assert!(client.tls_info().unwrap().resumed());
+ assert!(server.tls_info().unwrap().resumed());
+}
+
+#[test]
+fn remember_smoothed_rtt() {
+ const RTT1: Duration = Duration::from_millis(130);
+ const RTT2: Duration = Duration::from_millis(70);
+
+ let mut client = default_client();
+ let mut server = default_server();
+
+ let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT1);
+ assert_eq!(client.paths.rtt(), RTT1);
+
+ // We can't use exchange_ticket here because it doesn't respect RTT.
+ // Also, connect_with_rtt() ends with the server receiving a packet it
+ // wants to acknowledge; so the ticket will include an ACK frame too.
+ let validation = AddressValidation::new(now, ValidateAddress::NoToken).unwrap();
+ let validation = Rc::new(RefCell::new(validation));
+ server.set_validation(Rc::clone(&validation));
+ server.send_ticket(now, &[]).expect("can send ticket");
+ let ticket = server.process_output(now).dgram();
+ assert!(ticket.is_some());
+ now += RTT1 / 2;
+ client.process_input(&ticket.unwrap(), now);
+ let token = get_tokens(&mut client).pop().unwrap();
+
+ let mut client = default_client();
+ client.enable_resumption(now, token).unwrap();
+ assert_eq!(
+ client.paths.rtt(),
+ RTT1,
+ "client should remember previous RTT"
+ );
+ let mut server = resumed_server(&client);
+
+ connect_with_rtt(&mut client, &mut server, now, RTT2);
+ assert_eq!(
+ client.paths.rtt(),
+ RTT2,
+ "previous RTT should be completely erased"
+ );
+}
+
+/// Check that a resumed connection uses a token on Initial packets.
+#[test]
+fn address_validation_token_resume() {
+ const RTT: Duration = Duration::from_millis(10);
+
+ let mut client = default_client();
+ let mut server = default_server();
+ let validation = AddressValidation::new(now(), ValidateAddress::Always).unwrap();
+ let validation = Rc::new(RefCell::new(validation));
+ server.set_validation(Rc::clone(&validation));
+ let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT);
+
+ let token = exchange_ticket(&mut client, &mut server, now);
+ let mut client = default_client();
+ client.enable_resumption(now, token).unwrap();
+ let mut server = resumed_server(&client);
+
+ // Grab an Initial packet from the client.
+ let dgram = client.process(None, now).dgram();
+ assertions::assert_initial(dgram.as_ref().unwrap(), true);
+
+ // Now try to complete the handshake after giving time for a client PTO.
+ now += AT_LEAST_PTO;
+ connect_with_rtt(&mut client, &mut server, now, RTT);
+ assert!(client.crypto.tls.info().unwrap().resumed());
+ assert!(server.crypto.tls.info().unwrap().resumed());
+}
+
+fn can_resume(token: impl AsRef<[u8]>, initial_has_token: bool) {
+ let mut client = default_client();
+ client.enable_resumption(now(), token).unwrap();
+ let initial = client.process_output(now()).dgram();
+ assertions::assert_initial(initial.as_ref().unwrap(), initial_has_token);
+}
+
+#[test]
+fn two_tickets_on_timer() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // Send two tickets and then bundle those into a packet.
+ server.send_ticket(now(), &[]).expect("send ticket1");
+ server.send_ticket(now(), &[]).expect("send ticket2");
+ let pkt = send_something(&mut server, now());
+
+ // process() will return an ack first
+ assert!(client.process(Some(&pkt), now()).dgram().is_some());
+ // We do not have a ResumptionToken event yet, because NEW_TOKEN was not sent.
+ assert_eq!(get_tokens(&mut client).len(), 0);
+
+ // We need to wait for release_resumption_token_timer to expire. The timer will be
+ // set to 3 * PTO
+ let mut now = now() + 3 * client.pto();
+ mem::drop(client.process(None, now));
+ let mut recv_tokens = get_tokens(&mut client);
+ assert_eq!(recv_tokens.len(), 1);
+ let token1 = recv_tokens.pop().unwrap();
+ // Wai for anottheer 3 * PTO to get the nex okeen.
+ now += 3 * client.pto();
+ mem::drop(client.process(None, now));
+ let mut recv_tokens = get_tokens(&mut client);
+ assert_eq!(recv_tokens.len(), 1);
+ let token2 = recv_tokens.pop().unwrap();
+ // Wait for 3 * PTO, but now there are no more tokens.
+ now += 3 * client.pto();
+ mem::drop(client.process(None, now));
+ assert_eq!(get_tokens(&mut client).len(), 0);
+ assert_ne!(token1.as_ref(), token2.as_ref());
+
+ can_resume(token1, false);
+ can_resume(token2, false);
+}
+
+#[test]
+fn two_tickets_with_new_token() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let validation = AddressValidation::new(now(), ValidateAddress::Always).unwrap();
+ let validation = Rc::new(RefCell::new(validation));
+ server.set_validation(Rc::clone(&validation));
+ connect(&mut client, &mut server);
+
+ // Send two tickets with tokens and then bundle those into a packet.
+ server.send_ticket(now(), &[]).expect("send ticket1");
+ server.send_ticket(now(), &[]).expect("send ticket2");
+ let pkt = send_something(&mut server, now());
+
+ client.process_input(&pkt, now());
+ let mut all_tokens = get_tokens(&mut client);
+ assert_eq!(all_tokens.len(), 2);
+ let token1 = all_tokens.pop().unwrap();
+ let token2 = all_tokens.pop().unwrap();
+ assert_ne!(token1.as_ref(), token2.as_ref());
+
+ can_resume(token1, true);
+ can_resume(token2, true);
+}
+
+/// By disabling address validation, the server won't send `NEW_TOKEN`, but
+/// we can take the session ticket still.
+#[test]
+fn take_token() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ server.send_ticket(now(), &[]).unwrap();
+ let dgram = server.process(None, now()).dgram();
+ client.process_input(&dgram.unwrap(), now());
+
+ // There should be no ResumptionToken event here.
+ let tokens = get_tokens(&mut client);
+ assert_eq!(tokens.len(), 0);
+
+ // But we should be able to get the token directly, and use it.
+ let token = client.take_resumption_token(now()).unwrap();
+ can_resume(token, false);
+}
+
+/// If a version is selected and subsequently disabled, resumption fails.
+#[test]
+fn resume_disabled_version() {
+ let mut client = new_client(
+ ConnectionParameters::default().versions(Version::Version1, vec![Version::Version1]),
+ );
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let token = exchange_ticket(&mut client, &mut server, now());
+
+ let mut client = new_client(
+ ConnectionParameters::default().versions(Version::Version2, vec![Version::Version2]),
+ );
+ assert_eq!(
+ client.enable_resumption(now(), token).unwrap_err(),
+ Error::DisabledVersion
+ );
+}
+
+/// It's not possible to resume once a packet has been sent.
+#[test]
+fn resume_after_packet() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let token = exchange_ticket(&mut client, &mut server, now());
+
+ let mut client = default_client();
+ mem::drop(client.process_output(now()).dgram().unwrap());
+ assert_eq!(
+ client.enable_resumption(now(), token).unwrap_err(),
+ Error::ConnectionState
+ );
+}
+
+/// It's not possible to resume at the server.
+#[test]
+fn resume_server() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let token = exchange_ticket(&mut client, &mut server, now());
+
+ let mut server = default_server();
+ assert_eq!(
+ server.enable_resumption(now(), token).unwrap_err(),
+ Error::ConnectionState
+ );
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/stream.rs b/third_party/rust/neqo-transport/src/connection/tests/stream.rs
new file mode 100644
index 0000000000..586a537b9d
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/stream.rs
@@ -0,0 +1,1162 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{cmp::max, collections::HashMap, convert::TryFrom, mem};
+
+use neqo_common::{event::Provider, qdebug};
+use test_fixture::now;
+
+use super::{
+ super::State, assert_error, connect, connect_force_idle, default_client, default_server,
+ maybe_authenticate, new_client, new_server, send_something, DEFAULT_STREAM_DATA,
+};
+use crate::{
+ events::ConnectionEvent,
+ recv_stream::RECV_BUFFER_SIZE,
+ send_stream::{OrderGroup, SendStreamState, SEND_BUFFER_SIZE},
+ streams::{SendOrder, StreamOrder},
+ tparams::{self, TransportParameter},
+ // tracking::DEFAULT_ACK_PACKET_TOLERANCE,
+ Connection,
+ ConnectionError,
+ ConnectionParameters,
+ Error,
+ StreamId,
+ StreamType,
+};
+
+#[test]
+fn stream_create() {
+ let mut client = default_client();
+
+ let out = client.process(None, now());
+ let mut server = default_server();
+ let out = server.process(out.as_dgram_ref(), now());
+
+ let out = client.process(out.as_dgram_ref(), now());
+ mem::drop(server.process(out.as_dgram_ref(), now()));
+ assert!(maybe_authenticate(&mut client));
+ let out = client.process(None, now());
+
+ // client now in State::Connected
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2);
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 6);
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0);
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 4);
+
+ mem::drop(server.process(out.as_dgram_ref(), now()));
+ // server now in State::Connected
+ assert_eq!(server.stream_create(StreamType::UniDi).unwrap(), 3);
+ assert_eq!(server.stream_create(StreamType::UniDi).unwrap(), 7);
+ assert_eq!(server.stream_create(StreamType::BiDi).unwrap(), 1);
+ assert_eq!(server.stream_create(StreamType::BiDi).unwrap(), 5);
+}
+
+#[test]
+// tests stream send/recv after connection is established.
+fn transfer() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ qdebug!("---- client sends");
+ // Send
+ let client_stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(client_stream_id, &[6; 100]).unwrap();
+ client.stream_send(client_stream_id, &[7; 40]).unwrap();
+ client.stream_send(client_stream_id, &[8; 4000]).unwrap();
+
+ // Send to another stream but some data after fin has been set
+ let client_stream_id2 = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(client_stream_id2, &[6; 60]).unwrap();
+ client.stream_close_send(client_stream_id2).unwrap();
+ client.stream_send(client_stream_id2, &[7; 50]).unwrap_err();
+ // Sending this much takes a few datagrams.
+ let mut datagrams = vec![];
+ let mut out = client.process_output(now());
+ while let Some(d) = out.dgram() {
+ datagrams.push(d);
+ out = client.process_output(now());
+ }
+ assert_eq!(datagrams.len(), 4);
+ assert_eq!(*client.state(), State::Confirmed);
+
+ qdebug!("---- server receives");
+ for d in datagrams {
+ let out = server.process(Some(&d), now());
+ // With an RTT of zero, the server will acknowledge every packet immediately.
+ assert!(out.as_dgram_ref().is_some());
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+ }
+ assert_eq!(*server.state(), State::Confirmed);
+
+ let mut buf = vec![0; 4000];
+
+ let mut stream_ids = server.events().filter_map(|evt| match evt {
+ ConnectionEvent::NewStream { stream_id, .. } => Some(stream_id),
+ _ => None,
+ });
+ let first_stream = stream_ids.next().expect("should have a new stream event");
+ let second_stream = stream_ids
+ .next()
+ .expect("should have a second new stream event");
+ assert!(stream_ids.next().is_none());
+ let (received1, fin1) = server.stream_recv(first_stream, &mut buf).unwrap();
+ assert_eq!(received1, 4000);
+ assert!(!fin1);
+ let (received2, fin2) = server.stream_recv(first_stream, &mut buf).unwrap();
+ assert_eq!(received2, 140);
+ assert!(!fin2);
+
+ let (received3, fin3) = server.stream_recv(second_stream, &mut buf).unwrap();
+ assert_eq!(received3, 60);
+ assert!(fin3);
+}
+
+#[derive(PartialEq, Eq, PartialOrd, Ord)]
+struct IdEntry {
+ sendorder: StreamOrder,
+ stream_id: StreamId,
+}
+
+// tests stream sendorder priorization
+fn sendorder_test(order_of_sendorder: &[Option<SendOrder>]) {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ qdebug!("---- client sends");
+ // open all streams and set the sendorders
+ let mut ordered = Vec::new();
+ let mut streams = Vec::<StreamId>::new();
+ for sendorder in order_of_sendorder {
+ let id = client.stream_create(StreamType::UniDi).unwrap();
+ streams.push(id);
+ ordered.push((id, *sendorder));
+ // must be set before sendorder
+ client.streams.set_fairness(id, true).ok();
+ client.streams.set_sendorder(id, *sendorder).ok();
+ }
+ // Write some data to all the streams
+ for stream_id in streams {
+ client.stream_send(stream_id, &[6; 100]).unwrap();
+ }
+
+ // Sending this much takes a few datagrams.
+ // Note: this test uses an RTT of 0 which simplifies things (no pacing)
+ let mut datagrams = Vec::new();
+ let mut out = client.process_output(now());
+ while let Some(d) = out.dgram() {
+ datagrams.push(d);
+ out = client.process_output(now());
+ }
+ assert_eq!(*client.state(), State::Confirmed);
+
+ qdebug!("---- server receives");
+ for d in datagrams {
+ let out = server.process(Some(&d), now());
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+ }
+ assert_eq!(*server.state(), State::Confirmed);
+
+ let stream_ids = server
+ .events()
+ .filter_map(|evt| match evt {
+ ConnectionEvent::RecvStreamReadable { stream_id, .. } => Some(stream_id),
+ _ => None,
+ })
+ .enumerate()
+ .map(|(a, b)| (b, a))
+ .collect::<HashMap<_, _>>();
+
+ // streams should arrive in priority order, not order of creation, if sendorder prioritization
+ // is working correctly
+
+ // 'ordered' has the send order currently. Re-sort it by sendorder, but
+ // if two items from the same sendorder exist, secondarily sort by the ordering in
+ // the stream_ids vector (HashMap<StreamId, index: usize>)
+ ordered.sort_unstable_by_key(|(stream_id, sendorder)| {
+ (
+ StreamOrder {
+ sendorder: *sendorder,
+ },
+ stream_ids[stream_id],
+ )
+ });
+ // make sure everything now is in the same order, since we modified the order of
+ // same-sendorder items to match the ordering of those we saw in reception
+ for (i, (stream_id, _sendorder)) in ordered.iter().enumerate() {
+ assert_eq!(i, stream_ids[stream_id]);
+ }
+}
+
+#[test]
+fn sendorder_0() {
+ sendorder_test(&[None, Some(1), Some(2), Some(3)]);
+}
+#[test]
+fn sendorder_1() {
+ sendorder_test(&[Some(3), Some(2), Some(1), None]);
+}
+#[test]
+fn sendorder_2() {
+ sendorder_test(&[Some(3), None, Some(2), Some(1)]);
+}
+#[test]
+fn sendorder_3() {
+ sendorder_test(&[Some(1), Some(2), None, Some(3)]);
+}
+#[test]
+fn sendorder_4() {
+ sendorder_test(&[
+ Some(1),
+ Some(2),
+ Some(1),
+ None,
+ Some(3),
+ Some(1),
+ Some(3),
+ None,
+ ]);
+}
+
+// Tests stream sendorder priorization
+// Converts Vecs of u64's into StreamIds
+fn fairness_test<S, R>(source: S, number_iterates: usize, truncate_to: usize, result_array: &R)
+where
+ S: IntoIterator,
+ S::Item: Into<StreamId>,
+ R: IntoIterator + std::fmt::Debug,
+ R::Item: Into<StreamId>,
+ Vec<u64>: PartialEq<R>,
+{
+ // test the OrderGroup code used for fairness
+ let mut group: OrderGroup = OrderGroup::default();
+ for stream_id in source {
+ group.insert(stream_id.into());
+ }
+ {
+ let mut iterator1 = group.iter();
+ // advance_by() would help here
+ let mut n = number_iterates;
+ while n > 0 {
+ iterator1.next();
+ n -= 1;
+ }
+ // let iterator1 go out of scope
+ }
+ group.truncate(truncate_to);
+
+ let iterator2 = group.iter();
+ let result: Vec<u64> = iterator2.map(StreamId::as_u64).collect();
+ assert_eq!(result, *result_array);
+}
+
+#[test]
+fn ordergroup_0() {
+ let source: [u64; 0] = [];
+ let result: [u64; 0] = [];
+ fairness_test(source, 1, usize::MAX, &result);
+}
+
+#[test]
+fn ordergroup_1() {
+ let source: [u64; 6] = [0, 1, 2, 3, 4, 5];
+ let result: [u64; 6] = [1, 2, 3, 4, 5, 0];
+ fairness_test(source, 1, usize::MAX, &result);
+}
+
+#[test]
+fn ordergroup_2() {
+ let source: [u64; 6] = [0, 1, 2, 3, 4, 5];
+ let result: [u64; 6] = [2, 3, 4, 5, 0, 1];
+ fairness_test(source, 2, usize::MAX, &result);
+}
+
+#[test]
+fn ordergroup_3() {
+ let source: [u64; 6] = [0, 1, 2, 3, 4, 5];
+ let result: [u64; 6] = [0, 1, 2, 3, 4, 5];
+ fairness_test(source, 10, usize::MAX, &result);
+}
+
+#[test]
+fn ordergroup_4() {
+ let source: [u64; 6] = [0, 1, 2, 3, 4, 5];
+ let result: [u64; 6] = [0, 1, 2, 3, 4, 5];
+ fairness_test(source, 0, usize::MAX, &result);
+}
+
+#[test]
+fn ordergroup_5() {
+ let source: [u64; 1] = [0];
+ let result: [u64; 1] = [0];
+ fairness_test(source, 1, usize::MAX, &result);
+}
+
+#[test]
+fn ordergroup_6() {
+ let source: [u64; 6] = [0, 1, 2, 3, 4, 5];
+ let result: [u64; 6] = [5, 0, 1, 2, 3, 4];
+ fairness_test(source, 5, usize::MAX, &result);
+}
+
+#[test]
+fn ordergroup_7() {
+ let source: [u64; 6] = [0, 1, 2, 3, 4, 5];
+ let result: [u64; 3] = [0, 1, 2];
+ fairness_test(source, 5, 3, &result);
+}
+
+#[test]
+// Send fin even if a peer closes a reomte bidi send stream before sending any data.
+fn report_fin_when_stream_closed_wo_data() {
+ // Note that the two servers in this test will get different anti-replay filters.
+ // That's OK because we aren't testing anti-replay.
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // create a stream
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out = client.process(None, now());
+ mem::drop(server.process(out.as_dgram_ref(), now()));
+
+ server.stream_close_send(stream_id).unwrap();
+ let out = server.process(None, now());
+ mem::drop(client.process(out.as_dgram_ref(), now()));
+ let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable { .. });
+ assert!(client.events().any(stream_readable));
+}
+
+fn exchange_data(client: &mut Connection, server: &mut Connection) {
+ let mut input = None;
+ loop {
+ let out = client.process(input.as_ref(), now()).dgram();
+ let c_done = out.is_none();
+ let out = server.process(out.as_ref(), now()).dgram();
+ if out.is_none() && c_done {
+ break;
+ }
+ input = out;
+ }
+}
+
+#[test]
+fn sending_max_data() {
+ const SMALL_MAX_DATA: usize = 2048;
+
+ let mut client = default_client();
+ let mut server = new_server(
+ ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()),
+ );
+
+ connect(&mut client, &mut server);
+
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(client.events().count(), 2); // SendStreamWritable, StateChange(connected)
+ assert_eq!(stream_id, 2);
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SMALL_MAX_DATA
+ );
+
+ assert_eq!(
+ client
+ .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1])
+ .unwrap(),
+ SMALL_MAX_DATA
+ );
+
+ exchange_data(&mut client, &mut server);
+
+ let mut buf = vec![0; 40000];
+ let (received, fin) = server.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(received, SMALL_MAX_DATA);
+ assert!(!fin);
+
+ let out = server.process(None, now()).dgram();
+ client.process_input(&out.unwrap(), now());
+
+ assert_eq!(
+ client
+ .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1])
+ .unwrap(),
+ SMALL_MAX_DATA
+ );
+}
+
+#[test]
+fn max_data() {
+ const SMALL_MAX_DATA: usize = 16383;
+
+ let mut client = default_client();
+ let mut server = default_server();
+
+ server
+ .set_local_tparam(
+ tparams::INITIAL_MAX_DATA,
+ TransportParameter::Integer(u64::try_from(SMALL_MAX_DATA).unwrap()),
+ )
+ .unwrap();
+
+ connect(&mut client, &mut server);
+
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(client.events().count(), 2); // SendStreamWritable, StateChange(connected)
+ assert_eq!(stream_id, 2);
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SMALL_MAX_DATA
+ );
+ assert_eq!(
+ client
+ .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1])
+ .unwrap(),
+ SMALL_MAX_DATA
+ );
+ assert_eq!(client.events().count(), 0);
+
+ assert_eq!(client.stream_send(stream_id, b"hello").unwrap(), 0);
+ client
+ .streams
+ .get_send_stream_mut(stream_id)
+ .unwrap()
+ .mark_as_sent(0, 4096, false);
+ assert_eq!(client.events().count(), 0);
+ client
+ .streams
+ .get_send_stream_mut(stream_id)
+ .unwrap()
+ .mark_as_acked(0, 4096, false);
+ assert_eq!(client.events().count(), 0);
+
+ assert_eq!(client.stream_send(stream_id, b"hello").unwrap(), 0);
+ // no event because still limited by conn max data
+ assert_eq!(client.events().count(), 0);
+
+ // Increase max data. Avail space now limited by stream credit
+ client.streams.handle_max_data(100_000_000);
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SEND_BUFFER_SIZE - SMALL_MAX_DATA
+ );
+
+ // Increase max stream data. Avail space now limited by tx buffer
+ client
+ .streams
+ .get_send_stream_mut(stream_id)
+ .unwrap()
+ .set_max_stream_data(100_000_000);
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SEND_BUFFER_SIZE - SMALL_MAX_DATA + 4096
+ );
+
+ let evts = client.events().collect::<Vec<_>>();
+ assert_eq!(evts.len(), 1);
+ assert!(matches!(
+ evts[0],
+ ConnectionEvent::SendStreamWritable { .. }
+ ));
+}
+
+#[test]
+fn exceed_max_data() {
+ const SMALL_MAX_DATA: usize = 1024;
+
+ let mut client = default_client();
+ let mut server = new_server(
+ ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()),
+ );
+
+ connect(&mut client, &mut server);
+
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(client.events().count(), 2); // SendStreamWritable, StateChange(connected)
+ assert_eq!(stream_id, 2);
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SMALL_MAX_DATA
+ );
+ assert_eq!(
+ client
+ .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1])
+ .unwrap(),
+ SMALL_MAX_DATA
+ );
+
+ assert_eq!(client.stream_send(stream_id, b"hello").unwrap(), 0);
+
+ // Artificially trick the client to think that it has more flow control credit.
+ client.streams.handle_max_data(100_000_000);
+ assert_eq!(client.stream_send(stream_id, b"h").unwrap(), 1);
+
+ exchange_data(&mut client, &mut server);
+
+ assert_error(
+ &client,
+ &ConnectionError::Transport(Error::PeerError(Error::FlowControlError.code())),
+ );
+ assert_error(
+ &server,
+ &ConnectionError::Transport(Error::FlowControlError),
+ );
+}
+
+#[test]
+// If we send a stop_sending to the peer, we should not accept more data from the peer.
+fn do_not_accept_data_after_stop_sending() {
+ // Note that the two servers in this test will get different anti-replay filters.
+ // That's OK because we aren't testing anti-replay.
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // create a stream
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out = client.process(None, now());
+ mem::drop(server.process(out.as_dgram_ref(), now()));
+
+ let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable { .. });
+ assert!(server.events().any(stream_readable));
+
+ // Send one more packet from client. The packet should arrive after the server
+ // has already requested stop_sending.
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out_second_data_frame = client.process(None, now());
+ // Call stop sending.
+ assert_eq!(
+ Ok(()),
+ server.stream_stop_sending(stream_id, Error::NoError.code())
+ );
+
+ // Receive the second data frame. The frame should be ignored and
+ // DataReadable events shouldn't be posted.
+ let out = server.process(out_second_data_frame.as_dgram_ref(), now());
+ assert!(!server.events().any(stream_readable));
+
+ mem::drop(client.process(out.as_dgram_ref(), now()));
+ assert_eq!(
+ Err(Error::FinalSizeError),
+ client.stream_send(stream_id, &[0x00])
+ );
+}
+
+#[test]
+// Server sends stop_sending, the client simultaneous sends reset.
+fn simultaneous_stop_sending_and_reset() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // create a stream
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out = client.process(None, now());
+ let ack = server.process(out.as_dgram_ref(), now()).dgram();
+
+ let stream_readable =
+ |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id: id } if id == stream_id);
+ assert!(server.events().any(stream_readable));
+
+ // The client resets the stream. The packet with reset should arrive after the server
+ // has already requested stop_sending.
+ client.stream_reset_send(stream_id, 0).unwrap();
+ let out_reset_frame = client.process(ack.as_ref(), now()).dgram();
+
+ // Send something out of order to force the server to generate an
+ // acknowledgment at the next opportunity.
+ let force_ack = send_something(&mut client, now());
+ server.process_input(&force_ack, now());
+
+ // Call stop sending.
+ server.stream_stop_sending(stream_id, 0).unwrap();
+ // Receive the second data frame. The frame should be ignored and
+ // DataReadable events shouldn't be posted.
+ let ack = server.process(out_reset_frame.as_ref(), now()).dgram();
+ assert!(ack.is_some());
+ assert!(!server.events().any(stream_readable));
+
+ // The client gets the STOP_SENDING frame.
+ client.process_input(&ack.unwrap(), now());
+ assert_eq!(
+ Err(Error::InvalidStreamId),
+ client.stream_send(stream_id, &[0x00])
+ );
+}
+
+#[test]
+fn client_fin_reorder() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ // Send ClientHello.
+ let client_hs = client.process(None, now());
+ assert!(client_hs.as_dgram_ref().is_some());
+
+ let server_hs = server.process(client_hs.as_dgram_ref(), now());
+ assert!(server_hs.as_dgram_ref().is_some()); // ServerHello, etc...
+
+ let client_ack = client.process(server_hs.as_dgram_ref(), now());
+ assert!(client_ack.as_dgram_ref().is_some());
+
+ let server_out = server.process(client_ack.as_dgram_ref(), now());
+ assert!(server_out.as_dgram_ref().is_none());
+
+ assert!(maybe_authenticate(&mut client));
+ assert_eq!(*client.state(), State::Connected);
+
+ let client_fin = client.process(None, now());
+ assert!(client_fin.as_dgram_ref().is_some());
+
+ let client_stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(client_stream_id, &[1, 2, 3]).unwrap();
+ let client_stream_data = client.process(None, now());
+ assert!(client_stream_data.as_dgram_ref().is_some());
+
+ // Now stream data gets before client_fin
+ let server_out = server.process(client_stream_data.as_dgram_ref(), now());
+ assert!(server_out.as_dgram_ref().is_none()); // the packet will be discarded
+
+ assert_eq!(*server.state(), State::Handshaking);
+ let server_out = server.process(client_fin.as_dgram_ref(), now());
+ assert!(server_out.as_dgram_ref().is_some());
+}
+
+#[test]
+fn after_fin_is_read_conn_events_for_stream_should_be_removed() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let id = server.stream_create(StreamType::BiDi).unwrap();
+ server.stream_send(id, &[6; 10]).unwrap();
+ server.stream_close_send(id).unwrap();
+ let out = server.process(None, now()).dgram();
+ assert!(out.is_some());
+
+ mem::drop(client.process(out.as_ref(), now()));
+
+ // read from the stream before checking connection events.
+ let mut buf = vec![0; 4000];
+ let (_, fin) = client.stream_recv(id, &mut buf).unwrap();
+ assert!(fin);
+
+ // Make sure we do not have RecvStreamReadable events for the stream when fin has been read.
+ let readable_stream_evt =
+ |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id } if stream_id == id);
+ assert!(!client.events().any(readable_stream_evt));
+}
+
+#[test]
+fn after_stream_stop_sending_is_called_conn_events_for_stream_should_be_removed() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let id = server.stream_create(StreamType::BiDi).unwrap();
+ server.stream_send(id, &[6; 10]).unwrap();
+ server.stream_close_send(id).unwrap();
+ let out = server.process(None, now()).dgram();
+ assert!(out.is_some());
+
+ mem::drop(client.process(out.as_ref(), now()));
+
+ // send stop seending.
+ client
+ .stream_stop_sending(id, Error::NoError.code())
+ .unwrap();
+
+ // Make sure we do not have RecvStreamReadable events for the stream after stream_stop_sending
+ // has been called.
+ let readable_stream_evt =
+ |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id } if stream_id == id);
+ assert!(!client.events().any(readable_stream_evt));
+}
+
+#[test]
+fn stream_data_blocked_generates_max_stream_data() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let now = now();
+
+ // Send some data and consume some flow control.
+ let stream_id = server.stream_create(StreamType::UniDi).unwrap();
+ _ = server.stream_send(stream_id, DEFAULT_STREAM_DATA).unwrap();
+ let dgram = server.process(None, now).dgram();
+ assert!(dgram.is_some());
+
+ // Consume the data.
+ client.process_input(&dgram.unwrap(), now);
+ let mut buf = [0; 10];
+ let (count, end) = client.stream_recv(stream_id, &mut buf[..]).unwrap();
+ assert_eq!(count, DEFAULT_STREAM_DATA.len());
+ assert!(!end);
+
+ // Now send `STREAM_DATA_BLOCKED`.
+ let internal_stream = server.streams.get_send_stream_mut(stream_id).unwrap();
+ if let SendStreamState::Send { fc, .. } = internal_stream.state() {
+ fc.blocked();
+ } else {
+ panic!("unexpected stream state");
+ }
+ let dgram = server.process_output(now).dgram();
+ assert!(dgram.is_some());
+
+ let sdb_before = client.stats().frame_rx.stream_data_blocked;
+ let dgram = client.process(dgram.as_ref(), now).dgram();
+ assert_eq!(client.stats().frame_rx.stream_data_blocked, sdb_before + 1);
+ assert!(dgram.is_some());
+
+ // Client should have sent a MAX_STREAM_DATA frame with just a small increase
+ // on the default window size.
+ let msd_before = server.stats().frame_rx.max_stream_data;
+ server.process_input(&dgram.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.max_stream_data, msd_before + 1);
+
+ // Test that the entirety of the receive buffer is available now.
+ let mut written = 0;
+ loop {
+ const LARGE_BUFFER: &[u8] = &[0; 1024];
+ let amount = server.stream_send(stream_id, LARGE_BUFFER).unwrap();
+ if amount == 0 {
+ break;
+ }
+ written += amount;
+ }
+ assert_eq!(written, RECV_BUFFER_SIZE);
+}
+
+/// See <https://github.com/mozilla/neqo/issues/871>
+#[test]
+fn max_streams_after_bidi_closed() {
+ const REQUEST: &[u8] = b"ping";
+ const RESPONSE: &[u8] = b"pong";
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ while client.stream_create(StreamType::BiDi).is_ok() {
+ // Exhaust the stream limit.
+ }
+ // Write on the one stream and send that out.
+ _ = client.stream_send(stream_id, REQUEST).unwrap();
+ client.stream_close_send(stream_id).unwrap();
+ let dgram = client.process(None, now()).dgram();
+
+ // Now handle the stream and send an incomplete response.
+ server.process_input(&dgram.unwrap(), now());
+ server.stream_send(stream_id, RESPONSE).unwrap();
+ let dgram = server.process_output(now()).dgram();
+
+ // The server shouldn't have released more stream credit.
+ client.process_input(&dgram.unwrap(), now());
+ let e = client.stream_create(StreamType::BiDi).unwrap_err();
+ assert!(matches!(e, Error::StreamLimitError));
+
+ // Closing the stream isn't enough.
+ server.stream_close_send(stream_id).unwrap();
+ let dgram = server.process_output(now()).dgram();
+ client.process_input(&dgram.unwrap(), now());
+ assert!(client.stream_create(StreamType::BiDi).is_err());
+
+ // The server needs to see an acknowledgment from the client for its
+ // response AND the server has to read all of the request.
+ // and the server needs to read all the data. Read first.
+ let mut buf = [0; REQUEST.len()];
+ let (count, fin) = server.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(&buf[..count], REQUEST);
+ assert!(fin);
+
+ // We need an ACK from the client now, but that isn't guaranteed,
+ // so give the client one more packet just in case.
+ let dgram = send_something(&mut server, now());
+ client.process_input(&dgram, now());
+
+ // Now get the client to send the ACK and have the server handle that.
+ let dgram = send_something(&mut client, now());
+ let dgram = server.process(Some(&dgram), now()).dgram();
+ client.process_input(&dgram.unwrap(), now());
+ assert!(client.stream_create(StreamType::BiDi).is_ok());
+ assert!(client.stream_create(StreamType::BiDi).is_err());
+}
+
+#[test]
+fn no_dupdata_readable_events() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // create a stream
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out = client.process(None, now());
+ mem::drop(server.process(out.as_dgram_ref(), now()));
+
+ // We have a data_readable event.
+ let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable { .. });
+ assert!(server.events().any(stream_readable));
+
+ // Send one more data frame from client. The previous stream data has not been read yet,
+ // therefore there should not be a new DataReadable event.
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out_second_data_frame = client.process(None, now());
+ mem::drop(server.process(out_second_data_frame.as_dgram_ref(), now()));
+ assert!(!server.events().any(stream_readable));
+
+ // One more frame with a fin will not produce a new DataReadable event, because the
+ // previous stream data has not been read yet.
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ client.stream_close_send(stream_id).unwrap();
+ let out_third_data_frame = client.process(None, now());
+ mem::drop(server.process(out_third_data_frame.as_dgram_ref(), now()));
+ assert!(!server.events().any(stream_readable));
+}
+
+#[test]
+fn no_dupdata_readable_events_empty_last_frame() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // create a stream
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out = client.process(None, now());
+ mem::drop(server.process(out.as_dgram_ref(), now()));
+
+ // We have a data_readable event.
+ let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable { .. });
+ assert!(server.events().any(stream_readable));
+
+ // An empty frame with a fin will not produce a new DataReadable event, because
+ // the previous stream data has not been read yet.
+ client.stream_close_send(stream_id).unwrap();
+ let out_second_data_frame = client.process(None, now());
+ mem::drop(server.process(out_second_data_frame.as_dgram_ref(), now()));
+ assert!(!server.events().any(stream_readable));
+}
+
+fn change_flow_control(stream_type: StreamType, new_fc: u64) {
+ const RECV_BUFFER_START: u64 = 300;
+
+ let mut client = new_client(
+ ConnectionParameters::default()
+ .max_stream_data(StreamType::BiDi, true, RECV_BUFFER_START)
+ .max_stream_data(StreamType::UniDi, true, RECV_BUFFER_START),
+ );
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // create a stream
+ let stream_id = server.stream_create(stream_type).unwrap();
+ let written1 = server.stream_send(stream_id, &[0x0; 10000]).unwrap();
+ assert_eq!(u64::try_from(written1).unwrap(), RECV_BUFFER_START);
+
+ // Send the stream to the client.
+ let out = server.process(None, now());
+ mem::drop(client.process(out.as_dgram_ref(), now()));
+
+ // change max_stream_data for stream_id.
+ client.set_stream_max_data(stream_id, new_fc).unwrap();
+
+ // server should receive a MAX_SREAM_DATA frame if the flow control window is updated.
+ let out2 = client.process(None, now());
+ let out3 = server.process(out2.as_dgram_ref(), now());
+ let expected = usize::from(RECV_BUFFER_START < new_fc);
+ assert_eq!(server.stats().frame_rx.max_stream_data, expected);
+
+ // If the flow control window has been increased, server can write more data.
+ let written2 = server.stream_send(stream_id, &[0x0; 10000]).unwrap();
+ if RECV_BUFFER_START < new_fc {
+ assert_eq!(u64::try_from(written2).unwrap(), new_fc - RECV_BUFFER_START);
+ } else {
+ assert_eq!(written2, 0);
+ }
+
+ // Exchange packets so that client gets all data.
+ let out4 = client.process(out3.as_dgram_ref(), now());
+ let out5 = server.process(out4.as_dgram_ref(), now());
+ mem::drop(client.process(out5.as_dgram_ref(), now()));
+
+ // read all data by client
+ let mut buf = [0x0; 10000];
+ let (read, _) = client.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(u64::try_from(read).unwrap(), max(RECV_BUFFER_START, new_fc));
+
+ let out4 = client.process(None, now());
+ mem::drop(server.process(out4.as_dgram_ref(), now()));
+
+ let written3 = server.stream_send(stream_id, &[0x0; 10000]).unwrap();
+ assert_eq!(u64::try_from(written3).unwrap(), new_fc);
+}
+
+#[test]
+fn increase_decrease_flow_control() {
+ const RECV_BUFFER_NEW_BIGGER: u64 = 400;
+ const RECV_BUFFER_NEW_SMALLER: u64 = 200;
+
+ change_flow_control(StreamType::UniDi, RECV_BUFFER_NEW_BIGGER);
+ change_flow_control(StreamType::BiDi, RECV_BUFFER_NEW_BIGGER);
+
+ change_flow_control(StreamType::UniDi, RECV_BUFFER_NEW_SMALLER);
+ change_flow_control(StreamType::BiDi, RECV_BUFFER_NEW_SMALLER);
+}
+
+#[test]
+fn session_flow_control_stop_sending_state_recv() {
+ const SMALL_MAX_DATA: usize = 1024;
+
+ let mut client = default_client();
+ let mut server = new_server(
+ ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()),
+ );
+
+ connect(&mut client, &mut server);
+
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SMALL_MAX_DATA
+ );
+
+ // send 1 byte so that the server learns about the stream.
+ assert_eq!(client.stream_send(stream_id, b"a").unwrap(), 1);
+
+ exchange_data(&mut client, &mut server);
+
+ server
+ .stream_stop_sending(stream_id, Error::NoError.code())
+ .unwrap();
+
+ assert_eq!(
+ client
+ .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA])
+ .unwrap(),
+ SMALL_MAX_DATA - 1
+ );
+
+ // In this case the final size is only known after RESET frame is received.
+ // The server sends STOP_SENDING -> the client sends RESET -> the server
+ // sends MAX_DATA.
+ let out = server.process(None, now()).dgram();
+ let out = client.process(out.as_ref(), now()).dgram();
+ // the client is still limited.
+ let stream_id2 = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(client.stream_avail_send_space(stream_id2).unwrap(), 0);
+ let out = server.process(out.as_ref(), now()).dgram();
+ client.process_input(&out.unwrap(), now());
+ assert_eq!(
+ client.stream_avail_send_space(stream_id2).unwrap(),
+ SMALL_MAX_DATA
+ );
+}
+
+#[test]
+fn session_flow_control_stop_sending_state_size_known() {
+ const SMALL_MAX_DATA: usize = 1024;
+
+ let mut client = default_client();
+ let mut server = new_server(
+ ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()),
+ );
+
+ connect(&mut client, &mut server);
+
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SMALL_MAX_DATA
+ );
+
+ // send 1 byte so that the server learns about the stream.
+ assert_eq!(
+ client
+ .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1])
+ .unwrap(),
+ SMALL_MAX_DATA
+ );
+
+ let out1 = client.process(None, now()).dgram();
+ // Delay this packet and let the server receive fin first (it will enter SizeKnown state).
+ client.stream_close_send(stream_id).unwrap();
+ let out2 = client.process(None, now()).dgram();
+
+ server.process_input(&out2.unwrap(), now());
+
+ server
+ .stream_stop_sending(stream_id, Error::NoError.code())
+ .unwrap();
+
+ // In this case the final size is known when stream_stop_sending is called
+ // and the server releases flow control immediately and sends STOP_SENDING and
+ // MAX_DATA in the same packet.
+ let out = server.process(out1.as_ref(), now()).dgram();
+ client.process_input(&out.unwrap(), now());
+
+ // The flow control should have been updated and the client can again send
+ // SMALL_MAX_DATA.
+ let stream_id2 = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(
+ client.stream_avail_send_space(stream_id2).unwrap(),
+ SMALL_MAX_DATA
+ );
+}
+
+#[test]
+fn session_flow_control_stop_sending_state_data_recvd() {
+ const SMALL_MAX_DATA: usize = 1024;
+
+ let mut client = default_client();
+ let mut server = new_server(
+ ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()),
+ );
+
+ connect(&mut client, &mut server);
+
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SMALL_MAX_DATA
+ );
+
+ // send 1 byte so that the server learns about the stream.
+ assert_eq!(
+ client
+ .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA + 1])
+ .unwrap(),
+ SMALL_MAX_DATA
+ );
+
+ client.stream_close_send(stream_id).unwrap();
+
+ exchange_data(&mut client, &mut server);
+
+ // The stream is DataRecvd state
+ server
+ .stream_stop_sending(stream_id, Error::NoError.code())
+ .unwrap();
+
+ exchange_data(&mut client, &mut server);
+
+ // The flow control should have been updated and the client can again send
+ // SMALL_MAX_DATA.
+ let stream_id2 = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(
+ client.stream_avail_send_space(stream_id2).unwrap(),
+ SMALL_MAX_DATA
+ );
+}
+
+#[test]
+fn session_flow_control_affects_all_streams() {
+ const SMALL_MAX_DATA: usize = 1024;
+
+ let mut client = default_client();
+ let mut server = new_server(
+ ConnectionParameters::default().max_data(u64::try_from(SMALL_MAX_DATA).unwrap()),
+ );
+
+ connect(&mut client, &mut server);
+
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SMALL_MAX_DATA
+ );
+
+ let stream_id2 = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(
+ client.stream_avail_send_space(stream_id2).unwrap(),
+ SMALL_MAX_DATA
+ );
+
+ assert_eq!(
+ client
+ .stream_send(stream_id, &[b'a'; SMALL_MAX_DATA / 2 + 1])
+ .unwrap(),
+ SMALL_MAX_DATA / 2 + 1
+ );
+
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SMALL_MAX_DATA / 2 - 1
+ );
+ assert_eq!(
+ client.stream_avail_send_space(stream_id2).unwrap(),
+ SMALL_MAX_DATA / 2 - 1
+ );
+
+ exchange_data(&mut client, &mut server);
+
+ let mut buf = [0x0; SMALL_MAX_DATA];
+ let (read, _) = server.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(read, SMALL_MAX_DATA / 2 + 1);
+
+ exchange_data(&mut client, &mut server);
+
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SMALL_MAX_DATA
+ );
+
+ assert_eq!(
+ client.stream_avail_send_space(stream_id2).unwrap(),
+ SMALL_MAX_DATA
+ );
+}
+
+fn connect_w_different_limit(bidi_limit: u64, unidi_limit: u64) {
+ let mut client = default_client();
+ let out = client.process(None, now());
+ let mut server = new_server(
+ ConnectionParameters::default()
+ .max_streams(StreamType::BiDi, bidi_limit)
+ .max_streams(StreamType::UniDi, unidi_limit),
+ );
+ let out = server.process(out.as_dgram_ref(), now());
+
+ let out = client.process(out.as_dgram_ref(), now());
+ mem::drop(server.process(out.as_dgram_ref(), now()));
+
+ assert!(maybe_authenticate(&mut client));
+
+ let mut bidi_events = 0;
+ let mut unidi_events = 0;
+ let mut connected_events = 0;
+ for e in client.events() {
+ match e {
+ ConnectionEvent::SendStreamCreatable { stream_type } => {
+ if stream_type == StreamType::BiDi {
+ bidi_events += 1;
+ } else {
+ unidi_events += 1;
+ }
+ }
+ ConnectionEvent::StateChange(State::Connected) => {
+ connected_events += 1;
+ }
+ _ => {}
+ }
+ }
+ assert_eq!(bidi_events, usize::from(bidi_limit > 0));
+ assert_eq!(unidi_events, usize::from(unidi_limit > 0));
+ assert_eq!(connected_events, 1);
+}
+
+#[test]
+fn client_stream_creatable_event() {
+ connect_w_different_limit(0, 0);
+ connect_w_different_limit(0, 1);
+ connect_w_different_limit(1, 0);
+ connect_w_different_limit(1, 1);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/vn.rs b/third_party/rust/neqo-transport/src/connection/tests/vn.rs
new file mode 100644
index 0000000000..22f15c991c
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/vn.rs
@@ -0,0 +1,482 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{mem, time::Duration};
+
+use neqo_common::{event::Provider, Decoder, Encoder};
+use test_fixture::{self, assertions, datagram, now};
+
+use super::{
+ super::{ConnectionError, ConnectionEvent, Output, State, ZeroRttState},
+ connect, connect_fail, default_client, default_server, exchange_ticket, new_client, new_server,
+ send_something,
+};
+use crate::{
+ packet::PACKET_BIT_LONG,
+ tparams::{self, TransportParameter},
+ ConnectionParameters, Error, Version,
+};
+
+// The expected PTO duration after the first Initial is sent.
+const INITIAL_PTO: Duration = Duration::from_millis(300);
+
+#[test]
+fn unknown_version() {
+ let mut client = default_client();
+ // Start the handshake.
+ mem::drop(client.process(None, now()).dgram());
+
+ let mut unknown_version_packet = vec![0x80, 0x1a, 0x1a, 0x1a, 0x1a];
+ unknown_version_packet.resize(1200, 0x0);
+ mem::drop(client.process(Some(&datagram(unknown_version_packet)), now()));
+ assert_eq!(1, client.stats().dropped_rx);
+}
+
+#[test]
+fn server_receive_unknown_first_packet() {
+ let mut server = default_server();
+
+ let mut unknown_version_packet = vec![0x80, 0x1a, 0x1a, 0x1a, 0x1a];
+ unknown_version_packet.resize(1200, 0x0);
+
+ assert_eq!(
+ server.process(Some(&datagram(unknown_version_packet,)), now(),),
+ Output::None
+ );
+
+ assert_eq!(1, server.stats().dropped_rx);
+}
+
+fn create_vn(initial_pkt: &[u8], versions: &[u32]) -> Vec<u8> {
+ let mut dec = Decoder::from(&initial_pkt[5..]); // Skip past version.
+ let dst_cid = dec.decode_vec(1).expect("client DCID");
+ let src_cid = dec.decode_vec(1).expect("client SCID");
+
+ let mut encoder = Encoder::default();
+ encoder.encode_byte(PACKET_BIT_LONG);
+ encoder.encode(&[0; 4]); // Zero version == VN.
+ encoder.encode_vec(1, src_cid);
+ encoder.encode_vec(1, dst_cid);
+
+ for v in versions {
+ encoder.encode_uint(4, *v);
+ }
+ encoder.into()
+}
+
+#[test]
+fn version_negotiation_current_version() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let vn = create_vn(
+ &initial_pkt,
+ &[0x1a1a_1a1a, Version::default().wire_version()],
+ );
+
+ let dgram = datagram(vn);
+ let delay = client.process(Some(&dgram), now()).callback();
+ assert_eq!(delay, INITIAL_PTO);
+ assert_eq!(*client.state(), State::WaitInitial);
+ assert_eq!(1, client.stats().dropped_rx);
+}
+
+#[test]
+fn version_negotiation_version0() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let vn = create_vn(&initial_pkt, &[0, 0x1a1a_1a1a]);
+
+ let dgram = datagram(vn);
+ let delay = client.process(Some(&dgram), now()).callback();
+ assert_eq!(delay, INITIAL_PTO);
+ assert_eq!(*client.state(), State::WaitInitial);
+ assert_eq!(1, client.stats().dropped_rx);
+}
+
+#[test]
+fn version_negotiation_only_reserved() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a]);
+
+ let dgram = datagram(vn);
+ assert_eq!(client.process(Some(&dgram), now()), Output::None);
+ match client.state() {
+ State::Closed(err) => {
+ assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation));
+ }
+ _ => panic!("Invalid client state"),
+ }
+}
+
+#[test]
+fn version_negotiation_corrupted() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a]);
+
+ let dgram = datagram(vn[..vn.len() - 1].to_vec());
+ let delay = client.process(Some(&dgram), now()).callback();
+ assert_eq!(delay, INITIAL_PTO);
+ assert_eq!(*client.state(), State::WaitInitial);
+ assert_eq!(1, client.stats().dropped_rx);
+}
+
+#[test]
+fn version_negotiation_empty() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let vn = create_vn(&initial_pkt, &[]);
+
+ let dgram = datagram(vn);
+ let delay = client.process(Some(&dgram), now()).callback();
+ assert_eq!(delay, INITIAL_PTO);
+ assert_eq!(*client.state(), State::WaitInitial);
+ assert_eq!(1, client.stats().dropped_rx);
+}
+
+#[test]
+fn version_negotiation_not_supported() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a, 0xff00_0001]);
+ let dgram = datagram(vn);
+ assert_eq!(client.process(Some(&dgram), now()), Output::None);
+ match client.state() {
+ State::Closed(err) => {
+ assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation));
+ }
+ _ => panic!("Invalid client state"),
+ }
+}
+
+#[test]
+fn version_negotiation_bad_cid() {
+ let mut client = default_client();
+ // Start the handshake.
+ let mut initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ initial_pkt[6] ^= 0xc4;
+ let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a, 0xff00_0001]);
+
+ let dgram = datagram(vn);
+ let delay = client.process(Some(&dgram), now()).callback();
+ assert_eq!(delay, INITIAL_PTO);
+ assert_eq!(*client.state(), State::WaitInitial);
+ assert_eq!(1, client.stats().dropped_rx);
+}
+
+#[test]
+fn compatible_upgrade() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ connect(&mut client, &mut server);
+ assert_eq!(client.version(), Version::Version2);
+ assert_eq!(server.version(), Version::Version2);
+}
+
+/// When the first packet from the client is gigantic, the server might generate acknowledgment
+/// packets in version 1. Both client and server need to handle that gracefully.
+#[test]
+fn compatible_upgrade_large_initial() {
+ let params = ConnectionParameters::default().versions(
+ Version::Version1,
+ vec![Version::Version2, Version::Version1],
+ );
+ let mut client = new_client(params.clone());
+ client
+ .set_local_tparam(
+ 0x0845_de37_00ac_a5f9,
+ TransportParameter::Bytes(vec![0; 2048]),
+ )
+ .unwrap();
+ let mut server = new_server(params);
+
+ // Client Initial should take 2 packets.
+ // Each should elicit a Version 1 ACK from the server.
+ let dgram = client.process_output(now()).dgram();
+ assert!(dgram.is_some());
+ let dgram = server.process(dgram.as_ref(), now()).dgram();
+ assert!(dgram.is_some());
+ // The following uses the Version from *outside* this crate.
+ assertions::assert_version(dgram.as_ref().unwrap(), Version::Version1.wire_version());
+ client.process_input(&dgram.unwrap(), now());
+
+ connect(&mut client, &mut server);
+ assert_eq!(client.version(), Version::Version2);
+ assert_eq!(server.version(), Version::Version2);
+ // Only handshake padding is "dropped".
+ assert_eq!(client.stats().dropped_rx, 1);
+ assert_eq!(server.stats().dropped_rx, 1);
+}
+
+/// A server that supports versions 1 and 2 might prefer version 1 and that's OK.
+/// This one starts with version 1 and stays there.
+#[test]
+fn compatible_no_upgrade() {
+ let mut client = new_client(ConnectionParameters::default().versions(
+ Version::Version1,
+ vec![Version::Version2, Version::Version1],
+ ));
+ let mut server = new_server(ConnectionParameters::default().versions(
+ Version::Version1,
+ vec![Version::Version1, Version::Version2],
+ ));
+
+ connect(&mut client, &mut server);
+ assert_eq!(client.version(), Version::Version1);
+ assert_eq!(server.version(), Version::Version1);
+}
+
+/// A server that supports versions 1 and 2 might prefer version 1 and that's OK.
+/// This one starts with version 2 and downgrades to version 1.
+#[test]
+fn compatible_downgrade() {
+ let mut client = new_client(ConnectionParameters::default().versions(
+ Version::Version2,
+ vec![Version::Version2, Version::Version1],
+ ));
+ let mut server = new_server(ConnectionParameters::default().versions(
+ Version::Version2,
+ vec![Version::Version1, Version::Version2],
+ ));
+
+ connect(&mut client, &mut server);
+ assert_eq!(client.version(), Version::Version1);
+ assert_eq!(server.version(), Version::Version1);
+}
+
+/// Inject a Version Negotiation packet, which the client detects when it validates the
+/// server `version_negotiation` transport parameter.
+#[test]
+fn version_negotiation_downgrade() {
+ const DOWNGRADE: Version = Version::Draft29;
+
+ let mut client = default_client();
+ // The server sets the current version in the transport parameter and
+ // protects Initial packets with the version in its configuration.
+ // When a server `Connection` is created by a `Server`, the configuration is set
+ // to match the version of the packet it first receives. This replicates that.
+ let mut server =
+ new_server(ConnectionParameters::default().versions(DOWNGRADE, Version::all()));
+
+ // Start the handshake and spoof a VN packet.
+ let initial = client.process_output(now()).dgram().unwrap();
+ let vn = create_vn(&initial, &[DOWNGRADE.wire_version()]);
+ let dgram = datagram(vn);
+ client.process_input(&dgram, now());
+
+ connect_fail(
+ &mut client,
+ &mut server,
+ Error::VersionNegotiation,
+ Error::PeerError(Error::VersionNegotiation.code()),
+ );
+}
+
+/// A server connection needs to be configured with the version that the client attempts.
+/// Otherwise, it will object to the client transport parameters and not do anything.
+#[test]
+fn invalid_server_version() {
+ let mut client =
+ new_client(ConnectionParameters::default().versions(Version::Version1, Version::all()));
+ let mut server =
+ new_server(ConnectionParameters::default().versions(Version::Version2, Version::all()));
+
+ let dgram = client.process_output(now()).dgram();
+ server.process_input(&dgram.unwrap(), now());
+
+ // One packet received.
+ assert_eq!(server.stats().packets_rx, 1);
+ // None dropped; the server will have decrypted it successfully.
+ assert_eq!(server.stats().dropped_rx, 0);
+ assert_eq!(server.stats().saved_datagrams, 0);
+ // The server effectively hasn't reacted here.
+ match server.state() {
+ State::Closed(err) => {
+ assert_eq!(*err, ConnectionError::Transport(Error::CryptoAlert(47)));
+ }
+ _ => panic!("invalid server state"),
+ }
+}
+
+#[test]
+fn invalid_current_version_client() {
+ const OTHER_VERSION: Version = Version::Draft29;
+
+ let mut client = default_client();
+ let mut server = default_server();
+
+ assert_ne!(OTHER_VERSION, client.version());
+ client
+ .set_local_tparam(
+ tparams::VERSION_INFORMATION,
+ TransportParameter::Versions {
+ current: OTHER_VERSION.wire_version(),
+ other: Version::all()
+ .iter()
+ .copied()
+ .map(Version::wire_version)
+ .collect(),
+ },
+ )
+ .unwrap();
+
+ connect_fail(
+ &mut client,
+ &mut server,
+ Error::PeerError(Error::CryptoAlert(47).code()),
+ Error::CryptoAlert(47),
+ );
+}
+
+/// To test this, we need to disable compatible upgrade so that the server doesn't update
+/// its transport parameters. Then, we can overwrite its transport parameters without
+/// them being overwritten. Otherwise, it would be hard to find a window during which
+/// the transport parameter can be modified.
+#[test]
+fn invalid_current_version_server() {
+ const OTHER_VERSION: Version = Version::Draft29;
+
+ let mut client = default_client();
+ let mut server = new_server(
+ ConnectionParameters::default().versions(Version::default(), vec![Version::default()]),
+ );
+
+ assert!(!Version::default().is_compatible(OTHER_VERSION));
+ server
+ .set_local_tparam(
+ tparams::VERSION_INFORMATION,
+ TransportParameter::Versions {
+ current: OTHER_VERSION.wire_version(),
+ other: vec![OTHER_VERSION.wire_version()],
+ },
+ )
+ .unwrap();
+
+ connect_fail(
+ &mut client,
+ &mut server,
+ Error::CryptoAlert(47),
+ Error::PeerError(Error::CryptoAlert(47).code()),
+ );
+}
+
+#[test]
+fn no_compatible_version() {
+ const OTHER_VERSION: Version = Version::Draft29;
+
+ let mut client = default_client();
+ let mut server = default_server();
+
+ assert_ne!(OTHER_VERSION, client.version());
+ client
+ .set_local_tparam(
+ tparams::VERSION_INFORMATION,
+ TransportParameter::Versions {
+ current: Version::default().wire_version(),
+ other: vec![OTHER_VERSION.wire_version()],
+ },
+ )
+ .unwrap();
+
+ connect_fail(
+ &mut client,
+ &mut server,
+ Error::PeerError(Error::CryptoAlert(47).code()),
+ Error::CryptoAlert(47),
+ );
+}
+
+/// When a compatible upgrade chooses a different version, 0-RTT is rejected.
+#[test]
+fn compatible_upgrade_0rtt_rejected() {
+ // This is the baseline configuration where v1 is attempted and v2 preferred.
+ let prefer_v2 = ConnectionParameters::default().versions(
+ Version::Version1,
+ vec![Version::Version2, Version::Version1],
+ );
+ let mut client = new_client(prefer_v2.clone());
+ // The server will start with this so that the client resumes with v1.
+ let just_v1 =
+ ConnectionParameters::default().versions(Version::Version1, vec![Version::Version1]);
+ let mut server = new_server(just_v1);
+
+ connect(&mut client, &mut server);
+ assert_eq!(client.version(), Version::Version1);
+ let token = exchange_ticket(&mut client, &mut server, now());
+
+ // Now upgrade the server to the preferred configuration.
+ let mut client = new_client(prefer_v2.clone());
+ let mut server = new_server(prefer_v2);
+ client.enable_resumption(now(), token).unwrap();
+
+ // Create a packet with 0-RTT from the client.
+ let initial = send_something(&mut client, now());
+ assertions::assert_version(&initial, Version::Version1.wire_version());
+ assertions::assert_coalesced_0rtt(&initial);
+ server.process_input(&initial, now());
+ assert!(!server
+ .events()
+ .any(|e| matches!(e, ConnectionEvent::NewStream { .. })));
+
+ // Finalize the connection. Don't use connect() because it uses
+ // maybe_authenticate() too liberally and that eats the events we want to check.
+ let dgram = server.process_output(now()).dgram(); // ServerHello flight
+ let dgram = client.process(dgram.as_ref(), now()).dgram(); // Client Finished (note: no authentication)
+ let dgram = server.process(dgram.as_ref(), now()).dgram(); // HANDSHAKE_DONE
+ client.process_input(&dgram.unwrap(), now());
+
+ assert!(matches!(client.state(), State::Confirmed));
+ assert!(matches!(server.state(), State::Confirmed));
+
+ assert!(client.events().any(|e| {
+ println!(" client event: {e:?}");
+ matches!(e, ConnectionEvent::ZeroRttRejected)
+ }));
+ assert_eq!(client.zero_rtt_state(), ZeroRttState::Rejected);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs b/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs
new file mode 100644
index 0000000000..0aa5573c98
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs
@@ -0,0 +1,257 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{cell::RefCell, rc::Rc};
+
+use neqo_common::event::Provider;
+use neqo_crypto::{AllowZeroRtt, AntiReplay};
+use test_fixture::{self, assertions, now};
+
+use super::{
+ super::Connection, connect, default_client, default_server, exchange_ticket, new_server,
+ resumed_server, CountingConnectionIdGenerator,
+};
+use crate::{events::ConnectionEvent, ConnectionParameters, Error, StreamType, Version};
+
+#[test]
+fn zero_rtt_negotiate() {
+ // Note that the two servers in this test will get different anti-replay filters.
+ // That's OK because we aren't testing anti-replay.
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let token = exchange_ticket(&mut client, &mut server, now());
+ let mut client = default_client();
+ client
+ .enable_resumption(now(), token)
+ .expect("should set token");
+ let mut server = resumed_server(&client);
+ connect(&mut client, &mut server);
+ assert!(client.tls_info().unwrap().early_data_accepted());
+ assert!(server.tls_info().unwrap().early_data_accepted());
+}
+
+#[test]
+fn zero_rtt_send_recv() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let token = exchange_ticket(&mut client, &mut server, now());
+ let mut client = default_client();
+ client
+ .enable_resumption(now(), token)
+ .expect("should set token");
+ let mut server = resumed_server(&client);
+
+ // Send ClientHello.
+ let client_hs = client.process(None, now());
+ assert!(client_hs.as_dgram_ref().is_some());
+
+ // Now send a 0-RTT packet.
+ let client_stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(client_stream_id, &[1, 2, 3]).unwrap();
+ let client_0rtt = client.process(None, now());
+ assert!(client_0rtt.as_dgram_ref().is_some());
+ // 0-RTT packets on their own shouldn't be padded to 1200.
+ assert!(client_0rtt.as_dgram_ref().unwrap().len() < 1200);
+
+ let server_hs = server.process(client_hs.as_dgram_ref(), now());
+ assert!(server_hs.as_dgram_ref().is_some()); // ServerHello, etc...
+
+ let all_frames = server.stats().frame_tx.all;
+ let ack_frames = server.stats().frame_tx.ack;
+ let server_process_0rtt = server.process(client_0rtt.as_dgram_ref(), now());
+ assert!(server_process_0rtt.as_dgram_ref().is_some());
+ assert_eq!(server.stats().frame_tx.all, all_frames + 1);
+ assert_eq!(server.stats().frame_tx.ack, ack_frames + 1);
+
+ let server_stream_id = server
+ .events()
+ .find_map(|evt| match evt {
+ ConnectionEvent::NewStream { stream_id, .. } => Some(stream_id),
+ _ => None,
+ })
+ .expect("should have received a new stream event");
+ assert_eq!(client_stream_id, server_stream_id.as_u64());
+}
+
+#[test]
+fn zero_rtt_send_coalesce() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let token = exchange_ticket(&mut client, &mut server, now());
+ let mut client = default_client();
+ client
+ .enable_resumption(now(), token)
+ .expect("should set token");
+ let mut server = resumed_server(&client);
+
+ // Write 0-RTT before generating any packets.
+ // This should result in a datagram that coalesces Initial and 0-RTT.
+ let client_stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(client_stream_id, &[1, 2, 3]).unwrap();
+ let client_0rtt = client.process(None, now());
+ assert!(client_0rtt.as_dgram_ref().is_some());
+
+ assertions::assert_coalesced_0rtt(&client_0rtt.as_dgram_ref().unwrap()[..]);
+
+ let server_hs = server.process(client_0rtt.as_dgram_ref(), now());
+ assert!(server_hs.as_dgram_ref().is_some()); // Should produce ServerHello etc...
+
+ let server_stream_id = server
+ .events()
+ .find_map(|evt| match evt {
+ ConnectionEvent::NewStream { stream_id } => Some(stream_id),
+ _ => None,
+ })
+ .expect("should have received a new stream event");
+ assert_eq!(client_stream_id, server_stream_id.as_u64());
+}
+
+#[test]
+fn zero_rtt_before_resumption_token() {
+ let mut client = default_client();
+ assert!(client.stream_create(StreamType::BiDi).is_err());
+}
+
+#[test]
+fn zero_rtt_send_reject() {
+ const MESSAGE: &[u8] = &[1, 2, 3];
+
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let token = exchange_ticket(&mut client, &mut server, now());
+ let mut client = default_client();
+ client
+ .enable_resumption(now(), token)
+ .expect("should set token");
+ let mut server = Connection::new_server(
+ test_fixture::DEFAULT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(CountingConnectionIdGenerator::default())),
+ ConnectionParameters::default().versions(client.version(), Version::all()),
+ )
+ .unwrap();
+ // Using a freshly initialized anti-replay context
+ // should result in the server rejecting 0-RTT.
+ let ar =
+ AntiReplay::new(now(), test_fixture::ANTI_REPLAY_WINDOW, 1, 3).expect("setup anti-replay");
+ server
+ .server_enable_0rtt(&ar, AllowZeroRtt {})
+ .expect("enable 0-RTT");
+
+ // Send ClientHello.
+ let client_hs = client.process(None, now());
+ assert!(client_hs.as_dgram_ref().is_some());
+
+ // Write some data on the client.
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(stream_id, MESSAGE).unwrap();
+ let client_0rtt = client.process(None, now());
+ assert!(client_0rtt.as_dgram_ref().is_some());
+
+ let server_hs = server.process(client_hs.as_dgram_ref(), now());
+ assert!(server_hs.as_dgram_ref().is_some()); // Should produce ServerHello etc...
+ let server_ignored = server.process(client_0rtt.as_dgram_ref(), now());
+ assert!(server_ignored.as_dgram_ref().is_none());
+
+ // The server shouldn't receive that 0-RTT data.
+ let recvd_stream_evt = |e| matches!(e, ConnectionEvent::NewStream { .. });
+ assert!(!server.events().any(recvd_stream_evt));
+
+ // Client should get a rejection.
+ let client_fin = client.process(server_hs.as_dgram_ref(), now());
+ let recvd_0rtt_reject = |e| e == ConnectionEvent::ZeroRttRejected;
+ assert!(client.events().any(recvd_0rtt_reject));
+
+ // Server consume client_fin
+ let server_ack = server.process(client_fin.as_dgram_ref(), now());
+ assert!(server_ack.as_dgram_ref().is_some());
+ let client_out = client.process(server_ack.as_dgram_ref(), now());
+ assert!(client_out.as_dgram_ref().is_none());
+
+ // ...and the client stream should be gone.
+ let res = client.stream_send(stream_id, MESSAGE);
+ assert!(res.is_err());
+ assert_eq!(res.unwrap_err(), Error::InvalidStreamId);
+
+ // Open a new stream and send data. StreamId should start with 0.
+ let stream_id_after_reject = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(stream_id, stream_id_after_reject);
+ client.stream_send(stream_id_after_reject, MESSAGE).unwrap();
+ let client_after_reject = client.process(None, now()).dgram();
+ assert!(client_after_reject.is_some());
+
+ // The server should receive new stream
+ server.process_input(&client_after_reject.unwrap(), now());
+ assert!(server.events().any(recvd_stream_evt));
+}
+
+#[test]
+fn zero_rtt_update_flow_control() {
+ const LOW: u64 = 3;
+ const HIGH: u64 = 10;
+ #[allow(clippy::cast_possible_truncation)]
+ const MESSAGE: &[u8] = &[0; HIGH as usize];
+
+ let mut client = default_client();
+ let mut server = new_server(
+ ConnectionParameters::default()
+ .max_stream_data(StreamType::UniDi, true, LOW)
+ .max_stream_data(StreamType::BiDi, true, LOW),
+ );
+ connect(&mut client, &mut server);
+
+ let token = exchange_ticket(&mut client, &mut server, now());
+ let mut client = default_client();
+ client
+ .enable_resumption(now(), token)
+ .expect("should set token");
+ let mut server = new_server(
+ ConnectionParameters::default()
+ .max_stream_data(StreamType::UniDi, true, HIGH)
+ .max_stream_data(StreamType::BiDi, true, HIGH)
+ .versions(client.version, Version::all()),
+ );
+
+ // Stream limits should be low for 0-RTT.
+ let client_hs = client.process(None, now()).dgram();
+ let uni_stream = client.stream_create(StreamType::UniDi).unwrap();
+ assert!(!client.stream_send_atomic(uni_stream, MESSAGE).unwrap());
+ let bidi_stream = client.stream_create(StreamType::BiDi).unwrap();
+ assert!(!client.stream_send_atomic(bidi_stream, MESSAGE).unwrap());
+
+ // Now get the server transport parameters.
+ let server_hs = server.process(client_hs.as_ref(), now()).dgram();
+ client.process_input(&server_hs.unwrap(), now());
+
+ // The streams should report a writeable event.
+ let mut uni_stream_event = false;
+ let mut bidi_stream_event = false;
+ for e in client.events() {
+ if let ConnectionEvent::SendStreamWritable { stream_id } = e {
+ if stream_id.is_uni() {
+ uni_stream_event = true;
+ } else {
+ bidi_stream_event = true;
+ }
+ }
+ }
+ assert!(uni_stream_event);
+ assert!(bidi_stream_event);
+ // But no MAX_STREAM_DATA frame was received.
+ assert_eq!(client.stats().frame_rx.max_stream_data, 0);
+
+ // And the new limit applies.
+ assert!(client.stream_send_atomic(uni_stream, MESSAGE).unwrap());
+ assert!(client.stream_send_atomic(bidi_stream, MESSAGE).unwrap());
+}
diff --git a/third_party/rust/neqo-transport/src/crypto.rs b/third_party/rust/neqo-transport/src/crypto.rs
new file mode 100644
index 0000000000..f6cc7c0e2f
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/crypto.rs
@@ -0,0 +1,1583 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::{
+ cell::RefCell,
+ cmp::{max, min},
+ collections::HashMap,
+ convert::TryFrom,
+ mem,
+ ops::{Index, IndexMut, Range},
+ rc::Rc,
+ time::Instant,
+};
+
+use neqo_common::{hex, hex_snip_middle, qdebug, qinfo, qtrace, Encoder, Role};
+use neqo_crypto::{
+ hkdf, hp::HpKey, Aead, Agent, AntiReplay, Cipher, Epoch, Error as CryptoError, HandshakeState,
+ PrivateKey, PublicKey, Record, RecordList, ResumptionToken, SymKey, ZeroRttChecker,
+ TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, TLS_CT_HANDSHAKE,
+ TLS_EPOCH_APPLICATION_DATA, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL, TLS_EPOCH_ZERO_RTT,
+ TLS_GRP_EC_SECP256R1, TLS_GRP_EC_SECP384R1, TLS_GRP_EC_SECP521R1, TLS_GRP_EC_X25519,
+ TLS_VERSION_1_3,
+};
+
+use crate::{
+ cid::ConnectionIdRef,
+ packet::{PacketBuilder, PacketNumber},
+ recovery::RecoveryToken,
+ recv_stream::RxStreamOrderer,
+ send_stream::TxBuffer,
+ stats::FrameStats,
+ tparams::{TpZeroRttChecker, TransportParameters, TransportParametersHandler},
+ tracking::PacketNumberSpace,
+ version::Version,
+ Error, Res,
+};
+
+const MAX_AUTH_TAG: usize = 32;
+/// The number of invocations remaining on a write cipher before we try
+/// to update keys. This has to be much smaller than the number returned
+/// by `CryptoDxState::limit` or updates will happen too often. As we don't
+/// need to ask permission to update, this can be quite small.
+pub(crate) const UPDATE_WRITE_KEYS_AT: PacketNumber = 100;
+
+// This is a testing kludge that allows for overwriting the number of
+// invocations of the next cipher to operate. With this, it is possible
+// to test what happens when the number of invocations reaches 0, or
+// when it hits `UPDATE_WRITE_KEYS_AT` and an automatic update should occur.
+// This is a little crude, but it saves a lot of plumbing.
+#[cfg(test)]
+thread_local!(pub(crate) static OVERWRITE_INVOCATIONS: RefCell<Option<PacketNumber>> = RefCell::default());
+
+#[derive(Debug)]
+pub struct Crypto {
+ version: Version,
+ protocols: Vec<String>,
+ pub(crate) tls: Agent,
+ pub(crate) streams: CryptoStreams,
+ pub(crate) states: CryptoStates,
+}
+
+type TpHandler = Rc<RefCell<TransportParametersHandler>>;
+
+impl Crypto {
+ pub fn new(
+ version: Version,
+ mut agent: Agent,
+ protocols: Vec<String>,
+ tphandler: TpHandler,
+ fuzzing: bool,
+ ) -> Res<Self> {
+ agent.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?;
+ agent.set_ciphers(&[
+ TLS_AES_128_GCM_SHA256,
+ TLS_AES_256_GCM_SHA384,
+ TLS_CHACHA20_POLY1305_SHA256,
+ ])?;
+ agent.set_groups(&[
+ TLS_GRP_EC_X25519,
+ TLS_GRP_EC_SECP256R1,
+ TLS_GRP_EC_SECP384R1,
+ TLS_GRP_EC_SECP521R1,
+ ])?;
+ agent.send_additional_key_shares(1)?;
+ agent.set_alpn(&protocols)?;
+ agent.disable_end_of_early_data()?;
+ // Always enable 0-RTT on the client, but the server needs
+ // more configuration passed to server_enable_0rtt.
+ if let Agent::Client(c) = &mut agent {
+ c.enable_0rtt()?;
+ }
+ let extension = match version {
+ Version::Version2 | Version::Version1 => 0x39,
+ Version::Draft29 | Version::Draft30 | Version::Draft31 | Version::Draft32 => 0xffa5,
+ };
+ agent.extension_handler(extension, tphandler)?;
+ Ok(Self {
+ version,
+ protocols,
+ tls: agent,
+ streams: Default::default(),
+ states: CryptoStates {
+ fuzzing,
+ ..Default::default()
+ },
+ })
+ }
+
+ /// Get the name of the server. (Only works for the client currently).
+ pub fn server_name(&self) -> Option<&str> {
+ if let Agent::Client(c) = &self.tls {
+ Some(c.server_name())
+ } else {
+ None
+ }
+ }
+
+ /// Get the set of enabled protocols.
+ pub fn protocols(&self) -> &[String] {
+ &self.protocols
+ }
+
+ pub fn server_enable_0rtt(
+ &mut self,
+ tphandler: TpHandler,
+ anti_replay: &AntiReplay,
+ zero_rtt_checker: impl ZeroRttChecker + 'static,
+ ) -> Res<()> {
+ if let Agent::Server(s) = &mut self.tls {
+ Ok(s.enable_0rtt(
+ anti_replay,
+ 0xffff_ffff,
+ TpZeroRttChecker::wrap(tphandler, zero_rtt_checker),
+ )?)
+ } else {
+ panic!("not a server");
+ }
+ }
+
+ pub fn server_enable_ech(
+ &mut self,
+ config: u8,
+ public_name: &str,
+ sk: &PrivateKey,
+ pk: &PublicKey,
+ ) -> Res<()> {
+ if let Agent::Server(s) = &mut self.tls {
+ s.enable_ech(config, public_name, sk, pk)?;
+ Ok(())
+ } else {
+ panic!("not a client");
+ }
+ }
+
+ pub fn client_enable_ech(&mut self, ech_config_list: impl AsRef<[u8]>) -> Res<()> {
+ if let Agent::Client(c) = &mut self.tls {
+ c.enable_ech(ech_config_list)?;
+ Ok(())
+ } else {
+ panic!("not a client");
+ }
+ }
+
+ /// Get the active ECH configuration, which is empty if ECH is disabled.
+ pub fn ech_config(&self) -> &[u8] {
+ self.tls.ech_config()
+ }
+
+ pub fn handshake(
+ &mut self,
+ now: Instant,
+ space: PacketNumberSpace,
+ data: Option<&[u8]>,
+ ) -> Res<&HandshakeState> {
+ let input = data.map(|d| {
+ qtrace!("Handshake record received {:0x?} ", d);
+ let epoch = match space {
+ PacketNumberSpace::Initial => TLS_EPOCH_INITIAL,
+ PacketNumberSpace::Handshake => TLS_EPOCH_HANDSHAKE,
+ // Our epoch progresses forward, but the TLS epoch is fixed to 3.
+ PacketNumberSpace::ApplicationData => TLS_EPOCH_APPLICATION_DATA,
+ };
+ Record {
+ ct: TLS_CT_HANDSHAKE,
+ epoch,
+ data: d.to_vec(),
+ }
+ });
+
+ match self.tls.handshake_raw(now, input) {
+ Ok(output) => {
+ self.buffer_records(output)?;
+ Ok(self.tls.state())
+ }
+ Err(CryptoError::EchRetry(v)) => Err(Error::EchRetry(v)),
+ Err(e) => {
+ qinfo!("Handshake failed {:?}", e);
+ Err(match self.tls.alert() {
+ Some(a) => Error::CryptoAlert(*a),
+ _ => Error::CryptoError(e),
+ })
+ }
+ }
+ }
+
+ /// Enable 0-RTT and return `true` if it is enabled successfully.
+ pub fn enable_0rtt(&mut self, version: Version, role: Role) -> Res<bool> {
+ let info = self.tls.preinfo()?;
+ // `info.early_data()` returns false for a server,
+ // so use `early_data_cipher()` to tell if 0-RTT is enabled.
+ let cipher = info.early_data_cipher();
+ if cipher.is_none() {
+ return Ok(false);
+ }
+ let (dir, secret) = match role {
+ Role::Client => (
+ CryptoDxDirection::Write,
+ self.tls.write_secret(TLS_EPOCH_ZERO_RTT),
+ ),
+ Role::Server => (
+ CryptoDxDirection::Read,
+ self.tls.read_secret(TLS_EPOCH_ZERO_RTT),
+ ),
+ };
+ let secret = secret.ok_or(Error::InternalError)?;
+ self.states
+ .set_0rtt_keys(version, dir, &secret, cipher.unwrap());
+ Ok(true)
+ }
+
+ /// Lock in a compatible upgrade.
+ pub fn confirm_version(&mut self, confirmed: Version) {
+ self.states.confirm_version(self.version, confirmed);
+ self.version = confirmed;
+ }
+
+ /// Returns true if new handshake keys were installed.
+ pub fn install_keys(&mut self, role: Role) -> Res<bool> {
+ if !self.tls.state().is_final() {
+ let installed_hs = self.install_handshake_keys()?;
+ if role == Role::Server {
+ self.maybe_install_application_write_key(self.version)?;
+ }
+ Ok(installed_hs)
+ } else {
+ Ok(false)
+ }
+ }
+
+ fn install_handshake_keys(&mut self) -> Res<bool> {
+ qtrace!([self], "Attempt to install handshake keys");
+ let Some(write_secret) = self.tls.write_secret(TLS_EPOCH_HANDSHAKE) else {
+ // No keys is fine.
+ return Ok(false);
+ };
+ let read_secret = self
+ .tls
+ .read_secret(TLS_EPOCH_HANDSHAKE)
+ .ok_or(Error::InternalError)?;
+ let cipher = match self.tls.info() {
+ None => self.tls.preinfo()?.cipher_suite(),
+ Some(info) => Some(info.cipher_suite()),
+ }
+ .ok_or(Error::InternalError)?;
+ self.states
+ .set_handshake_keys(self.version, &write_secret, &read_secret, cipher);
+ qdebug!([self], "Handshake keys installed");
+ Ok(true)
+ }
+
+ fn maybe_install_application_write_key(&mut self, version: Version) -> Res<()> {
+ qtrace!([self], "Attempt to install application write key");
+ if let Some(secret) = self.tls.write_secret(TLS_EPOCH_APPLICATION_DATA) {
+ self.states.set_application_write_key(version, secret)?;
+ qdebug!([self], "Application write key installed");
+ }
+ Ok(())
+ }
+
+ pub fn install_application_keys(&mut self, version: Version, expire_0rtt: Instant) -> Res<()> {
+ self.maybe_install_application_write_key(version)?;
+ // The write key might have been installed earlier, but it should
+ // always be installed now.
+ debug_assert!(self.states.app_write.is_some());
+ let read_secret = self
+ .tls
+ .read_secret(TLS_EPOCH_APPLICATION_DATA)
+ .ok_or(Error::InternalError)?;
+ self.states
+ .set_application_read_key(version, read_secret, expire_0rtt)?;
+ qdebug!([self], "application read keys installed");
+ Ok(())
+ }
+
+ /// Buffer crypto records for sending.
+ pub fn buffer_records(&mut self, records: RecordList) -> Res<()> {
+ for r in records {
+ if r.ct != TLS_CT_HANDSHAKE {
+ return Err(Error::ProtocolViolation);
+ }
+ qtrace!([self], "Adding CRYPTO data {:?}", r);
+ self.streams.send(PacketNumberSpace::from(r.epoch), &r.data);
+ }
+ Ok(())
+ }
+
+ pub fn write_frame(
+ &mut self,
+ space: PacketNumberSpace,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) -> Res<()> {
+ self.streams.write_frame(space, builder, tokens, stats)
+ }
+
+ pub fn acked(&mut self, token: &CryptoRecoveryToken) {
+ qinfo!(
+ "Acked crypto frame space={} offset={} length={}",
+ token.space,
+ token.offset,
+ token.length
+ );
+ self.streams.acked(token);
+ }
+
+ pub fn lost(&mut self, token: &CryptoRecoveryToken) {
+ qinfo!(
+ "Lost crypto frame space={} offset={} length={}",
+ token.space,
+ token.offset,
+ token.length
+ );
+ self.streams.lost(token);
+ }
+
+ /// Mark any outstanding frames in the indicated space as "lost" so
+ /// that they can be sent again.
+ pub fn resend_unacked(&mut self, space: PacketNumberSpace) {
+ self.streams.resend_unacked(space);
+ }
+
+ /// Discard state for a packet number space and return true
+ /// if something was discarded.
+ pub fn discard(&mut self, space: PacketNumberSpace) -> bool {
+ self.streams.discard(space);
+ self.states.discard(space)
+ }
+
+ pub fn create_resumption_token(
+ &mut self,
+ new_token: Option<&[u8]>,
+ tps: &TransportParameters,
+ version: Version,
+ rtt: u64,
+ ) -> Option<ResumptionToken> {
+ if let Agent::Client(ref mut c) = self.tls {
+ if let Some(ref t) = c.resumption_token() {
+ qtrace!("TLS token {}", hex(t.as_ref()));
+ let mut enc = Encoder::default();
+ enc.encode_uint(4, version.wire_version());
+ enc.encode_varint(rtt);
+ enc.encode_vvec_with(|enc_inner| {
+ tps.encode(enc_inner);
+ });
+ enc.encode_vvec(new_token.unwrap_or(&[]));
+ enc.encode(t.as_ref());
+ qinfo!("resumption token {}", hex_snip_middle(enc.as_ref()));
+ Some(ResumptionToken::new(enc.into(), t.expiration_time()))
+ } else {
+ None
+ }
+ } else {
+ unreachable!("It is a server.");
+ }
+ }
+
+ pub fn has_resumption_token(&self) -> bool {
+ if let Agent::Client(c) = &self.tls {
+ c.has_resumption_token()
+ } else {
+ unreachable!("It is a server.");
+ }
+ }
+}
+
+impl ::std::fmt::Display for Crypto {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "Crypto")
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub enum CryptoDxDirection {
+ Read,
+ Write,
+}
+
+#[derive(Debug)]
+pub struct CryptoDxState {
+ /// The QUIC version.
+ version: Version,
+ /// Whether packets protected with this state will be read or written.
+ direction: CryptoDxDirection,
+ /// The epoch of this crypto state. This initially tracks TLS epochs
+ /// via DTLS: 0 = initial, 1 = 0-RTT, 2 = handshake, 3 = application.
+ /// But we don't need to keep that, and QUIC isn't limited in how
+ /// many times keys can be updated, so we don't use `u16` for this.
+ epoch: usize,
+ aead: Aead,
+ hpkey: HpKey,
+ /// This tracks the range of packet numbers that have been seen. This allows
+ /// for verifying that packet numbers before a key update are strictly lower
+ /// than packet numbers after a key update.
+ used_pn: Range<PacketNumber>,
+ /// This is the minimum packet number that is allowed.
+ min_pn: PacketNumber,
+ /// The total number of operations that are remaining before the keys
+ /// become exhausted and can't be used any more.
+ invocations: PacketNumber,
+ fuzzing: bool,
+}
+
+impl CryptoDxState {
+ #[allow(clippy::reversed_empty_ranges)] // To initialize an empty range.
+ pub fn new(
+ version: Version,
+ direction: CryptoDxDirection,
+ epoch: Epoch,
+ secret: &SymKey,
+ cipher: Cipher,
+ fuzzing: bool,
+ ) -> Self {
+ qinfo!(
+ "Making {:?} {} CryptoDxState, v={:?} cipher={}",
+ direction,
+ epoch,
+ version,
+ cipher,
+ );
+ let hplabel = String::from(version.label_prefix()) + "hp";
+ Self {
+ version,
+ direction,
+ epoch: usize::from(epoch),
+ aead: Aead::new(
+ fuzzing,
+ TLS_VERSION_1_3,
+ cipher,
+ secret,
+ version.label_prefix(),
+ )
+ .unwrap(),
+ hpkey: HpKey::extract(TLS_VERSION_1_3, cipher, secret, &hplabel).unwrap(),
+ used_pn: 0..0,
+ min_pn: 0,
+ invocations: Self::limit(direction, cipher),
+ fuzzing,
+ }
+ }
+
+ pub fn new_initial(
+ version: Version,
+ direction: CryptoDxDirection,
+ label: &str,
+ dcid: &[u8],
+ fuzzing: bool,
+ ) -> Self {
+ qtrace!("new_initial {:?} {}", version, ConnectionIdRef::from(dcid));
+ let salt = version.initial_salt();
+ let cipher = TLS_AES_128_GCM_SHA256;
+ let initial_secret = hkdf::extract(
+ TLS_VERSION_1_3,
+ cipher,
+ Some(hkdf::import_key(TLS_VERSION_1_3, salt).as_ref().unwrap()),
+ hkdf::import_key(TLS_VERSION_1_3, dcid).as_ref().unwrap(),
+ )
+ .unwrap();
+
+ let secret =
+ hkdf::expand_label(TLS_VERSION_1_3, cipher, &initial_secret, &[], label).unwrap();
+
+ Self::new(
+ version,
+ direction,
+ TLS_EPOCH_INITIAL,
+ &secret,
+ cipher,
+ fuzzing,
+ )
+ }
+
+ /// Determine the confidentiality and integrity limits for the cipher.
+ fn limit(direction: CryptoDxDirection, cipher: Cipher) -> PacketNumber {
+ match direction {
+ // This uses the smaller limits for 2^16 byte packets
+ // as we don't control incoming packet size.
+ CryptoDxDirection::Read => match cipher {
+ TLS_AES_128_GCM_SHA256 => 1 << 52,
+ TLS_AES_256_GCM_SHA384 => PacketNumber::MAX,
+ TLS_CHACHA20_POLY1305_SHA256 => 1 << 36,
+ _ => unreachable!(),
+ },
+ // This uses the larger limits for 2^11 byte packets.
+ CryptoDxDirection::Write => match cipher {
+ TLS_AES_128_GCM_SHA256 | TLS_AES_256_GCM_SHA384 => 1 << 28,
+ TLS_CHACHA20_POLY1305_SHA256 => PacketNumber::MAX,
+ _ => unreachable!(),
+ },
+ }
+ }
+
+ fn invoked(&mut self) -> Res<()> {
+ #[cfg(test)]
+ OVERWRITE_INVOCATIONS.with(|v| {
+ if let Some(i) = v.borrow_mut().take() {
+ neqo_common::qwarn!("Setting {:?} invocations to {}", self.direction, i);
+ self.invocations = i;
+ }
+ });
+ self.invocations = self
+ .invocations
+ .checked_sub(1)
+ .ok_or(Error::KeysExhausted)?;
+ Ok(())
+ }
+
+ /// Determine whether we should initiate a key update.
+ pub fn should_update(&self) -> bool {
+ // There is no point in updating read keys as the limit is global.
+ debug_assert_eq!(self.direction, CryptoDxDirection::Write);
+ self.invocations <= UPDATE_WRITE_KEYS_AT
+ }
+
+ pub fn next(&self, next_secret: &SymKey, cipher: Cipher) -> Self {
+ let pn = self.next_pn();
+ // We count invocations of each write key just for that key, but all
+ // attempts to invocations to read count toward a single limit.
+ // This doesn't count use of Handshake keys.
+ let invocations = if self.direction == CryptoDxDirection::Read {
+ self.invocations
+ } else {
+ Self::limit(CryptoDxDirection::Write, cipher)
+ };
+ Self {
+ version: self.version,
+ direction: self.direction,
+ epoch: self.epoch + 1,
+ aead: Aead::new(
+ self.fuzzing,
+ TLS_VERSION_1_3,
+ cipher,
+ next_secret,
+ self.version.label_prefix(),
+ )
+ .unwrap(),
+ hpkey: self.hpkey.clone(),
+ used_pn: pn..pn,
+ min_pn: pn,
+ invocations,
+ fuzzing: self.fuzzing,
+ }
+ }
+
+ #[must_use]
+ pub fn version(&self) -> Version {
+ self.version
+ }
+
+ #[must_use]
+ pub fn key_phase(&self) -> bool {
+ // Epoch 3 => 0, 4 => 1, 5 => 0, 6 => 1, ...
+ self.epoch & 1 != 1
+ }
+
+ /// This is a continuation of a previous, so adjust the range accordingly.
+ /// Fail if the two ranges overlap. Do nothing if the directions don't match.
+ pub fn continuation(&mut self, prev: &Self) -> Res<()> {
+ debug_assert_eq!(self.direction, prev.direction);
+ let next = prev.next_pn();
+ self.min_pn = next;
+ if self.used_pn.is_empty() {
+ self.used_pn = next..next;
+ Ok(())
+ } else if prev.used_pn.end > self.used_pn.start {
+ qdebug!(
+ [self],
+ "Found packet with too new packet number {} > {}, compared to {}",
+ self.used_pn.start,
+ prev.used_pn.end,
+ prev,
+ );
+ Err(Error::PacketNumberOverlap)
+ } else {
+ self.used_pn.start = next;
+ Ok(())
+ }
+ }
+
+ /// Mark a packet number as used. If this is too low, reject it.
+ /// Note that this won't catch a value that is too high if packets protected with
+ /// old keys are received after a key update. That needs to be caught elsewhere.
+ pub fn used(&mut self, pn: PacketNumber) -> Res<()> {
+ if pn < self.min_pn {
+ qdebug!(
+ [self],
+ "Found packet with too old packet number: {} < {}",
+ pn,
+ self.min_pn
+ );
+ return Err(Error::PacketNumberOverlap);
+ }
+ if self.used_pn.start == self.used_pn.end {
+ self.used_pn.start = pn;
+ }
+ self.used_pn.end = max(pn + 1, self.used_pn.end);
+ Ok(())
+ }
+
+ #[must_use]
+ pub fn needs_update(&self) -> bool {
+ // Only initiate a key update if we have processed exactly one packet
+ // and we are in an epoch greater than 3.
+ self.used_pn.start + 1 == self.used_pn.end
+ && self.epoch > usize::from(TLS_EPOCH_APPLICATION_DATA)
+ }
+
+ #[must_use]
+ pub fn can_update(&self, largest_acknowledged: Option<PacketNumber>) -> bool {
+ if let Some(la) = largest_acknowledged {
+ self.used_pn.contains(&la)
+ } else {
+ // If we haven't received any acknowledgments, it's OK to update
+ // the first application data epoch.
+ self.epoch == usize::from(TLS_EPOCH_APPLICATION_DATA)
+ }
+ }
+
+ pub fn compute_mask(&self, sample: &[u8]) -> Res<Vec<u8>> {
+ let mask = self.hpkey.mask(sample)?;
+ qtrace!([self], "HP sample={} mask={}", hex(sample), hex(&mask));
+ Ok(mask)
+ }
+
+ #[must_use]
+ pub fn next_pn(&self) -> PacketNumber {
+ self.used_pn.end
+ }
+
+ pub fn encrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> {
+ debug_assert_eq!(self.direction, CryptoDxDirection::Write);
+ qtrace!(
+ [self],
+ "encrypt pn={} hdr={} body={}",
+ pn,
+ hex(hdr),
+ hex(body)
+ );
+ // The numbers in `Self::limit` assume a maximum packet size of 2^11.
+ if body.len() > 2048 {
+ debug_assert!(false);
+ return Err(Error::InternalError);
+ }
+ self.invoked()?;
+
+ let size = body.len() + MAX_AUTH_TAG;
+ let mut out = vec![0; size];
+ let res = self.aead.encrypt(pn, hdr, body, &mut out)?;
+
+ qtrace!([self], "encrypt ct={}", hex(res));
+ debug_assert_eq!(pn, self.next_pn());
+ self.used(pn)?;
+ Ok(res.to_vec())
+ }
+
+ #[must_use]
+ pub fn expansion(&self) -> usize {
+ self.aead.expansion()
+ }
+
+ pub fn decrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> {
+ debug_assert_eq!(self.direction, CryptoDxDirection::Read);
+ qtrace!(
+ [self],
+ "decrypt pn={} hdr={} body={}",
+ pn,
+ hex(hdr),
+ hex(body)
+ );
+ self.invoked()?;
+ let mut out = vec![0; body.len()];
+ let res = self.aead.decrypt(pn, hdr, body, &mut out)?;
+ self.used(pn)?;
+ Ok(res.to_vec())
+ }
+
+ #[cfg(all(test, not(feature = "fuzzing")))]
+ pub(crate) fn test_default() -> Self {
+ // This matches the value in packet.rs
+ const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
+ Self::new_initial(
+ Version::default(),
+ CryptoDxDirection::Write,
+ "server in",
+ CLIENT_CID,
+ false,
+ )
+ }
+
+ /// Get the amount of extra padding packets protected with this profile need.
+ /// This is the difference between the size of the header protection sample
+ /// and the AEAD expansion.
+ pub fn extra_padding(&self) -> usize {
+ self.hpkey
+ .sample_size()
+ .saturating_sub(self.aead.expansion())
+ }
+}
+
+impl std::fmt::Display for CryptoDxState {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "epoch {} {:?}", self.epoch, self.direction)
+ }
+}
+
+#[derive(Debug)]
+pub struct CryptoState {
+ tx: CryptoDxState,
+ rx: CryptoDxState,
+}
+
+impl Index<CryptoDxDirection> for CryptoState {
+ type Output = CryptoDxState;
+
+ fn index(&self, dir: CryptoDxDirection) -> &Self::Output {
+ match dir {
+ CryptoDxDirection::Read => &self.rx,
+ CryptoDxDirection::Write => &self.tx,
+ }
+ }
+}
+
+impl IndexMut<CryptoDxDirection> for CryptoState {
+ fn index_mut(&mut self, dir: CryptoDxDirection) -> &mut Self::Output {
+ match dir {
+ CryptoDxDirection::Read => &mut self.rx,
+ CryptoDxDirection::Write => &mut self.tx,
+ }
+ }
+}
+
+/// `CryptoDxAppData` wraps the state necessary for one direction of application data keys.
+/// This includes the secret needed to generate the next set of keys.
+#[derive(Debug)]
+pub(crate) struct CryptoDxAppData {
+ dx: CryptoDxState,
+ cipher: Cipher,
+ // Not the secret used to create `self.dx`, but the one needed for the next iteration.
+ next_secret: SymKey,
+ fuzzing: bool,
+}
+
+impl CryptoDxAppData {
+ pub fn new(
+ version: Version,
+ dir: CryptoDxDirection,
+ secret: SymKey,
+ cipher: Cipher,
+ fuzzing: bool,
+ ) -> Res<Self> {
+ Ok(Self {
+ dx: CryptoDxState::new(
+ version,
+ dir,
+ TLS_EPOCH_APPLICATION_DATA,
+ &secret,
+ cipher,
+ fuzzing,
+ ),
+ cipher,
+ next_secret: Self::update_secret(cipher, &secret)?,
+ fuzzing,
+ })
+ }
+
+ fn update_secret(cipher: Cipher, secret: &SymKey) -> Res<SymKey> {
+ let next = hkdf::expand_label(TLS_VERSION_1_3, cipher, secret, &[], "quic ku")?;
+ Ok(next)
+ }
+
+ pub fn next(&self) -> Res<Self> {
+ if self.dx.epoch == usize::max_value() {
+ // Guard against too many key updates.
+ return Err(Error::KeysExhausted);
+ }
+ let next_secret = Self::update_secret(self.cipher, &self.next_secret)?;
+ Ok(Self {
+ dx: self.dx.next(&self.next_secret, self.cipher),
+ cipher: self.cipher,
+ next_secret,
+ fuzzing: self.fuzzing,
+ })
+ }
+
+ pub fn epoch(&self) -> usize {
+ self.dx.epoch
+ }
+}
+
+#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
+pub enum CryptoSpace {
+ Initial,
+ ZeroRtt,
+ Handshake,
+ ApplicationData,
+}
+
+/// All of the keying material needed for a connection.
+///
+/// Note that the methods on this struct take a version but those are only ever
+/// used for Initial keys; a version has been selected at the time we need to
+/// get other keys, so those have fixed versions.
+#[derive(Debug, Default)]
+pub struct CryptoStates {
+ initials: HashMap<Version, CryptoState>,
+ handshake: Option<CryptoState>,
+ zero_rtt: Option<CryptoDxState>, // One direction only!
+ cipher: Cipher,
+ app_write: Option<CryptoDxAppData>,
+ app_read: Option<CryptoDxAppData>,
+ app_read_next: Option<CryptoDxAppData>,
+ // If this is set, then we have noticed a genuine update.
+ // Once this time passes, we should switch in new keys.
+ read_update_time: Option<Instant>,
+ fuzzing: bool,
+}
+
+impl CryptoStates {
+ /// Select a `CryptoDxState` and `CryptoSpace` for the given `PacketNumberSpace`.
+ /// This selects 0-RTT keys for `PacketNumberSpace::ApplicationData` if 1-RTT keys are
+ /// not yet available.
+ pub fn select_tx_mut(
+ &mut self,
+ version: Version,
+ space: PacketNumberSpace,
+ ) -> Option<(CryptoSpace, &mut CryptoDxState)> {
+ match space {
+ PacketNumberSpace::Initial => self
+ .tx_mut(version, CryptoSpace::Initial)
+ .map(|dx| (CryptoSpace::Initial, dx)),
+ PacketNumberSpace::Handshake => self
+ .tx_mut(version, CryptoSpace::Handshake)
+ .map(|dx| (CryptoSpace::Handshake, dx)),
+ PacketNumberSpace::ApplicationData => {
+ if let Some(app) = self.app_write.as_mut() {
+ Some((CryptoSpace::ApplicationData, &mut app.dx))
+ } else {
+ self.zero_rtt.as_mut().map(|dx| (CryptoSpace::ZeroRtt, dx))
+ }
+ }
+ }
+ }
+
+ pub fn tx_mut<'a>(
+ &'a mut self,
+ version: Version,
+ cspace: CryptoSpace,
+ ) -> Option<&'a mut CryptoDxState> {
+ let tx = |k: Option<&'a mut CryptoState>| k.map(|dx| &mut dx.tx);
+ match cspace {
+ CryptoSpace::Initial => tx(self.initials.get_mut(&version)),
+ CryptoSpace::ZeroRtt => self
+ .zero_rtt
+ .as_mut()
+ .filter(|z| z.direction == CryptoDxDirection::Write),
+ CryptoSpace::Handshake => tx(self.handshake.as_mut()),
+ CryptoSpace::ApplicationData => self.app_write.as_mut().map(|app| &mut app.dx),
+ }
+ }
+
+ pub fn tx<'a>(&'a self, version: Version, cspace: CryptoSpace) -> Option<&'a CryptoDxState> {
+ let tx = |k: Option<&'a CryptoState>| k.map(|dx| &dx.tx);
+ match cspace {
+ CryptoSpace::Initial => tx(self.initials.get(&version)),
+ CryptoSpace::ZeroRtt => self
+ .zero_rtt
+ .as_ref()
+ .filter(|z| z.direction == CryptoDxDirection::Write),
+ CryptoSpace::Handshake => tx(self.handshake.as_ref()),
+ CryptoSpace::ApplicationData => self.app_write.as_ref().map(|app| &app.dx),
+ }
+ }
+
+ pub fn select_tx(
+ &self,
+ version: Version,
+ space: PacketNumberSpace,
+ ) -> Option<(CryptoSpace, &CryptoDxState)> {
+ match space {
+ PacketNumberSpace::Initial => self
+ .tx(version, CryptoSpace::Initial)
+ .map(|dx| (CryptoSpace::Initial, dx)),
+ PacketNumberSpace::Handshake => self
+ .tx(version, CryptoSpace::Handshake)
+ .map(|dx| (CryptoSpace::Handshake, dx)),
+ PacketNumberSpace::ApplicationData => {
+ if let Some(app) = self.app_write.as_ref() {
+ Some((CryptoSpace::ApplicationData, &app.dx))
+ } else {
+ self.zero_rtt.as_ref().map(|dx| (CryptoSpace::ZeroRtt, dx))
+ }
+ }
+ }
+ }
+
+ pub fn rx_hp(&mut self, version: Version, cspace: CryptoSpace) -> Option<&mut CryptoDxState> {
+ if let CryptoSpace::ApplicationData = cspace {
+ self.app_read.as_mut().map(|ar| &mut ar.dx)
+ } else {
+ self.rx(version, cspace, false)
+ }
+ }
+
+ pub fn rx<'a>(
+ &'a mut self,
+ version: Version,
+ cspace: CryptoSpace,
+ key_phase: bool,
+ ) -> Option<&'a mut CryptoDxState> {
+ let rx = |x: Option<&'a mut CryptoState>| x.map(|dx| &mut dx.rx);
+ match cspace {
+ CryptoSpace::Initial => rx(self.initials.get_mut(&version)),
+ CryptoSpace::ZeroRtt => self
+ .zero_rtt
+ .as_mut()
+ .filter(|z| z.direction == CryptoDxDirection::Read),
+ CryptoSpace::Handshake => rx(self.handshake.as_mut()),
+ CryptoSpace::ApplicationData => {
+ let f = |a: Option<&'a mut CryptoDxAppData>| {
+ a.filter(|ar| ar.dx.key_phase() == key_phase)
+ };
+ // XOR to reduce the leakage about which key is chosen.
+ f(self.app_read.as_mut())
+ .xor(f(self.app_read_next.as_mut()))
+ .map(|ar| &mut ar.dx)
+ }
+ }
+ }
+
+ /// Whether keys for processing packets in the indicated space are pending.
+ /// This allows the caller to determine whether to save a packet for later
+ /// when keys are not available.
+ /// NOTE: 0-RTT keys are not considered here. The expectation is that a
+ /// server will have to save 0-RTT packets in a different place. Though it
+ /// is possible to attribute 0-RTT packets to an existing connection if there
+ /// is a multi-packet Initial, that is an unusual circumstance, so we
+ /// don't do caching for that in those places that call this function.
+ pub fn rx_pending(&self, space: CryptoSpace) -> bool {
+ match space {
+ CryptoSpace::Initial | CryptoSpace::ZeroRtt => false,
+ CryptoSpace::Handshake => self.handshake.is_none() && !self.initials.is_empty(),
+ CryptoSpace::ApplicationData => self.app_read.is_none(),
+ }
+ }
+
+ /// Create the initial crypto state.
+ /// Note that the version here can change and that's OK.
+ pub fn init<'v, V>(&mut self, versions: V, role: Role, dcid: &[u8])
+ where
+ V: IntoIterator<Item = &'v Version>,
+ {
+ const CLIENT_INITIAL_LABEL: &str = "client in";
+ const SERVER_INITIAL_LABEL: &str = "server in";
+
+ let (write, read) = match role {
+ Role::Client => (CLIENT_INITIAL_LABEL, SERVER_INITIAL_LABEL),
+ Role::Server => (SERVER_INITIAL_LABEL, CLIENT_INITIAL_LABEL),
+ };
+
+ for v in versions {
+ qinfo!(
+ [self],
+ "Creating initial cipher state v={:?}, role={:?} dcid={}",
+ v,
+ role,
+ hex(dcid)
+ );
+
+ let mut initial = CryptoState {
+ tx: CryptoDxState::new_initial(
+ *v,
+ CryptoDxDirection::Write,
+ write,
+ dcid,
+ self.fuzzing,
+ ),
+ rx: CryptoDxState::new_initial(
+ *v,
+ CryptoDxDirection::Read,
+ read,
+ dcid,
+ self.fuzzing,
+ ),
+ };
+ if let Some(prev) = self.initials.get(v) {
+ qinfo!(
+ [self],
+ "Continue packet numbers for initial after retry (write is {:?})",
+ prev.rx.used_pn,
+ );
+ initial.tx.continuation(&prev.tx).unwrap();
+ }
+ self.initials.insert(*v, initial);
+ }
+ }
+
+ /// At a server, we can be more targeted in initializing.
+ /// Initialize on demand: either to decrypt Initial packets that we receive
+ /// or after a version has been selected.
+ /// This is maybe slightly inefficient in the first case, because we might
+ /// not need the send keys if the packet is subsequently discarded, but
+ /// the overall effort is small enough to write off.
+ pub fn init_server(&mut self, version: Version, dcid: &[u8]) {
+ if !self.initials.contains_key(&version) {
+ self.init(&[version], Role::Server, dcid);
+ }
+ }
+
+ pub fn confirm_version(&mut self, orig: Version, confirmed: Version) {
+ if orig != confirmed {
+ // This part where the old data is removed and then re-added is to
+ // appease the borrow checker.
+ // Note that on the server, we might not have initials for |orig| if it
+ // was configured for |orig| and only |confirmed| Initial packets arrived.
+ if let Some(prev) = self.initials.remove(&orig) {
+ let next = self.initials.get_mut(&confirmed).unwrap();
+ next.tx.continuation(&prev.tx).unwrap();
+ self.initials.insert(orig, prev);
+ }
+ }
+ }
+
+ pub fn set_0rtt_keys(
+ &mut self,
+ version: Version,
+ dir: CryptoDxDirection,
+ secret: &SymKey,
+ cipher: Cipher,
+ ) {
+ qtrace!([self], "install 0-RTT keys");
+ self.zero_rtt = Some(CryptoDxState::new(
+ version,
+ dir,
+ TLS_EPOCH_ZERO_RTT,
+ secret,
+ cipher,
+ self.fuzzing,
+ ));
+ }
+
+ /// Discard keys and return true if that happened.
+ pub fn discard(&mut self, space: PacketNumberSpace) -> bool {
+ match space {
+ PacketNumberSpace::Initial => {
+ let empty = self.initials.is_empty();
+ self.initials.clear();
+ !empty
+ }
+ PacketNumberSpace::Handshake => self.handshake.take().is_some(),
+ PacketNumberSpace::ApplicationData => panic!("Can't drop application data keys"),
+ }
+ }
+
+ pub fn discard_0rtt_keys(&mut self) {
+ qtrace!([self], "discard 0-RTT keys");
+ assert!(
+ self.app_read.is_none(),
+ "Can't discard 0-RTT after setting application keys"
+ );
+ self.zero_rtt = None;
+ }
+
+ pub fn set_handshake_keys(
+ &mut self,
+ version: Version,
+ write_secret: &SymKey,
+ read_secret: &SymKey,
+ cipher: Cipher,
+ ) {
+ self.cipher = cipher;
+ self.handshake = Some(CryptoState {
+ tx: CryptoDxState::new(
+ version,
+ CryptoDxDirection::Write,
+ TLS_EPOCH_HANDSHAKE,
+ write_secret,
+ cipher,
+ self.fuzzing,
+ ),
+ rx: CryptoDxState::new(
+ version,
+ CryptoDxDirection::Read,
+ TLS_EPOCH_HANDSHAKE,
+ read_secret,
+ cipher,
+ self.fuzzing,
+ ),
+ });
+ }
+
+ pub fn set_application_write_key(&mut self, version: Version, secret: SymKey) -> Res<()> {
+ debug_assert!(self.app_write.is_none());
+ debug_assert_ne!(self.cipher, 0);
+ let mut app = CryptoDxAppData::new(
+ version,
+ CryptoDxDirection::Write,
+ secret,
+ self.cipher,
+ self.fuzzing,
+ )?;
+ if let Some(z) = &self.zero_rtt {
+ if z.direction == CryptoDxDirection::Write {
+ app.dx.continuation(z)?;
+ }
+ }
+ self.zero_rtt = None;
+ self.app_write = Some(app);
+ Ok(())
+ }
+
+ pub fn set_application_read_key(
+ &mut self,
+ version: Version,
+ secret: SymKey,
+ expire_0rtt: Instant,
+ ) -> Res<()> {
+ debug_assert!(self.app_write.is_some(), "should have write keys installed");
+ debug_assert!(self.app_read.is_none());
+ let mut app = CryptoDxAppData::new(
+ version,
+ CryptoDxDirection::Read,
+ secret,
+ self.cipher,
+ self.fuzzing,
+ )?;
+ if let Some(z) = &self.zero_rtt {
+ if z.direction == CryptoDxDirection::Read {
+ app.dx.continuation(z)?;
+ }
+ self.read_update_time = Some(expire_0rtt);
+ }
+ self.app_read_next = Some(app.next()?);
+ self.app_read = Some(app);
+ Ok(())
+ }
+
+ /// Update the write keys.
+ pub fn initiate_key_update(&mut self, largest_acknowledged: Option<PacketNumber>) -> Res<()> {
+ // Only update if we are able to. We can only do this if we have
+ // received an acknowledgement for a packet in the current phase.
+ // Also, skip this if we are waiting for read keys on the existing
+ // key update to be rolled over.
+ let write = &self.app_write.as_ref().unwrap().dx;
+ if write.can_update(largest_acknowledged) && self.read_update_time.is_none() {
+ // This call additionally checks that we don't advance to the next
+ // epoch while a key update is in progress.
+ if self.maybe_update_write()? {
+ Ok(())
+ } else {
+ qdebug!([self], "Write keys already updated");
+ Err(Error::KeyUpdateBlocked)
+ }
+ } else {
+ qdebug!([self], "Waiting for ACK or blocked on read key timer");
+ Err(Error::KeyUpdateBlocked)
+ }
+ }
+
+ /// Try to update, and return true if it happened.
+ fn maybe_update_write(&mut self) -> Res<bool> {
+ // Update write keys. But only do so if the write keys are not already
+ // ahead of the read keys. If we initiated the key update, the write keys
+ // will already be ahead.
+ debug_assert!(self.read_update_time.is_none());
+ let write = &self.app_write.as_ref().unwrap();
+ let read = &self.app_read.as_ref().unwrap();
+ if write.epoch() == read.epoch() {
+ qdebug!([self], "Update write keys to epoch={}", write.epoch() + 1);
+ self.app_write = Some(write.next()?);
+ Ok(true)
+ } else {
+ Ok(false)
+ }
+ }
+
+ /// Check whether write keys are close to running out of invocations.
+ /// If that is close, update them if possible. Failing to update at
+ /// this stage is cause for a fatal error.
+ pub fn auto_update(&mut self) -> Res<()> {
+ if let Some(app_write) = self.app_write.as_ref() {
+ if app_write.dx.should_update() {
+ qinfo!([self], "Initiating automatic key update");
+ if !self.maybe_update_write()? {
+ return Err(Error::KeysExhausted);
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn has_0rtt_read(&self) -> bool {
+ self.zero_rtt
+ .as_ref()
+ .filter(|z| z.direction == CryptoDxDirection::Read)
+ .is_some()
+ }
+
+ /// Prepare to update read keys. This doesn't happen immediately as
+ /// we want to ensure that we can continue to receive any delayed
+ /// packets that use the old keys. So we just set a timer.
+ pub fn key_update_received(&mut self, expiration: Instant) -> Res<()> {
+ qtrace!([self], "Key update received");
+ // If we received a key update, then we assume that the peer has
+ // acknowledged a packet we sent in this epoch. It's OK to do that
+ // because they aren't allowed to update without first having received
+ // something from us. If the ACK isn't in the packet that triggered this
+ // key update, it must be in some other packet they have sent.
+ _ = self.maybe_update_write()?;
+
+ // We shouldn't have 0-RTT keys at this point, but if we do, dump them.
+ debug_assert_eq!(self.read_update_time.is_some(), self.has_0rtt_read());
+ if self.has_0rtt_read() {
+ self.zero_rtt = None;
+ }
+ self.read_update_time = Some(expiration);
+ Ok(())
+ }
+
+ #[must_use]
+ pub fn update_time(&self) -> Option<Instant> {
+ self.read_update_time
+ }
+
+ /// Check if time has passed for updating key update parameters.
+ /// If it has, then swap keys over and allow more key updates to be initiated.
+ /// This is also used to discard 0-RTT read keys at the server in the same way.
+ pub fn check_key_update(&mut self, now: Instant) -> Res<()> {
+ if let Some(expiry) = self.read_update_time {
+ // If enough time has passed, then install new keys and clear the timer.
+ if now >= expiry {
+ if self.has_0rtt_read() {
+ qtrace!([self], "Discarding 0-RTT keys");
+ self.zero_rtt = None;
+ } else {
+ qtrace!([self], "Rotating read keys");
+ mem::swap(&mut self.app_read, &mut self.app_read_next);
+ self.app_read_next = Some(self.app_read.as_ref().unwrap().next()?);
+ }
+ self.read_update_time = None;
+ }
+ }
+ Ok(())
+ }
+
+ /// Get the current/highest epoch. This returns (write, read) epochs.
+ #[cfg(test)]
+ pub fn get_epochs(&self) -> (Option<usize>, Option<usize>) {
+ let to_epoch = |app: &Option<CryptoDxAppData>| app.as_ref().map(|a| a.dx.epoch);
+ (to_epoch(&self.app_write), to_epoch(&self.app_read))
+ }
+
+ /// While we are awaiting the completion of a key update, we might receive
+ /// valid packets that are protected with old keys. We need to ensure that
+ /// these don't carry packet numbers higher than those in packets protected
+ /// with the newer keys. To ensure that, this is called after every decryption.
+ pub fn check_pn_overlap(&mut self) -> Res<()> {
+ // We only need to do the check while we are waiting for read keys to be updated.
+ if self.read_update_time.is_some() {
+ qtrace!([self], "Checking for PN overlap");
+ let next_dx = &mut self.app_read_next.as_mut().unwrap().dx;
+ next_dx.continuation(&self.app_read.as_ref().unwrap().dx)?;
+ }
+ Ok(())
+ }
+
+ /// Make some state for removing protection in tests.
+ #[cfg(not(feature = "fuzzing"))]
+ #[cfg(test)]
+ pub(crate) fn test_default() -> Self {
+ let read = |epoch| {
+ let mut dx = CryptoDxState::test_default();
+ dx.direction = CryptoDxDirection::Read;
+ dx.epoch = epoch;
+ dx
+ };
+ let app_read = |epoch| CryptoDxAppData {
+ dx: read(epoch),
+ cipher: TLS_AES_128_GCM_SHA256,
+ next_secret: hkdf::import_key(TLS_VERSION_1_3, &[0xaa; 32]).unwrap(),
+ fuzzing: false,
+ };
+ let mut initials = HashMap::new();
+ initials.insert(
+ Version::Version1,
+ CryptoState {
+ tx: CryptoDxState::test_default(),
+ rx: read(0),
+ },
+ );
+ Self {
+ initials,
+ handshake: None,
+ zero_rtt: None,
+ cipher: TLS_AES_128_GCM_SHA256,
+ // This isn't used, but the epoch is read to check for a key update.
+ app_write: Some(app_read(3)),
+ app_read: Some(app_read(3)),
+ app_read_next: Some(app_read(4)),
+ read_update_time: None,
+ fuzzing: false,
+ }
+ }
+
+ #[cfg(all(not(feature = "fuzzing"), test))]
+ pub(crate) fn test_chacha() -> Self {
+ const SECRET: &[u8] = &[
+ 0x9a, 0xc3, 0x12, 0xa7, 0xf8, 0x77, 0x46, 0x8e, 0xbe, 0x69, 0x42, 0x27, 0x48, 0xad,
+ 0x00, 0xa1, 0x54, 0x43, 0xf1, 0x82, 0x03, 0xa0, 0x7d, 0x60, 0x60, 0xf6, 0x88, 0xf3,
+ 0x0f, 0x21, 0x63, 0x2b,
+ ];
+ let secret = hkdf::import_key(TLS_VERSION_1_3, SECRET).unwrap();
+ let app_read = |epoch| CryptoDxAppData {
+ dx: CryptoDxState {
+ version: Version::Version1,
+ direction: CryptoDxDirection::Read,
+ epoch,
+ aead: Aead::new(
+ false,
+ TLS_VERSION_1_3,
+ TLS_CHACHA20_POLY1305_SHA256,
+ &secret,
+ "quic ", // This is a v1 test so hard-code the label.
+ )
+ .unwrap(),
+ hpkey: HpKey::extract(
+ TLS_VERSION_1_3,
+ TLS_CHACHA20_POLY1305_SHA256,
+ &secret,
+ "quic hp",
+ )
+ .unwrap(),
+ used_pn: 0..645_971_972,
+ min_pn: 0,
+ invocations: 10,
+ fuzzing: false,
+ },
+ cipher: TLS_CHACHA20_POLY1305_SHA256,
+ next_secret: secret.clone(),
+ fuzzing: false,
+ };
+ Self {
+ initials: HashMap::new(),
+ handshake: None,
+ zero_rtt: None,
+ cipher: TLS_CHACHA20_POLY1305_SHA256,
+ app_write: Some(app_read(3)),
+ app_read: Some(app_read(3)),
+ app_read_next: Some(app_read(4)),
+ read_update_time: None,
+ fuzzing: false,
+ }
+ }
+}
+
+impl std::fmt::Display for CryptoStates {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "CryptoStates")
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct CryptoStream {
+ tx: TxBuffer,
+ rx: RxStreamOrderer,
+}
+
+#[derive(Debug)]
+#[allow(dead_code)] // Suppress false positive: https://github.com/rust-lang/rust/issues/68408
+pub enum CryptoStreams {
+ Initial {
+ initial: CryptoStream,
+ handshake: CryptoStream,
+ application: CryptoStream,
+ },
+ Handshake {
+ handshake: CryptoStream,
+ application: CryptoStream,
+ },
+ ApplicationData {
+ application: CryptoStream,
+ },
+}
+
+impl CryptoStreams {
+ /// Keep around 64k if a server wants to push excess data at us.
+ const BUFFER_LIMIT: u64 = 65536;
+
+ pub fn discard(&mut self, space: PacketNumberSpace) {
+ match space {
+ PacketNumberSpace::Initial => {
+ if let Self::Initial {
+ handshake,
+ application,
+ ..
+ } = self
+ {
+ *self = Self::Handshake {
+ handshake: mem::take(handshake),
+ application: mem::take(application),
+ };
+ }
+ }
+ PacketNumberSpace::Handshake => {
+ if let Self::Handshake { application, .. } = self {
+ *self = Self::ApplicationData {
+ application: mem::take(application),
+ };
+ } else if matches!(self, Self::Initial { .. }) {
+ panic!("Discarding handshake before initial discarded");
+ }
+ }
+ PacketNumberSpace::ApplicationData => {
+ panic!("Discarding application data crypto streams")
+ }
+ }
+ }
+
+ pub fn send(&mut self, space: PacketNumberSpace, data: &[u8]) {
+ self.get_mut(space).unwrap().tx.send(data);
+ }
+
+ pub fn inbound_frame(&mut self, space: PacketNumberSpace, offset: u64, data: &[u8]) -> Res<()> {
+ let rx = &mut self.get_mut(space).unwrap().rx;
+ rx.inbound_frame(offset, data);
+ if rx.received() - rx.retired() <= Self::BUFFER_LIMIT {
+ Ok(())
+ } else {
+ Err(Error::CryptoBufferExceeded)
+ }
+ }
+
+ pub fn data_ready(&self, space: PacketNumberSpace) -> bool {
+ self.get(space).map_or(false, |cs| cs.rx.data_ready())
+ }
+
+ pub fn read_to_end(&mut self, space: PacketNumberSpace, buf: &mut Vec<u8>) -> usize {
+ self.get_mut(space).unwrap().rx.read_to_end(buf)
+ }
+
+ pub fn acked(&mut self, token: &CryptoRecoveryToken) {
+ self.get_mut(token.space)
+ .unwrap()
+ .tx
+ .mark_as_acked(token.offset, token.length);
+ }
+
+ pub fn lost(&mut self, token: &CryptoRecoveryToken) {
+ // See BZ 1624800, ignore lost packets in spaces we've dropped keys
+ if let Some(cs) = self.get_mut(token.space) {
+ cs.tx.mark_as_lost(token.offset, token.length);
+ }
+ }
+
+ /// Resend any Initial or Handshake CRYPTO frames that might be outstanding.
+ /// This can help speed up handshake times.
+ pub fn resend_unacked(&mut self, space: PacketNumberSpace) {
+ if space != PacketNumberSpace::ApplicationData {
+ if let Some(cs) = self.get_mut(space) {
+ cs.tx.unmark_sent();
+ }
+ }
+ }
+
+ fn get(&self, space: PacketNumberSpace) -> Option<&CryptoStream> {
+ let (initial, hs, app) = match self {
+ Self::Initial {
+ initial,
+ handshake,
+ application,
+ } => (Some(initial), Some(handshake), Some(application)),
+ Self::Handshake {
+ handshake,
+ application,
+ } => (None, Some(handshake), Some(application)),
+ Self::ApplicationData { application } => (None, None, Some(application)),
+ };
+ match space {
+ PacketNumberSpace::Initial => initial,
+ PacketNumberSpace::Handshake => hs,
+ PacketNumberSpace::ApplicationData => app,
+ }
+ }
+
+ fn get_mut(&mut self, space: PacketNumberSpace) -> Option<&mut CryptoStream> {
+ let (initial, hs, app) = match self {
+ Self::Initial {
+ initial,
+ handshake,
+ application,
+ } => (Some(initial), Some(handshake), Some(application)),
+ Self::Handshake {
+ handshake,
+ application,
+ } => (None, Some(handshake), Some(application)),
+ Self::ApplicationData { application } => (None, None, Some(application)),
+ };
+ match space {
+ PacketNumberSpace::Initial => initial,
+ PacketNumberSpace::Handshake => hs,
+ PacketNumberSpace::ApplicationData => app,
+ }
+ }
+
+ pub fn write_frame(
+ &mut self,
+ space: PacketNumberSpace,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) -> Res<()> {
+ let cs = self.get_mut(space).unwrap();
+ if let Some((offset, data)) = cs.tx.next_bytes() {
+ let mut header_len = 1 + Encoder::varint_len(offset) + 1;
+
+ // Don't bother if there isn't room for the header and some data.
+ if builder.remaining() < header_len + 1 {
+ return Ok(());
+ }
+ // Calculate length of data based on the minimum of:
+ // - available data
+ // - remaining space, less the header, which counts only one byte for the length at
+ // first to avoid underestimating length
+ let length = min(data.len(), builder.remaining() - header_len);
+ header_len += Encoder::varint_len(u64::try_from(length).unwrap()) - 1;
+ let length = min(data.len(), builder.remaining() - header_len);
+
+ builder.encode_varint(crate::frame::FRAME_TYPE_CRYPTO);
+ builder.encode_varint(offset);
+ builder.encode_vvec(&data[..length]);
+
+ cs.tx.mark_as_sent(offset, length);
+
+ qdebug!("CRYPTO for {} offset={}, len={}", space, offset, length);
+ tokens.push(RecoveryToken::Crypto(CryptoRecoveryToken {
+ space,
+ offset,
+ length,
+ }));
+ stats.crypto += 1;
+ }
+ Ok(())
+ }
+}
+
+impl Default for CryptoStreams {
+ fn default() -> Self {
+ Self::Initial {
+ initial: CryptoStream::default(),
+ handshake: CryptoStream::default(),
+ application: CryptoStream::default(),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct CryptoRecoveryToken {
+ space: PacketNumberSpace,
+ offset: u64,
+ length: usize,
+}
diff --git a/third_party/rust/neqo-transport/src/events.rs b/third_party/rust/neqo-transport/src/events.rs
new file mode 100644
index 0000000000..88a85250ee
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/events.rs
@@ -0,0 +1,321 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Collecting a list of events relevant to whoever is using the Connection.
+
+use std::{cell::RefCell, collections::VecDeque, rc::Rc};
+
+use neqo_common::event::Provider as EventProvider;
+use neqo_crypto::ResumptionToken;
+
+use crate::{
+ connection::State,
+ quic_datagrams::DatagramTracking,
+ stream_id::{StreamId, StreamType},
+ AppError, Stats,
+};
+
+#[derive(Debug, PartialOrd, Ord, PartialEq, Eq)]
+pub enum OutgoingDatagramOutcome {
+ DroppedTooBig,
+ DroppedQueueFull,
+ Lost,
+ Acked,
+}
+
+#[derive(Debug, PartialOrd, Ord, PartialEq, Eq)]
+pub enum ConnectionEvent {
+ /// Cert authentication needed
+ AuthenticationNeeded,
+ /// Encrypted client hello fallback occurred. The certificate for the
+ /// public name needs to be authenticated.
+ EchFallbackAuthenticationNeeded {
+ public_name: String,
+ },
+ /// A new uni (read) or bidi stream has been opened by the peer.
+ NewStream {
+ stream_id: StreamId,
+ },
+ /// Space available in the buffer for an application write to succeed.
+ SendStreamWritable {
+ stream_id: StreamId,
+ },
+ /// New bytes available for reading.
+ RecvStreamReadable {
+ stream_id: StreamId,
+ },
+ /// Peer reset the stream.
+ RecvStreamReset {
+ stream_id: StreamId,
+ app_error: AppError,
+ },
+ /// Peer has sent STOP_SENDING
+ SendStreamStopSending {
+ stream_id: StreamId,
+ app_error: AppError,
+ },
+ /// Peer has acked everything sent on the stream.
+ SendStreamComplete {
+ stream_id: StreamId,
+ },
+ /// Peer increased MAX_STREAMS
+ SendStreamCreatable {
+ stream_type: StreamType,
+ },
+ /// Connection state change.
+ StateChange(State),
+ /// The server rejected 0-RTT.
+ /// This event invalidates all state in streams that has been created.
+ /// Any data written to streams needs to be written again.
+ ZeroRttRejected,
+ ResumptionToken(ResumptionToken),
+ Datagram(Vec<u8>),
+ OutgoingDatagramOutcome {
+ id: u64,
+ outcome: OutgoingDatagramOutcome,
+ },
+ IncomingDatagramDropped,
+}
+
+#[derive(Debug, Default, Clone)]
+#[allow(clippy::module_name_repetitions)]
+pub struct ConnectionEvents {
+ events: Rc<RefCell<VecDeque<ConnectionEvent>>>,
+}
+
+impl ConnectionEvents {
+ pub fn authentication_needed(&self) {
+ self.insert(ConnectionEvent::AuthenticationNeeded);
+ }
+
+ pub fn ech_fallback_authentication_needed(&self, public_name: String) {
+ self.insert(ConnectionEvent::EchFallbackAuthenticationNeeded { public_name });
+ }
+
+ pub fn new_stream(&self, stream_id: StreamId) {
+ self.insert(ConnectionEvent::NewStream { stream_id });
+ }
+
+ pub fn recv_stream_readable(&self, stream_id: StreamId) {
+ self.insert(ConnectionEvent::RecvStreamReadable { stream_id });
+ }
+
+ pub fn recv_stream_reset(&self, stream_id: StreamId, app_error: AppError) {
+ // If reset, no longer readable.
+ self.remove(|evt| matches!(evt, ConnectionEvent::RecvStreamReadable { stream_id: x } if *x == stream_id.as_u64()));
+
+ self.insert(ConnectionEvent::RecvStreamReset {
+ stream_id,
+ app_error,
+ });
+ }
+
+ pub fn send_stream_writable(&self, stream_id: StreamId) {
+ self.insert(ConnectionEvent::SendStreamWritable { stream_id });
+ }
+
+ pub fn send_stream_stop_sending(&self, stream_id: StreamId, app_error: AppError) {
+ // If stopped, no longer writable.
+ self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamWritable { stream_id: x } if *x == stream_id));
+
+ self.insert(ConnectionEvent::SendStreamStopSending {
+ stream_id,
+ app_error,
+ });
+ }
+
+ pub fn send_stream_complete(&self, stream_id: StreamId) {
+ self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamWritable { stream_id: x } if *x == stream_id));
+
+ self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamStopSending { stream_id: x, .. } if *x == stream_id.as_u64()));
+
+ self.insert(ConnectionEvent::SendStreamComplete { stream_id });
+ }
+
+ pub fn send_stream_creatable(&self, stream_type: StreamType) {
+ self.insert(ConnectionEvent::SendStreamCreatable { stream_type });
+ }
+
+ pub fn connection_state_change(&self, state: State) {
+ // If closing, existing events no longer relevant.
+ match state {
+ State::Closing { .. } | State::Closed(_) => self.events.borrow_mut().clear(),
+ _ => (),
+ }
+ self.insert(ConnectionEvent::StateChange(state));
+ }
+
+ pub fn client_resumption_token(&self, token: ResumptionToken) {
+ self.insert(ConnectionEvent::ResumptionToken(token));
+ }
+
+ pub fn client_0rtt_rejected(&self) {
+ // If 0rtt rejected, must start over and existing events are no longer
+ // relevant.
+ self.events.borrow_mut().clear();
+ self.insert(ConnectionEvent::ZeroRttRejected);
+ }
+
+ pub fn recv_stream_complete(&self, stream_id: StreamId) {
+ // If stopped, no longer readable.
+ self.remove(|evt| matches!(evt, ConnectionEvent::RecvStreamReadable { stream_id: x } if *x == stream_id.as_u64()));
+ }
+
+ // The number of datagrams in the events queue is limited to max_queued_datagrams.
+ // This function ensure this and deletes the oldest datagrams if needed.
+ fn check_datagram_queued(&self, max_queued_datagrams: usize, stats: &mut Stats) {
+ let mut q = self.events.borrow_mut();
+ let mut remove = None;
+ if q.iter()
+ .filter(|evt| matches!(evt, ConnectionEvent::Datagram(_)))
+ .count()
+ == max_queued_datagrams
+ {
+ if let Some(d) = q
+ .iter()
+ .rev()
+ .enumerate()
+ .filter(|(_, evt)| matches!(evt, ConnectionEvent::Datagram(_)))
+ .take(1)
+ .next()
+ {
+ remove = Some(d.0);
+ }
+ }
+ if let Some(r) = remove {
+ q.remove(r);
+ q.push_back(ConnectionEvent::IncomingDatagramDropped);
+ stats.incoming_datagram_dropped += 1;
+ }
+ }
+
+ pub fn add_datagram(&self, max_queued_datagrams: usize, data: &[u8], stats: &mut Stats) {
+ self.check_datagram_queued(max_queued_datagrams, stats);
+ self.events
+ .borrow_mut()
+ .push_back(ConnectionEvent::Datagram(data.to_vec()));
+ }
+
+ pub fn datagram_outcome(
+ &self,
+ dgram_tracker: &DatagramTracking,
+ outcome: OutgoingDatagramOutcome,
+ ) {
+ if let DatagramTracking::Id(id) = dgram_tracker {
+ self.events
+ .borrow_mut()
+ .push_back(ConnectionEvent::OutgoingDatagramOutcome { id: *id, outcome });
+ }
+ }
+
+ fn insert(&self, event: ConnectionEvent) {
+ let mut q = self.events.borrow_mut();
+
+ // Special-case two enums that are not strictly PartialEq equal but that
+ // we wish to avoid inserting duplicates.
+ let already_present = match &event {
+ ConnectionEvent::SendStreamStopSending { stream_id, .. } => q.iter().any(|evt| {
+ matches!(evt, ConnectionEvent::SendStreamStopSending { stream_id: x, .. }
+ if *x == *stream_id)
+ }),
+ ConnectionEvent::RecvStreamReset { stream_id, .. } => q.iter().any(|evt| {
+ matches!(evt, ConnectionEvent::RecvStreamReset { stream_id: x, .. }
+ if *x == *stream_id)
+ }),
+ _ => q.contains(&event),
+ };
+ if !already_present {
+ q.push_back(event);
+ }
+ }
+
+ fn remove<F>(&self, f: F)
+ where
+ F: Fn(&ConnectionEvent) -> bool,
+ {
+ self.events.borrow_mut().retain(|evt| !f(evt));
+ }
+}
+
+impl EventProvider for ConnectionEvents {
+ type Event = ConnectionEvent;
+
+ fn has_events(&self) -> bool {
+ !self.events.borrow().is_empty()
+ }
+
+ fn next_event(&mut self) -> Option<Self::Event> {
+ self.events.borrow_mut().pop_front()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{ConnectionError, Error};
+
+ #[test]
+ fn event_culling() {
+ let mut evts = ConnectionEvents::default();
+
+ evts.client_0rtt_rejected();
+ evts.client_0rtt_rejected();
+ assert_eq!(evts.events().count(), 1);
+ assert_eq!(evts.events().count(), 0);
+
+ evts.new_stream(4.into());
+ evts.new_stream(4.into());
+ assert_eq!(evts.events().count(), 1);
+
+ evts.recv_stream_readable(6.into());
+ evts.recv_stream_reset(6.into(), 66);
+ evts.recv_stream_reset(6.into(), 65);
+ assert_eq!(evts.events().count(), 1);
+
+ evts.send_stream_writable(8.into());
+ evts.send_stream_writable(8.into());
+ evts.send_stream_stop_sending(8.into(), 55);
+ evts.send_stream_stop_sending(8.into(), 56);
+ let events = evts.events().collect::<Vec<_>>();
+ assert_eq!(events.len(), 1);
+ assert_eq!(
+ events[0],
+ ConnectionEvent::SendStreamStopSending {
+ stream_id: StreamId::new(8),
+ app_error: 55
+ }
+ );
+
+ evts.send_stream_writable(8.into());
+ evts.send_stream_writable(8.into());
+ evts.send_stream_stop_sending(8.into(), 55);
+ evts.send_stream_stop_sending(8.into(), 56);
+ evts.send_stream_complete(8.into());
+ assert_eq!(evts.events().count(), 1);
+
+ evts.send_stream_writable(8.into());
+ evts.send_stream_writable(9.into());
+ evts.send_stream_stop_sending(10.into(), 55);
+ evts.send_stream_stop_sending(11.into(), 56);
+ evts.send_stream_complete(12.into());
+ assert_eq!(evts.events().count(), 5);
+
+ evts.send_stream_writable(8.into());
+ evts.send_stream_writable(9.into());
+ evts.send_stream_stop_sending(10.into(), 55);
+ evts.send_stream_stop_sending(11.into(), 56);
+ evts.send_stream_complete(12.into());
+ evts.client_0rtt_rejected();
+ assert_eq!(evts.events().count(), 1);
+
+ evts.send_stream_writable(9.into());
+ evts.send_stream_stop_sending(10.into(), 55);
+ evts.connection_state_change(State::Closed(ConnectionError::Transport(
+ Error::StreamStateError,
+ )));
+ assert_eq!(evts.events().count(), 1);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/fc.rs b/third_party/rust/neqo-transport/src/fc.rs
new file mode 100644
index 0000000000..a219ca7e8d
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/fc.rs
@@ -0,0 +1,918 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Tracks possibly-redundant flow control signals from other code and converts
+// into flow control frames needing to be sent to the remote.
+
+use std::{
+ convert::TryFrom,
+ fmt::Debug,
+ ops::{Deref, DerefMut, Index, IndexMut},
+};
+
+use neqo_common::{qtrace, Role};
+
+use crate::{
+ frame::{
+ FRAME_TYPE_DATA_BLOCKED, FRAME_TYPE_MAX_DATA, FRAME_TYPE_MAX_STREAMS_BIDI,
+ FRAME_TYPE_MAX_STREAMS_UNIDI, FRAME_TYPE_MAX_STREAM_DATA, FRAME_TYPE_STREAMS_BLOCKED_BIDI,
+ FRAME_TYPE_STREAMS_BLOCKED_UNIDI, FRAME_TYPE_STREAM_DATA_BLOCKED,
+ },
+ packet::PacketBuilder,
+ recovery::{RecoveryToken, StreamRecoveryToken},
+ stats::FrameStats,
+ stream_id::{StreamId, StreamType},
+ Error, Res,
+};
+
+#[derive(Debug)]
+pub struct SenderFlowControl<T>
+where
+ T: Debug + Sized,
+{
+ /// The thing that we're counting for.
+ subject: T,
+ /// The limit.
+ limit: u64,
+ /// How much of that limit we've used.
+ used: u64,
+ /// The point at which blocking occurred. This is updated each time
+ /// the sender decides that it is blocked. It only ever changes
+ /// when blocking occurs. This ensures that blocking at any given limit
+ /// is only reported once.
+ /// Note: All values are one greater than the corresponding `limit` to
+ /// allow distinguishing between blocking at a limit of 0 and no blocking.
+ blocked_at: u64,
+ /// Whether a blocked frame should be sent.
+ blocked_frame: bool,
+}
+
+impl<T> SenderFlowControl<T>
+where
+ T: Debug + Sized,
+{
+ /// Make a new instance with the initial value and subject.
+ pub fn new(subject: T, initial: u64) -> Self {
+ Self {
+ subject,
+ limit: initial,
+ used: 0,
+ blocked_at: 0,
+ blocked_frame: false,
+ }
+ }
+
+ /// Update the maximum. Returns `true` if the change was an increase.
+ pub fn update(&mut self, limit: u64) -> bool {
+ debug_assert!(limit < u64::MAX);
+ if limit > self.limit {
+ self.limit = limit;
+ self.blocked_frame = false;
+ true
+ } else {
+ false
+ }
+ }
+
+ /// Consume flow control.
+ pub fn consume(&mut self, count: usize) {
+ let amt = u64::try_from(count).unwrap();
+ debug_assert!(self.used + amt <= self.limit);
+ self.used += amt;
+ }
+
+ /// Get available flow control.
+ pub fn available(&self) -> usize {
+ usize::try_from(self.limit - self.used).unwrap_or(usize::MAX)
+ }
+
+ /// How much data has been written.
+ pub fn used(&self) -> u64 {
+ self.used
+ }
+
+ /// Mark flow control as blocked.
+ /// This only does something if the current limit exceeds the last reported blocking limit.
+ pub fn blocked(&mut self) {
+ if self.limit >= self.blocked_at {
+ self.blocked_at = self.limit + 1;
+ self.blocked_frame = true;
+ }
+ }
+
+ /// Return whether a blocking frame needs to be sent.
+ /// This is `Some` with the active limit if `blocked` has been called,
+ /// if a blocking frame has not been sent (or it has been lost), and
+ /// if the blocking condition remains.
+ fn blocked_needed(&self) -> Option<u64> {
+ if self.blocked_frame && self.limit < self.blocked_at {
+ Some(self.blocked_at - 1)
+ } else {
+ None
+ }
+ }
+
+ /// Clear the need to send a blocked frame.
+ fn blocked_sent(&mut self) {
+ self.blocked_frame = false;
+ }
+
+ /// Mark a blocked frame as having been lost.
+ /// Only send again if value of `self.blocked_at` hasn't increased since sending.
+ /// That would imply that the limit has since increased.
+ pub fn frame_lost(&mut self, limit: u64) {
+ if self.blocked_at == limit + 1 {
+ self.blocked_frame = true;
+ }
+ }
+}
+
+impl SenderFlowControl<()> {
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ if let Some(limit) = self.blocked_needed() {
+ if builder.write_varint_frame(&[FRAME_TYPE_DATA_BLOCKED, limit]) {
+ stats.data_blocked += 1;
+ tokens.push(RecoveryToken::Stream(StreamRecoveryToken::DataBlocked(
+ limit,
+ )));
+ self.blocked_sent();
+ }
+ }
+ }
+}
+
+impl SenderFlowControl<StreamId> {
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ if let Some(limit) = self.blocked_needed() {
+ if builder.write_varint_frame(&[
+ FRAME_TYPE_STREAM_DATA_BLOCKED,
+ self.subject.as_u64(),
+ limit,
+ ]) {
+ stats.stream_data_blocked += 1;
+ tokens.push(RecoveryToken::Stream(
+ StreamRecoveryToken::StreamDataBlocked {
+ stream_id: self.subject,
+ limit,
+ },
+ ));
+ self.blocked_sent();
+ }
+ }
+ }
+}
+
+impl SenderFlowControl<StreamType> {
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ if let Some(limit) = self.blocked_needed() {
+ let frame = match self.subject {
+ StreamType::BiDi => FRAME_TYPE_STREAMS_BLOCKED_BIDI,
+ StreamType::UniDi => FRAME_TYPE_STREAMS_BLOCKED_UNIDI,
+ };
+ if builder.write_varint_frame(&[frame, limit]) {
+ stats.streams_blocked += 1;
+ tokens.push(RecoveryToken::Stream(StreamRecoveryToken::StreamsBlocked {
+ stream_type: self.subject,
+ limit,
+ }));
+ self.blocked_sent();
+ }
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct ReceiverFlowControl<T>
+where
+ T: Debug + Sized,
+{
+ /// The thing that we're counting for.
+ subject: T,
+ /// The maximum amount of items that can be active (e.g., the size of the receive buffer).
+ max_active: u64,
+ /// Last max allowed sent.
+ max_allowed: u64,
+ /// Item received, but not retired yet.
+ /// This will be used for byte flow control: each stream will remember is largest byte
+ /// offset received and session flow control will remember the sum of all bytes consumed
+ /// by all streams.
+ consumed: u64,
+ /// Retired items.
+ retired: u64,
+ frame_pending: bool,
+}
+
+impl<T> ReceiverFlowControl<T>
+where
+ T: Debug + Sized,
+{
+ /// Make a new instance with the initial value and subject.
+ pub fn new(subject: T, max: u64) -> Self {
+ Self {
+ subject,
+ max_active: max,
+ max_allowed: max,
+ consumed: 0,
+ retired: 0,
+ frame_pending: false,
+ }
+ }
+
+ /// Retired some items and maybe send flow control
+ /// update.
+ pub fn retire(&mut self, retired: u64) {
+ if retired <= self.retired {
+ return;
+ }
+
+ self.retired = retired;
+ if self.retired + self.max_active / 2 > self.max_allowed {
+ self.frame_pending = true;
+ }
+ }
+
+ /// This function is called when STREAM_DATA_BLOCKED frame is received.
+ /// The flow control will try to send an update if possible.
+ pub fn send_flowc_update(&mut self) {
+ if self.retired + self.max_active > self.max_allowed {
+ self.frame_pending = true;
+ }
+ }
+
+ pub fn frame_needed(&self) -> bool {
+ self.frame_pending
+ }
+
+ pub fn next_limit(&self) -> u64 {
+ self.retired + self.max_active
+ }
+
+ pub fn max_active(&self) -> u64 {
+ self.max_active
+ }
+
+ pub fn frame_lost(&mut self, maximum_data: u64) {
+ if maximum_data == self.max_allowed {
+ self.frame_pending = true;
+ }
+ }
+
+ fn frame_sent(&mut self, new_max: u64) {
+ self.max_allowed = new_max;
+ self.frame_pending = false;
+ }
+
+ pub fn set_max_active(&mut self, max: u64) {
+ // If max_active has been increased, send an update immediately.
+ self.frame_pending |= self.max_active < max;
+ self.max_active = max;
+ }
+
+ pub fn retired(&self) -> u64 {
+ self.retired
+ }
+
+ pub fn consumed(&self) -> u64 {
+ self.consumed
+ }
+}
+
+impl ReceiverFlowControl<()> {
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ if !self.frame_needed() {
+ return;
+ }
+ let max_allowed = self.next_limit();
+ if builder.write_varint_frame(&[FRAME_TYPE_MAX_DATA, max_allowed]) {
+ stats.max_data += 1;
+ tokens.push(RecoveryToken::Stream(StreamRecoveryToken::MaxData(
+ max_allowed,
+ )));
+ self.frame_sent(max_allowed);
+ }
+ }
+
+ pub fn add_retired(&mut self, count: u64) {
+ debug_assert!(self.retired + count <= self.consumed);
+ self.retired += count;
+ if self.retired + self.max_active / 2 > self.max_allowed {
+ self.frame_pending = true;
+ }
+ }
+
+ pub fn consume(&mut self, count: u64) -> Res<()> {
+ if self.consumed + count > self.max_allowed {
+ qtrace!(
+ "Session RX window exceeded: consumed:{} new:{} limit:{}",
+ self.consumed,
+ count,
+ self.max_allowed
+ );
+ return Err(Error::FlowControlError);
+ }
+ self.consumed += count;
+ Ok(())
+ }
+}
+
+impl Default for ReceiverFlowControl<()> {
+ fn default() -> Self {
+ Self::new((), 0)
+ }
+}
+
+impl ReceiverFlowControl<StreamId> {
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ if !self.frame_needed() {
+ return;
+ }
+ let max_allowed = self.next_limit();
+ if builder.write_varint_frame(&[
+ FRAME_TYPE_MAX_STREAM_DATA,
+ self.subject.as_u64(),
+ max_allowed,
+ ]) {
+ stats.max_stream_data += 1;
+ tokens.push(RecoveryToken::Stream(StreamRecoveryToken::MaxStreamData {
+ stream_id: self.subject,
+ max_data: max_allowed,
+ }));
+ self.frame_sent(max_allowed);
+ }
+ }
+
+ pub fn add_retired(&mut self, count: u64) {
+ debug_assert!(self.retired + count <= self.consumed);
+ self.retired += count;
+ if self.retired + self.max_active / 2 > self.max_allowed {
+ self.frame_pending = true;
+ }
+ }
+
+ pub fn set_consumed(&mut self, consumed: u64) -> Res<u64> {
+ if consumed <= self.consumed {
+ return Ok(0);
+ }
+
+ if consumed > self.max_allowed {
+ qtrace!("Stream RX window exceeded: {}", consumed);
+ return Err(Error::FlowControlError);
+ }
+ let new_consumed = consumed - self.consumed;
+ self.consumed = consumed;
+ Ok(new_consumed)
+ }
+}
+
+impl Default for ReceiverFlowControl<StreamId> {
+ fn default() -> Self {
+ Self::new(StreamId::new(0), 0)
+ }
+}
+
+impl ReceiverFlowControl<StreamType> {
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ if !self.frame_needed() {
+ return;
+ }
+ let max_streams = self.next_limit();
+ let frame = match self.subject {
+ StreamType::BiDi => FRAME_TYPE_MAX_STREAMS_BIDI,
+ StreamType::UniDi => FRAME_TYPE_MAX_STREAMS_UNIDI,
+ };
+ if builder.write_varint_frame(&[frame, max_streams]) {
+ stats.max_streams += 1;
+ tokens.push(RecoveryToken::Stream(StreamRecoveryToken::MaxStreams {
+ stream_type: self.subject,
+ max_streams,
+ }));
+ self.frame_sent(max_streams);
+ }
+ }
+
+ /// Check if received item exceeds the allowed flow control limit.
+ pub fn check_allowed(&self, new_end: u64) -> bool {
+ new_end < self.max_allowed
+ }
+
+ /// Retire given amount of additional data.
+ /// This function will send flow updates immediately.
+ pub fn add_retired(&mut self, count: u64) {
+ self.retired += count;
+ if count > 0 {
+ self.send_flowc_update();
+ }
+ }
+}
+
+pub struct RemoteStreamLimit {
+ streams_fc: ReceiverFlowControl<StreamType>,
+ next_stream: StreamId,
+}
+
+impl RemoteStreamLimit {
+ pub fn new(stream_type: StreamType, max_streams: u64, role: Role) -> Self {
+ Self {
+ streams_fc: ReceiverFlowControl::new(stream_type, max_streams),
+ // // This is for a stream created by a peer, therefore we use role.remote().
+ next_stream: StreamId::init(stream_type, role.remote()),
+ }
+ }
+
+ pub fn is_allowed(&self, stream_id: StreamId) -> bool {
+ let stream_idx = stream_id.as_u64() >> 2;
+ self.streams_fc.check_allowed(stream_idx)
+ }
+
+ pub fn is_new_stream(&self, stream_id: StreamId) -> Res<bool> {
+ if !self.is_allowed(stream_id) {
+ return Err(Error::StreamLimitError);
+ }
+ Ok(stream_id >= self.next_stream)
+ }
+
+ pub fn take_stream_id(&mut self) -> StreamId {
+ let new_stream = self.next_stream;
+ self.next_stream.next();
+ assert!(self.is_allowed(new_stream));
+ new_stream
+ }
+}
+
+impl Deref for RemoteStreamLimit {
+ type Target = ReceiverFlowControl<StreamType>;
+ fn deref(&self) -> &Self::Target {
+ &self.streams_fc
+ }
+}
+
+impl DerefMut for RemoteStreamLimit {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.streams_fc
+ }
+}
+
+pub struct RemoteStreamLimits {
+ bidirectional: RemoteStreamLimit,
+ unidirectional: RemoteStreamLimit,
+}
+
+impl RemoteStreamLimits {
+ pub fn new(local_max_stream_bidi: u64, local_max_stream_uni: u64, role: Role) -> Self {
+ Self {
+ bidirectional: RemoteStreamLimit::new(StreamType::BiDi, local_max_stream_bidi, role),
+ unidirectional: RemoteStreamLimit::new(StreamType::UniDi, local_max_stream_uni, role),
+ }
+ }
+}
+
+impl Index<StreamType> for RemoteStreamLimits {
+ type Output = RemoteStreamLimit;
+
+ fn index(&self, idx: StreamType) -> &Self::Output {
+ match idx {
+ StreamType::BiDi => &self.bidirectional,
+ StreamType::UniDi => &self.unidirectional,
+ }
+ }
+}
+
+impl IndexMut<StreamType> for RemoteStreamLimits {
+ fn index_mut(&mut self, idx: StreamType) -> &mut Self::Output {
+ match idx {
+ StreamType::BiDi => &mut self.bidirectional,
+ StreamType::UniDi => &mut self.unidirectional,
+ }
+ }
+}
+
+pub struct LocalStreamLimits {
+ bidirectional: SenderFlowControl<StreamType>,
+ unidirectional: SenderFlowControl<StreamType>,
+ role_bit: u64,
+}
+
+impl LocalStreamLimits {
+ pub fn new(role: Role) -> Self {
+ Self {
+ bidirectional: SenderFlowControl::new(StreamType::BiDi, 0),
+ unidirectional: SenderFlowControl::new(StreamType::UniDi, 0),
+ role_bit: StreamId::role_bit(role),
+ }
+ }
+
+ pub fn take_stream_id(&mut self, stream_type: StreamType) -> Option<StreamId> {
+ let fc = match stream_type {
+ StreamType::BiDi => &mut self.bidirectional,
+ StreamType::UniDi => &mut self.unidirectional,
+ };
+ if fc.available() > 0 {
+ let new_stream = fc.used();
+ fc.consume(1);
+ let type_bit = match stream_type {
+ StreamType::BiDi => 0,
+ StreamType::UniDi => 2,
+ };
+ Some(StreamId::from((new_stream << 2) + type_bit + self.role_bit))
+ } else {
+ fc.blocked();
+ None
+ }
+ }
+}
+
+impl Index<StreamType> for LocalStreamLimits {
+ type Output = SenderFlowControl<StreamType>;
+
+ fn index(&self, idx: StreamType) -> &Self::Output {
+ match idx {
+ StreamType::BiDi => &self.bidirectional,
+ StreamType::UniDi => &self.unidirectional,
+ }
+ }
+}
+
+impl IndexMut<StreamType> for LocalStreamLimits {
+ fn index_mut(&mut self, idx: StreamType) -> &mut Self::Output {
+ match idx {
+ StreamType::BiDi => &mut self.bidirectional,
+ StreamType::UniDi => &mut self.unidirectional,
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use neqo_common::{Encoder, Role};
+
+ use super::{LocalStreamLimits, ReceiverFlowControl, RemoteStreamLimits, SenderFlowControl};
+ use crate::{
+ packet::PacketBuilder,
+ stats::FrameStats,
+ stream_id::{StreamId, StreamType},
+ Error,
+ };
+
+ #[test]
+ fn blocked_at_zero() {
+ let mut fc = SenderFlowControl::new((), 0);
+ fc.blocked();
+ assert_eq!(fc.blocked_needed(), Some(0));
+ }
+
+ #[test]
+ fn blocked() {
+ let mut fc = SenderFlowControl::new((), 10);
+ fc.blocked();
+ assert_eq!(fc.blocked_needed(), Some(10));
+ }
+
+ #[test]
+ fn update_consume() {
+ let mut fc = SenderFlowControl::new((), 10);
+ fc.consume(10);
+ assert_eq!(fc.available(), 0);
+ fc.update(5); // An update lower than the current limit does nothing.
+ assert_eq!(fc.available(), 0);
+ fc.update(15);
+ assert_eq!(fc.available(), 5);
+ fc.consume(3);
+ assert_eq!(fc.available(), 2);
+ }
+
+ #[test]
+ fn update_clears_blocked() {
+ let mut fc = SenderFlowControl::new((), 10);
+ fc.blocked();
+ assert_eq!(fc.blocked_needed(), Some(10));
+ fc.update(5); // An update lower than the current limit does nothing.
+ assert_eq!(fc.blocked_needed(), Some(10));
+ fc.update(11);
+ assert_eq!(fc.blocked_needed(), None);
+ }
+
+ #[test]
+ fn lost_blocked_resent() {
+ let mut fc = SenderFlowControl::new((), 10);
+ fc.blocked();
+ fc.blocked_sent();
+ assert_eq!(fc.blocked_needed(), None);
+ fc.frame_lost(10);
+ assert_eq!(fc.blocked_needed(), Some(10));
+ }
+
+ #[test]
+ fn lost_after_increase() {
+ let mut fc = SenderFlowControl::new((), 10);
+ fc.blocked();
+ fc.blocked_sent();
+ assert_eq!(fc.blocked_needed(), None);
+ fc.update(11);
+ fc.frame_lost(10);
+ assert_eq!(fc.blocked_needed(), None);
+ }
+
+ #[test]
+ fn lost_after_higher_blocked() {
+ let mut fc = SenderFlowControl::new((), 10);
+ fc.blocked();
+ fc.blocked_sent();
+ fc.update(11);
+ fc.blocked();
+ assert_eq!(fc.blocked_needed(), Some(11));
+ fc.blocked_sent();
+ fc.frame_lost(10);
+ assert_eq!(fc.blocked_needed(), None);
+ }
+
+ #[test]
+ fn do_no_need_max_allowed_frame_at_start() {
+ let fc = ReceiverFlowControl::new((), 0);
+ assert!(!fc.frame_needed());
+ }
+
+ #[test]
+ fn max_allowed_after_items_retired() {
+ let mut fc = ReceiverFlowControl::new((), 100);
+ fc.retire(49);
+ assert!(!fc.frame_needed());
+ fc.retire(51);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 151);
+ }
+
+ #[test]
+ fn need_max_allowed_frame_after_loss() {
+ let mut fc = ReceiverFlowControl::new((), 100);
+ fc.retire(100);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 200);
+ fc.frame_sent(200);
+ assert!(!fc.frame_needed());
+ fc.frame_lost(200);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 200);
+ }
+
+ #[test]
+ fn no_max_allowed_frame_after_old_loss() {
+ let mut fc = ReceiverFlowControl::new((), 100);
+ fc.retire(51);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 151);
+ fc.frame_sent(151);
+ assert!(!fc.frame_needed());
+ fc.retire(102);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 202);
+ fc.frame_sent(202);
+ assert!(!fc.frame_needed());
+ fc.frame_lost(151);
+ assert!(!fc.frame_needed());
+ }
+
+ #[test]
+ fn force_send_max_allowed() {
+ let mut fc = ReceiverFlowControl::new((), 100);
+ fc.retire(10);
+ assert!(!fc.frame_needed());
+ }
+
+ #[test]
+ fn multiple_retries_after_frame_pending_is_set() {
+ let mut fc = ReceiverFlowControl::new((), 100);
+ fc.retire(51);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 151);
+ fc.retire(61);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 161);
+ fc.retire(88);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 188);
+ fc.retire(90);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 190);
+ fc.frame_sent(190);
+ assert!(!fc.frame_needed());
+ fc.retire(141);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 241);
+ fc.frame_sent(241);
+ assert!(!fc.frame_needed());
+ }
+
+ #[test]
+ fn new_retired_before_loss() {
+ let mut fc = ReceiverFlowControl::new((), 100);
+ fc.retire(51);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 151);
+ fc.frame_sent(151);
+ assert!(!fc.frame_needed());
+ fc.retire(62);
+ assert!(!fc.frame_needed());
+ fc.frame_lost(151);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 162);
+ }
+
+ #[test]
+ fn changing_max_active() {
+ let mut fc = ReceiverFlowControl::new((), 100);
+ fc.set_max_active(50);
+ // There is no MAX_STREAM_DATA frame needed.
+ assert!(!fc.frame_needed());
+ // We can still retire more than 50.
+ fc.retire(60);
+ // There is no MAX_STREAM_DATA fame needed yet.
+ assert!(!fc.frame_needed());
+ fc.retire(76);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 126);
+
+ // Increase max_active.
+ fc.set_max_active(60);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 136);
+
+ // We can retire more than 60.
+ fc.retire(136);
+ assert!(fc.frame_needed());
+ assert_eq!(fc.next_limit(), 196);
+ }
+
+ fn remote_stream_limits(role: Role, bidi: u64, unidi: u64) {
+ let mut fc = RemoteStreamLimits::new(2, 1, role);
+ assert!(fc[StreamType::BiDi]
+ .is_new_stream(StreamId::from(bidi))
+ .unwrap());
+ assert!(fc[StreamType::BiDi]
+ .is_new_stream(StreamId::from(bidi + 4))
+ .unwrap());
+ assert!(fc[StreamType::UniDi]
+ .is_new_stream(StreamId::from(unidi))
+ .unwrap());
+
+ // Exceed limits
+ assert_eq!(
+ fc[StreamType::BiDi].is_new_stream(StreamId::from(bidi + 8)),
+ Err(Error::StreamLimitError)
+ );
+ assert_eq!(
+ fc[StreamType::UniDi].is_new_stream(StreamId::from(unidi + 4)),
+ Err(Error::StreamLimitError)
+ );
+
+ assert_eq!(fc[StreamType::BiDi].take_stream_id(), StreamId::from(bidi));
+ assert_eq!(
+ fc[StreamType::BiDi].take_stream_id(),
+ StreamId::from(bidi + 4)
+ );
+ assert_eq!(
+ fc[StreamType::UniDi].take_stream_id(),
+ StreamId::from(unidi)
+ );
+
+ fc[StreamType::BiDi].add_retired(1);
+ fc[StreamType::BiDi].send_flowc_update();
+ // consume the frame
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ let mut tokens = Vec::new();
+ fc[StreamType::BiDi].write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ assert_eq!(tokens.len(), 1);
+
+ // Now 9 can be a new StreamId.
+ assert!(fc[StreamType::BiDi]
+ .is_new_stream(StreamId::from(bidi + 8))
+ .unwrap());
+ assert_eq!(
+ fc[StreamType::BiDi].take_stream_id(),
+ StreamId::from(bidi + 8)
+ );
+ // 13 still exceeds limits
+ assert_eq!(
+ fc[StreamType::BiDi].is_new_stream(StreamId::from(bidi + 12)),
+ Err(Error::StreamLimitError)
+ );
+
+ fc[StreamType::UniDi].add_retired(1);
+ fc[StreamType::UniDi].send_flowc_update();
+ // consume the frame
+ fc[StreamType::UniDi].write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ assert_eq!(tokens.len(), 2);
+
+ // Now 7 can be a new StreamId.
+ assert!(fc[StreamType::UniDi]
+ .is_new_stream(StreamId::from(unidi + 4))
+ .unwrap());
+ assert_eq!(
+ fc[StreamType::UniDi].take_stream_id(),
+ StreamId::from(unidi + 4)
+ );
+ // 11 exceeds limits
+ assert_eq!(
+ fc[StreamType::UniDi].is_new_stream(StreamId::from(unidi + 8)),
+ Err(Error::StreamLimitError)
+ );
+ }
+
+ #[test]
+ fn remote_stream_limits_new_stream_client() {
+ remote_stream_limits(Role::Client, 1, 3);
+ }
+
+ #[test]
+ fn remote_stream_limits_new_stream_server() {
+ remote_stream_limits(Role::Server, 0, 2);
+ }
+
+ #[should_panic(expected = ".is_allowed")]
+ #[test]
+ fn remote_stream_limits_asserts_if_limit_exceeded() {
+ let mut fc = RemoteStreamLimits::new(2, 1, Role::Client);
+ assert_eq!(fc[StreamType::BiDi].take_stream_id(), StreamId::from(1));
+ assert_eq!(fc[StreamType::BiDi].take_stream_id(), StreamId::from(5));
+ _ = fc[StreamType::BiDi].take_stream_id();
+ }
+
+ fn local_stream_limits(role: Role, bidi: u64, unidi: u64) {
+ let mut fc = LocalStreamLimits::new(role);
+
+ fc[StreamType::BiDi].update(2);
+ fc[StreamType::UniDi].update(1);
+
+ // Add streams
+ assert_eq!(
+ fc.take_stream_id(StreamType::BiDi).unwrap(),
+ StreamId::from(bidi)
+ );
+ assert_eq!(
+ fc.take_stream_id(StreamType::BiDi).unwrap(),
+ StreamId::from(bidi + 4)
+ );
+ assert_eq!(fc.take_stream_id(StreamType::BiDi), None);
+ assert_eq!(
+ fc.take_stream_id(StreamType::UniDi).unwrap(),
+ StreamId::from(unidi)
+ );
+ assert_eq!(fc.take_stream_id(StreamType::UniDi), None);
+
+ // Increase limit
+ fc[StreamType::BiDi].update(3);
+ fc[StreamType::UniDi].update(2);
+ assert_eq!(
+ fc.take_stream_id(StreamType::BiDi).unwrap(),
+ StreamId::from(bidi + 8)
+ );
+ assert_eq!(fc.take_stream_id(StreamType::BiDi), None);
+ assert_eq!(
+ fc.take_stream_id(StreamType::UniDi).unwrap(),
+ StreamId::from(unidi + 4)
+ );
+ assert_eq!(fc.take_stream_id(StreamType::UniDi), None);
+ }
+
+ #[test]
+ fn local_stream_limits_new_stream_client() {
+ local_stream_limits(Role::Client, 0, 2);
+ }
+
+ #[test]
+ fn local_stream_limits_new_stream_server() {
+ local_stream_limits(Role::Server, 1, 3);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/frame.rs b/third_party/rust/neqo-transport/src/frame.rs
new file mode 100644
index 0000000000..f3d567ac7c
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/frame.rs
@@ -0,0 +1,977 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Directly relating to QUIC frames.
+
+use std::{convert::TryFrom, ops::RangeInclusive};
+
+use neqo_common::{qtrace, Decoder};
+
+use crate::{
+ cid::MAX_CONNECTION_ID_LEN,
+ packet::PacketType,
+ stream_id::{StreamId, StreamType},
+ AppError, ConnectionError, Error, Res, TransportError,
+};
+
+#[allow(clippy::module_name_repetitions)]
+pub type FrameType = u64;
+
+const FRAME_TYPE_PADDING: FrameType = 0x0;
+pub const FRAME_TYPE_PING: FrameType = 0x1;
+pub const FRAME_TYPE_ACK: FrameType = 0x2;
+const FRAME_TYPE_ACK_ECN: FrameType = 0x3;
+pub const FRAME_TYPE_RESET_STREAM: FrameType = 0x4;
+pub const FRAME_TYPE_STOP_SENDING: FrameType = 0x5;
+pub const FRAME_TYPE_CRYPTO: FrameType = 0x6;
+pub const FRAME_TYPE_NEW_TOKEN: FrameType = 0x7;
+const FRAME_TYPE_STREAM: FrameType = 0x8;
+const FRAME_TYPE_STREAM_MAX: FrameType = 0xf;
+pub const FRAME_TYPE_MAX_DATA: FrameType = 0x10;
+pub const FRAME_TYPE_MAX_STREAM_DATA: FrameType = 0x11;
+pub const FRAME_TYPE_MAX_STREAMS_BIDI: FrameType = 0x12;
+pub const FRAME_TYPE_MAX_STREAMS_UNIDI: FrameType = 0x13;
+pub const FRAME_TYPE_DATA_BLOCKED: FrameType = 0x14;
+pub const FRAME_TYPE_STREAM_DATA_BLOCKED: FrameType = 0x15;
+pub const FRAME_TYPE_STREAMS_BLOCKED_BIDI: FrameType = 0x16;
+pub const FRAME_TYPE_STREAMS_BLOCKED_UNIDI: FrameType = 0x17;
+pub const FRAME_TYPE_NEW_CONNECTION_ID: FrameType = 0x18;
+pub const FRAME_TYPE_RETIRE_CONNECTION_ID: FrameType = 0x19;
+pub const FRAME_TYPE_PATH_CHALLENGE: FrameType = 0x1a;
+pub const FRAME_TYPE_PATH_RESPONSE: FrameType = 0x1b;
+pub const FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT: FrameType = 0x1c;
+pub const FRAME_TYPE_CONNECTION_CLOSE_APPLICATION: FrameType = 0x1d;
+pub const FRAME_TYPE_HANDSHAKE_DONE: FrameType = 0x1e;
+// draft-ietf-quic-ack-delay
+pub const FRAME_TYPE_ACK_FREQUENCY: FrameType = 0xaf;
+// draft-ietf-quic-datagram
+pub const FRAME_TYPE_DATAGRAM: FrameType = 0x30;
+pub const FRAME_TYPE_DATAGRAM_WITH_LEN: FrameType = 0x31;
+const DATAGRAM_FRAME_BIT_LEN: u64 = 0x01;
+
+const STREAM_FRAME_BIT_FIN: u64 = 0x01;
+const STREAM_FRAME_BIT_LEN: u64 = 0x02;
+const STREAM_FRAME_BIT_OFF: u64 = 0x04;
+
+#[derive(PartialEq, Eq, Debug, PartialOrd, Ord, Clone, Copy)]
+pub enum CloseError {
+ Transport(TransportError),
+ Application(AppError),
+}
+
+impl CloseError {
+ fn frame_type_bit(self) -> u64 {
+ match self {
+ Self::Transport(_) => 0,
+ Self::Application(_) => 1,
+ }
+ }
+
+ fn from_type_bit(bit: u64, code: u64) -> Self {
+ if (bit & 0x01) == 0 {
+ Self::Transport(code)
+ } else {
+ Self::Application(code)
+ }
+ }
+
+ pub fn code(&self) -> u64 {
+ match self {
+ Self::Transport(c) | Self::Application(c) => *c,
+ }
+ }
+}
+
+impl From<ConnectionError> for CloseError {
+ fn from(err: ConnectionError) -> Self {
+ match err {
+ ConnectionError::Transport(c) => Self::Transport(c.code()),
+ ConnectionError::Application(c) => Self::Application(c),
+ }
+ }
+}
+
+#[derive(PartialEq, Eq, Debug, Default, Clone)]
+pub struct AckRange {
+ pub(crate) gap: u64,
+ pub(crate) range: u64,
+}
+
+#[derive(PartialEq, Eq, Debug, Clone)]
+pub enum Frame<'a> {
+ Padding,
+ Ping,
+ Ack {
+ largest_acknowledged: u64,
+ ack_delay: u64,
+ first_ack_range: u64,
+ ack_ranges: Vec<AckRange>,
+ },
+ ResetStream {
+ stream_id: StreamId,
+ application_error_code: AppError,
+ final_size: u64,
+ },
+ StopSending {
+ stream_id: StreamId,
+ application_error_code: AppError,
+ },
+ Crypto {
+ offset: u64,
+ data: &'a [u8],
+ },
+ NewToken {
+ token: &'a [u8],
+ },
+ Stream {
+ stream_id: StreamId,
+ offset: u64,
+ data: &'a [u8],
+ fin: bool,
+ fill: bool,
+ },
+ MaxData {
+ maximum_data: u64,
+ },
+ MaxStreamData {
+ stream_id: StreamId,
+ maximum_stream_data: u64,
+ },
+ MaxStreams {
+ stream_type: StreamType,
+ maximum_streams: u64,
+ },
+ DataBlocked {
+ data_limit: u64,
+ },
+ StreamDataBlocked {
+ stream_id: StreamId,
+ stream_data_limit: u64,
+ },
+ StreamsBlocked {
+ stream_type: StreamType,
+ stream_limit: u64,
+ },
+ NewConnectionId {
+ sequence_number: u64,
+ retire_prior: u64,
+ connection_id: &'a [u8],
+ stateless_reset_token: &'a [u8; 16],
+ },
+ RetireConnectionId {
+ sequence_number: u64,
+ },
+ PathChallenge {
+ data: [u8; 8],
+ },
+ PathResponse {
+ data: [u8; 8],
+ },
+ ConnectionClose {
+ error_code: CloseError,
+ frame_type: u64,
+ // Not a reference as we use this to hold the value.
+ // This is not used in optimized builds anyway.
+ reason_phrase: Vec<u8>,
+ },
+ HandshakeDone,
+ AckFrequency {
+ /// The current ACK frequency sequence number.
+ seqno: u64,
+ /// The number of contiguous packets that can be received without
+ /// acknowledging immediately.
+ tolerance: u64,
+ /// The time to delay after receiving the first packet that is
+ /// not immediately acknowledged.
+ delay: u64,
+ /// Ignore reordering when deciding to immediately acknowledge.
+ ignore_order: bool,
+ },
+ Datagram {
+ data: &'a [u8],
+ fill: bool,
+ },
+}
+
+impl<'a> Frame<'a> {
+ fn get_stream_type_bit(stream_type: StreamType) -> u64 {
+ match stream_type {
+ StreamType::BiDi => 0,
+ StreamType::UniDi => 1,
+ }
+ }
+
+ fn stream_type_from_bit(bit: u64) -> StreamType {
+ if (bit & 0x01) == 0 {
+ StreamType::BiDi
+ } else {
+ StreamType::UniDi
+ }
+ }
+
+ pub fn get_type(&self) -> FrameType {
+ match self {
+ Self::Padding => FRAME_TYPE_PADDING,
+ Self::Ping => FRAME_TYPE_PING,
+ Self::Ack { .. } => FRAME_TYPE_ACK, // We don't do ACK ECN.
+ Self::ResetStream { .. } => FRAME_TYPE_RESET_STREAM,
+ Self::StopSending { .. } => FRAME_TYPE_STOP_SENDING,
+ Self::Crypto { .. } => FRAME_TYPE_CRYPTO,
+ Self::NewToken { .. } => FRAME_TYPE_NEW_TOKEN,
+ Self::Stream {
+ fin, offset, fill, ..
+ } => Self::stream_type(*fin, *offset > 0, *fill),
+ Self::MaxData { .. } => FRAME_TYPE_MAX_DATA,
+ Self::MaxStreamData { .. } => FRAME_TYPE_MAX_STREAM_DATA,
+ Self::MaxStreams { stream_type, .. } => {
+ FRAME_TYPE_MAX_STREAMS_BIDI + Self::get_stream_type_bit(*stream_type)
+ }
+ Self::DataBlocked { .. } => FRAME_TYPE_DATA_BLOCKED,
+ Self::StreamDataBlocked { .. } => FRAME_TYPE_STREAM_DATA_BLOCKED,
+ Self::StreamsBlocked { stream_type, .. } => {
+ FRAME_TYPE_STREAMS_BLOCKED_BIDI + Self::get_stream_type_bit(*stream_type)
+ }
+ Self::NewConnectionId { .. } => FRAME_TYPE_NEW_CONNECTION_ID,
+ Self::RetireConnectionId { .. } => FRAME_TYPE_RETIRE_CONNECTION_ID,
+ Self::PathChallenge { .. } => FRAME_TYPE_PATH_CHALLENGE,
+ Self::PathResponse { .. } => FRAME_TYPE_PATH_RESPONSE,
+ Self::ConnectionClose { error_code, .. } => {
+ FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT + error_code.frame_type_bit()
+ }
+ Self::HandshakeDone => FRAME_TYPE_HANDSHAKE_DONE,
+ Self::AckFrequency { .. } => FRAME_TYPE_ACK_FREQUENCY,
+ Self::Datagram { fill, .. } => {
+ if *fill {
+ FRAME_TYPE_DATAGRAM
+ } else {
+ FRAME_TYPE_DATAGRAM_WITH_LEN
+ }
+ }
+ }
+ }
+
+ pub fn is_stream(&self) -> bool {
+ matches!(
+ self,
+ Self::ResetStream { .. }
+ | Self::StopSending { .. }
+ | Self::Stream { .. }
+ | Self::MaxData { .. }
+ | Self::MaxStreamData { .. }
+ | Self::MaxStreams { .. }
+ | Self::DataBlocked { .. }
+ | Self::StreamDataBlocked { .. }
+ | Self::StreamsBlocked { .. }
+ )
+ }
+
+ pub fn stream_type(fin: bool, nonzero_offset: bool, fill: bool) -> u64 {
+ let mut t = FRAME_TYPE_STREAM;
+ if fin {
+ t |= STREAM_FRAME_BIT_FIN;
+ }
+ if nonzero_offset {
+ t |= STREAM_FRAME_BIT_OFF;
+ }
+ if !fill {
+ t |= STREAM_FRAME_BIT_LEN;
+ }
+ t
+ }
+
+ /// If the frame causes a recipient to generate an ACK within its
+ /// advertised maximum acknowledgement delay.
+ pub fn ack_eliciting(&self) -> bool {
+ !matches!(
+ self,
+ Self::Ack { .. } | Self::Padding | Self::ConnectionClose { .. }
+ )
+ }
+
+ /// If the frame can be sent in a path probe
+ /// without initiating migration to that path.
+ pub fn path_probing(&self) -> bool {
+ matches!(
+ self,
+ Self::Padding
+ | Self::NewConnectionId { .. }
+ | Self::PathChallenge { .. }
+ | Self::PathResponse { .. }
+ )
+ }
+
+ /// Converts AckRanges as encoded in a ACK frame (see -transport
+ /// 19.3.1) into ranges of acked packets (end, start), inclusive of
+ /// start and end values.
+ pub fn decode_ack_frame(
+ largest_acked: u64,
+ first_ack_range: u64,
+ ack_ranges: &[AckRange],
+ ) -> Res<Vec<RangeInclusive<u64>>> {
+ let mut acked_ranges = Vec::with_capacity(ack_ranges.len() + 1);
+
+ if largest_acked < first_ack_range {
+ return Err(Error::FrameEncodingError);
+ }
+ acked_ranges.push((largest_acked - first_ack_range)..=largest_acked);
+ if !ack_ranges.is_empty() && largest_acked < first_ack_range + 1 {
+ return Err(Error::FrameEncodingError);
+ }
+ let mut cur = if ack_ranges.is_empty() {
+ 0
+ } else {
+ largest_acked - first_ack_range - 1
+ };
+ for r in ack_ranges {
+ if cur < r.gap + 1 {
+ return Err(Error::FrameEncodingError);
+ }
+ cur = cur - r.gap - 1;
+
+ if cur < r.range {
+ return Err(Error::FrameEncodingError);
+ }
+ acked_ranges.push((cur - r.range)..=cur);
+
+ if cur > r.range + 1 {
+ cur -= r.range + 1;
+ } else {
+ cur -= r.range;
+ }
+ }
+
+ Ok(acked_ranges)
+ }
+
+ pub fn dump(&self) -> Option<String> {
+ match self {
+ Self::Crypto { offset, data } => Some(format!(
+ "Crypto {{ offset: {}, len: {} }}",
+ offset,
+ data.len()
+ )),
+ Self::Stream {
+ stream_id,
+ offset,
+ fill,
+ data,
+ fin,
+ } => Some(format!(
+ "Stream {{ stream_id: {}, offset: {}, len: {}{}, fin: {} }}",
+ stream_id.as_u64(),
+ offset,
+ if *fill { ">>" } else { "" },
+ data.len(),
+ fin,
+ )),
+ Self::Padding => None,
+ Self::Datagram { data, .. } => Some(format!("Datagram {{ len: {} }}", data.len())),
+ _ => Some(format!("{self:?}")),
+ }
+ }
+
+ pub fn is_allowed(&self, pt: PacketType) -> bool {
+ match self {
+ Self::Padding | Self::Ping => true,
+ Self::Crypto { .. }
+ | Self::Ack { .. }
+ | Self::ConnectionClose {
+ error_code: CloseError::Transport(_),
+ ..
+ } => pt != PacketType::ZeroRtt,
+ Self::NewToken { .. } | Self::ConnectionClose { .. } => pt == PacketType::Short,
+ _ => pt == PacketType::ZeroRtt || pt == PacketType::Short,
+ }
+ }
+
+ pub fn decode(dec: &mut Decoder<'a>) -> Res<Self> {
+ /// Maximum ACK Range Count in ACK Frame
+ ///
+ /// Given a max UDP datagram size of 64k bytes and a minimum ACK Range size of 2
+ /// bytes (2 QUIC varints), a single datagram can at most contain 32k ACK
+ /// Ranges.
+ ///
+ /// Note that the maximum (jumbogram) Ethernet MTU of 9216 or on the
+ /// Internet the regular Ethernet MTU of 1518 are more realistically to
+ /// be the limiting factor. Though for simplicity the higher limit is chosen.
+ const MAX_ACK_RANGE_COUNT: u64 = 32 * 1024;
+
+ fn d<T>(v: Option<T>) -> Res<T> {
+ v.ok_or(Error::NoMoreData)
+ }
+ fn dv(dec: &mut Decoder) -> Res<u64> {
+ d(dec.decode_varint())
+ }
+
+ // TODO(ekr@rtfm.com): check for minimal encoding
+ let t = d(dec.decode_varint())?;
+ match t {
+ FRAME_TYPE_PADDING => Ok(Self::Padding),
+ FRAME_TYPE_PING => Ok(Self::Ping),
+ FRAME_TYPE_RESET_STREAM => Ok(Self::ResetStream {
+ stream_id: StreamId::from(dv(dec)?),
+ application_error_code: d(dec.decode_varint())?,
+ final_size: match dec.decode_varint() {
+ Some(v) => v,
+ _ => return Err(Error::NoMoreData),
+ },
+ }),
+ FRAME_TYPE_ACK | FRAME_TYPE_ACK_ECN => {
+ let la = dv(dec)?;
+ let ad = dv(dec)?;
+ let nr = dv(dec).and_then(|nr| {
+ if nr < MAX_ACK_RANGE_COUNT {
+ Ok(nr)
+ } else {
+ Err(Error::TooMuchData)
+ }
+ })?;
+ let fa = dv(dec)?;
+ let mut arr: Vec<AckRange> = Vec::with_capacity(nr as usize);
+ for _ in 0..nr {
+ let ar = AckRange {
+ gap: dv(dec)?,
+ range: dv(dec)?,
+ };
+ arr.push(ar);
+ }
+
+ // Now check for the values for ACK_ECN.
+ if t == FRAME_TYPE_ACK_ECN {
+ dv(dec)?;
+ dv(dec)?;
+ dv(dec)?;
+ }
+
+ Ok(Self::Ack {
+ largest_acknowledged: la,
+ ack_delay: ad,
+ first_ack_range: fa,
+ ack_ranges: arr,
+ })
+ }
+ FRAME_TYPE_STOP_SENDING => Ok(Self::StopSending {
+ stream_id: StreamId::from(dv(dec)?),
+ application_error_code: d(dec.decode_varint())?,
+ }),
+ FRAME_TYPE_CRYPTO => {
+ let offset = dv(dec)?;
+ let data = d(dec.decode_vvec())?;
+ if offset + u64::try_from(data.len()).unwrap() > ((1 << 62) - 1) {
+ return Err(Error::FrameEncodingError);
+ }
+ Ok(Self::Crypto { offset, data })
+ }
+ FRAME_TYPE_NEW_TOKEN => {
+ let token = d(dec.decode_vvec())?;
+ if token.is_empty() {
+ return Err(Error::FrameEncodingError);
+ }
+ Ok(Self::NewToken { token })
+ }
+ FRAME_TYPE_STREAM..=FRAME_TYPE_STREAM_MAX => {
+ let s = dv(dec)?;
+ let o = if t & STREAM_FRAME_BIT_OFF == 0 {
+ 0
+ } else {
+ dv(dec)?
+ };
+ let fill = (t & STREAM_FRAME_BIT_LEN) == 0;
+ let data = if fill {
+ qtrace!("STREAM frame, extends to the end of the packet");
+ dec.decode_remainder()
+ } else {
+ qtrace!("STREAM frame, with length");
+ d(dec.decode_vvec())?
+ };
+ if o + u64::try_from(data.len()).unwrap() > ((1 << 62) - 1) {
+ return Err(Error::FrameEncodingError);
+ }
+ Ok(Self::Stream {
+ fin: (t & STREAM_FRAME_BIT_FIN) != 0,
+ stream_id: StreamId::from(s),
+ offset: o,
+ data,
+ fill,
+ })
+ }
+ FRAME_TYPE_MAX_DATA => Ok(Self::MaxData {
+ maximum_data: dv(dec)?,
+ }),
+ FRAME_TYPE_MAX_STREAM_DATA => Ok(Self::MaxStreamData {
+ stream_id: StreamId::from(dv(dec)?),
+ maximum_stream_data: dv(dec)?,
+ }),
+ FRAME_TYPE_MAX_STREAMS_BIDI | FRAME_TYPE_MAX_STREAMS_UNIDI => {
+ let m = dv(dec)?;
+ if m > (1 << 60) {
+ return Err(Error::StreamLimitError);
+ }
+ Ok(Self::MaxStreams {
+ stream_type: Self::stream_type_from_bit(t),
+ maximum_streams: m,
+ })
+ }
+ FRAME_TYPE_DATA_BLOCKED => Ok(Self::DataBlocked {
+ data_limit: dv(dec)?,
+ }),
+ FRAME_TYPE_STREAM_DATA_BLOCKED => Ok(Self::StreamDataBlocked {
+ stream_id: dv(dec)?.into(),
+ stream_data_limit: dv(dec)?,
+ }),
+ FRAME_TYPE_STREAMS_BLOCKED_BIDI | FRAME_TYPE_STREAMS_BLOCKED_UNIDI => {
+ Ok(Self::StreamsBlocked {
+ stream_type: Self::stream_type_from_bit(t),
+ stream_limit: dv(dec)?,
+ })
+ }
+ FRAME_TYPE_NEW_CONNECTION_ID => {
+ let sequence_number = dv(dec)?;
+ let retire_prior = dv(dec)?;
+ let connection_id = d(dec.decode_vec(1))?;
+ if connection_id.len() > MAX_CONNECTION_ID_LEN {
+ return Err(Error::DecodingFrame);
+ }
+ let srt = d(dec.decode(16))?;
+ let stateless_reset_token = <&[_; 16]>::try_from(srt).unwrap();
+
+ Ok(Self::NewConnectionId {
+ sequence_number,
+ retire_prior,
+ connection_id,
+ stateless_reset_token,
+ })
+ }
+ FRAME_TYPE_RETIRE_CONNECTION_ID => Ok(Self::RetireConnectionId {
+ sequence_number: dv(dec)?,
+ }),
+ FRAME_TYPE_PATH_CHALLENGE => {
+ let data = d(dec.decode(8))?;
+ let mut datav: [u8; 8] = [0; 8];
+ datav.copy_from_slice(data);
+ Ok(Self::PathChallenge { data: datav })
+ }
+ FRAME_TYPE_PATH_RESPONSE => {
+ let data = d(dec.decode(8))?;
+ let mut datav: [u8; 8] = [0; 8];
+ datav.copy_from_slice(data);
+ Ok(Self::PathResponse { data: datav })
+ }
+ FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT | FRAME_TYPE_CONNECTION_CLOSE_APPLICATION => {
+ let error_code = CloseError::from_type_bit(t, d(dec.decode_varint())?);
+ let frame_type = if t == FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT {
+ dv(dec)?
+ } else {
+ 0
+ };
+ // We can tolerate this copy for now.
+ let reason_phrase = d(dec.decode_vvec())?.to_vec();
+ Ok(Self::ConnectionClose {
+ error_code,
+ frame_type,
+ reason_phrase,
+ })
+ }
+ FRAME_TYPE_HANDSHAKE_DONE => Ok(Self::HandshakeDone),
+ FRAME_TYPE_ACK_FREQUENCY => {
+ let seqno = dv(dec)?;
+ let tolerance = dv(dec)?;
+ if tolerance == 0 {
+ return Err(Error::FrameEncodingError);
+ }
+ let delay = dv(dec)?;
+ let ignore_order = match d(dec.decode_uint(1))? {
+ 0 => false,
+ 1 => true,
+ _ => return Err(Error::FrameEncodingError),
+ };
+ Ok(Self::AckFrequency {
+ seqno,
+ tolerance,
+ delay,
+ ignore_order,
+ })
+ }
+ FRAME_TYPE_DATAGRAM | FRAME_TYPE_DATAGRAM_WITH_LEN => {
+ let fill = (t & DATAGRAM_FRAME_BIT_LEN) == 0;
+ let data = if fill {
+ qtrace!("DATAGRAM frame, extends to the end of the packet");
+ dec.decode_remainder()
+ } else {
+ qtrace!("DATAGRAM frame, with length");
+ d(dec.decode_vvec())?
+ };
+ Ok(Self::Datagram { data, fill })
+ }
+ _ => Err(Error::UnknownFrameType),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use neqo_common::{Decoder, Encoder};
+
+ use super::*;
+
+ fn just_dec(f: &Frame, s: &str) {
+ let encoded = Encoder::from_hex(s);
+ let decoded = Frame::decode(&mut encoded.as_decoder()).unwrap();
+ assert_eq!(*f, decoded);
+ }
+
+ #[test]
+ fn padding() {
+ let f = Frame::Padding;
+ just_dec(&f, "00");
+ }
+
+ #[test]
+ fn ping() {
+ let f = Frame::Ping;
+ just_dec(&f, "01");
+ }
+
+ #[test]
+ fn ack() {
+ let ar = vec![AckRange { gap: 1, range: 2 }, AckRange { gap: 3, range: 4 }];
+
+ let f = Frame::Ack {
+ largest_acknowledged: 0x1234,
+ ack_delay: 0x1235,
+ first_ack_range: 0x1236,
+ ack_ranges: ar,
+ };
+
+ just_dec(&f, "025234523502523601020304");
+
+ // Try to parse ACK_ECN without ECN values
+ let enc = Encoder::from_hex("035234523502523601020304");
+ let mut dec = enc.as_decoder();
+ assert_eq!(Frame::decode(&mut dec).unwrap_err(), Error::NoMoreData);
+
+ // Try to parse ACK_ECN without ECN values
+ let enc = Encoder::from_hex("035234523502523601020304010203");
+ let mut dec = enc.as_decoder();
+ assert_eq!(Frame::decode(&mut dec).unwrap(), f);
+ }
+
+ #[test]
+ fn reset_stream() {
+ let f = Frame::ResetStream {
+ stream_id: StreamId::from(0x1234),
+ application_error_code: 0x77,
+ final_size: 0x3456,
+ };
+
+ just_dec(&f, "04523440777456");
+ }
+
+ #[test]
+ fn stop_sending() {
+ let f = Frame::StopSending {
+ stream_id: StreamId::from(63),
+ application_error_code: 0x77,
+ };
+
+ just_dec(&f, "053F4077");
+ }
+
+ #[test]
+ fn crypto() {
+ let f = Frame::Crypto {
+ offset: 1,
+ data: &[1, 2, 3],
+ };
+
+ just_dec(&f, "060103010203");
+ }
+
+ #[test]
+ fn new_token() {
+ let f = Frame::NewToken {
+ token: &[0x12, 0x34, 0x56],
+ };
+
+ just_dec(&f, "0703123456");
+ }
+
+ #[test]
+ fn empty_new_token() {
+ let mut dec = Decoder::from(&[0x07, 0x00][..]);
+ assert_eq!(
+ Frame::decode(&mut dec).unwrap_err(),
+ Error::FrameEncodingError
+ );
+ }
+
+ #[test]
+ fn stream() {
+ // First, just set the length bit.
+ let f = Frame::Stream {
+ fin: false,
+ stream_id: StreamId::from(5),
+ offset: 0,
+ data: &[1, 2, 3],
+ fill: false,
+ };
+
+ just_dec(&f, "0a0503010203");
+
+ // Now with offset != 0 and FIN
+ let f = Frame::Stream {
+ fin: true,
+ stream_id: StreamId::from(5),
+ offset: 1,
+ data: &[1, 2, 3],
+ fill: false,
+ };
+ just_dec(&f, "0f050103010203");
+
+ // Now to fill the packet.
+ let f = Frame::Stream {
+ fin: true,
+ stream_id: StreamId::from(5),
+ offset: 0,
+ data: &[1, 2, 3],
+ fill: true,
+ };
+ just_dec(&f, "0905010203");
+ }
+
+ #[test]
+ fn max_data() {
+ let f = Frame::MaxData {
+ maximum_data: 0x1234,
+ };
+
+ just_dec(&f, "105234");
+ }
+
+ #[test]
+ fn max_stream_data() {
+ let f = Frame::MaxStreamData {
+ stream_id: StreamId::from(5),
+ maximum_stream_data: 0x1234,
+ };
+
+ just_dec(&f, "11055234");
+ }
+
+ #[test]
+ fn max_streams() {
+ let mut f = Frame::MaxStreams {
+ stream_type: StreamType::BiDi,
+ maximum_streams: 0x1234,
+ };
+
+ just_dec(&f, "125234");
+
+ f = Frame::MaxStreams {
+ stream_type: StreamType::UniDi,
+ maximum_streams: 0x1234,
+ };
+
+ just_dec(&f, "135234");
+ }
+
+ #[test]
+ fn data_blocked() {
+ let f = Frame::DataBlocked { data_limit: 0x1234 };
+
+ just_dec(&f, "145234");
+ }
+
+ #[test]
+ fn stream_data_blocked() {
+ let f = Frame::StreamDataBlocked {
+ stream_id: StreamId::from(5),
+ stream_data_limit: 0x1234,
+ };
+
+ just_dec(&f, "15055234");
+ }
+
+ #[test]
+ fn streams_blocked() {
+ let mut f = Frame::StreamsBlocked {
+ stream_type: StreamType::BiDi,
+ stream_limit: 0x1234,
+ };
+
+ just_dec(&f, "165234");
+
+ f = Frame::StreamsBlocked {
+ stream_type: StreamType::UniDi,
+ stream_limit: 0x1234,
+ };
+
+ just_dec(&f, "175234");
+ }
+
+ #[test]
+ fn new_connection_id() {
+ let f = Frame::NewConnectionId {
+ sequence_number: 0x1234,
+ retire_prior: 0,
+ connection_id: &[0x01, 0x02],
+ stateless_reset_token: &[9; 16],
+ };
+
+ just_dec(&f, "1852340002010209090909090909090909090909090909");
+ }
+
+ #[test]
+ fn too_large_new_connection_id() {
+ let mut enc = Encoder::from_hex("18523400"); // up to the CID
+ enc.encode_vvec(&[0x0c; MAX_CONNECTION_ID_LEN + 10]);
+ enc.encode(&[0x11; 16][..]);
+ assert_eq!(
+ Frame::decode(&mut enc.as_decoder()).unwrap_err(),
+ Error::DecodingFrame
+ );
+ }
+
+ #[test]
+ fn retire_connection_id() {
+ let f = Frame::RetireConnectionId {
+ sequence_number: 0x1234,
+ };
+
+ just_dec(&f, "195234");
+ }
+
+ #[test]
+ fn path_challenge() {
+ let f = Frame::PathChallenge { data: [9; 8] };
+
+ just_dec(&f, "1a0909090909090909");
+ }
+
+ #[test]
+ fn path_response() {
+ let f = Frame::PathResponse { data: [9; 8] };
+
+ just_dec(&f, "1b0909090909090909");
+ }
+
+ #[test]
+ fn connection_close_transport() {
+ let f = Frame::ConnectionClose {
+ error_code: CloseError::Transport(0x5678),
+ frame_type: 0x1234,
+ reason_phrase: vec![0x01, 0x02, 0x03],
+ };
+
+ just_dec(&f, "1c80005678523403010203");
+ }
+
+ #[test]
+ fn connection_close_application() {
+ let f = Frame::ConnectionClose {
+ error_code: CloseError::Application(0x5678),
+ frame_type: 0,
+ reason_phrase: vec![0x01, 0x02, 0x03],
+ };
+
+ just_dec(&f, "1d8000567803010203");
+ }
+
+ #[test]
+ fn test_compare() {
+ let f1 = Frame::Padding;
+ let f2 = Frame::Padding;
+ let f3 = Frame::Crypto {
+ offset: 0,
+ data: &[1, 2, 3],
+ };
+ let f4 = Frame::Crypto {
+ offset: 0,
+ data: &[1, 2, 3],
+ };
+ let f5 = Frame::Crypto {
+ offset: 1,
+ data: &[1, 2, 3],
+ };
+ let f6 = Frame::Crypto {
+ offset: 0,
+ data: &[1, 2, 4],
+ };
+
+ assert_eq!(f1, f2);
+ assert_ne!(f1, f3);
+ assert_eq!(f3, f4);
+ assert_ne!(f3, f5);
+ assert_ne!(f3, f6);
+ }
+
+ #[test]
+ fn decode_ack_frame() {
+ let res = Frame::decode_ack_frame(7, 2, &[AckRange { gap: 0, range: 3 }]);
+ assert!(res.is_ok());
+ assert_eq!(res.unwrap(), vec![5..=7, 0..=3]);
+ }
+
+ #[test]
+ fn ack_frequency() {
+ let f = Frame::AckFrequency {
+ seqno: 10,
+ tolerance: 5,
+ delay: 2000,
+ ignore_order: true,
+ };
+ just_dec(&f, "40af0a0547d001");
+ }
+
+ #[test]
+ fn ack_frequency_ignore_error_error() {
+ let enc = Encoder::from_hex("40af0a0547d003"); // ignore_order of 3
+ assert_eq!(
+ Frame::decode(&mut enc.as_decoder()).unwrap_err(),
+ Error::FrameEncodingError
+ );
+ }
+
+ /// Hopefully this test is eventually redundant.
+ #[test]
+ fn ack_frequency_zero_packets() {
+ let enc = Encoder::from_hex("40af0a000101"); // packets of 0
+ assert_eq!(
+ Frame::decode(&mut enc.as_decoder()).unwrap_err(),
+ Error::FrameEncodingError
+ );
+ }
+
+ #[test]
+ fn datagram() {
+ // Without the length bit.
+ let f = Frame::Datagram {
+ data: &[1, 2, 3],
+ fill: true,
+ };
+
+ just_dec(&f, "4030010203");
+
+ // With the length bit.
+ let f = Frame::Datagram {
+ data: &[1, 2, 3],
+ fill: false,
+ };
+ just_dec(&f, "403103010203");
+ }
+
+ #[test]
+ fn frame_decode_enforces_bound_on_ack_range() {
+ let mut e = Encoder::new();
+
+ e.encode_varint(FRAME_TYPE_ACK);
+ e.encode_varint(0u64); // largest acknowledged
+ e.encode_varint(0u64); // ACK delay
+ e.encode_varint(u32::MAX); // ACK range count = huge, but maybe available for allocation
+
+ assert_eq!(Err(Error::TooMuchData), Frame::decode(&mut e.as_decoder()));
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/lib.rs b/third_party/rust/neqo-transport/src/lib.rs
new file mode 100644
index 0000000000..ecf7ee2f73
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/lib.rs
@@ -0,0 +1,226 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![cfg_attr(feature = "deny-warnings", deny(warnings))]
+#![warn(clippy::use_self)]
+
+use neqo_common::qinfo;
+use neqo_crypto::Error as CryptoError;
+
+mod ackrate;
+mod addr_valid;
+mod cc;
+mod cid;
+mod connection;
+mod crypto;
+mod events;
+mod fc;
+mod frame;
+mod pace;
+mod packet;
+mod path;
+mod qlog;
+mod quic_datagrams;
+mod recovery;
+#[cfg(feature = "bench")]
+pub mod recv_stream;
+#[cfg(not(feature = "bench"))]
+mod recv_stream;
+mod rtt;
+mod send_stream;
+mod sender;
+pub mod server;
+mod stats;
+pub mod stream_id;
+pub mod streams;
+pub mod tparams;
+mod tracking;
+pub mod version;
+
+pub use self::{
+ cc::CongestionControlAlgorithm,
+ cid::{
+ ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef,
+ EmptyConnectionIdGenerator, RandomConnectionIdGenerator,
+ },
+ connection::{
+ params::{ConnectionParameters, ACK_RATIO_SCALE},
+ Connection, Output, State, ZeroRttState,
+ },
+ events::{ConnectionEvent, ConnectionEvents},
+ frame::CloseError,
+ quic_datagrams::DatagramTracking,
+ recv_stream::{RecvStreamStats, RECV_BUFFER_SIZE},
+ send_stream::{SendStreamStats, SEND_BUFFER_SIZE},
+ stats::Stats,
+ stream_id::{StreamId, StreamType},
+ version::Version,
+};
+
+pub type TransportError = u64;
+const ERROR_APPLICATION_CLOSE: TransportError = 12;
+const ERROR_CRYPTO_BUFFER_EXCEEDED: TransportError = 13;
+const ERROR_AEAD_LIMIT_REACHED: TransportError = 15;
+
+#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)]
+pub enum Error {
+ NoError,
+ // Each time tihe error is return a different parameter is supply.
+ // This will be use to distinguish each occurance of this error.
+ InternalError,
+ ConnectionRefused,
+ FlowControlError,
+ StreamLimitError,
+ StreamStateError,
+ FinalSizeError,
+ FrameEncodingError,
+ TransportParameterError,
+ ProtocolViolation,
+ InvalidToken,
+ ApplicationError,
+ CryptoBufferExceeded,
+ CryptoError(CryptoError),
+ QlogError,
+ CryptoAlert(u8),
+ EchRetry(Vec<u8>),
+
+ // All internal errors from here. Please keep these sorted.
+ AckedUnsentPacket,
+ ConnectionIdLimitExceeded,
+ ConnectionIdsExhausted,
+ ConnectionState,
+ DecodingFrame,
+ DecryptError,
+ DisabledVersion,
+ HandshakeFailed,
+ IdleTimeout,
+ IntegerOverflow,
+ InvalidInput,
+ InvalidMigration,
+ InvalidPacket,
+ InvalidResumptionToken,
+ InvalidRetry,
+ InvalidStreamId,
+ KeysDiscarded(crypto::CryptoSpace),
+ /// Packet protection keys are exhausted.
+ /// Also used when too many key updates have happened.
+ KeysExhausted,
+ /// Packet protection keys aren't available yet for the identified space.
+ KeysPending(crypto::CryptoSpace),
+ /// An attempt to update keys can be blocked if
+ /// a packet sent with the current keys hasn't been acknowledged.
+ KeyUpdateBlocked,
+ NoAvailablePath,
+ NoMoreData,
+ NotConnected,
+ PacketNumberOverlap,
+ PeerApplicationError(AppError),
+ PeerError(TransportError),
+ StatelessReset,
+ TooMuchData,
+ UnexpectedMessage,
+ UnknownConnectionId,
+ UnknownFrameType,
+ VersionNegotiation,
+ WrongRole,
+ NotAvailable,
+}
+
+impl Error {
+ pub fn code(&self) -> TransportError {
+ match self {
+ Self::NoError
+ | Self::IdleTimeout
+ | Self::PeerError(_)
+ | Self::PeerApplicationError(_) => 0,
+ Self::ConnectionRefused => 2,
+ Self::FlowControlError => 3,
+ Self::StreamLimitError => 4,
+ Self::StreamStateError => 5,
+ Self::FinalSizeError => 6,
+ Self::FrameEncodingError => 7,
+ Self::TransportParameterError => 8,
+ Self::ProtocolViolation => 10,
+ Self::InvalidToken => 11,
+ Self::KeysExhausted => ERROR_AEAD_LIMIT_REACHED,
+ Self::ApplicationError => ERROR_APPLICATION_CLOSE,
+ Self::NoAvailablePath => 16,
+ Self::CryptoBufferExceeded => ERROR_CRYPTO_BUFFER_EXCEEDED,
+ Self::CryptoAlert(a) => 0x100 + u64::from(*a),
+ // As we have a special error code for ECH fallbacks, we lose the alert.
+ // Send the server "ech_required" directly.
+ Self::EchRetry(_) => 0x100 + 121,
+ Self::VersionNegotiation => 0x53f8,
+ // All the rest are internal errors.
+ _ => 1,
+ }
+ }
+}
+
+impl From<CryptoError> for Error {
+ fn from(err: CryptoError) -> Self {
+ qinfo!("Crypto operation failed {:?}", err);
+ match err {
+ CryptoError::EchRetry(config) => Self::EchRetry(config),
+ _ => Self::CryptoError(err),
+ }
+ }
+}
+
+impl From<::qlog::Error> for Error {
+ fn from(_err: ::qlog::Error) -> Self {
+ Self::QlogError
+ }
+}
+
+impl From<std::num::TryFromIntError> for Error {
+ fn from(_: std::num::TryFromIntError) -> Self {
+ Self::IntegerOverflow
+ }
+}
+
+impl ::std::error::Error for Error {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ match self {
+ Self::CryptoError(e) => Some(e),
+ _ => None,
+ }
+ }
+}
+
+impl ::std::fmt::Display for Error {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "Transport error: {self:?}")
+ }
+}
+
+pub type AppError = u64;
+
+#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)]
+pub enum ConnectionError {
+ Transport(Error),
+ Application(AppError),
+}
+
+impl ConnectionError {
+ pub fn app_code(&self) -> Option<AppError> {
+ match self {
+ Self::Application(e) => Some(*e),
+ Self::Transport(_) => None,
+ }
+ }
+}
+
+impl From<CloseError> for ConnectionError {
+ fn from(err: CloseError) -> Self {
+ match err {
+ CloseError::Transport(c) => Self::Transport(Error::PeerError(c)),
+ CloseError::Application(c) => Self::Application(c),
+ }
+ }
+}
+
+pub type Res<T> = std::result::Result<T, Error>;
diff --git a/third_party/rust/neqo-transport/src/pace.rs b/third_party/rust/neqo-transport/src/pace.rs
new file mode 100644
index 0000000000..e5214c1bc8
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/pace.rs
@@ -0,0 +1,165 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Pacer
+#![deny(clippy::pedantic)]
+
+use std::{
+ cmp::min,
+ convert::TryFrom,
+ fmt::{Debug, Display},
+ time::{Duration, Instant},
+};
+
+use neqo_common::qtrace;
+
+/// This value determines how much faster the pacer operates than the
+/// congestion window.
+///
+/// A value of 1 would cause all packets to be spaced over the entire RTT,
+/// which is a little slow and might act as an additional restriction in
+/// the case the congestion controller increases the congestion window.
+/// This value spaces packets over half the congestion window, which matches
+/// our current congestion controller, which double the window every RTT.
+const PACER_SPEEDUP: usize = 2;
+
+/// A pacer that uses a leaky bucket.
+pub struct Pacer {
+ /// Whether pacing is enabled.
+ enabled: bool,
+ /// The last update time.
+ t: Instant,
+ /// The maximum capacity, or burst size, in bytes.
+ m: usize,
+ /// The current capacity, in bytes.
+ c: usize,
+ /// The packet size or minimum capacity for sending, in bytes.
+ p: usize,
+}
+
+impl Pacer {
+ /// Create a new `Pacer`. This takes the current time, the maximum burst size,
+ /// and the packet size.
+ ///
+ /// The value of `m` is the maximum capacity in bytes. `m` primes the pacer
+ /// with credit and determines the burst size. `m` must not exceed
+ /// the initial congestion window, but it should probably be lower.
+ ///
+ /// The value of `p` is the packet size in bytes, which determines the minimum
+ /// credit needed before a packet is sent. This should be a substantial
+ /// fraction of the maximum packet size, if not the packet size.
+ pub fn new(enabled: bool, now: Instant, m: usize, p: usize) -> Self {
+ assert!(m >= p, "maximum capacity has to be at least one packet");
+ Self {
+ enabled,
+ t: now,
+ m,
+ c: m,
+ p,
+ }
+ }
+
+ /// Determine when the next packet will be available based on the provided RTT
+ /// and congestion window. This doesn't update state.
+ /// This returns a time, which could be in the past (this object doesn't know what
+ /// the current time is).
+ pub fn next(&self, rtt: Duration, cwnd: usize) -> Instant {
+ if self.c >= self.p {
+ qtrace!([self], "next {}/{:?} no wait = {:?}", cwnd, rtt, self.t);
+ self.t
+ } else {
+ // This is the inverse of the function in `spend`:
+ // self.t + rtt * (self.p - self.c) / (PACER_SPEEDUP * cwnd)
+ let r = rtt.as_nanos();
+ let d = r.saturating_mul(u128::try_from(self.p - self.c).unwrap());
+ let add = d / u128::try_from(cwnd * PACER_SPEEDUP).unwrap();
+ let w = u64::try_from(add).map(Duration::from_nanos).unwrap_or(rtt);
+ let nxt = self.t + w;
+ qtrace!([self], "next {}/{:?} wait {:?} = {:?}", cwnd, rtt, w, nxt);
+ nxt
+ }
+ }
+
+ /// Spend credit. This cannot fail; users of this API are expected to call
+ /// `next()` to determine when to spend. This takes the current time (`now`),
+ /// an estimate of the round trip time (`rtt`), the estimated congestion
+ /// window (`cwnd`), and the number of bytes that were sent (`count`).
+ pub fn spend(&mut self, now: Instant, rtt: Duration, cwnd: usize, count: usize) {
+ if !self.enabled {
+ self.t = now;
+ return;
+ }
+
+ qtrace!([self], "spend {} over {}, {:?}", count, cwnd, rtt);
+ // Increase the capacity by:
+ // `(now - self.t) * PACER_SPEEDUP * cwnd / rtt`
+ // That is, the elapsed fraction of the RTT times rate that data is added.
+ let incr = now
+ .saturating_duration_since(self.t)
+ .as_nanos()
+ .saturating_mul(u128::try_from(cwnd * PACER_SPEEDUP).unwrap())
+ .checked_div(rtt.as_nanos())
+ .and_then(|i| usize::try_from(i).ok())
+ .unwrap_or(self.m);
+
+ // Add the capacity up to a limit of `self.m`, then subtract `count`.
+ self.c = min(self.m, (self.c + incr).saturating_sub(count));
+ self.t = now;
+ }
+}
+
+impl Display for Pacer {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "Pacer {}/{}", self.c, self.p)
+ }
+}
+
+impl Debug for Pacer {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "Pacer@{:?} {}/{}..{}", self.t, self.c, self.p, self.m)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::time::Duration;
+
+ use test_fixture::now;
+
+ use super::Pacer;
+
+ const RTT: Duration = Duration::from_millis(1000);
+ const PACKET: usize = 1000;
+ const CWND: usize = PACKET * 10;
+
+ #[test]
+ fn even() {
+ let n = now();
+ let mut p = Pacer::new(true, n, PACKET, PACKET);
+ assert_eq!(p.next(RTT, CWND), n);
+ p.spend(n, RTT, CWND, PACKET);
+ assert_eq!(p.next(RTT, CWND), n + (RTT / 20));
+ }
+
+ #[test]
+ fn backwards_in_time() {
+ let n = now();
+ let mut p = Pacer::new(true, n + RTT, PACKET, PACKET);
+ assert_eq!(p.next(RTT, CWND), n + RTT);
+ // Now spend some credit in the past using a time machine.
+ p.spend(n, RTT, CWND, PACKET);
+ assert_eq!(p.next(RTT, CWND), n + (RTT / 20));
+ }
+
+ #[test]
+ fn pacing_disabled() {
+ let n = now();
+ let mut p = Pacer::new(false, n, PACKET, PACKET);
+ assert_eq!(p.next(RTT, CWND), n);
+ p.spend(n, RTT, CWND, PACKET);
+ assert_eq!(p.next(RTT, CWND), n);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/packet/mod.rs b/third_party/rust/neqo-transport/src/packet/mod.rs
new file mode 100644
index 0000000000..ccfd212d5f
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/packet/mod.rs
@@ -0,0 +1,1457 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Encoding and decoding packets off the wire.
+use std::{
+ cmp::min,
+ convert::TryFrom,
+ fmt,
+ iter::ExactSizeIterator,
+ ops::{Deref, DerefMut, Range},
+ time::Instant,
+};
+
+use neqo_common::{hex, hex_with_len, qtrace, qwarn, Decoder, Encoder};
+use neqo_crypto::random;
+
+use crate::{
+ cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdRef, MAX_CONNECTION_ID_LEN},
+ crypto::{CryptoDxState, CryptoSpace, CryptoStates},
+ version::{Version, WireVersion},
+ Error, Res,
+};
+
+pub const PACKET_BIT_LONG: u8 = 0x80;
+const PACKET_BIT_SHORT: u8 = 0x00;
+const PACKET_BIT_FIXED_QUIC: u8 = 0x40;
+const PACKET_BIT_SPIN: u8 = 0x20;
+const PACKET_BIT_KEY_PHASE: u8 = 0x04;
+
+const PACKET_HP_MASK_LONG: u8 = 0x0f;
+const PACKET_HP_MASK_SHORT: u8 = 0x1f;
+
+const SAMPLE_SIZE: usize = 16;
+const SAMPLE_OFFSET: usize = 4;
+const MAX_PACKET_NUMBER_LEN: usize = 4;
+
+mod retry;
+
+pub type PacketNumber = u64;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum PacketType {
+ VersionNegotiation,
+ Initial,
+ Handshake,
+ ZeroRtt,
+ Retry,
+ Short,
+ OtherVersion,
+}
+
+impl PacketType {
+ #[must_use]
+ fn from_byte(t: u8, v: Version) -> Self {
+ // Version2 adds one to the type, modulo 4
+ match t.wrapping_sub(u8::from(v == Version::Version2)) & 3 {
+ 0 => Self::Initial,
+ 1 => Self::ZeroRtt,
+ 2 => Self::Handshake,
+ 3 => Self::Retry,
+ _ => panic!("packet type out of range"),
+ }
+ }
+
+ #[must_use]
+ fn to_byte(self, v: Version) -> u8 {
+ let t = match self {
+ Self::Initial => 0,
+ Self::ZeroRtt => 1,
+ Self::Handshake => 2,
+ Self::Retry => 3,
+ _ => panic!("not a long header packet type"),
+ };
+ // Version2 adds one to the type, modulo 4
+ (t + u8::from(v == Version::Version2)) & 3
+ }
+}
+
+impl From<PacketType> for CryptoSpace {
+ fn from(v: PacketType) -> Self {
+ match v {
+ PacketType::Initial => Self::Initial,
+ PacketType::ZeroRtt => Self::ZeroRtt,
+ PacketType::Handshake => Self::Handshake,
+ PacketType::Short => Self::ApplicationData,
+ _ => panic!("shouldn't be here"),
+ }
+ }
+}
+
+impl From<CryptoSpace> for PacketType {
+ fn from(cs: CryptoSpace) -> Self {
+ match cs {
+ CryptoSpace::Initial => Self::Initial,
+ CryptoSpace::ZeroRtt => Self::ZeroRtt,
+ CryptoSpace::Handshake => Self::Handshake,
+ CryptoSpace::ApplicationData => Self::Short,
+ }
+ }
+}
+
+struct PacketBuilderOffsets {
+ /// The bits of the first octet that need masking.
+ first_byte_mask: u8,
+ /// The offset of the length field.
+ len: usize,
+ /// The location of the packet number field.
+ pn: Range<usize>,
+}
+
+/// A packet builder that can be used to produce short packets and long packets.
+/// This does not produce Retry or Version Negotiation.
+pub struct PacketBuilder {
+ encoder: Encoder,
+ pn: PacketNumber,
+ header: Range<usize>,
+ offsets: PacketBuilderOffsets,
+ limit: usize,
+ /// Whether to pad the packet before construction.
+ padding: bool,
+}
+
+impl PacketBuilder {
+ /// The minimum useful frame size. If space is less than this, we will claim to be full.
+ pub const MINIMUM_FRAME_SIZE: usize = 2;
+
+ fn infer_limit(encoder: &Encoder) -> usize {
+ if encoder.capacity() > 64 {
+ encoder.capacity()
+ } else {
+ 2048
+ }
+ }
+
+ /// Start building a short header packet.
+ ///
+ /// This doesn't fail if there isn't enough space; instead it returns a builder that
+ /// has no available space left. This allows the caller to extract the encoder
+ /// and any packets that might have been added before as adding a packet header is
+ /// only likely to fail if there are other packets already written.
+ ///
+ /// If, after calling this method, `remaining()` returns 0, then call `abort()` to get
+ /// the encoder back.
+ #[allow(clippy::reversed_empty_ranges)]
+ pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self {
+ let mut limit = Self::infer_limit(&encoder);
+ let header_start = encoder.len();
+ // Check that there is enough space for the header.
+ // 5 = 1 (first byte) + 4 (packet number)
+ if limit > encoder.len() && 5 + dcid.as_ref().len() < limit - encoder.len() {
+ encoder
+ .encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2));
+ encoder.encode(dcid.as_ref());
+ } else {
+ limit = 0;
+ }
+ Self {
+ encoder,
+ pn: u64::max_value(),
+ header: header_start..header_start,
+ offsets: PacketBuilderOffsets {
+ first_byte_mask: PACKET_HP_MASK_SHORT,
+ pn: 0..0,
+ len: 0,
+ },
+ limit,
+ padding: false,
+ }
+ }
+
+ /// Start building a long header packet.
+ /// For an Initial packet you will need to call initial_token(),
+ /// even if the token is empty.
+ ///
+ /// See `short()` for more on how to handle this in cases where there is no space.
+ #[allow(clippy::reversed_empty_ranges)] // For initializing an empty range.
+ pub fn long(
+ mut encoder: Encoder,
+ pt: PacketType,
+ version: Version,
+ dcid: impl AsRef<[u8]>,
+ scid: impl AsRef<[u8]>,
+ ) -> Self {
+ let mut limit = Self::infer_limit(&encoder);
+ let header_start = encoder.len();
+ // Check that there is enough space for the header.
+ // 11 = 1 (first byte) + 4 (version) + 2 (dcid+scid length) + 4 (packet number)
+ if limit > encoder.len()
+ && 11 + dcid.as_ref().len() + scid.as_ref().len() < limit - encoder.len()
+ {
+ encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.to_byte(version) << 4);
+ encoder.encode_uint(4, version.wire_version());
+ encoder.encode_vec(1, dcid.as_ref());
+ encoder.encode_vec(1, scid.as_ref());
+ } else {
+ limit = 0;
+ }
+
+ Self {
+ encoder,
+ pn: u64::max_value(),
+ header: header_start..header_start,
+ offsets: PacketBuilderOffsets {
+ first_byte_mask: PACKET_HP_MASK_LONG,
+ pn: 0..0,
+ len: 0,
+ },
+ limit,
+ padding: false,
+ }
+ }
+
+ fn is_long(&self) -> bool {
+ self.as_ref()[self.header.start] & 0x80 == PACKET_BIT_LONG
+ }
+
+ /// This stores a value that can be used as a limit. This does not cause
+ /// this limit to be enforced until encryption occurs. Prior to that, it
+ /// is only used voluntarily by users of the builder, through `remaining()`.
+ pub fn set_limit(&mut self, limit: usize) {
+ self.limit = limit;
+ }
+
+ /// Get the current limit.
+ #[must_use]
+ pub fn limit(&mut self) -> usize {
+ self.limit
+ }
+
+ /// How many bytes remain against the size limit for the builder.
+ #[must_use]
+ pub fn remaining(&self) -> usize {
+ self.limit.saturating_sub(self.encoder.len())
+ }
+
+ /// Returns true if the packet has no more space for frames.
+ #[must_use]
+ pub fn is_full(&self) -> bool {
+ // No useful frame is smaller than 2 bytes long.
+ self.limit < self.encoder.len() + Self::MINIMUM_FRAME_SIZE
+ }
+
+ /// Adjust the limit to ensure that no more data is added.
+ pub fn mark_full(&mut self) {
+ self.limit = self.encoder.len();
+ }
+
+ /// Mark the packet as needing padding (or not).
+ pub fn enable_padding(&mut self, needs_padding: bool) {
+ self.padding = needs_padding;
+ }
+
+ /// Maybe pad with "PADDING" frames.
+ /// Only does so if padding was needed and this is a short packet.
+ /// Returns true if padding was added.
+ pub fn pad(&mut self) -> bool {
+ if self.padding && !self.is_long() {
+ self.encoder.pad_to(self.limit, 0);
+ true
+ } else {
+ false
+ }
+ }
+
+ /// Add unpredictable values for unprotected parts of the packet.
+ pub fn scramble(&mut self, quic_bit: bool) {
+ debug_assert!(self.len() > self.header.start);
+ let mask = if quic_bit { PACKET_BIT_FIXED_QUIC } else { 0 }
+ | if self.is_long() { 0 } else { PACKET_BIT_SPIN };
+ let first = self.header.start;
+ self.encoder.as_mut()[first] ^= random(1)[0] & mask;
+ }
+
+ /// For an Initial packet, encode the token.
+ /// If you fail to do this, then you will not get a valid packet.
+ pub fn initial_token(&mut self, token: &[u8]) {
+ if Encoder::vvec_len(token.len()) < self.remaining() {
+ self.encoder.encode_vvec(token);
+ } else {
+ self.limit = 0;
+ }
+ }
+
+ /// Add a packet number of the given size.
+ /// For a long header packet, this also inserts a dummy length.
+ /// The length is filled in after calling `build`.
+ /// Does nothing if there isn't 4 bytes available other than render this builder
+ /// unusable; if `remaining()` returns 0 at any point, call `abort()`.
+ pub fn pn(&mut self, pn: PacketNumber, pn_len: usize) {
+ if self.remaining() < 4 {
+ self.limit = 0;
+ return;
+ }
+
+ // Reserve space for a length in long headers.
+ if self.is_long() {
+ self.offsets.len = self.encoder.len();
+ self.encoder.encode(&[0; 2]);
+ }
+
+ // This allows the input to be >4, which is absurd, but we can eat that.
+ let pn_len = min(MAX_PACKET_NUMBER_LEN, pn_len);
+ debug_assert_ne!(pn_len, 0);
+ // Encode the packet number and save its offset.
+ let pn_offset = self.encoder.len();
+ self.encoder.encode_uint(pn_len, pn);
+ self.offsets.pn = pn_offset..self.encoder.len();
+
+ // Now encode the packet number length and save the header length.
+ self.encoder.as_mut()[self.header.start] |= u8::try_from(pn_len - 1).unwrap();
+ self.header.end = self.encoder.len();
+ self.pn = pn;
+ }
+
+ fn write_len(&mut self, expansion: usize) {
+ let len = self.encoder.len() - (self.offsets.len + 2) + expansion;
+ self.encoder.as_mut()[self.offsets.len] = 0x40 | ((len >> 8) & 0x3f) as u8;
+ self.encoder.as_mut()[self.offsets.len + 1] = (len & 0xff) as u8;
+ }
+
+ fn pad_for_crypto(&mut self, crypto: &mut CryptoDxState) {
+ // Make sure that there is enough data in the packet.
+ // The length of the packet number plus the payload length needs to
+ // be at least 4 (MAX_PACKET_NUMBER_LEN) plus any amount by which
+ // the header protection sample exceeds the AEAD expansion.
+ let crypto_pad = crypto.extra_padding();
+ self.encoder.pad_to(
+ self.offsets.pn.start + MAX_PACKET_NUMBER_LEN + crypto_pad,
+ 0,
+ );
+ }
+
+ /// A lot of frames here are just a collection of varints.
+ /// This helper functions writes a frame like that safely, returning `true` if
+ /// a frame was written.
+ pub fn write_varint_frame(&mut self, values: &[u64]) -> bool {
+ let write = self.remaining()
+ >= values
+ .iter()
+ .map(|&v| Encoder::varint_len(v))
+ .sum::<usize>();
+ if write {
+ for v in values {
+ self.encode_varint(*v);
+ }
+ debug_assert!(self.len() <= self.limit());
+ };
+ write
+ }
+
+ /// Build the packet and return the encoder.
+ pub fn build(mut self, crypto: &mut CryptoDxState) -> Res<Encoder> {
+ if self.len() > self.limit {
+ qwarn!("Packet contents are more than the limit");
+ debug_assert!(false);
+ return Err(Error::InternalError);
+ }
+
+ self.pad_for_crypto(crypto);
+ if self.offsets.len > 0 {
+ self.write_len(crypto.expansion());
+ }
+
+ let hdr = &self.encoder.as_ref()[self.header.clone()];
+ let body = &self.encoder.as_ref()[self.header.end..];
+ qtrace!(
+ "Packet build pn={} hdr={} body={}",
+ self.pn,
+ hex(hdr),
+ hex(body)
+ );
+ let ciphertext = crypto.encrypt(self.pn, hdr, body)?;
+
+ // Calculate the mask.
+ let offset = SAMPLE_OFFSET - self.offsets.pn.len();
+ assert!(offset + SAMPLE_SIZE <= ciphertext.len());
+ let sample = &ciphertext[offset..offset + SAMPLE_SIZE];
+ let mask = crypto.compute_mask(sample)?;
+
+ // Apply the mask.
+ self.encoder.as_mut()[self.header.start] ^= mask[0] & self.offsets.first_byte_mask;
+ for (i, j) in (1..=self.offsets.pn.len()).zip(self.offsets.pn) {
+ self.encoder.as_mut()[j] ^= mask[i];
+ }
+
+ // Finally, cut off the plaintext and add back the ciphertext.
+ self.encoder.truncate(self.header.end);
+ self.encoder.encode(&ciphertext);
+ qtrace!("Packet built {}", hex(&self.encoder));
+ Ok(self.encoder)
+ }
+
+ /// Abort writing of this packet and return the encoder.
+ #[must_use]
+ pub fn abort(mut self) -> Encoder {
+ self.encoder.truncate(self.header.start);
+ self.encoder
+ }
+
+ /// Work out if nothing was added after the header.
+ #[must_use]
+ pub fn packet_empty(&self) -> bool {
+ self.encoder.len() == self.header.end
+ }
+
+ /// Make a retry packet.
+ /// As this is a simple packet, this is just an associated function.
+ /// As Retry is odd (it has to be constructed with leading bytes),
+ /// this returns a [`Vec<u8>`] rather than building on an encoder.
+ pub fn retry(
+ version: Version,
+ dcid: &[u8],
+ scid: &[u8],
+ token: &[u8],
+ odcid: &[u8],
+ ) -> Res<Vec<u8>> {
+ let mut encoder = Encoder::default();
+ encoder.encode_vec(1, odcid);
+ let start = encoder.len();
+ encoder.encode_byte(
+ PACKET_BIT_LONG
+ | PACKET_BIT_FIXED_QUIC
+ | (PacketType::Retry.to_byte(version) << 4)
+ | (random(1)[0] & 0xf),
+ );
+ encoder.encode_uint(4, version.wire_version());
+ encoder.encode_vec(1, dcid);
+ encoder.encode_vec(1, scid);
+ debug_assert_ne!(token.len(), 0);
+ encoder.encode(token);
+ let tag = retry::use_aead(version, |aead| {
+ let mut buf = vec![0; aead.expansion()];
+ Ok(aead.encrypt(0, encoder.as_ref(), &[], &mut buf)?.to_vec())
+ })?;
+ encoder.encode(&tag);
+ let mut complete: Vec<u8> = encoder.into();
+ Ok(complete.split_off(start))
+ }
+
+ /// Make a Version Negotiation packet.
+ pub fn version_negotiation(
+ dcid: &[u8],
+ scid: &[u8],
+ client_version: u32,
+ versions: &[Version],
+ ) -> Vec<u8> {
+ let mut encoder = Encoder::default();
+ let mut grease = random(4);
+ // This will not include the "QUIC bit" sometimes. Intentionally.
+ encoder.encode_byte(PACKET_BIT_LONG | (grease[3] & 0x7f));
+ encoder.encode(&[0; 4]); // Zero version == VN.
+ encoder.encode_vec(1, dcid);
+ encoder.encode_vec(1, scid);
+
+ for v in versions {
+ encoder.encode_uint(4, v.wire_version());
+ }
+ // Add a greased version, using the randomness already generated.
+ for g in &mut grease[..3] {
+ *g = *g & 0xf0 | 0x0a;
+ }
+
+ // Ensure our greased version does not collide with the client version
+ // by making the last byte differ from the client initial.
+ grease[3] = (client_version.wrapping_add(0x10) & 0xf0) as u8 | 0x0a;
+ encoder.encode(&grease[..4]);
+
+ Vec::from(encoder)
+ }
+}
+
+impl Deref for PacketBuilder {
+ type Target = Encoder;
+
+ fn deref(&self) -> &Self::Target {
+ &self.encoder
+ }
+}
+
+impl DerefMut for PacketBuilder {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.encoder
+ }
+}
+
+impl From<PacketBuilder> for Encoder {
+ fn from(v: PacketBuilder) -> Self {
+ v.encoder
+ }
+}
+
+/// PublicPacket holds information from packets that is public only. This allows for
+/// processing of packets prior to decryption.
+pub struct PublicPacket<'a> {
+ /// The packet type.
+ packet_type: PacketType,
+ /// The recovered destination connection ID.
+ dcid: ConnectionIdRef<'a>,
+ /// The source connection ID, if this is a long header packet.
+ scid: Option<ConnectionIdRef<'a>>,
+ /// Any token that is included in the packet (Retry always has a token; Initial sometimes
+ /// does). This is empty when there is no token.
+ token: &'a [u8],
+ /// The size of the header, not including the packet number.
+ header_len: usize,
+ /// Protocol version, if present in header.
+ version: Option<WireVersion>,
+ /// A reference to the entire packet, including the header.
+ data: &'a [u8],
+}
+
+impl<'a> PublicPacket<'a> {
+ fn opt<T>(v: Option<T>) -> Res<T> {
+ if let Some(v) = v {
+ Ok(v)
+ } else {
+ Err(Error::NoMoreData)
+ }
+ }
+
+ /// Decode the type-specific portions of a long header.
+ /// This includes reading the length and the remainder of the packet.
+ /// Returns a tuple of any token and the length of the header.
+ fn decode_long(
+ decoder: &mut Decoder<'a>,
+ packet_type: PacketType,
+ version: Version,
+ ) -> Res<(&'a [u8], usize)> {
+ if packet_type == PacketType::Retry {
+ let header_len = decoder.offset();
+ let expansion = retry::expansion(version);
+ let token = Self::opt(decoder.decode(decoder.remaining() - expansion))?;
+ if token.is_empty() {
+ return Err(Error::InvalidPacket);
+ }
+ Self::opt(decoder.decode(expansion))?;
+ return Ok((token, header_len));
+ }
+ let token = if packet_type == PacketType::Initial {
+ Self::opt(decoder.decode_vvec())?
+ } else {
+ &[]
+ };
+ let len = Self::opt(decoder.decode_varint())?;
+ let header_len = decoder.offset();
+ let _body = Self::opt(decoder.decode(usize::try_from(len)?))?;
+ Ok((token, header_len))
+ }
+
+ /// Decode the common parts of a packet. This provides minimal parsing and validation.
+ /// Returns a tuple of a `PublicPacket` and a slice with any remainder from the datagram.
+ pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> {
+ let mut decoder = Decoder::new(data);
+ let first = Self::opt(decoder.decode_byte())?;
+
+ if first & 0x80 == PACKET_BIT_SHORT {
+ // Conveniently, this also guarantees that there is enough space
+ // for a connection ID of any size.
+ if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE {
+ return Err(Error::InvalidPacket);
+ }
+ let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?;
+ if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE {
+ return Err(Error::InvalidPacket);
+ }
+ let header_len = decoder.offset();
+ return Ok((
+ Self {
+ packet_type: PacketType::Short,
+ dcid,
+ scid: None,
+ token: &[],
+ header_len,
+ version: None,
+ data,
+ },
+ &[],
+ ));
+ }
+
+ // Generic long header.
+ let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?).unwrap();
+ let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
+ let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
+
+ // Version negotiation.
+ if version == 0 {
+ return Ok((
+ Self {
+ packet_type: PacketType::VersionNegotiation,
+ dcid,
+ scid: Some(scid),
+ token: &[],
+ header_len: decoder.offset(),
+ version: None,
+ data,
+ },
+ &[],
+ ));
+ }
+
+ // Check that this is a long header from a supported version.
+ let Ok(version) = Version::try_from(version) else {
+ return Ok((
+ Self {
+ packet_type: PacketType::OtherVersion,
+ dcid,
+ scid: Some(scid),
+ token: &[],
+ header_len: decoder.offset(),
+ version: Some(version),
+ data,
+ },
+ &[],
+ ));
+ };
+
+ if dcid.len() > MAX_CONNECTION_ID_LEN || scid.len() > MAX_CONNECTION_ID_LEN {
+ return Err(Error::InvalidPacket);
+ }
+ let packet_type = PacketType::from_byte((first >> 4) & 3, version);
+
+ // The type-specific code includes a token. This consumes the remainder of the packet.
+ let (token, header_len) = Self::decode_long(&mut decoder, packet_type, version)?;
+ let end = data.len() - decoder.remaining();
+ let (data, remainder) = data.split_at(end);
+ Ok((
+ Self {
+ packet_type,
+ dcid,
+ scid: Some(scid),
+ token,
+ header_len,
+ version: Some(version.wire_version()),
+ data,
+ },
+ remainder,
+ ))
+ }
+
+ /// Validate the given packet as though it were a retry.
+ pub fn is_valid_retry(&self, odcid: &ConnectionId) -> bool {
+ if self.packet_type != PacketType::Retry {
+ return false;
+ }
+ let version = self.version().unwrap();
+ let expansion = retry::expansion(version);
+ if self.data.len() <= expansion {
+ return false;
+ }
+ let (header, tag) = self.data.split_at(self.data.len() - expansion);
+ let mut encoder = Encoder::with_capacity(self.data.len());
+ encoder.encode_vec(1, odcid);
+ encoder.encode(header);
+ retry::use_aead(version, |aead| {
+ let mut buf = vec![0; expansion];
+ Ok(aead.decrypt(0, encoder.as_ref(), tag, &mut buf)?.is_empty())
+ })
+ .unwrap_or(false)
+ }
+
+ pub fn is_valid_initial(&self) -> bool {
+ // Packet has to be an initial, with a DCID of 8 bytes, or a token.
+ // Note: the Server class validates the token and checks the length.
+ self.packet_type == PacketType::Initial
+ && (self.dcid().len() >= 8 || !self.token.is_empty())
+ }
+
+ pub fn packet_type(&self) -> PacketType {
+ self.packet_type
+ }
+
+ pub fn dcid(&self) -> ConnectionIdRef<'a> {
+ self.dcid
+ }
+
+ pub fn scid(&self) -> ConnectionIdRef<'a> {
+ self.scid
+ .expect("should only be called for long header packets")
+ }
+
+ pub fn token(&self) -> &'a [u8] {
+ self.token
+ }
+
+ pub fn version(&self) -> Option<Version> {
+ self.version.and_then(|v| Version::try_from(v).ok())
+ }
+
+ pub fn wire_version(&self) -> WireVersion {
+ debug_assert!(self.version.is_some());
+ self.version.unwrap_or(0)
+ }
+
+ pub fn len(&self) -> usize {
+ self.data.len()
+ }
+
+ fn decode_pn(expected: PacketNumber, pn: u64, w: usize) -> PacketNumber {
+ let window = 1_u64 << (w * 8);
+ let candidate = (expected & !(window - 1)) | pn;
+ if candidate + (window / 2) <= expected {
+ candidate + window
+ } else if candidate > expected + (window / 2) {
+ match candidate.checked_sub(window) {
+ Some(pn_sub) => pn_sub,
+ None => candidate,
+ }
+ } else {
+ candidate
+ }
+ }
+
+ /// Decrypt the header of the packet.
+ fn decrypt_header(
+ &self,
+ crypto: &mut CryptoDxState,
+ ) -> Res<(bool, PacketNumber, Vec<u8>, &'a [u8])> {
+ assert_ne!(self.packet_type, PacketType::Retry);
+ assert_ne!(self.packet_type, PacketType::VersionNegotiation);
+
+ qtrace!(
+ "unmask hdr={}",
+ hex(&self.data[..self.header_len + SAMPLE_OFFSET])
+ );
+
+ let sample_offset = self.header_len + SAMPLE_OFFSET;
+ let mask = if let Some(sample) = self.data.get(sample_offset..(sample_offset + SAMPLE_SIZE))
+ {
+ crypto.compute_mask(sample)
+ } else {
+ Err(Error::NoMoreData)
+ }?;
+
+ // Un-mask the leading byte.
+ let bits = if self.packet_type == PacketType::Short {
+ PACKET_HP_MASK_SHORT
+ } else {
+ PACKET_HP_MASK_LONG
+ };
+ let first_byte = self.data[0] ^ (mask[0] & bits);
+
+ // Make a copy of the header to work on.
+ let mut hdrbytes = self.data[..self.header_len + 4].to_vec();
+ hdrbytes[0] = first_byte;
+
+ // Unmask the PN.
+ let mut pn_encoded: u64 = 0;
+ for i in 0..MAX_PACKET_NUMBER_LEN {
+ hdrbytes[self.header_len + i] ^= mask[1 + i];
+ pn_encoded <<= 8;
+ pn_encoded += u64::from(hdrbytes[self.header_len + i]);
+ }
+
+ // Now decode the packet number length and apply it, hopefully in constant time.
+ let pn_len = usize::from((first_byte & 0x3) + 1);
+ hdrbytes.truncate(self.header_len + pn_len);
+ pn_encoded >>= 8 * (MAX_PACKET_NUMBER_LEN - pn_len);
+
+ qtrace!("unmasked hdr={}", hex(&hdrbytes));
+
+ let key_phase = self.packet_type == PacketType::Short
+ && (first_byte & PACKET_BIT_KEY_PHASE) == PACKET_BIT_KEY_PHASE;
+ let pn = Self::decode_pn(crypto.next_pn(), pn_encoded, pn_len);
+ Ok((
+ key_phase,
+ pn,
+ hdrbytes,
+ &self.data[self.header_len + pn_len..],
+ ))
+ }
+
+ pub fn decrypt(&self, crypto: &mut CryptoStates, release_at: Instant) -> Res<DecryptedPacket> {
+ let cspace: CryptoSpace = self.packet_type.into();
+ // When we don't have a version, the crypto code doesn't need a version
+ // for lookup, so use the default, but fix it up if decryption succeeds.
+ let version = self.version().unwrap_or_default();
+ // This has to work in two stages because we need to remove header protection
+ // before picking the keys to use.
+ if let Some(rx) = crypto.rx_hp(version, cspace) {
+ // Note that this will dump early, which creates a side-channel.
+ // This is OK in this case because we the only reason this can
+ // fail is if the cryptographic module is bad or the packet is
+ // too small (which is public information).
+ let (key_phase, pn, header, body) = self.decrypt_header(rx)?;
+ qtrace!([rx], "decoded header: {:?}", header);
+ let rx = crypto.rx(version, cspace, key_phase).unwrap();
+ let version = rx.version(); // Version fixup; see above.
+ let d = rx.decrypt(pn, &header, body)?;
+ // If this is the first packet ever successfully decrypted
+ // using `rx`, make sure to initiate a key update.
+ if rx.needs_update() {
+ crypto.key_update_received(release_at)?;
+ }
+ crypto.check_pn_overlap()?;
+ Ok(DecryptedPacket {
+ version,
+ pt: self.packet_type,
+ pn,
+ data: d,
+ })
+ } else if crypto.rx_pending(cspace) {
+ Err(Error::KeysPending(cspace))
+ } else {
+ qtrace!("keys for {:?} already discarded", cspace);
+ Err(Error::KeysDiscarded(cspace))
+ }
+ }
+
+ pub fn supported_versions(&self) -> Res<Vec<WireVersion>> {
+ assert_eq!(self.packet_type, PacketType::VersionNegotiation);
+ let mut decoder = Decoder::new(&self.data[self.header_len..]);
+ let mut res = Vec::new();
+ while decoder.remaining() > 0 {
+ let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?)?;
+ res.push(version);
+ }
+ Ok(res)
+ }
+}
+
+impl fmt::Debug for PublicPacket<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{:?}: {} {}",
+ self.packet_type(),
+ hex_with_len(&self.data[..self.header_len]),
+ hex_with_len(&self.data[self.header_len..])
+ )
+ }
+}
+
+pub struct DecryptedPacket {
+ version: Version,
+ pt: PacketType,
+ pn: PacketNumber,
+ data: Vec<u8>,
+}
+
+impl DecryptedPacket {
+ pub fn version(&self) -> Version {
+ self.version
+ }
+
+ pub fn packet_type(&self) -> PacketType {
+ self.pt
+ }
+
+ pub fn pn(&self) -> PacketNumber {
+ self.pn
+ }
+}
+
+impl Deref for DecryptedPacket {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.data[..]
+ }
+}
+
+#[cfg(all(test, not(feature = "fuzzing")))]
+mod tests {
+ use neqo_common::Encoder;
+ use test_fixture::{fixture_init, now};
+
+ use super::*;
+ use crate::{
+ crypto::{CryptoDxState, CryptoStates},
+ EmptyConnectionIdGenerator, RandomConnectionIdGenerator, Version,
+ };
+
+ const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
+ const SERVER_CID: &[u8] = &[0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5];
+
+ /// This is a connection ID manager, which is only used for decoding short header packets.
+ fn cid_mgr() -> RandomConnectionIdGenerator {
+ RandomConnectionIdGenerator::new(SERVER_CID.len())
+ }
+
+ const SAMPLE_INITIAL_PAYLOAD: &[u8] = &[
+ 0x02, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x40, 0x5a, 0x02, 0x00, 0x00, 0x56, 0x03, 0x03,
+ 0xee, 0xfc, 0xe7, 0xf7, 0xb3, 0x7b, 0xa1, 0xd1, 0x63, 0x2e, 0x96, 0x67, 0x78, 0x25, 0xdd,
+ 0xf7, 0x39, 0x88, 0xcf, 0xc7, 0x98, 0x25, 0xdf, 0x56, 0x6d, 0xc5, 0x43, 0x0b, 0x9a, 0x04,
+ 0x5a, 0x12, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00,
+ 0x20, 0x9d, 0x3c, 0x94, 0x0d, 0x89, 0x69, 0x0b, 0x84, 0xd0, 0x8a, 0x60, 0x99, 0x3c, 0x14,
+ 0x4e, 0xca, 0x68, 0x4d, 0x10, 0x81, 0x28, 0x7c, 0x83, 0x4d, 0x53, 0x11, 0xbc, 0xf3, 0x2b,
+ 0xb9, 0xda, 0x1a, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04,
+ ];
+ const SAMPLE_INITIAL: &[u8] = &[
+ 0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x00, 0x40, 0x75, 0xc0, 0xd9, 0x5a, 0x48, 0x2c, 0xd0, 0x99, 0x1c, 0xd2, 0x5b, 0x0a, 0xac,
+ 0x40, 0x6a, 0x58, 0x16, 0xb6, 0x39, 0x41, 0x00, 0xf3, 0x7a, 0x1c, 0x69, 0x79, 0x75, 0x54,
+ 0x78, 0x0b, 0xb3, 0x8c, 0xc5, 0xa9, 0x9f, 0x5e, 0xde, 0x4c, 0xf7, 0x3c, 0x3e, 0xc2, 0x49,
+ 0x3a, 0x18, 0x39, 0xb3, 0xdb, 0xcb, 0xa3, 0xf6, 0xea, 0x46, 0xc5, 0xb7, 0x68, 0x4d, 0xf3,
+ 0x54, 0x8e, 0x7d, 0xde, 0xb9, 0xc3, 0xbf, 0x9c, 0x73, 0xcc, 0x3f, 0x3b, 0xde, 0xd7, 0x4b,
+ 0x56, 0x2b, 0xfb, 0x19, 0xfb, 0x84, 0x02, 0x2f, 0x8e, 0xf4, 0xcd, 0xd9, 0x37, 0x95, 0xd7,
+ 0x7d, 0x06, 0xed, 0xbb, 0x7a, 0xaf, 0x2f, 0x58, 0x89, 0x18, 0x50, 0xab, 0xbd, 0xca, 0x3d,
+ 0x20, 0x39, 0x8c, 0x27, 0x64, 0x56, 0xcb, 0xc4, 0x21, 0x58, 0x40, 0x7d, 0xd0, 0x74, 0xee,
+ ];
+
+ #[test]
+ fn sample_server_initial() {
+ fixture_init();
+ let mut prot = CryptoDxState::test_default();
+
+ // The spec uses PN=1, but our crypto refuses to skip packet numbers.
+ // So burn an encryption:
+ let burn = prot.encrypt(0, &[], &[]).expect("burn OK");
+ assert_eq!(burn.len(), prot.expansion());
+
+ let mut builder = PacketBuilder::long(
+ Encoder::new(),
+ PacketType::Initial,
+ Version::default(),
+ &ConnectionId::from(&[][..]),
+ &ConnectionId::from(SERVER_CID),
+ );
+ builder.initial_token(&[]);
+ builder.pn(1, 2);
+ builder.encode(SAMPLE_INITIAL_PAYLOAD);
+ let packet = builder.build(&mut prot).expect("build");
+ assert_eq!(packet.as_ref(), SAMPLE_INITIAL);
+ }
+
+ #[test]
+ fn decrypt_initial() {
+ const EXTRA: &[u8] = &[0xce; 33];
+
+ fixture_init();
+ let mut padded = SAMPLE_INITIAL.to_vec();
+ padded.extend_from_slice(EXTRA);
+ let (packet, remainder) = PublicPacket::decode(&padded, &cid_mgr()).unwrap();
+ assert_eq!(packet.packet_type(), PacketType::Initial);
+ assert_eq!(&packet.dcid()[..], &[] as &[u8]);
+ assert_eq!(&packet.scid()[..], SERVER_CID);
+ assert!(packet.token().is_empty());
+ assert_eq!(remainder, EXTRA);
+
+ let decrypted = packet
+ .decrypt(&mut CryptoStates::test_default(), now())
+ .unwrap();
+ assert_eq!(decrypted.pn(), 1);
+ }
+
+ #[test]
+ fn disallow_long_dcid() {
+ let mut enc = Encoder::new();
+ enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC);
+ enc.encode_uint(4, Version::default().wire_version());
+ enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 1]);
+ enc.encode_vec(1, &[]);
+ enc.encode(&[0xff; 40]); // junk
+
+ assert!(PublicPacket::decode(enc.as_ref(), &cid_mgr()).is_err());
+ }
+
+ #[test]
+ fn disallow_long_scid() {
+ let mut enc = Encoder::new();
+ enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC);
+ enc.encode_uint(4, Version::default().wire_version());
+ enc.encode_vec(1, &[]);
+ enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 2]);
+ enc.encode(&[0xff; 40]); // junk
+
+ assert!(PublicPacket::decode(enc.as_ref(), &cid_mgr()).is_err());
+ }
+
+ const SAMPLE_SHORT: &[u8] = &[
+ 0x40, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0xf4, 0xa8, 0x30, 0x39, 0xc4, 0x7d,
+ 0x99, 0xe3, 0x94, 0x1c, 0x9b, 0xb9, 0x7a, 0x30, 0x1d, 0xd5, 0x8f, 0xf3, 0xdd, 0xa9,
+ ];
+ const SAMPLE_SHORT_PAYLOAD: &[u8] = &[0; 3];
+
+ #[test]
+ fn build_short() {
+ fixture_init();
+ let mut builder =
+ PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID));
+ builder.pn(0, 1);
+ builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling.
+ let packet = builder
+ .build(&mut CryptoDxState::test_default())
+ .expect("build");
+ assert_eq!(packet.as_ref(), SAMPLE_SHORT);
+ }
+
+ #[test]
+ fn scramble_short() {
+ fixture_init();
+ let mut firsts = Vec::new();
+ for _ in 0..64 {
+ let mut builder =
+ PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID));
+ builder.scramble(true);
+ builder.pn(0, 1);
+ firsts.push(builder.as_ref()[0]);
+ }
+ let is_set = |bit| move |v| v & bit == bit;
+ // There should be at least one value with the QUIC bit set:
+ assert!(firsts.iter().any(is_set(PACKET_BIT_FIXED_QUIC)));
+ // ... but not all:
+ assert!(!firsts.iter().all(is_set(PACKET_BIT_FIXED_QUIC)));
+ // There should be at least one value with the spin bit set:
+ assert!(firsts.iter().any(is_set(PACKET_BIT_SPIN)));
+ // ... but not all:
+ assert!(!firsts.iter().all(is_set(PACKET_BIT_SPIN)));
+ }
+
+ #[test]
+ fn decode_short() {
+ fixture_init();
+ let (packet, remainder) = PublicPacket::decode(SAMPLE_SHORT, &cid_mgr()).unwrap();
+ assert_eq!(packet.packet_type(), PacketType::Short);
+ assert!(remainder.is_empty());
+ let decrypted = packet
+ .decrypt(&mut CryptoStates::test_default(), now())
+ .unwrap();
+ assert_eq!(&decrypted[..], SAMPLE_SHORT_PAYLOAD);
+ }
+
+ /// By telling the decoder that the connection ID is shorter than it really is, we get a
+ /// decryption error.
+ #[test]
+ fn decode_short_bad_cid() {
+ fixture_init();
+ let (packet, remainder) = PublicPacket::decode(
+ SAMPLE_SHORT,
+ &RandomConnectionIdGenerator::new(SERVER_CID.len() - 1),
+ )
+ .unwrap();
+ assert_eq!(packet.packet_type(), PacketType::Short);
+ assert!(remainder.is_empty());
+ assert!(packet
+ .decrypt(&mut CryptoStates::test_default(), now())
+ .is_err());
+ }
+
+ /// Saying that the connection ID is longer causes the initial decode to fail.
+ #[test]
+ fn decode_short_long_cid() {
+ assert!(PublicPacket::decode(
+ SAMPLE_SHORT,
+ &RandomConnectionIdGenerator::new(SERVER_CID.len() + 1)
+ )
+ .is_err());
+ }
+
+ #[test]
+ fn build_two() {
+ fixture_init();
+ let mut prot = CryptoDxState::test_default();
+ let mut builder = PacketBuilder::long(
+ Encoder::new(),
+ PacketType::Handshake,
+ Version::default(),
+ &ConnectionId::from(SERVER_CID),
+ &ConnectionId::from(CLIENT_CID),
+ );
+ builder.pn(0, 1);
+ builder.encode(&[0; 3]);
+ let encoder = builder.build(&mut prot).expect("build");
+ assert_eq!(encoder.len(), 45);
+ let first = encoder.clone();
+
+ let mut builder = PacketBuilder::short(encoder, false, &ConnectionId::from(SERVER_CID));
+ builder.pn(1, 3);
+ builder.encode(&[0]); // Minimal size (packet number is big enough).
+ let encoder = builder.build(&mut prot).expect("build");
+ assert_eq!(
+ first.as_ref(),
+ &encoder.as_ref()[..first.len()],
+ "the first packet should be a prefix"
+ );
+ assert_eq!(encoder.len(), 45 + 29);
+ }
+
+ #[test]
+ fn build_long() {
+ const EXPECTED: &[u8] = &[
+ 0xe4, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x40, 0x14, 0xfb, 0xa9, 0x32, 0x3a, 0xf8,
+ 0xbb, 0x18, 0x63, 0xc6, 0xbd, 0x78, 0x0e, 0xba, 0x0c, 0x98, 0x65, 0x58, 0xc9, 0x62,
+ 0x31,
+ ];
+
+ fixture_init();
+ let mut builder = PacketBuilder::long(
+ Encoder::new(),
+ PacketType::Handshake,
+ Version::default(),
+ &ConnectionId::from(&[][..]),
+ &ConnectionId::from(&[][..]),
+ );
+ builder.pn(0, 1);
+ builder.encode(&[1, 2, 3]);
+ let packet = builder.build(&mut CryptoDxState::test_default()).unwrap();
+ assert_eq!(packet.as_ref(), EXPECTED);
+ }
+
+ #[test]
+ fn scramble_long() {
+ fixture_init();
+ let mut found_unset = false;
+ let mut found_set = false;
+ for _ in 1..64 {
+ let mut builder = PacketBuilder::long(
+ Encoder::new(),
+ PacketType::Handshake,
+ Version::default(),
+ &ConnectionId::from(&[][..]),
+ &ConnectionId::from(&[][..]),
+ );
+ builder.pn(0, 1);
+ builder.scramble(true);
+ if (builder.as_ref()[0] & PACKET_BIT_FIXED_QUIC) == 0 {
+ found_unset = true;
+ } else {
+ found_set = true;
+ }
+ }
+ assert!(found_unset);
+ assert!(found_set);
+ }
+
+ #[test]
+ fn build_abort() {
+ let mut builder = PacketBuilder::long(
+ Encoder::new(),
+ PacketType::Initial,
+ Version::default(),
+ &ConnectionId::from(&[][..]),
+ &ConnectionId::from(SERVER_CID),
+ );
+ assert_ne!(builder.remaining(), 0);
+ builder.initial_token(&[]);
+ assert_ne!(builder.remaining(), 0);
+ builder.pn(1, 2);
+ assert_ne!(builder.remaining(), 0);
+ let encoder = builder.abort();
+ assert!(encoder.is_empty());
+ }
+
+ #[test]
+ fn build_insufficient_space() {
+ fixture_init();
+
+ let mut builder = PacketBuilder::short(
+ Encoder::with_capacity(100),
+ true,
+ &ConnectionId::from(SERVER_CID),
+ );
+ builder.pn(0, 1);
+ // Pad, but not up to the full capacity. Leave enough space for the
+ // AEAD expansion and some extra, but not for an entire long header.
+ builder.set_limit(75);
+ builder.enable_padding(true);
+ assert!(builder.pad());
+ let encoder = builder.build(&mut CryptoDxState::test_default()).unwrap();
+ let encoder_copy = encoder.clone();
+
+ let builder = PacketBuilder::long(
+ encoder,
+ PacketType::Initial,
+ Version::default(),
+ &ConnectionId::from(SERVER_CID),
+ &ConnectionId::from(SERVER_CID),
+ );
+ assert_eq!(builder.remaining(), 0);
+ assert_eq!(builder.abort(), encoder_copy);
+ }
+
+ const SAMPLE_RETRY_V2: &[u8] = &[
+ 0xcf, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xc8, 0x64, 0x6c, 0xe8, 0xbf, 0xe3, 0x39, 0x52, 0xd9, 0x55,
+ 0x54, 0x36, 0x65, 0xdc, 0xc7, 0xb6,
+ ];
+
+ const SAMPLE_RETRY_V1: &[u8] = &[
+ 0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x04, 0xa2, 0x65, 0xba, 0x2e, 0xff, 0x4d, 0x82, 0x90, 0x58,
+ 0xfb, 0x3f, 0x0f, 0x24, 0x96, 0xba,
+ ];
+
+ const SAMPLE_RETRY_29: &[u8] = &[
+ 0xff, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xd1, 0x69, 0x26, 0xd8, 0x1f, 0x6f, 0x9c, 0xa2, 0x95, 0x3a,
+ 0x8a, 0xa4, 0x57, 0x5e, 0x1e, 0x49,
+ ];
+
+ const SAMPLE_RETRY_30: &[u8] = &[
+ 0xff, 0xff, 0x00, 0x00, 0x1e, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2d, 0x3e, 0x04, 0x5d, 0x6d, 0x39, 0x20, 0x67, 0x89, 0x94,
+ 0x37, 0x10, 0x8c, 0xe0, 0x0a, 0x61,
+ ];
+
+ const SAMPLE_RETRY_31: &[u8] = &[
+ 0xff, 0xff, 0x00, 0x00, 0x1f, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xc7, 0x0c, 0xe5, 0xde, 0x43, 0x0b, 0x4b, 0xdb, 0x7d, 0xf1,
+ 0xa3, 0x83, 0x3a, 0x75, 0xf9, 0x86,
+ ];
+
+ const SAMPLE_RETRY_32: &[u8] = &[
+ 0xff, 0xff, 0x00, 0x00, 0x20, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x59, 0x75, 0x65, 0x19, 0xdd, 0x6c, 0xc8, 0x5b, 0xd9, 0x0e,
+ 0x33, 0xa9, 0x34, 0xd2, 0xff, 0x85,
+ ];
+
+ const RETRY_TOKEN: &[u8] = b"token";
+
+ fn build_retry_single(version: Version, sample_retry: &[u8]) {
+ fixture_init();
+ let retry =
+ PacketBuilder::retry(version, &[], SERVER_CID, RETRY_TOKEN, CLIENT_CID).unwrap();
+
+ let (packet, remainder) = PublicPacket::decode(&retry, &cid_mgr()).unwrap();
+ assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID)));
+ assert!(remainder.is_empty());
+
+ // The builder adds randomness, which makes expectations hard.
+ // So only do a full check when that randomness matches up.
+ if retry[0] == sample_retry[0] {
+ assert_eq!(&retry, &sample_retry);
+ } else {
+ // Otherwise, just check that the header is OK.
+ assert_eq!(
+ retry[0] & 0xf0,
+ 0xc0 | (PacketType::Retry.to_byte(version) << 4)
+ );
+ let header_range = 1..retry.len() - 16;
+ assert_eq!(&retry[header_range.clone()], &sample_retry[header_range]);
+ }
+ }
+
+ #[test]
+ fn build_retry_v2() {
+ build_retry_single(Version::Version2, SAMPLE_RETRY_V2);
+ }
+
+ #[test]
+ fn build_retry_v1() {
+ build_retry_single(Version::Version1, SAMPLE_RETRY_V1);
+ }
+
+ #[test]
+ fn build_retry_29() {
+ build_retry_single(Version::Draft29, SAMPLE_RETRY_29);
+ }
+
+ #[test]
+ fn build_retry_30() {
+ build_retry_single(Version::Draft30, SAMPLE_RETRY_30);
+ }
+
+ #[test]
+ fn build_retry_31() {
+ build_retry_single(Version::Draft31, SAMPLE_RETRY_31);
+ }
+
+ #[test]
+ fn build_retry_32() {
+ build_retry_single(Version::Draft32, SAMPLE_RETRY_32);
+ }
+
+ #[test]
+ fn build_retry_multiple() {
+ // Run the build_retry test a few times.
+ // Odds are approximately 1 in 8 that the full comparison doesn't happen
+ // for a given version.
+ for _ in 0..32 {
+ build_retry_v2();
+ build_retry_v1();
+ build_retry_29();
+ build_retry_30();
+ build_retry_31();
+ build_retry_32();
+ }
+ }
+
+ fn decode_retry(version: Version, sample_retry: &[u8]) {
+ fixture_init();
+ let (packet, remainder) =
+ PublicPacket::decode(sample_retry, &RandomConnectionIdGenerator::new(5)).unwrap();
+ assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID)));
+ assert_eq!(Some(version), packet.version());
+ assert!(packet.dcid().is_empty());
+ assert_eq!(&packet.scid()[..], SERVER_CID);
+ assert_eq!(packet.token(), RETRY_TOKEN);
+ assert!(remainder.is_empty());
+ }
+
+ #[test]
+ fn decode_retry_v2() {
+ decode_retry(Version::Version2, SAMPLE_RETRY_V2);
+ }
+
+ #[test]
+ fn decode_retry_v1() {
+ decode_retry(Version::Version1, SAMPLE_RETRY_V1);
+ }
+
+ #[test]
+ fn decode_retry_29() {
+ decode_retry(Version::Draft29, SAMPLE_RETRY_29);
+ }
+
+ #[test]
+ fn decode_retry_30() {
+ decode_retry(Version::Draft30, SAMPLE_RETRY_30);
+ }
+
+ #[test]
+ fn decode_retry_31() {
+ decode_retry(Version::Draft31, SAMPLE_RETRY_31);
+ }
+
+ #[test]
+ fn decode_retry_32() {
+ decode_retry(Version::Draft32, SAMPLE_RETRY_32);
+ }
+
+ /// Check some packets that are clearly not valid Retry packets.
+ #[test]
+ fn invalid_retry() {
+ fixture_init();
+ let cid_mgr = RandomConnectionIdGenerator::new(5);
+ let odcid = ConnectionId::from(CLIENT_CID);
+
+ assert!(PublicPacket::decode(&[], &cid_mgr).is_err());
+
+ let (packet, remainder) = PublicPacket::decode(SAMPLE_RETRY_V1, &cid_mgr).unwrap();
+ assert!(remainder.is_empty());
+ assert!(packet.is_valid_retry(&odcid));
+
+ let mut damaged_retry = SAMPLE_RETRY_V1.to_vec();
+ let last = damaged_retry.len() - 1;
+ damaged_retry[last] ^= 66;
+ let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap();
+ assert!(remainder.is_empty());
+ assert!(!packet.is_valid_retry(&odcid));
+
+ damaged_retry.truncate(last);
+ let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap();
+ assert!(remainder.is_empty());
+ assert!(!packet.is_valid_retry(&odcid));
+
+ // An invalid token should be rejected sooner.
+ damaged_retry.truncate(last - 4);
+ assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err());
+
+ damaged_retry.truncate(last - 1);
+ assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err());
+ }
+
+ const SAMPLE_VN: &[u8] = &[
+ 0x80, 0x00, 0x00, 0x00, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x08,
+ 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x00, 0x00,
+ 0x01, 0xff, 0x00, 0x00, 0x20, 0xff, 0x00, 0x00, 0x1f, 0xff, 0x00, 0x00, 0x1e, 0xff, 0x00,
+ 0x00, 0x1d, 0x0a, 0x0a, 0x0a, 0x0a,
+ ];
+
+ #[test]
+ fn build_vn() {
+ fixture_init();
+ let mut vn =
+ PacketBuilder::version_negotiation(SERVER_CID, CLIENT_CID, 0x0a0a0a0a, &Version::all());
+ // Erase randomness from greasing...
+ assert_eq!(vn.len(), SAMPLE_VN.len());
+ vn[0] &= 0x80;
+ for v in vn.iter_mut().skip(SAMPLE_VN.len() - 4) {
+ *v &= 0x0f;
+ }
+ assert_eq!(&vn, &SAMPLE_VN);
+ }
+
+ #[test]
+ fn vn_do_not_repeat_client_grease() {
+ fixture_init();
+ let vn =
+ PacketBuilder::version_negotiation(SERVER_CID, CLIENT_CID, 0x0a0a0a0a, &Version::all());
+ assert_ne!(&vn[SAMPLE_VN.len() - 4..], &[0x0a, 0x0a, 0x0a, 0x0a]);
+ }
+
+ #[test]
+ fn parse_vn() {
+ let (packet, remainder) =
+ PublicPacket::decode(SAMPLE_VN, &EmptyConnectionIdGenerator::default()).unwrap();
+ assert!(remainder.is_empty());
+ assert_eq!(&packet.dcid[..], SERVER_CID);
+ assert!(packet.scid.is_some());
+ assert_eq!(&packet.scid.unwrap()[..], CLIENT_CID);
+ }
+
+ /// A Version Negotiation packet can have a long connection ID.
+ #[test]
+ fn parse_vn_big_cid() {
+ const BIG_DCID: &[u8] = &[0x44; MAX_CONNECTION_ID_LEN + 1];
+ const BIG_SCID: &[u8] = &[0xee; 255];
+
+ let mut enc = Encoder::from(&[0xff, 0x00, 0x00, 0x00, 0x00][..]);
+ enc.encode_vec(1, BIG_DCID);
+ enc.encode_vec(1, BIG_SCID);
+ enc.encode_uint(4, 0x1a2a_3a4a_u64);
+ enc.encode_uint(4, Version::default().wire_version());
+ enc.encode_uint(4, 0x5a6a_7a8a_u64);
+
+ let (packet, remainder) =
+ PublicPacket::decode(enc.as_ref(), &EmptyConnectionIdGenerator::default()).unwrap();
+ assert!(remainder.is_empty());
+ assert_eq!(&packet.dcid[..], BIG_DCID);
+ assert!(packet.scid.is_some());
+ assert_eq!(&packet.scid.unwrap()[..], BIG_SCID);
+ }
+
+ #[test]
+ fn decode_pn() {
+ // When the expected value is low, the value doesn't go negative.
+ assert_eq!(PublicPacket::decode_pn(0, 0, 1), 0);
+ assert_eq!(PublicPacket::decode_pn(0, 0xff, 1), 0xff);
+ assert_eq!(PublicPacket::decode_pn(10, 0, 1), 0);
+ assert_eq!(PublicPacket::decode_pn(0x7f, 0, 1), 0);
+ assert_eq!(PublicPacket::decode_pn(0x80, 0, 1), 0x100);
+ assert_eq!(PublicPacket::decode_pn(0x80, 2, 1), 2);
+ assert_eq!(PublicPacket::decode_pn(0x80, 0xff, 1), 0xff);
+ assert_eq!(PublicPacket::decode_pn(0x7ff, 0xfe, 1), 0x7fe);
+
+ // This is invalid by spec, as we are expected to check for overflow around 2^62-1,
+ // but we don't need to worry about overflow
+ // and hitting this is basically impossible in practice.
+ assert_eq!(
+ PublicPacket::decode_pn(0x3fff_ffff_ffff_ffff, 2, 4),
+ 0x4000_0000_0000_0002
+ );
+ }
+
+ #[test]
+ fn chacha20_sample() {
+ const PACKET: &[u8] = &[
+ 0x4c, 0xfe, 0x41, 0x89, 0x65, 0x5e, 0x5c, 0xd5, 0x5c, 0x41, 0xf6, 0x90, 0x80, 0x57,
+ 0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb,
+ ];
+ fixture_init();
+ let (packet, slice) =
+ PublicPacket::decode(PACKET, &EmptyConnectionIdGenerator::default()).unwrap();
+ assert!(slice.is_empty());
+ let decrypted = packet
+ .decrypt(&mut CryptoStates::test_chacha(), now())
+ .unwrap();
+ assert_eq!(decrypted.packet_type(), PacketType::Short);
+ assert_eq!(decrypted.pn(), 654_360_564);
+ assert_eq!(&decrypted[..], &[0x01]);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/packet/retry.rs b/third_party/rust/neqo-transport/src/packet/retry.rs
new file mode 100644
index 0000000000..004e9de6e7
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/packet/retry.rs
@@ -0,0 +1,59 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![deny(clippy::pedantic)]
+
+use std::cell::RefCell;
+
+use neqo_common::qerror;
+use neqo_crypto::{hkdf, Aead, TLS_AES_128_GCM_SHA256, TLS_VERSION_1_3};
+
+use crate::{version::Version, Error, Res};
+
+/// The AEAD used for Retry is fixed, so use thread local storage.
+fn make_aead(version: Version) -> Aead {
+ #[cfg(debug_assertions)]
+ ::neqo_crypto::assert_initialized();
+
+ let secret = hkdf::import_key(TLS_VERSION_1_3, version.retry_secret()).unwrap();
+ Aead::new(
+ false,
+ TLS_VERSION_1_3,
+ TLS_AES_128_GCM_SHA256,
+ &secret,
+ version.label_prefix(),
+ )
+ .unwrap()
+}
+thread_local!(static RETRY_AEAD_29: RefCell<Aead> = RefCell::new(make_aead(Version::Draft29)));
+thread_local!(static RETRY_AEAD_V1: RefCell<Aead> = RefCell::new(make_aead(Version::Version1)));
+thread_local!(static RETRY_AEAD_V2: RefCell<Aead> = RefCell::new(make_aead(Version::Version2)));
+
+/// Run a function with the appropriate Retry AEAD.
+pub fn use_aead<F, T>(version: Version, f: F) -> Res<T>
+where
+ F: FnOnce(&Aead) -> Res<T>,
+{
+ match version {
+ Version::Version2 => &RETRY_AEAD_V2,
+ Version::Version1 => &RETRY_AEAD_V1,
+ Version::Draft29 | Version::Draft30 | Version::Draft31 | Version::Draft32 => &RETRY_AEAD_29,
+ }
+ .try_with(|aead| f(&aead.borrow()))
+ .map_err(|e| {
+ qerror!("Unable to access Retry AEAD: {:?}", e);
+ Error::InternalError
+ })?
+}
+
+/// Determine how large the expansion is for a given key.
+pub fn expansion(version: Version) -> usize {
+ if let Ok(ex) = use_aead(version, |aead| Ok(aead.expansion())) {
+ ex
+ } else {
+ panic!("Unable to access Retry AEAD")
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/path.rs b/third_party/rust/neqo-transport/src/path.rs
new file mode 100644
index 0000000000..d6920c8d94
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/path.rs
@@ -0,0 +1,1032 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![deny(clippy::pedantic)]
+#![allow(clippy::module_name_repetitions)]
+
+use std::{
+ cell::RefCell,
+ convert::TryFrom,
+ fmt::{self, Display},
+ mem,
+ net::{IpAddr, SocketAddr},
+ rc::Rc,
+ time::{Duration, Instant},
+};
+
+use neqo_common::{hex, qdebug, qinfo, qlog::NeqoQlog, qtrace, Datagram, Encoder, IpTos};
+use neqo_crypto::random;
+
+use crate::{
+ ackrate::{AckRate, PeerAckDelay},
+ cc::CongestionControlAlgorithm,
+ cid::{ConnectionId, ConnectionIdRef, ConnectionIdStore, RemoteConnectionIdEntry},
+ frame::{FRAME_TYPE_PATH_CHALLENGE, FRAME_TYPE_PATH_RESPONSE, FRAME_TYPE_RETIRE_CONNECTION_ID},
+ packet::PacketBuilder,
+ recovery::RecoveryToken,
+ rtt::RttEstimate,
+ sender::PacketSender,
+ stats::FrameStats,
+ tracking::{PacketNumberSpace, SentPacket},
+ Stats,
+};
+
+/// This is the MTU that we assume when using IPv6.
+/// We use this size for Initial packets, so we don't need to worry about probing for support.
+/// If the path doesn't support this MTU, we will assume that it doesn't support QUIC.
+///
+/// This is a multiple of 16 greater than the largest possible short header (1 + 20 + 4).
+pub const PATH_MTU_V6: usize = 1337;
+/// The path MTU for IPv4 can be 20 bytes larger than for v6.
+pub const PATH_MTU_V4: usize = PATH_MTU_V6 + 20;
+/// The number of times that a path will be probed before it is considered failed.
+const MAX_PATH_PROBES: usize = 3;
+/// The maximum number of paths that `Paths` will track.
+const MAX_PATHS: usize = 15;
+
+pub type PathRef = Rc<RefCell<Path>>;
+
+/// A collection for network paths.
+/// This holds a collection of paths that have been used for sending or
+/// receiving, plus an additional "temporary" path that is held only while
+/// processing a packet.
+/// This structure limits its storage and will forget about paths if it
+/// is exposed to too many paths.
+#[derive(Debug, Default)]
+pub struct Paths {
+ /// All of the paths. All of these paths will be permanent.
+ #[allow(unknown_lints)] // available with Rust v1.75
+ #[allow(clippy::struct_field_names)]
+ paths: Vec<PathRef>,
+ /// This is the primary path. This will only be `None` initially, so
+ /// care needs to be taken regarding that only during the handshake.
+ /// This path will also be in `paths`.
+ primary: Option<PathRef>,
+
+ /// The path that we would prefer to migrate to.
+ migration_target: Option<PathRef>,
+
+ /// Connection IDs that need to be retired.
+ to_retire: Vec<u64>,
+
+ /// QLog handler.
+ qlog: NeqoQlog,
+}
+
+impl Paths {
+ /// Find the path for the given addresses.
+ /// This might be a temporary path.
+ pub fn find_path(
+ &self,
+ local: SocketAddr,
+ remote: SocketAddr,
+ cc: CongestionControlAlgorithm,
+ pacing: bool,
+ now: Instant,
+ ) -> PathRef {
+ self.paths
+ .iter()
+ .find_map(|p| {
+ if p.borrow().received_on(local, remote, false) {
+ Some(Rc::clone(p))
+ } else {
+ None
+ }
+ })
+ .unwrap_or_else(|| {
+ let mut p = Path::temporary(local, remote, cc, pacing, self.qlog.clone(), now);
+ if let Some(primary) = self.primary.as_ref() {
+ p.prime_rtt(primary.borrow().rtt());
+ }
+ Rc::new(RefCell::new(p))
+ })
+ }
+
+ /// Find the path, but allow for rebinding. That matches the pair of addresses
+ /// to paths that match the remote address only based on IP addres, not port.
+ /// We use this when the other side migrates to skip address validation and
+ /// creating a new path.
+ pub fn find_path_with_rebinding(
+ &self,
+ local: SocketAddr,
+ remote: SocketAddr,
+ cc: CongestionControlAlgorithm,
+ pacing: bool,
+ now: Instant,
+ ) -> PathRef {
+ self.paths
+ .iter()
+ .find_map(|p| {
+ if p.borrow().received_on(local, remote, false) {
+ Some(Rc::clone(p))
+ } else {
+ None
+ }
+ })
+ .or_else(|| {
+ self.paths.iter().find_map(|p| {
+ if p.borrow().received_on(local, remote, true) {
+ Some(Rc::clone(p))
+ } else {
+ None
+ }
+ })
+ })
+ .unwrap_or_else(|| {
+ Rc::new(RefCell::new(Path::temporary(
+ local,
+ remote,
+ cc,
+ pacing,
+ self.qlog.clone(),
+ now,
+ )))
+ })
+ }
+
+ /// Get a reference to the primary path. This will assert if there is no primary
+ /// path, which happens at a server prior to receiving a valid Initial packet
+ /// from a client. So be careful using this method.
+ pub fn primary(&self) -> PathRef {
+ self.primary_fallible().unwrap()
+ }
+
+ /// Get a reference to the primary path. Use this prior to handshake completion.
+ pub fn primary_fallible(&self) -> Option<PathRef> {
+ self.primary.as_ref().map(Rc::clone)
+ }
+
+ /// Returns true if the path is not permanent.
+ pub fn is_temporary(&self, path: &PathRef) -> bool {
+ // Ask the path first, which is simpler.
+ path.borrow().is_temporary() || !self.paths.iter().any(|p| Rc::ptr_eq(p, path))
+ }
+
+ fn retire(to_retire: &mut Vec<u64>, retired: &PathRef) {
+ let seqno = retired
+ .borrow()
+ .remote_cid
+ .as_ref()
+ .unwrap()
+ .sequence_number();
+ to_retire.push(seqno);
+ }
+
+ /// Adopt a temporary path as permanent.
+ /// The first path that is made permanent is made primary.
+ pub fn make_permanent(
+ &mut self,
+ path: &PathRef,
+ local_cid: Option<ConnectionId>,
+ remote_cid: RemoteConnectionIdEntry,
+ ) {
+ debug_assert!(self.is_temporary(path));
+
+ // Make sure not to track too many paths.
+ // This protects index 0, which contains the primary path.
+ if self.paths.len() >= MAX_PATHS {
+ debug_assert_eq!(self.paths.len(), MAX_PATHS);
+ let removed = self.paths.remove(1);
+ Self::retire(&mut self.to_retire, &removed);
+ if self
+ .migration_target
+ .as_ref()
+ .map_or(false, |target| Rc::ptr_eq(target, &removed))
+ {
+ qinfo!(
+ [path.borrow()],
+ "The migration target path had to be removed"
+ );
+ self.migration_target = None;
+ }
+ debug_assert_eq!(Rc::strong_count(&removed), 1);
+ }
+
+ qdebug!([path.borrow()], "Make permanent");
+ path.borrow_mut().make_permanent(local_cid, remote_cid);
+ self.paths.push(Rc::clone(path));
+ if self.primary.is_none() {
+ assert!(self.select_primary(path).is_none());
+ }
+ }
+
+ /// Select a path as the primary. Returns the old primary path.
+ /// Using the old path is only necessary if this change in path is a reaction
+ /// to a migration from a peer, in which case the old path needs to be probed.
+ #[must_use]
+ fn select_primary(&mut self, path: &PathRef) -> Option<PathRef> {
+ qinfo!([path.borrow()], "set as primary path");
+ let old_path = self.primary.replace(Rc::clone(path)).map(|old| {
+ old.borrow_mut().set_primary(false);
+ old
+ });
+
+ // Swap the primary path into slot 0, so that it is protected from eviction.
+ let idx = self
+ .paths
+ .iter()
+ .enumerate()
+ .find_map(|(i, p)| if Rc::ptr_eq(p, path) { Some(i) } else { None })
+ .expect("migration target should be permanent");
+ self.paths.swap(0, idx);
+
+ path.borrow_mut().set_primary(true);
+ old_path
+ }
+
+ /// Migrate to the identified path. If `force` is true, the path
+ /// is forcibly marked as valid and the path is used immediately.
+ /// Otherwise, migration will occur after probing succeeds.
+ /// The path is always probed and will be abandoned if probing fails.
+ /// Returns `true` if the path was migrated.
+ pub fn migrate(&mut self, path: &PathRef, force: bool, now: Instant) -> bool {
+ debug_assert!(!self.is_temporary(path));
+ if force || path.borrow().is_valid() {
+ path.borrow_mut().set_valid(now);
+ mem::drop(self.select_primary(path));
+ self.migration_target = None;
+ } else {
+ self.migration_target = Some(Rc::clone(path));
+ }
+ path.borrow_mut().probe();
+ self.migration_target.is_none()
+ }
+
+ /// Process elapsed time for active paths.
+ /// Returns an true if there are viable paths remaining after tidying up.
+ ///
+ /// TODO(mt) - the paths should own the RTT estimator, so they can find the PTO
+ /// for themselves.
+ pub fn process_timeout(&mut self, now: Instant, pto: Duration) -> bool {
+ let to_retire = &mut self.to_retire;
+ let mut primary_failed = false;
+ self.paths.retain(|p| {
+ if p.borrow_mut().process_timeout(now, pto) {
+ true
+ } else {
+ qdebug!([p.borrow()], "Retiring path");
+ if p.borrow().is_primary() {
+ primary_failed = true;
+ }
+ Self::retire(to_retire, p);
+ false
+ }
+ });
+
+ if primary_failed {
+ self.primary = None;
+ // Find a valid path to fall back to.
+ if let Some(fallback) = self
+ .paths
+ .iter()
+ .rev() // More recent paths are toward the end.
+ .find(|p| p.borrow().is_valid())
+ {
+ // Need a clone as `fallback` is borrowed from `self`.
+ let path = Rc::clone(fallback);
+ qinfo!([path.borrow()], "Failing over after primary path failed");
+ mem::drop(self.select_primary(&path));
+ true
+ } else {
+ false
+ }
+ } else {
+ true
+ }
+ }
+
+ /// Get when the next call to `process_timeout()` should be scheduled.
+ pub fn next_timeout(&self, pto: Duration) -> Option<Instant> {
+ self.paths
+ .iter()
+ .filter_map(|p| p.borrow().next_timeout(pto))
+ .min()
+ }
+
+ /// Set the identified path to be primary.
+ /// This panics if `make_permanent` hasn't been called.
+ pub fn handle_migration(&mut self, path: &PathRef, remote: SocketAddr, now: Instant) {
+ qtrace!([self.primary().borrow()], "handle_migration");
+ // The update here needs to match the checks in `Path::received_on`.
+ // Here, we update the remote port number to match the source port on the
+ // datagram that was received. This ensures that we send subsequent
+ // packets back to the right place.
+ path.borrow_mut().update_port(remote.port());
+
+ if path.borrow().is_primary() {
+ // Update when the path was last regarded as valid.
+ path.borrow_mut().update(now);
+ return;
+ }
+
+ if let Some(old_path) = self.select_primary(path) {
+ // Need to probe the old path if the peer migrates.
+ old_path.borrow_mut().probe();
+ // TODO(mt) - suppress probing if the path was valid within 3PTO.
+ }
+ }
+
+ /// Select a path to send on. This will select the first path that has
+ /// probes to send, then fall back to the primary path.
+ pub fn select_path(&self) -> Option<PathRef> {
+ self.paths
+ .iter()
+ .find_map(|p| {
+ if p.borrow().has_probe() {
+ Some(Rc::clone(p))
+ } else {
+ None
+ }
+ })
+ .or_else(|| self.primary.as_ref().map(Rc::clone))
+ }
+
+ /// A `PATH_RESPONSE` was received.
+ /// Returns `true` if migration occurred.
+ #[must_use]
+ pub fn path_response(&mut self, response: [u8; 8], now: Instant) -> bool {
+ // TODO(mt) consider recording an RTT measurement here as we don't train
+ // RTT for non-primary paths.
+ for p in &self.paths {
+ if p.borrow_mut().path_response(response, now) {
+ // The response was accepted. If this path is one we intend
+ // to migrate to, then migrate.
+ if self
+ .migration_target
+ .as_ref()
+ .map_or(false, |target| Rc::ptr_eq(target, p))
+ {
+ let primary = self.migration_target.take();
+ mem::drop(self.select_primary(&primary.unwrap()));
+ return true;
+ }
+ break;
+ }
+ }
+ false
+ }
+
+ /// Retire all of the connection IDs prior to the indicated sequence number.
+ /// Keep active paths if possible by pulling new connection IDs from the provided store.
+ /// One slightly non-obvious consequence of this is that if migration is being attempted
+ /// and the new path cannot obtain a new connection ID, the migration attempt will fail.
+ pub fn retire_cids(&mut self, retire_prior: u64, store: &mut ConnectionIdStore<[u8; 16]>) {
+ let to_retire = &mut self.to_retire;
+ let migration_target = &mut self.migration_target;
+
+ // First, tell the store to release any connection IDs that are too old.
+ let mut retired = store.retire_prior_to(retire_prior);
+ to_retire.append(&mut retired);
+
+ self.paths.retain(|p| {
+ let current = p.borrow().remote_cid.as_ref().unwrap().sequence_number();
+ if current < retire_prior {
+ to_retire.push(current);
+ let new_cid = store.next();
+ let has_replacement = new_cid.is_some();
+ // There must be a connection ID available for the primary path as we
+ // keep that path at the first index.
+ debug_assert!(!p.borrow().is_primary() || has_replacement);
+ p.borrow_mut().remote_cid = new_cid;
+ if !has_replacement
+ && migration_target
+ .as_ref()
+ .map_or(false, |target| Rc::ptr_eq(target, p))
+ {
+ qinfo!(
+ [p.borrow()],
+ "NEW_CONNECTION_ID with Retire Prior To forced migration to fail"
+ );
+ *migration_target = None;
+ }
+ has_replacement
+ } else {
+ true
+ }
+ });
+ }
+
+ /// Write out any `RETIRE_CONNECTION_ID` frames that are outstanding.
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ while let Some(seqno) = self.to_retire.pop() {
+ if builder.remaining() < 1 + Encoder::varint_len(seqno) {
+ self.to_retire.push(seqno);
+ break;
+ }
+ builder.encode_varint(FRAME_TYPE_RETIRE_CONNECTION_ID);
+ builder.encode_varint(seqno);
+ tokens.push(RecoveryToken::RetireConnectionId(seqno));
+ stats.retire_connection_id += 1;
+ }
+
+ // Write out any ACK_FREQUENCY frames.
+ self.primary()
+ .borrow_mut()
+ .write_cc_frames(builder, tokens, stats);
+ }
+
+ pub fn lost_retire_cid(&mut self, lost: u64) {
+ self.to_retire.push(lost);
+ }
+
+ pub fn acked_retire_cid(&mut self, acked: u64) {
+ self.to_retire.retain(|&seqno| seqno != acked);
+ }
+
+ pub fn lost_ack_frequency(&mut self, lost: &AckRate) {
+ self.primary().borrow_mut().lost_ack_frequency(lost);
+ }
+
+ pub fn acked_ack_frequency(&mut self, acked: &AckRate) {
+ self.primary().borrow_mut().acked_ack_frequency(acked);
+ }
+
+ /// Get an estimate of the RTT on the primary path.
+ #[cfg(test)]
+ pub fn rtt(&self) -> Duration {
+ // Rather than have this fail when there is no active path,
+ // make a new RTT esimate and interrogate that.
+ // That is more expensive, but it should be rare and breaking encapsulation
+ // is worse, especially as this is only used in tests.
+ self.primary_fallible()
+ .map_or(RttEstimate::default().estimate(), |p| {
+ p.borrow().rtt().estimate()
+ })
+ }
+
+ pub fn set_qlog(&mut self, qlog: NeqoQlog) {
+ for p in &mut self.paths {
+ p.borrow_mut().set_qlog(qlog.clone());
+ }
+ self.qlog = qlog;
+ }
+}
+
+/// The state of a path with respect to address validation.
+#[derive(Debug)]
+enum ProbeState {
+ /// The path was last valid at the indicated time.
+ Valid,
+ /// The path was previously valid, but a new probe is needed.
+ ProbeNeeded { probe_count: usize },
+ /// The path hasn't been validated, but a probe has been sent.
+ Probing {
+ /// The number of probes that have been sent.
+ probe_count: usize,
+ /// The probe that was last sent.
+ data: [u8; 8],
+ /// Whether the probe was sent in a datagram padded to the path MTU.
+ mtu: bool,
+ /// When the probe was sent.
+ sent: Instant,
+ },
+ /// Validation failed the last time it was attempted.
+ Failed,
+}
+
+impl ProbeState {
+ /// Determine whether the current state requires probing.
+ fn probe_needed(&self) -> bool {
+ matches!(self, Self::ProbeNeeded { .. })
+ }
+}
+
+/// A network path.
+///
+/// Paths are used a little bit strangely by connections:
+/// they need to encapsulate all the state for a path (which
+/// is normal), but that information is not propagated to the
+/// `Paths` instance that holds them. This is because the packet
+/// processing where changes occur can't hold a reference to the
+/// `Paths` instance that owns the `Path`. Any changes to the
+/// path are communicated to `Paths` afterwards.
+#[derive(Debug)]
+pub struct Path {
+ /// A local socket address.
+ local: SocketAddr,
+ /// A remote socket address.
+ remote: SocketAddr,
+ /// The connection IDs that we use when sending on this path.
+ /// This is only needed during the handshake.
+ local_cid: Option<ConnectionId>,
+ /// The current connection ID that we are using and its details.
+ remote_cid: Option<RemoteConnectionIdEntry>,
+
+ /// Whether this is the primary path.
+ primary: bool,
+ /// Whether the current path is considered valid.
+ state: ProbeState,
+ /// For a path that is not validated, this is `None`. For a validated
+ /// path, the time that the path was last valid.
+ validated: Option<Instant>,
+ /// A path challenge was received and PATH_RESPONSE has not been sent.
+ challenge: Option<[u8; 8]>,
+
+ /// The round trip time estimate for this path.
+ rtt: RttEstimate,
+ /// A packet sender for the path, which includes congestion control and a pacer.
+ sender: PacketSender,
+ /// The DSCP/ECN marking to use for outgoing packets on this path.
+ tos: IpTos,
+ /// The IP TTL to use for outgoing packets on this path.
+ ttl: u8,
+
+ /// The number of bytes received on this path.
+ /// Note that this value might saturate on a long-lived connection,
+ /// but we only use it before the path is validated.
+ received_bytes: usize,
+ /// The number of bytes sent on this path.
+ sent_bytes: usize,
+
+ /// For logging of events.
+ qlog: NeqoQlog,
+}
+
+impl Path {
+ /// Create a path from addresses and a remote connection ID.
+ /// This is used for migration and for new datagrams.
+ pub fn temporary(
+ local: SocketAddr,
+ remote: SocketAddr,
+ cc: CongestionControlAlgorithm,
+ pacing: bool,
+ qlog: NeqoQlog,
+ now: Instant,
+ ) -> Self {
+ let mut sender = PacketSender::new(cc, pacing, Self::mtu_by_addr(remote.ip()), now);
+ sender.set_qlog(qlog.clone());
+ Self {
+ local,
+ remote,
+ local_cid: None,
+ remote_cid: None,
+ primary: false,
+ state: ProbeState::ProbeNeeded { probe_count: 0 },
+ validated: None,
+ challenge: None,
+ rtt: RttEstimate::default(),
+ sender,
+ tos: IpTos::default(), // TODO: Default to Ect0 when ECN is supported.
+ ttl: 64, // This is the default TTL on many OSes.
+ received_bytes: 0,
+ sent_bytes: 0,
+ qlog,
+ }
+ }
+
+ /// Whether this path is the primary or current path for the connection.
+ pub fn is_primary(&self) -> bool {
+ self.primary
+ }
+
+ /// Whether this path is a temporary one.
+ pub fn is_temporary(&self) -> bool {
+ self.remote_cid.is_none()
+ }
+
+ /// By adding a remote connection ID, we make the path permanent
+ /// and one that we will later send packets on.
+ /// If `local_cid` is `None`, the existing value will be kept.
+ pub(crate) fn make_permanent(
+ &mut self,
+ local_cid: Option<ConnectionId>,
+ remote_cid: RemoteConnectionIdEntry,
+ ) {
+ if self.local_cid.is_none() {
+ self.local_cid = local_cid;
+ }
+ self.remote_cid.replace(remote_cid);
+ }
+
+ /// Determine if this path was the one that the provided datagram was received on.
+ /// This uses the full local socket address, but ignores the port number on the peer
+ /// if `flexible` is true, allowing for NAT rebinding that retains the same IP.
+ fn received_on(&self, local: SocketAddr, remote: SocketAddr, flexible: bool) -> bool {
+ self.local == local
+ && self.remote.ip() == remote.ip()
+ && (flexible || self.remote.port() == remote.port())
+ }
+
+ /// Update the remote port number. Any flexibility we allow in `received_on`
+ /// need to be adjusted at this point.
+ fn update_port(&mut self, port: u16) {
+ self.remote.set_port(port);
+ }
+
+ /// Set whether this path is primary.
+ pub(crate) fn set_primary(&mut self, primary: bool) {
+ qtrace!([self], "Make primary {}", primary);
+ debug_assert!(self.remote_cid.is_some());
+ self.primary = primary;
+ if !primary {
+ self.sender.discard_in_flight();
+ }
+ }
+
+ /// Set the current path as valid. This updates the time that the path was
+ /// last validated and cancels any path validation.
+ pub fn set_valid(&mut self, now: Instant) {
+ qdebug!([self], "Path validated {:?}", now);
+ self.state = ProbeState::Valid;
+ self.validated = Some(now);
+ }
+
+ /// Update the last use of this path, if it is valid.
+ /// This will keep the path active slightly longer.
+ pub fn update(&mut self, now: Instant) {
+ if self.validated.is_some() {
+ self.validated = Some(now);
+ }
+ }
+
+ fn mtu_by_addr(addr: IpAddr) -> usize {
+ match addr {
+ IpAddr::V4(_) => PATH_MTU_V4,
+ IpAddr::V6(_) => PATH_MTU_V6,
+ }
+ }
+
+ /// Get the path MTU. This is currently fixed based on IP version.
+ pub fn mtu(&self) -> usize {
+ Self::mtu_by_addr(self.remote.ip())
+ }
+
+ /// Get the first local connection ID.
+ /// Only do this for the primary path during the handshake.
+ pub fn local_cid(&self) -> &ConnectionId {
+ self.local_cid.as_ref().unwrap()
+ }
+
+ /// Set the remote connection ID based on the peer's choice.
+ /// This is only valid during the handshake.
+ pub fn set_remote_cid(&mut self, cid: ConnectionIdRef) {
+ self.remote_cid
+ .as_mut()
+ .unwrap()
+ .update_cid(ConnectionId::from(cid));
+ }
+
+ /// Access the remote connection ID.
+ pub fn remote_cid(&self) -> &ConnectionId {
+ self.remote_cid.as_ref().unwrap().connection_id()
+ }
+
+ /// Set the stateless reset token for the connection ID that is currently in use.
+ /// Panics if the sequence number is non-zero as this is only necessary during
+ /// the handshake; all other connection IDs are initialized with a token.
+ pub fn set_reset_token(&mut self, token: [u8; 16]) {
+ self.remote_cid
+ .as_mut()
+ .unwrap()
+ .set_stateless_reset_token(token);
+ }
+
+ /// Determine if the provided token is a stateless reset token.
+ pub fn is_stateless_reset(&self, token: &[u8; 16]) -> bool {
+ self.remote_cid
+ .as_ref()
+ .map_or(false, |rcid| rcid.is_stateless_reset(token))
+ }
+
+ /// Make a datagram.
+ pub fn datagram<V: Into<Vec<u8>>>(&self, payload: V) -> Datagram {
+ Datagram::new(self.local, self.remote, self.tos, Some(self.ttl), payload)
+ }
+
+ /// Get local address as `SocketAddr`
+ pub fn local_address(&self) -> SocketAddr {
+ self.local
+ }
+
+ /// Get remote address as `SocketAddr`
+ pub fn remote_address(&self) -> SocketAddr {
+ self.remote
+ }
+
+ /// Whether the path has been validated.
+ pub fn is_valid(&self) -> bool {
+ self.validated.is_some()
+ }
+
+ /// Handle a `PATH_RESPONSE` frame. Returns true if the response was accepted.
+ pub fn path_response(&mut self, response: [u8; 8], now: Instant) -> bool {
+ if let ProbeState::Probing { data, mtu, .. } = &mut self.state {
+ if response == *data {
+ let need_full_probe = !*mtu;
+ self.set_valid(now);
+ if need_full_probe {
+ qdebug!([self], "Sub-MTU probe successful, reset probe count");
+ self.probe();
+ }
+ true
+ } else {
+ false
+ }
+ } else {
+ false
+ }
+ }
+
+ /// The path has been challenged. This generates a response.
+ /// This only generates a single response at a time.
+ pub fn challenged(&mut self, challenge: [u8; 8]) {
+ self.challenge = Some(challenge.to_owned());
+ }
+
+ /// At the next opportunity, send a probe.
+ /// If the probe count has been exhausted already, marks the path as failed.
+ fn probe(&mut self) {
+ let probe_count = match &self.state {
+ ProbeState::Probing { probe_count, .. } => *probe_count + 1,
+ ProbeState::ProbeNeeded { probe_count, .. } => *probe_count,
+ _ => 0,
+ };
+ self.state = if probe_count >= MAX_PATH_PROBES {
+ qinfo!([self], "Probing failed");
+ ProbeState::Failed
+ } else {
+ qdebug!([self], "Initiating probe");
+ ProbeState::ProbeNeeded { probe_count }
+ };
+ }
+
+ /// Returns true if this path have any probing frames to send.
+ pub fn has_probe(&self) -> bool {
+ self.challenge.is_some() || self.state.probe_needed()
+ }
+
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ stats: &mut FrameStats,
+ mtu: bool, // Whether the packet we're writing into will be a full MTU.
+ now: Instant,
+ ) -> bool {
+ if builder.remaining() < 9 {
+ return false;
+ }
+
+ // Send PATH_RESPONSE.
+ let resp_sent = if let Some(challenge) = self.challenge.take() {
+ qtrace!([self], "Responding to path challenge {}", hex(challenge));
+ builder.encode_varint(FRAME_TYPE_PATH_RESPONSE);
+ builder.encode(&challenge[..]);
+
+ // These frames are not retransmitted in the usual fashion.
+ // There is no token, therefore we need to count `all` specially.
+ stats.path_response += 1;
+ stats.all += 1;
+
+ if builder.remaining() < 9 {
+ return true;
+ }
+ true
+ } else {
+ false
+ };
+
+ // Send PATH_CHALLENGE.
+ if let ProbeState::ProbeNeeded { probe_count } = self.state {
+ qtrace!([self], "Initiating path challenge {}", probe_count);
+ let data = <[u8; 8]>::try_from(&random(8)[..]).unwrap();
+ builder.encode_varint(FRAME_TYPE_PATH_CHALLENGE);
+ builder.encode(&data);
+
+ // As above, no recovery token.
+ stats.path_challenge += 1;
+ stats.all += 1;
+
+ self.state = ProbeState::Probing {
+ probe_count,
+ data,
+ mtu,
+ sent: now,
+ };
+ true
+ } else {
+ resp_sent
+ }
+ }
+
+ /// Write `ACK_FREQUENCY` frames.
+ pub fn write_cc_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ self.rtt.write_frames(builder, tokens, stats);
+ }
+
+ pub fn lost_ack_frequency(&mut self, lost: &AckRate) {
+ self.rtt.frame_lost(lost);
+ }
+
+ pub fn acked_ack_frequency(&mut self, acked: &AckRate) {
+ self.rtt.frame_acked(acked);
+ }
+
+ /// Process a timer for this path.
+ /// This returns true if the path is viable and can be kept alive.
+ pub fn process_timeout(&mut self, now: Instant, pto: Duration) -> bool {
+ if let ProbeState::Probing { sent, .. } = &self.state {
+ if now >= *sent + pto {
+ self.probe();
+ }
+ }
+ if let ProbeState::Failed = self.state {
+ // Retire failed paths immediately.
+ false
+ } else if self.primary {
+ // Keep valid primary paths otherwise.
+ true
+ } else if let ProbeState::Valid = self.state {
+ // Retire validated, non-primary paths.
+ // Allow more than `MAX_PATH_PROBES` times the PTO so that an old
+ // path remains around until after a previous path fails.
+ let count = u32::try_from(MAX_PATH_PROBES + 1).unwrap();
+ self.validated.unwrap() + (pto * count) > now
+ } else {
+ // Keep paths that are being actively probed.
+ true
+ }
+ }
+
+ /// Return the next time that this path needs servicing.
+ /// This only considers retransmissions of probes, not cleanup of the path.
+ /// If there is no other activity, then there is no real need to schedule a
+ /// timer to cleanup old paths.
+ pub fn next_timeout(&self, pto: Duration) -> Option<Instant> {
+ if let ProbeState::Probing { sent, .. } = &self.state {
+ Some(*sent + pto)
+ } else {
+ None
+ }
+ }
+
+ /// Get the RTT estimator for this path.
+ pub fn rtt(&self) -> &RttEstimate {
+ &self.rtt
+ }
+
+ /// Mutably borrow the RTT estimator for this path.
+ pub fn rtt_mut(&mut self) -> &mut RttEstimate {
+ &mut self.rtt
+ }
+
+ /// Read-only access to the owned sender.
+ pub fn sender(&self) -> &PacketSender {
+ &self.sender
+ }
+
+ /// Pass on RTT configuration: the maximum acknowledgment delay of the peer,
+ /// and maybe the minimum delay.
+ pub fn set_ack_delay(
+ &mut self,
+ max_ack_delay: Duration,
+ min_ack_delay: Option<Duration>,
+ ack_ratio: u8,
+ ) {
+ let ack_delay = min_ack_delay.map_or_else(
+ || PeerAckDelay::fixed(max_ack_delay),
+ |m| {
+ PeerAckDelay::flexible(
+ max_ack_delay,
+ m,
+ ack_ratio,
+ self.sender.cwnd(),
+ self.mtu(),
+ self.rtt.estimate(),
+ )
+ },
+ );
+ self.rtt.set_ack_delay(ack_delay);
+ }
+
+ /// Initialize the RTT for the path based on an existing estimate.
+ pub fn prime_rtt(&mut self, rtt: &RttEstimate) {
+ self.rtt.prime_rtt(rtt);
+ }
+
+ /// Record received bytes for the path.
+ pub fn add_received(&mut self, count: usize) {
+ self.received_bytes = self.received_bytes.saturating_add(count);
+ }
+
+ /// Record sent bytes for the path.
+ pub fn add_sent(&mut self, count: usize) {
+ self.sent_bytes = self.sent_bytes.saturating_add(count);
+ }
+
+ /// Record a packet as having been sent on this path.
+ pub fn packet_sent(&mut self, sent: &mut SentPacket) {
+ if !self.is_primary() {
+ sent.clear_primary_path();
+ }
+ self.sender.on_packet_sent(sent, self.rtt.estimate());
+ }
+
+ /// Discard a packet that previously might have been in-flight.
+ pub fn discard_packet(&mut self, sent: &SentPacket, now: Instant, stats: &mut Stats) {
+ if self.rtt.first_sample_time().is_none() {
+ // When discarding a packet there might not be a good RTT estimate.
+ // But discards only occur after receiving something, so that means
+ // that there is some RTT information, which is better than nothing.
+ // Two cases: 1. at the client when handling a Retry and
+ // 2. at the server when disposing the Initial packet number space.
+ qinfo!(
+ [self],
+ "discarding a packet without an RTT estimate; guessing RTT={:?}",
+ now - sent.time_sent
+ );
+ stats.rtt_init_guess = true;
+ self.rtt.update(
+ &mut self.qlog,
+ now - sent.time_sent,
+ Duration::new(0, 0),
+ false,
+ now,
+ );
+ }
+
+ self.sender.discard(sent);
+ }
+
+ /// Record packets as acknowledged with the sender.
+ pub fn on_packets_acked(&mut self, acked_pkts: &[SentPacket], now: Instant) {
+ debug_assert!(self.is_primary());
+ self.sender.on_packets_acked(acked_pkts, &self.rtt, now);
+ }
+
+ /// Record packets as lost with the sender.
+ pub fn on_packets_lost(
+ &mut self,
+ prev_largest_acked_sent: Option<Instant>,
+ space: PacketNumberSpace,
+ lost_packets: &[SentPacket],
+ ) {
+ debug_assert!(self.is_primary());
+ let cwnd_reduced = self.sender.on_packets_lost(
+ self.rtt.first_sample_time(),
+ prev_largest_acked_sent,
+ self.rtt.pto(space), // Important: the base PTO, not adjusted.
+ lost_packets,
+ );
+ if cwnd_reduced {
+ self.rtt.update_ack_delay(self.sender.cwnd(), self.mtu());
+ }
+ }
+
+ /// Get the number of bytes that can be written to this path.
+ pub fn amplification_limit(&self) -> usize {
+ if matches!(self.state, ProbeState::Failed) {
+ 0
+ } else if self.is_valid() {
+ usize::MAX
+ } else {
+ self.received_bytes
+ .checked_mul(3)
+ .map_or(usize::MAX, |limit| {
+ let budget = if limit == 0 {
+ // If we have received absolutely nothing thus far, then this endpoint
+ // is the one initiating communication on this path. Allow enough space for
+ // probing.
+ self.mtu() * 5
+ } else {
+ limit
+ };
+ budget.saturating_sub(self.sent_bytes)
+ })
+ }
+ }
+
+ /// Update the `NeqoQLog` instance.
+ pub fn set_qlog(&mut self, qlog: NeqoQlog) {
+ self.sender.set_qlog(qlog);
+ }
+}
+
+impl Display for Path {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ if self.is_primary() {
+ write!(f, "pri-")?; // primary
+ }
+ if !self.is_valid() {
+ write!(f, "unv-")?; // unvalidated
+ }
+ write!(f, "path")?;
+ if let Some(entry) = self.remote_cid.as_ref() {
+ write!(f, ":{}", entry.connection_id())?;
+ }
+ write!(f, " {}->{}", self.local, self.remote)?;
+ Ok(())
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/qlog.rs b/third_party/rust/neqo-transport/src/qlog.rs
new file mode 100644
index 0000000000..434395fd23
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/qlog.rs
@@ -0,0 +1,563 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Functions that handle capturing QLOG traces.
+
+use std::{
+ convert::TryFrom,
+ ops::{Deref, RangeInclusive},
+ string::String,
+ time::Duration,
+};
+
+use neqo_common::{hex, qinfo, qlog::NeqoQlog, Decoder};
+use qlog::events::{
+ connectivity::{ConnectionStarted, ConnectionState, ConnectionStateUpdated},
+ quic::{
+ AckedRanges, ErrorSpace, MetricsUpdated, PacketDropped, PacketHeader, PacketLost,
+ PacketReceived, PacketSent, QuicFrame, StreamType, VersionInformation,
+ },
+ EventData, RawInfo,
+};
+use smallvec::SmallVec;
+
+use crate::{
+ connection::State,
+ frame::{CloseError, Frame},
+ packet::{DecryptedPacket, PacketNumber, PacketType, PublicPacket},
+ path::PathRef,
+ stream_id::StreamType as NeqoStreamType,
+ tparams::{self, TransportParametersHandler},
+ tracking::SentPacket,
+ version::{Version, VersionConfig, WireVersion},
+};
+
+pub fn connection_tparams_set(qlog: &mut NeqoQlog, tph: &TransportParametersHandler) {
+ qlog.add_event_data(|| {
+ let remote = tph.remote();
+ let ev_data = EventData::TransportParametersSet(
+ qlog::events::quic::TransportParametersSet {
+ owner: None,
+ resumption_allowed: None,
+ early_data_enabled: None,
+ tls_cipher: None,
+ aead_tag_length: None,
+ original_destination_connection_id: remote
+ .get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID)
+ .map(hex),
+ initial_source_connection_id: None,
+ retry_source_connection_id: None,
+ stateless_reset_token: remote.get_bytes(tparams::STATELESS_RESET_TOKEN).map(hex),
+ disable_active_migration: if remote.get_empty(tparams::DISABLE_MIGRATION) {
+ Some(true)
+ } else {
+ None
+ },
+ max_idle_timeout: Some(remote.get_integer(tparams::IDLE_TIMEOUT)),
+ max_udp_payload_size: Some(remote.get_integer(tparams::MAX_UDP_PAYLOAD_SIZE) as u32),
+ ack_delay_exponent: Some(remote.get_integer(tparams::ACK_DELAY_EXPONENT) as u16),
+ max_ack_delay: Some(remote.get_integer(tparams::MAX_ACK_DELAY) as u16),
+ active_connection_id_limit: Some(remote.get_integer(tparams::ACTIVE_CONNECTION_ID_LIMIT) as u32),
+ initial_max_data: Some(remote.get_integer(tparams::INITIAL_MAX_DATA)),
+ initial_max_stream_data_bidi_local: Some(remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL)),
+ initial_max_stream_data_bidi_remote: Some(remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE)),
+ initial_max_stream_data_uni: Some(remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI)),
+ initial_max_streams_bidi: Some(remote.get_integer(tparams::INITIAL_MAX_STREAMS_BIDI)),
+ initial_max_streams_uni: Some(remote.get_integer(tparams::INITIAL_MAX_STREAMS_UNI)),
+ preferred_address: remote.get_preferred_address().and_then(|(paddr, cid)| {
+ Some(qlog::events::quic::PreferredAddress {
+ ip_v4: paddr.ipv4()?.ip().to_string(),
+ ip_v6: paddr.ipv6()?.ip().to_string(),
+ port_v4: paddr.ipv4()?.port(),
+ port_v6: paddr.ipv6()?.port(),
+ connection_id: cid.connection_id().to_string(),
+ stateless_reset_token: hex(cid.reset_token()),
+ })
+ }),
+ });
+
+ Some(ev_data)
+ });
+}
+
+pub fn server_connection_started(qlog: &mut NeqoQlog, path: &PathRef) {
+ connection_started(qlog, path);
+}
+
+pub fn client_connection_started(qlog: &mut NeqoQlog, path: &PathRef) {
+ connection_started(qlog, path);
+}
+
+fn connection_started(qlog: &mut NeqoQlog, path: &PathRef) {
+ qlog.add_event_data(|| {
+ let p = path.deref().borrow();
+ let ev_data = EventData::ConnectionStarted(ConnectionStarted {
+ ip_version: if p.local_address().ip().is_ipv4() {
+ Some("ipv4".into())
+ } else {
+ Some("ipv6".into())
+ },
+ src_ip: format!("{}", p.local_address().ip()),
+ dst_ip: format!("{}", p.remote_address().ip()),
+ protocol: Some("QUIC".into()),
+ src_port: p.local_address().port().into(),
+ dst_port: p.remote_address().port().into(),
+ src_cid: Some(format!("{}", p.local_cid())),
+ dst_cid: Some(format!("{}", p.remote_cid())),
+ });
+
+ Some(ev_data)
+ });
+}
+
+pub fn connection_state_updated(qlog: &mut NeqoQlog, new: &State) {
+ qlog.add_event_data(|| {
+ let ev_data = EventData::ConnectionStateUpdated(ConnectionStateUpdated {
+ old: None,
+ new: match new {
+ State::Init | State::WaitInitial => ConnectionState::Attempted,
+ State::WaitVersion | State::Handshaking => ConnectionState::HandshakeStarted,
+ State::Connected => ConnectionState::HandshakeCompleted,
+ State::Confirmed => ConnectionState::HandshakeConfirmed,
+ State::Closing { .. } => ConnectionState::Closing,
+ State::Draining { .. } => ConnectionState::Draining,
+ State::Closed { .. } => ConnectionState::Closed,
+ },
+ });
+
+ Some(ev_data)
+ });
+}
+
+pub fn client_version_information_initiated(qlog: &mut NeqoQlog, version_config: &VersionConfig) {
+ qlog.add_event_data(|| {
+ Some(EventData::VersionInformation(VersionInformation {
+ client_versions: Some(
+ version_config
+ .all()
+ .iter()
+ .map(|v| format!("{:02x}", v.wire_version()))
+ .collect(),
+ ),
+ server_versions: None,
+ chosen_version: Some(format!("{:02x}", version_config.initial().wire_version())),
+ }))
+ });
+}
+
+pub fn client_version_information_negotiated(
+ qlog: &mut NeqoQlog,
+ client: &[Version],
+ server: &[WireVersion],
+ chosen: Version,
+) {
+ qlog.add_event_data(|| {
+ Some(EventData::VersionInformation(VersionInformation {
+ client_versions: Some(
+ client
+ .iter()
+ .map(|v| format!("{:02x}", v.wire_version()))
+ .collect(),
+ ),
+ server_versions: Some(server.iter().map(|v| format!("{v:02x}")).collect()),
+ chosen_version: Some(format!("{:02x}", chosen.wire_version())),
+ }))
+ });
+}
+
+pub fn server_version_information_failed(
+ qlog: &mut NeqoQlog,
+ server: &[Version],
+ client: WireVersion,
+) {
+ qlog.add_event_data(|| {
+ Some(EventData::VersionInformation(VersionInformation {
+ client_versions: Some(vec![format!("{client:02x}")]),
+ server_versions: Some(
+ server
+ .iter()
+ .map(|v| format!("{:02x}", v.wire_version()))
+ .collect(),
+ ),
+ chosen_version: None,
+ }))
+ });
+}
+
+pub fn packet_sent(
+ qlog: &mut NeqoQlog,
+ pt: PacketType,
+ pn: PacketNumber,
+ plen: usize,
+ body: &[u8],
+) {
+ qlog.add_event_with_stream(|stream| {
+ let mut d = Decoder::from(body);
+ let header = PacketHeader::with_type(to_qlog_pkt_type(pt), Some(pn), None, None, None);
+ let raw = RawInfo {
+ length: Some(plen as u64),
+ payload_length: None,
+ data: None,
+ };
+
+ let mut frames = SmallVec::new();
+ while d.remaining() > 0 {
+ if let Ok(f) = Frame::decode(&mut d) {
+ frames.push(frame_to_qlogframe(&f))
+ } else {
+ qinfo!("qlog: invalid frame");
+ break;
+ }
+ }
+
+ let ev_data = EventData::PacketSent(PacketSent {
+ header,
+ frames: Some(frames),
+ is_coalesced: None,
+ retry_token: None,
+ stateless_reset_token: None,
+ supported_versions: None,
+ raw: Some(raw),
+ datagram_id: None,
+ send_at_time: None,
+ trigger: None,
+ });
+
+ stream.add_event_data_now(ev_data)
+ });
+}
+
+pub fn packet_dropped(qlog: &mut NeqoQlog, public_packet: &PublicPacket) {
+ qlog.add_event_data(|| {
+ let header = PacketHeader::with_type(
+ to_qlog_pkt_type(public_packet.packet_type()),
+ None,
+ None,
+ None,
+ None,
+ );
+ let raw = RawInfo {
+ length: Some(public_packet.len() as u64),
+ payload_length: None,
+ data: None,
+ };
+
+ let ev_data = EventData::PacketDropped(PacketDropped {
+ header: Some(header),
+ raw: Some(raw),
+ datagram_id: None,
+ details: None,
+ trigger: None,
+ });
+
+ Some(ev_data)
+ });
+}
+
+pub fn packets_lost(qlog: &mut NeqoQlog, pkts: &[SentPacket]) {
+ qlog.add_event_with_stream(|stream| {
+ for pkt in pkts {
+ let header =
+ PacketHeader::with_type(to_qlog_pkt_type(pkt.pt), Some(pkt.pn), None, None, None);
+
+ let ev_data = EventData::PacketLost(PacketLost {
+ header: Some(header),
+ trigger: None,
+ frames: None,
+ });
+
+ stream.add_event_data_now(ev_data)?;
+ }
+ Ok(())
+ });
+}
+
+pub fn packet_received(
+ qlog: &mut NeqoQlog,
+ public_packet: &PublicPacket,
+ payload: &DecryptedPacket,
+) {
+ qlog.add_event_with_stream(|stream| {
+ let mut d = Decoder::from(&payload[..]);
+
+ let header = PacketHeader::with_type(
+ to_qlog_pkt_type(public_packet.packet_type()),
+ Some(payload.pn()),
+ None,
+ None,
+ None,
+ );
+ let raw = RawInfo {
+ length: Some(public_packet.len() as u64),
+ payload_length: None,
+ data: None,
+ };
+
+ let mut frames = Vec::new();
+
+ while d.remaining() > 0 {
+ if let Ok(f) = Frame::decode(&mut d) {
+ frames.push(frame_to_qlogframe(&f))
+ } else {
+ qinfo!("qlog: invalid frame");
+ break;
+ }
+ }
+
+ let ev_data = EventData::PacketReceived(PacketReceived {
+ header,
+ frames: Some(frames),
+ is_coalesced: None,
+ retry_token: None,
+ stateless_reset_token: None,
+ supported_versions: None,
+ raw: Some(raw),
+ datagram_id: None,
+ trigger: None,
+ });
+
+ stream.add_event_data_now(ev_data)
+ });
+}
+
+#[allow(dead_code)]
+pub enum QlogMetric {
+ MinRtt(Duration),
+ SmoothedRtt(Duration),
+ LatestRtt(Duration),
+ RttVariance(u64),
+ MaxAckDelay(u64),
+ PtoCount(usize),
+ CongestionWindow(usize),
+ BytesInFlight(usize),
+ SsThresh(usize),
+ PacketsInFlight(u64),
+ InRecovery(bool),
+ PacingRate(u64),
+}
+
+pub fn metrics_updated(qlog: &mut NeqoQlog, updated_metrics: &[QlogMetric]) {
+ debug_assert!(!updated_metrics.is_empty());
+
+ qlog.add_event_data(|| {
+ let mut min_rtt: Option<f32> = None;
+ let mut smoothed_rtt: Option<f32> = None;
+ let mut latest_rtt: Option<f32> = None;
+ let mut rtt_variance: Option<f32> = None;
+ let mut pto_count: Option<u16> = None;
+ let mut congestion_window: Option<u64> = None;
+ let mut bytes_in_flight: Option<u64> = None;
+ let mut ssthresh: Option<u64> = None;
+ let mut packets_in_flight: Option<u64> = None;
+ let mut pacing_rate: Option<u64> = None;
+
+ for metric in updated_metrics {
+ match metric {
+ QlogMetric::MinRtt(v) => min_rtt = Some(v.as_secs_f32() * 1000.0),
+ QlogMetric::SmoothedRtt(v) => smoothed_rtt = Some(v.as_secs_f32() * 1000.0),
+ QlogMetric::LatestRtt(v) => latest_rtt = Some(v.as_secs_f32() * 1000.0),
+ QlogMetric::RttVariance(v) => rtt_variance = Some(*v as f32),
+ QlogMetric::PtoCount(v) => pto_count = Some(u16::try_from(*v).unwrap()),
+ QlogMetric::CongestionWindow(v) => {
+ congestion_window = Some(u64::try_from(*v).unwrap());
+ }
+ QlogMetric::BytesInFlight(v) => bytes_in_flight = Some(u64::try_from(*v).unwrap()),
+ QlogMetric::SsThresh(v) => ssthresh = Some(u64::try_from(*v).unwrap()),
+ QlogMetric::PacketsInFlight(v) => packets_in_flight = Some(*v),
+ QlogMetric::PacingRate(v) => pacing_rate = Some(*v),
+ _ => (),
+ }
+ }
+
+ let ev_data = EventData::MetricsUpdated(MetricsUpdated {
+ min_rtt,
+ smoothed_rtt,
+ latest_rtt,
+ rtt_variance,
+ pto_count,
+ congestion_window,
+ bytes_in_flight,
+ ssthresh,
+ packets_in_flight,
+ pacing_rate,
+ });
+
+ Some(ev_data)
+ });
+}
+
+// Helper functions
+
+fn frame_to_qlogframe(frame: &Frame) -> QuicFrame {
+ match frame {
+ Frame::Padding => QuicFrame::Padding,
+ Frame::Ping => QuicFrame::Ping,
+ Frame::Ack {
+ largest_acknowledged,
+ ack_delay,
+ first_ack_range,
+ ack_ranges,
+ } => {
+ let ranges =
+ Frame::decode_ack_frame(*largest_acknowledged, *first_ack_range, ack_ranges).ok();
+
+ let acked_ranges = ranges.map(|all| {
+ AckedRanges::Double(
+ all.into_iter()
+ .map(RangeInclusive::into_inner)
+ .collect::<Vec<_>>(),
+ )
+ });
+
+ QuicFrame::Ack {
+ ack_delay: Some(*ack_delay as f32 / 1000.0),
+ acked_ranges,
+ ect1: None,
+ ect0: None,
+ ce: None,
+ }
+ }
+ Frame::ResetStream {
+ stream_id,
+ application_error_code,
+ final_size,
+ } => QuicFrame::ResetStream {
+ stream_id: stream_id.as_u64(),
+ error_code: *application_error_code,
+ final_size: *final_size,
+ },
+ Frame::StopSending {
+ stream_id,
+ application_error_code,
+ } => QuicFrame::StopSending {
+ stream_id: stream_id.as_u64(),
+ error_code: *application_error_code,
+ },
+ Frame::Crypto { offset, data } => QuicFrame::Crypto {
+ offset: *offset,
+ length: data.len() as u64,
+ },
+ Frame::NewToken { token } => QuicFrame::NewToken {
+ token: qlog::Token {
+ ty: Some(qlog::TokenType::Retry),
+ details: None,
+ raw: Some(RawInfo {
+ data: Some(hex(token)),
+ length: Some(token.len() as u64),
+ payload_length: None,
+ }),
+ },
+ },
+ Frame::Stream {
+ fin,
+ stream_id,
+ offset,
+ data,
+ ..
+ } => QuicFrame::Stream {
+ stream_id: stream_id.as_u64(),
+ offset: *offset,
+ length: data.len() as u64,
+ fin: Some(*fin),
+ raw: None,
+ },
+ Frame::MaxData { maximum_data } => QuicFrame::MaxData {
+ maximum: *maximum_data,
+ },
+ Frame::MaxStreamData {
+ stream_id,
+ maximum_stream_data,
+ } => QuicFrame::MaxStreamData {
+ stream_id: stream_id.as_u64(),
+ maximum: *maximum_stream_data,
+ },
+ Frame::MaxStreams {
+ stream_type,
+ maximum_streams,
+ } => QuicFrame::MaxStreams {
+ stream_type: match stream_type {
+ NeqoStreamType::BiDi => StreamType::Bidirectional,
+ NeqoStreamType::UniDi => StreamType::Unidirectional,
+ },
+ maximum: *maximum_streams,
+ },
+ Frame::DataBlocked { data_limit } => QuicFrame::DataBlocked { limit: *data_limit },
+ Frame::StreamDataBlocked {
+ stream_id,
+ stream_data_limit,
+ } => QuicFrame::StreamDataBlocked {
+ stream_id: stream_id.as_u64(),
+ limit: *stream_data_limit,
+ },
+ Frame::StreamsBlocked {
+ stream_type,
+ stream_limit,
+ } => QuicFrame::StreamsBlocked {
+ stream_type: match stream_type {
+ NeqoStreamType::BiDi => StreamType::Bidirectional,
+ NeqoStreamType::UniDi => StreamType::Unidirectional,
+ },
+ limit: *stream_limit,
+ },
+ Frame::NewConnectionId {
+ sequence_number,
+ retire_prior,
+ connection_id,
+ stateless_reset_token,
+ } => QuicFrame::NewConnectionId {
+ sequence_number: *sequence_number as u32,
+ retire_prior_to: *retire_prior as u32,
+ connection_id_length: Some(connection_id.len() as u8),
+ connection_id: hex(connection_id),
+ stateless_reset_token: Some(hex(stateless_reset_token)),
+ },
+ Frame::RetireConnectionId { sequence_number } => QuicFrame::RetireConnectionId {
+ sequence_number: *sequence_number as u32,
+ },
+ Frame::PathChallenge { data } => QuicFrame::PathChallenge {
+ data: Some(hex(data)),
+ },
+ Frame::PathResponse { data } => QuicFrame::PathResponse {
+ data: Some(hex(data)),
+ },
+ Frame::ConnectionClose {
+ error_code,
+ frame_type,
+ reason_phrase,
+ } => QuicFrame::ConnectionClose {
+ error_space: match error_code {
+ CloseError::Transport(_) => Some(ErrorSpace::TransportError),
+ CloseError::Application(_) => Some(ErrorSpace::ApplicationError),
+ },
+ error_code: Some(error_code.code()),
+ error_code_value: Some(0),
+ reason: Some(String::from_utf8_lossy(reason_phrase).to_string()),
+ trigger_frame_type: Some(*frame_type),
+ },
+ Frame::HandshakeDone => QuicFrame::HandshakeDone,
+ Frame::AckFrequency { .. } => QuicFrame::Unknown {
+ frame_type_value: None,
+ raw_frame_type: frame.get_type(),
+ raw: None,
+ },
+ Frame::Datagram { data, .. } => QuicFrame::Datagram {
+ length: data.len() as u64,
+ raw: None,
+ },
+ }
+}
+
+fn to_qlog_pkt_type(ptype: PacketType) -> qlog::events::quic::PacketType {
+ match ptype {
+ PacketType::Initial => qlog::events::quic::PacketType::Initial,
+ PacketType::Handshake => qlog::events::quic::PacketType::Handshake,
+ PacketType::ZeroRtt => qlog::events::quic::PacketType::ZeroRtt,
+ PacketType::Short => qlog::events::quic::PacketType::OneRtt,
+ PacketType::Retry => qlog::events::quic::PacketType::Retry,
+ PacketType::VersionNegotiation => qlog::events::quic::PacketType::VersionNegotiation,
+ PacketType::OtherVersion => qlog::events::quic::PacketType::Unknown,
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/quic_datagrams.rs b/third_party/rust/neqo-transport/src/quic_datagrams.rs
new file mode 100644
index 0000000000..07f3594768
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/quic_datagrams.rs
@@ -0,0 +1,185 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// https://datatracker.ietf.org/doc/html/draft-ietf-quic-datagram
+
+use std::{cmp::min, collections::VecDeque, convert::TryFrom};
+
+use neqo_common::Encoder;
+
+use crate::{
+ events::OutgoingDatagramOutcome,
+ frame::{FRAME_TYPE_DATAGRAM, FRAME_TYPE_DATAGRAM_WITH_LEN},
+ packet::PacketBuilder,
+ recovery::RecoveryToken,
+ ConnectionEvents, Error, Res, Stats,
+};
+
+pub const MAX_QUIC_DATAGRAM: u64 = 65535;
+
+#[derive(Debug, Clone, Copy)]
+pub enum DatagramTracking {
+ None,
+ Id(u64),
+}
+
+impl From<Option<u64>> for DatagramTracking {
+ fn from(v: Option<u64>) -> Self {
+ match v {
+ Some(id) => Self::Id(id),
+ None => Self::None,
+ }
+ }
+}
+
+impl From<DatagramTracking> for Option<u64> {
+ fn from(v: DatagramTracking) -> Self {
+ match v {
+ DatagramTracking::Id(id) => Some(id),
+ DatagramTracking::None => None,
+ }
+ }
+}
+
+struct QuicDatagram {
+ data: Vec<u8>,
+ tracking: DatagramTracking,
+}
+
+impl QuicDatagram {
+ fn tracking(&self) -> &DatagramTracking {
+ &self.tracking
+ }
+}
+
+impl AsRef<[u8]> for QuicDatagram {
+ #[must_use]
+ fn as_ref(&self) -> &[u8] {
+ &self.data[..]
+ }
+}
+
+pub struct QuicDatagrams {
+ /// The max size of a datagram that would be acceptable.
+ local_datagram_size: u64,
+ /// The max size of a datagram that would be acceptable by the peer.
+ remote_datagram_size: u64,
+ max_queued_outgoing_datagrams: usize,
+ /// The max number of datagrams that will be queued in connection events.
+ /// If the number is exceeded, the oldest datagram will be dropped.
+ max_queued_incoming_datagrams: usize,
+ /// Datagram queued for sending.
+ datagrams: VecDeque<QuicDatagram>,
+ conn_events: ConnectionEvents,
+}
+
+impl QuicDatagrams {
+ pub fn new(
+ local_datagram_size: u64,
+ max_queued_outgoing_datagrams: usize,
+ max_queued_incoming_datagrams: usize,
+ conn_events: ConnectionEvents,
+ ) -> Self {
+ Self {
+ local_datagram_size,
+ remote_datagram_size: 0,
+ max_queued_outgoing_datagrams,
+ max_queued_incoming_datagrams,
+ datagrams: VecDeque::with_capacity(max_queued_outgoing_datagrams),
+ conn_events,
+ }
+ }
+
+ pub fn remote_datagram_size(&self) -> u64 {
+ self.remote_datagram_size
+ }
+
+ pub fn set_remote_datagram_size(&mut self, v: u64) {
+ self.remote_datagram_size = min(v, MAX_QUIC_DATAGRAM);
+ }
+
+ /// This function tries to write a datagram frame into a packet.
+ /// If the frame does not fit into the packet, the datagram will
+ /// be dropped and a DatagramLost event will be posted.
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut Stats,
+ ) {
+ while let Some(dgram) = self.datagrams.pop_front() {
+ let len = dgram.as_ref().len();
+ if builder.remaining() > len {
+ // We need 1 more than `len` for the Frame type.
+ let length_len = Encoder::varint_len(u64::try_from(len).unwrap());
+ // Include a length if there is space for another frame after this one.
+ if builder.remaining() >= 1 + length_len + len + PacketBuilder::MINIMUM_FRAME_SIZE {
+ builder.encode_varint(FRAME_TYPE_DATAGRAM_WITH_LEN);
+ builder.encode_vvec(dgram.as_ref());
+ } else {
+ builder.encode_varint(FRAME_TYPE_DATAGRAM);
+ builder.encode(dgram.as_ref());
+ builder.mark_full();
+ }
+ debug_assert!(builder.len() <= builder.limit());
+ stats.frame_tx.datagram += 1;
+ tokens.push(RecoveryToken::Datagram(*dgram.tracking()));
+ } else if tokens.is_empty() {
+ // If the packet is empty, except packet headers, and the
+ // datagram cannot fit, drop it.
+ // Also continue trying to write the next QuicDatagram.
+ self.conn_events
+ .datagram_outcome(dgram.tracking(), OutgoingDatagramOutcome::DroppedTooBig);
+ stats.datagram_tx.dropped_too_big += 1;
+ } else {
+ self.datagrams.push_front(dgram);
+ // Try later on an empty packet.
+ return;
+ }
+ }
+ }
+
+ /// Returns true if there was an unsent datagram that has been dismissed.
+ ///
+ /// # Error
+ ///
+ /// The function returns `TooMuchData` if the supply buffer is bigger than
+ /// the allowed remote datagram size. The funcion does not check if the
+ /// datagram can fit into a packet (i.e. MTU limit). This is checked during
+ /// creation of an actual packet and the datagram will be dropped if it does
+ /// not fit into the packet.
+ pub fn add_datagram(
+ &mut self,
+ buf: &[u8],
+ tracking: DatagramTracking,
+ stats: &mut Stats,
+ ) -> Res<()> {
+ if u64::try_from(buf.len()).unwrap() > self.remote_datagram_size {
+ return Err(Error::TooMuchData);
+ }
+ if self.datagrams.len() == self.max_queued_outgoing_datagrams {
+ self.conn_events.datagram_outcome(
+ self.datagrams.pop_front().unwrap().tracking(),
+ OutgoingDatagramOutcome::DroppedQueueFull,
+ );
+ stats.datagram_tx.dropped_queue_full += 1;
+ }
+ self.datagrams.push_back(QuicDatagram {
+ data: buf.to_vec(),
+ tracking,
+ });
+ Ok(())
+ }
+
+ pub fn handle_datagram(&self, data: &[u8], stats: &mut Stats) -> Res<()> {
+ if self.local_datagram_size < u64::try_from(data.len()).unwrap() {
+ return Err(Error::ProtocolViolation);
+ }
+ self.conn_events
+ .add_datagram(self.max_queued_incoming_datagrams, data, stats);
+ Ok(())
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/recovery.rs b/third_party/rust/neqo-transport/src/recovery.rs
new file mode 100644
index 0000000000..d90989b486
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/recovery.rs
@@ -0,0 +1,1610 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Tracking of sent packets and detecting their loss.
+
+#![deny(clippy::pedantic)]
+
+use std::{
+ cmp::{max, min},
+ collections::BTreeMap,
+ convert::TryFrom,
+ mem,
+ ops::RangeInclusive,
+ time::{Duration, Instant},
+};
+
+use neqo_common::{qdebug, qinfo, qlog::NeqoQlog, qtrace, qwarn};
+use smallvec::{smallvec, SmallVec};
+
+use crate::{
+ ackrate::AckRate,
+ cid::ConnectionIdEntry,
+ crypto::CryptoRecoveryToken,
+ packet::PacketNumber,
+ path::{Path, PathRef},
+ qlog::{self, QlogMetric},
+ quic_datagrams::DatagramTracking,
+ rtt::RttEstimate,
+ send_stream::SendStreamRecoveryToken,
+ stats::{Stats, StatsCell},
+ stream_id::{StreamId, StreamType},
+ tracking::{AckToken, PacketNumberSpace, PacketNumberSpaceSet, SentPacket},
+};
+
+pub(crate) const PACKET_THRESHOLD: u64 = 3;
+/// `ACK_ONLY_SIZE_LIMIT` is the minimum size of the congestion window.
+/// If the congestion window is this small, we will only send ACK frames.
+pub(crate) const ACK_ONLY_SIZE_LIMIT: usize = 256;
+/// The maximum number of packets we send on a PTO.
+/// And the maximum number to declare lost when the PTO timer is hit.
+pub const MAX_PTO_PACKET_COUNT: usize = 2;
+/// The preferred limit on the number of packets that are tracked.
+/// If we exceed this number, we start sending `PING` frames sooner to
+/// force the peer to acknowledge some of them.
+pub(crate) const MAX_OUTSTANDING_UNACK: usize = 200;
+/// Disable PING until this many packets are outstanding.
+pub(crate) const MIN_OUTSTANDING_UNACK: usize = 16;
+/// The scale we use for the fast PTO feature.
+pub const FAST_PTO_SCALE: u8 = 100;
+
+#[derive(Debug, Clone)]
+#[allow(clippy::module_name_repetitions)]
+pub enum StreamRecoveryToken {
+ Stream(SendStreamRecoveryToken),
+ ResetStream {
+ stream_id: StreamId,
+ },
+ StopSending {
+ stream_id: StreamId,
+ },
+
+ MaxData(u64),
+ DataBlocked(u64),
+
+ MaxStreamData {
+ stream_id: StreamId,
+ max_data: u64,
+ },
+ StreamDataBlocked {
+ stream_id: StreamId,
+ limit: u64,
+ },
+
+ MaxStreams {
+ stream_type: StreamType,
+ max_streams: u64,
+ },
+ StreamsBlocked {
+ stream_type: StreamType,
+ limit: u64,
+ },
+}
+
+#[derive(Debug, Clone)]
+#[allow(clippy::module_name_repetitions)]
+pub enum RecoveryToken {
+ Stream(StreamRecoveryToken),
+ Ack(AckToken),
+ Crypto(CryptoRecoveryToken),
+ HandshakeDone,
+ KeepAlive, // Special PING.
+ NewToken(usize),
+ NewConnectionId(ConnectionIdEntry<[u8; 16]>),
+ RetireConnectionId(u64),
+ AckFrequency(AckRate),
+ Datagram(DatagramTracking),
+}
+
+/// `SendProfile` tells a sender how to send packets.
+#[derive(Debug)]
+pub struct SendProfile {
+ /// The limit on the size of the packet.
+ limit: usize,
+ /// Whether this is a PTO, and what space the PTO is for.
+ pto: Option<PacketNumberSpace>,
+ /// What spaces should be probed.
+ probe: PacketNumberSpaceSet,
+ /// Whether pacing is active.
+ paced: bool,
+}
+
+impl SendProfile {
+ pub fn new_limited(limit: usize) -> Self {
+ // When the limit is too low, we only send ACK frames.
+ // Set the limit to `ACK_ONLY_SIZE_LIMIT - 1` to ensure that
+ // ACK-only packets are still limited in size.
+ Self {
+ limit: max(ACK_ONLY_SIZE_LIMIT - 1, limit),
+ pto: None,
+ probe: PacketNumberSpaceSet::default(),
+ paced: false,
+ }
+ }
+
+ pub fn new_paced() -> Self {
+ // When pacing, we still allow ACK frames to be sent.
+ Self {
+ limit: ACK_ONLY_SIZE_LIMIT - 1,
+ pto: None,
+ probe: PacketNumberSpaceSet::default(),
+ paced: true,
+ }
+ }
+
+ pub fn new_pto(pn_space: PacketNumberSpace, mtu: usize, probe: PacketNumberSpaceSet) -> Self {
+ debug_assert!(mtu > ACK_ONLY_SIZE_LIMIT);
+ debug_assert!(probe[pn_space]);
+ Self {
+ limit: mtu,
+ pto: Some(pn_space),
+ probe,
+ paced: false,
+ }
+ }
+
+ /// Whether probing this space is helpful. This isn't necessarily the space
+ /// that caused the timer to pop, but it is helpful to send a PING in a space
+ /// that has the PTO timer armed.
+ pub fn should_probe(&self, space: PacketNumberSpace) -> bool {
+ self.probe[space]
+ }
+
+ /// Determine whether an ACK-only packet should be sent for the given packet
+ /// number space.
+ /// Send only ACKs either: when the space available is too small, or when a PTO
+ /// exists for a later packet number space (which should get the most space).
+ pub fn ack_only(&self, space: PacketNumberSpace) -> bool {
+ self.limit < ACK_ONLY_SIZE_LIMIT || self.pto.map_or(false, |sp| space < sp)
+ }
+
+ pub fn paced(&self) -> bool {
+ self.paced
+ }
+
+ pub fn limit(&self) -> usize {
+ self.limit
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct LossRecoverySpace {
+ space: PacketNumberSpace,
+ largest_acked: Option<PacketNumber>,
+ largest_acked_sent_time: Option<Instant>,
+ /// The time used to calculate the PTO timer for this space.
+ /// This is the time that the last ACK-eliciting packet in this space
+ /// was sent. This might be the time that a probe was sent.
+ last_ack_eliciting: Option<Instant>,
+ /// The number of outstanding packets in this space that are in flight.
+ /// This might be less than the number of ACK-eliciting packets,
+ /// because PTO packets don't count.
+ in_flight_outstanding: usize,
+ sent_packets: BTreeMap<u64, SentPacket>,
+ /// The time that the first out-of-order packet was sent.
+ /// This is `None` if there were no out-of-order packets detected.
+ /// When set to `Some(T)`, time-based loss detection should be enabled.
+ first_ooo_time: Option<Instant>,
+}
+
+impl LossRecoverySpace {
+ pub fn new(space: PacketNumberSpace) -> Self {
+ Self {
+ space,
+ largest_acked: None,
+ largest_acked_sent_time: None,
+ last_ack_eliciting: None,
+ in_flight_outstanding: 0,
+ sent_packets: BTreeMap::default(),
+ first_ooo_time: None,
+ }
+ }
+
+ #[must_use]
+ pub fn space(&self) -> PacketNumberSpace {
+ self.space
+ }
+
+ /// Find the time we sent the first packet that is lower than the
+ /// largest acknowledged and that isn't yet declared lost.
+ /// Use the value we prepared earlier in `detect_lost_packets`.
+ #[must_use]
+ pub fn loss_recovery_timer_start(&self) -> Option<Instant> {
+ self.first_ooo_time
+ }
+
+ pub fn in_flight_outstanding(&self) -> bool {
+ self.in_flight_outstanding > 0
+ }
+
+ pub fn pto_packets(&mut self, count: usize) -> impl Iterator<Item = &SentPacket> {
+ self.sent_packets
+ .iter_mut()
+ .filter_map(|(pn, sent)| {
+ if sent.pto() {
+ qtrace!("PTO: marking packet {} lost ", pn);
+ Some(&*sent)
+ } else {
+ None
+ }
+ })
+ .take(count)
+ }
+
+ pub fn pto_base_time(&self) -> Option<Instant> {
+ if self.in_flight_outstanding() {
+ debug_assert!(self.last_ack_eliciting.is_some());
+ self.last_ack_eliciting
+ } else if self.space == PacketNumberSpace::ApplicationData {
+ None
+ } else {
+ // Nasty special case to prevent handshake deadlocks.
+ // A client needs to keep the PTO timer armed to prevent a stall
+ // of the handshake. Technically, this has to stop once we receive
+ // an ACK of Handshake or 1-RTT, or when we receive HANDSHAKE_DONE,
+ // but a few extra probes won't hurt.
+ // It only means that we fail anti-amplification tests.
+ // A server shouldn't arm its PTO timer this way. The server sends
+ // ack-eliciting, in-flight packets immediately so this only
+ // happens when the server has nothing outstanding. If we had
+ // client authentication, this might cause some extra probes,
+ // but they would be harmless anyway.
+ self.last_ack_eliciting
+ }
+ }
+
+ pub fn on_packet_sent(&mut self, sent_packet: SentPacket) {
+ if sent_packet.ack_eliciting() {
+ self.last_ack_eliciting = Some(sent_packet.time_sent);
+ self.in_flight_outstanding += 1;
+ } else if self.space != PacketNumberSpace::ApplicationData
+ && self.last_ack_eliciting.is_none()
+ {
+ // For Initial and Handshake spaces, make sure that we have a PTO baseline
+ // always. See `LossRecoverySpace::pto_base_time()` for details.
+ self.last_ack_eliciting = Some(sent_packet.time_sent);
+ }
+ self.sent_packets.insert(sent_packet.pn, sent_packet);
+ }
+
+ /// If we are only sending ACK frames, send a PING frame after 2 PTOs so that
+ /// the peer sends an ACK frame. If we have received lots of packets and no ACK,
+ /// send a PING frame after 1 PTO. Note that this can't be within a PTO, or
+ /// we would risk setting up a feedback loop; having this many packets
+ /// outstanding can be normal and we don't want to PING too often.
+ pub fn should_probe(&self, pto: Duration, now: Instant) -> bool {
+ let n_pto = if self.sent_packets.len() >= MAX_OUTSTANDING_UNACK {
+ 1
+ } else if self.sent_packets.len() >= MIN_OUTSTANDING_UNACK {
+ 2
+ } else {
+ return false;
+ };
+ self.last_ack_eliciting
+ .map_or(false, |t| now > t + (pto * n_pto))
+ }
+
+ fn remove_packet(&mut self, p: &SentPacket) {
+ if p.ack_eliciting() {
+ debug_assert!(self.in_flight_outstanding > 0);
+ self.in_flight_outstanding -= 1;
+ if self.in_flight_outstanding == 0 {
+ qtrace!("remove_packet outstanding == 0 for space {}", self.space);
+ }
+ }
+ }
+
+ /// Remove all acknowledged packets.
+ /// Returns all the acknowledged packets, with the largest packet number first.
+ /// ...and a boolean indicating if any of those packets were ack-eliciting.
+ /// This operates more efficiently because it assumes that the input is sorted
+ /// in the order that an ACK frame is (from the top).
+ fn remove_acked<R>(&mut self, acked_ranges: R, stats: &mut Stats) -> (Vec<SentPacket>, bool)
+ where
+ R: IntoIterator<Item = RangeInclusive<u64>>,
+ R::IntoIter: ExactSizeIterator,
+ {
+ let acked_ranges = acked_ranges.into_iter();
+ let mut keep = Vec::with_capacity(acked_ranges.len());
+
+ let mut acked = Vec::new();
+ let mut eliciting = false;
+ for range in acked_ranges {
+ let first_keep = *range.end() + 1;
+ if let Some((&first, _)) = self.sent_packets.range(range).next() {
+ let mut tail = self.sent_packets.split_off(&first);
+ if let Some((&next, _)) = tail.range(first_keep..).next() {
+ keep.push(tail.split_off(&next));
+ }
+ for (_, p) in tail.into_iter().rev() {
+ self.remove_packet(&p);
+ eliciting |= p.ack_eliciting();
+ if p.lost() {
+ stats.late_ack += 1;
+ }
+ if p.pto_fired() {
+ stats.pto_ack += 1;
+ }
+ acked.push(p);
+ }
+ }
+ }
+
+ for mut k in keep.into_iter().rev() {
+ self.sent_packets.append(&mut k);
+ }
+
+ (acked, eliciting)
+ }
+
+ /// Remove all tracked packets from the space.
+ /// This is called by a client when 0-RTT packets are dropped, when a Retry is received
+ /// and when keys are dropped.
+ fn remove_ignored(&mut self) -> impl Iterator<Item = SentPacket> {
+ self.in_flight_outstanding = 0;
+ mem::take(&mut self.sent_packets).into_values()
+ }
+
+ /// Remove the primary path marking on any packets this is tracking.
+ fn migrate(&mut self) {
+ for pkt in self.sent_packets.values_mut() {
+ pkt.clear_primary_path();
+ }
+ }
+
+ /// Remove old packets that we've been tracking in case they get acknowledged.
+ /// We try to keep these around until a probe is sent for them, so it is
+ /// important that `cd` is set to at least the current PTO time; otherwise we
+ /// might remove all in-flight packets and stop sending probes.
+ #[allow(clippy::option_if_let_else)] // Hard enough to read as-is.
+ fn remove_old_lost(&mut self, now: Instant, cd: Duration) {
+ let mut it = self.sent_packets.iter();
+ // If the first item is not expired, do nothing.
+ if it.next().map_or(false, |(_, p)| p.expired(now, cd)) {
+ // Find the index of the first unexpired packet.
+ let to_remove = if let Some(first_keep) =
+ it.find_map(|(i, p)| if p.expired(now, cd) { None } else { Some(*i) })
+ {
+ // Some packets haven't expired, so keep those.
+ let keep = self.sent_packets.split_off(&first_keep);
+ mem::replace(&mut self.sent_packets, keep)
+ } else {
+ // All packets are expired.
+ mem::take(&mut self.sent_packets)
+ };
+ for (_, p) in to_remove {
+ self.remove_packet(&p);
+ }
+ }
+ }
+
+ /// Detect lost packets.
+ /// `loss_delay` is the time we will wait before declaring something lost.
+ /// `cleanup_delay` is the time we will wait before cleaning up a lost packet.
+ pub fn detect_lost_packets(
+ &mut self,
+ now: Instant,
+ loss_delay: Duration,
+ cleanup_delay: Duration,
+ lost_packets: &mut Vec<SentPacket>,
+ ) {
+ // Housekeeping.
+ self.remove_old_lost(now, cleanup_delay);
+
+ qtrace!(
+ "detect lost {}: now={:?} delay={:?}",
+ self.space,
+ now,
+ loss_delay,
+ );
+ self.first_ooo_time = None;
+
+ let largest_acked = self.largest_acked;
+
+ // Lost for retrans/CC purposes
+ let mut lost_pns = SmallVec::<[_; 8]>::new();
+
+ for (pn, packet) in self
+ .sent_packets
+ .iter_mut()
+ // BTreeMap iterates in order of ascending PN
+ .take_while(|(&k, _)| k < largest_acked.unwrap_or(PacketNumber::MAX))
+ {
+ // Packets sent before now - loss_delay are deemed lost.
+ if packet.time_sent + loss_delay <= now {
+ qtrace!(
+ "lost={}, time sent {:?} is before lost_delay {:?}",
+ pn,
+ packet.time_sent,
+ loss_delay
+ );
+ } else if largest_acked >= Some(*pn + PACKET_THRESHOLD) {
+ qtrace!(
+ "lost={}, is >= {} from largest acked {:?}",
+ pn,
+ PACKET_THRESHOLD,
+ largest_acked
+ );
+ } else {
+ if largest_acked.is_some() {
+ self.first_ooo_time = Some(packet.time_sent);
+ }
+ // No more packets can be declared lost after this one.
+ break;
+ };
+
+ if packet.declare_lost(now) {
+ lost_pns.push(*pn);
+ }
+ }
+
+ lost_packets.extend(lost_pns.iter().map(|pn| self.sent_packets[pn].clone()));
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct LossRecoverySpaces {
+ /// When we have all of the loss recovery spaces, this will use a separate
+ /// allocation, but this is reduced once the handshake is done.
+ spaces: SmallVec<[LossRecoverySpace; 1]>,
+}
+
+impl LossRecoverySpaces {
+ fn idx(space: PacketNumberSpace) -> usize {
+ match space {
+ PacketNumberSpace::ApplicationData => 0,
+ PacketNumberSpace::Handshake => 1,
+ PacketNumberSpace::Initial => 2,
+ }
+ }
+
+ /// Drop a packet number space and return all the packets that were
+ /// outstanding, so that those can be marked as lost.
+ ///
+ /// # Panics
+ ///
+ /// If the space has already been removed.
+ pub fn drop_space(&mut self, space: PacketNumberSpace) -> impl IntoIterator<Item = SentPacket> {
+ let sp = match space {
+ PacketNumberSpace::Initial => self.spaces.pop(),
+ PacketNumberSpace::Handshake => {
+ let sp = self.spaces.pop();
+ self.spaces.shrink_to_fit();
+ sp
+ }
+ PacketNumberSpace::ApplicationData => panic!("discarding application space"),
+ };
+ let mut sp = sp.unwrap();
+ assert_eq!(sp.space(), space, "dropping spaces out of order");
+ sp.remove_ignored()
+ }
+
+ pub fn get(&self, space: PacketNumberSpace) -> Option<&LossRecoverySpace> {
+ self.spaces.get(Self::idx(space))
+ }
+
+ pub fn get_mut(&mut self, space: PacketNumberSpace) -> Option<&mut LossRecoverySpace> {
+ self.spaces.get_mut(Self::idx(space))
+ }
+
+ fn iter(&self) -> impl Iterator<Item = &LossRecoverySpace> {
+ self.spaces.iter()
+ }
+
+ fn iter_mut(&mut self) -> impl Iterator<Item = &mut LossRecoverySpace> {
+ self.spaces.iter_mut()
+ }
+}
+
+impl Default for LossRecoverySpaces {
+ fn default() -> Self {
+ Self {
+ spaces: smallvec![
+ LossRecoverySpace::new(PacketNumberSpace::ApplicationData),
+ LossRecoverySpace::new(PacketNumberSpace::Handshake),
+ LossRecoverySpace::new(PacketNumberSpace::Initial),
+ ],
+ }
+ }
+}
+
+#[derive(Debug)]
+struct PtoState {
+ /// The packet number space that caused the PTO to fire.
+ space: PacketNumberSpace,
+ /// The number of probes that we have sent.
+ count: usize,
+ packets: usize,
+ /// The complete set of packet number spaces that can have probes sent.
+ probe: PacketNumberSpaceSet,
+}
+
+impl PtoState {
+ /// The number of packets we send on a PTO.
+ /// And the number to declare lost when the PTO timer is hit.
+ fn pto_packet_count(space: PacketNumberSpace, rx_count: usize) -> usize {
+ if space == PacketNumberSpace::Initial && rx_count == 0 {
+ // For the Initial space, we only send one packet on PTO if we have not received any
+ // packets from the peer yet. This avoids sending useless PING-only packets
+ // when the Client Initial is deemed lost.
+ 1
+ } else {
+ MAX_PTO_PACKET_COUNT
+ }
+ }
+
+ pub fn new(space: PacketNumberSpace, probe: PacketNumberSpaceSet, rx_count: usize) -> Self {
+ debug_assert!(probe[space]);
+ Self {
+ space,
+ count: 1,
+ packets: Self::pto_packet_count(space, rx_count),
+ probe,
+ }
+ }
+
+ pub fn pto(&mut self, space: PacketNumberSpace, probe: PacketNumberSpaceSet, rx_count: usize) {
+ debug_assert!(probe[space]);
+ self.space = space;
+ self.count += 1;
+ self.packets = Self::pto_packet_count(space, rx_count);
+ self.probe = probe;
+ }
+
+ pub fn count(&self) -> usize {
+ self.count
+ }
+
+ pub fn count_pto(&self, stats: &mut Stats) {
+ stats.add_pto_count(self.count);
+ }
+
+ /// Generate a sending profile, indicating what space it should be from.
+ /// This takes a packet from the supply if one remains, or returns `None`.
+ pub fn send_profile(&mut self, mtu: usize) -> Option<SendProfile> {
+ if self.packets > 0 {
+ // This is a PTO, so ignore the limit.
+ self.packets -= 1;
+ Some(SendProfile::new_pto(self.space, mtu, self.probe))
+ } else {
+ None
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct LossRecovery {
+ /// When the handshake was confirmed, if it has been.
+ confirmed_time: Option<Instant>,
+ pto_state: Option<PtoState>,
+ spaces: LossRecoverySpaces,
+ qlog: NeqoQlog,
+ stats: StatsCell,
+ /// The factor by which the PTO period is reduced.
+ /// This enables faster probing at a cost in additional lost packets.
+ fast_pto: u8,
+}
+
+impl LossRecovery {
+ pub fn new(stats: StatsCell, fast_pto: u8) -> Self {
+ Self {
+ confirmed_time: None,
+ pto_state: None,
+ spaces: LossRecoverySpaces::default(),
+ qlog: NeqoQlog::default(),
+ stats,
+ fast_pto,
+ }
+ }
+
+ pub fn largest_acknowledged_pn(&self, pn_space: PacketNumberSpace) -> Option<PacketNumber> {
+ self.spaces.get(pn_space).and_then(|sp| sp.largest_acked)
+ }
+
+ pub fn set_qlog(&mut self, qlog: NeqoQlog) {
+ self.qlog = qlog;
+ }
+
+ pub fn drop_0rtt(&mut self, primary_path: &PathRef, now: Instant) -> Vec<SentPacket> {
+ // The largest acknowledged or loss_time should still be unset.
+ // The client should not have received any ACK frames when it drops 0-RTT.
+ assert!(self
+ .spaces
+ .get(PacketNumberSpace::ApplicationData)
+ .unwrap()
+ .largest_acked
+ .is_none());
+ let mut dropped = self
+ .spaces
+ .get_mut(PacketNumberSpace::ApplicationData)
+ .unwrap()
+ .remove_ignored()
+ .collect::<Vec<_>>();
+ let mut path = primary_path.borrow_mut();
+ for p in &mut dropped {
+ path.discard_packet(p, now, &mut self.stats.borrow_mut());
+ }
+ dropped
+ }
+
+ pub fn on_packet_sent(&mut self, path: &PathRef, mut sent_packet: SentPacket) {
+ let pn_space = PacketNumberSpace::from(sent_packet.pt);
+ qdebug!([self], "packet {}-{} sent", pn_space, sent_packet.pn);
+ if let Some(space) = self.spaces.get_mut(pn_space) {
+ path.borrow_mut().packet_sent(&mut sent_packet);
+ space.on_packet_sent(sent_packet);
+ } else {
+ qwarn!(
+ [self],
+ "ignoring {}-{} from dropped space",
+ pn_space,
+ sent_packet.pn
+ );
+ }
+ }
+
+ pub fn should_probe(&self, pto: Duration, now: Instant) -> bool {
+ self.spaces
+ .get(PacketNumberSpace::ApplicationData)
+ .unwrap()
+ .should_probe(pto, now)
+ }
+
+ /// Record an RTT sample.
+ fn rtt_sample(
+ &mut self,
+ rtt: &mut RttEstimate,
+ send_time: Instant,
+ now: Instant,
+ ack_delay: Duration,
+ ) {
+ let confirmed = self.confirmed_time.map_or(false, |t| t < send_time);
+ if let Some(sample) = now.checked_duration_since(send_time) {
+ rtt.update(&mut self.qlog, sample, ack_delay, confirmed, now);
+ }
+ }
+
+ /// Returns (acked packets, lost packets)
+ pub fn on_ack_received<R>(
+ &mut self,
+ primary_path: &PathRef,
+ pn_space: PacketNumberSpace,
+ largest_acked: u64,
+ acked_ranges: R,
+ ack_delay: Duration,
+ now: Instant,
+ ) -> (Vec<SentPacket>, Vec<SentPacket>)
+ where
+ R: IntoIterator<Item = RangeInclusive<u64>>,
+ R::IntoIter: ExactSizeIterator,
+ {
+ qdebug!(
+ [self],
+ "ACK for {} - largest_acked={}.",
+ pn_space,
+ largest_acked
+ );
+
+ let Some(space) = self.spaces.get_mut(pn_space) else {
+ qinfo!("ACK on discarded space");
+ return (Vec::new(), Vec::new());
+ };
+
+ let (acked_packets, any_ack_eliciting) =
+ space.remove_acked(acked_ranges, &mut self.stats.borrow_mut());
+ if acked_packets.is_empty() {
+ // No new information.
+ return (Vec::new(), Vec::new());
+ }
+
+ // Track largest PN acked per space
+ let prev_largest_acked = space.largest_acked_sent_time;
+ if Some(largest_acked) > space.largest_acked {
+ space.largest_acked = Some(largest_acked);
+
+ // If the largest acknowledged is newly acked and any newly acked
+ // packet was ack-eliciting, update the RTT. (-recovery 5.1)
+ let largest_acked_pkt = acked_packets.first().expect("must be there");
+ space.largest_acked_sent_time = Some(largest_acked_pkt.time_sent);
+ if any_ack_eliciting && largest_acked_pkt.on_primary_path() {
+ self.rtt_sample(
+ primary_path.borrow_mut().rtt_mut(),
+ largest_acked_pkt.time_sent,
+ now,
+ ack_delay,
+ );
+ }
+ }
+
+ // Perform loss detection.
+ // PTO is used to remove lost packets from in-flight accounting.
+ // We need to ensure that we have sent any PTO probes before they are removed
+ // as we rely on the count of in-flight packets to determine whether to send
+ // another probe. Removing them too soon would result in not sending on PTO.
+ let loss_delay = primary_path.borrow().rtt().loss_delay();
+ let cleanup_delay = self.pto_period(primary_path.borrow().rtt(), pn_space);
+ let mut lost = Vec::new();
+ self.spaces.get_mut(pn_space).unwrap().detect_lost_packets(
+ now,
+ loss_delay,
+ cleanup_delay,
+ &mut lost,
+ );
+ self.stats.borrow_mut().lost += lost.len();
+
+ // Tell the congestion controller about any lost packets.
+ // The PTO for congestion control is the raw number, without exponential
+ // backoff, so that we can determine persistent congestion.
+ primary_path
+ .borrow_mut()
+ .on_packets_lost(prev_largest_acked, pn_space, &lost);
+
+ // This must happen after on_packets_lost. If in recovery, this could
+ // take us out, and then lost packets will start a new recovery period
+ // when it shouldn't.
+ primary_path
+ .borrow_mut()
+ .on_packets_acked(&acked_packets, now);
+
+ self.pto_state = None;
+
+ (acked_packets, lost)
+ }
+
+ /// When receiving a retry, get all the sent packets so that they can be flushed.
+ /// We also need to pretend that they never happened for the purposes of congestion control.
+ pub fn retry(&mut self, primary_path: &PathRef, now: Instant) -> Vec<SentPacket> {
+ self.pto_state = None;
+ let mut dropped = self
+ .spaces
+ .iter_mut()
+ .flat_map(LossRecoverySpace::remove_ignored)
+ .collect::<Vec<_>>();
+ let mut path = primary_path.borrow_mut();
+ for p in &mut dropped {
+ path.discard_packet(p, now, &mut self.stats.borrow_mut());
+ }
+ dropped
+ }
+
+ fn confirmed(&mut self, rtt: &RttEstimate, now: Instant) {
+ debug_assert!(self.confirmed_time.is_none());
+ self.confirmed_time = Some(now);
+ // Up until now, the ApplicationData space has been ignored for PTO.
+ // So maybe fire a PTO.
+ if let Some(pto) = self.pto_time(rtt, PacketNumberSpace::ApplicationData) {
+ if pto < now {
+ let probes = PacketNumberSpaceSet::from(&[PacketNumberSpace::ApplicationData]);
+ self.fire_pto(PacketNumberSpace::ApplicationData, probes);
+ }
+ }
+ }
+
+ /// This function is called when the connection migrates.
+ /// It marks all packets that are outstanding as having being sent on a non-primary path.
+ /// This way failure to deliver on the old path doesn't count against the congestion
+ /// control state on the new path and the RTT measurements don't apply either.
+ pub fn migrate(&mut self) {
+ for space in self.spaces.iter_mut() {
+ space.migrate();
+ }
+ }
+
+ /// Discard state for a given packet number space.
+ pub fn discard(&mut self, primary_path: &PathRef, space: PacketNumberSpace, now: Instant) {
+ qdebug!([self], "Reset loss recovery state for {}", space);
+ let mut path = primary_path.borrow_mut();
+ for p in self.spaces.drop_space(space) {
+ path.discard_packet(&p, now, &mut self.stats.borrow_mut());
+ }
+
+ // We just made progress, so discard PTO count.
+ // The spec says that clients should not do this until confirming that
+ // the server has completed address validation, but ignore that.
+ self.pto_state = None;
+
+ if space == PacketNumberSpace::Handshake {
+ self.confirmed(path.rtt(), now);
+ }
+ }
+
+ /// Calculate when the next timeout is likely to be. This is the earlier of the loss timer
+ /// and the PTO timer; either or both might be disabled, so this can return `None`.
+ pub fn next_timeout(&mut self, rtt: &RttEstimate) -> Option<Instant> {
+ let loss_time = self.earliest_loss_time(rtt);
+ let pto_time = self.earliest_pto(rtt);
+ qtrace!(
+ [self],
+ "next_timeout loss={:?} pto={:?}",
+ loss_time,
+ pto_time
+ );
+ match (loss_time, pto_time) {
+ (Some(loss_time), Some(pto_time)) => Some(min(loss_time, pto_time)),
+ (Some(loss_time), None) => Some(loss_time),
+ (None, Some(pto_time)) => Some(pto_time),
+ (None, None) => None,
+ }
+ }
+
+ /// Find when the earliest sent packet should be considered lost.
+ fn earliest_loss_time(&self, rtt: &RttEstimate) -> Option<Instant> {
+ self.spaces
+ .iter()
+ .filter_map(LossRecoverySpace::loss_recovery_timer_start)
+ .min()
+ .map(|val| val + rtt.loss_delay())
+ }
+
+ /// Simple wrapper for the PTO calculation that avoids borrow check rules.
+ fn pto_period_inner(
+ rtt: &RttEstimate,
+ pto_state: Option<&PtoState>,
+ pn_space: PacketNumberSpace,
+ fast_pto: u8,
+ ) -> Duration {
+ // This is a complicated (but safe) way of calculating:
+ // base_pto * F * 2^pto_count
+ // where F = fast_pto / FAST_PTO_SCALE (== 1 by default)
+ let pto_count = pto_state.map_or(0, |p| u32::try_from(p.count).unwrap_or(0));
+ rtt.pto(pn_space)
+ .checked_mul(u32::from(fast_pto) << min(pto_count, u32::BITS - u8::BITS))
+ .map_or(Duration::from_secs(3600), |p| p / u32::from(FAST_PTO_SCALE))
+ }
+
+ /// Get the current PTO period for the given packet number space.
+ /// Unlike calling `RttEstimate::pto` directly, this includes exponential backoff.
+ fn pto_period(&self, rtt: &RttEstimate, pn_space: PacketNumberSpace) -> Duration {
+ Self::pto_period_inner(rtt, self.pto_state.as_ref(), pn_space, self.fast_pto)
+ }
+
+ // Calculate PTO time for the given space.
+ fn pto_time(&self, rtt: &RttEstimate, pn_space: PacketNumberSpace) -> Option<Instant> {
+ if self.confirmed_time.is_none() && pn_space == PacketNumberSpace::ApplicationData {
+ None
+ } else {
+ self.spaces.get(pn_space).and_then(|space| {
+ space
+ .pto_base_time()
+ .map(|t| t + self.pto_period(rtt, pn_space))
+ })
+ }
+ }
+
+ /// Find the earliest PTO time for all active packet number spaces.
+ /// Ignore Application if either Initial or Handshake have an active PTO.
+ fn earliest_pto(&self, rtt: &RttEstimate) -> Option<Instant> {
+ if self.confirmed_time.is_some() {
+ self.pto_time(rtt, PacketNumberSpace::ApplicationData)
+ } else {
+ self.pto_time(rtt, PacketNumberSpace::Initial)
+ .iter()
+ .chain(self.pto_time(rtt, PacketNumberSpace::Handshake).iter())
+ .min()
+ .copied()
+ }
+ }
+
+ fn fire_pto(&mut self, pn_space: PacketNumberSpace, allow_probes: PacketNumberSpaceSet) {
+ let rx_count = self.stats.borrow().packets_rx;
+ if let Some(st) = &mut self.pto_state {
+ st.pto(pn_space, allow_probes, rx_count);
+ } else {
+ self.pto_state = Some(PtoState::new(pn_space, allow_probes, rx_count));
+ }
+
+ self.pto_state
+ .as_mut()
+ .unwrap()
+ .count_pto(&mut self.stats.borrow_mut());
+
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[QlogMetric::PtoCount(
+ self.pto_state.as_ref().unwrap().count(),
+ )],
+ );
+ }
+
+ /// This checks whether the PTO timer has fired and fires it if needed.
+ /// When it has, mark a few packets as "lost" for the purposes of having frames
+ /// regenerated in subsequent packets. The packets aren't truly lost, so
+ /// we have to clone the `SentPacket` instance.
+ fn maybe_fire_pto(&mut self, rtt: &RttEstimate, now: Instant, lost: &mut Vec<SentPacket>) {
+ let mut pto_space = None;
+ // The spaces in which we will allow probing.
+ let mut allow_probes = PacketNumberSpaceSet::default();
+ for pn_space in PacketNumberSpace::iter() {
+ if let Some(t) = self.pto_time(rtt, *pn_space) {
+ allow_probes[*pn_space] = true;
+ if t <= now {
+ qdebug!([self], "PTO timer fired for {}", pn_space);
+ let space = self.spaces.get_mut(*pn_space).unwrap();
+ lost.extend(
+ space
+ .pto_packets(PtoState::pto_packet_count(
+ *pn_space,
+ self.stats.borrow().packets_rx,
+ ))
+ .cloned(),
+ );
+
+ pto_space = pto_space.or(Some(*pn_space));
+ }
+ }
+ }
+
+ // This has to happen outside the loop. Increasing the PTO count here causes the
+ // pto_time to increase which might cause PTO for later packet number spaces to not fire.
+ if let Some(pn_space) = pto_space {
+ qtrace!([self], "PTO {}, probing {:?}", pn_space, allow_probes);
+ self.fire_pto(pn_space, allow_probes);
+ }
+ }
+
+ pub fn timeout(&mut self, primary_path: &PathRef, now: Instant) -> Vec<SentPacket> {
+ qtrace!([self], "timeout {:?}", now);
+
+ let loss_delay = primary_path.borrow().rtt().loss_delay();
+
+ let mut lost_packets = Vec::new();
+ for space in self.spaces.iter_mut() {
+ let first = lost_packets.len(); // The first packet lost in this space.
+ let pto = Self::pto_period_inner(
+ primary_path.borrow().rtt(),
+ self.pto_state.as_ref(),
+ space.space(),
+ self.fast_pto,
+ );
+ space.detect_lost_packets(now, loss_delay, pto, &mut lost_packets);
+
+ primary_path.borrow_mut().on_packets_lost(
+ space.largest_acked_sent_time,
+ space.space(),
+ &lost_packets[first..],
+ );
+ }
+ self.stats.borrow_mut().lost += lost_packets.len();
+
+ self.maybe_fire_pto(primary_path.borrow().rtt(), now, &mut lost_packets);
+ lost_packets
+ }
+
+ /// Check how packets should be sent, based on whether there is a PTO,
+ /// what the current congestion window is, and what the pacer says.
+ #[allow(clippy::option_if_let_else)]
+ pub fn send_profile(&mut self, path: &Path, now: Instant) -> SendProfile {
+ qdebug!([self], "get send profile {:?}", now);
+ let sender = path.sender();
+ let mtu = path.mtu();
+ if let Some(profile) = self
+ .pto_state
+ .as_mut()
+ .and_then(|pto| pto.send_profile(mtu))
+ {
+ profile
+ } else {
+ let limit = min(sender.cwnd_avail(), path.amplification_limit());
+ if limit > mtu {
+ // More than an MTU available; we might need to pace.
+ if sender
+ .next_paced(path.rtt().estimate())
+ .map_or(false, |t| t > now)
+ {
+ SendProfile::new_paced()
+ } else {
+ SendProfile::new_limited(mtu)
+ }
+ } else if sender.recovery_packet() {
+ // After entering recovery, allow a packet to be sent immediately.
+ // This uses the PTO machinery, probing in all spaces. This will
+ // result in a PING being sent in every active space.
+ SendProfile::new_pto(PacketNumberSpace::Initial, mtu, PacketNumberSpaceSet::all())
+ } else {
+ SendProfile::new_limited(limit)
+ }
+ }
+ }
+}
+
+impl ::std::fmt::Display for LossRecovery {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "LossRecovery")
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::{
+ cell::RefCell,
+ convert::TryInto,
+ ops::{Deref, DerefMut, RangeInclusive},
+ rc::Rc,
+ time::{Duration, Instant},
+ };
+
+ use neqo_common::qlog::NeqoQlog;
+ use test_fixture::{addr, now};
+
+ use super::{
+ LossRecovery, LossRecoverySpace, PacketNumberSpace, SendProfile, SentPacket, FAST_PTO_SCALE,
+ };
+ use crate::{
+ cc::CongestionControlAlgorithm,
+ cid::{ConnectionId, ConnectionIdEntry},
+ packet::PacketType,
+ path::{Path, PathRef},
+ rtt::RttEstimate,
+ stats::{Stats, StatsCell},
+ };
+
+ // Shorthand for a time in milliseconds.
+ const fn ms(t: u64) -> Duration {
+ Duration::from_millis(t)
+ }
+
+ const ON_SENT_SIZE: usize = 100;
+ /// An initial RTT for using with `setup_lr`.
+ const TEST_RTT: Duration = ms(80);
+ const TEST_RTTVAR: Duration = ms(40);
+
+ struct Fixture {
+ lr: LossRecovery,
+ path: PathRef,
+ }
+
+ // This shadows functions on the base object so that the path and RTT estimator
+ // is used consistently in the tests. It also simplifies the function signatures.
+ impl Fixture {
+ pub fn on_ack_received(
+ &mut self,
+ pn_space: PacketNumberSpace,
+ largest_acked: u64,
+ acked_ranges: Vec<RangeInclusive<u64>>,
+ ack_delay: Duration,
+ now: Instant,
+ ) -> (Vec<SentPacket>, Vec<SentPacket>) {
+ self.lr.on_ack_received(
+ &self.path,
+ pn_space,
+ largest_acked,
+ acked_ranges,
+ ack_delay,
+ now,
+ )
+ }
+
+ pub fn on_packet_sent(&mut self, sent_packet: SentPacket) {
+ self.lr.on_packet_sent(&self.path, sent_packet);
+ }
+
+ pub fn timeout(&mut self, now: Instant) -> Vec<SentPacket> {
+ self.lr.timeout(&self.path, now)
+ }
+
+ pub fn next_timeout(&mut self) -> Option<Instant> {
+ self.lr.next_timeout(self.path.borrow().rtt())
+ }
+
+ pub fn discard(&mut self, space: PacketNumberSpace, now: Instant) {
+ self.lr.discard(&self.path, space, now);
+ }
+
+ pub fn pto_time(&self, space: PacketNumberSpace) -> Option<Instant> {
+ self.lr.pto_time(self.path.borrow().rtt(), space)
+ }
+
+ pub fn send_profile(&mut self, now: Instant) -> SendProfile {
+ self.lr.send_profile(&self.path.borrow(), now)
+ }
+ }
+
+ impl Default for Fixture {
+ fn default() -> Self {
+ const CC: CongestionControlAlgorithm = CongestionControlAlgorithm::NewReno;
+ let mut path = Path::temporary(addr(), addr(), CC, true, NeqoQlog::default(), now());
+ path.make_permanent(
+ None,
+ ConnectionIdEntry::new(0, ConnectionId::from(&[1, 2, 3]), [0; 16]),
+ );
+ path.set_primary(true);
+ Self {
+ lr: LossRecovery::new(StatsCell::default(), FAST_PTO_SCALE),
+ path: Rc::new(RefCell::new(path)),
+ }
+ }
+ }
+
+ // Most uses of the fixture only care about the loss recovery piece,
+ // but the internal functions need the other bits.
+ impl Deref for Fixture {
+ type Target = LossRecovery;
+ #[must_use]
+ fn deref(&self) -> &Self::Target {
+ &self.lr
+ }
+ }
+
+ impl DerefMut for Fixture {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.lr
+ }
+ }
+
+ fn assert_rtts(
+ lr: &Fixture,
+ latest_rtt: Duration,
+ smoothed_rtt: Duration,
+ rttvar: Duration,
+ min_rtt: Duration,
+ ) {
+ let p = lr.path.borrow();
+ let rtt = p.rtt();
+ println!(
+ "rtts: {:?} {:?} {:?} {:?}",
+ rtt.latest(),
+ rtt.estimate(),
+ rtt.rttvar(),
+ rtt.minimum(),
+ );
+ assert_eq!(rtt.latest(), latest_rtt, "latest RTT");
+ assert_eq!(rtt.estimate(), smoothed_rtt, "smoothed RTT");
+ assert_eq!(rtt.rttvar(), rttvar, "RTT variance");
+ assert_eq!(rtt.minimum(), min_rtt, "min RTT");
+ }
+
+ fn assert_sent_times(
+ lr: &Fixture,
+ initial: Option<Instant>,
+ handshake: Option<Instant>,
+ app_data: Option<Instant>,
+ ) {
+ let est = |sp| {
+ lr.spaces
+ .get(sp)
+ .and_then(LossRecoverySpace::loss_recovery_timer_start)
+ };
+ println!(
+ "loss times: {:?} {:?} {:?}",
+ est(PacketNumberSpace::Initial),
+ est(PacketNumberSpace::Handshake),
+ est(PacketNumberSpace::ApplicationData),
+ );
+ assert_eq!(
+ est(PacketNumberSpace::Initial),
+ initial,
+ "Initial earliest sent time"
+ );
+ assert_eq!(
+ est(PacketNumberSpace::Handshake),
+ handshake,
+ "Handshake earliest sent time"
+ );
+ assert_eq!(
+ est(PacketNumberSpace::ApplicationData),
+ app_data,
+ "AppData earliest sent time"
+ );
+ }
+
+ fn assert_no_sent_times(lr: &Fixture) {
+ assert_sent_times(lr, None, None, None);
+ }
+
+ // In most of the tests below, packets are sent at a fixed cadence, with PACING between each.
+ const PACING: Duration = ms(7);
+ fn pn_time(pn: u64) -> Instant {
+ now() + (PACING * pn.try_into().unwrap())
+ }
+
+ fn pace(lr: &mut Fixture, count: u64) {
+ for pn in 0..count {
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ pn,
+ pn_time(pn),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ }
+ }
+
+ const ACK_DELAY: Duration = ms(24);
+ /// Acknowledge PN with the identified delay.
+ fn ack(lr: &mut Fixture, pn: u64, delay: Duration) {
+ lr.on_ack_received(
+ PacketNumberSpace::ApplicationData,
+ pn,
+ vec![pn..=pn],
+ ACK_DELAY,
+ pn_time(pn) + delay,
+ );
+ }
+
+ fn add_sent(lrs: &mut LossRecoverySpace, packet_numbers: &[u64]) {
+ for &pn in packet_numbers {
+ lrs.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ pn,
+ pn_time(pn),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ }
+ }
+
+ fn match_acked(acked: &[SentPacket], expected: &[u64]) {
+ assert!(acked.iter().map(|p| &p.pn).eq(expected));
+ }
+
+ #[test]
+ fn remove_acked() {
+ let mut lrs = LossRecoverySpace::new(PacketNumberSpace::ApplicationData);
+ let mut stats = Stats::default();
+ add_sent(&mut lrs, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
+ let (acked, _) = lrs.remove_acked(vec![], &mut stats);
+ assert!(acked.is_empty());
+ let (acked, _) = lrs.remove_acked(vec![7..=8, 2..=4], &mut stats);
+ match_acked(&acked, &[8, 7, 4, 3, 2]);
+ let (acked, _) = lrs.remove_acked(vec![8..=11], &mut stats);
+ match_acked(&acked, &[10, 9]);
+ let (acked, _) = lrs.remove_acked(vec![0..=2], &mut stats);
+ match_acked(&acked, &[1]);
+ let (acked, _) = lrs.remove_acked(vec![5..=6], &mut stats);
+ match_acked(&acked, &[6, 5]);
+ }
+
+ #[test]
+ fn initial_rtt() {
+ let mut lr = Fixture::default();
+ pace(&mut lr, 1);
+ let rtt = ms(100);
+ ack(&mut lr, 0, rtt);
+ assert_rtts(&lr, rtt, rtt, rtt / 2, rtt);
+ assert_no_sent_times(&lr);
+ }
+
+ /// Send `n` packets (using PACING), then acknowledge the first.
+ fn setup_lr(n: u64) -> Fixture {
+ let mut lr = Fixture::default();
+ pace(&mut lr, n);
+ ack(&mut lr, 0, TEST_RTT);
+ assert_rtts(&lr, TEST_RTT, TEST_RTT, TEST_RTTVAR, TEST_RTT);
+ assert_no_sent_times(&lr);
+ lr
+ }
+
+ // The ack delay is removed from any RTT estimate.
+ #[test]
+ fn ack_delay_adjusted() {
+ let mut lr = setup_lr(2);
+ ack(&mut lr, 1, TEST_RTT + ACK_DELAY);
+ // RTT stays the same, but the RTTVAR is adjusted downwards.
+ assert_rtts(&lr, TEST_RTT, TEST_RTT, TEST_RTTVAR * 3 / 4, TEST_RTT);
+ assert_no_sent_times(&lr);
+ }
+
+ // The ack delay is ignored when it would cause a sample to be less than min_rtt.
+ #[test]
+ fn ack_delay_ignored() {
+ let mut lr = setup_lr(2);
+ let extra = ms(8);
+ assert!(extra < ACK_DELAY);
+ ack(&mut lr, 1, TEST_RTT + extra);
+ let expected_rtt = TEST_RTT + (extra / 8);
+ let expected_rttvar = (TEST_RTTVAR * 3 + extra) / 4;
+ assert_rtts(
+ &lr,
+ TEST_RTT + extra,
+ expected_rtt,
+ expected_rttvar,
+ TEST_RTT,
+ );
+ assert_no_sent_times(&lr);
+ }
+
+ // A lower observed RTT is used as min_rtt (and ack delay is ignored).
+ #[test]
+ fn reduce_min_rtt() {
+ let mut lr = setup_lr(2);
+ let delta = ms(4);
+ let reduced_rtt = TEST_RTT - delta;
+ ack(&mut lr, 1, reduced_rtt);
+ let expected_rtt = TEST_RTT - (delta / 8);
+ let expected_rttvar = (TEST_RTTVAR * 3 + delta) / 4;
+ assert_rtts(&lr, reduced_rtt, expected_rtt, expected_rttvar, reduced_rtt);
+ assert_no_sent_times(&lr);
+ }
+
+ // Acknowledging something again has no effect.
+ #[test]
+ fn no_new_acks() {
+ let mut lr = setup_lr(1);
+ let check = |lr: &Fixture| {
+ assert_rtts(lr, TEST_RTT, TEST_RTT, TEST_RTTVAR, TEST_RTT);
+ assert_no_sent_times(lr);
+ };
+ check(&lr);
+
+ ack(&mut lr, 0, ms(1339)); // much delayed ACK
+ check(&lr);
+
+ ack(&mut lr, 0, ms(3)); // time travel!
+ check(&lr);
+ }
+
+ // Test time loss detection as part of handling a regular ACK.
+ #[test]
+ fn time_loss_detection_gap() {
+ let mut lr = Fixture::default();
+ // Create a single packet gap, and have pn 0 time out.
+ // This can't use the default pacing, which is too tight.
+ // So send two packets with 1/4 RTT between them. Acknowledge pn 1 after 1 RTT.
+ // pn 0 should then be marked lost because it is then outstanding for 5RTT/4
+ // the loss time for packets is 9RTT/8.
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ 0,
+ pn_time(0),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ 1,
+ pn_time(0) + TEST_RTT / 4,
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ let (_, lost) = lr.on_ack_received(
+ PacketNumberSpace::ApplicationData,
+ 1,
+ vec![1..=1],
+ ACK_DELAY,
+ pn_time(0) + (TEST_RTT * 5 / 4),
+ );
+ assert_eq!(lost.len(), 1);
+ assert_no_sent_times(&lr);
+ }
+
+ // Test time loss detection as part of an explicit timeout.
+ #[test]
+ fn time_loss_detection_timeout() {
+ let mut lr = setup_lr(3);
+
+ // We want to declare PN 2 as acknowledged before we declare PN 1 as lost.
+ // For this to work, we need PACING above to be less than 1/8 of an RTT.
+ let pn1_sent_time = pn_time(1);
+ let pn1_loss_time = pn1_sent_time + (TEST_RTT * 9 / 8);
+ let pn2_ack_time = pn_time(2) + TEST_RTT;
+ assert!(pn1_loss_time > pn2_ack_time);
+
+ let (_, lost) = lr.on_ack_received(
+ PacketNumberSpace::ApplicationData,
+ 2,
+ vec![2..=2],
+ ACK_DELAY,
+ pn2_ack_time,
+ );
+ assert!(lost.is_empty());
+ // Run the timeout function here to force time-based loss recovery to be enabled.
+ let lost = lr.timeout(pn2_ack_time);
+ assert!(lost.is_empty());
+ assert_sent_times(&lr, None, None, Some(pn1_sent_time));
+
+ // After time elapses, pn 1 is marked lost.
+ let callback_time = lr.next_timeout();
+ assert_eq!(callback_time, Some(pn1_loss_time));
+ let packets = lr.timeout(pn1_loss_time);
+ assert_eq!(packets.len(), 1);
+ // Checking for expiration with zero delay lets us check the loss time.
+ assert!(packets[0].expired(pn1_loss_time, Duration::new(0, 0)));
+ assert_no_sent_times(&lr);
+ }
+
+ #[test]
+ fn big_gap_loss() {
+ let mut lr = setup_lr(5); // This sends packets 0-4 and acknowledges pn 0.
+
+ // Acknowledge just 2-4, which will cause pn 1 to be marked as lost.
+ assert_eq!(super::PACKET_THRESHOLD, 3);
+ let (_, lost) = lr.on_ack_received(
+ PacketNumberSpace::ApplicationData,
+ 4,
+ vec![2..=4],
+ ACK_DELAY,
+ pn_time(4),
+ );
+ assert_eq!(lost.len(), 1);
+ }
+
+ #[test]
+ #[should_panic(expected = "discarding application space")]
+ fn drop_app() {
+ let mut lr = Fixture::default();
+ lr.discard(PacketNumberSpace::ApplicationData, now());
+ }
+
+ #[test]
+ #[should_panic(expected = "dropping spaces out of order")]
+ fn drop_out_of_order() {
+ let mut lr = Fixture::default();
+ lr.discard(PacketNumberSpace::Handshake, now());
+ }
+
+ #[test]
+ fn ack_after_drop() {
+ let mut lr = Fixture::default();
+ lr.discard(PacketNumberSpace::Initial, now());
+ let (acked, lost) = lr.on_ack_received(
+ PacketNumberSpace::Initial,
+ 0,
+ vec![],
+ Duration::from_millis(0),
+ pn_time(0),
+ );
+ assert!(acked.is_empty());
+ assert!(lost.is_empty());
+ }
+
+ #[test]
+ fn drop_spaces() {
+ let mut lr = Fixture::default();
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Initial,
+ 0,
+ pn_time(0),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Handshake,
+ 0,
+ pn_time(1),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ 0,
+ pn_time(2),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+
+ // Now put all spaces on the LR timer so we can see them.
+ for sp in &[
+ PacketType::Initial,
+ PacketType::Handshake,
+ PacketType::Short,
+ ] {
+ let sent_pkt = SentPacket::new(*sp, 1, pn_time(3), true, Vec::new(), ON_SENT_SIZE);
+ let pn_space = PacketNumberSpace::from(sent_pkt.pt);
+ lr.on_packet_sent(sent_pkt);
+ lr.on_ack_received(pn_space, 1, vec![1..=1], Duration::from_secs(0), pn_time(3));
+ let mut lost = Vec::new();
+ lr.spaces.get_mut(pn_space).unwrap().detect_lost_packets(
+ pn_time(3),
+ TEST_RTT,
+ TEST_RTT * 3, // unused
+ &mut lost,
+ );
+ assert!(lost.is_empty());
+ }
+
+ lr.discard(PacketNumberSpace::Initial, pn_time(3));
+ assert_sent_times(&lr, None, Some(pn_time(1)), Some(pn_time(2)));
+
+ lr.discard(PacketNumberSpace::Handshake, pn_time(3));
+ assert_sent_times(&lr, None, None, Some(pn_time(2)));
+
+ // There are cases where we send a packet that is not subsequently tracked.
+ // So check that this works.
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Initial,
+ 0,
+ pn_time(3),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ assert_sent_times(&lr, None, None, Some(pn_time(2)));
+ }
+
+ #[test]
+ fn rearm_pto_after_confirmed() {
+ let mut lr = Fixture::default();
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Initial,
+ 0,
+ now(),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ // Set the RTT to the initial value so that discarding doesn't
+ // alter the estimate.
+ let rtt = lr.path.borrow().rtt().estimate();
+ lr.on_ack_received(
+ PacketNumberSpace::Initial,
+ 0,
+ vec![0..=0],
+ Duration::new(0, 0),
+ now() + rtt,
+ );
+
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Handshake,
+ 0,
+ now(),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ 0,
+ now(),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+
+ assert_eq!(lr.pto_time(PacketNumberSpace::ApplicationData), None);
+ lr.discard(PacketNumberSpace::Initial, pn_time(1));
+ assert_eq!(lr.pto_time(PacketNumberSpace::ApplicationData), None);
+
+ // Expiring state after the PTO on the ApplicationData space has
+ // expired should result in setting a PTO state.
+ let default_pto = RttEstimate::default().pto(PacketNumberSpace::ApplicationData);
+ let expected_pto = pn_time(2) + default_pto;
+ lr.discard(PacketNumberSpace::Handshake, expected_pto);
+ let profile = lr.send_profile(now());
+ assert!(profile.pto.is_some());
+ assert!(!profile.should_probe(PacketNumberSpace::Initial));
+ assert!(!profile.should_probe(PacketNumberSpace::Handshake));
+ assert!(profile.should_probe(PacketNumberSpace::ApplicationData));
+ }
+
+ #[test]
+ fn no_pto_if_amplification_limited() {
+ let mut lr = Fixture::default();
+ // Eat up the amplification limit by telling the path that we've sent a giant packet.
+ {
+ const SPARE: usize = 10;
+ let mut path = lr.path.borrow_mut();
+ let limit = path.amplification_limit();
+ path.add_sent(limit - SPARE);
+ assert_eq!(path.amplification_limit(), SPARE);
+ }
+
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Initial,
+ 1,
+ now(),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+
+ let handshake_pto = RttEstimate::default().pto(PacketNumberSpace::Handshake);
+ let expected_pto = now() + handshake_pto;
+ assert_eq!(lr.pto_time(PacketNumberSpace::Initial), Some(expected_pto));
+ let profile = lr.send_profile(now());
+ assert!(profile.ack_only(PacketNumberSpace::Initial));
+ assert!(profile.pto.is_none());
+ assert!(!profile.should_probe(PacketNumberSpace::Initial));
+ assert!(!profile.should_probe(PacketNumberSpace::Handshake));
+ assert!(!profile.should_probe(PacketNumberSpace::ApplicationData));
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/recv_stream.rs b/third_party/rust/neqo-transport/src/recv_stream.rs
new file mode 100644
index 0000000000..06ca59685d
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/recv_stream.rs
@@ -0,0 +1,2149 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Building a stream of ordered bytes to give the application from a series of
+// incoming STREAM frames.
+
+use std::{
+ cell::RefCell,
+ cmp::max,
+ collections::BTreeMap,
+ convert::TryFrom,
+ mem,
+ rc::{Rc, Weak},
+};
+
+use neqo_common::{qtrace, Role};
+use smallvec::SmallVec;
+
+use crate::{
+ events::ConnectionEvents,
+ fc::ReceiverFlowControl,
+ frame::FRAME_TYPE_STOP_SENDING,
+ packet::PacketBuilder,
+ recovery::{RecoveryToken, StreamRecoveryToken},
+ send_stream::SendStreams,
+ stats::FrameStats,
+ stream_id::StreamId,
+ AppError, Error, Res,
+};
+
+const RX_STREAM_DATA_WINDOW: u64 = 0x10_0000; // 1MiB
+
+// Export as usize for consistency with SEND_BUFFER_SIZE
+pub const RECV_BUFFER_SIZE: usize = RX_STREAM_DATA_WINDOW as usize;
+
+#[derive(Debug, Default)]
+pub(crate) struct RecvStreams {
+ streams: BTreeMap<StreamId, RecvStream>,
+ keep_alive: Weak<()>,
+}
+
+impl RecvStreams {
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ for stream in self.streams.values_mut() {
+ stream.write_frame(builder, tokens, stats);
+ if builder.is_full() {
+ return;
+ }
+ }
+ }
+
+ pub fn insert(&mut self, id: StreamId, stream: RecvStream) {
+ self.streams.insert(id, stream);
+ }
+
+ pub fn get_mut(&mut self, id: StreamId) -> Res<&mut RecvStream> {
+ self.streams.get_mut(&id).ok_or(Error::InvalidStreamId)
+ }
+
+ pub fn keep_alive(&mut self, id: StreamId, k: bool) -> Res<()> {
+ let self_ka = &mut self.keep_alive;
+ let s = self.streams.get_mut(&id).ok_or(Error::InvalidStreamId)?;
+ s.keep_alive = if k {
+ Some(self_ka.upgrade().unwrap_or_else(|| {
+ let r = Rc::new(());
+ *self_ka = Rc::downgrade(&r);
+ r
+ }))
+ } else {
+ None
+ };
+ Ok(())
+ }
+
+ pub fn need_keep_alive(&mut self) -> bool {
+ self.keep_alive.strong_count() > 0
+ }
+
+ pub fn clear(&mut self) {
+ self.streams.clear();
+ }
+
+ pub fn clear_terminal(&mut self, send_streams: &SendStreams, role: Role) -> (u64, u64) {
+ let recv_to_remove = self
+ .streams
+ .iter()
+ .filter_map(|(id, stream)| {
+ // Remove all streams for which the receiving is done (or aborted).
+ // But only if they are unidirectional, or we have finished sending.
+ if stream.is_terminal() && (id.is_uni() || !send_streams.exists(*id)) {
+ Some(*id)
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<_>>();
+
+ let mut removed_bidi = 0;
+ let mut removed_uni = 0;
+ for id in &recv_to_remove {
+ self.streams.remove(id);
+ if id.is_remote_initiated(role) {
+ if id.is_bidi() {
+ removed_bidi += 1;
+ } else {
+ removed_uni += 1;
+ }
+ }
+ }
+
+ (removed_bidi, removed_uni)
+ }
+}
+
+/// Holds data not yet read by application. Orders and dedupes data ranges
+/// from incoming STREAM frames.
+#[derive(Debug, Default)]
+pub struct RxStreamOrderer {
+ data_ranges: BTreeMap<u64, Vec<u8>>, // (start_offset, data)
+ retired: u64, // Number of bytes the application has read
+ received: u64, // The number of bytes has stored in `data_ranges`
+}
+
+impl RxStreamOrderer {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ /// Process an incoming stream frame off the wire. This may result in data
+ /// being available to upper layers if frame is not out of order (ooo) or
+ /// if the frame fills a gap.
+ pub fn inbound_frame(&mut self, mut new_start: u64, mut new_data: &[u8]) {
+ qtrace!("Inbound data offset={} len={}", new_start, new_data.len());
+
+ // Get entry before where new entry would go, so we can see if we already
+ // have the new bytes.
+ // Avoid copies and duplicated data.
+ let new_end = new_start + u64::try_from(new_data.len()).unwrap();
+
+ if new_end <= self.retired {
+ // Range already read by application, this frame is very late and unneeded.
+ return;
+ }
+
+ if new_start < self.retired {
+ new_data = &new_data[usize::try_from(self.retired - new_start).unwrap()..];
+ new_start = self.retired;
+ }
+
+ if new_data.is_empty() {
+ // No data to insert
+ return;
+ }
+
+ let extend = if let Some((&prev_start, prev_vec)) =
+ self.data_ranges.range_mut(..=new_start).next_back()
+ {
+ let prev_end = prev_start + u64::try_from(prev_vec.len()).unwrap();
+ if new_end > prev_end {
+ // PPPPPP -> PPPPPP
+ // NNNNNN NN
+ // NNNNNNNN NN
+ // Add a range containing only new data
+ // (In-order frames will take this path, with no overlap)
+ let overlap = prev_end.saturating_sub(new_start);
+ qtrace!(
+ "New frame {}-{} received, overlap: {}",
+ new_start,
+ new_end,
+ overlap
+ );
+ new_start += overlap;
+ new_data = &new_data[usize::try_from(overlap).unwrap()..];
+ // If it is small enough, extend the previous buffer.
+ // This can't always extend, because otherwise the buffer could end up
+ // growing indefinitely without being released.
+ prev_vec.len() < 4096 && prev_end == new_start
+ } else {
+ // PPPPPP -> PPPPPP
+ // NNNN
+ // NNNN
+ // Do nothing
+ qtrace!(
+ "Dropping frame with already-received range {}-{}",
+ new_start,
+ new_end
+ );
+ return;
+ }
+ } else {
+ qtrace!("New frame {}-{} received", new_start, new_end);
+ false
+ };
+
+ let mut to_add = new_data;
+ if self
+ .data_ranges
+ .last_entry()
+ .map_or(false, |e| *e.key() >= new_start)
+ {
+ // Is this at the end (common case)? If so, nothing to do in this block
+ // Common case:
+ // PPPPPP -> PPPPPP
+ // NNNNNNN NNNNNNN
+ // or
+ // PPPPPP -> PPPPPP
+ // NNNNNNN NNNNNNN
+ //
+ // Not the common case, handle possible overlap with next entries
+ // PPPPPP AAA -> PPPPPP
+ // NNNNNNN NNNNNNN
+ // or
+ // PPPPPP AAAA -> PPPPPP AAAA
+ // NNNNNNN NNNNN
+ // or (this is where to_remove is used)
+ // PPPPPP AA -> PPPPPP
+ // NNNNNNN NNNNNNN
+
+ let mut to_remove = SmallVec::<[_; 8]>::new();
+
+ for (&next_start, next_data) in self.data_ranges.range_mut(new_start..) {
+ let next_end = next_start + u64::try_from(next_data.len()).unwrap();
+ let overlap = new_end.saturating_sub(next_start);
+ if overlap == 0 {
+ // Fills in the hole, exactly (probably common)
+ break;
+ } else if next_end >= new_end {
+ qtrace!(
+ "New frame {}-{} overlaps with next frame by {}, truncating",
+ new_start,
+ new_end,
+ overlap
+ );
+ let truncate_to = new_data.len() - usize::try_from(overlap).unwrap();
+ to_add = &new_data[..truncate_to];
+ break;
+ }
+ qtrace!(
+ "New frame {}-{} spans entire next frame {}-{}, replacing",
+ new_start,
+ new_end,
+ next_start,
+ next_end
+ );
+ to_remove.push(next_start);
+ // Continue, since we may have more overlaps
+ }
+
+ for start in to_remove {
+ self.data_ranges.remove(&start);
+ }
+ }
+
+ if !to_add.is_empty() {
+ self.received += u64::try_from(to_add.len()).unwrap();
+ if extend {
+ let (_, buf) = self
+ .data_ranges
+ .range_mut(..=new_start)
+ .next_back()
+ .unwrap();
+ buf.extend_from_slice(to_add);
+ } else {
+ self.data_ranges.insert(new_start, to_add.to_vec());
+ }
+ }
+ }
+
+ /// Are any bytes readable?
+ pub fn data_ready(&self) -> bool {
+ self.data_ranges
+ .keys()
+ .next()
+ .map_or(false, |&start| start <= self.retired)
+ }
+
+ /// How many bytes are readable?
+ fn bytes_ready(&self) -> usize {
+ let mut prev_end = self.retired;
+ self.data_ranges
+ .iter()
+ .map(|(start_offset, data)| {
+ // All ranges don't overlap but we could have partially
+ // retired some of the first entry's data.
+ let data_len = data.len() as u64 - self.retired.saturating_sub(*start_offset);
+ (start_offset, data_len)
+ })
+ .take_while(|(start_offset, data_len)| {
+ if **start_offset <= prev_end {
+ prev_end += data_len;
+ true
+ } else {
+ false
+ }
+ })
+ .map(|(_, data_len)| data_len as usize)
+ .sum()
+ }
+
+ /// Bytes read by the application.
+ pub fn retired(&self) -> u64 {
+ self.retired
+ }
+
+ pub fn received(&self) -> u64 {
+ self.received
+ }
+
+ /// Data bytes buffered. Could be more than bytes_readable if there are
+ /// ranges missing.
+ fn buffered(&self) -> u64 {
+ self.data_ranges
+ .iter()
+ .map(|(&start, data)| data.len() as u64 - (self.retired.saturating_sub(start)))
+ .sum()
+ }
+
+ /// Copy received data (if any) into the buffer. Returns bytes copied.
+ fn read(&mut self, buf: &mut [u8]) -> usize {
+ qtrace!("Reading {} bytes, {} available", buf.len(), self.buffered());
+ let mut copied = 0;
+
+ for (&range_start, range_data) in &mut self.data_ranges {
+ let mut keep = false;
+ if self.retired >= range_start {
+ // Frame data has new contiguous bytes.
+ let copy_offset =
+ usize::try_from(max(range_start, self.retired) - range_start).unwrap();
+ assert!(range_data.len() >= copy_offset);
+ let available = range_data.len() - copy_offset;
+ let space = buf.len() - copied;
+ let copy_bytes = if available > space {
+ keep = true;
+ space
+ } else {
+ available
+ };
+
+ if copy_bytes > 0 {
+ let copy_slc = &range_data[copy_offset..copy_offset + copy_bytes];
+ buf[copied..copied + copy_bytes].copy_from_slice(copy_slc);
+ copied += copy_bytes;
+ self.retired += u64::try_from(copy_bytes).unwrap();
+ }
+ } else {
+ // The data in the buffer isn't contiguous.
+ keep = true;
+ }
+ if keep {
+ let mut keep = self.data_ranges.split_off(&range_start);
+ mem::swap(&mut self.data_ranges, &mut keep);
+ return copied;
+ }
+ }
+
+ self.data_ranges.clear();
+ copied
+ }
+
+ /// Extend the given Vector with any available data.
+ pub fn read_to_end(&mut self, buf: &mut Vec<u8>) -> usize {
+ let orig_len = buf.len();
+ buf.resize(orig_len + self.bytes_ready(), 0);
+ self.read(&mut buf[orig_len..])
+ }
+}
+
+/// QUIC receiving states, based on -transport 3.2.
+#[derive(Debug)]
+#[allow(dead_code)]
+// Because a dead_code warning is easier than clippy::unused_self, see https://github.com/rust-lang/rust/issues/68408
+enum RecvStreamState {
+ Recv {
+ fc: ReceiverFlowControl<StreamId>,
+ session_fc: Rc<RefCell<ReceiverFlowControl<()>>>,
+ recv_buf: RxStreamOrderer,
+ },
+ SizeKnown {
+ fc: ReceiverFlowControl<StreamId>,
+ session_fc: Rc<RefCell<ReceiverFlowControl<()>>>,
+ recv_buf: RxStreamOrderer,
+ },
+ DataRecvd {
+ fc: ReceiverFlowControl<StreamId>,
+ session_fc: Rc<RefCell<ReceiverFlowControl<()>>>,
+ recv_buf: RxStreamOrderer,
+ },
+ DataRead {
+ final_received: u64,
+ final_read: u64,
+ },
+ AbortReading {
+ fc: ReceiverFlowControl<StreamId>,
+ session_fc: Rc<RefCell<ReceiverFlowControl<()>>>,
+ final_size_reached: bool,
+ frame_needed: bool,
+ err: AppError,
+ final_received: u64,
+ final_read: u64,
+ },
+ WaitForReset {
+ fc: ReceiverFlowControl<StreamId>,
+ session_fc: Rc<RefCell<ReceiverFlowControl<()>>>,
+ final_received: u64,
+ final_read: u64,
+ },
+ ResetRecvd {
+ final_received: u64,
+ final_read: u64,
+ },
+ // Defined by spec but we don't use it: ResetRead
+}
+
+impl RecvStreamState {
+ fn new(
+ max_bytes: u64,
+ stream_id: StreamId,
+ session_fc: Rc<RefCell<ReceiverFlowControl<()>>>,
+ ) -> Self {
+ Self::Recv {
+ fc: ReceiverFlowControl::new(stream_id, max_bytes),
+ recv_buf: RxStreamOrderer::new(),
+ session_fc,
+ }
+ }
+
+ fn name(&self) -> &str {
+ match self {
+ Self::Recv { .. } => "Recv",
+ Self::SizeKnown { .. } => "SizeKnown",
+ Self::DataRecvd { .. } => "DataRecvd",
+ Self::DataRead { .. } => "DataRead",
+ Self::AbortReading { .. } => "AbortReading",
+ Self::WaitForReset { .. } => "WaitForReset",
+ Self::ResetRecvd { .. } => "ResetRecvd",
+ }
+ }
+
+ fn recv_buf(&self) -> Option<&RxStreamOrderer> {
+ match self {
+ Self::Recv { recv_buf, .. }
+ | Self::SizeKnown { recv_buf, .. }
+ | Self::DataRecvd { recv_buf, .. } => Some(recv_buf),
+ Self::DataRead { .. }
+ | Self::AbortReading { .. }
+ | Self::WaitForReset { .. }
+ | Self::ResetRecvd { .. } => None,
+ }
+ }
+
+ fn flow_control_consume_data(&mut self, consumed: u64, fin: bool) -> Res<()> {
+ let (fc, session_fc, final_size_reached, retire_data) = match self {
+ Self::Recv { fc, session_fc, .. } => (fc, session_fc, false, false),
+ Self::WaitForReset { fc, session_fc, .. } => (fc, session_fc, false, true),
+ Self::SizeKnown { fc, session_fc, .. } | Self::DataRecvd { fc, session_fc, .. } => {
+ (fc, session_fc, true, false)
+ }
+ Self::AbortReading {
+ fc,
+ session_fc,
+ final_size_reached,
+ ..
+ } => {
+ let old_final_size_reached = *final_size_reached;
+ *final_size_reached |= fin;
+ (fc, session_fc, old_final_size_reached, true)
+ }
+ Self::DataRead { .. } | Self::ResetRecvd { .. } => {
+ return Ok(());
+ }
+ };
+
+ // Check final size:
+ let final_size_ok = match (fin, final_size_reached) {
+ (true, true) => consumed == fc.consumed(),
+ (false, true) => consumed <= fc.consumed(),
+ (true, false) => consumed >= fc.consumed(),
+ (false, false) => true,
+ };
+
+ if !final_size_ok {
+ return Err(Error::FinalSizeError);
+ }
+
+ let new_bytes_consumed = fc.set_consumed(consumed)?;
+ session_fc.borrow_mut().consume(new_bytes_consumed)?;
+ if retire_data {
+ // Let's also retire this data since the stream has been aborted
+ RecvStream::flow_control_retire_data(fc.consumed() - fc.retired(), fc, session_fc);
+ }
+ Ok(())
+ }
+}
+
+// See https://www.w3.org/TR/webtransport/#receive-stream-stats
+#[derive(Debug, Clone, Copy)]
+pub struct RecvStreamStats {
+ // An indicator of progress on how many of the server application’s bytes
+ // intended for this stream have been received so far.
+ // Only sequential bytes up to, but not including, the first missing byte,
+ // are counted. This number can only increase.
+ pub bytes_received: u64,
+ // The total number of bytes the application has successfully read from this
+ // stream. This number can only increase, and is always less than or equal
+ // to bytes_received.
+ pub bytes_read: u64,
+}
+
+impl RecvStreamStats {
+ #[must_use]
+ pub fn new(bytes_received: u64, bytes_read: u64) -> Self {
+ Self {
+ bytes_received,
+ bytes_read,
+ }
+ }
+
+ #[must_use]
+ pub fn bytes_received(&self) -> u64 {
+ self.bytes_received
+ }
+
+ #[must_use]
+ pub fn bytes_read(&self) -> u64 {
+ self.bytes_read
+ }
+}
+
+/// Implement a QUIC receive stream.
+#[derive(Debug)]
+pub struct RecvStream {
+ stream_id: StreamId,
+ state: RecvStreamState,
+ conn_events: ConnectionEvents,
+ keep_alive: Option<Rc<()>>,
+}
+
+impl RecvStream {
+ pub fn new(
+ stream_id: StreamId,
+ max_stream_data: u64,
+ session_fc: Rc<RefCell<ReceiverFlowControl<()>>>,
+ conn_events: ConnectionEvents,
+ ) -> Self {
+ Self {
+ stream_id,
+ state: RecvStreamState::new(max_stream_data, stream_id, session_fc),
+ conn_events,
+ keep_alive: None,
+ }
+ }
+
+ fn set_state(&mut self, new_state: RecvStreamState) {
+ debug_assert_ne!(
+ mem::discriminant(&self.state),
+ mem::discriminant(&new_state)
+ );
+ qtrace!(
+ "RecvStream {} state {} -> {}",
+ self.stream_id.as_u64(),
+ self.state.name(),
+ new_state.name()
+ );
+
+ match new_state {
+ // Receiving all data, or receiving or requesting RESET_STREAM
+ // is cause to stop keep-alives.
+ RecvStreamState::DataRecvd { .. }
+ | RecvStreamState::AbortReading { .. }
+ | RecvStreamState::ResetRecvd { .. } => {
+ self.keep_alive = None;
+ }
+ // Once all the data is read, generate an event.
+ RecvStreamState::DataRead { .. } => {
+ self.conn_events.recv_stream_complete(self.stream_id);
+ }
+ _ => {}
+ }
+
+ self.state = new_state;
+ }
+
+ pub fn stats(&self) -> RecvStreamStats {
+ match &self.state {
+ RecvStreamState::Recv { recv_buf, .. }
+ | RecvStreamState::SizeKnown { recv_buf, .. }
+ | RecvStreamState::DataRecvd { recv_buf, .. } => {
+ let received = recv_buf.received();
+ let read = recv_buf.retired();
+ RecvStreamStats::new(received, read)
+ }
+ RecvStreamState::AbortReading {
+ final_received,
+ final_read,
+ ..
+ }
+ | RecvStreamState::WaitForReset {
+ final_received,
+ final_read,
+ ..
+ }
+ | RecvStreamState::DataRead {
+ final_received,
+ final_read,
+ }
+ | RecvStreamState::ResetRecvd {
+ final_received,
+ final_read,
+ } => {
+ let received = *final_received;
+ let read = *final_read;
+ RecvStreamStats::new(received, read)
+ }
+ }
+ }
+
+ pub fn inbound_stream_frame(&mut self, fin: bool, offset: u64, data: &[u8]) -> Res<()> {
+ // We should post a DataReadable event only once when we change from no-data-ready to
+ // data-ready. Therefore remember the state before processing a new frame.
+ let already_data_ready = self.data_ready();
+ let new_end = offset + u64::try_from(data.len()).unwrap();
+
+ self.state.flow_control_consume_data(new_end, fin)?;
+
+ match &mut self.state {
+ RecvStreamState::Recv {
+ recv_buf,
+ fc,
+ session_fc,
+ } => {
+ recv_buf.inbound_frame(offset, data);
+ if fin {
+ let all_recv =
+ fc.consumed() == recv_buf.retired() + recv_buf.bytes_ready() as u64;
+ let buf = mem::replace(recv_buf, RxStreamOrderer::new());
+ let fc_copy = mem::take(fc);
+ let session_fc_copy = mem::take(session_fc);
+ if all_recv {
+ self.set_state(RecvStreamState::DataRecvd {
+ fc: fc_copy,
+ session_fc: session_fc_copy,
+ recv_buf: buf,
+ });
+ } else {
+ self.set_state(RecvStreamState::SizeKnown {
+ fc: fc_copy,
+ session_fc: session_fc_copy,
+ recv_buf: buf,
+ });
+ }
+ }
+ }
+ RecvStreamState::SizeKnown {
+ recv_buf,
+ fc,
+ session_fc,
+ } => {
+ recv_buf.inbound_frame(offset, data);
+ if fc.consumed() == recv_buf.retired() + recv_buf.bytes_ready() as u64 {
+ let buf = mem::replace(recv_buf, RxStreamOrderer::new());
+ let fc_copy = mem::take(fc);
+ let session_fc_copy = mem::take(session_fc);
+ self.set_state(RecvStreamState::DataRecvd {
+ fc: fc_copy,
+ session_fc: session_fc_copy,
+ recv_buf: buf,
+ });
+ }
+ }
+ RecvStreamState::DataRecvd { .. }
+ | RecvStreamState::DataRead { .. }
+ | RecvStreamState::AbortReading { .. }
+ | RecvStreamState::WaitForReset { .. }
+ | RecvStreamState::ResetRecvd { .. } => {
+ qtrace!("data received when we are in state {}", self.state.name());
+ }
+ }
+
+ if !already_data_ready && (self.data_ready() || self.needs_to_inform_app_about_fin()) {
+ self.conn_events.recv_stream_readable(self.stream_id);
+ }
+
+ Ok(())
+ }
+
+ pub fn reset(&mut self, application_error_code: AppError, final_size: u64) -> Res<()> {
+ self.state.flow_control_consume_data(final_size, true)?;
+ match &mut self.state {
+ RecvStreamState::Recv {
+ fc,
+ session_fc,
+ recv_buf,
+ }
+ | RecvStreamState::SizeKnown {
+ fc,
+ session_fc,
+ recv_buf,
+ } => {
+ // make flow control consumes new data that not really exist.
+ Self::flow_control_retire_data(final_size - fc.retired(), fc, session_fc);
+ self.conn_events
+ .recv_stream_reset(self.stream_id, application_error_code);
+ let received = recv_buf.received();
+ let read = recv_buf.retired();
+ self.set_state(RecvStreamState::ResetRecvd {
+ final_received: received,
+ final_read: read,
+ });
+ }
+ RecvStreamState::AbortReading {
+ fc,
+ session_fc,
+ final_received,
+ final_read,
+ ..
+ }
+ | RecvStreamState::WaitForReset {
+ fc,
+ session_fc,
+ final_received,
+ final_read,
+ } => {
+ // make flow control consumes new data that not really exist.
+ Self::flow_control_retire_data(final_size - fc.retired(), fc, session_fc);
+ self.conn_events
+ .recv_stream_reset(self.stream_id, application_error_code);
+ let received = *final_received;
+ let read = *final_read;
+ self.set_state(RecvStreamState::ResetRecvd {
+ final_received: received,
+ final_read: read,
+ });
+ }
+ _ => {
+ // Ignore reset if in DataRecvd, DataRead, or ResetRecvd
+ }
+ }
+ Ok(())
+ }
+
+ /// If we should tell the sender they have more credit, return an offset
+ fn flow_control_retire_data(
+ new_read: u64,
+ fc: &mut ReceiverFlowControl<StreamId>,
+ session_fc: &mut Rc<RefCell<ReceiverFlowControl<()>>>,
+ ) {
+ if new_read > 0 {
+ fc.add_retired(new_read);
+ session_fc.borrow_mut().add_retired(new_read);
+ }
+ }
+
+ /// Send a flow control update.
+ /// This is used when a peer declares that they are blocked.
+ /// This sends `MAX_STREAM_DATA` if there is any increase possible.
+ pub fn send_flowc_update(&mut self) {
+ if let RecvStreamState::Recv { fc, .. } = &mut self.state {
+ fc.send_flowc_update();
+ }
+ }
+
+ pub fn set_stream_max_data(&mut self, max_data: u64) {
+ if let RecvStreamState::Recv { fc, .. } = &mut self.state {
+ fc.set_max_active(max_data);
+ }
+ }
+
+ pub fn is_terminal(&self) -> bool {
+ matches!(
+ self.state,
+ RecvStreamState::ResetRecvd { .. } | RecvStreamState::DataRead { .. }
+ )
+ }
+
+ // App got all data but did not get the fin signal.
+ fn needs_to_inform_app_about_fin(&self) -> bool {
+ matches!(self.state, RecvStreamState::DataRecvd { .. })
+ }
+
+ fn data_ready(&self) -> bool {
+ self.state
+ .recv_buf()
+ .map_or(false, RxStreamOrderer::data_ready)
+ }
+
+ /// # Errors
+ ///
+ /// `NoMoreData` if data and fin bit were previously read by the application.
+ pub fn read(&mut self, buf: &mut [u8]) -> Res<(usize, bool)> {
+ let data_recvd_state = matches!(self.state, RecvStreamState::DataRecvd { .. });
+ match &mut self.state {
+ RecvStreamState::Recv {
+ recv_buf,
+ fc,
+ session_fc,
+ }
+ | RecvStreamState::SizeKnown {
+ recv_buf,
+ fc,
+ session_fc,
+ ..
+ }
+ | RecvStreamState::DataRecvd {
+ recv_buf,
+ fc,
+ session_fc,
+ } => {
+ let bytes_read = recv_buf.read(buf);
+ Self::flow_control_retire_data(u64::try_from(bytes_read).unwrap(), fc, session_fc);
+ let fin_read = if data_recvd_state {
+ if recv_buf.buffered() == 0 {
+ let received = recv_buf.received();
+ let read = recv_buf.retired();
+ self.set_state(RecvStreamState::DataRead {
+ final_received: received,
+ final_read: read,
+ });
+ true
+ } else {
+ false
+ }
+ } else {
+ false
+ };
+ Ok((bytes_read, fin_read))
+ }
+ RecvStreamState::DataRead { .. }
+ | RecvStreamState::AbortReading { .. }
+ | RecvStreamState::WaitForReset { .. }
+ | RecvStreamState::ResetRecvd { .. } => Err(Error::NoMoreData),
+ }
+ }
+
+ pub fn stop_sending(&mut self, err: AppError) {
+ qtrace!("stop_sending called when in state {}", self.state.name());
+ match &mut self.state {
+ RecvStreamState::Recv {
+ fc,
+ session_fc,
+ recv_buf,
+ }
+ | RecvStreamState::SizeKnown {
+ fc,
+ session_fc,
+ recv_buf,
+ } => {
+ // Retire data
+ Self::flow_control_retire_data(fc.consumed() - fc.retired(), fc, session_fc);
+ let fc_copy = mem::take(fc);
+ let session_fc_copy = mem::take(session_fc);
+ let received = recv_buf.received();
+ let read = recv_buf.retired();
+ self.set_state(RecvStreamState::AbortReading {
+ fc: fc_copy,
+ session_fc: session_fc_copy,
+ final_size_reached: matches!(self.state, RecvStreamState::SizeKnown { .. }),
+ frame_needed: true,
+ err,
+ final_received: received,
+ final_read: read,
+ });
+ }
+ RecvStreamState::DataRecvd {
+ fc,
+ session_fc,
+ recv_buf,
+ } => {
+ Self::flow_control_retire_data(fc.consumed() - fc.retired(), fc, session_fc);
+ let received = recv_buf.received();
+ let read = recv_buf.retired();
+ self.set_state(RecvStreamState::DataRead {
+ final_received: received,
+ final_read: read,
+ });
+ }
+ RecvStreamState::DataRead { .. }
+ | RecvStreamState::AbortReading { .. }
+ | RecvStreamState::WaitForReset { .. }
+ | RecvStreamState::ResetRecvd { .. } => {
+ // Already in terminal state
+ }
+ }
+ }
+
+ /// Maybe write a `MAX_STREAM_DATA` frame.
+ pub fn write_frame(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ match &mut self.state {
+ // Maybe send MAX_STREAM_DATA
+ RecvStreamState::Recv { fc, .. } => fc.write_frames(builder, tokens, stats),
+ // Maybe send STOP_SENDING
+ RecvStreamState::AbortReading {
+ frame_needed, err, ..
+ } => {
+ if *frame_needed
+ && builder.write_varint_frame(&[
+ FRAME_TYPE_STOP_SENDING,
+ self.stream_id.as_u64(),
+ *err,
+ ])
+ {
+ tokens.push(RecoveryToken::Stream(StreamRecoveryToken::StopSending {
+ stream_id: self.stream_id,
+ }));
+ stats.stop_sending += 1;
+ *frame_needed = false;
+ }
+ }
+ _ => {}
+ }
+ }
+
+ pub fn max_stream_data_lost(&mut self, maximum_data: u64) {
+ if let RecvStreamState::Recv { fc, .. } = &mut self.state {
+ fc.frame_lost(maximum_data);
+ }
+ }
+
+ pub fn stop_sending_lost(&mut self) {
+ if let RecvStreamState::AbortReading { frame_needed, .. } = &mut self.state {
+ *frame_needed = true;
+ }
+ }
+
+ pub fn stop_sending_acked(&mut self) {
+ if let RecvStreamState::AbortReading {
+ fc,
+ session_fc,
+ final_size_reached,
+ final_received,
+ final_read,
+ ..
+ } = &mut self.state
+ {
+ let received = *final_received;
+ let read = *final_read;
+ if *final_size_reached {
+ // We already know the final_size of the stream therefore we
+ // do not need to wait for RESET.
+ self.set_state(RecvStreamState::ResetRecvd {
+ final_received: received,
+ final_read: read,
+ });
+ } else {
+ let fc_copy = mem::take(fc);
+ let session_fc_copy = mem::take(session_fc);
+ self.set_state(RecvStreamState::WaitForReset {
+ fc: fc_copy,
+ session_fc: session_fc_copy,
+ final_received: received,
+ final_read: read,
+ });
+ }
+ }
+ }
+
+ #[cfg(test)]
+ pub fn has_frames_to_write(&self) -> bool {
+ if let RecvStreamState::Recv { fc, .. } = &self.state {
+ fc.frame_needed()
+ } else {
+ false
+ }
+ }
+
+ #[cfg(test)]
+ pub fn fc(&self) -> Option<&ReceiverFlowControl<StreamId>> {
+ match &self.state {
+ RecvStreamState::Recv { fc, .. }
+ | RecvStreamState::SizeKnown { fc, .. }
+ | RecvStreamState::DataRecvd { fc, .. }
+ | RecvStreamState::AbortReading { fc, .. }
+ | RecvStreamState::WaitForReset { fc, .. } => Some(fc),
+ _ => None,
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::ops::Range;
+
+ use neqo_common::Encoder;
+
+ use super::*;
+
+ const SESSION_WINDOW: usize = 1024;
+
+ fn recv_ranges(ranges: &[Range<u64>], available: usize) {
+ const ZEROES: &[u8] = &[0; 100];
+ qtrace!("recv_ranges {:?}", ranges);
+
+ let mut s = RxStreamOrderer::default();
+ for r in ranges {
+ let data = &ZEROES[..usize::try_from(r.end - r.start).unwrap()];
+ s.inbound_frame(r.start, data);
+ }
+
+ let mut buf = [0xff; 100];
+ let mut total_recvd = 0;
+ loop {
+ let recvd = s.read(&mut buf[..]);
+ qtrace!("recv_ranges read {}", recvd);
+ total_recvd += recvd;
+ if recvd == 0 {
+ assert_eq!(total_recvd, available);
+ break;
+ }
+ }
+ }
+
+ #[test]
+ #[allow(unknown_lints, clippy::single_range_in_vec_init)] // Because that lint makes no sense here.
+ fn recv_noncontiguous() {
+ // Non-contiguous with the start, no data available.
+ recv_ranges(&[10..20], 0);
+ }
+
+ /// Overlaps with the start of a 10..20 range of bytes.
+ #[test]
+ fn recv_overlap_start() {
+ // Overlap the start, with a larger new value.
+ // More overlap than not.
+ recv_ranges(&[10..20, 4..18, 0..4], 20);
+ // Overlap the start, with a larger new value.
+ // Less overlap than not.
+ recv_ranges(&[10..20, 2..15, 0..2], 20);
+ // Overlap the start, with a smaller new value.
+ // More overlap than not.
+ recv_ranges(&[10..20, 8..14, 0..8], 20);
+ // Overlap the start, with a smaller new value.
+ // Less overlap than not.
+ recv_ranges(&[10..20, 6..13, 0..6], 20);
+
+ // Again with some of the first range split in two.
+ recv_ranges(&[10..11, 11..20, 4..18, 0..4], 20);
+ recv_ranges(&[10..11, 11..20, 2..15, 0..2], 20);
+ recv_ranges(&[10..11, 11..20, 8..14, 0..8], 20);
+ recv_ranges(&[10..11, 11..20, 6..13, 0..6], 20);
+
+ // Again with a gap in the first range.
+ recv_ranges(&[10..11, 12..20, 4..18, 0..4], 20);
+ recv_ranges(&[10..11, 12..20, 2..15, 0..2], 20);
+ recv_ranges(&[10..11, 12..20, 8..14, 0..8], 20);
+ recv_ranges(&[10..11, 12..20, 6..13, 0..6], 20);
+ }
+
+ /// Overlaps with the end of a 10..20 range of bytes.
+ #[test]
+ fn recv_overlap_end() {
+ // Overlap the end, with a larger new value.
+ // More overlap than not.
+ recv_ranges(&[10..20, 12..25, 0..10], 25);
+ // Overlap the end, with a larger new value.
+ // Less overlap than not.
+ recv_ranges(&[10..20, 17..33, 0..10], 33);
+ // Overlap the end, with a smaller new value.
+ // More overlap than not.
+ recv_ranges(&[10..20, 15..21, 0..10], 21);
+ // Overlap the end, with a smaller new value.
+ // Less overlap than not.
+ recv_ranges(&[10..20, 17..25, 0..10], 25);
+
+ // Again with some of the first range split in two.
+ recv_ranges(&[10..19, 19..20, 12..25, 0..10], 25);
+ recv_ranges(&[10..19, 19..20, 17..33, 0..10], 33);
+ recv_ranges(&[10..19, 19..20, 15..21, 0..10], 21);
+ recv_ranges(&[10..19, 19..20, 17..25, 0..10], 25);
+
+ // Again with a gap in the first range.
+ recv_ranges(&[10..18, 19..20, 12..25, 0..10], 25);
+ recv_ranges(&[10..18, 19..20, 17..33, 0..10], 33);
+ recv_ranges(&[10..18, 19..20, 15..21, 0..10], 21);
+ recv_ranges(&[10..18, 19..20, 17..25, 0..10], 25);
+ }
+
+ /// Complete overlaps with the start of a 10..20 range of bytes.
+ #[test]
+ fn recv_overlap_complete() {
+ // Complete overlap, more at the end.
+ recv_ranges(&[10..20, 9..23, 0..9], 23);
+ // Complete overlap, more at the start.
+ recv_ranges(&[10..20, 3..23, 0..3], 23);
+ // Complete overlap, to end.
+ recv_ranges(&[10..20, 5..20, 0..5], 20);
+ // Complete overlap, from start.
+ recv_ranges(&[10..20, 10..27, 0..10], 27);
+ // Complete overlap, from 0 and more.
+ recv_ranges(&[10..20, 0..23], 23);
+
+ // Again with the first range split in two.
+ recv_ranges(&[10..14, 14..20, 9..23, 0..9], 23);
+ recv_ranges(&[10..14, 14..20, 3..23, 0..3], 23);
+ recv_ranges(&[10..14, 14..20, 5..20, 0..5], 20);
+ recv_ranges(&[10..14, 14..20, 10..27, 0..10], 27);
+ recv_ranges(&[10..14, 14..20, 0..23], 23);
+
+ // Again with the a gap in the first range.
+ recv_ranges(&[10..13, 14..20, 9..23, 0..9], 23);
+ recv_ranges(&[10..13, 14..20, 3..23, 0..3], 23);
+ recv_ranges(&[10..13, 14..20, 5..20, 0..5], 20);
+ recv_ranges(&[10..13, 14..20, 10..27, 0..10], 27);
+ recv_ranges(&[10..13, 14..20, 0..23], 23);
+ }
+
+ /// An overlap with no new bytes.
+ #[test]
+ fn recv_overlap_duplicate() {
+ recv_ranges(&[10..20, 11..12, 0..10], 20);
+ recv_ranges(&[10..20, 10..15, 0..10], 20);
+ recv_ranges(&[10..20, 14..20, 0..10], 20);
+ // Now with the first range split.
+ recv_ranges(&[10..14, 14..20, 10..15, 0..10], 20);
+ recv_ranges(&[10..15, 16..20, 21..25, 10..25, 0..10], 25);
+ }
+
+ /// Reading exactly one chunk works, when the next chunk starts immediately.
+ #[test]
+ fn stop_reading_at_chunk() {
+ const CHUNK_SIZE: usize = 10;
+ const EXTRA_SIZE: usize = 3;
+ let mut s = RxStreamOrderer::new();
+
+ // Add three chunks.
+ s.inbound_frame(0, &[0; CHUNK_SIZE]);
+ let offset = u64::try_from(CHUNK_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+ let offset = u64::try_from(CHUNK_SIZE + EXTRA_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+
+ // Read, providing only enough space for the first.
+ let mut buf = [0; 100];
+ let count = s.read(&mut buf[..CHUNK_SIZE]);
+ assert_eq!(count, CHUNK_SIZE);
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, EXTRA_SIZE * 2);
+ }
+
+ #[test]
+ fn recv_overlap_while_reading() {
+ let mut s = RxStreamOrderer::new();
+
+ // Add a chunk
+ s.inbound_frame(0, &[0; 150]);
+ assert_eq!(s.data_ranges.get(&0).unwrap().len(), 150);
+ // Read, providing only enough space for the first 100.
+ let mut buf = [0; 100];
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, 100);
+ assert_eq!(s.retired, 100);
+
+ // Add a second frame that overlaps.
+ // This shouldn't truncate the first frame, as we're already
+ // Reading from it.
+ s.inbound_frame(120, &[0; 60]);
+ assert_eq!(s.data_ranges.get(&0).unwrap().len(), 180);
+ // Read second part of first frame and all of the second frame
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, 80);
+ }
+
+ /// Reading exactly one chunk works, when there is a gap.
+ #[test]
+ fn stop_reading_at_gap() {
+ const CHUNK_SIZE: usize = 10;
+ const EXTRA_SIZE: usize = 3;
+ let mut s = RxStreamOrderer::new();
+
+ // Add three chunks.
+ s.inbound_frame(0, &[0; CHUNK_SIZE]);
+ let offset = u64::try_from(CHUNK_SIZE + EXTRA_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+
+ // Read, providing only enough space for the first chunk.
+ let mut buf = [0; 100];
+ let count = s.read(&mut buf[..CHUNK_SIZE]);
+ assert_eq!(count, CHUNK_SIZE);
+
+ // Now fill the gap and ensure that everything can be read.
+ let offset = u64::try_from(CHUNK_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, EXTRA_SIZE * 2);
+ }
+
+ /// Reading exactly one chunk works, when there is a gap.
+ #[test]
+ fn stop_reading_in_chunk() {
+ const CHUNK_SIZE: usize = 10;
+ const EXTRA_SIZE: usize = 3;
+ let mut s = RxStreamOrderer::new();
+
+ // Add two chunks.
+ s.inbound_frame(0, &[0; CHUNK_SIZE]);
+ let offset = u64::try_from(CHUNK_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+
+ // Read, providing only enough space for some of the first chunk.
+ let mut buf = [0; 100];
+ let count = s.read(&mut buf[..CHUNK_SIZE - EXTRA_SIZE]);
+ assert_eq!(count, CHUNK_SIZE - EXTRA_SIZE);
+
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, EXTRA_SIZE * 2);
+ }
+
+ /// Read one byte at a time.
+ #[test]
+ fn read_byte_at_a_time() {
+ const CHUNK_SIZE: usize = 10;
+ const EXTRA_SIZE: usize = 3;
+ let mut s = RxStreamOrderer::new();
+
+ // Add two chunks.
+ s.inbound_frame(0, &[0; CHUNK_SIZE]);
+ let offset = u64::try_from(CHUNK_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+
+ let mut buf = [0; 1];
+ for _ in 0..CHUNK_SIZE + EXTRA_SIZE {
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, 1);
+ }
+ assert_eq!(0, s.read(&mut buf[..]));
+ }
+
+ fn check_stats(stream: &RecvStream, expected_received: u64, expected_read: u64) {
+ let stream_stats = stream.stats();
+ assert_eq!(expected_received, stream_stats.bytes_received());
+ assert_eq!(expected_read, stream_stats.bytes_read());
+ }
+
+ #[test]
+ fn stream_rx() {
+ let conn_events = ConnectionEvents::default();
+
+ let mut s = RecvStream::new(
+ StreamId::from(567),
+ 1024,
+ Rc::new(RefCell::new(ReceiverFlowControl::new((), 1024 * 1024))),
+ conn_events,
+ );
+
+ // test receiving a contig frame and reading it works
+ s.inbound_stream_frame(false, 0, &[1; 10]).unwrap();
+ assert!(s.data_ready());
+ check_stats(&s, 10, 0);
+
+ let mut buf = vec![0u8; 100];
+ assert_eq!(s.read(&mut buf).unwrap(), (10, false));
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 0);
+
+ check_stats(&s, 10, 10);
+
+ // test receiving a noncontig frame
+ s.inbound_stream_frame(false, 12, &[2; 12]).unwrap();
+ assert!(!s.data_ready());
+ assert_eq!(s.read(&mut buf).unwrap(), (0, false));
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 12);
+
+ check_stats(&s, 22, 10);
+
+ // another frame that overlaps the first
+ s.inbound_stream_frame(false, 14, &[3; 8]).unwrap();
+ assert!(!s.data_ready());
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 12);
+
+ check_stats(&s, 22, 10);
+
+ // fill in the gap, but with a FIN
+ s.inbound_stream_frame(true, 10, &[4; 6]).unwrap_err();
+ assert!(!s.data_ready());
+ assert_eq!(s.read(&mut buf).unwrap(), (0, false));
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 12);
+
+ check_stats(&s, 22, 10);
+
+ // fill in the gap
+ s.inbound_stream_frame(false, 10, &[5; 10]).unwrap();
+ assert!(s.data_ready());
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 14);
+
+ check_stats(&s, 24, 10);
+
+ // a legit FIN
+ s.inbound_stream_frame(true, 24, &[6; 18]).unwrap();
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 32);
+ assert!(s.data_ready());
+ assert_eq!(s.read(&mut buf).unwrap(), (32, true));
+
+ check_stats(&s, 42, 42);
+
+ // Stream now no longer readable (is in DataRead state)
+ s.read(&mut buf).unwrap_err();
+ }
+
+ fn check_chunks(s: &mut RxStreamOrderer, expected: &[(u64, usize)]) {
+ assert_eq!(s.data_ranges.len(), expected.len());
+ for ((start, buf), (expected_start, expected_len)) in s.data_ranges.iter().zip(expected) {
+ assert_eq!((*start, buf.len()), (*expected_start, *expected_len));
+ }
+ }
+
+ // Test deduplication when the new data is at the end.
+ #[test]
+ fn stream_rx_dedupe_tail() {
+ let mut s = RxStreamOrderer::new();
+
+ s.inbound_frame(0, &[1; 6]);
+ check_chunks(&mut s, &[(0, 6)]);
+
+ // New data that overlaps entirely (starting from the head), is ignored.
+ s.inbound_frame(0, &[2; 3]);
+ check_chunks(&mut s, &[(0, 6)]);
+
+ // New data that overlaps at the tail has any new data appended.
+ s.inbound_frame(2, &[3; 6]);
+ check_chunks(&mut s, &[(0, 8)]);
+
+ // New data that overlaps entirely (up to the tail), is ignored.
+ s.inbound_frame(4, &[4; 4]);
+ check_chunks(&mut s, &[(0, 8)]);
+
+ // New data that overlaps, starting from the beginning is appended too.
+ s.inbound_frame(0, &[5; 10]);
+ check_chunks(&mut s, &[(0, 10)]);
+
+ // New data that is entirely subsumed is ignored.
+ s.inbound_frame(2, &[6; 2]);
+ check_chunks(&mut s, &[(0, 10)]);
+
+ let mut buf = [0; 16];
+ assert_eq!(s.read(&mut buf[..]), 10);
+ assert_eq!(buf[..10], [1, 1, 1, 1, 1, 1, 3, 3, 5, 5]);
+ }
+
+ /// When chunks are added before existing data, they aren't merged.
+ #[test]
+ fn stream_rx_dedupe_head() {
+ let mut s = RxStreamOrderer::new();
+
+ s.inbound_frame(1, &[6; 6]);
+ check_chunks(&mut s, &[(1, 6)]);
+
+ // Insertion before an existing chunk causes truncation of the new chunk.
+ s.inbound_frame(0, &[7; 6]);
+ check_chunks(&mut s, &[(0, 1), (1, 6)]);
+
+ // Perfect overlap with existing slices has no effect.
+ s.inbound_frame(0, &[8; 7]);
+ check_chunks(&mut s, &[(0, 1), (1, 6)]);
+
+ let mut buf = [0; 16];
+ assert_eq!(s.read(&mut buf[..]), 7);
+ assert_eq!(buf[..7], [7, 6, 6, 6, 6, 6, 6]);
+ }
+
+ #[test]
+ fn stream_rx_dedupe_new_tail() {
+ let mut s = RxStreamOrderer::new();
+
+ s.inbound_frame(1, &[6; 6]);
+ check_chunks(&mut s, &[(1, 6)]);
+
+ // Insertion before an existing chunk causes truncation of the new chunk.
+ s.inbound_frame(0, &[7; 6]);
+ check_chunks(&mut s, &[(0, 1), (1, 6)]);
+
+ // New data at the end causes the tail to be added to the first chunk,
+ // replacing later chunks entirely.
+ s.inbound_frame(0, &[9; 8]);
+ check_chunks(&mut s, &[(0, 8)]);
+
+ let mut buf = [0; 16];
+ assert_eq!(s.read(&mut buf[..]), 8);
+ assert_eq!(buf[..8], [7, 9, 9, 9, 9, 9, 9, 9]);
+ }
+
+ #[test]
+ fn stream_rx_dedupe_replace() {
+ let mut s = RxStreamOrderer::new();
+
+ s.inbound_frame(2, &[6; 6]);
+ check_chunks(&mut s, &[(2, 6)]);
+
+ // Insertion before an existing chunk causes truncation of the new chunk.
+ s.inbound_frame(1, &[7; 6]);
+ check_chunks(&mut s, &[(1, 1), (2, 6)]);
+
+ // New data at the start and end replaces all the slices.
+ s.inbound_frame(0, &[9; 10]);
+ check_chunks(&mut s, &[(0, 10)]);
+
+ let mut buf = [0; 16];
+ assert_eq!(s.read(&mut buf[..]), 10);
+ assert_eq!(buf[..10], [9; 10]);
+ }
+
+ #[test]
+ fn trim_retired() {
+ let mut s = RxStreamOrderer::new();
+
+ let mut buf = [0; 18];
+ s.inbound_frame(0, &[1; 10]);
+
+ // Partially read slices are retained.
+ assert_eq!(s.read(&mut buf[..6]), 6);
+ check_chunks(&mut s, &[(0, 10)]);
+
+ // Partially read slices are kept and so are added to.
+ s.inbound_frame(3, &buf[..10]);
+ check_chunks(&mut s, &[(0, 13)]);
+
+ // Wholly read pieces are dropped.
+ assert_eq!(s.read(&mut buf[..]), 7);
+ assert!(s.data_ranges.is_empty());
+
+ // New data that overlaps with retired data is trimmed.
+ s.inbound_frame(0, &buf[..]);
+ check_chunks(&mut s, &[(13, 5)]);
+ }
+
+ #[test]
+ fn stream_flowc_update() {
+ let mut s = create_stream(1024 * RX_STREAM_DATA_WINDOW);
+ let mut buf = vec![0u8; RECV_BUFFER_SIZE + 100]; // Make it overlarge
+
+ assert!(!s.has_frames_to_write());
+ s.inbound_stream_frame(false, 0, &[0; RECV_BUFFER_SIZE])
+ .unwrap();
+ assert!(!s.has_frames_to_write());
+ assert_eq!(s.read(&mut buf).unwrap(), (RECV_BUFFER_SIZE, false));
+ assert!(!s.data_ready());
+
+ // flow msg generated!
+ assert!(s.has_frames_to_write());
+
+ // consume it
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ let mut token = Vec::new();
+ s.write_frame(&mut builder, &mut token, &mut FrameStats::default());
+
+ // it should be gone
+ assert!(!s.has_frames_to_write());
+ }
+
+ fn create_stream(session_fc: u64) -> RecvStream {
+ let conn_events = ConnectionEvents::default();
+ RecvStream::new(
+ StreamId::from(67),
+ RX_STREAM_DATA_WINDOW,
+ Rc::new(RefCell::new(ReceiverFlowControl::new((), session_fc))),
+ conn_events,
+ )
+ }
+
+ #[test]
+ fn stream_max_stream_data() {
+ let mut s = create_stream(1024 * RX_STREAM_DATA_WINDOW);
+ assert!(!s.has_frames_to_write());
+ s.inbound_stream_frame(false, 0, &[0; RECV_BUFFER_SIZE])
+ .unwrap();
+ s.inbound_stream_frame(false, RX_STREAM_DATA_WINDOW, &[1; 1])
+ .unwrap_err();
+ }
+
+ #[test]
+ fn stream_orderer_bytes_ready() {
+ let mut rx_ord = RxStreamOrderer::new();
+
+ rx_ord.inbound_frame(0, &[1; 6]);
+ assert_eq!(rx_ord.bytes_ready(), 6);
+ assert_eq!(rx_ord.buffered(), 6);
+ assert_eq!(rx_ord.retired(), 0);
+
+ // read some so there's an offset into the first frame
+ let mut buf = [0u8; 10];
+ rx_ord.read(&mut buf[..2]);
+ assert_eq!(rx_ord.bytes_ready(), 4);
+ assert_eq!(rx_ord.buffered(), 4);
+ assert_eq!(rx_ord.retired(), 2);
+
+ // an overlapping frame
+ rx_ord.inbound_frame(5, &[2; 6]);
+ assert_eq!(rx_ord.bytes_ready(), 9);
+ assert_eq!(rx_ord.buffered(), 9);
+ assert_eq!(rx_ord.retired(), 2);
+
+ // a noncontig frame
+ rx_ord.inbound_frame(20, &[3; 6]);
+ assert_eq!(rx_ord.bytes_ready(), 9);
+ assert_eq!(rx_ord.buffered(), 15);
+ assert_eq!(rx_ord.retired(), 2);
+
+ // an old frame
+ rx_ord.inbound_frame(0, &[4; 2]);
+ assert_eq!(rx_ord.bytes_ready(), 9);
+ assert_eq!(rx_ord.buffered(), 15);
+ assert_eq!(rx_ord.retired(), 2);
+ }
+
+ #[test]
+ fn no_stream_flowc_event_after_exiting_recv() {
+ let mut s = create_stream(1024 * RX_STREAM_DATA_WINDOW);
+ s.inbound_stream_frame(false, 0, &[0; RECV_BUFFER_SIZE])
+ .unwrap();
+ let mut buf = [0; RECV_BUFFER_SIZE];
+ s.read(&mut buf).unwrap();
+ assert!(s.has_frames_to_write());
+ s.inbound_stream_frame(true, RX_STREAM_DATA_WINDOW, &[])
+ .unwrap();
+ assert!(!s.has_frames_to_write());
+ }
+
+ fn create_stream_with_fc(
+ session_fc: Rc<RefCell<ReceiverFlowControl<()>>>,
+ fc_limit: u64,
+ ) -> RecvStream {
+ RecvStream::new(
+ StreamId::from(567),
+ fc_limit,
+ session_fc,
+ ConnectionEvents::default(),
+ )
+ }
+
+ fn create_stream_session_flow_control() -> (RecvStream, Rc<RefCell<ReceiverFlowControl<()>>>) {
+ assert!(RX_STREAM_DATA_WINDOW > u64::try_from(SESSION_WINDOW).unwrap());
+ let session_fc = Rc::new(RefCell::new(ReceiverFlowControl::new(
+ (),
+ u64::try_from(SESSION_WINDOW).unwrap(),
+ )));
+ (
+ create_stream_with_fc(Rc::clone(&session_fc), RX_STREAM_DATA_WINDOW),
+ session_fc,
+ )
+ }
+
+ #[test]
+ fn session_flow_control() {
+ let (mut s, session_fc) = create_stream_session_flow_control();
+
+ s.inbound_stream_frame(false, 0, &[0; SESSION_WINDOW])
+ .unwrap();
+ assert!(!session_fc.borrow().frame_needed());
+ // The buffer is big enough to hold SESSION_WINDOW, this will make sure that we always
+ // read everything from he stream.
+ let mut buf = [0; 2 * SESSION_WINDOW];
+ s.read(&mut buf).unwrap();
+ assert!(session_fc.borrow().frame_needed());
+ // consume it
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ let mut token = Vec::new();
+ session_fc
+ .borrow_mut()
+ .write_frames(&mut builder, &mut token, &mut FrameStats::default());
+
+ // Switch to SizeKnown state
+ s.inbound_stream_frame(true, 2 * u64::try_from(SESSION_WINDOW).unwrap() - 1, &[0])
+ .unwrap();
+ assert!(!session_fc.borrow().frame_needed());
+ // Receive new data that can be read.
+ s.inbound_stream_frame(
+ false,
+ u64::try_from(SESSION_WINDOW).unwrap(),
+ &[0; SESSION_WINDOW / 2 + 1],
+ )
+ .unwrap();
+ assert!(!session_fc.borrow().frame_needed());
+ s.read(&mut buf).unwrap();
+ assert!(session_fc.borrow().frame_needed());
+ // consume it
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ let mut token = Vec::new();
+ session_fc
+ .borrow_mut()
+ .write_frames(&mut builder, &mut token, &mut FrameStats::default());
+
+ // Test DataRecvd state
+ let session_fc = Rc::new(RefCell::new(ReceiverFlowControl::new(
+ (),
+ u64::try_from(SESSION_WINDOW).unwrap(),
+ )));
+ let mut s = RecvStream::new(
+ StreamId::from(567),
+ RX_STREAM_DATA_WINDOW,
+ Rc::clone(&session_fc),
+ ConnectionEvents::default(),
+ );
+
+ s.inbound_stream_frame(true, 0, &[0; SESSION_WINDOW])
+ .unwrap();
+ assert!(!session_fc.borrow().frame_needed());
+ s.read(&mut buf).unwrap();
+ assert!(session_fc.borrow().frame_needed());
+ }
+
+ #[test]
+ fn session_flow_control_reset() {
+ let (mut s, session_fc) = create_stream_session_flow_control();
+
+ s.inbound_stream_frame(false, 0, &[0; SESSION_WINDOW / 2])
+ .unwrap();
+ assert!(!session_fc.borrow().frame_needed());
+
+ s.reset(
+ Error::NoError.code(),
+ u64::try_from(SESSION_WINDOW).unwrap(),
+ )
+ .unwrap();
+ assert!(session_fc.borrow().frame_needed());
+ }
+
+ fn check_fc<T: std::fmt::Debug>(fc: &ReceiverFlowControl<T>, consumed: u64, retired: u64) {
+ assert_eq!(fc.consumed(), consumed);
+ assert_eq!(fc.retired(), retired);
+ }
+
+ /// Test consuming the flow control in RecvStreamState::Recv
+ #[test]
+ fn fc_state_recv_1() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+ let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+
+ check_fc(&fc.borrow(), 0, 0);
+ check_fc(s.fc().unwrap(), 0, 0);
+
+ s.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap();
+
+ check_fc(&fc.borrow(), SW / 4, 0);
+ check_fc(s.fc().unwrap(), SW / 4, 0);
+ }
+
+ /// Test consuming the flow control in RecvStreamState::Recv
+ /// with multiple streams
+ #[test]
+ fn fc_state_recv_2() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+ let mut s1 = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+ let mut s2 = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+
+ check_fc(&fc.borrow(), 0, 0);
+ check_fc(s1.fc().unwrap(), 0, 0);
+ check_fc(s2.fc().unwrap(), 0, 0);
+
+ s1.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap();
+
+ check_fc(&fc.borrow(), SW / 4, 0);
+ check_fc(s1.fc().unwrap(), SW / 4, 0);
+ check_fc(s2.fc().unwrap(), 0, 0);
+
+ s2.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap();
+
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s1.fc().unwrap(), SW / 4, 0);
+ check_fc(s2.fc().unwrap(), SW / 4, 0);
+ }
+
+ /// Test retiring the flow control in RecvStreamState::Recv
+ /// with multiple streams
+ #[test]
+ fn fc_state_recv_3() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+ let mut s1 = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+ let mut s2 = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+
+ check_fc(&fc.borrow(), 0, 0);
+ check_fc(s1.fc().unwrap(), 0, 0);
+ check_fc(s2.fc().unwrap(), 0, 0);
+
+ s1.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap();
+ s2.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s1.fc().unwrap(), SW / 4, 0);
+ check_fc(s2.fc().unwrap(), SW / 4, 0);
+
+ // Read data
+ let mut buf = [1; SW_US];
+ assert_eq!(s1.read(&mut buf).unwrap(), (SW_US / 4, false));
+ check_fc(&fc.borrow(), SW / 2, SW / 4);
+ check_fc(s1.fc().unwrap(), SW / 4, SW / 4);
+ check_fc(s2.fc().unwrap(), SW / 4, 0);
+
+ assert_eq!(s2.read(&mut buf).unwrap(), (SW_US / 4, false));
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s1.fc().unwrap(), SW / 4, SW / 4);
+ check_fc(s2.fc().unwrap(), SW / 4, SW / 4);
+
+ // Read when there is no more date to be read will not change fc.
+ assert_eq!(s1.read(&mut buf).unwrap(), (0, false));
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s1.fc().unwrap(), SW / 4, SW / 4);
+ check_fc(s2.fc().unwrap(), SW / 4, SW / 4);
+
+ // Receiving more data on a stream.
+ s1.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4])
+ .unwrap();
+ check_fc(&fc.borrow(), SW * 3 / 4, SW / 2);
+ check_fc(s1.fc().unwrap(), SW / 2, SW / 4);
+ check_fc(s2.fc().unwrap(), SW / 4, SW / 4);
+
+ // Read data
+ assert_eq!(s1.read(&mut buf).unwrap(), (SW_US / 4, false));
+ check_fc(&fc.borrow(), SW * 3 / 4, SW * 3 / 4);
+ check_fc(s1.fc().unwrap(), SW / 2, SW / 2);
+ check_fc(s2.fc().unwrap(), SW / 4, SW / 4);
+ }
+
+ /// Test consuming the flow control in RecvStreamState::Recv - duplicate data
+ #[test]
+ fn fc_state_recv_4() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+ let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+
+ check_fc(&fc.borrow(), 0, 0);
+ check_fc(s.fc().unwrap(), 0, 0);
+
+ s.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap();
+
+ check_fc(&fc.borrow(), SW / 4, 0);
+ check_fc(s.fc().unwrap(), SW / 4, 0);
+
+ // Receiving duplicate frames (already consumed data) will not cause an error or
+ // change fc.
+ s.inbound_stream_frame(false, 0, &[0; SW_US / 8]).unwrap();
+ check_fc(&fc.borrow(), SW / 4, 0);
+ check_fc(s.fc().unwrap(), SW / 4, 0);
+ }
+
+ /// Test consuming the flow control in RecvStreamState::Recv - filling a gap in the
+ /// data stream.
+ #[test]
+ fn fc_state_recv_5() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+ let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+
+ // Receive out of order data.
+ s.inbound_stream_frame(false, SW / 8, &[0; SW_US / 8])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 4, 0);
+ check_fc(s.fc().unwrap(), SW / 4, 0);
+
+ // Filling in the gap will not change fc.
+ s.inbound_stream_frame(false, 0, &[0; SW_US / 8]).unwrap();
+ check_fc(&fc.borrow(), SW / 4, 0);
+ check_fc(s.fc().unwrap(), SW / 4, 0);
+ }
+
+ /// Test consuming the flow control in RecvStreamState::Recv - receiving frame past
+ /// the flow control will cause an error.
+ #[test]
+ fn fc_state_recv_6() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+ let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+
+ // Receiving frame past the flow control will cause an error.
+ assert_eq!(
+ s.inbound_stream_frame(false, 0, &[0; SW_US * 3 / 4 + 1]),
+ Err(Error::FlowControlError)
+ );
+ }
+
+ /// Test that the flow controls will send updates.
+ #[test]
+ fn fc_state_recv_7() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+ let mut s = create_stream_with_fc(Rc::clone(&fc), SW / 2);
+
+ check_fc(&fc.borrow(), 0, 0);
+ check_fc(s.fc().unwrap(), 0, 0);
+
+ s.inbound_stream_frame(false, 0, &[0; SW_US / 4]).unwrap();
+ let mut buf = [1; SW_US];
+ assert_eq!(s.read(&mut buf).unwrap(), (SW_US / 4, false));
+ check_fc(&fc.borrow(), SW / 4, SW / 4);
+ check_fc(s.fc().unwrap(), SW / 4, SW / 4);
+
+ // Still no fc update needed.
+ assert!(!fc.borrow().frame_needed());
+ assert!(!s.fc().unwrap().frame_needed());
+
+ // Receive one more byte that will cause a fc update after it is read.
+ s.inbound_stream_frame(false, SW / 4, &[0]).unwrap();
+ check_fc(&fc.borrow(), SW / 4 + 1, SW / 4);
+ check_fc(s.fc().unwrap(), SW / 4 + 1, SW / 4);
+ // Only consuming data does not cause a fc update to be sent.
+ assert!(!fc.borrow().frame_needed());
+ assert!(!s.fc().unwrap().frame_needed());
+
+ assert_eq!(s.read(&mut buf).unwrap(), (1, false));
+ check_fc(&fc.borrow(), SW / 4 + 1, SW / 4 + 1);
+ check_fc(s.fc().unwrap(), SW / 4 + 1, SW / 4 + 1);
+ // Data are retired and the sttream fc will send an update.
+ assert!(!fc.borrow().frame_needed());
+ assert!(s.fc().unwrap().frame_needed());
+
+ // Receive more data to increase fc further.
+ s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4])
+ .unwrap();
+ assert_eq!(s.read(&mut buf).unwrap(), (SW_US / 4 - 1, false));
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+ assert!(!fc.borrow().frame_needed());
+ assert!(s.fc().unwrap().frame_needed());
+
+ // Write the fc update frame
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ let mut token = Vec::new();
+ let mut stats = FrameStats::default();
+ fc.borrow_mut()
+ .write_frames(&mut builder, &mut token, &mut stats);
+ assert_eq!(stats.max_data, 0);
+ s.write_frame(&mut builder, &mut token, &mut stats);
+ assert_eq!(stats.max_stream_data, 1);
+
+ // Receive 1 byte that will case a session fc update after it is read.
+ s.inbound_stream_frame(false, SW / 2, &[0]).unwrap();
+ assert_eq!(s.read(&mut buf).unwrap(), (1, false));
+ check_fc(&fc.borrow(), SW / 2 + 1, SW / 2 + 1);
+ check_fc(s.fc().unwrap(), SW / 2 + 1, SW / 2 + 1);
+ assert!(fc.borrow().frame_needed());
+ assert!(!s.fc().unwrap().frame_needed());
+ fc.borrow_mut()
+ .write_frames(&mut builder, &mut token, &mut stats);
+ assert_eq!(stats.max_data, 1);
+ s.write_frame(&mut builder, &mut token, &mut stats);
+ assert_eq!(stats.max_stream_data, 1);
+ }
+
+ /// Test flow control in RecvStreamState::SizeKnown
+ #[test]
+ fn fc_state_size_known() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+
+ let mut s = create_stream_with_fc(Rc::clone(&fc), SW);
+
+ check_fc(&fc.borrow(), 0, 0);
+ check_fc(s.fc().unwrap(), 0, 0);
+
+ s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ // Receiving duplicate frames (already consumed data) will not cause an error or
+ // change fc.
+ s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ // The stream can still receive duplicate data without a fin bit.
+ s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ // Receiving frame past the final size of a stream will return an error.
+ assert_eq!(
+ s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4 + 1]),
+ Err(Error::FinalSizeError)
+ );
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ // Add new data to the gap will not change fc.
+ s.inbound_stream_frame(false, SW / 8, &[0; SW_US / 8])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ // Fill the gap
+ s.inbound_stream_frame(false, 0, &[0; SW_US / 8]).unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ // Read all data
+ let mut buf = [1; SW_US];
+ assert_eq!(s.read(&mut buf).unwrap(), (SW_US / 2, true));
+ // the stream does not have fc any more. We can only check the session fc.
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ assert!(s.fc().is_none());
+ }
+
+ /// Test flow control in RecvStreamState::DataRecvd
+ #[test]
+ fn fc_state_data_recv() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+
+ let mut s = create_stream_with_fc(Rc::clone(&fc), SW);
+
+ check_fc(&fc.borrow(), 0, 0);
+ check_fc(s.fc().unwrap(), 0, 0);
+
+ s.inbound_stream_frame(true, 0, &[0; SW_US / 2]).unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ // Receiving duplicate frames (already consumed data) will not cause an error or
+ // change fc.
+ s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ // The stream can still receive duplicate data without a fin bit.
+ s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ // Receiving frame past the final size of a stream will return an error.
+ assert_eq!(
+ s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4 + 1]),
+ Err(Error::FinalSizeError)
+ );
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ // Read all data
+ let mut buf = [1; SW_US];
+ assert_eq!(s.read(&mut buf).unwrap(), (SW_US / 2, true));
+ // the stream does not have fc any more. We can only check the session fc.
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ assert!(s.fc().is_none());
+ }
+
+ /// Test flow control in RecvStreamState::DataRead
+ #[test]
+ fn fc_state_data_read() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+
+ let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+ check_fc(&fc.borrow(), 0, 0);
+ check_fc(s.fc().unwrap(), 0, 0);
+
+ s.inbound_stream_frame(true, 0, &[0; SW_US / 2]).unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ let mut buf = [1; SW_US];
+ assert_eq!(s.read(&mut buf).unwrap(), (SW_US / 2, true));
+ // the stream does not have fc any more. We can only check the session fc.
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ assert!(s.fc().is_none());
+
+ // Receiving duplicate frames (already consumed data) will not cause an error or
+ // change fc.
+ s.inbound_stream_frame(true, 0, &[0; SW_US / 2]).unwrap();
+ // the stream does not have fc any more. We can only check the session fc.
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ assert!(s.fc().is_none());
+
+ // Receiving frame past the final size of a stream or the stream's fc limit
+ // will NOT return an error.
+ s.inbound_stream_frame(true, 0, &[0; SW_US / 2 + 1])
+ .unwrap();
+ s.inbound_stream_frame(true, 0, &[0; SW_US * 3 / 4 + 1])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ assert!(s.fc().is_none());
+ }
+
+ /// Test flow control in RecvStreamState::AbortReading and final size is known
+ #[test]
+ fn fc_state_abort_reading_1() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+
+ let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+ check_fc(&fc.borrow(), 0, 0);
+ check_fc(s.fc().unwrap(), 0, 0);
+
+ s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ s.stop_sending(Error::NoError.code());
+ // All data will de retired
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+
+ // Receiving duplicate frames (already consumed data) will not cause an error or
+ // change fc.
+ s.inbound_stream_frame(true, 0, &[0; SW_US / 2]).unwrap();
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+
+ // The stream can still receive duplicate data without a fin bit.
+ s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+
+ // Receiving frame past the final size of a stream will return an error.
+ assert_eq!(
+ s.inbound_stream_frame(true, SW / 4, &[0; SW_US / 4 + 1]),
+ Err(Error::FinalSizeError)
+ );
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+ }
+
+ /// Test flow control in RecvStreamState::AbortReading and final size is unknown
+ #[test]
+ fn fc_state_abort_reading_2() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+
+ let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+ check_fc(&fc.borrow(), 0, 0);
+ check_fc(s.fc().unwrap(), 0, 0);
+
+ s.inbound_stream_frame(false, 0, &[0; SW_US / 2]).unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ s.stop_sending(Error::NoError.code());
+ // All data will de retired
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+
+ // Receiving duplicate frames (already consumed data) will not cause an error or
+ // change fc.
+ s.inbound_stream_frame(false, 0, &[0; SW_US / 2]).unwrap();
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+
+ // Receiving data past the flow control limit will cause an error.
+ assert_eq!(
+ s.inbound_stream_frame(false, 0, &[0; SW_US * 3 / 4 + 1]),
+ Err(Error::FlowControlError)
+ );
+
+ // The stream can still receive duplicate data without a fin bit.
+ s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+
+ // Receiving more data will case the data to be retired.
+ // The stream can still receive duplicate data without a fin bit.
+ s.inbound_stream_frame(false, SW / 2, &[0; 10]).unwrap();
+ check_fc(&fc.borrow(), SW / 2 + 10, SW / 2 + 10);
+ check_fc(s.fc().unwrap(), SW / 2 + 10, SW / 2 + 10);
+
+ // We can still receive the final size.
+ s.inbound_stream_frame(true, SW / 2, &[0; 20]).unwrap();
+ check_fc(&fc.borrow(), SW / 2 + 20, SW / 2 + 20);
+ check_fc(s.fc().unwrap(), SW / 2 + 20, SW / 2 + 20);
+
+ // Receiving frame past the final size of a stream will return an error.
+ assert_eq!(
+ s.inbound_stream_frame(true, SW / 2, &[0; 21]),
+ Err(Error::FinalSizeError)
+ );
+ check_fc(&fc.borrow(), SW / 2 + 20, SW / 2 + 20);
+ check_fc(s.fc().unwrap(), SW / 2 + 20, SW / 2 + 20);
+ }
+
+ /// Test flow control in RecvStreamState::WaitForReset
+ #[test]
+ fn fc_state_wait_for_reset() {
+ const SW: u64 = 1024;
+ const SW_US: usize = 1024;
+ let fc = Rc::new(RefCell::new(ReceiverFlowControl::new((), SW)));
+
+ let mut s = create_stream_with_fc(Rc::clone(&fc), SW * 3 / 4);
+ check_fc(&fc.borrow(), 0, 0);
+ check_fc(s.fc().unwrap(), 0, 0);
+
+ s.inbound_stream_frame(false, 0, &[0; SW_US / 2]).unwrap();
+ check_fc(&fc.borrow(), SW / 2, 0);
+ check_fc(s.fc().unwrap(), SW / 2, 0);
+
+ s.stop_sending(Error::NoError.code());
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+
+ s.stop_sending_acked();
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+
+ // Receiving duplicate frames (already consumed data) will not cause an error or
+ // change fc.
+ s.inbound_stream_frame(false, 0, &[0; SW_US / 2]).unwrap();
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+
+ // Receiving data past the flow control limit will cause an error.
+ assert_eq!(
+ s.inbound_stream_frame(false, 0, &[0; SW_US * 3 / 4 + 1]),
+ Err(Error::FlowControlError)
+ );
+
+ // The stream can still receive duplicate data without a fin bit.
+ s.inbound_stream_frame(false, SW / 4, &[0; SW_US / 4])
+ .unwrap();
+ check_fc(&fc.borrow(), SW / 2, SW / 2);
+ check_fc(s.fc().unwrap(), SW / 2, SW / 2);
+
+ // Receiving more data will case the data to be retired.
+ // The stream can still receive duplicate data without a fin bit.
+ s.inbound_stream_frame(false, SW / 2, &[0; 10]).unwrap();
+ check_fc(&fc.borrow(), SW / 2 + 10, SW / 2 + 10);
+ check_fc(s.fc().unwrap(), SW / 2 + 10, SW / 2 + 10);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/rtt.rs b/third_party/rust/neqo-transport/src/rtt.rs
new file mode 100644
index 0000000000..4b05198bc9
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/rtt.rs
@@ -0,0 +1,211 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Tracking of sent packets and detecting their loss.
+
+#![deny(clippy::pedantic)]
+
+use std::{
+ cmp::{max, min},
+ time::{Duration, Instant},
+};
+
+use neqo_common::{qlog::NeqoQlog, qtrace};
+
+use crate::{
+ ackrate::{AckRate, PeerAckDelay},
+ packet::PacketBuilder,
+ qlog::{self, QlogMetric},
+ recovery::RecoveryToken,
+ stats::FrameStats,
+ tracking::PacketNumberSpace,
+};
+
+/// The smallest time that the system timer (via `sleep()`, `nanosleep()`,
+/// `select()`, or similar) can reliably deliver; see `neqo_common::hrtime`.
+pub const GRANULARITY: Duration = Duration::from_millis(1);
+// Defined in -recovery 6.2 as 333ms but using lower value.
+pub(crate) const INITIAL_RTT: Duration = Duration::from_millis(100);
+
+#[derive(Debug)]
+#[allow(clippy::module_name_repetitions)]
+pub struct RttEstimate {
+ first_sample_time: Option<Instant>,
+ latest_rtt: Duration,
+ smoothed_rtt: Duration,
+ rttvar: Duration,
+ min_rtt: Duration,
+ ack_delay: PeerAckDelay,
+}
+
+impl RttEstimate {
+ fn init(&mut self, rtt: Duration) {
+ // Only allow this when there are no samples.
+ debug_assert!(self.first_sample_time.is_none());
+ self.latest_rtt = rtt;
+ self.min_rtt = rtt;
+ self.smoothed_rtt = rtt;
+ self.rttvar = rtt / 2;
+ }
+
+ #[cfg(test)]
+ pub const fn from_duration(rtt: Duration) -> Self {
+ Self {
+ first_sample_time: None,
+ latest_rtt: rtt,
+ smoothed_rtt: rtt,
+ rttvar: Duration::from_millis(0),
+ min_rtt: rtt,
+ ack_delay: PeerAckDelay::Fixed(Duration::from_millis(25)),
+ }
+ }
+
+ pub fn set_initial(&mut self, rtt: Duration) {
+ qtrace!("initial RTT={:?}", rtt);
+ if rtt >= GRANULARITY {
+ // Ignore if the value is too small.
+ self.init(rtt);
+ }
+ }
+
+ /// For a new path, prime the RTT based on the state of another path.
+ pub fn prime_rtt(&mut self, other: &Self) {
+ self.set_initial(other.smoothed_rtt + other.rttvar);
+ self.ack_delay = other.ack_delay.clone();
+ }
+
+ pub fn set_ack_delay(&mut self, ack_delay: PeerAckDelay) {
+ self.ack_delay = ack_delay;
+ }
+
+ pub fn update_ack_delay(&mut self, cwnd: usize, mtu: usize) {
+ self.ack_delay.update(cwnd, mtu, self.smoothed_rtt);
+ }
+
+ pub fn update(
+ &mut self,
+ qlog: &mut NeqoQlog,
+ mut rtt_sample: Duration,
+ ack_delay: Duration,
+ confirmed: bool,
+ now: Instant,
+ ) {
+ // Limit ack delay by max_ack_delay if confirmed.
+ let mad = self.ack_delay.max();
+ let ack_delay = if confirmed && ack_delay > mad {
+ mad
+ } else {
+ ack_delay
+ };
+
+ // min_rtt ignores ack delay.
+ self.min_rtt = min(self.min_rtt, rtt_sample);
+ // Adjust for ack delay unless it goes below `min_rtt`.
+ if rtt_sample - self.min_rtt >= ack_delay {
+ rtt_sample -= ack_delay;
+ }
+
+ if self.first_sample_time.is_none() {
+ self.init(rtt_sample);
+ self.first_sample_time = Some(now);
+ } else {
+ // Calculate EWMA RTT (based on {{?RFC6298}}).
+ let rttvar_sample = if self.smoothed_rtt > rtt_sample {
+ self.smoothed_rtt - rtt_sample
+ } else {
+ rtt_sample - self.smoothed_rtt
+ };
+
+ self.latest_rtt = rtt_sample;
+ self.rttvar = (self.rttvar * 3 + rttvar_sample) / 4;
+ self.smoothed_rtt = (self.smoothed_rtt * 7 + rtt_sample) / 8;
+ }
+ qtrace!(
+ "RTT latest={:?} -> estimate={:?}~{:?}",
+ self.latest_rtt,
+ self.smoothed_rtt,
+ self.rttvar
+ );
+ qlog::metrics_updated(
+ qlog,
+ &[
+ QlogMetric::LatestRtt(self.latest_rtt),
+ QlogMetric::MinRtt(self.min_rtt),
+ QlogMetric::SmoothedRtt(self.smoothed_rtt),
+ ],
+ );
+ }
+
+ /// Get the estimated value.
+ pub fn estimate(&self) -> Duration {
+ self.smoothed_rtt
+ }
+
+ pub fn pto(&self, pn_space: PacketNumberSpace) -> Duration {
+ let mut t = self.estimate() + max(4 * self.rttvar, GRANULARITY);
+ if pn_space == PacketNumberSpace::ApplicationData {
+ t += self.ack_delay.max();
+ }
+ t
+ }
+
+ /// Calculate the loss delay based on the current estimate and the last
+ /// RTT measurement received.
+ pub fn loss_delay(&self) -> Duration {
+ // kTimeThreshold = 9/8
+ // loss_delay = kTimeThreshold * max(latest_rtt, smoothed_rtt)
+ // loss_delay = max(loss_delay, kGranularity)
+ let rtt = max(self.latest_rtt, self.smoothed_rtt);
+ max(rtt * 9 / 8, GRANULARITY)
+ }
+
+ pub fn first_sample_time(&self) -> Option<Instant> {
+ self.first_sample_time
+ }
+
+ #[cfg(test)]
+ pub fn latest(&self) -> Duration {
+ self.latest_rtt
+ }
+
+ pub fn rttvar(&self) -> Duration {
+ self.rttvar
+ }
+
+ pub fn minimum(&self) -> Duration {
+ self.min_rtt
+ }
+
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ self.ack_delay.write_frames(builder, tokens, stats);
+ }
+
+ pub fn frame_lost(&mut self, lost: &AckRate) {
+ self.ack_delay.frame_lost(lost);
+ }
+
+ pub fn frame_acked(&mut self, acked: &AckRate) {
+ self.ack_delay.frame_acked(acked);
+ }
+}
+
+impl Default for RttEstimate {
+ fn default() -> Self {
+ Self {
+ first_sample_time: None,
+ latest_rtt: INITIAL_RTT,
+ smoothed_rtt: INITIAL_RTT,
+ rttvar: INITIAL_RTT / 2,
+ min_rtt: INITIAL_RTT,
+ ack_delay: PeerAckDelay::default(),
+ }
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/send_stream.rs b/third_party/rust/neqo-transport/src/send_stream.rs
new file mode 100644
index 0000000000..5feb785ac6
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/send_stream.rs
@@ -0,0 +1,2636 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Buffering data to send until it is acked.
+
+use std::{
+ cell::RefCell,
+ cmp::{max, min, Ordering},
+ collections::{BTreeMap, VecDeque},
+ convert::TryFrom,
+ hash::{Hash, Hasher},
+ mem,
+ ops::Add,
+ rc::Rc,
+};
+
+use indexmap::IndexMap;
+use neqo_common::{qdebug, qerror, qinfo, qtrace, Encoder, Role};
+use smallvec::SmallVec;
+
+use crate::{
+ events::ConnectionEvents,
+ fc::SenderFlowControl,
+ frame::{Frame, FRAME_TYPE_RESET_STREAM},
+ packet::PacketBuilder,
+ recovery::{RecoveryToken, StreamRecoveryToken},
+ stats::FrameStats,
+ stream_id::StreamId,
+ streams::SendOrder,
+ tparams::{self, TransportParameters},
+ AppError, Error, Res,
+};
+
+pub const SEND_BUFFER_SIZE: usize = 0x10_0000; // 1 MiB
+
+/// The priority that is assigned to sending data for the stream.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum TransmissionPriority {
+ /// This stream is more important than the functioning of the connection.
+ /// Don't use this priority unless the stream really is that important.
+ /// A stream at this priority can starve out other connection functions,
+ /// including flow control, which could be very bad.
+ Critical,
+ /// The stream is very important. Stream data will be written ahead of
+ /// some of the less critical connection functions, like path validation,
+ /// connection ID management, and session tickets.
+ Important,
+ /// High priority streams are important, but not enough to disrupt
+ /// connection operation. They go ahead of session tickets though.
+ High,
+ /// The default priority.
+ Normal,
+ /// Low priority streams get sent last.
+ Low,
+}
+
+impl Default for TransmissionPriority {
+ fn default() -> Self {
+ Self::Normal
+ }
+}
+
+impl PartialOrd for TransmissionPriority {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl Ord for TransmissionPriority {
+ fn cmp(&self, other: &Self) -> Ordering {
+ if self == other {
+ return Ordering::Equal;
+ }
+ match (self, other) {
+ (Self::Critical, _) => Ordering::Greater,
+ (_, Self::Critical) => Ordering::Less,
+ (Self::Important, _) => Ordering::Greater,
+ (_, Self::Important) => Ordering::Less,
+ (Self::High, _) => Ordering::Greater,
+ (_, Self::High) => Ordering::Less,
+ (Self::Normal, _) => Ordering::Greater,
+ (_, Self::Normal) => Ordering::Less,
+ _ => unreachable!(),
+ }
+ }
+}
+
+impl Add<RetransmissionPriority> for TransmissionPriority {
+ type Output = Self;
+ fn add(self, rhs: RetransmissionPriority) -> Self::Output {
+ match rhs {
+ RetransmissionPriority::Fixed(fixed) => fixed,
+ RetransmissionPriority::Same => self,
+ RetransmissionPriority::Higher => match self {
+ Self::Critical => Self::Critical,
+ Self::Important | Self::High => Self::Important,
+ Self::Normal => Self::High,
+ Self::Low => Self::Normal,
+ },
+ RetransmissionPriority::MuchHigher => match self {
+ Self::Critical | Self::Important => Self::Critical,
+ Self::High | Self::Normal => Self::Important,
+ Self::Low => Self::High,
+ },
+ }
+ }
+}
+
+/// If data is lost, this determines the priority that applies to retransmissions
+/// of that data.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum RetransmissionPriority {
+ /// Prioritize retransmission at a fixed priority.
+ /// With this, it is possible to prioritize retransmissions lower than transmissions.
+ /// Doing that can create a deadlock with flow control which might cause the connection
+ /// to stall unless new data stops arriving fast enough that retransmissions can complete.
+ Fixed(TransmissionPriority),
+ /// Don't increase priority for retransmission. This is probably not a good idea
+ /// as it could mean starving flow control.
+ Same,
+ /// Increase the priority of retransmissions (the default).
+ /// Retransmissions of `Critical` or `Important` aren't elevated at all.
+ Higher,
+ /// Increase the priority of retransmissions a lot.
+ /// This is useful for streams that are particularly exposed to head-of-line blocking.
+ MuchHigher,
+}
+
+impl Default for RetransmissionPriority {
+ fn default() -> Self {
+ Self::Higher
+ }
+}
+
+#[derive(Debug, PartialEq, Clone, Copy)]
+enum RangeState {
+ Sent,
+ Acked,
+}
+
+/// Track ranges in the stream as sent or acked. Acked implies sent. Not in a
+/// range implies needing-to-be-sent, either initially or as a retransmission.
+#[derive(Debug, Default, PartialEq)]
+struct RangeTracker {
+ // offset, (len, RangeState). Use u64 for len because ranges can exceed 32bits.
+ used: BTreeMap<u64, (u64, RangeState)>,
+}
+
+impl RangeTracker {
+ fn highest_offset(&self) -> u64 {
+ self.used
+ .range(..)
+ .next_back()
+ .map_or(0, |(k, (v, _))| *k + *v)
+ }
+
+ fn acked_from_zero(&self) -> u64 {
+ self.used
+ .get(&0)
+ .filter(|(_, state)| *state == RangeState::Acked)
+ .map_or(0, |(v, _)| *v)
+ }
+
+ /// Find the first unmarked range. If all are contiguous, this will return
+ /// (highest_offset(), None).
+ fn first_unmarked_range(&self) -> (u64, Option<u64>) {
+ let mut prev_end = 0;
+
+ for (cur_off, (cur_len, _)) in &self.used {
+ if prev_end == *cur_off {
+ prev_end = cur_off + cur_len;
+ } else {
+ return (prev_end, Some(cur_off - prev_end));
+ }
+ }
+ (prev_end, None)
+ }
+
+ /// Turn one range into a list of subranges that align with existing
+ /// ranges.
+ /// Check impermissible overlaps in subregions: Sent cannot overwrite Acked.
+ //
+ // e.g. given N is new and ABC are existing:
+ // NNNNNNNNNNNNNNNN
+ // AAAAA BBBCCCCC ...then we want 5 chunks:
+ // 1122222333444555
+ //
+ // but also if we have this:
+ // NNNNNNNNNNNNNNNN
+ // AAAAAAAAAA BBBB ...then break existing A and B ranges up:
+ //
+ // 1111111122222233
+ // aaAAAAAAAA BBbb
+ //
+ // Doing all this work up front should make handling each chunk much
+ // easier.
+ fn chunk_range_on_edges(
+ &mut self,
+ new_off: u64,
+ new_len: u64,
+ new_state: RangeState,
+ ) -> Vec<(u64, u64, RangeState)> {
+ let mut tmp_off = new_off;
+ let mut tmp_len = new_len;
+ let mut v = Vec::new();
+
+ // cut previous overlapping range if needed
+ let prev = self.used.range_mut(..tmp_off).next_back();
+ if let Some((prev_off, (prev_len, prev_state))) = prev {
+ let prev_state = *prev_state;
+ let overlap = (*prev_off + *prev_len).saturating_sub(new_off);
+ *prev_len -= overlap;
+ if overlap > 0 {
+ self.used.insert(new_off, (overlap, prev_state));
+ }
+ }
+
+ let mut last_existing_remaining = None;
+ for (off, (len, state)) in self.used.range(tmp_off..tmp_off + tmp_len) {
+ // Create chunk for "overhang" before an existing range
+ if tmp_off < *off {
+ let sub_len = off - tmp_off;
+ v.push((tmp_off, sub_len, new_state));
+ tmp_off += sub_len;
+ tmp_len -= sub_len;
+ }
+
+ // Create chunk to match existing range
+ let sub_len = min(*len, tmp_len);
+ let remaining_len = len - sub_len;
+ if new_state == RangeState::Sent && *state == RangeState::Acked {
+ qinfo!(
+ "Attempted to downgrade overlapping range Acked range {}-{} with Sent {}-{}",
+ off,
+ len,
+ new_off,
+ new_len
+ );
+ } else {
+ v.push((tmp_off, sub_len, new_state));
+ }
+ tmp_off += sub_len;
+ tmp_len -= sub_len;
+
+ if remaining_len > 0 {
+ last_existing_remaining = Some((*off, sub_len, remaining_len, *state));
+ }
+ }
+
+ // Maybe break last existing range in two so that a final chunk will
+ // have the same length as an existing range entry
+ if let Some((off, sub_len, remaining_len, state)) = last_existing_remaining {
+ *self.used.get_mut(&off).expect("must be there") = (sub_len, state);
+ self.used.insert(off + sub_len, (remaining_len, state));
+ }
+
+ // Create final chunk if anything remains of the new range
+ if tmp_len > 0 {
+ v.push((tmp_off, tmp_len, new_state));
+ }
+
+ v
+ }
+
+ /// Merge contiguous Acked ranges into the first entry (0). This range may
+ /// be dropped from the send buffer.
+ fn coalesce_acked_from_zero(&mut self) {
+ let acked_range_from_zero = self
+ .used
+ .get_mut(&0)
+ .filter(|(_, state)| *state == RangeState::Acked)
+ .map(|(len, _)| *len);
+
+ if let Some(len_from_zero) = acked_range_from_zero {
+ let mut new_len_from_zero = len_from_zero;
+
+ // See if there's another Acked range entry contiguous to this one
+ while let Some((next_len, _)) = self
+ .used
+ .get(&new_len_from_zero)
+ .filter(|(_, state)| *state == RangeState::Acked)
+ {
+ let to_remove = new_len_from_zero;
+ new_len_from_zero += *next_len;
+ self.used.remove(&to_remove);
+ }
+
+ if len_from_zero != new_len_from_zero {
+ self.used.get_mut(&0).expect("must be there").0 = new_len_from_zero;
+ }
+ }
+ }
+
+ fn mark_range(&mut self, off: u64, len: usize, state: RangeState) {
+ if len == 0 {
+ qinfo!("mark 0-length range at {}", off);
+ return;
+ }
+
+ let subranges = self.chunk_range_on_edges(off, len as u64, state);
+
+ for (sub_off, sub_len, sub_state) in subranges {
+ self.used.insert(sub_off, (sub_len, sub_state));
+ }
+
+ self.coalesce_acked_from_zero();
+ }
+
+ fn unmark_range(&mut self, off: u64, len: usize) {
+ if len == 0 {
+ qdebug!("unmark 0-length range at {}", off);
+ return;
+ }
+
+ let len = u64::try_from(len).unwrap();
+ let end_off = off + len;
+
+ let mut to_remove = SmallVec::<[_; 8]>::new();
+ let mut to_add = None;
+
+ // Walk backwards through possibly affected existing ranges
+ for (cur_off, (cur_len, cur_state)) in self.used.range_mut(..off + len).rev() {
+ // Maybe fixup range preceding the removed range
+ if *cur_off < off {
+ // Check for overlap
+ if *cur_off + *cur_len > off {
+ if *cur_state == RangeState::Acked {
+ qdebug!(
+ "Attempted to unmark Acked range {}-{} with unmark_range {}-{}",
+ cur_off,
+ cur_len,
+ off,
+ off + len
+ );
+ } else {
+ *cur_len = off - cur_off;
+ }
+ }
+ break;
+ }
+
+ if *cur_state == RangeState::Acked {
+ qdebug!(
+ "Attempted to unmark Acked range {}-{} with unmark_range {}-{}",
+ cur_off,
+ cur_len,
+ off,
+ off + len
+ );
+ continue;
+ }
+
+ // Add a new range for old subrange extending beyond
+ // to-be-unmarked range
+ let cur_end_off = cur_off + *cur_len;
+ if cur_end_off > end_off {
+ let new_cur_off = off + len;
+ let new_cur_len = cur_end_off - end_off;
+ assert_eq!(to_add, None);
+ to_add = Some((new_cur_off, new_cur_len, *cur_state));
+ }
+
+ to_remove.push(*cur_off);
+ }
+
+ for remove_off in to_remove {
+ self.used.remove(&remove_off);
+ }
+
+ if let Some((new_cur_off, new_cur_len, cur_state)) = to_add {
+ self.used.insert(new_cur_off, (new_cur_len, cur_state));
+ }
+ }
+
+ /// Unmark all sent ranges.
+ pub fn unmark_sent(&mut self) {
+ self.unmark_range(0, usize::try_from(self.highest_offset()).unwrap());
+ }
+}
+
+/// Buffer to contain queued bytes and track their state.
+#[derive(Debug, Default, PartialEq)]
+pub struct TxBuffer {
+ retired: u64, // contig acked bytes, no longer in buffer
+ send_buf: VecDeque<u8>, // buffer of not-acked bytes
+ ranges: RangeTracker, // ranges in buffer that have been sent or acked
+}
+
+impl TxBuffer {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ /// Attempt to add some or all of the passed-in buffer to the TxBuffer.
+ pub fn send(&mut self, buf: &[u8]) -> usize {
+ let can_buffer = min(SEND_BUFFER_SIZE - self.buffered(), buf.len());
+ if can_buffer > 0 {
+ self.send_buf.extend(&buf[..can_buffer]);
+ assert!(self.send_buf.len() <= SEND_BUFFER_SIZE);
+ }
+ can_buffer
+ }
+
+ pub fn next_bytes(&self) -> Option<(u64, &[u8])> {
+ let (start, maybe_len) = self.ranges.first_unmarked_range();
+
+ if start == self.retired + u64::try_from(self.buffered()).unwrap() {
+ return None;
+ }
+
+ // Convert from ranges-relative-to-zero to
+ // ranges-relative-to-buffer-start
+ let buff_off = usize::try_from(start - self.retired).unwrap();
+
+ // Deque returns two slices. Create a subslice from whichever
+ // one contains the first unmarked data.
+ let slc = if buff_off < self.send_buf.as_slices().0.len() {
+ &self.send_buf.as_slices().0[buff_off..]
+ } else {
+ &self.send_buf.as_slices().1[buff_off - self.send_buf.as_slices().0.len()..]
+ };
+
+ let len = if let Some(range_len) = maybe_len {
+ // Truncate if range crosses deque slices
+ min(usize::try_from(range_len).unwrap(), slc.len())
+ } else {
+ slc.len()
+ };
+
+ debug_assert!(len > 0);
+ debug_assert!(len <= slc.len());
+
+ Some((start, &slc[..len]))
+ }
+
+ pub fn mark_as_sent(&mut self, offset: u64, len: usize) {
+ self.ranges.mark_range(offset, len, RangeState::Sent);
+ }
+
+ pub fn mark_as_acked(&mut self, offset: u64, len: usize) {
+ self.ranges.mark_range(offset, len, RangeState::Acked);
+
+ // We can drop contig acked range from the buffer
+ let new_retirable = self.ranges.acked_from_zero() - self.retired;
+ debug_assert!(new_retirable <= self.buffered() as u64);
+ let keep_len =
+ self.buffered() - usize::try_from(new_retirable).expect("should fit in usize");
+
+ // Truncate front
+ self.send_buf.rotate_left(self.buffered() - keep_len);
+ self.send_buf.truncate(keep_len);
+
+ self.retired += new_retirable;
+ }
+
+ pub fn mark_as_lost(&mut self, offset: u64, len: usize) {
+ self.ranges.unmark_range(offset, len);
+ }
+
+ /// Forget about anything that was marked as sent.
+ pub fn unmark_sent(&mut self) {
+ self.ranges.unmark_sent();
+ }
+
+ pub fn retired(&self) -> u64 {
+ self.retired
+ }
+
+ fn buffered(&self) -> usize {
+ self.send_buf.len()
+ }
+
+ fn avail(&self) -> usize {
+ SEND_BUFFER_SIZE - self.buffered()
+ }
+
+ fn used(&self) -> u64 {
+ self.retired + u64::try_from(self.buffered()).unwrap()
+ }
+}
+
+/// QUIC sending stream states, based on -transport 3.1.
+#[derive(Debug)]
+pub(crate) enum SendStreamState {
+ Ready {
+ fc: SenderFlowControl<StreamId>,
+ conn_fc: Rc<RefCell<SenderFlowControl<()>>>,
+ },
+ Send {
+ fc: SenderFlowControl<StreamId>,
+ conn_fc: Rc<RefCell<SenderFlowControl<()>>>,
+ send_buf: TxBuffer,
+ },
+ // Note: `DataSent` is entered when the stream is closed, not when all data has been
+ // sent for the first time.
+ DataSent {
+ send_buf: TxBuffer,
+ fin_sent: bool,
+ fin_acked: bool,
+ },
+ DataRecvd {
+ retired: u64,
+ written: u64,
+ },
+ ResetSent {
+ err: AppError,
+ final_size: u64,
+ priority: Option<TransmissionPriority>,
+ final_retired: u64,
+ final_written: u64,
+ },
+ ResetRecvd {
+ final_retired: u64,
+ final_written: u64,
+ },
+}
+
+impl SendStreamState {
+ fn tx_buf_mut(&mut self) -> Option<&mut TxBuffer> {
+ match self {
+ Self::Send { send_buf, .. } | Self::DataSent { send_buf, .. } => Some(send_buf),
+ Self::Ready { .. }
+ | Self::DataRecvd { .. }
+ | Self::ResetSent { .. }
+ | Self::ResetRecvd { .. } => None,
+ }
+ }
+
+ fn tx_avail(&self) -> usize {
+ match self {
+ // In Ready, TxBuffer not yet allocated but size is known
+ Self::Ready { .. } => SEND_BUFFER_SIZE,
+ Self::Send { send_buf, .. } | Self::DataSent { send_buf, .. } => send_buf.avail(),
+ Self::DataRecvd { .. } | Self::ResetSent { .. } | Self::ResetRecvd { .. } => 0,
+ }
+ }
+
+ fn name(&self) -> &str {
+ match self {
+ Self::Ready { .. } => "Ready",
+ Self::Send { .. } => "Send",
+ Self::DataSent { .. } => "DataSent",
+ Self::DataRecvd { .. } => "DataRecvd",
+ Self::ResetSent { .. } => "ResetSent",
+ Self::ResetRecvd { .. } => "ResetRecvd",
+ }
+ }
+
+ fn transition(&mut self, new_state: Self) {
+ qtrace!("SendStream state {} -> {}", self.name(), new_state.name());
+ *self = new_state;
+ }
+}
+
+// See https://www.w3.org/TR/webtransport/#send-stream-stats.
+#[derive(Debug, Clone, Copy)]
+pub struct SendStreamStats {
+ // The total number of bytes the consumer has successfully written to
+ // this stream. This number can only increase.
+ pub bytes_written: u64,
+ // An indicator of progress on how many of the consumer bytes written to
+ // this stream has been sent at least once. This number can only increase,
+ // and is always less than or equal to bytes_written.
+ pub bytes_sent: u64,
+ // An indicator of progress on how many of the consumer bytes written to
+ // this stream have been sent and acknowledged as received by the server
+ // using QUIC’s ACK mechanism. Only sequential bytes up to,
+ // but not including, the first non-acknowledged byte, are counted.
+ // This number can only increase and is always less than or equal to
+ // bytes_sent.
+ pub bytes_acked: u64,
+}
+
+impl SendStreamStats {
+ #[must_use]
+ pub fn new(bytes_written: u64, bytes_sent: u64, bytes_acked: u64) -> Self {
+ Self {
+ bytes_written,
+ bytes_sent,
+ bytes_acked,
+ }
+ }
+
+ #[must_use]
+ pub fn bytes_written(&self) -> u64 {
+ self.bytes_written
+ }
+
+ #[must_use]
+ pub fn bytes_sent(&self) -> u64 {
+ self.bytes_sent
+ }
+
+ #[must_use]
+ pub fn bytes_acked(&self) -> u64 {
+ self.bytes_acked
+ }
+}
+
+/// Implement a QUIC send stream.
+#[derive(Debug)]
+pub struct SendStream {
+ stream_id: StreamId,
+ state: SendStreamState,
+ conn_events: ConnectionEvents,
+ priority: TransmissionPriority,
+ retransmission_priority: RetransmissionPriority,
+ retransmission_offset: u64,
+ sendorder: Option<SendOrder>,
+ bytes_sent: u64,
+ fair: bool,
+}
+
+impl Hash for SendStream {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.stream_id.hash(state);
+ }
+}
+
+impl PartialEq for SendStream {
+ fn eq(&self, other: &Self) -> bool {
+ self.stream_id == other.stream_id
+ }
+}
+impl Eq for SendStream {}
+
+impl SendStream {
+ pub fn new(
+ stream_id: StreamId,
+ max_stream_data: u64,
+ conn_fc: Rc<RefCell<SenderFlowControl<()>>>,
+ conn_events: ConnectionEvents,
+ ) -> Self {
+ let ss = Self {
+ stream_id,
+ state: SendStreamState::Ready {
+ fc: SenderFlowControl::new(stream_id, max_stream_data),
+ conn_fc,
+ },
+ conn_events,
+ priority: TransmissionPriority::default(),
+ retransmission_priority: RetransmissionPriority::default(),
+ retransmission_offset: 0,
+ sendorder: None,
+ bytes_sent: 0,
+ fair: false,
+ };
+ if ss.avail() > 0 {
+ ss.conn_events.send_stream_writable(stream_id);
+ }
+ ss
+ }
+
+ pub fn write_frames(
+ &mut self,
+ priority: TransmissionPriority,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ qtrace!("write STREAM frames at priority {:?}", priority);
+ if !self.write_reset_frame(priority, builder, tokens, stats) {
+ self.write_blocked_frame(priority, builder, tokens, stats);
+ self.write_stream_frame(priority, builder, tokens, stats);
+ }
+ }
+
+ // return false if the builder is full and the caller should stop iterating
+ pub fn write_frames_with_early_return(
+ &mut self,
+ priority: TransmissionPriority,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) -> bool {
+ if !self.write_reset_frame(priority, builder, tokens, stats) {
+ self.write_blocked_frame(priority, builder, tokens, stats);
+ if builder.is_full() {
+ return false;
+ }
+ self.write_stream_frame(priority, builder, tokens, stats);
+ if builder.is_full() {
+ return false;
+ }
+ }
+ true
+ }
+
+ pub fn set_fairness(&mut self, make_fair: bool) {
+ self.fair = make_fair;
+ }
+
+ pub fn is_fair(&self) -> bool {
+ self.fair
+ }
+
+ pub fn set_priority(
+ &mut self,
+ transmission: TransmissionPriority,
+ retransmission: RetransmissionPriority,
+ ) {
+ self.priority = transmission;
+ self.retransmission_priority = retransmission;
+ }
+
+ pub fn sendorder(&self) -> Option<SendOrder> {
+ self.sendorder
+ }
+
+ pub fn set_sendorder(&mut self, sendorder: Option<SendOrder>) {
+ self.sendorder = sendorder;
+ }
+
+ /// If all data has been buffered or written, how much was sent.
+ pub fn final_size(&self) -> Option<u64> {
+ match &self.state {
+ SendStreamState::DataSent { send_buf, .. } => Some(send_buf.used()),
+ SendStreamState::ResetSent { final_size, .. } => Some(*final_size),
+ _ => None,
+ }
+ }
+
+ pub fn stats(&self) -> SendStreamStats {
+ SendStreamStats::new(self.bytes_written(), self.bytes_sent, self.bytes_acked())
+ }
+
+ pub fn bytes_written(&self) -> u64 {
+ match &self.state {
+ SendStreamState::Send { send_buf, .. } | SendStreamState::DataSent { send_buf, .. } => {
+ send_buf.retired() + u64::try_from(send_buf.buffered()).unwrap()
+ }
+ SendStreamState::DataRecvd {
+ retired, written, ..
+ } => *retired + *written,
+ SendStreamState::ResetSent {
+ final_retired,
+ final_written,
+ ..
+ }
+ | SendStreamState::ResetRecvd {
+ final_retired,
+ final_written,
+ ..
+ } => *final_retired + *final_written,
+ SendStreamState::Ready { .. } => 0,
+ }
+ }
+
+ pub fn bytes_acked(&self) -> u64 {
+ match &self.state {
+ SendStreamState::Send { send_buf, .. } | SendStreamState::DataSent { send_buf, .. } => {
+ send_buf.retired()
+ }
+ SendStreamState::DataRecvd { retired, .. } => *retired,
+ SendStreamState::ResetSent { final_retired, .. }
+ | SendStreamState::ResetRecvd { final_retired, .. } => *final_retired,
+ SendStreamState::Ready { .. } => 0,
+ }
+ }
+
+ /// Return the next range to be sent, if any.
+ /// If this is a retransmission, cut off what is sent at the retransmission
+ /// offset.
+ fn next_bytes(&mut self, retransmission_only: bool) -> Option<(u64, &[u8])> {
+ match self.state {
+ SendStreamState::Send { ref send_buf, .. } => {
+ send_buf.next_bytes().and_then(|(offset, slice)| {
+ if retransmission_only {
+ qtrace!(
+ [self],
+ "next_bytes apply retransmission limit at {}",
+ self.retransmission_offset
+ );
+ if self.retransmission_offset > offset {
+ let len = min(
+ usize::try_from(self.retransmission_offset - offset).unwrap(),
+ slice.len(),
+ );
+ Some((offset, &slice[..len]))
+ } else {
+ None
+ }
+ } else {
+ Some((offset, slice))
+ }
+ })
+ }
+ SendStreamState::DataSent {
+ ref send_buf,
+ fin_sent,
+ ..
+ } => {
+ let bytes = send_buf.next_bytes();
+ if bytes.is_some() {
+ bytes
+ } else if fin_sent {
+ None
+ } else {
+ // Send empty stream frame with fin set
+ Some((send_buf.used(), &[]))
+ }
+ }
+ SendStreamState::Ready { .. }
+ | SendStreamState::DataRecvd { .. }
+ | SendStreamState::ResetSent { .. }
+ | SendStreamState::ResetRecvd { .. } => None,
+ }
+ }
+
+ /// Calculate how many bytes (length) can fit into available space and whether
+ /// the remainder of the space can be filled (or if a length field is needed).
+ fn length_and_fill(data_len: usize, space: usize) -> (usize, bool) {
+ if data_len >= space {
+ // More data than space allows, or an exact fit => fast path.
+ qtrace!("SendStream::length_and_fill fill {}", space);
+ return (space, true);
+ }
+
+ // Estimate size of the length field based on the available space,
+ // less 1, which is the worst case.
+ let length = min(space.saturating_sub(1), data_len);
+ let length_len = Encoder::varint_len(u64::try_from(length).unwrap());
+ debug_assert!(length_len <= space); // We don't depend on this being true, but it is true.
+
+ // From here we can always fit `data_len`, but we might as well fill
+ // if there is no space for the length field plus another frame.
+ let fill = data_len + length_len + PacketBuilder::MINIMUM_FRAME_SIZE > space;
+ qtrace!("SendStream::length_and_fill {} fill {}", data_len, fill);
+ (data_len, fill)
+ }
+
+ /// Maybe write a `STREAM` frame.
+ pub fn write_stream_frame(
+ &mut self,
+ priority: TransmissionPriority,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ let retransmission = if priority == self.priority {
+ false
+ } else if priority == self.priority + self.retransmission_priority {
+ true
+ } else {
+ return;
+ };
+
+ let id = self.stream_id;
+ let final_size = self.final_size();
+ if let Some((offset, data)) = self.next_bytes(retransmission) {
+ let overhead = 1 // Frame type
+ + Encoder::varint_len(id.as_u64())
+ + if offset > 0 {
+ Encoder::varint_len(offset)
+ } else {
+ 0
+ };
+ if overhead > builder.remaining() {
+ qtrace!([self], "write_frame no space for header");
+ return;
+ }
+
+ let (length, fill) = Self::length_and_fill(data.len(), builder.remaining() - overhead);
+ let fin = final_size.map_or(false, |fs| fs == offset + u64::try_from(length).unwrap());
+ if length == 0 && !fin {
+ qtrace!([self], "write_frame no data, no fin");
+ return;
+ }
+
+ // Write the stream out.
+ builder.encode_varint(Frame::stream_type(fin, offset > 0, fill));
+ builder.encode_varint(id.as_u64());
+ if offset > 0 {
+ builder.encode_varint(offset);
+ }
+ if fill {
+ builder.encode(&data[..length]);
+ builder.mark_full();
+ } else {
+ builder.encode_vvec(&data[..length]);
+ }
+ debug_assert!(builder.len() <= builder.limit());
+
+ self.mark_as_sent(offset, length, fin);
+ tokens.push(RecoveryToken::Stream(StreamRecoveryToken::Stream(
+ SendStreamRecoveryToken {
+ id,
+ offset,
+ length,
+ fin,
+ },
+ )));
+ stats.stream += 1;
+ }
+ }
+
+ pub fn reset_acked(&mut self) {
+ match self.state {
+ SendStreamState::Ready { .. }
+ | SendStreamState::Send { .. }
+ | SendStreamState::DataSent { .. }
+ | SendStreamState::DataRecvd { .. } => {
+ qtrace!([self], "Reset acked while in {} state?", self.state.name());
+ }
+ SendStreamState::ResetSent {
+ final_retired,
+ final_written,
+ ..
+ } => self.state.transition(SendStreamState::ResetRecvd {
+ final_retired,
+ final_written,
+ }),
+ SendStreamState::ResetRecvd { .. } => qtrace!([self], "already in ResetRecvd state"),
+ };
+ }
+
+ pub fn reset_lost(&mut self) {
+ match self.state {
+ SendStreamState::ResetSent {
+ ref mut priority, ..
+ } => {
+ *priority = Some(self.priority + self.retransmission_priority);
+ }
+ SendStreamState::ResetRecvd { .. } => (),
+ _ => unreachable!(),
+ }
+ }
+
+ /// Maybe write a `RESET_STREAM` frame.
+ pub fn write_reset_frame(
+ &mut self,
+ p: TransmissionPriority,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) -> bool {
+ if let SendStreamState::ResetSent {
+ final_size,
+ err,
+ ref mut priority,
+ ..
+ } = self.state
+ {
+ if *priority != Some(p) {
+ return false;
+ }
+ if builder.write_varint_frame(&[
+ FRAME_TYPE_RESET_STREAM,
+ self.stream_id.as_u64(),
+ err,
+ final_size,
+ ]) {
+ tokens.push(RecoveryToken::Stream(StreamRecoveryToken::ResetStream {
+ stream_id: self.stream_id,
+ }));
+ stats.reset_stream += 1;
+ *priority = None;
+ true
+ } else {
+ false
+ }
+ } else {
+ false
+ }
+ }
+
+ pub fn blocked_lost(&mut self, limit: u64) {
+ if let SendStreamState::Ready { fc, .. } | SendStreamState::Send { fc, .. } =
+ &mut self.state
+ {
+ fc.frame_lost(limit);
+ } else {
+ qtrace!([self], "Ignoring lost STREAM_DATA_BLOCKED({})", limit);
+ }
+ }
+
+ /// Maybe write a `STREAM_DATA_BLOCKED` frame.
+ pub fn write_blocked_frame(
+ &mut self,
+ priority: TransmissionPriority,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ // Send STREAM_DATA_BLOCKED at normal priority always.
+ if priority == self.priority {
+ if let SendStreamState::Ready { fc, .. } | SendStreamState::Send { fc, .. } =
+ &mut self.state
+ {
+ fc.write_frames(builder, tokens, stats);
+ }
+ }
+ }
+
+ pub fn mark_as_sent(&mut self, offset: u64, len: usize, fin: bool) {
+ self.bytes_sent = max(self.bytes_sent, offset + u64::try_from(len).unwrap());
+
+ if let Some(buf) = self.state.tx_buf_mut() {
+ buf.mark_as_sent(offset, len);
+ self.send_blocked_if_space_needed(0);
+ };
+
+ if fin {
+ if let SendStreamState::DataSent { fin_sent, .. } = &mut self.state {
+ *fin_sent = true;
+ }
+ }
+ }
+
+ pub fn mark_as_acked(&mut self, offset: u64, len: usize, fin: bool) {
+ match self.state {
+ SendStreamState::Send {
+ ref mut send_buf, ..
+ } => {
+ send_buf.mark_as_acked(offset, len);
+ if self.avail() > 0 {
+ self.conn_events.send_stream_writable(self.stream_id);
+ }
+ }
+ SendStreamState::DataSent {
+ ref mut send_buf,
+ ref mut fin_acked,
+ ..
+ } => {
+ send_buf.mark_as_acked(offset, len);
+ if fin {
+ *fin_acked = true;
+ }
+ if *fin_acked && send_buf.buffered() == 0 {
+ self.conn_events.send_stream_complete(self.stream_id);
+ let retired = send_buf.retired();
+ let buffered = u64::try_from(send_buf.buffered()).unwrap();
+ self.state.transition(SendStreamState::DataRecvd {
+ retired,
+ written: buffered,
+ });
+ }
+ }
+ _ => qtrace!(
+ [self],
+ "mark_as_acked called from state {}",
+ self.state.name()
+ ),
+ }
+ }
+
+ pub fn mark_as_lost(&mut self, offset: u64, len: usize, fin: bool) {
+ self.retransmission_offset = max(
+ self.retransmission_offset,
+ offset + u64::try_from(len).unwrap(),
+ );
+ qtrace!(
+ [self],
+ "mark_as_lost retransmission offset={}",
+ self.retransmission_offset
+ );
+ if let Some(buf) = self.state.tx_buf_mut() {
+ buf.mark_as_lost(offset, len);
+ }
+
+ if fin {
+ if let SendStreamState::DataSent {
+ fin_sent,
+ fin_acked,
+ ..
+ } = &mut self.state
+ {
+ *fin_sent = *fin_acked;
+ }
+ }
+ }
+
+ /// Bytes sendable on stream. Constrained by stream credit available,
+ /// connection credit available, and space in the tx buffer.
+ pub fn avail(&self) -> usize {
+ if let SendStreamState::Ready { fc, conn_fc } | SendStreamState::Send { fc, conn_fc, .. } =
+ &self.state
+ {
+ min(
+ min(fc.available(), conn_fc.borrow().available()),
+ self.state.tx_avail(),
+ )
+ } else {
+ 0
+ }
+ }
+
+ pub fn set_max_stream_data(&mut self, limit: u64) {
+ if let SendStreamState::Ready { fc, .. } | SendStreamState::Send { fc, .. } =
+ &mut self.state
+ {
+ let stream_was_blocked = fc.available() == 0;
+ fc.update(limit);
+ if stream_was_blocked && self.avail() > 0 {
+ self.conn_events.send_stream_writable(self.stream_id);
+ }
+ }
+ }
+
+ pub fn is_terminal(&self) -> bool {
+ matches!(
+ self.state,
+ SendStreamState::DataRecvd { .. } | SendStreamState::ResetRecvd { .. }
+ )
+ }
+
+ pub fn send(&mut self, buf: &[u8]) -> Res<usize> {
+ self.send_internal(buf, false)
+ }
+
+ pub fn send_atomic(&mut self, buf: &[u8]) -> Res<usize> {
+ self.send_internal(buf, true)
+ }
+
+ fn send_blocked_if_space_needed(&mut self, needed_space: usize) {
+ if let SendStreamState::Ready { fc, conn_fc } | SendStreamState::Send { fc, conn_fc, .. } =
+ &mut self.state
+ {
+ if fc.available() <= needed_space {
+ fc.blocked();
+ }
+
+ if conn_fc.borrow().available() <= needed_space {
+ conn_fc.borrow_mut().blocked();
+ }
+ }
+ }
+
+ fn send_internal(&mut self, buf: &[u8], atomic: bool) -> Res<usize> {
+ if buf.is_empty() {
+ qerror!([self], "zero-length send on stream");
+ return Err(Error::InvalidInput);
+ }
+
+ if let SendStreamState::Ready { fc, conn_fc } = &mut self.state {
+ let owned_fc = mem::replace(fc, SenderFlowControl::new(self.stream_id, 0));
+ let owned_conn_fc = Rc::clone(conn_fc);
+ self.state.transition(SendStreamState::Send {
+ fc: owned_fc,
+ conn_fc: owned_conn_fc,
+ send_buf: TxBuffer::new(),
+ });
+ }
+
+ if !matches!(self.state, SendStreamState::Send { .. }) {
+ return Err(Error::FinalSizeError);
+ }
+
+ let buf = if buf.is_empty() || (self.avail() == 0) {
+ return Ok(0);
+ } else if self.avail() < buf.len() {
+ if atomic {
+ self.send_blocked_if_space_needed(buf.len());
+ return Ok(0);
+ } else {
+ &buf[..self.avail()]
+ }
+ } else {
+ buf
+ };
+
+ match &mut self.state {
+ SendStreamState::Ready { .. } => unreachable!(),
+ SendStreamState::Send {
+ fc,
+ conn_fc,
+ send_buf,
+ } => {
+ let sent = send_buf.send(buf);
+ fc.consume(sent);
+ conn_fc.borrow_mut().consume(sent);
+ Ok(sent)
+ }
+ _ => Err(Error::FinalSizeError),
+ }
+ }
+
+ pub fn close(&mut self) {
+ match &mut self.state {
+ SendStreamState::Ready { .. } => {
+ self.state.transition(SendStreamState::DataSent {
+ send_buf: TxBuffer::new(),
+ fin_sent: false,
+ fin_acked: false,
+ });
+ }
+ SendStreamState::Send { send_buf, .. } => {
+ let owned_buf = mem::replace(send_buf, TxBuffer::new());
+ self.state.transition(SendStreamState::DataSent {
+ send_buf: owned_buf,
+ fin_sent: false,
+ fin_acked: false,
+ });
+ }
+ SendStreamState::DataSent { .. } => qtrace!([self], "already in DataSent state"),
+ SendStreamState::DataRecvd { .. } => qtrace!([self], "already in DataRecvd state"),
+ SendStreamState::ResetSent { .. } => qtrace!([self], "already in ResetSent state"),
+ SendStreamState::ResetRecvd { .. } => qtrace!([self], "already in ResetRecvd state"),
+ }
+ }
+
+ pub fn reset(&mut self, err: AppError) {
+ match &self.state {
+ SendStreamState::Ready { fc, .. } => {
+ let final_size = fc.used();
+ self.state.transition(SendStreamState::ResetSent {
+ err,
+ final_size,
+ priority: Some(self.priority),
+ final_retired: 0,
+ final_written: 0,
+ });
+ }
+ SendStreamState::Send { fc, send_buf, .. } => {
+ let final_size = fc.used();
+ let final_retired = send_buf.retired();
+ let buffered = u64::try_from(send_buf.buffered()).unwrap();
+ self.state.transition(SendStreamState::ResetSent {
+ err,
+ final_size,
+ priority: Some(self.priority),
+ final_retired,
+ final_written: buffered,
+ });
+ }
+ SendStreamState::DataSent { send_buf, .. } => {
+ let final_size = send_buf.used();
+ let final_retired = send_buf.retired();
+ let buffered = u64::try_from(send_buf.buffered()).unwrap();
+ self.state.transition(SendStreamState::ResetSent {
+ err,
+ final_size,
+ priority: Some(self.priority),
+ final_retired,
+ final_written: buffered,
+ });
+ }
+ SendStreamState::DataRecvd { .. } => qtrace!([self], "already in DataRecvd state"),
+ SendStreamState::ResetSent { .. } => qtrace!([self], "already in ResetSent state"),
+ SendStreamState::ResetRecvd { .. } => qtrace!([self], "already in ResetRecvd state"),
+ };
+ }
+
+ #[cfg(test)]
+ pub(crate) fn state(&mut self) -> &mut SendStreamState {
+ &mut self.state
+ }
+}
+
+impl ::std::fmt::Display for SendStream {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "SendStream {}", self.stream_id)
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct OrderGroup {
+ // This vector is sorted by StreamId
+ vec: Vec<StreamId>,
+
+ // Since we need to remember where we were, we'll store the iterator next
+ // position in the object. This means there can only be a single iterator active
+ // at a time!
+ next: usize,
+ // This is used when an iterator is created to set the start/stop point for the
+ // iteration. The iterator must iterate from this entry to the end, and then
+ // wrap and iterate from 0 until before the initial value of next.
+ // This value may need to be updated after insertion and removal; in theory we should
+ // track the target entry across modifications, but in practice it should be good
+ // enough to simply leave it alone unless it points past the end of the
+ // Vec, and re-initialize to 0 in that case.
+}
+
+pub struct OrderGroupIter<'a> {
+ group: &'a mut OrderGroup,
+ // We store the next position in the OrderGroup.
+ // Otherwise we'd need an explicit "done iterating" call to be made, or implement Drop to
+ // copy the value back.
+ // This is where next was when we iterated for the first time; when we get back to that we
+ // stop.
+ started_at: Option<usize>,
+}
+
+impl OrderGroup {
+ pub fn iter(&mut self) -> OrderGroupIter {
+ // Ids may have been deleted since we last iterated
+ if self.next >= self.vec.len() {
+ self.next = 0;
+ }
+ OrderGroupIter {
+ started_at: None,
+ group: self,
+ }
+ }
+
+ pub fn stream_ids(&self) -> &Vec<StreamId> {
+ &self.vec
+ }
+
+ pub fn clear(&mut self) {
+ self.vec.clear();
+ }
+
+ pub fn push(&mut self, stream_id: StreamId) {
+ self.vec.push(stream_id);
+ }
+
+ #[cfg(test)]
+ pub fn truncate(&mut self, position: usize) {
+ self.vec.truncate(position);
+ }
+
+ fn update_next(&mut self) -> usize {
+ let next = self.next;
+ self.next = (self.next + 1) % self.vec.len();
+ next
+ }
+
+ pub fn insert(&mut self, stream_id: StreamId) {
+ match self.vec.binary_search(&stream_id) {
+ Ok(_) => {
+ // element already in vector @ `pos`
+ panic!("Duplicate stream_id {}", stream_id)
+ }
+ Err(pos) => self.vec.insert(pos, stream_id),
+ }
+ }
+
+ pub fn remove(&mut self, stream_id: StreamId) {
+ match self.vec.binary_search(&stream_id) {
+ Ok(pos) => {
+ self.vec.remove(pos);
+ }
+ Err(_) => {
+ // element already in vector @ `pos`
+ panic!("Missing stream_id {}", stream_id)
+ }
+ }
+ }
+}
+
+impl<'a> Iterator for OrderGroupIter<'a> {
+ type Item = StreamId;
+ fn next(&mut self) -> Option<Self::Item> {
+ // Stop when we would return the started_at element on the next
+ // call. Note that this must take into account wrapping.
+ if self.started_at == Some(self.group.next) || self.group.vec.is_empty() {
+ return None;
+ }
+ self.started_at = self.started_at.or(Some(self.group.next));
+ let orig = self.group.update_next();
+ Some(self.group.vec[orig])
+ }
+}
+
+#[derive(Debug, Default)]
+pub(crate) struct SendStreams {
+ map: IndexMap<StreamId, SendStream>,
+
+ // What we really want is a Priority Queue that we can do arbitrary
+ // removes from (so we can reprioritize). BinaryHeap doesn't work,
+ // because there's no remove(). BTreeMap doesn't work, since you can't
+ // duplicate keys. PriorityQueue does have what we need, except for an
+ // ordered iterator that doesn't consume the queue. So we roll our own.
+
+ // Added complication: We want to have Fairness for streams of the same
+ // 'group' (for WebTransport), but for H3 (and other non-WT streams) we
+ // tend to get better pageload performance by prioritizing by creation order.
+ //
+ // Two options are to walk the 'map' first, ignoring WebTransport
+ // streams, then process the unordered and ordered WebTransport
+ // streams. The second is to have a sorted Vec for unfair streams (and
+ // use a normal iterator for that), and then chain the iterators for
+ // the unordered and ordered WebTranport streams. The first works very
+ // well for H3, and for WebTransport nodes are visited twice on every
+ // processing loop. The second adds insertion and removal costs, but
+ // avoids a CPU penalty for WebTransport streams. For now we'll do #1.
+ //
+ // So we use a sorted Vec<> for the regular streams (that's usually all of
+ // them), and then a BTreeMap of an entry for each SendOrder value, and
+ // for each of those entries a Vec of the stream_ids at that
+ // sendorder. In most cases (such as stream-per-frame), there will be
+ // a single stream at a given sendorder.
+
+ // These both store stream_ids, which need to be looked up in 'map'.
+ // This avoids the complexity of trying to hold references to the
+ // Streams which are owned by the IndexMap.
+ sendordered: BTreeMap<SendOrder, OrderGroup>,
+ regular: OrderGroup, // streams with no SendOrder set, sorted in stream_id order
+}
+
+impl SendStreams {
+ pub fn get(&self, id: StreamId) -> Res<&SendStream> {
+ self.map.get(&id).ok_or(Error::InvalidStreamId)
+ }
+
+ pub fn get_mut(&mut self, id: StreamId) -> Res<&mut SendStream> {
+ self.map.get_mut(&id).ok_or(Error::InvalidStreamId)
+ }
+
+ pub fn exists(&self, id: StreamId) -> bool {
+ self.map.contains_key(&id)
+ }
+
+ pub fn insert(&mut self, id: StreamId, stream: SendStream) {
+ self.map.insert(id, stream);
+ }
+
+ fn group_mut(&mut self, sendorder: Option<SendOrder>) -> &mut OrderGroup {
+ if let Some(order) = sendorder {
+ self.sendordered.entry(order).or_default()
+ } else {
+ &mut self.regular
+ }
+ }
+
+ pub fn set_sendorder(&mut self, stream_id: StreamId, sendorder: Option<SendOrder>) -> Res<()> {
+ self.set_fairness(stream_id, true)?;
+ if let Some(stream) = self.map.get_mut(&stream_id) {
+ // don't grab stream here; causes borrow errors
+ let old_sendorder = stream.sendorder();
+ if old_sendorder != sendorder {
+ // we have to remove it from the list it was in, and reinsert it with the new
+ // sendorder key
+ let mut group = self.group_mut(old_sendorder);
+ group.remove(stream_id);
+ self.get_mut(stream_id).unwrap().set_sendorder(sendorder);
+ group = self.group_mut(sendorder);
+ group.insert(stream_id);
+ qtrace!(
+ "ordering of stream_ids: {:?}",
+ self.sendordered.values().collect::<Vec::<_>>()
+ );
+ }
+ Ok(())
+ } else {
+ Err(Error::InvalidStreamId)
+ }
+ }
+
+ pub fn set_fairness(&mut self, stream_id: StreamId, make_fair: bool) -> Res<()> {
+ let stream: &mut SendStream = self.map.get_mut(&stream_id).ok_or(Error::InvalidStreamId)?;
+ let was_fair = stream.fair;
+ stream.set_fairness(make_fair);
+ if !was_fair && make_fair {
+ // Move to the regular OrderGroup.
+
+ // We know sendorder can't have been set, since
+ // set_sendorder() will call this routine if it's not
+ // already set as fair.
+
+ // This normally is only called when a new stream is created. If
+ // so, because of how we allocate StreamIds, it should always have
+ // the largest value. This means we can just append it to the
+ // regular vector. However, if we were ever to change this
+ // invariant, things would break subtly.
+
+ // To be safe we can try to insert at the end and if not
+ // fall back to binary-search insertion
+ if matches!(self.regular.stream_ids().last(), Some(last) if stream_id > *last) {
+ self.regular.push(stream_id);
+ } else {
+ self.regular.insert(stream_id);
+ }
+ } else if was_fair && !make_fair {
+ // remove from the OrderGroup
+ let group = if let Some(sendorder) = stream.sendorder {
+ self.sendordered.get_mut(&sendorder).unwrap()
+ } else {
+ &mut self.regular
+ };
+ group.remove(stream_id);
+ }
+ Ok(())
+ }
+
+ pub fn acked(&mut self, token: &SendStreamRecoveryToken) {
+ if let Some(ss) = self.map.get_mut(&token.id) {
+ ss.mark_as_acked(token.offset, token.length, token.fin);
+ }
+ }
+
+ pub fn reset_acked(&mut self, id: StreamId) {
+ if let Some(ss) = self.map.get_mut(&id) {
+ ss.reset_acked();
+ }
+ }
+
+ pub fn lost(&mut self, token: &SendStreamRecoveryToken) {
+ if let Some(ss) = self.map.get_mut(&token.id) {
+ ss.mark_as_lost(token.offset, token.length, token.fin);
+ }
+ }
+
+ pub fn reset_lost(&mut self, stream_id: StreamId) {
+ if let Some(ss) = self.map.get_mut(&stream_id) {
+ ss.reset_lost();
+ }
+ }
+
+ pub fn blocked_lost(&mut self, stream_id: StreamId, limit: u64) {
+ if let Some(ss) = self.map.get_mut(&stream_id) {
+ ss.blocked_lost(limit);
+ }
+ }
+
+ pub fn clear(&mut self) {
+ self.map.clear();
+ self.sendordered.clear();
+ self.regular.clear();
+ }
+
+ pub fn remove_terminal(&mut self) {
+ let map: &mut IndexMap<StreamId, SendStream> = &mut self.map;
+ let regular: &mut OrderGroup = &mut self.regular;
+ let sendordered: &mut BTreeMap<SendOrder, OrderGroup> = &mut self.sendordered;
+
+ // Take refs to all the items we need to modify instead of &mut
+ // self to keep the compiler happy (if we use self.map.retain it
+ // gets upset due to borrows)
+ map.retain(|stream_id, stream| {
+ if stream.is_terminal() {
+ if stream.is_fair() {
+ match stream.sendorder() {
+ None => regular.remove(*stream_id),
+ Some(sendorder) => {
+ sendordered.get_mut(&sendorder).unwrap().remove(*stream_id);
+ }
+ };
+ }
+ // if unfair, we're done
+ return false;
+ }
+ true
+ });
+ }
+
+ pub(crate) fn write_frames(
+ &mut self,
+ priority: TransmissionPriority,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ qtrace!("write STREAM frames at priority {:?}", priority);
+ // WebTransport data (which is Normal) may have a SendOrder
+ // priority attached. The spec states (6.3 write-chunk 6.1):
+
+ // First, we send any streams without Fairness defined, with
+ // ordering defined by StreamId. (Http3 streams used for
+ // e.g. pageload benefit from being processed in order of creation
+ // so the far side can start acting on a datum/request sooner. All
+ // WebTransport streams MUST have fairness set.) Then we send
+ // streams with fairness set (including all WebTransport streams)
+ // as follows:
+
+ // If stream.[[SendOrder]] is null then this sending MUST NOT
+ // starve except for flow control reasons or error. If
+ // stream.[[SendOrder]] is not null then this sending MUST starve
+ // until all bytes queued for sending on WebTransportSendStreams
+ // with a non-null and higher [[SendOrder]], that are neither
+ // errored nor blocked by flow control, have been sent.
+
+ // So data without SendOrder goes first. Then the highest priority
+ // SendOrdered streams.
+ //
+ // Fairness is implemented by a round-robining or "statefully
+ // iterating" within a single sendorder/unordered vector. We do
+ // this by recording where we stopped in the previous pass, and
+ // starting there the next pass. If we store an index into the
+ // vec, this means we can't use a chained iterator, since we want
+ // to retain our place-in-the-vector. If we rotate the vector,
+ // that would let us use the chained iterator, but would require
+ // more expensive searches for insertion and removal (since the
+ // sorted order would be lost).
+
+ // Iterate the map, but only those without fairness, then iterate
+ // OrderGroups, then iterate each group
+ qdebug!("processing streams... unfair:");
+ for stream in self.map.values_mut() {
+ if !stream.is_fair() {
+ qdebug!(" {}", stream);
+ if !stream.write_frames_with_early_return(priority, builder, tokens, stats) {
+ break;
+ }
+ }
+ }
+ qdebug!("fair streams:");
+ let stream_ids = self.regular.iter().chain(
+ self.sendordered
+ .values_mut()
+ .rev()
+ .flat_map(|group| group.iter()),
+ );
+ for stream_id in stream_ids {
+ let stream = self.map.get_mut(&stream_id).unwrap();
+ if let Some(order) = stream.sendorder() {
+ qdebug!(" {} ({})", stream_id, order)
+ } else {
+ qdebug!(" None")
+ }
+ if !stream.write_frames_with_early_return(priority, builder, tokens, stats) {
+ break;
+ }
+ }
+ }
+
+ pub fn update_initial_limit(&mut self, remote: &TransportParameters) {
+ for (id, ss) in self.map.iter_mut() {
+ let limit = if id.is_bidi() {
+ assert!(!id.is_remote_initiated(Role::Client));
+ remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE)
+ } else {
+ remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI)
+ };
+ ss.set_max_stream_data(limit);
+ }
+ }
+}
+
+impl<'a> IntoIterator for &'a mut SendStreams {
+ type Item = (&'a StreamId, &'a mut SendStream);
+ type IntoIter = indexmap::map::IterMut<'a, StreamId, SendStream>;
+
+ fn into_iter(self) -> indexmap::map::IterMut<'a, StreamId, SendStream> {
+ self.map.iter_mut()
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct SendStreamRecoveryToken {
+ pub(crate) id: StreamId,
+ offset: u64,
+ length: usize,
+ fin: bool,
+}
+
+#[cfg(test)]
+mod tests {
+ use neqo_common::{event::Provider, hex_with_len, qtrace};
+
+ use super::*;
+ use crate::events::ConnectionEvent;
+
+ fn connection_fc(limit: u64) -> Rc<RefCell<SenderFlowControl<()>>> {
+ Rc::new(RefCell::new(SenderFlowControl::new((), limit)))
+ }
+
+ #[test]
+ fn test_mark_range() {
+ let mut rt = RangeTracker::default();
+
+ // ranges can go from nothing->Sent if queued for retrans and then
+ // acks arrive
+ rt.mark_range(5, 5, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 10);
+ assert_eq!(rt.acked_from_zero(), 0);
+ rt.mark_range(10, 4, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 14);
+ assert_eq!(rt.acked_from_zero(), 0);
+
+ rt.mark_range(0, 5, RangeState::Sent);
+ assert_eq!(rt.highest_offset(), 14);
+ assert_eq!(rt.acked_from_zero(), 0);
+ rt.mark_range(0, 5, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 14);
+ assert_eq!(rt.acked_from_zero(), 14);
+
+ rt.mark_range(12, 20, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 32);
+ assert_eq!(rt.acked_from_zero(), 32);
+
+ // ack the lot
+ rt.mark_range(0, 400, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 400);
+ assert_eq!(rt.acked_from_zero(), 400);
+
+ // acked trumps sent
+ rt.mark_range(0, 200, RangeState::Sent);
+ assert_eq!(rt.highest_offset(), 400);
+ assert_eq!(rt.acked_from_zero(), 400);
+ }
+
+ #[test]
+ fn unmark_sent_start() {
+ let mut rt = RangeTracker::default();
+
+ rt.mark_range(0, 5, RangeState::Sent);
+ assert_eq!(rt.highest_offset(), 5);
+ assert_eq!(rt.acked_from_zero(), 0);
+
+ rt.unmark_sent();
+ assert_eq!(rt.highest_offset(), 0);
+ assert_eq!(rt.acked_from_zero(), 0);
+ assert_eq!(rt.first_unmarked_range(), (0, None));
+ }
+
+ #[test]
+ fn unmark_sent_middle() {
+ let mut rt = RangeTracker::default();
+
+ rt.mark_range(0, 5, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 5);
+ assert_eq!(rt.acked_from_zero(), 5);
+ rt.mark_range(5, 5, RangeState::Sent);
+ assert_eq!(rt.highest_offset(), 10);
+ assert_eq!(rt.acked_from_zero(), 5);
+ rt.mark_range(10, 5, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 15);
+ assert_eq!(rt.acked_from_zero(), 5);
+ assert_eq!(rt.first_unmarked_range(), (15, None));
+
+ rt.unmark_sent();
+ assert_eq!(rt.highest_offset(), 15);
+ assert_eq!(rt.acked_from_zero(), 5);
+ assert_eq!(rt.first_unmarked_range(), (5, Some(5)));
+ }
+
+ #[test]
+ fn unmark_sent_end() {
+ let mut rt = RangeTracker::default();
+
+ rt.mark_range(0, 5, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 5);
+ assert_eq!(rt.acked_from_zero(), 5);
+ rt.mark_range(5, 5, RangeState::Sent);
+ assert_eq!(rt.highest_offset(), 10);
+ assert_eq!(rt.acked_from_zero(), 5);
+ assert_eq!(rt.first_unmarked_range(), (10, None));
+
+ rt.unmark_sent();
+ assert_eq!(rt.highest_offset(), 5);
+ assert_eq!(rt.acked_from_zero(), 5);
+ assert_eq!(rt.first_unmarked_range(), (5, None));
+ }
+
+ #[test]
+ fn truncate_front() {
+ let mut v = VecDeque::new();
+ v.push_back(5);
+ v.push_back(6);
+ v.push_back(7);
+ v.push_front(4usize);
+
+ v.rotate_left(1);
+ v.truncate(3);
+ assert_eq!(*v.front().unwrap(), 5);
+ assert_eq!(*v.back().unwrap(), 7);
+ }
+
+ #[test]
+ fn test_unmark_range() {
+ let mut rt = RangeTracker::default();
+
+ rt.mark_range(5, 5, RangeState::Acked);
+ rt.mark_range(10, 5, RangeState::Sent);
+
+ // Should unmark sent but not acked range
+ rt.unmark_range(7, 6);
+
+ let res = rt.first_unmarked_range();
+ assert_eq!(res, (0, Some(5)));
+ assert_eq!(
+ rt.used.iter().next().unwrap(),
+ (&5, &(5, RangeState::Acked))
+ );
+ assert_eq!(
+ rt.used.iter().nth(1).unwrap(),
+ (&13, &(2, RangeState::Sent))
+ );
+ assert!(rt.used.iter().nth(2).is_none());
+ rt.mark_range(0, 5, RangeState::Sent);
+
+ let res = rt.first_unmarked_range();
+ assert_eq!(res, (10, Some(3)));
+ rt.mark_range(10, 3, RangeState::Sent);
+
+ let res = rt.first_unmarked_range();
+ assert_eq!(res, (15, None));
+ }
+
+ #[test]
+ #[allow(clippy::cognitive_complexity)]
+ fn tx_buffer_next_bytes_1() {
+ let mut txb = TxBuffer::new();
+
+ assert_eq!(txb.avail(), SEND_BUFFER_SIZE);
+
+ // Fill the buffer
+ assert_eq!(txb.send(&[1; SEND_BUFFER_SIZE * 2]), SEND_BUFFER_SIZE);
+ assert!(matches!(txb.next_bytes(),
+ Some((0, x)) if x.len()==SEND_BUFFER_SIZE
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Mark almost all as sent. Get what's left
+ let one_byte_from_end = SEND_BUFFER_SIZE as u64 - 1;
+ txb.mark_as_sent(0, one_byte_from_end as usize);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 1
+ && start == one_byte_from_end
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Mark all as sent. Get nothing
+ txb.mark_as_sent(0, SEND_BUFFER_SIZE);
+ assert!(txb.next_bytes().is_none());
+
+ // Mark as lost. Get it again
+ txb.mark_as_lost(one_byte_from_end, 1);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 1
+ && start == one_byte_from_end
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Mark a larger range lost, including beyond what's in the buffer even.
+ // Get a little more
+ let five_bytes_from_end = SEND_BUFFER_SIZE as u64 - 5;
+ txb.mark_as_lost(five_bytes_from_end, 100);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 5
+ && start == five_bytes_from_end
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Contig acked range at start means it can be removed from buffer
+ // Impl of vecdeque should now result in a split buffer when more data
+ // is sent
+ txb.mark_as_acked(0, five_bytes_from_end as usize);
+ assert_eq!(txb.send(&[2; 30]), 30);
+ // Just get 5 even though there is more
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 5
+ && start == five_bytes_from_end
+ && x.iter().all(|ch| *ch == 1)));
+ assert_eq!(txb.retired, five_bytes_from_end);
+ assert_eq!(txb.buffered(), 35);
+
+ // Marking that bit as sent should let the last contig bit be returned
+ // when called again
+ txb.mark_as_sent(five_bytes_from_end, 5);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 30
+ && start == SEND_BUFFER_SIZE as u64
+ && x.iter().all(|ch| *ch == 2)));
+ }
+
+ #[test]
+ fn tx_buffer_next_bytes_2() {
+ let mut txb = TxBuffer::new();
+
+ assert_eq!(txb.avail(), SEND_BUFFER_SIZE);
+
+ // Fill the buffer
+ assert_eq!(txb.send(&[1; SEND_BUFFER_SIZE * 2]), SEND_BUFFER_SIZE);
+ assert!(matches!(txb.next_bytes(),
+ Some((0, x)) if x.len()==SEND_BUFFER_SIZE
+ && x.iter().all(|ch| *ch == 1)));
+
+ // As above
+ let forty_bytes_from_end = SEND_BUFFER_SIZE as u64 - 40;
+
+ txb.mark_as_acked(0, forty_bytes_from_end as usize);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 40
+ && start == forty_bytes_from_end
+ ));
+
+ // Valid new data placed in split locations
+ assert_eq!(txb.send(&[2; 100]), 100);
+
+ // Mark a little more as sent
+ txb.mark_as_sent(forty_bytes_from_end, 10);
+ let thirty_bytes_from_end = forty_bytes_from_end + 10;
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 30
+ && start == thirty_bytes_from_end
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Mark a range 'A' in second slice as sent. Should still return the same
+ let range_a_start = SEND_BUFFER_SIZE as u64 + 30;
+ let range_a_end = range_a_start + 10;
+ txb.mark_as_sent(range_a_start, 10);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 30
+ && start == thirty_bytes_from_end
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Ack entire first slice and into second slice
+ let ten_bytes_past_end = SEND_BUFFER_SIZE as u64 + 10;
+ txb.mark_as_acked(0, ten_bytes_past_end as usize);
+
+ // Get up to marked range A
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 20
+ && start == ten_bytes_past_end
+ && x.iter().all(|ch| *ch == 2)));
+
+ txb.mark_as_sent(ten_bytes_past_end, 20);
+
+ // Get bit after earlier marked range A
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 60
+ && start == range_a_end
+ && x.iter().all(|ch| *ch == 2)));
+
+ // No more bytes.
+ txb.mark_as_sent(range_a_end, 60);
+ assert!(txb.next_bytes().is_none());
+ }
+
+ #[test]
+ fn test_stream_tx() {
+ let conn_fc = connection_fc(4096);
+ let conn_events = ConnectionEvents::default();
+
+ let mut s = SendStream::new(4.into(), 1024, Rc::clone(&conn_fc), conn_events);
+
+ let res = s.send(&[4; 100]).unwrap();
+ assert_eq!(res, 100);
+ s.mark_as_sent(0, 50, false);
+ if let SendStreamState::Send { fc, .. } = s.state() {
+ assert_eq!(fc.used(), 100);
+ } else {
+ panic!("unexpected stream state");
+ }
+
+ // Should hit stream flow control limit before filling up send buffer
+ let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap();
+ assert_eq!(res, 1024 - 100);
+
+ // should do nothing, max stream data already 1024
+ s.set_max_stream_data(1024);
+ let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap();
+ assert_eq!(res, 0);
+
+ // should now hit the conn flow control (4096)
+ s.set_max_stream_data(1_048_576);
+ let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap();
+ assert_eq!(res, 3072);
+
+ // should now hit the tx buffer size
+ conn_fc.borrow_mut().update(SEND_BUFFER_SIZE as u64);
+ let res = s.send(&[4; SEND_BUFFER_SIZE + 100]).unwrap();
+ assert_eq!(res, SEND_BUFFER_SIZE - 4096);
+
+ // TODO(agrover@mozilla.com): test ooo acks somehow
+ s.mark_as_acked(0, 40, false);
+ }
+
+ #[test]
+ fn test_tx_buffer_acks() {
+ let mut tx = TxBuffer::new();
+ assert_eq!(tx.send(&[4; 100]), 100);
+ let res = tx.next_bytes().unwrap();
+ assert_eq!(res.0, 0);
+ assert_eq!(res.1.len(), 100);
+ tx.mark_as_sent(0, 100);
+ let res = tx.next_bytes();
+ assert_eq!(res, None);
+
+ tx.mark_as_acked(0, 100);
+ let res = tx.next_bytes();
+ assert_eq!(res, None);
+ }
+
+ #[test]
+ fn send_stream_writable_event_gen() {
+ let conn_fc = connection_fc(2);
+ let mut conn_events = ConnectionEvents::default();
+
+ let mut s = SendStream::new(4.into(), 0, Rc::clone(&conn_fc), conn_events.clone());
+
+ // Stream is initially blocked (conn:2, stream:0)
+ // and will not accept data.
+ assert_eq!(s.send(b"hi").unwrap(), 0);
+
+ // increasing to (conn:2, stream:2) will allow 2 bytes, and also
+ // generate a SendStreamWritable event.
+ s.set_max_stream_data(2);
+ let evts = conn_events.events().collect::<Vec<_>>();
+ assert_eq!(evts.len(), 1);
+ assert!(matches!(
+ evts[0],
+ ConnectionEvent::SendStreamWritable { .. }
+ ));
+ assert_eq!(s.send(b"hello").unwrap(), 2);
+
+ // increasing to (conn:2, stream:4) will not generate an event or allow
+ // sending anything.
+ s.set_max_stream_data(4);
+ assert_eq!(conn_events.events().count(), 0);
+ assert_eq!(s.send(b"hello").unwrap(), 0);
+
+ // Increasing conn max (conn:4, stream:4) will unblock but not emit
+ // event b/c that happens in Connection::emit_frame() (tested in
+ // connection.rs)
+ assert!(conn_fc.borrow_mut().update(4));
+ assert_eq!(conn_events.events().count(), 0);
+ assert_eq!(s.avail(), 2);
+ assert_eq!(s.send(b"hello").unwrap(), 2);
+
+ // No event because still blocked by conn
+ s.set_max_stream_data(1_000_000_000);
+ assert_eq!(conn_events.events().count(), 0);
+
+ // No event because happens in emit_frame()
+ conn_fc.borrow_mut().update(1_000_000_000);
+ assert_eq!(conn_events.events().count(), 0);
+
+ // Unblocking both by a large amount will cause avail() to be limited by
+ // tx buffer size.
+ assert_eq!(s.avail(), SEND_BUFFER_SIZE - 4);
+
+ assert_eq!(
+ s.send(&[b'a'; SEND_BUFFER_SIZE]).unwrap(),
+ SEND_BUFFER_SIZE - 4
+ );
+
+ // No event because still blocked by tx buffer full
+ s.set_max_stream_data(2_000_000_000);
+ assert_eq!(conn_events.events().count(), 0);
+ assert_eq!(s.send(b"hello").unwrap(), 0);
+ }
+
+ #[test]
+ fn send_stream_writable_event_new_stream() {
+ let conn_fc = connection_fc(2);
+ let mut conn_events = ConnectionEvents::default();
+
+ let _s = SendStream::new(4.into(), 100, conn_fc, conn_events.clone());
+
+ // Creating a new stream with conn and stream credits should result in
+ // an event.
+ let evts = conn_events.events().collect::<Vec<_>>();
+ assert_eq!(evts.len(), 1);
+ assert!(matches!(
+ evts[0],
+ ConnectionEvent::SendStreamWritable { .. }
+ ));
+ }
+
+ fn as_stream_token(t: &RecoveryToken) -> &SendStreamRecoveryToken {
+ if let RecoveryToken::Stream(StreamRecoveryToken::Stream(rt)) = &t {
+ rt
+ } else {
+ panic!();
+ }
+ }
+
+ #[test]
+ // Verify lost frames handle fin properly
+ fn send_stream_get_frame_data() {
+ let conn_fc = connection_fc(100);
+ let conn_events = ConnectionEvents::default();
+
+ let mut s = SendStream::new(0.into(), 100, conn_fc, conn_events);
+ s.send(&[0; 10]).unwrap();
+ s.close();
+
+ let mut ss = SendStreams::default();
+ ss.insert(StreamId::from(0), s);
+
+ let mut tokens = Vec::new();
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+
+ // Write a small frame: no fin.
+ let written = builder.len();
+ builder.set_limit(written + 6);
+ ss.write_frames(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ );
+ assert_eq!(builder.len(), written + 6);
+ assert_eq!(tokens.len(), 1);
+ let f1_token = tokens.remove(0);
+ assert!(!as_stream_token(&f1_token).fin);
+
+ // Write the rest: fin.
+ let written = builder.len();
+ builder.set_limit(written + 200);
+ ss.write_frames(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ );
+ assert_eq!(builder.len(), written + 10);
+ assert_eq!(tokens.len(), 1);
+ let f2_token = tokens.remove(0);
+ assert!(as_stream_token(&f2_token).fin);
+
+ // Should be no more data to frame.
+ let written = builder.len();
+ ss.write_frames(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ );
+ assert_eq!(builder.len(), written);
+ assert!(tokens.is_empty());
+
+ // Mark frame 1 as lost
+ ss.lost(as_stream_token(&f1_token));
+
+ // Next frame should not set fin even though stream has fin but frame
+ // does not include end of stream
+ let written = builder.len();
+ ss.write_frames(
+ TransmissionPriority::default() + RetransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ );
+ assert_eq!(builder.len(), written + 7); // Needs a length this time.
+ assert_eq!(tokens.len(), 1);
+ let f4_token = tokens.remove(0);
+ assert!(!as_stream_token(&f4_token).fin);
+
+ // Mark frame 2 as lost
+ ss.lost(as_stream_token(&f2_token));
+
+ // Next frame should set fin because it includes end of stream
+ let written = builder.len();
+ ss.write_frames(
+ TransmissionPriority::default() + RetransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ );
+ assert_eq!(builder.len(), written + 10);
+ assert_eq!(tokens.len(), 1);
+ let f5_token = tokens.remove(0);
+ assert!(as_stream_token(&f5_token).fin);
+ }
+
+ #[test]
+ #[allow(clippy::cognitive_complexity)]
+ // Verify lost frames handle fin properly with zero length fin
+ fn send_stream_get_frame_zerolength_fin() {
+ let conn_fc = connection_fc(100);
+ let conn_events = ConnectionEvents::default();
+
+ let mut s = SendStream::new(0.into(), 100, conn_fc, conn_events);
+ s.send(&[0; 10]).unwrap();
+
+ let mut ss = SendStreams::default();
+ ss.insert(StreamId::from(0), s);
+
+ let mut tokens = Vec::new();
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ ss.write_frames(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ );
+ let f1_token = tokens.remove(0);
+ assert_eq!(as_stream_token(&f1_token).offset, 0);
+ assert_eq!(as_stream_token(&f1_token).length, 10);
+ assert!(!as_stream_token(&f1_token).fin);
+
+ // Should be no more data to frame
+ ss.write_frames(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ );
+ assert!(tokens.is_empty());
+
+ ss.get_mut(StreamId::from(0)).unwrap().close();
+
+ ss.write_frames(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ );
+ let f2_token = tokens.remove(0);
+ assert_eq!(as_stream_token(&f2_token).offset, 10);
+ assert_eq!(as_stream_token(&f2_token).length, 0);
+ assert!(as_stream_token(&f2_token).fin);
+
+ // Mark frame 2 as lost
+ ss.lost(as_stream_token(&f2_token));
+
+ // Next frame should set fin
+ ss.write_frames(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ );
+ let f3_token = tokens.remove(0);
+ assert_eq!(as_stream_token(&f3_token).offset, 10);
+ assert_eq!(as_stream_token(&f3_token).length, 0);
+ assert!(as_stream_token(&f3_token).fin);
+
+ // Mark frame 1 as lost
+ ss.lost(as_stream_token(&f1_token));
+
+ // Next frame should set fin and include all data
+ ss.write_frames(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut FrameStats::default(),
+ );
+ let f4_token = tokens.remove(0);
+ assert_eq!(as_stream_token(&f4_token).offset, 0);
+ assert_eq!(as_stream_token(&f4_token).length, 10);
+ assert!(as_stream_token(&f4_token).fin);
+ }
+
+ #[test]
+ fn data_blocked() {
+ let conn_fc = connection_fc(5);
+ let conn_events = ConnectionEvents::default();
+
+ let stream_id = StreamId::from(4);
+ let mut s = SendStream::new(stream_id, 2, Rc::clone(&conn_fc), conn_events);
+
+ // Only two bytes can be sent due to the stream limit.
+ assert_eq!(s.send(b"abc").unwrap(), 2);
+ assert_eq!(s.next_bytes(false), Some((0, &b"ab"[..])));
+
+ // This doesn't report blocking yet.
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ let mut tokens = Vec::new();
+ let mut stats = FrameStats::default();
+ s.write_blocked_frame(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut stats,
+ );
+ assert_eq!(stats.stream_data_blocked, 0);
+
+ // Blocking is reported after sending the last available credit.
+ s.mark_as_sent(0, 2, false);
+ s.write_blocked_frame(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut stats,
+ );
+ assert_eq!(stats.stream_data_blocked, 1);
+
+ // Now increase the stream limit and test the connection limit.
+ s.set_max_stream_data(10);
+
+ assert_eq!(s.send(b"abcd").unwrap(), 3);
+ assert_eq!(s.next_bytes(false), Some((2, &b"abc"[..])));
+ // DATA_BLOCKED is not sent yet.
+ conn_fc
+ .borrow_mut()
+ .write_frames(&mut builder, &mut tokens, &mut stats);
+ assert_eq!(stats.data_blocked, 0);
+
+ // DATA_BLOCKED is queued once bytes using all credit are sent.
+ s.mark_as_sent(2, 3, false);
+ conn_fc
+ .borrow_mut()
+ .write_frames(&mut builder, &mut tokens, &mut stats);
+ assert_eq!(stats.data_blocked, 1);
+ }
+
+ #[test]
+ fn data_blocked_atomic() {
+ let conn_fc = connection_fc(5);
+ let conn_events = ConnectionEvents::default();
+
+ let stream_id = StreamId::from(4);
+ let mut s = SendStream::new(stream_id, 2, Rc::clone(&conn_fc), conn_events);
+
+ // Stream is initially blocked (conn:5, stream:2)
+ // and will not accept atomic write of 3 bytes.
+ assert_eq!(s.send_atomic(b"abc").unwrap(), 0);
+
+ // Assert that STREAM_DATA_BLOCKED is sent.
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ let mut tokens = Vec::new();
+ let mut stats = FrameStats::default();
+ s.write_blocked_frame(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut stats,
+ );
+ assert_eq!(stats.stream_data_blocked, 1);
+
+ // Assert that a non-atomic write works.
+ assert_eq!(s.send(b"abc").unwrap(), 2);
+ assert_eq!(s.next_bytes(false), Some((0, &b"ab"[..])));
+ s.mark_as_sent(0, 2, false);
+
+ // Set limits to (conn:5, stream:10).
+ s.set_max_stream_data(10);
+
+ // An atomic write of 4 bytes exceeds the remaining limit of 3.
+ assert_eq!(s.send_atomic(b"abcd").unwrap(), 0);
+
+ // Assert that DATA_BLOCKED is sent.
+ conn_fc
+ .borrow_mut()
+ .write_frames(&mut builder, &mut tokens, &mut stats);
+ assert_eq!(stats.data_blocked, 1);
+
+ // Check that a non-atomic write works.
+ assert_eq!(s.send(b"abcd").unwrap(), 3);
+ assert_eq!(s.next_bytes(false), Some((2, &b"abc"[..])));
+ s.mark_as_sent(2, 3, false);
+
+ // Increase limits to (conn:15, stream:15).
+ s.set_max_stream_data(15);
+ conn_fc.borrow_mut().update(15);
+
+ // Check that atomic writing right up to the limit works.
+ assert_eq!(s.send_atomic(b"abcdefghij").unwrap(), 10);
+ }
+
+ #[test]
+ fn ack_fin_first() {
+ const MESSAGE: &[u8] = b"hello";
+ let len_u64 = u64::try_from(MESSAGE.len()).unwrap();
+
+ let conn_fc = connection_fc(len_u64);
+ let conn_events = ConnectionEvents::default();
+
+ let mut s = SendStream::new(StreamId::new(100), 0, conn_fc, conn_events);
+ s.set_max_stream_data(len_u64);
+
+ // Send all the data, then the fin.
+ _ = s.send(MESSAGE).unwrap();
+ s.mark_as_sent(0, MESSAGE.len(), false);
+ s.close();
+ s.mark_as_sent(len_u64, 0, true);
+
+ // Ack the fin, then the data.
+ s.mark_as_acked(len_u64, 0, true);
+ s.mark_as_acked(0, MESSAGE.len(), false);
+ assert!(s.is_terminal());
+ }
+
+ #[test]
+ fn ack_then_lose_fin() {
+ const MESSAGE: &[u8] = b"hello";
+ let len_u64 = u64::try_from(MESSAGE.len()).unwrap();
+
+ let conn_fc = connection_fc(len_u64);
+ let conn_events = ConnectionEvents::default();
+
+ let id = StreamId::new(100);
+ let mut s = SendStream::new(id, 0, conn_fc, conn_events);
+ s.set_max_stream_data(len_u64);
+
+ // Send all the data, then the fin.
+ _ = s.send(MESSAGE).unwrap();
+ s.mark_as_sent(0, MESSAGE.len(), false);
+ s.close();
+ s.mark_as_sent(len_u64, 0, true);
+
+ // Ack the fin, then mark it lost.
+ s.mark_as_acked(len_u64, 0, true);
+ s.mark_as_lost(len_u64, 0, true);
+
+ // No frame should be sent here.
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ let mut tokens = Vec::new();
+ let mut stats = FrameStats::default();
+ s.write_stream_frame(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut stats,
+ );
+ assert_eq!(stats.stream, 0);
+ }
+
+ /// Create a `SendStream` and force it into a state where it believes that
+ /// `offset` bytes have already been sent and acknowledged.
+ fn stream_with_sent(stream: u64, offset: usize) -> SendStream {
+ const MAX_VARINT: u64 = (1 << 62) - 1;
+
+ let conn_fc = connection_fc(MAX_VARINT);
+ let mut s = SendStream::new(
+ StreamId::from(stream),
+ MAX_VARINT,
+ conn_fc,
+ ConnectionEvents::default(),
+ );
+
+ let mut send_buf = TxBuffer::new();
+ send_buf.retired = u64::try_from(offset).unwrap();
+ send_buf.ranges.mark_range(0, offset, RangeState::Acked);
+ let mut fc = SenderFlowControl::new(StreamId::from(stream), MAX_VARINT);
+ fc.consume(offset);
+ let conn_fc = Rc::new(RefCell::new(SenderFlowControl::new((), MAX_VARINT)));
+ s.state = SendStreamState::Send {
+ fc,
+ conn_fc,
+ send_buf,
+ };
+ s
+ }
+
+ fn frame_sent_sid(stream: u64, offset: usize, len: usize, fin: bool, space: usize) -> bool {
+ const BUF: &[u8] = &[0x42; 128];
+
+ qtrace!(
+ "frame_sent stream={} offset={} len={} fin={}, space={}",
+ stream,
+ offset,
+ len,
+ fin,
+ space
+ );
+
+ let mut s = stream_with_sent(stream, offset);
+
+ // Now write out the proscribed data and maybe close.
+ if len > 0 {
+ s.send(&BUF[..len]).unwrap();
+ }
+ if fin {
+ s.close();
+ }
+
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ let header_len = builder.len();
+ builder.set_limit(header_len + space);
+
+ let mut tokens = Vec::new();
+ let mut stats = FrameStats::default();
+ s.write_stream_frame(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut stats,
+ );
+ qtrace!(
+ "STREAM frame: {}",
+ hex_with_len(&builder.as_ref()[header_len..])
+ );
+ stats.stream > 0
+ }
+
+ fn frame_sent(offset: usize, len: usize, fin: bool, space: usize) -> bool {
+ frame_sent_sid(0, offset, len, fin, space)
+ }
+
+ #[test]
+ fn stream_frame_empty() {
+ // Stream frames with empty data and no fin never work.
+ assert!(!frame_sent(10, 0, false, 2));
+ assert!(!frame_sent(10, 0, false, 3));
+ assert!(!frame_sent(10, 0, false, 4));
+ assert!(!frame_sent(10, 0, false, 5));
+ assert!(!frame_sent(10, 0, false, 100));
+
+ // Empty data with fin is only a problem if there is no space.
+ assert!(!frame_sent(0, 0, true, 1));
+ assert!(frame_sent(0, 0, true, 2));
+ assert!(!frame_sent(10, 0, true, 2));
+ assert!(frame_sent(10, 0, true, 3));
+ assert!(frame_sent(10, 0, true, 4));
+ assert!(frame_sent(10, 0, true, 5));
+ assert!(frame_sent(10, 0, true, 100));
+ }
+
+ #[test]
+ fn stream_frame_minimum() {
+ // Add minimum data
+ assert!(!frame_sent(10, 1, false, 3));
+ assert!(!frame_sent(10, 1, true, 3));
+ assert!(frame_sent(10, 1, false, 4));
+ assert!(frame_sent(10, 1, true, 4));
+ assert!(frame_sent(10, 1, false, 5));
+ assert!(frame_sent(10, 1, true, 5));
+ assert!(frame_sent(10, 1, false, 100));
+ assert!(frame_sent(10, 1, true, 100));
+ }
+
+ #[test]
+ fn stream_frame_more() {
+ // Try more data
+ assert!(!frame_sent(10, 100, false, 3));
+ assert!(!frame_sent(10, 100, true, 3));
+ assert!(frame_sent(10, 100, false, 4));
+ assert!(frame_sent(10, 100, true, 4));
+ assert!(frame_sent(10, 100, false, 5));
+ assert!(frame_sent(10, 100, true, 5));
+ assert!(frame_sent(10, 100, false, 100));
+ assert!(frame_sent(10, 100, true, 100));
+
+ assert!(frame_sent(10, 100, false, 1000));
+ assert!(frame_sent(10, 100, true, 1000));
+ }
+
+ #[test]
+ fn stream_frame_big_id() {
+ // A value that encodes to the largest varint.
+ const BIG: u64 = 1 << 30;
+ const BIGSZ: usize = 1 << 30;
+
+ assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 16));
+ assert!(!frame_sent_sid(BIG, BIGSZ, 0, true, 16));
+ assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 17));
+ assert!(frame_sent_sid(BIG, BIGSZ, 0, true, 17));
+ assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 18));
+ assert!(frame_sent_sid(BIG, BIGSZ, 0, true, 18));
+
+ assert!(!frame_sent_sid(BIG, BIGSZ, 1, false, 17));
+ assert!(!frame_sent_sid(BIG, BIGSZ, 1, true, 17));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 18));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 18));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 19));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 19));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 100));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 100));
+ }
+
+ fn stream_frame_at_boundary(data: &[u8]) {
+ fn send_with_extra_capacity(data: &[u8], extra: usize, expect_full: bool) -> Vec<u8> {
+ qtrace!("send_with_extra_capacity {} + {}", data.len(), extra);
+ let mut s = stream_with_sent(0, 0);
+ s.send(data).unwrap();
+ s.close();
+
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ let header_len = builder.len();
+ // Add 2 for the frame type and stream ID, then add the extra.
+ builder.set_limit(header_len + data.len() + 2 + extra);
+ let mut tokens = Vec::new();
+ let mut stats = FrameStats::default();
+ s.write_stream_frame(
+ TransmissionPriority::default(),
+ &mut builder,
+ &mut tokens,
+ &mut stats,
+ );
+ assert_eq!(stats.stream, 1);
+ assert_eq!(builder.is_full(), expect_full);
+ Vec::from(Encoder::from(builder)).split_off(header_len)
+ }
+
+ // The minimum amount of extra space for getting another frame in.
+ let mut enc = Encoder::new();
+ enc.encode_varint(u64::try_from(data.len()).unwrap());
+ let len_buf = Vec::from(enc);
+ let minimum_extra = len_buf.len() + PacketBuilder::MINIMUM_FRAME_SIZE;
+
+ // For anything short of the minimum extra, the frame should fill the packet.
+ for i in 0..minimum_extra {
+ let frame = send_with_extra_capacity(data, i, true);
+ let (header, body) = frame.split_at(2);
+ assert_eq!(header, &[0b1001, 0]);
+ assert_eq!(body, data);
+ }
+
+ // Once there is space for another packet AND a length field,
+ // then a length will be added.
+ let frame = send_with_extra_capacity(data, minimum_extra, false);
+ let (header, rest) = frame.split_at(2);
+ assert_eq!(header, &[0b1011, 0]);
+ let (len, body) = rest.split_at(len_buf.len());
+ assert_eq!(len, &len_buf);
+ assert_eq!(body, data);
+ }
+
+ /// 16383/16384 is an odd boundary in STREAM frame construction.
+ /// That is the boundary where a length goes from 2 bytes to 4 bytes.
+ /// Test that we correctly add a length field to the frame; and test
+ /// that if we don't, then we don't allow other frames to be added.
+ #[test]
+ fn stream_frame_16384() {
+ stream_frame_at_boundary(&[4; 16383]);
+ stream_frame_at_boundary(&[4; 16384]);
+ }
+
+ /// 63/64 is the other odd boundary.
+ #[test]
+ fn stream_frame_64() {
+ stream_frame_at_boundary(&[2; 63]);
+ stream_frame_at_boundary(&[2; 64]);
+ }
+
+ fn check_stats(
+ stream: &SendStream,
+ expected_written: u64,
+ expected_sent: u64,
+ expected_acked: u64,
+ ) {
+ let stream_stats = stream.stats();
+ assert_eq!(stream_stats.bytes_written(), expected_written);
+ assert_eq!(stream_stats.bytes_sent(), expected_sent);
+ assert_eq!(stream_stats.bytes_acked(), expected_acked);
+ }
+
+ #[test]
+ fn send_stream_stats() {
+ const MESSAGE: &[u8] = b"hello";
+ let len_u64 = u64::try_from(MESSAGE.len()).unwrap();
+
+ let conn_fc = connection_fc(len_u64);
+ let conn_events = ConnectionEvents::default();
+
+ let id = StreamId::new(100);
+ let mut s = SendStream::new(id, 0, conn_fc, conn_events);
+ s.set_max_stream_data(len_u64);
+
+ // Initial stats should be all 0.
+ check_stats(&s, 0, 0, 0);
+ // Adter sending the data, bytes_written should be increased.
+ _ = s.send(MESSAGE).unwrap();
+ check_stats(&s, len_u64, 0, 0);
+
+ // Adter calling mark_as_sent, bytes_sent should be increased.
+ s.mark_as_sent(0, MESSAGE.len(), false);
+ check_stats(&s, len_u64, len_u64, 0);
+
+ s.close();
+ s.mark_as_sent(len_u64, 0, true);
+
+ // In the end, check bytes_acked.
+ s.mark_as_acked(0, MESSAGE.len(), false);
+ check_stats(&s, len_u64, len_u64, len_u64);
+
+ s.mark_as_acked(len_u64, 0, true);
+ assert!(s.is_terminal());
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/sender.rs b/third_party/rust/neqo-transport/src/sender.rs
new file mode 100644
index 0000000000..9a00dfc7a7
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/sender.rs
@@ -0,0 +1,130 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Congestion control
+#![deny(clippy::pedantic)]
+#![allow(clippy::module_name_repetitions)]
+
+use std::{
+ fmt::{self, Debug, Display},
+ time::{Duration, Instant},
+};
+
+use neqo_common::qlog::NeqoQlog;
+
+use crate::{
+ cc::{ClassicCongestionControl, CongestionControl, CongestionControlAlgorithm, Cubic, NewReno},
+ pace::Pacer,
+ rtt::RttEstimate,
+ tracking::SentPacket,
+};
+
+/// The number of packets we allow to burst from the pacer.
+pub const PACING_BURST_SIZE: usize = 2;
+
+#[derive(Debug)]
+pub struct PacketSender {
+ cc: Box<dyn CongestionControl>,
+ pacer: Pacer,
+}
+
+impl Display for PacketSender {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{} {}", self.cc, self.pacer)
+ }
+}
+
+impl PacketSender {
+ #[must_use]
+ pub fn new(
+ alg: CongestionControlAlgorithm,
+ pacing_enabled: bool,
+ mtu: usize,
+ now: Instant,
+ ) -> Self {
+ Self {
+ cc: match alg {
+ CongestionControlAlgorithm::NewReno => {
+ Box::new(ClassicCongestionControl::new(NewReno::default()))
+ }
+ CongestionControlAlgorithm::Cubic => {
+ Box::new(ClassicCongestionControl::new(Cubic::default()))
+ }
+ },
+ pacer: Pacer::new(pacing_enabled, now, mtu * PACING_BURST_SIZE, mtu),
+ }
+ }
+
+ pub fn set_qlog(&mut self, qlog: NeqoQlog) {
+ self.cc.set_qlog(qlog);
+ }
+
+ #[must_use]
+ pub fn cwnd(&self) -> usize {
+ self.cc.cwnd()
+ }
+
+ #[must_use]
+ pub fn cwnd_avail(&self) -> usize {
+ self.cc.cwnd_avail()
+ }
+
+ pub fn on_packets_acked(
+ &mut self,
+ acked_pkts: &[SentPacket],
+ rtt_est: &RttEstimate,
+ now: Instant,
+ ) {
+ self.cc.on_packets_acked(acked_pkts, rtt_est, now);
+ }
+
+ /// Called when packets are lost. Returns true if the congestion window was reduced.
+ pub fn on_packets_lost(
+ &mut self,
+ first_rtt_sample_time: Option<Instant>,
+ prev_largest_acked_sent: Option<Instant>,
+ pto: Duration,
+ lost_packets: &[SentPacket],
+ ) -> bool {
+ self.cc.on_packets_lost(
+ first_rtt_sample_time,
+ prev_largest_acked_sent,
+ pto,
+ lost_packets,
+ )
+ }
+
+ pub fn discard(&mut self, pkt: &SentPacket) {
+ self.cc.discard(pkt);
+ }
+
+ /// When we migrate, the congestion controller for the previously active path drops
+ /// all bytes in flight.
+ pub fn discard_in_flight(&mut self) {
+ self.cc.discard_in_flight();
+ }
+
+ pub fn on_packet_sent(&mut self, pkt: &SentPacket, rtt: Duration) {
+ self.pacer
+ .spend(pkt.time_sent, rtt, self.cc.cwnd(), pkt.size);
+ self.cc.on_packet_sent(pkt);
+ }
+
+ #[must_use]
+ pub fn next_paced(&self, rtt: Duration) -> Option<Instant> {
+ // Only pace if there are bytes in flight.
+ if self.cc.bytes_in_flight() > 0 {
+ Some(self.pacer.next(rtt, self.cc.cwnd()))
+ } else {
+ None
+ }
+ }
+
+ #[must_use]
+ pub fn recovery_packet(&self) -> bool {
+ self.cc.recovery_packet()
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/server.rs b/third_party/rust/neqo-transport/src/server.rs
new file mode 100644
index 0000000000..12a7d2f9e0
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/server.rs
@@ -0,0 +1,782 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// This file implements a server that can handle multiple connections.
+
+use std::{
+ cell::RefCell,
+ collections::{HashMap, HashSet, VecDeque},
+ fs::OpenOptions,
+ mem,
+ net::SocketAddr,
+ ops::{Deref, DerefMut},
+ path::PathBuf,
+ rc::{Rc, Weak},
+ time::{Duration, Instant},
+};
+
+use neqo_common::{
+ self as common, event::Provider, hex, qdebug, qerror, qinfo, qlog::NeqoQlog, qtrace, qwarn,
+ timer::Timer, Datagram, Decoder, Role,
+};
+use neqo_crypto::{
+ encode_ech_config, AntiReplay, Cipher, PrivateKey, PublicKey, ZeroRttCheckResult,
+ ZeroRttChecker,
+};
+use qlog::streamer::QlogStreamer;
+
+pub use crate::addr_valid::ValidateAddress;
+use crate::{
+ addr_valid::{AddressValidation, AddressValidationResult},
+ cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef},
+ connection::{Connection, Output, State},
+ packet::{PacketBuilder, PacketType, PublicPacket},
+ ConnectionParameters, Res, Version,
+};
+
+pub enum InitialResult {
+ Accept,
+ Drop,
+ Retry(Vec<u8>),
+}
+
+/// MIN_INITIAL_PACKET_SIZE is the smallest packet that can be used to establish
+/// a new connection across all QUIC versions this server supports.
+const MIN_INITIAL_PACKET_SIZE: usize = 1200;
+/// The size of timer buckets. This is higher than the actual timer granularity
+/// as this depends on there being some distribution of events.
+const TIMER_GRANULARITY: Duration = Duration::from_millis(4);
+/// The number of buckets in the timer. As mentioned in the definition of `Timer`,
+/// the granularity and capacity need to multiply to be larger than the largest
+/// delay that might be used. That's the idle timeout (currently 30s).
+const TIMER_CAPACITY: usize = 16384;
+
+type StateRef = Rc<RefCell<ServerConnectionState>>;
+type ConnectionTableRef = Rc<RefCell<HashMap<ConnectionId, StateRef>>>;
+
+#[derive(Debug)]
+pub struct ServerConnectionState {
+ c: Connection,
+ active_attempt: Option<AttemptKey>,
+ last_timer: Instant,
+}
+
+impl Deref for ServerConnectionState {
+ type Target = Connection;
+ fn deref(&self) -> &Self::Target {
+ &self.c
+ }
+}
+
+impl DerefMut for ServerConnectionState {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.c
+ }
+}
+
+/// A `AttemptKey` is used to disambiguate connection attempts.
+/// Multiple connection attempts with the same key won't produce multiple connections.
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+struct AttemptKey {
+ // Using the remote address is sufficient for disambiguation,
+ // until we support multiple local socket addresses.
+ remote_address: SocketAddr,
+ odcid: ConnectionId,
+}
+
+/// A `ServerZeroRttChecker` is a simple wrapper around a single checker.
+/// It uses `RefCell` so that the wrapped checker can be shared between
+/// multiple connections created by the server.
+#[derive(Clone, Debug)]
+struct ServerZeroRttChecker {
+ checker: Rc<RefCell<Box<dyn ZeroRttChecker>>>,
+}
+
+impl ServerZeroRttChecker {
+ pub fn new(checker: Box<dyn ZeroRttChecker>) -> Self {
+ Self {
+ checker: Rc::new(RefCell::new(checker)),
+ }
+ }
+}
+
+impl ZeroRttChecker for ServerZeroRttChecker {
+ fn check(&self, token: &[u8]) -> ZeroRttCheckResult {
+ self.checker.borrow().check(token)
+ }
+}
+
+/// `InitialDetails` holds important information for processing `Initial` packets.
+struct InitialDetails {
+ src_cid: ConnectionId,
+ dst_cid: ConnectionId,
+ token: Vec<u8>,
+ version: Version,
+}
+
+impl InitialDetails {
+ fn new(packet: &PublicPacket) -> Self {
+ Self {
+ src_cid: ConnectionId::from(packet.scid()),
+ dst_cid: ConnectionId::from(packet.dcid()),
+ token: packet.token().to_vec(),
+ version: packet.version().unwrap(),
+ }
+ }
+}
+
+struct EchConfig {
+ config: u8,
+ public_name: String,
+ sk: PrivateKey,
+ pk: PublicKey,
+ encoded: Vec<u8>,
+}
+
+impl EchConfig {
+ fn new(config: u8, public_name: &str, sk: &PrivateKey, pk: &PublicKey) -> Res<Self> {
+ let encoded = encode_ech_config(config, public_name, pk)?;
+ Ok(Self {
+ config,
+ public_name: String::from(public_name),
+ sk: sk.clone(),
+ pk: pk.clone(),
+ encoded,
+ })
+ }
+}
+
+pub struct Server {
+ /// The names of certificates.
+ certs: Vec<String>,
+ /// The ALPN values that the server supports.
+ protocols: Vec<String>,
+ /// The cipher suites that the server supports.
+ ciphers: Vec<Cipher>,
+ /// Anti-replay configuration for 0-RTT.
+ anti_replay: AntiReplay,
+ /// A function for determining if 0-RTT can be accepted.
+ zero_rtt_checker: ServerZeroRttChecker,
+ /// A connection ID generator.
+ cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>,
+ /// Connection parameters.
+ conn_params: ConnectionParameters,
+ /// Active connection attempts, keyed by `AttemptKey`. Initial packets with
+ /// the same key are routed to the connection that was first accepted.
+ /// This is cleared out when the connection is closed or established.
+ active_attempts: HashMap<AttemptKey, StateRef>,
+ /// All connections, keyed by ConnectionId.
+ connections: ConnectionTableRef,
+ /// The connections that have new events.
+ active: HashSet<ActiveConnectionRef>,
+ /// The set of connections that need immediate processing.
+ waiting: VecDeque<StateRef>,
+ /// Outstanding timers for connections.
+ timers: Timer<StateRef>,
+ /// Address validation logic, which determines whether we send a Retry.
+ address_validation: Rc<RefCell<AddressValidation>>,
+ /// Directory to create qlog traces in
+ qlog_dir: Option<PathBuf>,
+ /// Encrypted client hello (ECH) configuration.
+ ech_config: Option<EchConfig>,
+}
+
+impl Server {
+ /// Construct a new server.
+ /// * `now` is the time that the server is instantiated.
+ /// * `certs` is a list of the certificates that should be configured.
+ /// * `protocols` is the preference list of ALPN values.
+ /// * `anti_replay` is an anti-replay context.
+ /// * `zero_rtt_checker` determines whether 0-RTT should be accepted. This will be passed the
+ /// value of the `extra` argument that was passed to `Connection::send_ticket` to see if it is
+ /// OK.
+ /// * `cid_generator` is responsible for generating connection IDs and parsing them; connection
+ /// IDs produced by the manager cannot be zero-length.
+ pub fn new(
+ now: Instant,
+ certs: &[impl AsRef<str>],
+ protocols: &[impl AsRef<str>],
+ anti_replay: AntiReplay,
+ zero_rtt_checker: Box<dyn ZeroRttChecker>,
+ cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>,
+ conn_params: ConnectionParameters,
+ ) -> Res<Self> {
+ let validation = AddressValidation::new(now, ValidateAddress::Never)?;
+ Ok(Self {
+ certs: certs.iter().map(|x| String::from(x.as_ref())).collect(),
+ protocols: protocols.iter().map(|x| String::from(x.as_ref())).collect(),
+ ciphers: Vec::new(),
+ anti_replay,
+ zero_rtt_checker: ServerZeroRttChecker::new(zero_rtt_checker),
+ cid_generator,
+ conn_params,
+ active_attempts: HashMap::default(),
+ connections: Rc::default(),
+ active: HashSet::default(),
+ waiting: VecDeque::default(),
+ timers: Timer::new(now, TIMER_GRANULARITY, TIMER_CAPACITY),
+ address_validation: Rc::new(RefCell::new(validation)),
+ qlog_dir: None,
+ ech_config: None,
+ })
+ }
+
+ /// Set or clear directory to create logs of connection events in QLOG format.
+ pub fn set_qlog_dir(&mut self, dir: Option<PathBuf>) {
+ self.qlog_dir = dir;
+ }
+
+ /// Set the policy for address validation.
+ pub fn set_validation(&mut self, v: ValidateAddress) {
+ self.address_validation.borrow_mut().set_validation(v);
+ }
+
+ /// Set the cipher suites that should be used. Set an empty value to use
+ /// default values.
+ pub fn set_ciphers(&mut self, ciphers: impl AsRef<[Cipher]>) {
+ self.ciphers = Vec::from(ciphers.as_ref());
+ }
+
+ pub fn enable_ech(
+ &mut self,
+ config: u8,
+ public_name: &str,
+ sk: &PrivateKey,
+ pk: &PublicKey,
+ ) -> Res<()> {
+ self.ech_config = Some(EchConfig::new(config, public_name, sk, pk)?);
+ Ok(())
+ }
+
+ pub fn ech_config(&self) -> &[u8] {
+ self.ech_config.as_ref().map_or(&[], |cfg| &cfg.encoded)
+ }
+
+ fn remove_timer(&mut self, c: &StateRef) {
+ let last = c.borrow().last_timer;
+ self.timers.remove(last, |t| Rc::ptr_eq(t, c));
+ }
+
+ fn process_connection(
+ &mut self,
+ c: StateRef,
+ dgram: Option<&Datagram>,
+ now: Instant,
+ ) -> Option<Datagram> {
+ qtrace!([self], "Process connection {:?}", c);
+ let out = c.borrow_mut().process(dgram, now);
+ match out {
+ Output::Datagram(_) => {
+ qtrace!([self], "Sending packet, added to waiting connections");
+ self.waiting.push_back(Rc::clone(&c));
+ }
+ Output::Callback(delay) => {
+ let next = now + delay;
+ if next != c.borrow().last_timer {
+ qtrace!([self], "Change timer to {:?}", next);
+ self.remove_timer(&c);
+ c.borrow_mut().last_timer = next;
+ self.timers.add(next, Rc::clone(&c));
+ }
+ }
+ Output::None => {
+ self.remove_timer(&c);
+ }
+ }
+ if c.borrow().has_events() {
+ qtrace!([self], "Connection active: {:?}", c);
+ self.active.insert(ActiveConnectionRef { c: Rc::clone(&c) });
+ }
+
+ if *c.borrow().state() > State::Handshaking {
+ // Remove any active connection attempt now that this is no longer handshaking.
+ if let Some(k) = c.borrow_mut().active_attempt.take() {
+ self.active_attempts.remove(&k);
+ }
+ }
+
+ if matches!(c.borrow().state(), State::Closed(_)) {
+ c.borrow_mut().set_qlog(NeqoQlog::disabled());
+ self.connections
+ .borrow_mut()
+ .retain(|_, v| !Rc::ptr_eq(v, &c));
+ }
+ out.dgram()
+ }
+
+ fn connection(&self, cid: ConnectionIdRef) -> Option<StateRef> {
+ self.connections.borrow().get(&cid[..]).map(Rc::clone)
+ }
+
+ fn handle_initial(
+ &mut self,
+ initial: InitialDetails,
+ dgram: &Datagram,
+ now: Instant,
+ ) -> Option<Datagram> {
+ qdebug!([self], "Handle initial");
+ let res = self
+ .address_validation
+ .borrow()
+ .validate(&initial.token, dgram.source(), now);
+ match res {
+ AddressValidationResult::Invalid => None,
+ AddressValidationResult::Pass => self.connection_attempt(initial, dgram, None, now),
+ AddressValidationResult::ValidRetry(orig_dcid) => {
+ self.connection_attempt(initial, dgram, Some(orig_dcid), now)
+ }
+ AddressValidationResult::Validate => {
+ qinfo!([self], "Send retry for {:?}", initial.dst_cid);
+
+ let res = self.address_validation.borrow().generate_retry_token(
+ &initial.dst_cid,
+ dgram.source(),
+ now,
+ );
+ let Ok(token) = res else {
+ qerror!([self], "unable to generate token, dropping packet");
+ return None;
+ };
+ if let Some(new_dcid) = self.cid_generator.borrow_mut().generate_cid() {
+ let packet = PacketBuilder::retry(
+ initial.version,
+ &initial.src_cid,
+ &new_dcid,
+ &token,
+ &initial.dst_cid,
+ );
+ if let Ok(p) = packet {
+ let retry = Datagram::new(
+ dgram.destination(),
+ dgram.source(),
+ dgram.tos(),
+ dgram.ttl(),
+ p,
+ );
+ Some(retry)
+ } else {
+ qerror!([self], "unable to encode retry, dropping packet");
+ None
+ }
+ } else {
+ qerror!([self], "no connection ID for retry, dropping packet");
+ None
+ }
+ }
+ }
+ }
+
+ fn connection_attempt(
+ &mut self,
+ initial: InitialDetails,
+ dgram: &Datagram,
+ orig_dcid: Option<ConnectionId>,
+ now: Instant,
+ ) -> Option<Datagram> {
+ let attempt_key = AttemptKey {
+ remote_address: dgram.source(),
+ odcid: orig_dcid.as_ref().unwrap_or(&initial.dst_cid).clone(),
+ };
+ if let Some(c) = self.active_attempts.get(&attempt_key) {
+ qdebug!(
+ [self],
+ "Handle Initial for existing connection attempt {:?}",
+ attempt_key
+ );
+ let c = Rc::clone(c);
+ self.process_connection(c, Some(dgram), now)
+ } else {
+ self.accept_connection(attempt_key, initial, dgram, orig_dcid, now)
+ }
+ }
+
+ fn create_qlog_trace(&self, odcid: ConnectionIdRef<'_>) -> NeqoQlog {
+ if let Some(qlog_dir) = &self.qlog_dir {
+ let mut qlog_path = qlog_dir.to_path_buf();
+
+ qlog_path.push(format!("{}.qlog", odcid));
+
+ // The original DCID is chosen by the client. Using create_new()
+ // prevents attackers from overwriting existing logs.
+ match OpenOptions::new()
+ .write(true)
+ .create_new(true)
+ .open(&qlog_path)
+ {
+ Ok(f) => {
+ qinfo!("Qlog output to {}", qlog_path.display());
+
+ let streamer = QlogStreamer::new(
+ qlog::QLOG_VERSION.to_string(),
+ Some("Neqo server qlog".to_string()),
+ Some("Neqo server qlog".to_string()),
+ None,
+ std::time::Instant::now(),
+ common::qlog::new_trace(Role::Server),
+ qlog::events::EventImportance::Base,
+ Box::new(f),
+ );
+ let n_qlog = NeqoQlog::enabled(streamer, qlog_path);
+ match n_qlog {
+ Ok(nql) => nql,
+ Err(e) => {
+ // Keep going but w/o qlogging
+ qerror!("NeqoQlog error: {}", e);
+ NeqoQlog::disabled()
+ }
+ }
+ }
+ Err(e) => {
+ qerror!(
+ "Could not open file {} for qlog output: {}",
+ qlog_path.display(),
+ e
+ );
+ NeqoQlog::disabled()
+ }
+ }
+ } else {
+ NeqoQlog::disabled()
+ }
+ }
+
+ fn setup_connection(
+ &mut self,
+ c: &mut Connection,
+ attempt_key: &AttemptKey,
+ initial: InitialDetails,
+ orig_dcid: Option<ConnectionId>,
+ ) {
+ let zcheck = self.zero_rtt_checker.clone();
+ if c.server_enable_0rtt(&self.anti_replay, zcheck).is_err() {
+ qwarn!([self], "Unable to enable 0-RTT");
+ }
+ if let Some(odcid) = orig_dcid {
+ // There was a retry, so set the connection IDs for.
+ c.set_retry_cids(odcid, initial.src_cid, initial.dst_cid);
+ }
+ c.set_validation(Rc::clone(&self.address_validation));
+ c.set_qlog(self.create_qlog_trace(attempt_key.odcid.as_cid_ref()));
+ if let Some(cfg) = &self.ech_config {
+ if c.server_enable_ech(cfg.config, &cfg.public_name, &cfg.sk, &cfg.pk)
+ .is_err()
+ {
+ qwarn!([self], "Unable to enable ECH");
+ }
+ }
+ }
+
+ fn accept_connection(
+ &mut self,
+ attempt_key: AttemptKey,
+ initial: InitialDetails,
+ dgram: &Datagram,
+ orig_dcid: Option<ConnectionId>,
+ now: Instant,
+ ) -> Option<Datagram> {
+ qinfo!([self], "Accept connection {:?}", attempt_key);
+ // The internal connection ID manager that we use is not used directly.
+ // Instead, wrap it so that we can save connection IDs.
+
+ let cid_mgr = Rc::new(RefCell::new(ServerConnectionIdGenerator {
+ c: Weak::new(),
+ cid_generator: Rc::clone(&self.cid_generator),
+ connections: Rc::clone(&self.connections),
+ saved_cids: Vec::new(),
+ }));
+
+ let mut params = self.conn_params.clone();
+ params.get_versions_mut().set_initial(initial.version);
+ let sconn = Connection::new_server(
+ &self.certs,
+ &self.protocols,
+ Rc::clone(&cid_mgr) as _,
+ params,
+ );
+
+ match sconn {
+ Ok(mut c) => {
+ self.setup_connection(&mut c, &attempt_key, initial, orig_dcid);
+ let c = Rc::new(RefCell::new(ServerConnectionState {
+ c,
+ last_timer: now,
+ active_attempt: Some(attempt_key.clone()),
+ }));
+ cid_mgr.borrow_mut().set_connection(Rc::clone(&c));
+ let previous_attempt = self.active_attempts.insert(attempt_key, Rc::clone(&c));
+ debug_assert!(previous_attempt.is_none());
+ self.process_connection(c, Some(dgram), now)
+ }
+ Err(e) => {
+ qwarn!([self], "Unable to create connection");
+ if e == crate::Error::VersionNegotiation {
+ crate::qlog::server_version_information_failed(
+ &mut self.create_qlog_trace(attempt_key.odcid.as_cid_ref()),
+ self.conn_params.get_versions().all(),
+ initial.version.wire_version(),
+ )
+ }
+ None
+ }
+ }
+ }
+
+ /// Handle 0-RTT packets that were sent with the client's choice of connection ID.
+ /// Most 0-RTT will arrive this way. A client can usually send 1-RTT after it
+ /// receives a connection ID from the server.
+ fn handle_0rtt(
+ &mut self,
+ dgram: &Datagram,
+ dcid: ConnectionId,
+ now: Instant,
+ ) -> Option<Datagram> {
+ let attempt_key = AttemptKey {
+ remote_address: dgram.source(),
+ odcid: dcid,
+ };
+ if let Some(c) = self.active_attempts.get(&attempt_key) {
+ qdebug!(
+ [self],
+ "Handle 0-RTT for existing connection attempt {:?}",
+ attempt_key
+ );
+ let c = Rc::clone(c);
+ self.process_connection(c, Some(dgram), now)
+ } else {
+ qdebug!([self], "Dropping 0-RTT for unknown connection");
+ None
+ }
+ }
+
+ fn process_input(&mut self, dgram: &Datagram, now: Instant) -> Option<Datagram> {
+ qtrace!("Process datagram: {}", hex(&dgram[..]));
+
+ // This is only looking at the first packet header in the datagram.
+ // All packets in the datagram are routed to the same connection.
+ let res = PublicPacket::decode(&dgram[..], self.cid_generator.borrow().as_decoder());
+ let Ok((packet, _remainder)) = res else {
+ qtrace!([self], "Discarding {:?}", dgram);
+ return None;
+ };
+
+ // Finding an existing connection. Should be the most common case.
+ if let Some(c) = self.connection(packet.dcid()) {
+ return self.process_connection(c, Some(dgram), now);
+ }
+
+ if packet.packet_type() == PacketType::Short {
+ // TODO send a stateless reset here.
+ qtrace!([self], "Short header packet for an unknown connection");
+ return None;
+ }
+
+ if packet.packet_type() == PacketType::OtherVersion
+ || (packet.packet_type() == PacketType::Initial
+ && !self
+ .conn_params
+ .get_versions()
+ .all()
+ .contains(&packet.version().unwrap()))
+ {
+ if dgram.len() < MIN_INITIAL_PACKET_SIZE {
+ qdebug!([self], "Unsupported version: too short");
+ return None;
+ }
+
+ qdebug!([self], "Unsupported version: {:x}", packet.wire_version());
+ let vn = PacketBuilder::version_negotiation(
+ &packet.scid()[..],
+ &packet.dcid()[..],
+ packet.wire_version(),
+ self.conn_params.get_versions().all(),
+ );
+
+ crate::qlog::server_version_information_failed(
+ &mut self.create_qlog_trace(packet.dcid()),
+ self.conn_params.get_versions().all(),
+ packet.wire_version(),
+ );
+
+ return Some(Datagram::new(
+ dgram.destination(),
+ dgram.source(),
+ dgram.tos(),
+ dgram.ttl(),
+ vn,
+ ));
+ }
+
+ match packet.packet_type() {
+ PacketType::Initial => {
+ if dgram.len() < MIN_INITIAL_PACKET_SIZE {
+ qdebug!([self], "Drop initial: too short");
+ return None;
+ }
+ // Copy values from `packet` because they are currently still borrowing from
+ // `dgram`.
+ let initial = InitialDetails::new(&packet);
+ self.handle_initial(initial, dgram, now)
+ }
+ PacketType::ZeroRtt => {
+ let dcid = ConnectionId::from(packet.dcid());
+ self.handle_0rtt(dgram, dcid, now)
+ }
+ PacketType::OtherVersion => unreachable!(),
+ _ => {
+ qtrace!([self], "Not an initial packet");
+ None
+ }
+ }
+ }
+
+ /// Iterate through the pending connections looking for any that might want
+ /// to send a datagram. Stop at the first one that does.
+ fn process_next_output(&mut self, now: Instant) -> Option<Datagram> {
+ qtrace!([self], "No packet to send, look at waiting connections");
+ while let Some(c) = self.waiting.pop_front() {
+ if let Some(d) = self.process_connection(c, None, now) {
+ return Some(d);
+ }
+ }
+ qtrace!([self], "No packet to send still, run timers");
+ while let Some(c) = self.timers.take_next(now) {
+ if let Some(d) = self.process_connection(c, None, now) {
+ return Some(d);
+ }
+ }
+ None
+ }
+
+ fn next_time(&mut self, now: Instant) -> Option<Duration> {
+ if self.waiting.is_empty() {
+ self.timers.next_time().map(|x| x - now)
+ } else {
+ Some(Duration::new(0, 0))
+ }
+ }
+
+ pub fn process(&mut self, dgram: Option<&Datagram>, now: Instant) -> Output {
+ dgram
+ .and_then(|d| self.process_input(d, now))
+ .or_else(|| self.process_next_output(now))
+ .map(|d| {
+ qtrace!([self], "Send packet: {:?}", d);
+ Output::Datagram(d)
+ })
+ .or_else(|| {
+ self.next_time(now).map(|delay| {
+ qtrace!([self], "Wait: {:?}", delay);
+ Output::Callback(delay)
+ })
+ })
+ .unwrap_or_else(|| {
+ qtrace!([self], "Go dormant");
+ Output::None
+ })
+ }
+
+ /// This lists the connections that have received new events
+ /// as a result of calling `process()`.
+ pub fn active_connections(&mut self) -> Vec<ActiveConnectionRef> {
+ mem::take(&mut self.active).into_iter().collect()
+ }
+
+ pub fn add_to_waiting(&mut self, c: ActiveConnectionRef) {
+ self.waiting.push_back(c.connection());
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct ActiveConnectionRef {
+ c: StateRef,
+}
+
+impl ActiveConnectionRef {
+ pub fn borrow(&self) -> impl Deref<Target = Connection> + '_ {
+ std::cell::Ref::map(self.c.borrow(), |c| &c.c)
+ }
+
+ pub fn borrow_mut(&mut self) -> impl DerefMut<Target = Connection> + '_ {
+ std::cell::RefMut::map(self.c.borrow_mut(), |c| &mut c.c)
+ }
+
+ pub fn connection(&self) -> StateRef {
+ Rc::clone(&self.c)
+ }
+}
+
+impl std::hash::Hash for ActiveConnectionRef {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ let ptr: *const _ = self.c.as_ref();
+ ptr.hash(state);
+ }
+}
+
+impl PartialEq for ActiveConnectionRef {
+ fn eq(&self, other: &Self) -> bool {
+ Rc::ptr_eq(&self.c, &other.c)
+ }
+}
+
+impl Eq for ActiveConnectionRef {}
+
+struct ServerConnectionIdGenerator {
+ c: Weak<RefCell<ServerConnectionState>>,
+ connections: ConnectionTableRef,
+ cid_generator: Rc<RefCell<dyn ConnectionIdGenerator>>,
+ saved_cids: Vec<ConnectionId>,
+}
+
+impl ServerConnectionIdGenerator {
+ pub fn set_connection(&mut self, c: StateRef) {
+ let saved = std::mem::replace(&mut self.saved_cids, Vec::with_capacity(0));
+ for cid in saved {
+ qtrace!("ServerConnectionIdGenerator inserting saved cid {}", cid);
+ self.insert_cid(cid, Rc::clone(&c));
+ }
+ self.c = Rc::downgrade(&c);
+ }
+
+ fn insert_cid(&mut self, cid: ConnectionId, rc: StateRef) {
+ debug_assert!(!cid.is_empty());
+ self.connections.borrow_mut().insert(cid, rc);
+ }
+}
+
+impl ConnectionIdDecoder for ServerConnectionIdGenerator {
+ fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> {
+ self.cid_generator.borrow_mut().decode_cid(dec)
+ }
+}
+
+impl ConnectionIdGenerator for ServerConnectionIdGenerator {
+ fn generate_cid(&mut self) -> Option<ConnectionId> {
+ let maybe_cid = self.cid_generator.borrow_mut().generate_cid();
+ if let Some(cid) = maybe_cid {
+ if let Some(rc) = self.c.upgrade() {
+ self.insert_cid(cid.clone(), rc);
+ } else {
+ // This function can be called before the connection is set.
+ // So save any connection IDs until that hookup happens.
+ qtrace!("ServerConnectionIdGenerator saving cid {}", cid);
+ self.saved_cids.push(cid.clone());
+ }
+ Some(cid)
+ } else {
+ None
+ }
+ }
+
+ fn as_decoder(&self) -> &dyn ConnectionIdDecoder {
+ self
+ }
+}
+
+impl ::std::fmt::Display for Server {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "Server")
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/stats.rs b/third_party/rust/neqo-transport/src/stats.rs
new file mode 100644
index 0000000000..d6c7a911f9
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/stats.rs
@@ -0,0 +1,235 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Tracking of some useful statistics.
+#![deny(clippy::pedantic)]
+
+use std::{
+ cell::RefCell,
+ fmt::{self, Debug},
+ ops::Deref,
+ rc::Rc,
+ time::Duration,
+};
+
+use neqo_common::qinfo;
+
+use crate::packet::PacketNumber;
+
+pub(crate) const MAX_PTO_COUNTS: usize = 16;
+
+#[derive(Default, Clone)]
+#[cfg_attr(test, derive(PartialEq, Eq))]
+#[allow(clippy::module_name_repetitions)]
+pub struct FrameStats {
+ pub all: usize,
+ pub ack: usize,
+ pub largest_acknowledged: PacketNumber,
+
+ pub crypto: usize,
+ pub stream: usize,
+ pub reset_stream: usize,
+ pub stop_sending: usize,
+
+ pub ping: usize,
+ pub padding: usize,
+
+ pub max_streams: usize,
+ pub streams_blocked: usize,
+ pub max_data: usize,
+ pub data_blocked: usize,
+ pub max_stream_data: usize,
+ pub stream_data_blocked: usize,
+
+ pub new_connection_id: usize,
+ pub retire_connection_id: usize,
+
+ pub path_challenge: usize,
+ pub path_response: usize,
+
+ pub connection_close: usize,
+ pub handshake_done: usize,
+ pub new_token: usize,
+
+ pub ack_frequency: usize,
+ pub datagram: usize,
+}
+
+impl Debug for FrameStats {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ writeln!(
+ f,
+ " crypto {} done {} token {} close {}",
+ self.crypto, self.handshake_done, self.new_token, self.connection_close,
+ )?;
+ writeln!(
+ f,
+ " ack {} (max {}) ping {} padding {}",
+ self.ack, self.largest_acknowledged, self.ping, self.padding
+ )?;
+ writeln!(
+ f,
+ " stream {} reset {} stop {}",
+ self.stream, self.reset_stream, self.stop_sending,
+ )?;
+ writeln!(
+ f,
+ " max: stream {} data {} stream_data {}",
+ self.max_streams, self.max_data, self.max_stream_data,
+ )?;
+ writeln!(
+ f,
+ " blocked: stream {} data {} stream_data {}",
+ self.streams_blocked, self.data_blocked, self.stream_data_blocked,
+ )?;
+ writeln!(f, " datagram {}", self.datagram)?;
+ writeln!(
+ f,
+ " ncid {} rcid {} pchallenge {} presponse {}",
+ self.new_connection_id,
+ self.retire_connection_id,
+ self.path_challenge,
+ self.path_response,
+ )?;
+ writeln!(f, " ack_frequency {}", self.ack_frequency)
+ }
+}
+
+/// Datagram stats
+#[derive(Default, Clone)]
+#[allow(clippy::module_name_repetitions)]
+pub struct DatagramStats {
+ /// The number of datagrams declared lost.
+ pub lost: usize,
+ /// The number of datagrams dropped due to being too large.
+ pub dropped_too_big: usize,
+ /// The number of datagrams dropped due to reaching the limit of the
+ /// outgoing queue.
+ pub dropped_queue_full: usize,
+}
+
+/// Connection statistics
+#[derive(Default, Clone)]
+#[allow(clippy::module_name_repetitions)]
+pub struct Stats {
+ info: String,
+
+ /// Total packets received, including all the bad ones.
+ pub packets_rx: usize,
+ /// Duplicate packets received.
+ pub dups_rx: usize,
+ /// Dropped packets or dropped garbage.
+ pub dropped_rx: usize,
+ /// The number of packet that were saved for later processing.
+ pub saved_datagrams: usize,
+
+ /// Total packets sent.
+ pub packets_tx: usize,
+ /// Total number of packets that are declared lost.
+ pub lost: usize,
+ /// Late acknowledgments, for packets that were declared lost already.
+ pub late_ack: usize,
+ /// Acknowledgments for packets that contained data that was marked
+ /// for retransmission when the PTO timer popped.
+ pub pto_ack: usize,
+
+ /// Whether the connection was resumed successfully.
+ pub resumed: bool,
+
+ /// The current, estimated round-trip time on the primary path.
+ pub rtt: Duration,
+ /// The current, estimated round-trip time variation on the primary path.
+ pub rttvar: Duration,
+ /// Whether the first RTT sample was guessed from a discarded packet.
+ pub rtt_init_guess: bool,
+
+ /// Count PTOs. Single PTOs, 2 PTOs in a row, 3 PTOs in row, etc. are counted
+ /// separately.
+ pub pto_counts: [usize; MAX_PTO_COUNTS],
+
+ /// Count frames received.
+ pub frame_rx: FrameStats,
+ /// Count frames sent.
+ pub frame_tx: FrameStats,
+
+ /// The number of incoming datagrams dropped due to reaching the limit
+ /// of the incoming queue.
+ pub incoming_datagram_dropped: usize,
+
+ pub datagram_tx: DatagramStats,
+}
+
+impl Stats {
+ pub fn init(&mut self, info: String) {
+ self.info = info;
+ }
+
+ pub fn pkt_dropped(&mut self, reason: impl AsRef<str>) {
+ self.dropped_rx += 1;
+ qinfo!(
+ [self.info],
+ "Dropped received packet: {}; Total: {}",
+ reason.as_ref(),
+ self.dropped_rx
+ );
+ }
+
+ /// # Panics
+ ///
+ /// When preconditions are violated.
+ pub fn add_pto_count(&mut self, count: usize) {
+ debug_assert!(count > 0);
+ if count >= MAX_PTO_COUNTS {
+ // We can't move this count any further, so stop.
+ return;
+ }
+ self.pto_counts[count - 1] += 1;
+ if count > 1 {
+ debug_assert!(self.pto_counts[count - 2] > 0);
+ self.pto_counts[count - 2] -= 1;
+ }
+ }
+}
+
+impl Debug for Stats {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ writeln!(f, "stats for {}", self.info)?;
+ writeln!(
+ f,
+ " rx: {} drop {} dup {} saved {}",
+ self.packets_rx, self.dropped_rx, self.dups_rx, self.saved_datagrams
+ )?;
+ writeln!(
+ f,
+ " tx: {} lost {} lateack {} ptoack {}",
+ self.packets_tx, self.lost, self.late_ack, self.pto_ack
+ )?;
+ writeln!(f, " resumed: {} ", self.resumed)?;
+ writeln!(f, " frames rx:")?;
+ self.frame_rx.fmt(f)?;
+ writeln!(f, " frames tx:")?;
+ self.frame_tx.fmt(f)
+ }
+}
+
+#[derive(Default, Clone)]
+#[allow(clippy::module_name_repetitions)]
+pub struct StatsCell {
+ stats: Rc<RefCell<Stats>>,
+}
+
+impl Deref for StatsCell {
+ type Target = RefCell<Stats>;
+ fn deref(&self) -> &Self::Target {
+ &self.stats
+ }
+}
+
+impl Debug for StatsCell {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ self.stats.borrow().fmt(f)
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/stream_id.rs b/third_party/rust/neqo-transport/src/stream_id.rs
new file mode 100644
index 0000000000..f3b07b86a8
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/stream_id.rs
@@ -0,0 +1,177 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Stream ID and stream index handling.
+
+use neqo_common::Role;
+
+#[derive(PartialEq, Debug, Copy, Clone, PartialOrd, Eq, Ord, Hash)]
+
+/// The type of stream, either Bi-Directional or Uni-Directional.
+pub enum StreamType {
+ BiDi,
+ UniDi,
+}
+
+#[derive(Debug, Eq, PartialEq, Clone, Copy, Ord, PartialOrd, Hash)]
+pub struct StreamId(u64);
+
+impl StreamId {
+ pub const fn new(id: u64) -> Self {
+ Self(id)
+ }
+
+ pub fn init(stream_type: StreamType, role: Role) -> Self {
+ let type_val = match stream_type {
+ StreamType::BiDi => 0,
+ StreamType::UniDi => 2,
+ };
+ Self(type_val + Self::role_bit(role))
+ }
+
+ pub fn as_u64(self) -> u64 {
+ self.0
+ }
+
+ pub fn is_bidi(self) -> bool {
+ self.as_u64() & 0x02 == 0
+ }
+
+ pub fn is_uni(self) -> bool {
+ !self.is_bidi()
+ }
+
+ pub fn stream_type(self) -> StreamType {
+ if self.is_bidi() {
+ StreamType::BiDi
+ } else {
+ StreamType::UniDi
+ }
+ }
+
+ pub fn is_client_initiated(self) -> bool {
+ self.as_u64() & 0x01 == 0
+ }
+
+ pub fn is_server_initiated(self) -> bool {
+ !self.is_client_initiated()
+ }
+
+ pub fn role(self) -> Role {
+ if self.is_client_initiated() {
+ Role::Client
+ } else {
+ Role::Server
+ }
+ }
+
+ pub fn is_self_initiated(self, my_role: Role) -> bool {
+ match my_role {
+ Role::Client if self.is_client_initiated() => true,
+ Role::Server if self.is_server_initiated() => true,
+ _ => false,
+ }
+ }
+
+ pub fn is_remote_initiated(self, my_role: Role) -> bool {
+ !self.is_self_initiated(my_role)
+ }
+
+ pub fn is_send_only(self, my_role: Role) -> bool {
+ self.is_uni() && self.is_self_initiated(my_role)
+ }
+
+ pub fn is_recv_only(self, my_role: Role) -> bool {
+ self.is_uni() && self.is_remote_initiated(my_role)
+ }
+
+ pub fn next(&mut self) {
+ self.0 += 4;
+ }
+
+ /// This returns a bit that is shared by all streams created by this role.
+ pub fn role_bit(role: Role) -> u64 {
+ match role {
+ Role::Server => 1,
+ Role::Client => 0,
+ }
+ }
+}
+
+impl From<u64> for StreamId {
+ fn from(val: u64) -> Self {
+ Self::new(val)
+ }
+}
+
+impl From<&u64> for StreamId {
+ fn from(val: &u64) -> Self {
+ Self::new(*val)
+ }
+}
+
+impl PartialEq<u64> for StreamId {
+ fn eq(&self, other: &u64) -> bool {
+ self.as_u64() == *other
+ }
+}
+
+impl AsRef<u64> for StreamId {
+ fn as_ref(&self) -> &u64 {
+ &self.0
+ }
+}
+
+impl ::std::fmt::Display for StreamId {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "{}", self.as_u64())
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use neqo_common::Role;
+
+ use super::StreamId;
+
+ #[test]
+ fn bidi_stream_properties() {
+ let id1 = StreamId::from(16);
+ assert!(id1.is_bidi());
+ assert!(!id1.is_uni());
+ assert!(id1.is_client_initiated());
+ assert!(!id1.is_server_initiated());
+ assert_eq!(id1.role(), Role::Client);
+ assert!(id1.is_self_initiated(Role::Client));
+ assert!(!id1.is_self_initiated(Role::Server));
+ assert!(!id1.is_remote_initiated(Role::Client));
+ assert!(id1.is_remote_initiated(Role::Server));
+ assert!(!id1.is_send_only(Role::Server));
+ assert!(!id1.is_send_only(Role::Client));
+ assert!(!id1.is_recv_only(Role::Server));
+ assert!(!id1.is_recv_only(Role::Client));
+ assert_eq!(id1.as_u64(), 16);
+ }
+
+ #[test]
+ fn uni_stream_properties() {
+ let id2 = StreamId::from(35);
+ assert!(!id2.is_bidi());
+ assert!(id2.is_uni());
+ assert!(!id2.is_client_initiated());
+ assert!(id2.is_server_initiated());
+ assert_eq!(id2.role(), Role::Server);
+ assert!(!id2.is_self_initiated(Role::Client));
+ assert!(id2.is_self_initiated(Role::Server));
+ assert!(id2.is_remote_initiated(Role::Client));
+ assert!(!id2.is_remote_initiated(Role::Server));
+ assert!(id2.is_send_only(Role::Server));
+ assert!(!id2.is_send_only(Role::Client));
+ assert!(!id2.is_recv_only(Role::Server));
+ assert!(id2.is_recv_only(Role::Client));
+ assert_eq!(id2.as_u64(), 35);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/streams.rs b/third_party/rust/neqo-transport/src/streams.rs
new file mode 100644
index 0000000000..7cbb29ce02
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/streams.rs
@@ -0,0 +1,547 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Stream management for a connection.
+use std::{cell::RefCell, cmp::Ordering, rc::Rc};
+
+use neqo_common::{qtrace, qwarn, Role};
+
+use crate::{
+ fc::{LocalStreamLimits, ReceiverFlowControl, RemoteStreamLimits, SenderFlowControl},
+ frame::Frame,
+ packet::PacketBuilder,
+ recovery::{RecoveryToken, StreamRecoveryToken},
+ recv_stream::{RecvStream, RecvStreams},
+ send_stream::{SendStream, SendStreams, TransmissionPriority},
+ stats::FrameStats,
+ stream_id::{StreamId, StreamType},
+ tparams::{self, TransportParametersHandler},
+ ConnectionEvents, Error, Res,
+};
+
+pub type SendOrder = i64;
+
+#[derive(Copy, Clone)]
+pub struct StreamOrder {
+ pub sendorder: Option<SendOrder>,
+}
+
+// We want highest to lowest, with None being higher than any value
+impl Ord for StreamOrder {
+ fn cmp(&self, other: &Self) -> Ordering {
+ if self.sendorder.is_some() && other.sendorder.is_some() {
+ // We want reverse order (high to low) when both values are specified.
+ other.sendorder.cmp(&self.sendorder)
+ } else {
+ self.sendorder.cmp(&other.sendorder)
+ }
+ }
+}
+
+impl PartialOrd for StreamOrder {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl PartialEq for StreamOrder {
+ fn eq(&self, other: &Self) -> bool {
+ self.sendorder == other.sendorder
+ }
+}
+
+impl Eq for StreamOrder {}
+
+pub struct Streams {
+ role: Role,
+ tps: Rc<RefCell<TransportParametersHandler>>,
+ events: ConnectionEvents,
+ sender_fc: Rc<RefCell<SenderFlowControl<()>>>,
+ receiver_fc: Rc<RefCell<ReceiverFlowControl<()>>>,
+ remote_stream_limits: RemoteStreamLimits,
+ local_stream_limits: LocalStreamLimits,
+ pub(crate) send: SendStreams,
+ pub(crate) recv: RecvStreams,
+}
+
+impl Streams {
+ pub fn new(
+ tps: Rc<RefCell<TransportParametersHandler>>,
+ role: Role,
+ events: ConnectionEvents,
+ ) -> Self {
+ let limit_bidi = tps
+ .borrow()
+ .local
+ .get_integer(tparams::INITIAL_MAX_STREAMS_BIDI);
+ let limit_uni = tps
+ .borrow()
+ .local
+ .get_integer(tparams::INITIAL_MAX_STREAMS_UNI);
+ let max_data = tps.borrow().local.get_integer(tparams::INITIAL_MAX_DATA);
+ Self {
+ role,
+ tps,
+ events,
+ sender_fc: Rc::new(RefCell::new(SenderFlowControl::new((), 0))),
+ receiver_fc: Rc::new(RefCell::new(ReceiverFlowControl::new((), max_data))),
+ remote_stream_limits: RemoteStreamLimits::new(limit_bidi, limit_uni, role),
+ local_stream_limits: LocalStreamLimits::new(role),
+ send: SendStreams::default(),
+ recv: RecvStreams::default(),
+ }
+ }
+
+ pub fn is_stream_id_allowed(&self, stream_id: StreamId) -> bool {
+ self.remote_stream_limits[stream_id.stream_type()].is_allowed(stream_id)
+ }
+
+ pub fn zero_rtt_rejected(&mut self) {
+ self.clear_streams();
+ debug_assert_eq!(
+ self.remote_stream_limits[StreamType::BiDi].max_active(),
+ self.tps
+ .borrow()
+ .local
+ .get_integer(tparams::INITIAL_MAX_STREAMS_BIDI)
+ );
+ debug_assert_eq!(
+ self.remote_stream_limits[StreamType::UniDi].max_active(),
+ self.tps
+ .borrow()
+ .local
+ .get_integer(tparams::INITIAL_MAX_STREAMS_UNI)
+ );
+ self.local_stream_limits = LocalStreamLimits::new(self.role);
+ }
+
+ pub fn input_frame(&mut self, frame: Frame, stats: &mut FrameStats) -> Res<()> {
+ match frame {
+ Frame::ResetStream {
+ stream_id,
+ application_error_code,
+ final_size,
+ } => {
+ stats.reset_stream += 1;
+ if let (_, Some(rs)) = self.obtain_stream(stream_id)? {
+ rs.reset(application_error_code, final_size)?;
+ }
+ }
+ Frame::StopSending {
+ stream_id,
+ application_error_code,
+ } => {
+ stats.stop_sending += 1;
+ self.events
+ .send_stream_stop_sending(stream_id, application_error_code);
+ if let (Some(ss), _) = self.obtain_stream(stream_id)? {
+ ss.reset(application_error_code);
+ }
+ }
+ Frame::Stream {
+ fin,
+ stream_id,
+ offset,
+ data,
+ ..
+ } => {
+ stats.stream += 1;
+ if let (_, Some(rs)) = self.obtain_stream(stream_id)? {
+ rs.inbound_stream_frame(fin, offset, data)?;
+ }
+ }
+ Frame::MaxData { maximum_data } => {
+ stats.max_data += 1;
+ self.handle_max_data(maximum_data);
+ }
+ Frame::MaxStreamData {
+ stream_id,
+ maximum_stream_data,
+ } => {
+ qtrace!(
+ "Stream {} Received MaxStreamData {}",
+ stream_id,
+ maximum_stream_data
+ );
+ stats.max_stream_data += 1;
+ if let (Some(ss), _) = self.obtain_stream(stream_id)? {
+ ss.set_max_stream_data(maximum_stream_data);
+ }
+ }
+ Frame::MaxStreams {
+ stream_type,
+ maximum_streams,
+ } => {
+ stats.max_streams += 1;
+ self.handle_max_streams(stream_type, maximum_streams);
+ }
+ Frame::DataBlocked { data_limit } => {
+ // Should never happen since we set data limit to max
+ qwarn!("Received DataBlocked with data limit {}", data_limit);
+ stats.data_blocked += 1;
+ self.handle_data_blocked();
+ }
+ Frame::StreamDataBlocked { stream_id, .. } => {
+ qtrace!("Received StreamDataBlocked");
+ stats.stream_data_blocked += 1;
+ // Terminate connection with STREAM_STATE_ERROR if send-only
+ // stream (-transport 19.13)
+ if stream_id.is_send_only(self.role) {
+ return Err(Error::StreamStateError);
+ }
+
+ if let (_, Some(rs)) = self.obtain_stream(stream_id)? {
+ rs.send_flowc_update();
+ }
+ }
+ Frame::StreamsBlocked { .. } => {
+ stats.streams_blocked += 1;
+ // We send an update evry time we retire a stream. There is no need to
+ // trigger flow updates here.
+ }
+ _ => unreachable!("This is not a stream Frame"),
+ }
+ Ok(())
+ }
+
+ fn write_maintenance_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ // Send `DATA_BLOCKED` as necessary.
+ self.sender_fc
+ .borrow_mut()
+ .write_frames(builder, tokens, stats);
+ if builder.is_full() {
+ return;
+ }
+
+ // Send `MAX_DATA` as necessary.
+ self.receiver_fc
+ .borrow_mut()
+ .write_frames(builder, tokens, stats);
+ if builder.is_full() {
+ return;
+ }
+
+ self.recv.write_frames(builder, tokens, stats);
+
+ self.remote_stream_limits[StreamType::BiDi].write_frames(builder, tokens, stats);
+ if builder.is_full() {
+ return;
+ }
+ self.remote_stream_limits[StreamType::UniDi].write_frames(builder, tokens, stats);
+ if builder.is_full() {
+ return;
+ }
+
+ self.local_stream_limits[StreamType::BiDi].write_frames(builder, tokens, stats);
+ if builder.is_full() {
+ return;
+ }
+
+ self.local_stream_limits[StreamType::UniDi].write_frames(builder, tokens, stats);
+ }
+
+ pub fn write_frames(
+ &mut self,
+ priority: TransmissionPriority,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ if priority == TransmissionPriority::Important {
+ self.write_maintenance_frames(builder, tokens, stats);
+ if builder.is_full() {
+ return;
+ }
+ }
+
+ self.send.write_frames(priority, builder, tokens, stats);
+ }
+
+ pub fn lost(&mut self, token: &StreamRecoveryToken) {
+ match token {
+ StreamRecoveryToken::Stream(st) => self.send.lost(st),
+ StreamRecoveryToken::ResetStream { stream_id } => self.send.reset_lost(*stream_id),
+ StreamRecoveryToken::StreamDataBlocked { stream_id, limit } => {
+ self.send.blocked_lost(*stream_id, *limit);
+ }
+ StreamRecoveryToken::MaxStreamData {
+ stream_id,
+ max_data,
+ } => {
+ if let Ok((_, Some(rs))) = self.obtain_stream(*stream_id) {
+ rs.max_stream_data_lost(*max_data);
+ }
+ }
+ StreamRecoveryToken::StopSending { stream_id } => {
+ if let Ok((_, Some(rs))) = self.obtain_stream(*stream_id) {
+ rs.stop_sending_lost();
+ }
+ }
+ StreamRecoveryToken::StreamsBlocked { stream_type, limit } => {
+ self.local_stream_limits[*stream_type].frame_lost(*limit);
+ }
+ StreamRecoveryToken::MaxStreams {
+ stream_type,
+ max_streams,
+ } => {
+ self.remote_stream_limits[*stream_type].frame_lost(*max_streams);
+ }
+ StreamRecoveryToken::DataBlocked(limit) => {
+ self.sender_fc.borrow_mut().frame_lost(*limit);
+ }
+ StreamRecoveryToken::MaxData(maximum_data) => {
+ self.receiver_fc.borrow_mut().frame_lost(*maximum_data);
+ }
+ }
+ }
+
+ pub fn acked(&mut self, token: &StreamRecoveryToken) {
+ match token {
+ StreamRecoveryToken::Stream(st) => self.send.acked(st),
+ StreamRecoveryToken::ResetStream { stream_id } => self.send.reset_acked(*stream_id),
+ StreamRecoveryToken::StopSending { stream_id } => {
+ if let Ok((_, Some(rs))) = self.obtain_stream(*stream_id) {
+ rs.stop_sending_acked();
+ }
+ }
+ // We only worry when these are lost
+ StreamRecoveryToken::DataBlocked(_)
+ | StreamRecoveryToken::StreamDataBlocked { .. }
+ | StreamRecoveryToken::MaxStreamData { .. }
+ | StreamRecoveryToken::StreamsBlocked { .. }
+ | StreamRecoveryToken::MaxStreams { .. }
+ | StreamRecoveryToken::MaxData(_) => (),
+ }
+ }
+
+ pub fn clear_streams(&mut self) {
+ self.send.clear();
+ self.recv.clear();
+ }
+
+ pub fn cleanup_closed_streams(&mut self) {
+ // filter the list, removing closed streams
+ self.send.remove_terminal();
+
+ let send = &self.send;
+ let (removed_bidi, removed_uni) = self.recv.clear_terminal(send, self.role);
+
+ // Send max_streams updates if we removed remote-initiated recv streams.
+ // The updates will be send if any steams has been removed.
+ self.remote_stream_limits[StreamType::BiDi].add_retired(removed_bidi);
+ self.remote_stream_limits[StreamType::UniDi].add_retired(removed_uni);
+ }
+
+ fn ensure_created_if_remote(&mut self, stream_id: StreamId) -> Res<()> {
+ if !stream_id.is_remote_initiated(self.role)
+ || !self.remote_stream_limits[stream_id.stream_type()].is_new_stream(stream_id)?
+ {
+ // If it is not a remote stream and stream already exist.
+ return Ok(());
+ }
+
+ let tp = match stream_id.stream_type() {
+ // From the local perspective, this is a remote- originated BiDi stream. From
+ // the remote perspective, this is a local-originated BiDi stream. Therefore,
+ // look at the local transport parameters for the
+ // INITIAL_MAX_STREAM_DATA_BIDI_REMOTE value to decide how much this endpoint
+ // will allow its peer to send.
+ StreamType::BiDi => tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE,
+ StreamType::UniDi => tparams::INITIAL_MAX_STREAM_DATA_UNI,
+ };
+ let recv_initial_max_stream_data = self.tps.borrow().local.get_integer(tp);
+
+ while self.remote_stream_limits[stream_id.stream_type()].is_new_stream(stream_id)? {
+ let next_stream_id =
+ self.remote_stream_limits[stream_id.stream_type()].take_stream_id();
+ self.events.new_stream(next_stream_id);
+
+ self.recv.insert(
+ next_stream_id,
+ RecvStream::new(
+ next_stream_id,
+ recv_initial_max_stream_data,
+ Rc::clone(&self.receiver_fc),
+ self.events.clone(),
+ ),
+ );
+
+ if next_stream_id.is_bidi() {
+ // From the local perspective, this is a remote- originated BiDi stream.
+ // From the remote perspective, this is a local-originated BiDi stream.
+ // Therefore, look at the remote's transport parameters for the
+ // INITIAL_MAX_STREAM_DATA_BIDI_LOCAL value to decide how much this endpoint
+ // is allowed to send its peer.
+ let send_initial_max_stream_data = self
+ .tps
+ .borrow()
+ .remote()
+ .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL);
+ self.send.insert(
+ next_stream_id,
+ SendStream::new(
+ next_stream_id,
+ send_initial_max_stream_data,
+ Rc::clone(&self.sender_fc),
+ self.events.clone(),
+ ),
+ );
+ }
+ }
+ Ok(())
+ }
+
+ /// Get or make a stream, and implicitly open additional streams as
+ /// indicated by its stream id.
+ pub fn obtain_stream(
+ &mut self,
+ stream_id: StreamId,
+ ) -> Res<(Option<&mut SendStream>, Option<&mut RecvStream>)> {
+ self.ensure_created_if_remote(stream_id)?;
+ Ok((
+ self.send.get_mut(stream_id).ok(),
+ self.recv.get_mut(stream_id).ok(),
+ ))
+ }
+
+ pub fn set_sendorder(&mut self, stream_id: StreamId, sendorder: Option<SendOrder>) -> Res<()> {
+ self.send.set_sendorder(stream_id, sendorder)
+ }
+
+ pub fn set_fairness(&mut self, stream_id: StreamId, fairness: bool) -> Res<()> {
+ self.send.set_fairness(stream_id, fairness)
+ }
+
+ pub fn stream_create(&mut self, st: StreamType) -> Res<StreamId> {
+ match self.local_stream_limits.take_stream_id(st) {
+ None => Err(Error::StreamLimitError),
+ Some(new_id) => {
+ let send_limit_tp = match st {
+ StreamType::UniDi => tparams::INITIAL_MAX_STREAM_DATA_UNI,
+ StreamType::BiDi => tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE,
+ };
+ let send_limit = self.tps.borrow().remote().get_integer(send_limit_tp);
+ let stream = SendStream::new(
+ new_id,
+ send_limit,
+ Rc::clone(&self.sender_fc),
+ self.events.clone(),
+ );
+ self.send.insert(new_id, stream);
+
+ if st == StreamType::BiDi {
+ // From the local perspective, this is a local- originated BiDi stream. From the
+ // remote perspective, this is a remote-originated BiDi stream. Therefore, look
+ // at the local transport parameters for the
+ // INITIAL_MAX_STREAM_DATA_BIDI_LOCAL value to decide how
+ // much this endpoint will allow its peer to send.
+ let recv_initial_max_stream_data = self
+ .tps
+ .borrow()
+ .local
+ .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL);
+
+ self.recv.insert(
+ new_id,
+ RecvStream::new(
+ new_id,
+ recv_initial_max_stream_data,
+ Rc::clone(&self.receiver_fc),
+ self.events.clone(),
+ ),
+ );
+ }
+ Ok(new_id)
+ }
+ }
+ }
+
+ pub fn handle_max_data(&mut self, maximum_data: u64) {
+ let conn_was_blocked = self.sender_fc.borrow().available() == 0;
+ let conn_credit_increased = self.sender_fc.borrow_mut().update(maximum_data);
+
+ if conn_was_blocked && conn_credit_increased {
+ for (id, ss) in &mut self.send {
+ if ss.avail() > 0 {
+ // These may not actually all be writable if one
+ // uses up all the conn credit. Not our fault.
+ self.events.send_stream_writable(*id);
+ }
+ }
+ }
+ }
+
+ pub fn handle_data_blocked(&mut self) {
+ self.receiver_fc.borrow_mut().send_flowc_update();
+ }
+
+ pub fn set_initial_limits(&mut self) {
+ _ = self.local_stream_limits[StreamType::BiDi].update(
+ self.tps
+ .borrow()
+ .remote()
+ .get_integer(tparams::INITIAL_MAX_STREAMS_BIDI),
+ );
+ _ = self.local_stream_limits[StreamType::UniDi].update(
+ self.tps
+ .borrow()
+ .remote()
+ .get_integer(tparams::INITIAL_MAX_STREAMS_UNI),
+ );
+
+ // As a client, there are two sets of initial limits for sending stream data.
+ // If the second limit is higher and streams have been created, then
+ // ensure that streams are not blocked on the lower limit.
+ if self.role == Role::Client {
+ self.send.update_initial_limit(self.tps.borrow().remote());
+ }
+
+ self.sender_fc.borrow_mut().update(
+ self.tps
+ .borrow()
+ .remote()
+ .get_integer(tparams::INITIAL_MAX_DATA),
+ );
+
+ if self.local_stream_limits[StreamType::BiDi].available() > 0 {
+ self.events.send_stream_creatable(StreamType::BiDi);
+ }
+ if self.local_stream_limits[StreamType::UniDi].available() > 0 {
+ self.events.send_stream_creatable(StreamType::UniDi);
+ }
+ }
+
+ pub fn handle_max_streams(&mut self, stream_type: StreamType, maximum_streams: u64) {
+ if self.local_stream_limits[stream_type].update(maximum_streams) {
+ self.events.send_stream_creatable(stream_type);
+ }
+ }
+
+ pub fn get_send_stream_mut(&mut self, stream_id: StreamId) -> Res<&mut SendStream> {
+ self.send.get_mut(stream_id)
+ }
+
+ pub fn get_send_stream(&self, stream_id: StreamId) -> Res<&SendStream> {
+ self.send.get(stream_id)
+ }
+
+ pub fn get_recv_stream_mut(&mut self, stream_id: StreamId) -> Res<&mut RecvStream> {
+ self.recv.get_mut(stream_id)
+ }
+
+ pub fn keep_alive(&mut self, stream_id: StreamId, keep: bool) -> Res<()> {
+ self.recv.keep_alive(stream_id, keep)
+ }
+
+ pub fn need_keep_alive(&mut self) -> bool {
+ self.recv.need_keep_alive()
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/tparams.rs b/third_party/rust/neqo-transport/src/tparams.rs
new file mode 100644
index 0000000000..1297829094
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/tparams.rs
@@ -0,0 +1,1130 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Transport parameters. See -transport section 7.3.
+
+use std::{
+ cell::RefCell,
+ collections::HashMap,
+ convert::TryFrom,
+ net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
+ rc::Rc,
+};
+
+use neqo_common::{hex, qdebug, qinfo, qtrace, Decoder, Encoder, Role};
+use neqo_crypto::{
+ constants::{TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS},
+ ext::{ExtensionHandler, ExtensionHandlerResult, ExtensionWriterResult},
+ random, HandshakeMessage, ZeroRttCheckResult, ZeroRttChecker,
+};
+
+use crate::{
+ cid::{ConnectionId, ConnectionIdEntry, CONNECTION_ID_SEQNO_PREFERRED, MAX_CONNECTION_ID_LEN},
+ version::{Version, VersionConfig, WireVersion},
+ Error, Res,
+};
+
+pub type TransportParameterId = u64;
+macro_rules! tpids {
+ { $($n:ident = $v:expr),+ $(,)? } => {
+ $(pub const $n: TransportParameterId = $v as TransportParameterId;)+
+
+ /// A complete list of internal transport parameters.
+ #[cfg(not(test))]
+ pub(crate) const INTERNAL_TRANSPORT_PARAMETERS: &[TransportParameterId] = &[ $($n),+ ];
+ };
+ }
+tpids! {
+ ORIGINAL_DESTINATION_CONNECTION_ID = 0x00,
+ IDLE_TIMEOUT = 0x01,
+ STATELESS_RESET_TOKEN = 0x02,
+ MAX_UDP_PAYLOAD_SIZE = 0x03,
+ INITIAL_MAX_DATA = 0x04,
+ INITIAL_MAX_STREAM_DATA_BIDI_LOCAL = 0x05,
+ INITIAL_MAX_STREAM_DATA_BIDI_REMOTE = 0x06,
+ INITIAL_MAX_STREAM_DATA_UNI = 0x07,
+ INITIAL_MAX_STREAMS_BIDI = 0x08,
+ INITIAL_MAX_STREAMS_UNI = 0x09,
+ ACK_DELAY_EXPONENT = 0x0a,
+ MAX_ACK_DELAY = 0x0b,
+ DISABLE_MIGRATION = 0x0c,
+ PREFERRED_ADDRESS = 0x0d,
+ ACTIVE_CONNECTION_ID_LIMIT = 0x0e,
+ INITIAL_SOURCE_CONNECTION_ID = 0x0f,
+ RETRY_SOURCE_CONNECTION_ID = 0x10,
+ VERSION_INFORMATION = 0x11,
+ GREASE_QUIC_BIT = 0x2ab2,
+ MIN_ACK_DELAY = 0xff02_de1a,
+ MAX_DATAGRAM_FRAME_SIZE = 0x0020,
+}
+
+#[derive(Clone, Debug)]
+pub struct PreferredAddress {
+ v4: Option<SocketAddrV4>,
+ v6: Option<SocketAddrV6>,
+}
+
+impl PreferredAddress {
+ /// Make a new preferred address configuration.
+ ///
+ /// # Panics
+ ///
+ /// If neither address is provided, or if either address is of the wrong type.
+ #[must_use]
+ pub fn new(v4: Option<SocketAddrV4>, v6: Option<SocketAddrV6>) -> Self {
+ assert!(v4.is_some() || v6.is_some());
+ if let Some(a) = v4 {
+ assert!(!a.ip().is_unspecified());
+ assert_ne!(a.port(), 0);
+ }
+ if let Some(a) = v6 {
+ assert!(!a.ip().is_unspecified());
+ assert_ne!(a.port(), 0);
+ }
+ Self { v4, v6 }
+ }
+
+ /// A generic version of `new()` for testing.
+ #[must_use]
+ #[cfg(test)]
+ pub fn new_any(v4: Option<std::net::SocketAddr>, v6: Option<std::net::SocketAddr>) -> Self {
+ use std::net::SocketAddr;
+
+ let v4 = v4.map(|v4| {
+ let SocketAddr::V4(v4) = v4 else {
+ panic!("not v4");
+ };
+ v4
+ });
+ let v6 = v6.map(|v6| {
+ let SocketAddr::V6(v6) = v6 else {
+ panic!("not v6");
+ };
+ v6
+ });
+ Self::new(v4, v6)
+ }
+
+ #[must_use]
+ pub fn ipv4(&self) -> Option<SocketAddrV4> {
+ self.v4
+ }
+ #[must_use]
+ pub fn ipv6(&self) -> Option<SocketAddrV6> {
+ self.v6
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum TransportParameter {
+ Bytes(Vec<u8>),
+ Integer(u64),
+ Empty,
+ PreferredAddress {
+ v4: Option<SocketAddrV4>,
+ v6: Option<SocketAddrV6>,
+ cid: ConnectionId,
+ srt: [u8; 16],
+ },
+ Versions {
+ current: WireVersion,
+ other: Vec<WireVersion>,
+ },
+}
+
+impl TransportParameter {
+ fn encode(&self, enc: &mut Encoder, tp: TransportParameterId) {
+ qdebug!("TP encoded; type 0x{:02x} val {:?}", tp, self);
+ enc.encode_varint(tp);
+ match self {
+ Self::Bytes(a) => {
+ enc.encode_vvec(a);
+ }
+ Self::Integer(a) => {
+ enc.encode_vvec_with(|enc_inner| {
+ enc_inner.encode_varint(*a);
+ });
+ }
+ Self::Empty => {
+ enc.encode_varint(0_u64);
+ }
+ Self::PreferredAddress { v4, v6, cid, srt } => {
+ enc.encode_vvec_with(|enc_inner| {
+ if let Some(v4) = v4 {
+ enc_inner.encode(&v4.ip().octets()[..]);
+ enc_inner.encode_uint(2, v4.port());
+ } else {
+ enc_inner.encode(&[0; 6]);
+ }
+ if let Some(v6) = v6 {
+ enc_inner.encode(&v6.ip().octets()[..]);
+ enc_inner.encode_uint(2, v6.port());
+ } else {
+ enc_inner.encode(&[0; 18]);
+ }
+ enc_inner.encode_vec(1, &cid[..]);
+ enc_inner.encode(&srt[..]);
+ });
+ }
+ Self::Versions { current, other } => {
+ enc.encode_vvec_with(|enc_inner| {
+ enc_inner.encode_uint(4, *current);
+ for v in other {
+ enc_inner.encode_uint(4, *v);
+ }
+ });
+ }
+ };
+ }
+
+ fn decode_preferred_address(d: &mut Decoder) -> Res<Self> {
+ // IPv4 address (maybe)
+ let v4ip =
+ Ipv4Addr::from(<[u8; 4]>::try_from(d.decode(4).ok_or(Error::NoMoreData)?).unwrap());
+ let v4port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?).unwrap();
+ // Can't have non-zero IP and zero port, or vice versa.
+ if v4ip.is_unspecified() ^ (v4port == 0) {
+ return Err(Error::TransportParameterError);
+ }
+ let v4 = if v4port == 0 {
+ None
+ } else {
+ Some(SocketAddrV4::new(v4ip, v4port))
+ };
+
+ // IPv6 address (mostly the same as v4)
+ let v6ip =
+ Ipv6Addr::from(<[u8; 16]>::try_from(d.decode(16).ok_or(Error::NoMoreData)?).unwrap());
+ let v6port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?).unwrap();
+ if v6ip.is_unspecified() ^ (v6port == 0) {
+ return Err(Error::TransportParameterError);
+ }
+ let v6 = if v6port == 0 {
+ None
+ } else {
+ Some(SocketAddrV6::new(v6ip, v6port, 0, 0))
+ };
+ // Need either v4 or v6 to be present.
+ if v4.is_none() && v6.is_none() {
+ return Err(Error::TransportParameterError);
+ }
+
+ // Connection ID (non-zero length)
+ let cid = ConnectionId::from(d.decode_vec(1).ok_or(Error::NoMoreData)?);
+ if cid.len() == 0 || cid.len() > MAX_CONNECTION_ID_LEN {
+ return Err(Error::TransportParameterError);
+ }
+
+ // Stateless reset token
+ let srtbuf = d.decode(16).ok_or(Error::NoMoreData)?;
+ let srt = <[u8; 16]>::try_from(srtbuf).unwrap();
+
+ Ok(Self::PreferredAddress { v4, v6, cid, srt })
+ }
+
+ fn decode_versions(dec: &mut Decoder) -> Res<Self> {
+ fn dv(dec: &mut Decoder) -> Res<WireVersion> {
+ let v = dec.decode_uint(4).ok_or(Error::NoMoreData)?;
+ if v == 0 {
+ Err(Error::TransportParameterError)
+ } else {
+ Ok(v as WireVersion)
+ }
+ }
+
+ let current = dv(dec)?;
+ // This rounding down is OK because `decode` checks for left over data.
+ let count = dec.remaining() / 4;
+ let mut other = Vec::with_capacity(count);
+ for _ in 0..count {
+ other.push(dv(dec)?);
+ }
+ Ok(Self::Versions { current, other })
+ }
+
+ fn decode(dec: &mut Decoder) -> Res<Option<(TransportParameterId, Self)>> {
+ let tp = dec.decode_varint().ok_or(Error::NoMoreData)?;
+ let content = dec.decode_vvec().ok_or(Error::NoMoreData)?;
+ qtrace!("TP {:x} length {:x}", tp, content.len());
+ let mut d = Decoder::from(content);
+ let value = match tp {
+ ORIGINAL_DESTINATION_CONNECTION_ID
+ | INITIAL_SOURCE_CONNECTION_ID
+ | RETRY_SOURCE_CONNECTION_ID => Self::Bytes(d.decode_remainder().to_vec()),
+ STATELESS_RESET_TOKEN => {
+ if d.remaining() != 16 {
+ return Err(Error::TransportParameterError);
+ }
+ Self::Bytes(d.decode_remainder().to_vec())
+ }
+ IDLE_TIMEOUT
+ | INITIAL_MAX_DATA
+ | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL
+ | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE
+ | INITIAL_MAX_STREAM_DATA_UNI
+ | MAX_ACK_DELAY
+ | MAX_DATAGRAM_FRAME_SIZE => match d.decode_varint() {
+ Some(v) => Self::Integer(v),
+ None => return Err(Error::TransportParameterError),
+ },
+
+ INITIAL_MAX_STREAMS_BIDI | INITIAL_MAX_STREAMS_UNI => match d.decode_varint() {
+ Some(v) if v <= (1 << 60) => Self::Integer(v),
+ _ => return Err(Error::StreamLimitError),
+ },
+
+ MAX_UDP_PAYLOAD_SIZE => match d.decode_varint() {
+ Some(v) if v >= 1200 => Self::Integer(v),
+ _ => return Err(Error::TransportParameterError),
+ },
+
+ ACK_DELAY_EXPONENT => match d.decode_varint() {
+ Some(v) if v <= 20 => Self::Integer(v),
+ _ => return Err(Error::TransportParameterError),
+ },
+ ACTIVE_CONNECTION_ID_LIMIT => match d.decode_varint() {
+ Some(v) if v >= 2 => Self::Integer(v),
+ _ => return Err(Error::TransportParameterError),
+ },
+
+ DISABLE_MIGRATION | GREASE_QUIC_BIT => Self::Empty,
+
+ PREFERRED_ADDRESS => Self::decode_preferred_address(&mut d)?,
+
+ MIN_ACK_DELAY => match d.decode_varint() {
+ Some(v) if v < (1 << 24) => Self::Integer(v),
+ _ => return Err(Error::TransportParameterError),
+ },
+
+ VERSION_INFORMATION => Self::decode_versions(&mut d)?,
+
+ // Skip.
+ _ => return Ok(None),
+ };
+ if d.remaining() > 0 {
+ return Err(Error::TooMuchData);
+ }
+ qdebug!("TP decoded; type 0x{:02x} val {:?}", tp, value);
+ Ok(Some((tp, value)))
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Eq)]
+pub struct TransportParameters {
+ params: HashMap<TransportParameterId, TransportParameter>,
+}
+
+impl TransportParameters {
+ /// Set a value.
+ pub fn set(&mut self, k: TransportParameterId, v: TransportParameter) {
+ self.params.insert(k, v);
+ }
+
+ /// Clear a key.
+ pub fn remove(&mut self, k: TransportParameterId) {
+ self.params.remove(&k);
+ }
+
+ /// Decode is a static function that parses transport parameters
+ /// using the provided decoder.
+ pub(crate) fn decode(d: &mut Decoder) -> Res<Self> {
+ let mut tps = Self::default();
+ qtrace!("Parsed fixed TP header");
+
+ while d.remaining() > 0 {
+ match TransportParameter::decode(d) {
+ Ok(Some((tipe, tp))) => {
+ tps.set(tipe, tp);
+ }
+ Ok(None) => {}
+ Err(e) => return Err(e),
+ }
+ }
+ Ok(tps)
+ }
+
+ pub(crate) fn encode(&self, enc: &mut Encoder) {
+ for (tipe, tp) in &self.params {
+ tp.encode(enc, *tipe);
+ }
+ }
+
+ // Get an integer type or a default.
+ pub fn get_integer(&self, tp: TransportParameterId) -> u64 {
+ let default = match tp {
+ IDLE_TIMEOUT
+ | INITIAL_MAX_DATA
+ | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL
+ | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE
+ | INITIAL_MAX_STREAM_DATA_UNI
+ | INITIAL_MAX_STREAMS_BIDI
+ | INITIAL_MAX_STREAMS_UNI
+ | MIN_ACK_DELAY
+ | MAX_DATAGRAM_FRAME_SIZE => 0,
+ MAX_UDP_PAYLOAD_SIZE => 65527,
+ ACK_DELAY_EXPONENT => 3,
+ MAX_ACK_DELAY => 25,
+ ACTIVE_CONNECTION_ID_LIMIT => 2,
+ _ => panic!("Transport parameter not known or not an Integer"),
+ };
+ match self.params.get(&tp) {
+ None => default,
+ Some(TransportParameter::Integer(x)) => *x,
+ _ => panic!("Internal error"),
+ }
+ }
+
+ // Set an integer type or a default.
+ pub fn set_integer(&mut self, tp: TransportParameterId, value: u64) {
+ match tp {
+ IDLE_TIMEOUT
+ | INITIAL_MAX_DATA
+ | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL
+ | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE
+ | INITIAL_MAX_STREAM_DATA_UNI
+ | INITIAL_MAX_STREAMS_BIDI
+ | INITIAL_MAX_STREAMS_UNI
+ | MAX_UDP_PAYLOAD_SIZE
+ | ACK_DELAY_EXPONENT
+ | MAX_ACK_DELAY
+ | ACTIVE_CONNECTION_ID_LIMIT
+ | MIN_ACK_DELAY
+ | MAX_DATAGRAM_FRAME_SIZE => {
+ self.set(tp, TransportParameter::Integer(value));
+ }
+ _ => panic!("Transport parameter not known"),
+ }
+ }
+
+ pub fn get_bytes(&self, tp: TransportParameterId) -> Option<&[u8]> {
+ match tp {
+ ORIGINAL_DESTINATION_CONNECTION_ID
+ | INITIAL_SOURCE_CONNECTION_ID
+ | RETRY_SOURCE_CONNECTION_ID
+ | STATELESS_RESET_TOKEN => {}
+ _ => panic!("Transport parameter not known or not type bytes"),
+ }
+
+ match self.params.get(&tp) {
+ None => None,
+ Some(TransportParameter::Bytes(x)) => Some(x),
+ _ => panic!("Internal error"),
+ }
+ }
+
+ pub fn set_bytes(&mut self, tp: TransportParameterId, value: Vec<u8>) {
+ match tp {
+ ORIGINAL_DESTINATION_CONNECTION_ID
+ | INITIAL_SOURCE_CONNECTION_ID
+ | RETRY_SOURCE_CONNECTION_ID
+ | STATELESS_RESET_TOKEN => {
+ self.set(tp, TransportParameter::Bytes(value));
+ }
+ _ => panic!("Transport parameter not known or not type bytes"),
+ }
+ }
+
+ pub fn set_empty(&mut self, tp: TransportParameterId) {
+ match tp {
+ DISABLE_MIGRATION | GREASE_QUIC_BIT => {
+ self.set(tp, TransportParameter::Empty);
+ }
+ _ => panic!("Transport parameter not known or not type empty"),
+ }
+ }
+
+ /// Set version information.
+ pub fn set_versions(&mut self, role: Role, versions: &VersionConfig) {
+ let rbuf = random(4);
+ let mut other = Vec::with_capacity(versions.all().len() + 1);
+ let mut dec = Decoder::new(&rbuf);
+ let grease = (dec.decode_uint(4).unwrap() as u32) & 0xf0f0_f0f0 | 0x0a0a_0a0a;
+ other.push(grease);
+ for &v in versions.all() {
+ if role == Role::Client && !versions.initial().is_compatible(v) {
+ continue;
+ }
+ other.push(v.wire_version());
+ }
+ let current = versions.initial().wire_version();
+ self.set(
+ VERSION_INFORMATION,
+ TransportParameter::Versions { current, other },
+ );
+ }
+
+ fn compatible_upgrade(&mut self, v: Version) {
+ if let Some(TransportParameter::Versions {
+ ref mut current, ..
+ }) = self.params.get_mut(&VERSION_INFORMATION)
+ {
+ *current = v.wire_version();
+ } else {
+ unreachable!("Compatible upgrade without transport parameters set!");
+ }
+ }
+
+ pub fn get_empty(&self, tipe: TransportParameterId) -> bool {
+ match self.params.get(&tipe) {
+ None => false,
+ Some(TransportParameter::Empty) => true,
+ _ => panic!("Internal error"),
+ }
+ }
+
+ /// Return true if the remembered transport parameters are OK for 0-RTT.
+ /// Generally this means that any value that is currently in effect is greater than
+ /// or equal to the promised value.
+ pub(crate) fn ok_for_0rtt(&self, remembered: &Self) -> bool {
+ for (k, v_rem) in &remembered.params {
+ // Skip checks for these, which don't affect 0-RTT.
+ if matches!(
+ *k,
+ ORIGINAL_DESTINATION_CONNECTION_ID
+ | INITIAL_SOURCE_CONNECTION_ID
+ | RETRY_SOURCE_CONNECTION_ID
+ | STATELESS_RESET_TOKEN
+ | IDLE_TIMEOUT
+ | ACK_DELAY_EXPONENT
+ | MAX_ACK_DELAY
+ | ACTIVE_CONNECTION_ID_LIMIT
+ | PREFERRED_ADDRESS
+ ) {
+ continue;
+ }
+ let ok = if let Some(v_self) = self.params.get(k) {
+ match (v_self, v_rem) {
+ (TransportParameter::Integer(i_self), TransportParameter::Integer(i_rem)) => {
+ if *k == MIN_ACK_DELAY {
+ // MIN_ACK_DELAY is backwards:
+ // it can only be reduced safely.
+ *i_self <= *i_rem
+ } else {
+ *i_self >= *i_rem
+ }
+ }
+ (TransportParameter::Empty, TransportParameter::Empty) => true,
+ (
+ TransportParameter::Versions {
+ current: v_self, ..
+ },
+ TransportParameter::Versions { current: v_rem, .. },
+ ) => v_self == v_rem,
+ _ => false,
+ }
+ } else {
+ false
+ };
+ if !ok {
+ return false;
+ }
+ }
+ true
+ }
+
+ /// Get the preferred address in a usable form.
+ #[must_use]
+ pub fn get_preferred_address(&self) -> Option<(PreferredAddress, ConnectionIdEntry<[u8; 16]>)> {
+ if let Some(TransportParameter::PreferredAddress { v4, v6, cid, srt }) =
+ self.params.get(&PREFERRED_ADDRESS)
+ {
+ Some((
+ PreferredAddress::new(*v4, *v6),
+ ConnectionIdEntry::new(CONNECTION_ID_SEQNO_PREFERRED, cid.clone(), *srt),
+ ))
+ } else {
+ None
+ }
+ }
+
+ /// Get the version negotiation values for validation.
+ #[must_use]
+ pub fn get_versions(&self) -> Option<(WireVersion, &[WireVersion])> {
+ if let Some(TransportParameter::Versions { current, other }) =
+ self.params.get(&VERSION_INFORMATION)
+ {
+ Some((*current, other))
+ } else {
+ None
+ }
+ }
+
+ #[must_use]
+ pub fn has_value(&self, tp: TransportParameterId) -> bool {
+ self.params.contains_key(&tp)
+ }
+}
+
+#[derive(Debug)]
+pub struct TransportParametersHandler {
+ role: Role,
+ versions: VersionConfig,
+ pub(crate) local: TransportParameters,
+ pub(crate) remote: Option<TransportParameters>,
+ pub(crate) remote_0rtt: Option<TransportParameters>,
+}
+
+impl TransportParametersHandler {
+ pub fn new(role: Role, versions: VersionConfig) -> Self {
+ let mut local = TransportParameters::default();
+ local.set_versions(role, &versions);
+ Self {
+ role,
+ versions,
+ local,
+ remote: None,
+ remote_0rtt: None,
+ }
+ }
+
+ /// When resuming, the version is set based on the ticket.
+ /// That needs to be done to override the default choice from configuration.
+ pub fn set_version(&mut self, version: Version) {
+ debug_assert_eq!(self.role, Role::Client);
+ self.versions.set_initial(version);
+ self.local.set_versions(self.role, &self.versions);
+ }
+
+ pub fn remote(&self) -> &TransportParameters {
+ match (self.remote.as_ref(), self.remote_0rtt.as_ref()) {
+ (Some(tp), _) | (_, Some(tp)) => tp,
+ _ => panic!("no transport parameters from peer"),
+ }
+ }
+
+ /// Get the version as set (or as determined by a compatible upgrade).
+ pub fn version(&self) -> Version {
+ self.versions.initial()
+ }
+
+ fn compatible_upgrade(&mut self, remote_tp: &TransportParameters) -> Res<()> {
+ if let Some((current, other)) = remote_tp.get_versions() {
+ qtrace!(
+ "Peer versions: {:x} {:x?}; config {:?}",
+ current,
+ other,
+ self.versions,
+ );
+
+ if self.role == Role::Client {
+ let chosen = Version::try_from(current)?;
+ if self.versions.compatible().any(|&v| v == chosen) {
+ Ok(())
+ } else {
+ qinfo!(
+ "Chosen version {:x} is not compatible with initial version {:x}",
+ current,
+ self.versions.initial().wire_version(),
+ );
+ Err(Error::TransportParameterError)
+ }
+ } else {
+ if current != self.versions.initial().wire_version() {
+ qinfo!(
+ "Current version {:x} != own version {:x}",
+ current,
+ self.versions.initial().wire_version(),
+ );
+ return Err(Error::TransportParameterError);
+ }
+
+ if let Some(preferred) = self.versions.preferred_compatible(other) {
+ if preferred != self.versions.initial() {
+ qinfo!(
+ "Compatible upgrade {:?} ==> {:?}",
+ self.versions.initial(),
+ preferred
+ );
+ self.versions.set_initial(preferred);
+ self.local.compatible_upgrade(preferred);
+ }
+ Ok(())
+ } else {
+ qinfo!("Unable to find any compatible version");
+ Err(Error::TransportParameterError)
+ }
+ }
+ } else {
+ Ok(())
+ }
+ }
+}
+
+impl ExtensionHandler for TransportParametersHandler {
+ fn write(&mut self, msg: HandshakeMessage, d: &mut [u8]) -> ExtensionWriterResult {
+ if !matches!(msg, TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS) {
+ return ExtensionWriterResult::Skip;
+ }
+
+ qdebug!("Writing transport parameters, msg={:?}", msg);
+
+ // TODO(ekr@rtfm.com): Modify to avoid a copy.
+ let mut enc = Encoder::default();
+ self.local.encode(&mut enc);
+ assert!(enc.len() <= d.len());
+ d[..enc.len()].copy_from_slice(enc.as_ref());
+ ExtensionWriterResult::Write(enc.len())
+ }
+
+ fn handle(&mut self, msg: HandshakeMessage, d: &[u8]) -> ExtensionHandlerResult {
+ qtrace!(
+ "Handling transport parameters, msg={:?} value={}",
+ msg,
+ hex(d),
+ );
+
+ if !matches!(msg, TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS) {
+ return ExtensionHandlerResult::Alert(110); // unsupported_extension
+ }
+
+ let mut dec = Decoder::from(d);
+ match TransportParameters::decode(&mut dec) {
+ Ok(tp) => {
+ if self.compatible_upgrade(&tp).is_ok() {
+ self.remote = Some(tp);
+ ExtensionHandlerResult::Ok
+ } else {
+ ExtensionHandlerResult::Alert(47)
+ }
+ }
+ _ => ExtensionHandlerResult::Alert(47), // illegal_parameter
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct TpZeroRttChecker<T> {
+ handler: Rc<RefCell<TransportParametersHandler>>,
+ app_checker: T,
+}
+
+impl<T> TpZeroRttChecker<T>
+where
+ T: ZeroRttChecker + 'static,
+{
+ pub fn wrap(
+ handler: Rc<RefCell<TransportParametersHandler>>,
+ app_checker: T,
+ ) -> Box<dyn ZeroRttChecker> {
+ Box::new(Self {
+ handler,
+ app_checker,
+ })
+ }
+}
+
+impl<T> ZeroRttChecker for TpZeroRttChecker<T>
+where
+ T: ZeroRttChecker,
+{
+ fn check(&self, token: &[u8]) -> ZeroRttCheckResult {
+ // Reject 0-RTT if there is no token.
+ if token.is_empty() {
+ qdebug!("0-RTT: no token, no 0-RTT");
+ return ZeroRttCheckResult::Reject;
+ }
+ let mut dec = Decoder::from(token);
+ let Some(tpslice) = dec.decode_vvec() else {
+ qinfo!("0-RTT: token code error");
+ return ZeroRttCheckResult::Fail;
+ };
+ let mut dec_tp = Decoder::from(tpslice);
+ let Ok(remembered) = TransportParameters::decode(&mut dec_tp) else {
+ qinfo!("0-RTT: transport parameter decode error");
+ return ZeroRttCheckResult::Fail;
+ };
+ if self.handler.borrow().local.ok_for_0rtt(&remembered) {
+ qinfo!("0-RTT: transport parameters OK, passing to application checker");
+ self.app_checker.check(dec.decode_remainder())
+ } else {
+ qinfo!("0-RTT: transport parameters bad, rejecting");
+ ZeroRttCheckResult::Reject
+ }
+ }
+}
+
+#[cfg(test)]
+#[allow(unused_variables)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn basic_tps() {
+ const RESET_TOKEN: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8];
+ let mut tps = TransportParameters::default();
+ tps.set(
+ STATELESS_RESET_TOKEN,
+ TransportParameter::Bytes(RESET_TOKEN.to_vec()),
+ );
+ tps.params
+ .insert(INITIAL_MAX_STREAMS_BIDI, TransportParameter::Integer(10));
+
+ let mut enc = Encoder::default();
+ tps.encode(&mut enc);
+
+ let tps2 = TransportParameters::decode(&mut enc.as_decoder()).expect("Couldn't decode");
+ assert_eq!(tps, tps2);
+
+ println!("TPS = {tps:?}");
+ assert_eq!(tps2.get_integer(IDLE_TIMEOUT), 0); // Default
+ assert_eq!(tps2.get_integer(MAX_ACK_DELAY), 25); // Default
+ assert_eq!(tps2.get_integer(ACTIVE_CONNECTION_ID_LIMIT), 2); // Default
+ assert_eq!(tps2.get_integer(INITIAL_MAX_STREAMS_BIDI), 10); // Sent
+ assert_eq!(tps2.get_bytes(STATELESS_RESET_TOKEN), Some(RESET_TOKEN));
+ assert_eq!(tps2.get_bytes(ORIGINAL_DESTINATION_CONNECTION_ID), None);
+ assert_eq!(tps2.get_bytes(INITIAL_SOURCE_CONNECTION_ID), None);
+ assert_eq!(tps2.get_bytes(RETRY_SOURCE_CONNECTION_ID), None);
+ assert!(!tps2.has_value(ORIGINAL_DESTINATION_CONNECTION_ID));
+ assert!(!tps2.has_value(INITIAL_SOURCE_CONNECTION_ID));
+ assert!(!tps2.has_value(RETRY_SOURCE_CONNECTION_ID));
+ assert!(tps2.has_value(STATELESS_RESET_TOKEN));
+
+ let mut enc = Encoder::default();
+ tps.encode(&mut enc);
+
+ let tps2 = TransportParameters::decode(&mut enc.as_decoder()).expect("Couldn't decode");
+ }
+
+ fn make_spa() -> TransportParameter {
+ TransportParameter::PreferredAddress {
+ v4: Some(SocketAddrV4::new(Ipv4Addr::from(0xc000_0201), 443)),
+ v6: Some(SocketAddrV6::new(
+ Ipv6Addr::from(0xfe80_0000_0000_0000_0000_0000_0000_0001),
+ 443,
+ 0,
+ 0,
+ )),
+ cid: ConnectionId::from(&[1, 2, 3, 4, 5]),
+ srt: [3; 16],
+ }
+ }
+
+ #[test]
+ fn preferred_address_encode_decode() {
+ const ENCODED: &[u8] = &[
+ 0x0d, 0x2e, 0xc0, 0x00, 0x02, 0x01, 0x01, 0xbb, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0xbb, 0x05, 0x01,
+ 0x02, 0x03, 0x04, 0x05, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
+ 0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
+ ];
+ let spa = make_spa();
+ let mut enc = Encoder::new();
+ spa.encode(&mut enc, PREFERRED_ADDRESS);
+ assert_eq!(enc.as_ref(), ENCODED);
+
+ let mut dec = enc.as_decoder();
+ let (id, decoded) = TransportParameter::decode(&mut dec).unwrap().unwrap();
+ assert_eq!(id, PREFERRED_ADDRESS);
+ assert_eq!(decoded, spa);
+ }
+
+ fn mutate_spa<F>(wrecker: F) -> TransportParameter
+ where
+ F: FnOnce(&mut Option<SocketAddrV4>, &mut Option<SocketAddrV6>, &mut ConnectionId),
+ {
+ let mut spa = make_spa();
+ if let TransportParameter::PreferredAddress {
+ ref mut v4,
+ ref mut v6,
+ ref mut cid,
+ ..
+ } = &mut spa
+ {
+ wrecker(v4, v6, cid);
+ } else {
+ unreachable!();
+ }
+ spa
+ }
+
+ /// This takes a `TransportParameter::PreferredAddress` that has been mutilated.
+ /// It then encodes it, working from the knowledge that the `encode` function
+ /// doesn't care about validity, and decodes it. The result should be failure.
+ fn assert_invalid_spa(spa: TransportParameter) {
+ let mut enc = Encoder::new();
+ spa.encode(&mut enc, PREFERRED_ADDRESS);
+ assert_eq!(
+ TransportParameter::decode(&mut enc.as_decoder()).unwrap_err(),
+ Error::TransportParameterError
+ );
+ }
+
+ /// This is for those rare mutations that are acceptable.
+ fn assert_valid_spa(spa: TransportParameter) {
+ let mut enc = Encoder::new();
+ spa.encode(&mut enc, PREFERRED_ADDRESS);
+ let mut dec = enc.as_decoder();
+ let (id, decoded) = TransportParameter::decode(&mut dec).unwrap().unwrap();
+ assert_eq!(id, PREFERRED_ADDRESS);
+ assert_eq!(decoded, spa);
+ }
+
+ #[test]
+ fn preferred_address_zero_address() {
+ // Either port being zero is bad.
+ assert_invalid_spa(mutate_spa(|v4, _, _| {
+ v4.as_mut().unwrap().set_port(0);
+ }));
+ assert_invalid_spa(mutate_spa(|_, v6, _| {
+ v6.as_mut().unwrap().set_port(0);
+ }));
+ // Either IP being zero is bad.
+ assert_invalid_spa(mutate_spa(|v4, _, _| {
+ v4.as_mut().unwrap().set_ip(Ipv4Addr::from(0));
+ }));
+ assert_invalid_spa(mutate_spa(|_, v6, _| {
+ v6.as_mut().unwrap().set_ip(Ipv6Addr::from(0));
+ }));
+ // Either address being absent is OK.
+ assert_valid_spa(mutate_spa(|v4, _, _| {
+ *v4 = None;
+ }));
+ assert_valid_spa(mutate_spa(|_, v6, _| {
+ *v6 = None;
+ }));
+ // Both addresses being absent is bad.
+ assert_invalid_spa(mutate_spa(|v4, v6, _| {
+ *v4 = None;
+ *v6 = None;
+ }));
+ }
+
+ #[test]
+ fn preferred_address_bad_cid() {
+ assert_invalid_spa(mutate_spa(|_, _, cid| {
+ *cid = ConnectionId::from(&[]);
+ }));
+ assert_invalid_spa(mutate_spa(|_, _, cid| {
+ *cid = ConnectionId::from(&[0x0c; 21]);
+ }));
+ }
+
+ #[test]
+ fn preferred_address_truncated() {
+ let spa = make_spa();
+ let mut enc = Encoder::new();
+ spa.encode(&mut enc, PREFERRED_ADDRESS);
+ let mut dec = Decoder::from(&enc.as_ref()[..enc.len() - 1]);
+ assert_eq!(
+ TransportParameter::decode(&mut dec).unwrap_err(),
+ Error::NoMoreData
+ );
+ }
+
+ #[test]
+ #[should_panic(expected = "v4.is_some() || v6.is_some()")]
+ fn preferred_address_neither() {
+ _ = PreferredAddress::new(None, None);
+ }
+
+ #[test]
+ #[should_panic(expected = ".is_unspecified")]
+ fn preferred_address_v4_unspecified() {
+ _ = PreferredAddress::new(Some(SocketAddrV4::new(Ipv4Addr::from(0), 443)), None);
+ }
+
+ #[test]
+ #[should_panic(expected = "left != right")]
+ fn preferred_address_v4_zero_port() {
+ _ = PreferredAddress::new(
+ Some(SocketAddrV4::new(Ipv4Addr::from(0xc000_0201), 0)),
+ None,
+ );
+ }
+
+ #[test]
+ #[should_panic(expected = ".is_unspecified")]
+ fn preferred_address_v6_unspecified() {
+ _ = PreferredAddress::new(None, Some(SocketAddrV6::new(Ipv6Addr::from(0), 443, 0, 0)));
+ }
+
+ #[test]
+ #[should_panic(expected = "left != right")]
+ fn preferred_address_v6_zero_port() {
+ _ = PreferredAddress::new(None, Some(SocketAddrV6::new(Ipv6Addr::from(1), 0, 0, 0)));
+ }
+
+ #[test]
+ fn compatible_0rtt_ignored_values() {
+ let mut tps_a = TransportParameters::default();
+ tps_a.set(
+ STATELESS_RESET_TOKEN,
+ TransportParameter::Bytes(vec![1, 2, 3]),
+ );
+ tps_a.set(IDLE_TIMEOUT, TransportParameter::Integer(10));
+ tps_a.set(MAX_ACK_DELAY, TransportParameter::Integer(22));
+ tps_a.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(33));
+
+ let mut tps_b = TransportParameters::default();
+ assert!(tps_a.ok_for_0rtt(&tps_b));
+ assert!(tps_b.ok_for_0rtt(&tps_a));
+
+ tps_b.set(
+ STATELESS_RESET_TOKEN,
+ TransportParameter::Bytes(vec![8, 9, 10]),
+ );
+ tps_b.set(IDLE_TIMEOUT, TransportParameter::Integer(100));
+ tps_b.set(MAX_ACK_DELAY, TransportParameter::Integer(2));
+ tps_b.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(44));
+ assert!(tps_a.ok_for_0rtt(&tps_b));
+ assert!(tps_b.ok_for_0rtt(&tps_a));
+ }
+
+ #[test]
+ fn compatible_0rtt_integers() {
+ let mut tps_a = TransportParameters::default();
+ const INTEGER_KEYS: &[TransportParameterId] = &[
+ INITIAL_MAX_DATA,
+ INITIAL_MAX_STREAM_DATA_BIDI_LOCAL,
+ INITIAL_MAX_STREAM_DATA_BIDI_REMOTE,
+ INITIAL_MAX_STREAM_DATA_UNI,
+ INITIAL_MAX_STREAMS_BIDI,
+ INITIAL_MAX_STREAMS_UNI,
+ MAX_UDP_PAYLOAD_SIZE,
+ MIN_ACK_DELAY,
+ MAX_DATAGRAM_FRAME_SIZE,
+ ];
+ for i in INTEGER_KEYS {
+ tps_a.set(*i, TransportParameter::Integer(12));
+ }
+
+ let tps_b = tps_a.clone();
+ assert!(tps_a.ok_for_0rtt(&tps_b));
+ assert!(tps_b.ok_for_0rtt(&tps_a));
+
+ // For each integer key, choose a new value that will be accepted.
+ for i in INTEGER_KEYS {
+ let mut tps_b = tps_a.clone();
+ // Set a safe new value; reducing MIN_ACK_DELAY instead.
+ let safe_value = if *i == MIN_ACK_DELAY { 11 } else { 13 };
+ tps_b.set(*i, TransportParameter::Integer(safe_value));
+ // If the new value is not safe relative to the remembered value,
+ // then we can't attempt 0-RTT with these parameters.
+ assert!(!tps_a.ok_for_0rtt(&tps_b));
+ // The opposite situation is fine.
+ assert!(tps_b.ok_for_0rtt(&tps_a));
+ }
+
+ // Drop integer values and check that that is OK.
+ for i in INTEGER_KEYS {
+ let mut tps_b = tps_a.clone();
+ tps_b.remove(*i);
+ // A value that is missing from what is rememebered is OK.
+ assert!(tps_a.ok_for_0rtt(&tps_b));
+ // A value that is rememebered, but not current is not OK.
+ assert!(!tps_b.ok_for_0rtt(&tps_a));
+ }
+ }
+
+ /// `ACTIVE_CONNECTION_ID_LIMIT` can't be less than 2.
+ #[test]
+ fn active_connection_id_limit_min_2() {
+ let mut tps = TransportParameters::default();
+
+ // Intentionally set an invalid value for the ACTIVE_CONNECTION_ID_LIMIT transport
+ // parameter.
+ tps.params
+ .insert(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(1));
+
+ let mut enc = Encoder::default();
+ tps.encode(&mut enc);
+
+ // When decoding a set of transport parameters with an invalid ACTIVE_CONNECTION_ID_LIMIT
+ // the result should be an error.
+ let invalid_decode_result = TransportParameters::decode(&mut enc.as_decoder());
+ assert!(invalid_decode_result.is_err());
+ }
+
+ #[test]
+ fn versions_encode_decode() {
+ const ENCODED: &[u8] = &[
+ 0x11, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x1a, 0x2a, 0x3a, 0x4a, 0x5a, 0x6a, 0x7a, 0x8a,
+ ];
+ let vn = TransportParameter::Versions {
+ current: Version::Version1.wire_version(),
+ other: vec![0x1a2a_3a4a, 0x5a6a_7a8a],
+ };
+
+ let mut enc = Encoder::new();
+ vn.encode(&mut enc, VERSION_INFORMATION);
+ assert_eq!(enc.as_ref(), ENCODED);
+
+ let mut dec = enc.as_decoder();
+ let (id, decoded) = TransportParameter::decode(&mut dec).unwrap().unwrap();
+ assert_eq!(id, VERSION_INFORMATION);
+ assert_eq!(decoded, vn);
+ }
+
+ #[test]
+ fn versions_truncated() {
+ const TRUNCATED: &[u8] = &[
+ 0x80, 0xff, 0x73, 0xdb, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x1a, 0x2a, 0x3a, 0x4a, 0x5a,
+ 0x6a, 0x7a,
+ ];
+ let mut dec = Decoder::from(&TRUNCATED);
+ assert_eq!(
+ TransportParameter::decode(&mut dec).unwrap_err(),
+ Error::NoMoreData
+ );
+ }
+
+ #[test]
+ fn versions_zero() {
+ const ZERO1: &[u8] = &[0x11, 0x04, 0x00, 0x00, 0x00, 0x00];
+ const ZERO2: &[u8] = &[0x11, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00];
+
+ let mut dec = Decoder::from(&ZERO1);
+ assert_eq!(
+ TransportParameter::decode(&mut dec).unwrap_err(),
+ Error::TransportParameterError
+ );
+ let mut dec = Decoder::from(&ZERO2);
+ assert_eq!(
+ TransportParameter::decode(&mut dec).unwrap_err(),
+ Error::TransportParameterError
+ );
+ }
+
+ #[test]
+ fn versions_equal_0rtt() {
+ let mut current = TransportParameters::default();
+ current.set(
+ VERSION_INFORMATION,
+ TransportParameter::Versions {
+ current: Version::Version1.wire_version(),
+ other: vec![0x1a2a_3a4a],
+ },
+ );
+
+ let mut remembered = TransportParameters::default();
+ // It's OK to not remember having versions.
+ assert!(current.ok_for_0rtt(&remembered));
+ // But it is bad in the opposite direction.
+ assert!(!remembered.ok_for_0rtt(&current));
+
+ // If the version matches, it's OK to use 0-RTT.
+ remembered.set(
+ VERSION_INFORMATION,
+ TransportParameter::Versions {
+ current: Version::Version1.wire_version(),
+ other: vec![0x5a6a_7a8a, 0x9aaa_baca],
+ },
+ );
+ assert!(current.ok_for_0rtt(&remembered));
+ assert!(remembered.ok_for_0rtt(&current));
+
+ // An apparent "upgrade" is still cause to reject 0-RTT.
+ remembered.set(
+ VERSION_INFORMATION,
+ TransportParameter::Versions {
+ current: Version::Version1.wire_version() + 1,
+ other: vec![],
+ },
+ );
+ assert!(!current.ok_for_0rtt(&remembered));
+ assert!(!remembered.ok_for_0rtt(&current));
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/tracking.rs b/third_party/rust/neqo-transport/src/tracking.rs
new file mode 100644
index 0000000000..64d00257d3
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/tracking.rs
@@ -0,0 +1,1228 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Tracking of received packets and generating acks thereof.
+
+#![deny(clippy::pedantic)]
+
+use std::{
+ cmp::min,
+ collections::VecDeque,
+ convert::TryFrom,
+ ops::{Index, IndexMut},
+ time::{Duration, Instant},
+};
+
+use neqo_common::{qdebug, qinfo, qtrace, qwarn};
+use neqo_crypto::{Epoch, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL};
+use smallvec::{smallvec, SmallVec};
+
+use crate::{
+ packet::{PacketBuilder, PacketNumber, PacketType},
+ recovery::RecoveryToken,
+ stats::FrameStats,
+};
+
+// TODO(mt) look at enabling EnumMap for this: https://stackoverflow.com/a/44905797/1375574
+#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq)]
+pub enum PacketNumberSpace {
+ Initial,
+ Handshake,
+ ApplicationData,
+}
+
+#[allow(clippy::use_self)] // https://github.com/rust-lang/rust-clippy/issues/3410
+impl PacketNumberSpace {
+ pub fn iter() -> impl Iterator<Item = &'static PacketNumberSpace> {
+ const SPACES: &[PacketNumberSpace] = &[
+ PacketNumberSpace::Initial,
+ PacketNumberSpace::Handshake,
+ PacketNumberSpace::ApplicationData,
+ ];
+ SPACES.iter()
+ }
+}
+
+impl From<Epoch> for PacketNumberSpace {
+ fn from(epoch: Epoch) -> Self {
+ match epoch {
+ TLS_EPOCH_INITIAL => Self::Initial,
+ TLS_EPOCH_HANDSHAKE => Self::Handshake,
+ _ => Self::ApplicationData,
+ }
+ }
+}
+
+impl From<PacketType> for PacketNumberSpace {
+ fn from(pt: PacketType) -> Self {
+ match pt {
+ PacketType::Initial => Self::Initial,
+ PacketType::Handshake => Self::Handshake,
+ PacketType::ZeroRtt | PacketType::Short => Self::ApplicationData,
+ _ => panic!("Attempted to get space from wrong packet type"),
+ }
+ }
+}
+
+#[derive(Clone, Copy, Default)]
+pub struct PacketNumberSpaceSet {
+ initial: bool,
+ handshake: bool,
+ application_data: bool,
+}
+
+impl PacketNumberSpaceSet {
+ pub fn all() -> Self {
+ Self {
+ initial: true,
+ handshake: true,
+ application_data: true,
+ }
+ }
+}
+
+impl Index<PacketNumberSpace> for PacketNumberSpaceSet {
+ type Output = bool;
+
+ fn index(&self, space: PacketNumberSpace) -> &Self::Output {
+ match space {
+ PacketNumberSpace::Initial => &self.initial,
+ PacketNumberSpace::Handshake => &self.handshake,
+ PacketNumberSpace::ApplicationData => &self.application_data,
+ }
+ }
+}
+
+impl IndexMut<PacketNumberSpace> for PacketNumberSpaceSet {
+ fn index_mut(&mut self, space: PacketNumberSpace) -> &mut Self::Output {
+ match space {
+ PacketNumberSpace::Initial => &mut self.initial,
+ PacketNumberSpace::Handshake => &mut self.handshake,
+ PacketNumberSpace::ApplicationData => &mut self.application_data,
+ }
+ }
+}
+
+impl<T: AsRef<[PacketNumberSpace]>> From<T> for PacketNumberSpaceSet {
+ fn from(spaces: T) -> Self {
+ let mut v = Self::default();
+ for sp in spaces.as_ref() {
+ v[*sp] = true;
+ }
+ v
+ }
+}
+
+impl std::fmt::Debug for PacketNumberSpaceSet {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ let mut first = true;
+ f.write_str("(")?;
+ for sp in PacketNumberSpace::iter() {
+ if self[*sp] {
+ if !first {
+ f.write_str("+")?;
+ first = false;
+ }
+ std::fmt::Display::fmt(sp, f)?;
+ }
+ }
+ f.write_str(")")
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct SentPacket {
+ pub pt: PacketType,
+ pub pn: PacketNumber,
+ ack_eliciting: bool,
+ pub time_sent: Instant,
+ primary_path: bool,
+ pub tokens: Vec<RecoveryToken>,
+
+ time_declared_lost: Option<Instant>,
+ /// After a PTO, this is true when the packet has been released.
+ pto: bool,
+
+ pub size: usize,
+}
+
+impl SentPacket {
+ pub fn new(
+ pt: PacketType,
+ pn: PacketNumber,
+ time_sent: Instant,
+ ack_eliciting: bool,
+ tokens: Vec<RecoveryToken>,
+ size: usize,
+ ) -> Self {
+ Self {
+ pt,
+ pn,
+ time_sent,
+ ack_eliciting,
+ primary_path: true,
+ tokens,
+ time_declared_lost: None,
+ pto: false,
+ size,
+ }
+ }
+
+ /// Returns `true` if the packet will elicit an ACK.
+ pub fn ack_eliciting(&self) -> bool {
+ self.ack_eliciting
+ }
+
+ /// Returns `true` if the packet was sent on the primary path.
+ pub fn on_primary_path(&self) -> bool {
+ self.primary_path
+ }
+
+ /// Clears the flag that had this packet on the primary path.
+ /// Used when migrating to clear out state.
+ pub fn clear_primary_path(&mut self) {
+ self.primary_path = false;
+ }
+
+ /// Whether the packet has been declared lost.
+ pub fn lost(&self) -> bool {
+ self.time_declared_lost.is_some()
+ }
+
+ /// Whether accounting for the loss or acknowledgement in the
+ /// congestion controller is pending.
+ /// Returns `true` if the packet counts as being "in flight",
+ /// and has not previously been declared lost.
+ /// Note that this should count packets that contain only ACK and PADDING,
+ /// but we don't send PADDING, so we don't track that.
+ pub fn cc_outstanding(&self) -> bool {
+ self.ack_eliciting() && self.on_primary_path() && !self.lost()
+ }
+
+ /// Whether the packet should be tracked as in-flight.
+ pub fn cc_in_flight(&self) -> bool {
+ self.ack_eliciting() && self.on_primary_path()
+ }
+
+ /// Declare the packet as lost. Returns `true` if this is the first time.
+ pub fn declare_lost(&mut self, now: Instant) -> bool {
+ if self.lost() {
+ false
+ } else {
+ self.time_declared_lost = Some(now);
+ true
+ }
+ }
+
+ /// Ask whether this tracked packet has been declared lost for long enough
+ /// that it can be expired and no longer tracked.
+ pub fn expired(&self, now: Instant, expiration_period: Duration) -> bool {
+ self.time_declared_lost
+ .map_or(false, |loss_time| (loss_time + expiration_period) <= now)
+ }
+
+ /// Whether the packet contents were cleared out after a PTO.
+ pub fn pto_fired(&self) -> bool {
+ self.pto
+ }
+
+ /// On PTO, we need to get the recovery tokens so that we can ensure that
+ /// the frames we sent can be sent again in the PTO packet(s). Do that just once.
+ pub fn pto(&mut self) -> bool {
+ if self.pto || self.lost() {
+ false
+ } else {
+ self.pto = true;
+ true
+ }
+ }
+}
+
+impl std::fmt::Display for PacketNumberSpace {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ f.write_str(match self {
+ Self::Initial => "in",
+ Self::Handshake => "hs",
+ Self::ApplicationData => "ap",
+ })
+ }
+}
+
+/// `InsertionResult` tracks whether something was inserted for `PacketRange::add()`.
+pub enum InsertionResult {
+ Largest,
+ Smallest,
+ NotInserted,
+}
+
+#[derive(Clone, Debug, Default)]
+pub struct PacketRange {
+ largest: PacketNumber,
+ smallest: PacketNumber,
+ ack_needed: bool,
+}
+
+impl PacketRange {
+ /// Make a single packet range.
+ pub fn new(pn: PacketNumber) -> Self {
+ Self {
+ largest: pn,
+ smallest: pn,
+ ack_needed: true,
+ }
+ }
+
+ /// Get the number of acknowleged packets in the range.
+ pub fn len(&self) -> u64 {
+ self.largest - self.smallest + 1
+ }
+
+ /// Returns whether this needs to be sent.
+ pub fn ack_needed(&self) -> bool {
+ self.ack_needed
+ }
+
+ /// Return whether the given number is in the range.
+ pub fn contains(&self, pn: PacketNumber) -> bool {
+ (pn >= self.smallest) && (pn <= self.largest)
+ }
+
+ /// Maybe add a packet number to the range. Returns true if it was added
+ /// at the small end (which indicates that this might need merging with a
+ /// preceding range).
+ pub fn add(&mut self, pn: PacketNumber) -> InsertionResult {
+ assert!(!self.contains(pn));
+ // Only insert if this is adjacent the current range.
+ if (self.largest + 1) == pn {
+ qtrace!([self], "Adding largest {}", pn);
+ self.largest += 1;
+ self.ack_needed = true;
+ InsertionResult::Largest
+ } else if self.smallest == (pn + 1) {
+ qtrace!([self], "Adding smallest {}", pn);
+ self.smallest -= 1;
+ self.ack_needed = true;
+ InsertionResult::Smallest
+ } else {
+ InsertionResult::NotInserted
+ }
+ }
+
+ /// Maybe merge a higher-numbered range into this.
+ fn merge_larger(&mut self, other: &Self) {
+ qinfo!([self], "Merging {}", other);
+ // This only works if they are immediately adjacent.
+ assert_eq!(self.largest + 1, other.smallest);
+
+ self.largest = other.largest;
+ self.ack_needed = self.ack_needed || other.ack_needed;
+ }
+
+ /// When a packet containing the range `other` is acknowledged,
+ /// clear the `ack_needed` attribute on this.
+ /// Requires that other is equal to this, or a larger range.
+ pub fn acknowledged(&mut self, other: &Self) {
+ if (other.smallest <= self.smallest) && (other.largest >= self.largest) {
+ self.ack_needed = false;
+ }
+ }
+}
+
+impl ::std::fmt::Display for PacketRange {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "{}->{}", self.largest, self.smallest)
+ }
+}
+
+/// The ACK delay we use.
+pub const DEFAULT_ACK_DELAY: Duration = Duration::from_millis(20); // 20ms
+/// The default number of in-order packets we will receive after
+/// largest acknowledged without sending an immediate acknowledgment.
+pub const DEFAULT_ACK_PACKET_TOLERANCE: PacketNumber = 1;
+const MAX_TRACKED_RANGES: usize = 32;
+const MAX_ACKS_PER_FRAME: usize = 32;
+
+/// A structure that tracks what was included in an ACK.
+#[derive(Debug, Clone)]
+pub struct AckToken {
+ space: PacketNumberSpace,
+ ranges: Vec<PacketRange>,
+}
+
+/// A structure that tracks what packets have been received,
+/// and what needs acknowledgement for a packet number space.
+#[derive(Debug)]
+pub struct RecvdPackets {
+ space: PacketNumberSpace,
+ ranges: VecDeque<PacketRange>,
+ /// The packet number of the lowest number packet that we are tracking.
+ min_tracked: PacketNumber,
+ /// The time we got the largest acknowledged.
+ largest_pn_time: Option<Instant>,
+ /// The time that we should be sending an ACK.
+ ack_time: Option<Instant>,
+ /// The time we last sent an ACK.
+ last_ack_time: Option<Instant>,
+ /// The current ACK frequency sequence number.
+ ack_frequency_seqno: u64,
+ /// The time to delay after receiving the first packet that is
+ /// not immediately acknowledged.
+ ack_delay: Duration,
+ /// The number of ack-eliciting packets that have been received, but
+ /// not acknowledged.
+ unacknowledged_count: PacketNumber,
+ /// The number of contiguous packets that can be received without
+ /// acknowledging immediately.
+ unacknowledged_tolerance: PacketNumber,
+ /// Whether we are ignoring packets that arrive out of order
+ /// for the purposes of generating immediate acknowledgment.
+ ignore_order: bool,
+}
+
+impl RecvdPackets {
+ /// Make a new `RecvdPackets` for the indicated packet number space.
+ pub fn new(space: PacketNumberSpace) -> Self {
+ Self {
+ space,
+ ranges: VecDeque::new(),
+ min_tracked: 0,
+ largest_pn_time: None,
+ ack_time: None,
+ last_ack_time: None,
+ ack_frequency_seqno: 0,
+ ack_delay: DEFAULT_ACK_DELAY,
+ unacknowledged_count: 0,
+ unacknowledged_tolerance: DEFAULT_ACK_PACKET_TOLERANCE,
+ ignore_order: false,
+ }
+ }
+
+ /// Get the time at which the next ACK should be sent.
+ pub fn ack_time(&self) -> Option<Instant> {
+ self.ack_time
+ }
+
+ /// Update acknowledgment delay parameters.
+ pub fn ack_freq(
+ &mut self,
+ seqno: u64,
+ tolerance: PacketNumber,
+ delay: Duration,
+ ignore_order: bool,
+ ) {
+ // Yes, this means that we will overwrite values if a sequence number is
+ // reused, but that is better than using an `Option<PacketNumber>`
+ // when it will always be `Some`.
+ if seqno >= self.ack_frequency_seqno {
+ self.ack_frequency_seqno = seqno;
+ self.unacknowledged_tolerance = tolerance;
+ self.ack_delay = delay;
+ self.ignore_order = ignore_order;
+ }
+ }
+
+ /// Returns true if an ACK frame should be sent now.
+ fn ack_now(&self, now: Instant, rtt: Duration) -> bool {
+ // If ack_time is Some, then we have something to acknowledge.
+ // In that case, either ack because `now >= ack_time`, or
+ // because it is more than an RTT since the last time we sent an ack.
+ self.ack_time.map_or(false, |next| {
+ next <= now || self.last_ack_time.map_or(false, |last| last + rtt <= now)
+ })
+ }
+
+ // A simple addition of a packet number to the tracked set.
+ // This doesn't do a binary search on the assumption that
+ // new packets will generally be added to the start of the list.
+ fn add(&mut self, pn: PacketNumber) {
+ for i in 0..self.ranges.len() {
+ match self.ranges[i].add(pn) {
+ InsertionResult::Largest => return,
+ InsertionResult::Smallest => {
+ // If this was the smallest, it might have filled a gap.
+ let nxt = i + 1;
+ if (nxt < self.ranges.len()) && (pn - 1 == self.ranges[nxt].largest) {
+ let larger = self.ranges.remove(i).unwrap();
+ self.ranges[i].merge_larger(&larger);
+ }
+ return;
+ }
+ InsertionResult::NotInserted => {
+ if self.ranges[i].largest < pn {
+ self.ranges.insert(i, PacketRange::new(pn));
+ return;
+ }
+ }
+ }
+ }
+ self.ranges.push_back(PacketRange::new(pn));
+ }
+
+ fn trim_ranges(&mut self) {
+ // Limit the number of ranges that are tracked to MAX_TRACKED_RANGES.
+ if self.ranges.len() > MAX_TRACKED_RANGES {
+ let oldest = self.ranges.pop_back().unwrap();
+ if oldest.ack_needed {
+ qwarn!([self], "Dropping unacknowledged ACK range: {}", oldest);
+ // TODO(mt) Record some statistics about this so we can tune MAX_TRACKED_RANGES.
+ } else {
+ qdebug!([self], "Drop ACK range: {}", oldest);
+ }
+ self.min_tracked = oldest.largest + 1;
+ }
+ }
+
+ /// Add the packet to the tracked set.
+ /// Return true if the packet was the largest received so far.
+ pub fn set_received(&mut self, now: Instant, pn: PacketNumber, ack_eliciting: bool) -> bool {
+ let next_in_order_pn = self.ranges.front().map_or(0, |r| r.largest + 1);
+ qdebug!([self], "received {}, next: {}", pn, next_in_order_pn);
+
+ self.add(pn);
+ self.trim_ranges();
+
+ // The new addition was the largest, so update the time we use for calculating ACK delay.
+ let largest = if pn >= next_in_order_pn {
+ self.largest_pn_time = Some(now);
+ true
+ } else {
+ false
+ };
+
+ if ack_eliciting {
+ self.unacknowledged_count += 1;
+
+ let immediate_ack = self.space != PacketNumberSpace::ApplicationData
+ || (pn != next_in_order_pn && !self.ignore_order)
+ || self.unacknowledged_count > self.unacknowledged_tolerance;
+
+ let ack_time = if immediate_ack {
+ now
+ } else {
+ // Note that `ack_delay` can change and that won't take effect if
+ // we are waiting on the previous delay timer.
+ // If ACK delay increases, we might send an ACK a bit early;
+ // if ACK delay decreases, we might send an ACK a bit later.
+ // We could use min() here, but change is rare and the size
+ // of the change is very small.
+ self.ack_time.unwrap_or_else(|| now + self.ack_delay)
+ };
+ qdebug!([self], "Set ACK timer to {:?}", ack_time);
+ self.ack_time = Some(ack_time);
+ }
+ largest
+ }
+
+ /// If we just received a PING frame, we should immediately acknowledge.
+ pub fn immediate_ack(&mut self, now: Instant) {
+ self.ack_time = Some(now);
+ qdebug!([self], "immediate_ack at {:?}", now);
+ }
+
+ /// Check if the packet is a duplicate.
+ pub fn is_duplicate(&self, pn: PacketNumber) -> bool {
+ if pn < self.min_tracked {
+ return true;
+ }
+ self.ranges
+ .iter()
+ .take_while(|r| pn <= r.largest)
+ .any(|r| r.contains(pn))
+ }
+
+ /// Mark the given range as having been acknowledged.
+ pub fn acknowledged(&mut self, acked: &[PacketRange]) {
+ let mut range_iter = self.ranges.iter_mut();
+ let mut cur = range_iter.next().expect("should have at least one range");
+ for ack in acked {
+ while cur.smallest > ack.largest {
+ cur = match range_iter.next() {
+ Some(c) => c,
+ None => return,
+ };
+ }
+ cur.acknowledged(ack);
+ }
+ }
+
+ /// Generate an ACK frame for this packet number space.
+ ///
+ /// Unlike other frame generators this doesn't modify the underlying instance
+ /// to track what has been sent. This only clears the delayed ACK timer.
+ ///
+ /// When sending ACKs, we want to always send the most recent ranges,
+ /// even if they have been sent in other packets.
+ ///
+ /// We don't send ranges that have been acknowledged, but they still need
+ /// to be tracked so that duplicates can be detected.
+ fn write_frame(
+ &mut self,
+ now: Instant,
+ rtt: Duration,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ // The worst possible ACK frame, assuming only one range.
+ // Note that this assumes one byte for the type and count of extra ranges.
+ const LONGEST_ACK_HEADER: usize = 1 + 8 + 8 + 1 + 8;
+
+ // Check that we aren't delaying ACKs.
+ if !self.ack_now(now, rtt) {
+ return;
+ }
+
+ // Drop extra ACK ranges to fit the available space. Do this based on
+ // a worst-case estimate of frame size for simplicity.
+ //
+ // When congestion limited, ACK-only packets are 255 bytes at most
+ // (`recovery::ACK_ONLY_SIZE_LIMIT - 1`). This results in limiting the
+ // ranges to 13 here.
+ let max_ranges = if let Some(avail) = builder.remaining().checked_sub(LONGEST_ACK_HEADER) {
+ // Apply a hard maximum to keep plenty of space for other stuff.
+ min(1 + (avail / 16), MAX_ACKS_PER_FRAME)
+ } else {
+ return;
+ };
+
+ let ranges = self
+ .ranges
+ .iter()
+ .filter(|r| r.ack_needed())
+ .take(max_ranges)
+ .cloned()
+ .collect::<Vec<_>>();
+
+ builder.encode_varint(crate::frame::FRAME_TYPE_ACK);
+ let mut iter = ranges.iter();
+ let Some(first) = iter.next() else { return };
+ builder.encode_varint(first.largest);
+ stats.largest_acknowledged = first.largest;
+ stats.ack += 1;
+
+ let elapsed = now.duration_since(self.largest_pn_time.unwrap());
+ // We use the default exponent, so delay is in multiples of 8 microseconds.
+ let ack_delay = u64::try_from(elapsed.as_micros() / 8).unwrap_or(u64::MAX);
+ let ack_delay = min((1 << 62) - 1, ack_delay);
+ builder.encode_varint(ack_delay);
+ builder.encode_varint(u64::try_from(ranges.len() - 1).unwrap()); // extra ranges
+ builder.encode_varint(first.len() - 1); // first range
+
+ let mut last = first.smallest;
+ for r in iter {
+ // the difference must be at least 2 because 0-length gaps,
+ // (difference 1) are illegal.
+ builder.encode_varint(last - r.largest - 2); // Gap
+ builder.encode_varint(r.len() - 1); // Range
+ last = r.smallest;
+ }
+
+ // We've sent an ACK, reset the timer.
+ self.ack_time = None;
+ self.last_ack_time = Some(now);
+ self.unacknowledged_count = 0;
+
+ tokens.push(RecoveryToken::Ack(AckToken {
+ space: self.space,
+ ranges,
+ }));
+ }
+}
+
+impl ::std::fmt::Display for RecvdPackets {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "Recvd-{}", self.space)
+ }
+}
+
+#[derive(Debug)]
+pub struct AckTracker {
+ /// This stores information about received packets in *reverse* order
+ /// by spaces. Why reverse? Because we ultimately only want to keep
+ /// `ApplicationData` and this allows us to drop other spaces easily.
+ spaces: SmallVec<[RecvdPackets; 1]>,
+}
+
+impl AckTracker {
+ pub fn drop_space(&mut self, space: PacketNumberSpace) {
+ let sp = match space {
+ PacketNumberSpace::Initial => self.spaces.pop(),
+ PacketNumberSpace::Handshake => {
+ let sp = self.spaces.pop();
+ self.spaces.shrink_to_fit();
+ sp
+ }
+ PacketNumberSpace::ApplicationData => panic!("discarding application space"),
+ };
+ assert_eq!(sp.unwrap().space, space, "dropping spaces out of order");
+ }
+
+ pub fn get_mut(&mut self, space: PacketNumberSpace) -> Option<&mut RecvdPackets> {
+ self.spaces.get_mut(match space {
+ PacketNumberSpace::ApplicationData => 0,
+ PacketNumberSpace::Handshake => 1,
+ PacketNumberSpace::Initial => 2,
+ })
+ }
+
+ pub fn ack_freq(
+ &mut self,
+ seqno: u64,
+ tolerance: PacketNumber,
+ delay: Duration,
+ ignore_order: bool,
+ ) {
+ // Only ApplicationData ever delays ACK.
+ self.get_mut(PacketNumberSpace::ApplicationData)
+ .unwrap()
+ .ack_freq(seqno, tolerance, delay, ignore_order);
+ }
+
+ // Force an ACK to be generated immediately (a PING was received).
+ pub fn immediate_ack(&mut self, now: Instant) {
+ self.get_mut(PacketNumberSpace::ApplicationData)
+ .unwrap()
+ .immediate_ack(now);
+ }
+
+ /// Determine the earliest time that an ACK might be needed.
+ pub fn ack_time(&self, now: Instant) -> Option<Instant> {
+ for recvd in &self.spaces {
+ qtrace!("ack_time for {} = {:?}", recvd.space, recvd.ack_time());
+ }
+
+ if self.spaces.len() == 1 {
+ self.spaces[0].ack_time()
+ } else {
+ // Ignore any time that is in the past relative to `now`.
+ // That is something of a hack, but there are cases where we can't send ACK
+ // frames for all spaces, which can mean that one space is stuck in the past.
+ // That isn't a problem because we guarantee that earlier spaces will always
+ // be able to send ACK frames.
+ self.spaces
+ .iter()
+ .filter_map(|recvd| recvd.ack_time().filter(|t| *t > now))
+ .min()
+ }
+ }
+
+ pub fn acked(&mut self, token: &AckToken) {
+ if let Some(space) = self.get_mut(token.space) {
+ space.acknowledged(&token.ranges);
+ }
+ }
+
+ pub(crate) fn write_frame(
+ &mut self,
+ pn_space: PacketNumberSpace,
+ now: Instant,
+ rtt: Duration,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ if let Some(space) = self.get_mut(pn_space) {
+ space.write_frame(now, rtt, builder, tokens, stats);
+ }
+ }
+}
+
+impl Default for AckTracker {
+ fn default() -> Self {
+ Self {
+ spaces: smallvec![
+ RecvdPackets::new(PacketNumberSpace::ApplicationData),
+ RecvdPackets::new(PacketNumberSpace::Handshake),
+ RecvdPackets::new(PacketNumberSpace::Initial),
+ ],
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::collections::HashSet;
+
+ use lazy_static::lazy_static;
+ use neqo_common::Encoder;
+
+ use super::{
+ AckTracker, Duration, Instant, PacketNumberSpace, PacketNumberSpaceSet, RecoveryToken,
+ RecvdPackets, MAX_TRACKED_RANGES,
+ };
+ use crate::{
+ frame::Frame,
+ packet::{PacketBuilder, PacketNumber},
+ stats::FrameStats,
+ };
+
+ const RTT: Duration = Duration::from_millis(100);
+ lazy_static! {
+ static ref NOW: Instant = Instant::now();
+ }
+
+ fn test_ack_range(pns: &[PacketNumber], nranges: usize) {
+ let mut rp = RecvdPackets::new(PacketNumberSpace::Initial); // Any space will do.
+ let mut packets = HashSet::new();
+
+ for pn in pns {
+ rp.set_received(*NOW, *pn, true);
+ packets.insert(*pn);
+ }
+
+ assert_eq!(rp.ranges.len(), nranges);
+
+ // Check that all these packets will be detected as duplicates.
+ for pn in pns {
+ assert!(rp.is_duplicate(*pn));
+ }
+
+ // Check that the ranges decrease monotonically and don't overlap.
+ let mut iter = rp.ranges.iter();
+ let mut last = iter.next().expect("should have at least one");
+ for n in iter {
+ assert!(n.largest + 1 < last.smallest);
+ last = n;
+ }
+
+ // Check that the ranges include the right values.
+ let mut in_ranges = HashSet::new();
+ for range in &rp.ranges {
+ for included in range.smallest..=range.largest {
+ in_ranges.insert(included);
+ }
+ }
+ assert_eq!(packets, in_ranges);
+ }
+
+ #[test]
+ fn pn0() {
+ test_ack_range(&[0], 1);
+ }
+
+ #[test]
+ fn pn1() {
+ test_ack_range(&[1], 1);
+ }
+
+ #[test]
+ fn two_ranges() {
+ test_ack_range(&[0, 1, 2, 5, 6, 7], 2);
+ }
+
+ #[test]
+ fn fill_in_range() {
+ test_ack_range(&[0, 1, 2, 5, 6, 7, 3, 4], 1);
+ }
+
+ #[test]
+ fn too_many_ranges() {
+ let mut rp = RecvdPackets::new(PacketNumberSpace::Initial); // Any space will do.
+
+ // This will add one too many disjoint ranges.
+ for i in 0..=MAX_TRACKED_RANGES {
+ rp.set_received(*NOW, (i * 2) as u64, true);
+ }
+
+ assert_eq!(rp.ranges.len(), MAX_TRACKED_RANGES);
+ assert_eq!(rp.ranges.back().unwrap().largest, 2);
+
+ // Even though the range was dropped, we still consider it a duplicate.
+ assert!(rp.is_duplicate(0));
+ assert!(!rp.is_duplicate(1));
+ assert!(rp.is_duplicate(2));
+ }
+
+ #[test]
+ fn ack_delay() {
+ const COUNT: PacketNumber = 9;
+ const DELAY: Duration = Duration::from_millis(7);
+ // Only application data packets are delayed.
+ let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData);
+ assert!(rp.ack_time().is_none());
+ assert!(!rp.ack_now(*NOW, RTT));
+
+ rp.ack_freq(0, COUNT, DELAY, false);
+
+ // Some packets won't cause an ACK to be needed.
+ for i in 0..COUNT {
+ rp.set_received(*NOW, i, true);
+ assert_eq!(Some(*NOW + DELAY), rp.ack_time());
+ assert!(!rp.ack_now(*NOW, RTT));
+ assert!(rp.ack_now(*NOW + DELAY, RTT));
+ }
+
+ // Exceeding COUNT will move the ACK time to now.
+ rp.set_received(*NOW, COUNT, true);
+ assert_eq!(Some(*NOW), rp.ack_time());
+ assert!(rp.ack_now(*NOW, RTT));
+ }
+
+ #[test]
+ fn no_ack_delay() {
+ for space in &[PacketNumberSpace::Initial, PacketNumberSpace::Handshake] {
+ let mut rp = RecvdPackets::new(*space);
+ assert!(rp.ack_time().is_none());
+ assert!(!rp.ack_now(*NOW, RTT));
+
+ // Any packet in these spaces is acknowledged straight away.
+ rp.set_received(*NOW, 0, true);
+ assert_eq!(Some(*NOW), rp.ack_time());
+ assert!(rp.ack_now(*NOW, RTT));
+ }
+ }
+
+ #[test]
+ fn ooo_no_ack_delay_new() {
+ let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData);
+ assert!(rp.ack_time().is_none());
+ assert!(!rp.ack_now(*NOW, RTT));
+
+ // Anything other than packet 0 is acknowledged immediately.
+ rp.set_received(*NOW, 1, true);
+ assert_eq!(Some(*NOW), rp.ack_time());
+ assert!(rp.ack_now(*NOW, RTT));
+ }
+
+ fn write_frame_at(rp: &mut RecvdPackets, now: Instant) {
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ let mut stats = FrameStats::default();
+ let mut tokens = Vec::new();
+ rp.write_frame(now, RTT, &mut builder, &mut tokens, &mut stats);
+ assert!(!tokens.is_empty());
+ assert_eq!(stats.ack, 1);
+ }
+
+ fn write_frame(rp: &mut RecvdPackets) {
+ write_frame_at(rp, *NOW);
+ }
+
+ #[test]
+ fn ooo_no_ack_delay_fill() {
+ let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData);
+ rp.set_received(*NOW, 1, true);
+ write_frame(&mut rp);
+
+ // Filling in behind the largest acknowledged causes immediate ACK.
+ rp.set_received(*NOW, 0, true);
+ write_frame(&mut rp);
+
+ // Receiving the next packet won't elicit an ACK.
+ rp.set_received(*NOW, 2, true);
+ assert!(!rp.ack_now(*NOW, RTT));
+ }
+
+ #[test]
+ fn immediate_ack_after_rtt() {
+ let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData);
+ rp.set_received(*NOW, 1, true);
+ write_frame(&mut rp);
+
+ // Filling in behind the largest acknowledged causes immediate ACK.
+ rp.set_received(*NOW, 0, true);
+ write_frame(&mut rp);
+
+ // A new packet ordinarily doesn't result in an ACK, but this time it does.
+ rp.set_received(*NOW + RTT, 2, true);
+ write_frame_at(&mut rp, *NOW + RTT);
+ }
+
+ #[test]
+ fn ooo_no_ack_delay_threshold_new() {
+ let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData);
+
+ // Set tolerance to 2 and then it takes three packets.
+ rp.ack_freq(0, 2, Duration::from_millis(10), true);
+
+ rp.set_received(*NOW, 1, true);
+ assert_ne!(Some(*NOW), rp.ack_time());
+ rp.set_received(*NOW, 2, true);
+ assert_ne!(Some(*NOW), rp.ack_time());
+ rp.set_received(*NOW, 3, true);
+ assert_eq!(Some(*NOW), rp.ack_time());
+ }
+
+ #[test]
+ fn ooo_no_ack_delay_threshold_gap() {
+ let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData);
+ rp.set_received(*NOW, 1, true);
+ write_frame(&mut rp);
+
+ // Set tolerance to 2 and then it takes three packets.
+ rp.ack_freq(0, 2, Duration::from_millis(10), true);
+
+ rp.set_received(*NOW, 3, true);
+ assert_ne!(Some(*NOW), rp.ack_time());
+ rp.set_received(*NOW, 4, true);
+ assert_ne!(Some(*NOW), rp.ack_time());
+ rp.set_received(*NOW, 5, true);
+ assert_eq!(Some(*NOW), rp.ack_time());
+ }
+
+ /// Test that an in-order packet that is not ack-eliciting doesn't
+ /// increase the number of packets needed to cause an ACK.
+ #[test]
+ fn non_ack_eliciting_skip() {
+ let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData);
+ rp.ack_freq(0, 1, Duration::from_millis(10), true);
+
+ // This should be ignored.
+ rp.set_received(*NOW, 0, false);
+ assert_ne!(Some(*NOW), rp.ack_time());
+ // Skip 1 (it has no effect).
+ rp.set_received(*NOW, 2, true);
+ assert_ne!(Some(*NOW), rp.ack_time());
+ rp.set_received(*NOW, 3, true);
+ assert_eq!(Some(*NOW), rp.ack_time());
+ }
+
+ /// If a packet that is not ack-eliciting is reordered, that's fine too.
+ #[test]
+ fn non_ack_eliciting_reorder() {
+ let mut rp = RecvdPackets::new(PacketNumberSpace::ApplicationData);
+ rp.ack_freq(0, 1, Duration::from_millis(10), false);
+
+ // These are out of order, but they are not ack-eliciting.
+ rp.set_received(*NOW, 1, false);
+ assert_ne!(Some(*NOW), rp.ack_time());
+ rp.set_received(*NOW, 0, false);
+ assert_ne!(Some(*NOW), rp.ack_time());
+
+ // These are in order.
+ rp.set_received(*NOW, 2, true);
+ assert_ne!(Some(*NOW), rp.ack_time());
+ rp.set_received(*NOW, 3, true);
+ assert_eq!(Some(*NOW), rp.ack_time());
+ }
+
+ #[test]
+ fn aggregate_ack_time() {
+ const DELAY: Duration = Duration::from_millis(17);
+ let mut tracker = AckTracker::default();
+ tracker.ack_freq(0, 1, DELAY, false);
+ // This packet won't trigger an ACK.
+ tracker
+ .get_mut(PacketNumberSpace::Handshake)
+ .unwrap()
+ .set_received(*NOW, 0, false);
+ assert_eq!(None, tracker.ack_time(*NOW));
+
+ // This should be delayed.
+ tracker
+ .get_mut(PacketNumberSpace::ApplicationData)
+ .unwrap()
+ .set_received(*NOW, 0, true);
+ assert_eq!(Some(*NOW + DELAY), tracker.ack_time(*NOW));
+
+ // This should move the time forward.
+ let later = *NOW + (DELAY / 2);
+ tracker
+ .get_mut(PacketNumberSpace::Initial)
+ .unwrap()
+ .set_received(later, 0, true);
+ assert_eq!(Some(later), tracker.ack_time(*NOW));
+ }
+
+ #[test]
+ #[should_panic(expected = "discarding application space")]
+ fn drop_app() {
+ let mut tracker = AckTracker::default();
+ tracker.drop_space(PacketNumberSpace::ApplicationData);
+ }
+
+ #[test]
+ #[should_panic(expected = "dropping spaces out of order")]
+ fn drop_out_of_order() {
+ let mut tracker = AckTracker::default();
+ tracker.drop_space(PacketNumberSpace::Handshake);
+ }
+
+ #[test]
+ fn drop_spaces() {
+ let mut tracker = AckTracker::default();
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ tracker
+ .get_mut(PacketNumberSpace::Initial)
+ .unwrap()
+ .set_received(*NOW, 0, true);
+ // The reference time for `ack_time` has to be in the past or we filter out the timer.
+ assert!(tracker
+ .ack_time(NOW.checked_sub(Duration::from_millis(1)).unwrap())
+ .is_some());
+
+ let mut tokens = Vec::new();
+ let mut stats = FrameStats::default();
+ tracker.write_frame(
+ PacketNumberSpace::Initial,
+ *NOW,
+ RTT,
+ &mut builder,
+ &mut tokens,
+ &mut stats,
+ );
+ assert_eq!(stats.ack, 1);
+
+ // Mark another packet as received so we have cause to send another ACK in that space.
+ tracker
+ .get_mut(PacketNumberSpace::Initial)
+ .unwrap()
+ .set_received(*NOW, 1, true);
+ assert!(tracker
+ .ack_time(NOW.checked_sub(Duration::from_millis(1)).unwrap())
+ .is_some());
+
+ // Now drop that space.
+ tracker.drop_space(PacketNumberSpace::Initial);
+
+ assert!(tracker.get_mut(PacketNumberSpace::Initial).is_none());
+ assert!(tracker
+ .ack_time(NOW.checked_sub(Duration::from_millis(1)).unwrap())
+ .is_none());
+ tracker.write_frame(
+ PacketNumberSpace::Initial,
+ *NOW,
+ RTT,
+ &mut builder,
+ &mut tokens,
+ &mut stats,
+ );
+ assert_eq!(stats.ack, 1);
+ if let RecoveryToken::Ack(tok) = &tokens[0] {
+ tracker.acked(tok); // Should be a noop.
+ } else {
+ panic!("not an ACK token");
+ }
+ }
+
+ #[test]
+ fn no_room_for_ack() {
+ let mut tracker = AckTracker::default();
+ tracker
+ .get_mut(PacketNumberSpace::Initial)
+ .unwrap()
+ .set_received(*NOW, 0, true);
+ assert!(tracker
+ .ack_time(NOW.checked_sub(Duration::from_millis(1)).unwrap())
+ .is_some());
+
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ builder.set_limit(10);
+
+ let mut stats = FrameStats::default();
+ tracker.write_frame(
+ PacketNumberSpace::Initial,
+ *NOW,
+ RTT,
+ &mut builder,
+ &mut Vec::new(),
+ &mut stats,
+ );
+ assert_eq!(stats.ack, 0);
+ assert_eq!(builder.len(), 1); // Only the short packet header has been added.
+ }
+
+ #[test]
+ fn no_room_for_extra_range() {
+ let mut tracker = AckTracker::default();
+ tracker
+ .get_mut(PacketNumberSpace::Initial)
+ .unwrap()
+ .set_received(*NOW, 0, true);
+ tracker
+ .get_mut(PacketNumberSpace::Initial)
+ .unwrap()
+ .set_received(*NOW, 2, true);
+ assert!(tracker
+ .ack_time(NOW.checked_sub(Duration::from_millis(1)).unwrap())
+ .is_some());
+
+ let mut builder = PacketBuilder::short(Encoder::new(), false, []);
+ builder.set_limit(32);
+
+ let mut stats = FrameStats::default();
+ tracker.write_frame(
+ PacketNumberSpace::Initial,
+ *NOW,
+ RTT,
+ &mut builder,
+ &mut Vec::new(),
+ &mut stats,
+ );
+ assert_eq!(stats.ack, 1);
+
+ let mut dec = builder.as_decoder();
+ _ = dec.decode_byte().unwrap(); // Skip the short header.
+ let frame = Frame::decode(&mut dec).unwrap();
+ if let Frame::Ack { ack_ranges, .. } = frame {
+ assert_eq!(ack_ranges.len(), 0);
+ } else {
+ panic!("not an ACK!");
+ }
+ }
+
+ #[test]
+ fn ack_time_elapsed() {
+ let mut tracker = AckTracker::default();
+
+ // While we have multiple PN spaces, we ignore ACK timers from the past.
+ // Send out of order to cause the delayed ack timer to be set to `*NOW`.
+ tracker
+ .get_mut(PacketNumberSpace::ApplicationData)
+ .unwrap()
+ .set_received(*NOW, 3, true);
+ assert!(tracker.ack_time(*NOW + Duration::from_millis(1)).is_none());
+
+ // When we are reduced to one space, that filter is off.
+ tracker.drop_space(PacketNumberSpace::Initial);
+ tracker.drop_space(PacketNumberSpace::Handshake);
+ assert_eq!(
+ tracker.ack_time(*NOW + Duration::from_millis(1)),
+ Some(*NOW)
+ );
+ }
+
+ #[test]
+ fn pnspaceset_default() {
+ let set = PacketNumberSpaceSet::default();
+ assert!(!set[PacketNumberSpace::Initial]);
+ assert!(!set[PacketNumberSpace::Handshake]);
+ assert!(!set[PacketNumberSpace::ApplicationData]);
+ }
+
+ #[test]
+ fn pnspaceset_from() {
+ let set = PacketNumberSpaceSet::from(&[PacketNumberSpace::Initial]);
+ assert!(set[PacketNumberSpace::Initial]);
+ assert!(!set[PacketNumberSpace::Handshake]);
+ assert!(!set[PacketNumberSpace::ApplicationData]);
+
+ let set =
+ PacketNumberSpaceSet::from(&[PacketNumberSpace::Handshake, PacketNumberSpace::Initial]);
+ assert!(set[PacketNumberSpace::Initial]);
+ assert!(set[PacketNumberSpace::Handshake]);
+ assert!(!set[PacketNumberSpace::ApplicationData]);
+
+ let set = PacketNumberSpaceSet::from(&[
+ PacketNumberSpace::ApplicationData,
+ PacketNumberSpace::ApplicationData,
+ ]);
+ assert!(!set[PacketNumberSpace::Initial]);
+ assert!(!set[PacketNumberSpace::Handshake]);
+ assert!(set[PacketNumberSpace::ApplicationData]);
+ }
+
+ #[test]
+ fn pnspaceset_copy() {
+ let set = PacketNumberSpaceSet::from(&[
+ PacketNumberSpace::Handshake,
+ PacketNumberSpace::ApplicationData,
+ ]);
+ let copy = set;
+ assert!(!copy[PacketNumberSpace::Initial]);
+ assert!(copy[PacketNumberSpace::Handshake]);
+ assert!(copy[PacketNumberSpace::ApplicationData]);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/version.rs b/third_party/rust/neqo-transport/src/version.rs
new file mode 100644
index 0000000000..13db0bf024
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/version.rs
@@ -0,0 +1,235 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::convert::TryFrom;
+
+use neqo_common::qdebug;
+
+use crate::{Error, Res};
+
+pub type WireVersion = u32;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum Version {
+ Version2,
+ Version1,
+ Draft29,
+ Draft30,
+ Draft31,
+ Draft32,
+}
+
+impl Version {
+ pub const fn wire_version(self) -> WireVersion {
+ match self {
+ Self::Version2 => 0x6b33_43cf,
+ Self::Version1 => 1,
+ Self::Draft29 => 0xff00_0000 + 29,
+ Self::Draft30 => 0xff00_0000 + 30,
+ Self::Draft31 => 0xff00_0000 + 31,
+ Self::Draft32 => 0xff00_0000 + 32,
+ }
+ }
+
+ pub(crate) fn initial_salt(self) -> &'static [u8] {
+ const INITIAL_SALT_V2: &[u8] = &[
+ 0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26,
+ 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9,
+ ];
+ const INITIAL_SALT_V1: &[u8] = &[
+ 0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8,
+ 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a,
+ ];
+ const INITIAL_SALT_29_32: &[u8] = &[
+ 0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61,
+ 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99,
+ ];
+ match self {
+ Self::Version2 => INITIAL_SALT_V2,
+ Self::Version1 => INITIAL_SALT_V1,
+ Self::Draft29 | Self::Draft30 | Self::Draft31 | Self::Draft32 => INITIAL_SALT_29_32,
+ }
+ }
+
+ pub(crate) fn label_prefix(self) -> &'static str {
+ match self {
+ Self::Version2 => "quicv2 ",
+ Self::Version1 | Self::Draft29 | Self::Draft30 | Self::Draft31 | Self::Draft32 => {
+ "quic "
+ }
+ }
+ }
+
+ pub(crate) fn retry_secret(self) -> &'static [u8] {
+ const RETRY_SECRET_V2: &[u8] = &[
+ 0xc4, 0xdd, 0x24, 0x84, 0xd6, 0x81, 0xae, 0xfa, 0x4f, 0xf4, 0xd6, 0x9c, 0x2c, 0x20,
+ 0x29, 0x99, 0x84, 0xa7, 0x65, 0xa5, 0xd3, 0xc3, 0x19, 0x82, 0xf3, 0x8f, 0xc7, 0x41,
+ 0x62, 0x15, 0x5e, 0x9f,
+ ];
+ const RETRY_SECRET_V1: &[u8] = &[
+ 0xd9, 0xc9, 0x94, 0x3e, 0x61, 0x01, 0xfd, 0x20, 0x00, 0x21, 0x50, 0x6b, 0xcc, 0x02,
+ 0x81, 0x4c, 0x73, 0x03, 0x0f, 0x25, 0xc7, 0x9d, 0x71, 0xce, 0x87, 0x6e, 0xca, 0x87,
+ 0x6e, 0x6f, 0xca, 0x8e,
+ ];
+ const RETRY_SECRET_29: &[u8] = &[
+ 0x8b, 0x0d, 0x37, 0xeb, 0x85, 0x35, 0x02, 0x2e, 0xbc, 0x8d, 0x76, 0xa2, 0x07, 0xd8,
+ 0x0d, 0xf2, 0x26, 0x46, 0xec, 0x06, 0xdc, 0x80, 0x96, 0x42, 0xc3, 0x0a, 0x8b, 0xaa,
+ 0x2b, 0xaa, 0xff, 0x4c,
+ ];
+ match self {
+ Self::Version2 => RETRY_SECRET_V2,
+ Self::Version1 => RETRY_SECRET_V1,
+ Self::Draft29 | Self::Draft30 | Self::Draft31 | Self::Draft32 => RETRY_SECRET_29,
+ }
+ }
+
+ pub(crate) fn is_draft(self) -> bool {
+ matches!(
+ self,
+ Self::Draft29 | Self::Draft30 | Self::Draft31 | Self::Draft32,
+ )
+ }
+
+ /// Determine if `self` can be upgraded to `other` compatibly.
+ pub fn is_compatible(self, other: Self) -> bool {
+ self == other
+ || matches!(
+ (self, other),
+ (Self::Version1, Self::Version2) | (Self::Version2, Self::Version1)
+ )
+ }
+
+ pub fn all() -> Vec<Self> {
+ vec![
+ Self::Version2,
+ Self::Version1,
+ Self::Draft32,
+ Self::Draft31,
+ Self::Draft30,
+ Self::Draft29,
+ ]
+ }
+
+ pub fn compatible<'a>(
+ self,
+ all: impl IntoIterator<Item = &'a Self>,
+ ) -> impl Iterator<Item = &'a Self> {
+ all.into_iter().filter(move |&v| self.is_compatible(*v))
+ }
+}
+
+impl Default for Version {
+ fn default() -> Self {
+ Self::Version1
+ }
+}
+
+impl TryFrom<WireVersion> for Version {
+ type Error = Error;
+
+ fn try_from(wire: WireVersion) -> Res<Self> {
+ if wire == 1 {
+ Ok(Self::Version1)
+ } else if wire == 0x6b33_43cf {
+ Ok(Self::Version2)
+ } else if wire == 0xff00_0000 + 29 {
+ Ok(Self::Draft29)
+ } else if wire == 0xff00_0000 + 30 {
+ Ok(Self::Draft30)
+ } else if wire == 0xff00_0000 + 31 {
+ Ok(Self::Draft31)
+ } else if wire == 0xff00_0000 + 32 {
+ Ok(Self::Draft32)
+ } else {
+ Err(Error::VersionNegotiation)
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct VersionConfig {
+ /// The version that a client uses to establish a connection.
+ ///
+ /// For a client, this is the version that is sent out in an Initial packet.
+ /// A client that resumes will set this to the version from the original
+ /// connection.
+ /// A client that handles a Version Negotiation packet will be initialized with
+ /// a version chosen from the packet, but it will then have this value overridden
+ /// to match the original configuration so that the version negotiation can be
+ /// authenticated.
+ ///
+ /// For a server `Connection`, this is the only type of Initial packet that
+ /// can be accepted; the correct value is set by `Server`, see below.
+ ///
+ /// For a `Server`, this value is not used; if an Initial packet is received
+ /// in a supported version (as listed in `versions`), new instances of
+ /// `Connection` will be created with this value set to match what was received.
+ ///
+ /// An invariant here is that this version is always listed in `all`.
+ initial: Version,
+ /// The set of versions that are enabled, in preference order. For a server,
+ /// only the relative order of compatible versions matters.
+ all: Vec<Version>,
+}
+
+impl VersionConfig {
+ pub fn new(initial: Version, all: Vec<Version>) -> Self {
+ assert!(all.contains(&initial));
+ Self { initial, all }
+ }
+
+ pub fn initial(&self) -> Version {
+ self.initial
+ }
+
+ pub fn all(&self) -> &[Version] {
+ &self.all
+ }
+
+ /// Overwrite the initial value; used by the `Server` when handling new connections
+ /// and by the client on resumption.
+ pub(crate) fn set_initial(&mut self, initial: Version) {
+ qdebug!(
+ "Overwrite initial version {:?} ==> {:?}",
+ self.initial,
+ initial
+ );
+ assert!(self.all.contains(&initial));
+ self.initial = initial;
+ }
+
+ pub fn compatible(&self) -> impl Iterator<Item = &Version> {
+ self.initial.compatible(&self.all)
+ }
+
+ fn find_preferred<'a>(
+ preferences: impl IntoIterator<Item = &'a Version>,
+ vn: &[WireVersion],
+ ) -> Option<Version> {
+ for v in preferences {
+ if vn.contains(&v.wire_version()) {
+ return Some(*v);
+ }
+ }
+ None
+ }
+
+ /// Determine the preferred version based on a version negotiation packet.
+ pub(crate) fn preferred(&self, vn: &[WireVersion]) -> Option<Version> {
+ Self::find_preferred(&self.all, vn)
+ }
+
+ /// Determine the preferred version based on a set of compatible versions.
+ pub(crate) fn preferred_compatible(&self, vn: &[WireVersion]) -> Option<Version> {
+ Self::find_preferred(self.compatible(), vn)
+ }
+}
+
+impl Default for VersionConfig {
+ fn default() -> Self {
+ Self::new(Version::default(), Version::all())
+ }
+}