summaryrefslogtreecommitdiffstats
path: root/third_party/rust/neqo-transport/src
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/neqo-transport/src')
-rw-r--r--third_party/rust/neqo-transport/src/addr_valid.rs509
-rw-r--r--third_party/rust/neqo-transport/src/cc/classic_cc.rs955
-rw-r--r--third_party/rust/neqo-transport/src/cc/mod.rs55
-rw-r--r--third_party/rust/neqo-transport/src/cc/new_reno.rs44
-rw-r--r--third_party/rust/neqo-transport/src/cid.rs157
-rw-r--r--third_party/rust/neqo-transport/src/connection/idle.rs90
-rw-r--r--third_party/rust/neqo-transport/src/connection/mod.rs2768
-rw-r--r--third_party/rust/neqo-transport/src/connection/params.rs71
-rw-r--r--third_party/rust/neqo-transport/src/connection/saved.rs72
-rw-r--r--third_party/rust/neqo-transport/src/connection/state.rs207
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/cc.rs526
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/close.rs206
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/handshake.rs697
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/idle.rs274
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/keys.rs330
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/mod.rs305
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/recovery.rs636
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/resumption.rs182
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/stream.rs580
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/vn.rs201
-rw-r--r--third_party/rust/neqo-transport/src/connection/tests/zerortt.rs193
-rw-r--r--third_party/rust/neqo-transport/src/crypto.rs1293
-rw-r--r--third_party/rust/neqo-transport/src/dump.rs32
-rw-r--r--third_party/rust/neqo-transport/src/events.rs254
-rw-r--r--third_party/rust/neqo-transport/src/flow_mgr.rs400
-rw-r--r--third_party/rust/neqo-transport/src/frame.rs835
-rw-r--r--third_party/rust/neqo-transport/src/lib.rs195
-rw-r--r--third_party/rust/neqo-transport/src/pace.rs138
-rw-r--r--third_party/rust/neqo-transport/src/packet/mod.rs1339
-rw-r--r--third_party/rust/neqo-transport/src/packet/retry.rs63
-rw-r--r--third_party/rust/neqo-transport/src/path.rs109
-rw-r--r--third_party/rust/neqo-transport/src/qlog.rs442
-rw-r--r--third_party/rust/neqo-transport/src/recovery.rs1470
-rw-r--r--third_party/rust/neqo-transport/src/recv_stream.rs1110
-rw-r--r--third_party/rust/neqo-transport/src/send_stream.rs1746
-rw-r--r--third_party/rust/neqo-transport/src/sender.rs124
-rw-r--r--third_party/rust/neqo-transport/src/server.rs636
-rw-r--r--third_party/rust/neqo-transport/src/stats.rs195
-rw-r--r--third_party/rust/neqo-transport/src/stream_id.rs205
-rw-r--r--third_party/rust/neqo-transport/src/tparams.rs541
-rw-r--r--third_party/rust/neqo-transport/src/tracking.rs992
41 files changed, 21177 insertions, 0 deletions
diff --git a/third_party/rust/neqo-transport/src/addr_valid.rs b/third_party/rust/neqo-transport/src/addr_valid.rs
new file mode 100644
index 0000000000..a8fcd76ab9
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/addr_valid.rs
@@ -0,0 +1,509 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// This file implements functions necessary for address validation.
+
+use neqo_common::{qinfo, qtrace, Decoder, Encoder, Role};
+use neqo_crypto::{
+ constants::{TLS_AES_128_GCM_SHA256, TLS_VERSION_1_3},
+ selfencrypt::SelfEncrypt,
+};
+
+use crate::cid::ConnectionId;
+use crate::packet::PacketBuilder;
+use crate::recovery::RecoveryToken;
+use crate::stats::FrameStats;
+use crate::Res;
+
+use smallvec::SmallVec;
+use std::convert::TryFrom;
+use std::net::{IpAddr, SocketAddr};
+use std::time::{Duration, Instant};
+
+/// A prefix we add to Retry tokens to distinguish them from NEW_TOKEN tokens.
+const TOKEN_IDENTIFIER_RETRY: &[u8] = &[0x52, 0x65, 0x74, 0x72, 0x79];
+/// A prefix on NEW_TOKEN tokens, that is maximally Hamming distant from NEW_TOKEN.
+/// Together, these need to have a low probability of collision, even if there is
+/// corruption of individual bits in transit.
+const TOKEN_IDENTIFIER_NEW_TOKEN: &[u8] = &[0xad, 0x9a, 0x8b, 0x8d, 0x86];
+
+/// The maximum number of tokens we'll save from NEW_TOKEN frames.
+/// This should be the same as the value of MAX_TICKETS in neqo-crypto.
+const MAX_NEW_TOKEN: usize = 4;
+/// The number of tokens we'll track for the purposes of looking for duplicates.
+/// This is based on how many might be received over a period where could be
+/// retransmissions. It should be at least `MAX_NEW_TOKEN`.
+const MAX_SAVED_TOKENS: usize = 8;
+
+/// `ValidateAddress` determines what sort of address validation is performed.
+/// In short, this determines when a Retry packet is sent.
+#[derive(Debug, PartialEq, Eq)]
+pub enum ValidateAddress {
+ /// Require address validation never.
+ Never,
+ /// Require address validation unless a NEW_TOKEN token is provided.
+ NoToken,
+ /// Require address validation even if a NEW_TOKEN token is provided.
+ Always,
+}
+
+pub enum AddressValidationResult {
+ Pass,
+ ValidRetry(ConnectionId),
+ Validate,
+ Invalid,
+}
+
+pub struct AddressValidation {
+ /// What sort of validation is performed.
+ validation: ValidateAddress,
+ /// A self-encryption object used for protecting Retry tokens.
+ self_encrypt: SelfEncrypt,
+ /// When this object was created.
+ start_time: Instant,
+}
+
+impl AddressValidation {
+ pub fn new(now: Instant, validation: ValidateAddress) -> Res<Self> {
+ Ok(Self {
+ validation,
+ self_encrypt: SelfEncrypt::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256)?,
+ start_time: now,
+ })
+ }
+
+ fn encode_aad(peer_address: SocketAddr, retry: bool) -> Encoder {
+ // Let's be "clever" by putting the peer's address in the AAD.
+ // We don't need to encode these into the token as they should be
+ // available when we need to check the token.
+ let mut aad = Encoder::default();
+ if retry {
+ aad.encode(TOKEN_IDENTIFIER_RETRY);
+ } else {
+ aad.encode(TOKEN_IDENTIFIER_NEW_TOKEN);
+ }
+ match peer_address.ip() {
+ IpAddr::V4(a) => {
+ aad.encode_byte(4);
+ aad.encode(&a.octets());
+ }
+ IpAddr::V6(a) => {
+ aad.encode_byte(6);
+ aad.encode(&a.octets());
+ }
+ }
+ if retry {
+ aad.encode_uint(2, peer_address.port());
+ }
+ aad
+ }
+
+ pub fn generate_token(
+ &self,
+ dcid: Option<&ConnectionId>,
+ peer_address: SocketAddr,
+ now: Instant,
+ ) -> Res<Vec<u8>> {
+ const EXPIRATION_RETRY: Duration = Duration::from_secs(5);
+ const EXPIRATION_NEW_TOKEN: Duration = Duration::from_secs(60 * 60 * 24);
+
+ // TODO(mt) rotate keys on a fixed schedule.
+ let retry = dcid.is_some();
+ let mut data = Encoder::default();
+ let end = now
+ + if retry {
+ EXPIRATION_RETRY
+ } else {
+ EXPIRATION_NEW_TOKEN
+ };
+ let end_millis = u32::try_from(end.duration_since(self.start_time).as_millis())?;
+ data.encode_uint(4, end_millis);
+ if let Some(dcid) = dcid {
+ data.encode(dcid);
+ }
+
+ // Include the token identifier ("Retry"/~) in the AAD, then keep it for plaintext.
+ let mut buf = Self::encode_aad(peer_address, retry);
+ let encrypted = self.self_encrypt.seal(&buf, &data)?;
+ buf.truncate(TOKEN_IDENTIFIER_RETRY.len());
+ buf.encode(&encrypted);
+ Ok(buf.into())
+ }
+
+ /// This generates a token for use with Retry.
+ pub fn generate_retry_token(
+ &self,
+ dcid: &ConnectionId,
+ peer_address: SocketAddr,
+ now: Instant,
+ ) -> Res<Vec<u8>> {
+ self.generate_token(Some(dcid), peer_address, now)
+ }
+
+ /// This generates a token for use with NEW_TOKEN.
+ pub fn generate_new_token(&self, peer_address: SocketAddr, now: Instant) -> Res<Vec<u8>> {
+ self.generate_token(None, peer_address, now)
+ }
+
+ pub fn set_validation(&mut self, validation: ValidateAddress) {
+ qtrace!("AddressValidation {:p}: set to {:?}", self, validation);
+ self.validation = validation;
+ }
+
+ /// Decrypts `token` and returns the connection ID it contains.
+ /// Returns a tuple with a boolean indicating whether this thinks
+ /// that the token was a Retry token, and a connection ID, that is
+ /// None if the token wasn't successfully decrypted.
+ fn decrypt_token(
+ &self,
+ token: &[u8],
+ peer_address: SocketAddr,
+ retry: bool,
+ now: Instant,
+ ) -> Option<ConnectionId> {
+ let peer_addr = Self::encode_aad(peer_address, retry);
+ let data = if let Ok(d) = self.self_encrypt.open(&peer_addr, token) {
+ d
+ } else {
+ return None;
+ };
+ let mut dec = Decoder::new(&data);
+ match dec.decode_uint(4) {
+ Some(d) => {
+ let end = self.start_time + Duration::from_millis(d);
+ if end < now {
+ qtrace!("Expired token: {:?} vs. {:?}", end, now);
+ return None;
+ }
+ }
+ _ => return None,
+ }
+ Some(ConnectionId::from(dec.decode_remainder()))
+ }
+
+ /// Calculate the Hamming difference between our identifier and the target.
+ /// Less than one difference per byte indicates that it is likely not a Retry.
+ /// This generous interpretation allows for a lot of damage in transit.
+ /// Note that if this check fails, then the token will be treated like it came
+ /// from NEW_TOKEN instead. If there truly is corruption of packets that causes
+ /// validation failure, it will be a failure that we try to recover from.
+ fn is_likely_retry(token: &[u8]) -> bool {
+ let mut difference = 0;
+ for i in 0..TOKEN_IDENTIFIER_RETRY.len() {
+ difference += (token[i] ^ TOKEN_IDENTIFIER_RETRY[i]).count_ones();
+ }
+ usize::try_from(difference).unwrap() < TOKEN_IDENTIFIER_RETRY.len()
+ }
+
+ pub fn validate(
+ &self,
+ token: &[u8],
+ peer_address: SocketAddr,
+ now: Instant,
+ ) -> AddressValidationResult {
+ qtrace!(
+ "AddressValidation {:p}: validate {:?}",
+ self,
+ self.validation
+ );
+
+ if token.is_empty() {
+ if self.validation == ValidateAddress::Never {
+ qinfo!("AddressValidation: no token; accepting");
+ return AddressValidationResult::Pass;
+ } else {
+ qinfo!("AddressValidation: no token; validating");
+ return AddressValidationResult::Validate;
+ }
+ }
+ if token.len() <= TOKEN_IDENTIFIER_RETRY.len() {
+ // Treat bad tokens strictly.
+ qinfo!("AddressValidation: too short token");
+ return AddressValidationResult::Invalid;
+ }
+ let retry = Self::is_likely_retry(token);
+ let enc = &token[TOKEN_IDENTIFIER_RETRY.len()..];
+ // Note that this allows the token identifier part to be corrupted.
+ // That's OK here as we don't depend on that being authenticated.
+ if let Some(cid) = self.decrypt_token(enc, peer_address, retry, now) {
+ if retry {
+ // This is from Retry, so we should have an ODCID >= 8.
+ if cid.len() >= 8 {
+ qinfo!("AddressValidation: valid Retry token for {}", cid);
+ AddressValidationResult::ValidRetry(cid)
+ } else {
+ panic!("AddressValidation: Retry token with small CID {}", cid);
+ }
+ } else if cid.is_empty() {
+ // An empty connection ID means NEW_TOKEN.
+ if self.validation == ValidateAddress::Always {
+ qinfo!("AddressValidation: valid NEW_TOKEN token; validating again");
+ AddressValidationResult::Validate
+ } else {
+ qinfo!("AddressValidation: valid NEW_TOKEN token; accepting");
+ AddressValidationResult::Pass
+ }
+ } else {
+ panic!("AddressValidation: NEW_TOKEN token with CID {}", cid);
+ }
+ } else {
+ // From here on, we have a token that we couldn't decrypt.
+ // We've either lost the keys or we've received junk.
+ if retry {
+ // If this looked like a Retry, treat it as being bad.
+ qinfo!("AddressValidation: invalid Retry token; rejecting");
+ AddressValidationResult::Invalid
+ } else if self.validation == ValidateAddress::Never {
+ // We don't require validation, so OK.
+ qinfo!("AddressValidation: invalid NEW_TOKEN token; accepting");
+ AddressValidationResult::Pass
+ } else {
+ // This might be an invalid NEW_TOKEN token, or a valid one
+ // for which we have since lost the keys. Check again.
+ qinfo!("AddressValidation: invalid NEW_TOKEN token; validating again");
+ AddressValidationResult::Validate
+ }
+ }
+ }
+}
+
+// Note: these lint override can be removed in later versions where the lints
+// either don't trip a false positive or don't apply. rustc 1.46 is fine.
+#[allow(dead_code, clippy::large_enum_variant)]
+pub enum NewTokenState {
+ Client {
+ /// Tokens that haven't been taken yet.
+ pending: SmallVec<[Vec<u8>; MAX_NEW_TOKEN]>,
+ /// Tokens that have been taken, saved so that we can discard duplicates.
+ old: SmallVec<[Vec<u8>; MAX_SAVED_TOKENS]>,
+ },
+ Server(NewTokenSender),
+}
+
+impl NewTokenState {
+ pub fn new(role: Role) -> Self {
+ match role {
+ Role::Client => Self::Client {
+ pending: SmallVec::<[_; MAX_NEW_TOKEN]>::new(),
+ old: SmallVec::<[_; MAX_SAVED_TOKENS]>::new(),
+ },
+ Role::Server => Self::Server(NewTokenSender::default()),
+ }
+ }
+
+ /// Is there a token available?
+ pub fn has_token(&self) -> bool {
+ match self {
+ Self::Client { ref pending, .. } => !pending.is_empty(),
+ Self::Server(..) => false,
+ }
+ }
+
+ /// If this is a client, take a token if there is one.
+ /// If this is a server, panic.
+ pub fn take_token(&mut self) -> Option<&[u8]> {
+ if let Self::Client {
+ ref mut pending,
+ ref mut old,
+ } = self
+ {
+ if let Some(t) = pending.pop() {
+ if old.len() >= MAX_SAVED_TOKENS {
+ old.remove(0);
+ }
+ old.push(t);
+ Some(&old[old.len() - 1])
+ } else {
+ None
+ }
+ } else {
+ unreachable!();
+ }
+ }
+
+ /// If this is a client, save a token.
+ /// If this is a server, panic.
+ pub fn save_token(&mut self, token: Vec<u8>) {
+ if let Self::Client {
+ ref mut pending,
+ ref old,
+ } = self
+ {
+ for t in old.iter().rev().chain(pending.iter().rev()) {
+ if t == &token {
+ qinfo!("NewTokenState discarding duplicate NEW_TOKEN");
+ return;
+ }
+ }
+
+ if pending.len() >= MAX_NEW_TOKEN {
+ pending.remove(0);
+ }
+ pending.push(token);
+ } else {
+ unreachable!();
+ }
+ }
+
+ /// If this is a server, maybe send a frame.
+ /// If this is a client, do nothing.
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ if let Self::Server(ref mut sender) = self {
+ sender.write_frames(builder, tokens, stats);
+ }
+ }
+
+ /// If this a server, buffer a NEW_TOKEN for sending.
+ /// If this is a client, panic.
+ pub fn send_new_token(&mut self, token: Vec<u8>) {
+ if let Self::Server(ref mut sender) = self {
+ sender.send_new_token(token);
+ } else {
+ unreachable!();
+ }
+ }
+
+ /// If this a server, process a lost signal for a NEW_TOKEN frame.
+ /// If this is a client, panic.
+ pub fn lost(&mut self, seqno: usize) {
+ if let Self::Server(ref mut sender) = self {
+ sender.lost(seqno);
+ } else {
+ unreachable!();
+ }
+ }
+
+ /// If this a server, process remove the acknowledged NEW_TOKEN frame.
+ /// If this is a client, panic.
+ pub fn acked(&mut self, seqno: usize) {
+ if let Self::Server(ref mut sender) = self {
+ sender.acked(seqno);
+ } else {
+ unreachable!();
+ }
+ }
+}
+
+struct NewTokenFrameStatus {
+ seqno: usize,
+ token: Vec<u8>,
+ needs_sending: bool,
+}
+
+impl NewTokenFrameStatus {
+ fn len(&self) -> usize {
+ 1 + Encoder::vvec_len(self.token.len())
+ }
+}
+
+#[derive(Default)]
+pub struct NewTokenSender {
+ /// The unacknowledged NEW_TOKEN frames we are yet to send.
+ tokens: Vec<NewTokenFrameStatus>,
+ /// A sequence number that is used to track individual tokens
+ /// by reference (so that recovery tokens can be simple).
+ next_seqno: usize,
+}
+
+impl NewTokenSender {
+ /// Add a token to be sent.
+ pub fn send_new_token(&mut self, token: Vec<u8>) {
+ self.tokens.push(NewTokenFrameStatus {
+ seqno: self.next_seqno,
+ token,
+ needs_sending: true,
+ });
+ self.next_seqno += 1;
+ }
+
+ pub fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ for t in self.tokens.iter_mut() {
+ if t.needs_sending && t.len() <= builder.remaining() {
+ t.needs_sending = false;
+
+ builder.encode_varint(crate::frame::FRAME_TYPE_NEW_TOKEN);
+ builder.encode_vvec(&t.token);
+
+ tokens.push(RecoveryToken::NewToken(t.seqno));
+ stats.new_token += 1;
+ }
+ }
+ }
+
+ pub fn lost(&mut self, seqno: usize) {
+ for t in self.tokens.iter_mut() {
+ if t.seqno == seqno {
+ t.needs_sending = true;
+ break;
+ }
+ }
+ }
+
+ pub fn acked(&mut self, seqno: usize) {
+ self.tokens.retain(|i| i.seqno != seqno);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::NewTokenState;
+ use neqo_common::Role;
+
+ const ONE: &[u8] = &[1, 2, 3];
+ const TWO: &[u8] = &[4, 5];
+
+ #[test]
+ fn duplicate_saved() {
+ let mut tokens = NewTokenState::new(Role::Client);
+ tokens.save_token(ONE.to_vec());
+ tokens.save_token(TWO.to_vec());
+ tokens.save_token(ONE.to_vec());
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably TWO
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably ONE
+ assert!(!tokens.has_token());
+ assert!(tokens.take_token().is_none());
+ }
+
+ #[test]
+ fn duplicate_after_take() {
+ let mut tokens = NewTokenState::new(Role::Client);
+ tokens.save_token(ONE.to_vec());
+ tokens.save_token(TWO.to_vec());
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably TWO
+ tokens.save_token(ONE.to_vec());
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably ONE
+ assert!(!tokens.has_token());
+ assert!(tokens.take_token().is_none());
+ }
+
+ #[test]
+ fn duplicate_after_empty() {
+ let mut tokens = NewTokenState::new(Role::Client);
+ tokens.save_token(ONE.to_vec());
+ tokens.save_token(TWO.to_vec());
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably TWO
+ assert!(tokens.has_token());
+ assert!(tokens.take_token().is_some()); // probably ONE
+ tokens.save_token(ONE.to_vec());
+ assert!(!tokens.has_token());
+ assert!(tokens.take_token().is_none());
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/cc/classic_cc.rs b/third_party/rust/neqo-transport/src/cc/classic_cc.rs
new file mode 100644
index 0000000000..b41969c680
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cc/classic_cc.rs
@@ -0,0 +1,955 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Congestion control
+#![deny(clippy::pedantic)]
+
+use std::cmp::{max, min};
+use std::fmt::{self, Debug, Display};
+use std::time::{Duration, Instant};
+
+use super::CongestionControl;
+
+use crate::cc::MAX_DATAGRAM_SIZE;
+use crate::qlog::{self, QlogMetric};
+use crate::sender::PACING_BURST_SIZE;
+use crate::tracking::SentPacket;
+use neqo_common::{const_max, const_min, qdebug, qinfo, qlog::NeqoQlog, qtrace};
+
+pub const CWND_INITIAL_PKTS: usize = 10;
+const CWND_INITIAL: usize = const_min(
+ CWND_INITIAL_PKTS * MAX_DATAGRAM_SIZE,
+ const_max(2 * MAX_DATAGRAM_SIZE, 14720),
+);
+pub const CWND_MIN: usize = MAX_DATAGRAM_SIZE * 2;
+const PERSISTENT_CONG_THRESH: u32 = 3;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum State {
+ /// In either slow start or congestion avoidance, not recovery.
+ SlowStart,
+ /// In congestion avoidance.
+ CongestionAvoidance,
+ /// In a recovery period, but no packets have been sent yet. This is a
+ /// transient state because we want to exempt the first packet sent after
+ /// entering recovery from the congestion window.
+ RecoveryStart,
+ /// In a recovery period, with the first packet sent at this time.
+ Recovery,
+ /// Start of persistent congestion, which is transient, like `RecoveryStart`.
+ PersistentCongestion,
+}
+
+impl State {
+ pub fn in_recovery(self) -> bool {
+ matches!(self, Self::RecoveryStart | Self::Recovery)
+ }
+
+ pub fn in_slow_start(self) -> bool {
+ self == Self::SlowStart
+ }
+
+ /// These states are transient, we tell qlog on entry, but not on exit.
+ pub fn transient(self) -> bool {
+ matches!(self, Self::RecoveryStart | Self::PersistentCongestion)
+ }
+
+ /// Update a transient state to the true state.
+ pub fn update(&mut self) {
+ *self = match self {
+ Self::PersistentCongestion => Self::SlowStart,
+ Self::RecoveryStart => Self::Recovery,
+ _ => unreachable!(),
+ };
+ }
+
+ pub fn to_qlog(self) -> &'static str {
+ match self {
+ Self::SlowStart | Self::PersistentCongestion => "slow_start",
+ Self::CongestionAvoidance => "congestion_avoidance",
+ Self::Recovery | Self::RecoveryStart => "recovery",
+ }
+ }
+}
+
+pub trait WindowAdjustment: Display + Debug {
+ fn on_packets_acked(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize);
+ fn on_congestion_event(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize);
+}
+
+#[derive(Debug)]
+pub struct ClassicCongestionControl<T> {
+ cc_algorithm: T,
+ state: State,
+ congestion_window: usize, // = kInitialWindow
+ bytes_in_flight: usize,
+ acked_bytes: usize,
+ ssthresh: usize,
+ recovery_start: Option<Instant>,
+
+ qlog: NeqoQlog,
+}
+
+impl<T: WindowAdjustment> Display for ClassicCongestionControl<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{} CongCtrl {}/{} ssthresh {}",
+ self.cc_algorithm, self.bytes_in_flight, self.congestion_window, self.ssthresh,
+ )?;
+ Ok(())
+ }
+}
+
+impl<T: WindowAdjustment> CongestionControl for ClassicCongestionControl<T> {
+ fn set_qlog(&mut self, qlog: NeqoQlog) {
+ self.qlog = qlog;
+ }
+
+ #[must_use]
+ fn cwnd(&self) -> usize {
+ self.congestion_window
+ }
+
+ #[must_use]
+ fn bytes_in_flight(&self) -> usize {
+ self.bytes_in_flight
+ }
+
+ #[must_use]
+ fn cwnd_avail(&self) -> usize {
+ // BIF can be higher than cwnd due to PTO packets, which are sent even
+ // if avail is 0, but still count towards BIF.
+ self.congestion_window.saturating_sub(self.bytes_in_flight)
+ }
+
+ // Multi-packet version of OnPacketAckedCC
+ fn on_packets_acked(&mut self, acked_pkts: &[SentPacket]) {
+ // Check whether we are app limited before acked packets are removed
+ // from bytes_in_flight.
+ let is_app_limited = self.app_limited();
+ qtrace!(
+ [self],
+ "app limited={}, bytes_in_flight:{}, cwnd: {}, state: {:?} pacing_burst_size: {}",
+ is_app_limited,
+ self.bytes_in_flight,
+ self.congestion_window,
+ self.state,
+ MAX_DATAGRAM_SIZE * PACING_BURST_SIZE,
+ );
+
+ let mut acked_bytes = 0;
+ for pkt in acked_pkts.iter().filter(|pkt| pkt.cc_outstanding()) {
+ assert!(self.bytes_in_flight >= pkt.size);
+ self.bytes_in_flight -= pkt.size;
+
+ if !self.after_recovery_start(pkt) {
+ // Do not increase congestion window for packets sent before
+ // recovery last started.
+ continue;
+ }
+
+ if self.state.in_recovery() {
+ self.set_state(State::CongestionAvoidance);
+ qlog::metrics_updated(&mut self.qlog, &[QlogMetric::InRecovery(false)]);
+ }
+
+ acked_bytes += pkt.size;
+ }
+
+ if !is_app_limited {
+ self.acked_bytes += acked_bytes;
+ }
+
+ qtrace!([self], "ACK received, acked_bytes = {}", self.acked_bytes);
+
+ // Slow start, up to the slow start threshold.
+ if self.congestion_window < self.ssthresh {
+ let increase = min(self.ssthresh - self.congestion_window, self.acked_bytes);
+ self.congestion_window += increase;
+ self.acked_bytes -= increase;
+ qinfo!([self], "slow start += {}", increase);
+ if self.congestion_window == self.ssthresh {
+ // This doesn't look like it is necessary, but it can happen
+ // after persistent congestion.
+ self.set_state(State::CongestionAvoidance);
+ }
+ }
+ // Congestion avoidance, above the slow start threshold.
+ if self.congestion_window >= self.ssthresh {
+ let (cwnd, acked_bytes) = self
+ .cc_algorithm
+ .on_packets_acked(self.congestion_window, self.acked_bytes);
+ self.congestion_window = cwnd;
+ self.acked_bytes = acked_bytes;
+ }
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[
+ QlogMetric::CongestionWindow(self.congestion_window),
+ QlogMetric::BytesInFlight(self.bytes_in_flight),
+ ],
+ );
+ }
+
+ /// Update congestion controller state based on lost packets.
+ fn on_packets_lost(
+ &mut self,
+ first_rtt_sample_time: Option<Instant>,
+ prev_largest_acked_sent: Option<Instant>,
+ pto: Duration,
+ lost_packets: &[SentPacket],
+ ) {
+ if lost_packets.is_empty() {
+ return;
+ }
+
+ for pkt in lost_packets.iter().filter(|pkt| pkt.ack_eliciting()) {
+ assert!(self.bytes_in_flight >= pkt.size);
+ self.bytes_in_flight -= pkt.size;
+ }
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[QlogMetric::BytesInFlight(self.bytes_in_flight)],
+ );
+
+ qdebug!([self], "Pkts lost {}", lost_packets.len());
+
+ self.on_congestion_event(lost_packets.last().unwrap());
+ self.detect_persistent_congestion(
+ first_rtt_sample_time,
+ prev_largest_acked_sent,
+ pto,
+ lost_packets,
+ );
+ }
+
+ fn discard(&mut self, pkt: &SentPacket) {
+ if pkt.cc_outstanding() {
+ assert!(self.bytes_in_flight >= pkt.size);
+ self.bytes_in_flight -= pkt.size;
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[QlogMetric::BytesInFlight(self.bytes_in_flight)],
+ );
+ qtrace!([self], "Ignore pkt with size {}", pkt.size);
+ }
+ }
+
+ fn on_packet_sent(&mut self, pkt: &SentPacket) {
+ // Record the recovery time and exit any transient state.
+ if self.state.transient() {
+ self.recovery_start = Some(pkt.time_sent);
+ self.state.update();
+ }
+
+ if !pkt.ack_eliciting() {
+ return;
+ }
+
+ self.bytes_in_flight += pkt.size;
+ qdebug!(
+ [self],
+ "Pkt Sent len {}, bif {}, cwnd {}",
+ pkt.size,
+ self.bytes_in_flight,
+ self.congestion_window
+ );
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[QlogMetric::BytesInFlight(self.bytes_in_flight)],
+ );
+ }
+
+ /// Whether a packet can be sent immediately as a result of entering recovery.
+ #[must_use]
+ fn recovery_packet(&self) -> bool {
+ self.state == State::RecoveryStart
+ }
+}
+
+impl<T: WindowAdjustment> ClassicCongestionControl<T> {
+ pub fn new(cc_algorithm: T) -> Self {
+ Self {
+ cc_algorithm,
+ state: State::SlowStart,
+ congestion_window: CWND_INITIAL,
+ bytes_in_flight: 0,
+ acked_bytes: 0,
+ ssthresh: usize::MAX,
+ recovery_start: None,
+ qlog: NeqoQlog::disabled(),
+ }
+ }
+
+ #[cfg(test)]
+ #[must_use]
+ pub fn ssthresh(&self) -> usize {
+ self.ssthresh
+ }
+
+ fn set_state(&mut self, state: State) {
+ if self.state != state {
+ qdebug!([self], "state -> {:?}", state);
+ let old_state = self.state;
+ self.qlog.add_event(|| {
+ // No need to tell qlog about exit from transient states.
+ if old_state.transient() {
+ None
+ } else {
+ Some(::qlog::event::Event::congestion_state_updated(
+ Some(old_state.to_qlog().to_owned()),
+ state.to_qlog().to_owned(),
+ ))
+ }
+ });
+ self.state = state;
+ }
+ }
+
+ fn detect_persistent_congestion(
+ &mut self,
+ first_rtt_sample_time: Option<Instant>,
+ prev_largest_acked_sent: Option<Instant>,
+ pto: Duration,
+ lost_packets: &[SentPacket],
+ ) {
+ if first_rtt_sample_time.is_none() {
+ return;
+ }
+
+ let pc_period = pto * PERSISTENT_CONG_THRESH;
+
+ let mut last_pn = 1 << 62; // Impossibly large, but not enough to overflow.
+ let mut start = None;
+
+ // Look for the first lost packet after the previous largest acknowledged.
+ // Ignore packets that weren't ack-eliciting for the start of this range.
+ // Also, make sure to ignore any packets sent before we got an RTT estimate
+ // as we might not have sent PTO packets soon enough after those.
+ let cutoff = max(first_rtt_sample_time, prev_largest_acked_sent);
+ for p in lost_packets
+ .iter()
+ .skip_while(|p| Some(p.time_sent) < cutoff)
+ {
+ if p.pn != last_pn + 1 {
+ // Not a contiguous range of lost packets, start over.
+ start = None;
+ }
+ last_pn = p.pn;
+ if !p.ack_eliciting() {
+ // Not interesting, keep looking.
+ continue;
+ }
+ if let Some(t) = start {
+ if p.time_sent.duration_since(t) > pc_period {
+ qinfo!([self], "persistent congestion");
+ self.congestion_window = CWND_MIN;
+ self.acked_bytes = 0;
+ self.set_state(State::PersistentCongestion);
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[QlogMetric::CongestionWindow(self.congestion_window)],
+ );
+ return;
+ }
+ } else {
+ start = Some(p.time_sent);
+ }
+ }
+ }
+
+ #[must_use]
+ fn after_recovery_start(&mut self, packet: &SentPacket) -> bool {
+ // At the start of the first recovery period, if the state is
+ // transient, all packets will have been sent before recovery.
+ self.recovery_start
+ .map_or(!self.state.transient(), |t| packet.time_sent >= t)
+ }
+
+ /// Handle a congestion event.
+ fn on_congestion_event(&mut self, last_packet: &SentPacket) {
+ // Start a new congestion event if lost packet was sent after the start
+ // of the previous congestion recovery period.
+ if self.after_recovery_start(last_packet) {
+ let (cwnd, acked_bytes) = self
+ .cc_algorithm
+ .on_congestion_event(self.congestion_window, self.acked_bytes);
+ self.congestion_window = max(cwnd, CWND_MIN);
+ self.acked_bytes = acked_bytes;
+ self.ssthresh = self.congestion_window;
+ qinfo!(
+ [self],
+ "Cong event -> recovery; cwnd {}, ssthresh {}",
+ self.congestion_window,
+ self.ssthresh
+ );
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[
+ QlogMetric::CongestionWindow(self.congestion_window),
+ QlogMetric::SsThresh(self.ssthresh),
+ QlogMetric::InRecovery(true),
+ ],
+ );
+ self.set_state(State::RecoveryStart);
+ }
+ }
+
+ #[allow(clippy::unused_self)]
+ fn app_limited(&self) -> bool {
+ if self.bytes_in_flight >= self.congestion_window {
+ false
+ } else if self.state.in_slow_start() {
+ // Allow for potential doubling of the congestion window during slow start.
+ // That is, the application might not have been able to send enough to respond
+ // to increases to the congestion window.
+ self.bytes_in_flight < self.congestion_window / 2
+ } else {
+ // We're not limited if the in-flight data is within a single burst of the
+ // congestion window.
+ (self.bytes_in_flight + MAX_DATAGRAM_SIZE * PACING_BURST_SIZE) < self.congestion_window
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{ClassicCongestionControl, CWND_INITIAL, CWND_MIN, PERSISTENT_CONG_THRESH};
+ use crate::cc::new_reno::NewReno;
+ use crate::cc::{CongestionControl, CWND_INITIAL_PKTS, MAX_DATAGRAM_SIZE};
+ use crate::packet::{PacketNumber, PacketType};
+ use crate::tracking::SentPacket;
+ use std::convert::TryFrom;
+ use std::time::{Duration, Instant};
+ use test_fixture::now;
+
+ const PTO: Duration = Duration::from_millis(100);
+ const RTT: Duration = Duration::from_millis(98);
+ const ZERO: Duration = Duration::from_secs(0);
+ const EPSILON: Duration = Duration::from_nanos(1);
+ const GAP: Duration = Duration::from_secs(1);
+ /// The largest time between packets without causing persistent congestion.
+ const SUB_PC: Duration = Duration::from_millis(100 * PERSISTENT_CONG_THRESH as u64);
+ /// The minimum time between packets to cause persistent congestion.
+ /// Uses an odd expression because `Duration` arithmetic isn't `const`.
+ const PC: Duration = Duration::from_nanos(100_000_000 * (PERSISTENT_CONG_THRESH as u64) + 1);
+
+ fn cwnd_is_default(cc: &ClassicCongestionControl<NewReno>) {
+ assert_eq!(cc.cwnd(), CWND_INITIAL);
+ assert_eq!(cc.ssthresh(), usize::MAX);
+ }
+
+ fn cwnd_is_halved(cc: &ClassicCongestionControl<NewReno>) {
+ assert_eq!(cc.cwnd(), CWND_INITIAL / 2);
+ assert_eq!(cc.ssthresh(), CWND_INITIAL / 2);
+ }
+
+ #[test]
+ fn issue_876() {
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+ let time_now = now();
+ let time_before = time_now - Duration::from_millis(100);
+ let time_after = time_now + Duration::from_millis(150);
+
+ let sent_packets = &[
+ SentPacket::new(
+ PacketType::Short,
+ 1, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE - 1, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 2, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE - 2, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 3, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 4, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 5, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 6, // pn
+ time_before, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ ),
+ SentPacket::new(
+ PacketType::Short,
+ 7, // pn
+ time_after, // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE - 3, // size
+ ),
+ ];
+
+ // Send some more packets so that the cc is not app-limited.
+ for p in &sent_packets[..6] {
+ cc.on_packet_sent(p);
+ }
+ assert_eq!(cc.acked_bytes, 0);
+ cwnd_is_default(&cc);
+ assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 3);
+
+ cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[0..1]);
+
+ // We are now in recovery
+ assert!(cc.recovery_packet());
+ assert_eq!(cc.acked_bytes, 0);
+ cwnd_is_halved(&cc);
+ assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2);
+
+ // Send a packet after recovery starts
+ cc.on_packet_sent(&sent_packets[6]);
+ assert!(!cc.recovery_packet());
+ cwnd_is_halved(&cc);
+ assert_eq!(cc.acked_bytes, 0);
+ assert_eq!(cc.bytes_in_flight(), 6 * MAX_DATAGRAM_SIZE - 5);
+
+ // and ack it. cwnd increases slightly
+ cc.on_packets_acked(&sent_packets[6..]);
+ assert_eq!(cc.acked_bytes, sent_packets[6].size);
+ cwnd_is_halved(&cc);
+ assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2);
+
+ // Packet from before is lost. Should not hurt cwnd.
+ cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[1..2]);
+ assert!(!cc.recovery_packet());
+ assert_eq!(cc.acked_bytes, sent_packets[6].size);
+ cwnd_is_halved(&cc);
+ assert_eq!(cc.bytes_in_flight(), 4 * MAX_DATAGRAM_SIZE);
+ }
+
+ fn lost(pn: PacketNumber, ack_eliciting: bool, t: Duration) -> SentPacket {
+ SentPacket::new(
+ PacketType::Short,
+ pn,
+ now() + t,
+ ack_eliciting,
+ Vec::new(),
+ 100,
+ )
+ }
+
+ fn persistent_congestion(lost_packets: &[SentPacket]) -> bool {
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+ for p in lost_packets {
+ cc.on_packet_sent(p);
+ }
+
+ cc.on_packets_lost(Some(now()), None, PTO, lost_packets);
+ if cc.cwnd() == CWND_INITIAL / 2 {
+ false
+ } else if cc.cwnd() == CWND_MIN {
+ true
+ } else {
+ panic!("unexpected cwnd");
+ }
+ }
+
+ /// A span of exactly the PC threshold only reduces the window on loss.
+ #[test]
+ fn persistent_congestion_none() {
+ assert!(!persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, true, SUB_PC),
+ ]));
+ }
+
+ /// A span of just more than the PC threshold causes persistent congestion.
+ #[test]
+ fn persistent_congestion_simple() {
+ assert!(persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, true, PC),
+ ]));
+ }
+
+ /// Both packets need to be ack-eliciting.
+ #[test]
+ fn persistent_congestion_non_ack_eliciting() {
+ assert!(!persistent_congestion(&[
+ lost(1, false, ZERO),
+ lost(2, true, PC),
+ ]));
+ assert!(!persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, false, PC),
+ ]));
+ }
+
+ /// Packets in the middle, of any type, are OK.
+ #[test]
+ fn persistent_congestion_middle() {
+ assert!(persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, false, RTT),
+ lost(3, true, PC),
+ ]));
+ assert!(persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, true, RTT),
+ lost(3, true, PC),
+ ]));
+ }
+
+ /// Leading non-ack-eliciting packets are skipped.
+ #[test]
+ fn persistent_congestion_leading_non_ack_eliciting() {
+ assert!(!persistent_congestion(&[
+ lost(1, false, ZERO),
+ lost(2, true, RTT),
+ lost(3, true, PC),
+ ]));
+ assert!(persistent_congestion(&[
+ lost(1, false, ZERO),
+ lost(2, true, RTT),
+ lost(3, true, RTT + PC),
+ ]));
+ }
+
+ /// Trailing non-ack-eliciting packets aren't relevant.
+ #[test]
+ fn persistent_congestion_trailing_non_ack_eliciting() {
+ assert!(persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, true, PC),
+ lost(3, false, PC + EPSILON),
+ ]));
+ assert!(!persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, true, SUB_PC),
+ lost(3, false, PC),
+ ]));
+ }
+
+ /// Gaps in the middle, of any type, restart the count.
+ #[test]
+ fn persistent_congestion_gap_reset() {
+ assert!(!persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(3, true, PC),
+ ]));
+ assert!(!persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, true, RTT),
+ lost(4, true, GAP),
+ lost(5, true, GAP + PTO * PERSISTENT_CONG_THRESH),
+ ]));
+ }
+
+ /// A span either side of a gap will cause persistent congestion.
+ #[test]
+ fn persistent_congestion_gap_or() {
+ assert!(persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, true, PC),
+ lost(4, true, GAP),
+ lost(5, true, GAP + PTO),
+ ]));
+ assert!(persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, true, PTO),
+ lost(4, true, GAP),
+ lost(5, true, GAP + PC),
+ ]));
+ }
+
+ /// A gap only restarts after an ack-eliciting packet.
+ #[test]
+ fn persistent_congestion_gap_non_ack_eliciting() {
+ assert!(!persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, true, PTO),
+ lost(4, false, GAP),
+ lost(5, true, GAP + PC),
+ ]));
+ assert!(!persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, true, PTO),
+ lost(4, false, GAP),
+ lost(5, true, GAP + RTT),
+ lost(6, true, GAP + RTT + SUB_PC),
+ ]));
+ assert!(persistent_congestion(&[
+ lost(1, true, ZERO),
+ lost(2, true, PTO),
+ lost(4, false, GAP),
+ lost(5, true, GAP + RTT),
+ lost(6, true, GAP + RTT + PC),
+ ]));
+ }
+
+ /// Get a time, in multiples of `PTO`, relative to `now()`.
+ fn by_pto(t: u32) -> Instant {
+ now() + (PTO * t)
+ }
+
+ /// Make packets that will be made lost.
+ /// `times` is the time of sending, in multiples of `PTO`, relative to `now()`.
+ fn make_lost(times: &[u32]) -> Vec<SentPacket> {
+ times
+ .iter()
+ .enumerate()
+ .map(|(i, &t)| {
+ SentPacket::new(
+ PacketType::Short,
+ u64::try_from(i).unwrap(),
+ by_pto(t),
+ true,
+ Vec::new(),
+ 1000,
+ )
+ })
+ .collect::<Vec<_>>()
+ }
+
+ /// Call `detect_persistent_congestion` using times relative to now and the fixed PTO time.
+ /// `last_ack` and `rtt_time` are times in multiples of `PTO`, relative to `now()`,
+ /// for the time of the largest acknowledged and the first RTT sample, respectively.
+ fn persistent_congestion_by_pto(last_ack: u32, rtt_time: u32, lost: &[SentPacket]) -> bool {
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+ assert_eq!(cc.cwnd(), CWND_INITIAL);
+
+ let last_ack = Some(by_pto(last_ack));
+ let rtt_time = Some(by_pto(rtt_time));
+
+ // Persistent congestion is never declared if the RTT time is `None`.
+ cc.detect_persistent_congestion(None, None, PTO, lost);
+ assert_eq!(cc.cwnd(), CWND_INITIAL);
+ cc.detect_persistent_congestion(None, last_ack, PTO, lost);
+ assert_eq!(cc.cwnd(), CWND_INITIAL);
+
+ cc.detect_persistent_congestion(rtt_time, last_ack, PTO, lost);
+ cc.cwnd() == CWND_MIN
+ }
+
+ /// No persistent congestion can be had if there are no lost packets.
+ #[test]
+ fn persistent_congestion_no_lost() {
+ let lost = make_lost(&[]);
+ assert!(!persistent_congestion_by_pto(0, 0, &lost));
+ }
+
+ /// No persistent congestion can be had if there is only one lost packet.
+ #[test]
+ fn persistent_congestion_one_lost() {
+ let lost = make_lost(&[1]);
+ assert!(!persistent_congestion_by_pto(0, 0, &lost));
+ }
+
+ /// Persistent congestion can't happen based on old packets.
+ #[test]
+ fn persistent_congestion_past() {
+ // Packets sent prior to either the last acknowledged or the first RTT
+ // sample are not considered. So 0 is ignored.
+ let lost = make_lost(&[0, PERSISTENT_CONG_THRESH + 1, PERSISTENT_CONG_THRESH + 2]);
+ assert!(!persistent_congestion_by_pto(1, 1, &lost));
+ assert!(!persistent_congestion_by_pto(0, 1, &lost));
+ assert!(!persistent_congestion_by_pto(1, 0, &lost));
+ }
+
+ /// Persistent congestion doesn't start unless the packet is ack-eliciting.
+ #[test]
+ fn persistent_congestion_ack_eliciting() {
+ let mut lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
+ lost[0] = SentPacket::new(
+ lost[0].pt,
+ lost[0].pn,
+ lost[0].time_sent,
+ false,
+ Vec::new(),
+ lost[0].size,
+ );
+ assert!(!persistent_congestion_by_pto(0, 0, &lost));
+ }
+
+ /// Detect persistent congestion. Note that the first lost packet needs to have a time
+ /// greater than the previously acknowledged packet AND the first RTT sample. And the
+ /// difference in times needs to be greater than the persistent congestion threshold.
+ #[test]
+ fn persistent_congestion_min() {
+ let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
+ assert!(persistent_congestion_by_pto(0, 0, &lost));
+ }
+
+ /// Make sure that not having a previous largest acknowledged also results
+ /// in detecting persistent congestion. (This is not expected to happen, but
+ /// the code permits it).
+ #[test]
+ fn persistent_congestion_no_prev_ack() {
+ let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]);
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+ cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, &lost);
+ assert_eq!(cc.cwnd(), CWND_MIN);
+ }
+
+ /// The code asserts on ordering errors.
+ #[test]
+ #[should_panic]
+ fn persistent_congestion_unsorted() {
+ let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]);
+ assert!(!persistent_congestion_by_pto(0, 0, &lost));
+ }
+
+ #[test]
+ fn app_limited_slow_start() {
+ const LESS_THAN_CWND_PKTS: usize = 4;
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+
+ for i in 0..CWND_INITIAL_PKTS {
+ let sent = SentPacket::new(
+ PacketType::Short,
+ u64::try_from(i).unwrap(), // pn
+ now(), // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ cc.on_packet_sent(&sent);
+ }
+ assert_eq!(cc.bytes_in_flight(), CWND_INITIAL);
+
+ for i in 0..LESS_THAN_CWND_PKTS {
+ let acked = SentPacket::new(
+ PacketType::Short,
+ u64::try_from(i).unwrap(), // pn
+ now(), // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ cc.on_packets_acked(&[acked]);
+
+ assert_eq!(
+ cc.bytes_in_flight(),
+ (CWND_INITIAL_PKTS - i - 1) * MAX_DATAGRAM_SIZE
+ );
+ assert_eq!(cc.cwnd(), (CWND_INITIAL_PKTS + i + 1) * MAX_DATAGRAM_SIZE);
+ }
+
+ // Now we are app limited
+ for i in 4..CWND_INITIAL_PKTS {
+ let p = [SentPacket::new(
+ PacketType::Short,
+ u64::try_from(i).unwrap(), // pn
+ now(), // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ )];
+ cc.on_packets_acked(&p);
+
+ assert_eq!(
+ cc.bytes_in_flight(),
+ (CWND_INITIAL_PKTS - i - 1) * MAX_DATAGRAM_SIZE
+ );
+ assert_eq!(cc.cwnd(), (CWND_INITIAL_PKTS + 4) * MAX_DATAGRAM_SIZE);
+ }
+ }
+
+ #[test]
+ fn app_limited_congestion_avoidance() {
+ const CWND_PKTS_CA: usize = CWND_INITIAL_PKTS / 2;
+
+ let mut cc = ClassicCongestionControl::new(NewReno::default());
+
+ // Change state to congestion avoidance by introducing loss.
+
+ let p_lost = SentPacket::new(
+ PacketType::Short,
+ 1, // pn
+ now(), // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ cc.on_packet_sent(&p_lost);
+ cwnd_is_default(&cc);
+ cc.on_packets_lost(Some(now()), None, PTO, &[p_lost]);
+ cwnd_is_halved(&cc);
+ let p_not_lost = SentPacket::new(
+ PacketType::Short,
+ 1, // pn
+ now(), // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ cc.on_packet_sent(&p_not_lost);
+ cc.on_packets_acked(&[p_not_lost]);
+ cwnd_is_halved(&cc);
+ // cc is app limited therefore cwnd in not increased.
+ assert_eq!(cc.acked_bytes, 0);
+
+ // Now we are in the congestion avoidance state.
+ let mut pkts = Vec::new();
+ for i in 0..CWND_PKTS_CA {
+ let p = SentPacket::new(
+ PacketType::Short,
+ u64::try_from(i + 3).unwrap(), // pn
+ now(), // time sent
+ true, // ack eliciting
+ Vec::new(), // tokens
+ MAX_DATAGRAM_SIZE, // size
+ );
+ cc.on_packet_sent(&p);
+ pkts.push(p);
+ }
+ assert_eq!(cc.bytes_in_flight(), CWND_INITIAL / 2);
+
+ for i in 0..CWND_PKTS_CA - 2 {
+ cc.on_packets_acked(&pkts[i..=i]);
+
+ assert_eq!(
+ cc.bytes_in_flight(),
+ (CWND_PKTS_CA - i - 1) * MAX_DATAGRAM_SIZE
+ );
+ assert_eq!(cc.cwnd(), CWND_PKTS_CA * MAX_DATAGRAM_SIZE);
+ assert_eq!(cc.acked_bytes, MAX_DATAGRAM_SIZE * (i + 1));
+ }
+
+ // Now we are app limited
+ for i in CWND_PKTS_CA - 2..CWND_PKTS_CA {
+ cc.on_packets_acked(&pkts[i..=i]);
+
+ assert_eq!(
+ cc.bytes_in_flight(),
+ (CWND_PKTS_CA - i - 1) * MAX_DATAGRAM_SIZE
+ );
+ assert_eq!(cc.cwnd(), CWND_PKTS_CA * MAX_DATAGRAM_SIZE);
+ assert_eq!(cc.acked_bytes, MAX_DATAGRAM_SIZE * 3);
+ }
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/cc/mod.rs b/third_party/rust/neqo-transport/src/cc/mod.rs
new file mode 100644
index 0000000000..996054ad08
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cc/mod.rs
@@ -0,0 +1,55 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Congestion control
+#![deny(clippy::pedantic)]
+
+use crate::path::PATH_MTU_V6;
+use crate::tracking::SentPacket;
+use neqo_common::qlog::NeqoQlog;
+
+use std::fmt::{Debug, Display};
+use std::time::{Duration, Instant};
+
+mod classic_cc;
+mod new_reno;
+
+pub use classic_cc::ClassicCongestionControl;
+pub use classic_cc::{CWND_INITIAL_PKTS, CWND_MIN};
+pub use new_reno::NewReno;
+
+pub const MAX_DATAGRAM_SIZE: usize = PATH_MTU_V6;
+
+pub trait CongestionControl: Display + Debug {
+ fn set_qlog(&mut self, qlog: NeqoQlog);
+
+ fn cwnd(&self) -> usize;
+
+ fn bytes_in_flight(&self) -> usize;
+
+ fn cwnd_avail(&self) -> usize;
+
+ fn on_packets_acked(&mut self, acked_pkts: &[SentPacket]);
+
+ fn on_packets_lost(
+ &mut self,
+ first_rtt_sample_time: Option<Instant>,
+ prev_largest_acked_sent: Option<Instant>,
+ pto: Duration,
+ lost_packets: &[SentPacket],
+ );
+
+ fn recovery_packet(&self) -> bool;
+
+ fn discard(&mut self, pkt: &SentPacket);
+
+ fn on_packet_sent(&mut self, pkt: &SentPacket);
+}
+
+#[derive(Copy, Clone)]
+pub enum CongestionControlAlgorithm {
+ NewReno,
+}
diff --git a/third_party/rust/neqo-transport/src/cc/new_reno.rs b/third_party/rust/neqo-transport/src/cc/new_reno.rs
new file mode 100644
index 0000000000..a398887d61
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cc/new_reno.rs
@@ -0,0 +1,44 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Congestion control
+#![deny(clippy::pedantic)]
+
+use std::fmt::{self, Display};
+
+use crate::cc::{classic_cc::WindowAdjustment, MAX_DATAGRAM_SIZE};
+use neqo_common::qinfo;
+
+#[derive(Debug)]
+pub struct NewReno {}
+
+impl Default for NewReno {
+ fn default() -> Self {
+ Self {}
+ }
+}
+
+impl Display for NewReno {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "NewReno")?;
+ Ok(())
+ }
+}
+
+impl WindowAdjustment for NewReno {
+ fn on_packets_acked(&mut self, mut curr_cwnd: usize, mut acked_bytes: usize) -> (usize, usize) {
+ if acked_bytes >= curr_cwnd {
+ acked_bytes -= curr_cwnd;
+ curr_cwnd += MAX_DATAGRAM_SIZE;
+ qinfo!([self], "congestion avoidance += {}", MAX_DATAGRAM_SIZE);
+ }
+ (curr_cwnd, acked_bytes)
+ }
+
+ fn on_congestion_event(&mut self, curr_cwnd: usize, acked_bytes: usize) -> (usize, usize) {
+ (curr_cwnd / 2, acked_bytes / 2)
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/cid.rs b/third_party/rust/neqo-transport/src/cid.rs
new file mode 100644
index 0000000000..ef2b938c28
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/cid.rs
@@ -0,0 +1,157 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Encoding and decoding packets off the wire.
+
+use neqo_common::{hex, hex_with_len, Decoder};
+use neqo_crypto::random;
+
+use std::borrow::Borrow;
+use std::cmp::max;
+use std::convert::AsRef;
+
+pub const MAX_CONNECTION_ID_LEN: usize = 20;
+
+#[derive(Clone, Default, Eq, Hash, PartialEq)]
+pub struct ConnectionId {
+ pub(crate) cid: Vec<u8>,
+}
+
+impl ConnectionId {
+ pub fn generate(len: usize) -> Self {
+ assert!(matches!(len, 0..=MAX_CONNECTION_ID_LEN));
+ Self { cid: random(len) }
+ }
+
+ // Apply a wee bit of greasing here in picking a length between 8 and 20 bytes long.
+ pub fn generate_initial() -> Self {
+ let v = random(1);
+ // Bias selection toward picking 8 (>50% of the time).
+ let len: usize = max(8, 5 + (v[0] & (v[0] >> 4))).into();
+ Self::generate(len)
+ }
+
+ pub fn as_cid_ref(&self) -> ConnectionIdRef {
+ ConnectionIdRef::from(&self.cid[..])
+ }
+}
+
+impl AsRef<[u8]> for ConnectionId {
+ fn as_ref(&self) -> &[u8] {
+ self.borrow()
+ }
+}
+
+impl Borrow<[u8]> for ConnectionId {
+ fn borrow(&self) -> &[u8] {
+ &self.cid
+ }
+}
+
+impl From<&[u8]> for ConnectionId {
+ fn from(buf: &[u8]) -> Self {
+ Self {
+ cid: Vec::from(buf),
+ }
+ }
+}
+
+impl<'a> From<&ConnectionIdRef<'a>> for ConnectionId {
+ fn from(cidref: &ConnectionIdRef<'a>) -> Self {
+ Self {
+ cid: Vec::from(cidref.cid),
+ }
+ }
+}
+
+impl std::ops::Deref for ConnectionId {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.cid
+ }
+}
+
+impl ::std::fmt::Debug for ConnectionId {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "CID {}", hex_with_len(&self.cid))
+ }
+}
+
+impl ::std::fmt::Display for ConnectionId {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "{}", hex(&self.cid))
+ }
+}
+
+impl<'a> PartialEq<ConnectionIdRef<'a>> for ConnectionId {
+ fn eq(&self, other: &ConnectionIdRef<'a>) -> bool {
+ &self.cid[..] == other.cid
+ }
+}
+
+#[derive(Hash, Eq, PartialEq)]
+pub struct ConnectionIdRef<'a> {
+ cid: &'a [u8],
+}
+
+impl<'a> ::std::fmt::Debug for ConnectionIdRef<'a> {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "CID {}", hex_with_len(&self.cid))
+ }
+}
+
+impl<'a> ::std::fmt::Display for ConnectionIdRef<'a> {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "{}", hex(&self.cid))
+ }
+}
+
+impl<'a> From<&'a [u8]> for ConnectionIdRef<'a> {
+ fn from(cid: &'a [u8]) -> Self {
+ Self { cid }
+ }
+}
+
+impl<'a> std::ops::Deref for ConnectionIdRef<'a> {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.cid
+ }
+}
+
+impl<'a> PartialEq<ConnectionId> for ConnectionIdRef<'a> {
+ fn eq(&self, other: &ConnectionId) -> bool {
+ self.cid == &other.cid[..]
+ }
+}
+
+pub trait ConnectionIdDecoder {
+ fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>>;
+}
+
+pub trait ConnectionIdManager: ConnectionIdDecoder {
+ fn generate_cid(&mut self) -> ConnectionId;
+ fn as_decoder(&self) -> &dyn ConnectionIdDecoder;
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use test_fixture::fixture_init;
+
+ #[test]
+ fn generate_initial_cid() {
+ fixture_init();
+ for _ in 0..100 {
+ let cid = ConnectionId::generate_initial();
+ if !matches!(cid.len(), 8..=MAX_CONNECTION_ID_LEN) {
+ panic!("connection ID {:?}", cid);
+ }
+ }
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/connection/idle.rs b/third_party/rust/neqo-transport/src/connection/idle.rs
new file mode 100644
index 0000000000..9cf3be20a1
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/idle.rs
@@ -0,0 +1,90 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::cmp::{max, min};
+use std::time::{Duration, Instant};
+
+pub const LOCAL_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
+
+#[derive(Debug, Clone)]
+/// There's a little bit of different behavior for resetting idle timeout. See
+/// -transport 10.2 ("Idle Timeout").
+enum IdleTimeoutState {
+ Init,
+ PacketReceived(Instant),
+ AckElicitingPacketSent(Instant),
+}
+
+#[derive(Debug, Clone)]
+/// There's a little bit of different behavior for resetting idle timeout. See
+/// -transport 10.2 ("Idle Timeout").
+pub struct IdleTimeout {
+ timeout: Duration,
+ state: IdleTimeoutState,
+}
+
+#[cfg(test)]
+impl IdleTimeout {
+ pub fn new(timeout: Duration) -> Self {
+ Self {
+ timeout,
+ state: IdleTimeoutState::Init,
+ }
+ }
+}
+
+impl Default for IdleTimeout {
+ fn default() -> Self {
+ Self {
+ timeout: LOCAL_IDLE_TIMEOUT,
+ state: IdleTimeoutState::Init,
+ }
+ }
+}
+
+impl IdleTimeout {
+ pub fn set_peer_timeout(&mut self, peer_timeout: Duration) {
+ self.timeout = min(self.timeout, peer_timeout);
+ }
+
+ pub fn expiry(&self, now: Instant, pto: Duration) -> Instant {
+ let start = match self.state {
+ IdleTimeoutState::Init => now,
+ IdleTimeoutState::PacketReceived(t) | IdleTimeoutState::AckElicitingPacketSent(t) => t,
+ };
+ start + max(self.timeout, pto * 3)
+ }
+
+ pub fn on_packet_sent(&mut self, now: Instant) {
+ // Only reset idle timeout if we've received a packet since the last
+ // time we reset the timeout here.
+ match self.state {
+ IdleTimeoutState::AckElicitingPacketSent(_) => {}
+ IdleTimeoutState::Init | IdleTimeoutState::PacketReceived(_) => {
+ self.state = IdleTimeoutState::AckElicitingPacketSent(now);
+ }
+ }
+ }
+
+ pub fn on_packet_received(&mut self, now: Instant) {
+ // Only update if this doesn't rewind the idle timeout.
+ // We sometimes process packets after caching them, which uses
+ // the time the packet was received. That could be in the past.
+ let update = match self.state {
+ IdleTimeoutState::Init => true,
+ IdleTimeoutState::AckElicitingPacketSent(t) | IdleTimeoutState::PacketReceived(t) => {
+ t <= now
+ }
+ };
+ if update {
+ self.state = IdleTimeoutState::PacketReceived(now);
+ }
+ }
+
+ pub fn expired(&self, now: Instant, pto: Duration) -> bool {
+ now >= self.expiry(now, pto)
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/connection/mod.rs b/third_party/rust/neqo-transport/src/connection/mod.rs
new file mode 100644
index 0000000000..ba6e628809
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/mod.rs
@@ -0,0 +1,2768 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// The class implementing a QUIC connection.
+
+use std::cell::RefCell;
+use std::cmp::{max, min};
+use std::collections::HashMap;
+use std::convert::TryFrom;
+use std::fmt::{self, Debug};
+use std::mem;
+use std::net::SocketAddr;
+use std::rc::{Rc, Weak};
+use std::time::{Duration, Instant};
+
+use smallvec::SmallVec;
+
+use neqo_common::{
+ event::Provider as EventProvider, hex, hex_snip_middle, qdebug, qerror, qinfo, qlog::NeqoQlog,
+ qtrace, qwarn, Datagram, Decoder, Encoder, Role,
+};
+use neqo_crypto::agent::CertificateInfo;
+use neqo_crypto::{
+ Agent, AntiReplay, AuthenticationStatus, Cipher, Client, HandshakeState, ResumptionToken,
+ SecretAgentInfo, Server, ZeroRttChecker,
+};
+
+use crate::addr_valid::{AddressValidation, NewTokenState};
+use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdManager, ConnectionIdRef};
+use crate::crypto::{Crypto, CryptoDxState, CryptoSpace};
+use crate::dump::*;
+use crate::events::{ConnectionEvent, ConnectionEvents};
+use crate::flow_mgr::FlowMgr;
+use crate::frame::{
+ AckRange, CloseError, Frame, FrameType, StreamType, FRAME_TYPE_CONNECTION_CLOSE_APPLICATION,
+ FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT,
+};
+use crate::packet::{
+ DecryptedPacket, PacketBuilder, PacketNumber, PacketType, PublicPacket, QuicVersion,
+};
+use crate::path::Path;
+use crate::qlog;
+use crate::recovery::{LossRecovery, RecoveryToken, SendProfile, GRANULARITY};
+use crate::recv_stream::{RecvStream, RecvStreams, RECV_BUFFER_SIZE};
+use crate::send_stream::{SendStream, SendStreams};
+use crate::stats::{Stats, StatsCell};
+use crate::stream_id::{StreamId, StreamIndex, StreamIndexes};
+use crate::tparams::{
+ self, TransportParameter, TransportParameterId, TransportParameters, TransportParametersHandler,
+};
+use crate::tracking::{AckTracker, PNSpace, SentPacket};
+use crate::ConnectionParameters;
+use crate::{AppError, ConnectionError, Error, Res};
+
+mod idle;
+pub mod params;
+mod saved;
+mod state;
+
+use idle::IdleTimeout;
+pub use idle::LOCAL_IDLE_TIMEOUT;
+use saved::SavedDatagrams;
+pub use state::State;
+use state::StateSignaling;
+
+#[derive(Debug, Default)]
+struct Packet(Vec<u8>);
+
+pub const LOCAL_STREAM_LIMIT_BIDI: u64 = 16;
+pub const LOCAL_STREAM_LIMIT_UNI: u64 = 16;
+
+/// The number of Initial packets that the client will send in response
+/// to receiving an undecryptable packet during the early part of the
+/// handshake. This is a hack, but a useful one.
+const EXTRA_INITIALS: usize = 4;
+const LOCAL_MAX_DATA: u64 = 0x3FFF_FFFF_FFFF_FFFF; // 2^62-1
+
+#[derive(Debug, PartialEq, Eq)]
+pub enum ZeroRttState {
+ Init,
+ Sending,
+ AcceptedClient,
+ AcceptedServer,
+ Rejected,
+}
+
+#[derive(Clone, Debug, PartialEq)]
+/// Type returned from process() and `process_output()`. Users are required to
+/// call these repeatedly until `Callback` or `None` is returned.
+pub enum Output {
+ /// Connection requires no action.
+ None,
+ /// Connection requires the datagram be sent.
+ Datagram(Datagram),
+ /// Connection requires `process_input()` be called when the `Duration`
+ /// elapses.
+ Callback(Duration),
+}
+
+impl Output {
+ /// Convert into an `Option<Datagram>`.
+ #[must_use]
+ pub fn dgram(self) -> Option<Datagram> {
+ match self {
+ Self::Datagram(dg) => Some(dg),
+ _ => None,
+ }
+ }
+
+ /// Get a reference to the Datagram, if any.
+ pub fn as_dgram_ref(&self) -> Option<&Datagram> {
+ match self {
+ Self::Datagram(dg) => Some(dg),
+ _ => None,
+ }
+ }
+
+ /// Ask how long the caller should wait before calling back.
+ #[must_use]
+ pub fn callback(&self) -> Duration {
+ match self {
+ Self::Callback(t) => *t,
+ _ => Duration::new(0, 0),
+ }
+ }
+}
+
+/// Used by inner functions like Connection::output.
+enum SendOption {
+ /// Yes, please send this datagram.
+ Yes(Datagram),
+ /// Don't send. If this was blocked on the pacer (the arg is true).
+ No(bool),
+}
+
+impl Default for SendOption {
+ fn default() -> Self {
+ Self::No(false)
+ }
+}
+
+/// Used by `Connection::preprocess` to determine what to do
+/// with an packet before attempting to remove protection.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum PreprocessResult {
+ /// End processing and return successfully.
+ End,
+ /// Stop processing this datagram and move on to the next.
+ Next,
+ /// Continue and process this packet.
+ Continue,
+}
+
+/// Alias the common form for ConnectionIdManager.
+type CidMgr = Rc<RefCell<dyn ConnectionIdManager>>;
+
+/// An FixedConnectionIdManager produces random connection IDs of a fixed length.
+pub struct FixedConnectionIdManager {
+ len: usize,
+}
+impl FixedConnectionIdManager {
+ pub fn new(len: usize) -> Self {
+ Self { len }
+ }
+}
+impl ConnectionIdDecoder for FixedConnectionIdManager {
+ fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> {
+ dec.decode(self.len).map(ConnectionIdRef::from)
+ }
+}
+impl ConnectionIdManager for FixedConnectionIdManager {
+ fn generate_cid(&mut self) -> ConnectionId {
+ ConnectionId::generate(self.len)
+ }
+ fn as_decoder(&self) -> &dyn ConnectionIdDecoder {
+ self
+ }
+}
+
+/// `AddressValidationInfo` holds information relevant to either
+/// responding to address validation (`NewToken`, `Retry`) or generating
+/// tokens for address validation (`Server`).
+enum AddressValidationInfo {
+ None,
+ // We are a client and have information from `NEW_TOKEN`.
+ NewToken(Vec<u8>),
+ // We are a client and have received a `Retry` packet.
+ Retry {
+ token: Vec<u8>,
+ retry_source_cid: ConnectionId,
+ },
+ // We are a server and can generate tokens.
+ Server(Weak<RefCell<AddressValidation>>),
+}
+
+impl AddressValidationInfo {
+ pub fn token(&self) -> &[u8] {
+ match self {
+ Self::NewToken(token) | Self::Retry { token, .. } => &token,
+ _ => &[],
+ }
+ }
+
+ pub fn generate_new_token(
+ &mut self,
+ peer_address: SocketAddr,
+ now: Instant,
+ ) -> Option<Vec<u8>> {
+ match self {
+ Self::Server(ref w) => {
+ if let Some(validation) = w.upgrade() {
+ validation
+ .borrow()
+ .generate_new_token(peer_address, now)
+ .ok()
+ } else {
+ None
+ }
+ }
+ Self::None => None,
+ _ => unreachable!("called a server function on a client"),
+ }
+ }
+}
+
+/// A QUIC Connection
+///
+/// First, create a new connection using `new_client()` or `new_server()`.
+///
+/// For the life of the connection, handle activity in the following manner:
+/// 1. Perform operations using the `stream_*()` methods.
+/// 1. Call `process_input()` when a datagram is received or the timer
+/// expires. Obtain information on connection state changes by checking
+/// `events()`.
+/// 1. Having completed handling current activity, repeatedly call
+/// `process_output()` for packets to send, until it returns `Output::Callback`
+/// or `Output::None`.
+///
+/// After the connection is closed (either by calling `close()` or by the
+/// remote) continue processing until `state()` returns `Closed`.
+pub struct Connection {
+ role: Role,
+ state: State,
+ tps: Rc<RefCell<TransportParametersHandler>>,
+ /// What we are doing with 0-RTT.
+ zero_rtt_state: ZeroRttState,
+ /// This object will generate connection IDs for the connection.
+ cid_manager: CidMgr,
+ /// Network paths. Right now, this tracks at most one path, so it uses `Option`.
+ path: Option<Path>,
+ /// The connection IDs that we will accept.
+ /// This includes any we advertise in NEW_CONNECTION_ID that haven't been bound to a path yet.
+ /// During the handshake at the server, it also includes the randomized DCID pick by the client.
+ valid_cids: Vec<ConnectionId>,
+ address_validation: AddressValidationInfo,
+
+ /// The source connection ID that this endpoint uses for the handshake.
+ /// Since we need to communicate this to our peer in tparams, setting this
+ /// value is part of constructing the struct.
+ local_initial_source_cid: ConnectionId,
+ /// The source connection ID from the first packet from the other end.
+ /// This is checked against the peer's transport parameters.
+ remote_initial_source_cid: Option<ConnectionId>,
+ /// The destination connection ID from the first packet from the client.
+ /// This is checked by the client against the server's transport parameters.
+ original_destination_cid: Option<ConnectionId>,
+
+ /// We sometimes save a datagram against the possibility that keys will later
+ /// become available. This avoids reporting packets as dropped during the handshake
+ /// when they are either just reordered or we haven't been able to install keys yet.
+ /// In particular, this occurs when asynchronous certificate validation happens.
+ saved_datagrams: SavedDatagrams,
+
+ pub(crate) crypto: Crypto,
+ pub(crate) acks: AckTracker,
+ idle_timeout: IdleTimeout,
+ pub(crate) indexes: StreamIndexes,
+ connection_ids: HashMap<u64, (ConnectionId, [u8; 16])>, // (sequence number, (connection id, reset token))
+ pub(crate) send_streams: SendStreams,
+ pub(crate) recv_streams: RecvStreams,
+ pub(crate) flow_mgr: Rc<RefCell<FlowMgr>>,
+ state_signaling: StateSignaling,
+ loss_recovery: LossRecovery,
+ events: ConnectionEvents,
+ new_token: NewTokenState,
+ stats: StatsCell,
+ qlog: NeqoQlog,
+ /// A session ticket was received without NEW_TOKEN,
+ /// this is when that turns into an event without NEW_TOKEN.
+ release_resumption_token_timer: Option<Instant>,
+ quic_version: QuicVersion,
+}
+
+impl Debug for Connection {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{:?} Connection: {:?} {:?}",
+ self.role, self.state, self.path
+ )
+ }
+}
+
+impl Connection {
+ /// Create a new QUIC connection with Client role.
+ pub fn new_client(
+ server_name: &str,
+ protocols: &[impl AsRef<str>],
+ cid_manager: CidMgr,
+ local_addr: SocketAddr,
+ remote_addr: SocketAddr,
+ conn_params: &ConnectionParameters,
+ ) -> Res<Self> {
+ let dcid = ConnectionId::generate_initial();
+ let mut c = Self::new(
+ Role::Client,
+ Client::new(server_name)?.into(),
+ cid_manager,
+ protocols,
+ None,
+ conn_params,
+ )?;
+ c.crypto
+ .states
+ .init(conn_params.get_quic_version(), Role::Client, &dcid);
+ c.original_destination_cid = Some(dcid);
+ c.initialize_path(local_addr, remote_addr);
+ Ok(c)
+ }
+
+ /// Create a new QUIC connection with Server role.
+ pub fn new_server(
+ certs: &[impl AsRef<str>],
+ protocols: &[impl AsRef<str>],
+ cid_manager: CidMgr,
+ conn_params: &ConnectionParameters,
+ ) -> Res<Self> {
+ Self::new(
+ Role::Server,
+ Server::new(certs)?.into(),
+ cid_manager,
+ protocols,
+ None,
+ conn_params,
+ )
+ }
+
+ pub fn server_enable_0rtt(
+ &mut self,
+ anti_replay: &AntiReplay,
+ zero_rtt_checker: impl ZeroRttChecker + 'static,
+ ) -> Res<()> {
+ self.crypto
+ .server_enable_0rtt(self.tps.clone(), anti_replay, zero_rtt_checker)
+ }
+
+ fn set_tp_defaults(tps: &mut TransportParameters) {
+ tps.set_integer(
+ tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL,
+ u64::try_from(RECV_BUFFER_SIZE).unwrap(),
+ );
+ tps.set_integer(
+ tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE,
+ u64::try_from(RECV_BUFFER_SIZE).unwrap(),
+ );
+ tps.set_integer(
+ tparams::INITIAL_MAX_STREAM_DATA_UNI,
+ u64::try_from(RECV_BUFFER_SIZE).unwrap(),
+ );
+ tps.set_integer(tparams::INITIAL_MAX_STREAMS_BIDI, LOCAL_STREAM_LIMIT_BIDI);
+ tps.set_integer(tparams::INITIAL_MAX_STREAMS_UNI, LOCAL_STREAM_LIMIT_UNI);
+ tps.set_integer(tparams::INITIAL_MAX_DATA, LOCAL_MAX_DATA);
+ tps.set_integer(
+ tparams::IDLE_TIMEOUT,
+ u64::try_from(LOCAL_IDLE_TIMEOUT.as_millis()).unwrap(),
+ );
+ tps.set_empty(tparams::DISABLE_MIGRATION);
+ tps.set_empty(tparams::GREASE_QUIC_BIT);
+ }
+
+ fn new(
+ role: Role,
+ agent: Agent,
+ cid_manager: CidMgr,
+ protocols: &[impl AsRef<str>],
+ path: Option<Path>,
+ conn_params: &ConnectionParameters,
+ ) -> Res<Self> {
+ let tphandler = Rc::new(RefCell::new(TransportParametersHandler::default()));
+ Self::set_tp_defaults(&mut tphandler.borrow_mut().local);
+ tphandler.borrow_mut().local.set_integer(
+ tparams::INITIAL_MAX_STREAMS_BIDI,
+ conn_params.get_max_streams(StreamType::BiDi),
+ );
+ tphandler.borrow_mut().local.set_integer(
+ tparams::INITIAL_MAX_STREAMS_UNI,
+ conn_params.get_max_streams(StreamType::UniDi),
+ );
+ let local_initial_source_cid = cid_manager.borrow_mut().generate_cid();
+ tphandler.borrow_mut().local.set_bytes(
+ tparams::INITIAL_SOURCE_CONNECTION_ID,
+ local_initial_source_cid.to_vec(),
+ );
+
+ let crypto = Crypto::new(agent, protocols, tphandler.clone())?;
+
+ let stats = StatsCell::default();
+ let c = Self {
+ role,
+ state: State::Init,
+ cid_manager,
+ path,
+ valid_cids: Vec::new(),
+ tps: tphandler,
+ zero_rtt_state: ZeroRttState::Init,
+ address_validation: AddressValidationInfo::None,
+ local_initial_source_cid,
+ remote_initial_source_cid: None,
+ original_destination_cid: None,
+ saved_datagrams: SavedDatagrams::default(),
+ crypto,
+ acks: AckTracker::default(),
+ idle_timeout: IdleTimeout::default(),
+ indexes: StreamIndexes::new(),
+ connection_ids: HashMap::new(),
+ send_streams: SendStreams::default(),
+ recv_streams: RecvStreams::default(),
+ flow_mgr: Rc::new(RefCell::new(FlowMgr::default())),
+ state_signaling: StateSignaling::Idle,
+ loss_recovery: LossRecovery::new(conn_params.get_cc_algorithm(), stats.clone()),
+ events: ConnectionEvents::default(),
+ new_token: NewTokenState::new(role),
+ stats,
+ qlog: NeqoQlog::disabled(),
+ release_resumption_token_timer: None,
+ quic_version: conn_params.get_quic_version(),
+ };
+ c.stats.borrow_mut().init(format!("{}", c));
+ Ok(c)
+ }
+
+ /// Get the local path.
+ pub fn path(&self) -> Option<&Path> {
+ self.path.as_ref()
+ }
+
+ /// Set or clear the qlog for this connection.
+ pub fn set_qlog(&mut self, qlog: NeqoQlog) {
+ self.loss_recovery.set_qlog(qlog.clone());
+ self.qlog = qlog;
+ }
+
+ /// Get the qlog (if any) for this connection.
+ pub fn qlog_mut(&mut self) -> &mut NeqoQlog {
+ &mut self.qlog
+ }
+
+ /// Get the original destination connection id for this connection. This
+ /// will always be present for Role::Client but not if Role::Server is in
+ /// State::Init.
+ pub fn odcid(&self) -> Option<&ConnectionId> {
+ self.original_destination_cid.as_ref()
+ }
+
+ /// Set a local transport parameter, possibly overriding a default value.
+ pub fn set_local_tparam(&self, tp: TransportParameterId, value: TransportParameter) -> Res<()> {
+ if *self.state() == State::Init {
+ self.tps.borrow_mut().local.set(tp, value);
+ Ok(())
+ } else {
+ qerror!("Current state: {:?}", self.state());
+ qerror!("Cannot set local tparam when not in an initial connection state.");
+ Err(Error::ConnectionState)
+ }
+ }
+
+ /// `odcid` is their original choice for our CID, which we get from the Retry token.
+ /// `remote_cid` is the value from the Source Connection ID field of
+ /// an incoming packet: what the peer wants us to use now.
+ /// `retry_cid` is what we asked them to use when we sent the Retry.
+ pub(crate) fn set_retry_cids(
+ &mut self,
+ odcid: ConnectionId,
+ remote_cid: ConnectionId,
+ retry_cid: ConnectionId,
+ ) {
+ debug_assert_eq!(self.role, Role::Server);
+ qtrace!(
+ [self],
+ "Retry CIDs: odcid={} remote={} retry={}",
+ odcid,
+ remote_cid,
+ retry_cid
+ );
+ // We advertise "our" choices in transport parameters.
+ let local_tps = &mut self.tps.borrow_mut().local;
+ local_tps.set_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID, odcid.to_vec());
+ local_tps.set_bytes(tparams::RETRY_SOURCE_CONNECTION_ID, retry_cid.to_vec());
+
+ // ...and save their choices for later validation.
+ self.remote_initial_source_cid = Some(remote_cid);
+ }
+
+ fn retry_sent(&self) -> bool {
+ self.tps
+ .borrow()
+ .local
+ .get_bytes(tparams::RETRY_SOURCE_CONNECTION_ID)
+ .is_some()
+ }
+
+ /// Set ALPN preferences. Strings that appear earlier in the list are given
+ /// higher preference.
+ pub fn set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()> {
+ self.crypto.tls.set_alpn(protocols)?;
+ Ok(())
+ }
+
+ /// Enable a set of ciphers.
+ pub fn set_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()> {
+ if self.state != State::Init {
+ qerror!([self], "Cannot enable ciphers in state {:?}", self.state);
+ return Err(Error::ConnectionState);
+ }
+ self.crypto.tls.set_ciphers(ciphers)?;
+ Ok(())
+ }
+
+ fn make_resumption_token(&mut self) -> ResumptionToken {
+ debug_assert_eq!(self.role, Role::Client);
+ debug_assert!(self.crypto.has_resumption_token());
+ self.crypto
+ .create_resumption_token(
+ self.new_token.take_token(),
+ self.tps
+ .borrow()
+ .remote
+ .as_ref()
+ .expect("should have transport parameters"),
+ u64::try_from(self.loss_recovery.rtt().as_millis()).unwrap_or(0),
+ )
+ .unwrap()
+ }
+
+ fn create_resumption_token(&mut self, now: Instant) {
+ if self.role == Role::Server || self.state < State::Connected {
+ return;
+ }
+
+ qtrace!(
+ [self],
+ "Maybe create resumption token: {} {}",
+ self.crypto.has_resumption_token(),
+ self.new_token.has_token()
+ );
+
+ while self.crypto.has_resumption_token() && self.new_token.has_token() {
+ let token = self.make_resumption_token();
+ self.events.client_resumption_token(token);
+ }
+
+ // If we have a resumption ticket check or set a timer.
+ if self.crypto.has_resumption_token() {
+ let arm = if let Some(expiration_time) = self.release_resumption_token_timer {
+ if expiration_time <= now {
+ let token = self.make_resumption_token();
+ self.events.client_resumption_token(token);
+ self.release_resumption_token_timer = None;
+
+ // This means that we release one session ticket every 3 PTOs
+ // if no NEW_TOKEN frame is received.
+ self.crypto.has_resumption_token()
+ } else {
+ false
+ }
+ } else {
+ true
+ };
+
+ if arm {
+ self.release_resumption_token_timer =
+ Some(now + 3 * self.loss_recovery.pto_raw(PNSpace::ApplicationData));
+ }
+ }
+ }
+
+ /// Get a resumption token. The correct way to obtain a resumption token is
+ /// waiting for the `ConnectionEvent::ResumptionToken` event. However, some
+ /// servers don't send `NEW_TOKEN` frames and so that event might be slow in
+ /// arriving. This is especially a problem for short-lived connections, where
+ /// the connection is closed before any events are released. This retrieves
+ /// the token, without waiting for the `NEW_TOKEN` frame to arrive.
+ ///
+ /// # Panics
+ /// If this is called on a server.
+ pub fn take_resumption_token(&mut self, now: Instant) -> Option<ResumptionToken> {
+ assert_eq!(self.role, Role::Client);
+
+ if self.crypto.has_resumption_token() {
+ let token = self.make_resumption_token();
+ if self.crypto.has_resumption_token() {
+ self.release_resumption_token_timer =
+ Some(now + 3 * self.loss_recovery.pto_raw(PNSpace::ApplicationData));
+ }
+ Some(token)
+ } else {
+ None
+ }
+ }
+
+ /// Enable resumption, using a token previously provided.
+ /// This can only be called once and only on the client.
+ /// After calling the function, it should be possible to attempt 0-RTT
+ /// if the token supports that.
+ pub fn enable_resumption(&mut self, now: Instant, token: impl AsRef<[u8]>) -> Res<()> {
+ if self.state != State::Init {
+ qerror!([self], "set token in state {:?}", self.state);
+ return Err(Error::ConnectionState);
+ }
+ if self.role == Role::Server {
+ return Err(Error::ConnectionState);
+ }
+
+ qinfo!(
+ [self],
+ "resumption token {}",
+ hex_snip_middle(token.as_ref())
+ );
+ let mut dec = Decoder::from(token.as_ref());
+
+ let smoothed_rtt =
+ Duration::from_millis(dec.decode_varint().ok_or(Error::InvalidResumptionToken)?);
+ qtrace!([self], " RTT {:?}", smoothed_rtt);
+
+ let tp_slice = dec.decode_vvec().ok_or(Error::InvalidResumptionToken)?;
+ qtrace!([self], " transport parameters {}", hex(&tp_slice));
+ let mut dec_tp = Decoder::from(tp_slice);
+ let tp =
+ TransportParameters::decode(&mut dec_tp).map_err(|_| Error::InvalidResumptionToken)?;
+
+ let init_token = dec.decode_vvec().ok_or(Error::InvalidResumptionToken)?;
+ qtrace!([self], " Initial token {}", hex(&init_token));
+
+ let tok = dec.decode_remainder();
+ qtrace!([self], " TLS token {}", hex(&tok));
+ match self.crypto.tls {
+ Agent::Client(ref mut c) => {
+ let res = c.enable_resumption(&tok);
+ if let Err(e) = res {
+ self.absorb_error::<Error>(now, Err(Error::from(e)));
+ return Ok(());
+ }
+ }
+ Agent::Server(_) => return Err(Error::WrongRole),
+ }
+
+ self.tps.borrow_mut().remote_0rtt = Some(tp);
+ if !init_token.is_empty() {
+ self.address_validation = AddressValidationInfo::NewToken(init_token.to_vec());
+ }
+ if smoothed_rtt > GRANULARITY {
+ self.loss_recovery.set_initial_rtt(smoothed_rtt);
+ }
+ self.set_initial_limits();
+ // Start up TLS, which has the effect of setting up all the necessary
+ // state for 0-RTT. This only stages the CRYPTO frames.
+ let res = self.client_start(now);
+ self.absorb_error(now, res);
+ Ok(())
+ }
+
+ pub(crate) fn set_validation(&mut self, validation: Rc<RefCell<AddressValidation>>) {
+ qtrace!([self], "Enabling NEW_TOKEN");
+ assert_eq!(self.role, Role::Server);
+ self.address_validation = AddressValidationInfo::Server(Rc::downgrade(&validation));
+ }
+
+ /// Send a TLS session ticket AND a NEW_TOKEN frame (if possible).
+ pub fn send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<()> {
+ if self.role == Role::Client {
+ return Err(Error::WrongRole);
+ }
+
+ let tps = &self.tps;
+ if let Agent::Server(ref mut s) = self.crypto.tls {
+ let mut enc = Encoder::default();
+ enc.encode_vvec_with(|mut enc_inner| {
+ tps.borrow().local.encode(&mut enc_inner);
+ });
+ enc.encode(extra);
+ let records = s.send_ticket(now, &enc)?;
+ qinfo!([self], "send session ticket {}", hex(&enc));
+ self.crypto.buffer_records(records)?;
+ } else {
+ unreachable!();
+ }
+
+ // If we are able, also send a NEW_TOKEN frame.
+ // This should be recording all remote addresses that are valid,
+ // but there are just 0 or 1 in the current implementation.
+ if let Some(p) = self.path.as_ref() {
+ if let Some(token) = self
+ .address_validation
+ .generate_new_token(p.remote_address(), now)
+ {
+ self.new_token.send_new_token(token);
+ }
+ }
+
+ Ok(())
+ }
+
+ pub fn tls_info(&self) -> Option<&SecretAgentInfo> {
+ self.crypto.tls.info()
+ }
+
+ /// Get the peer's certificate chain and other info.
+ pub fn peer_certificate(&self) -> Option<CertificateInfo> {
+ self.crypto.tls.peer_certificate()
+ }
+
+ /// Call by application when the peer cert has been verified.
+ ///
+ /// This panics if there is no active peer. It's OK to call this
+ /// when authentication isn't needed, that will likely only cause
+ /// the connection to fail. However, if no packets have been
+ /// exchanged, it's not OK.
+ pub fn authenticated(&mut self, status: AuthenticationStatus, now: Instant) {
+ qinfo!([self], "Authenticated {:?}", status);
+ self.crypto.tls.authenticated(status);
+ let res = self.handshake(now, PNSpace::Handshake, None);
+ self.absorb_error(now, res);
+ self.process_saved(now);
+ }
+
+ /// Get the role of the connection.
+ pub fn role(&self) -> Role {
+ self.role
+ }
+
+ /// Get the state of the connection.
+ pub fn state(&self) -> &State {
+ &self.state
+ }
+
+ /// Get the 0-RTT state of the connection.
+ pub fn zero_rtt_state(&self) -> &ZeroRttState {
+ &self.zero_rtt_state
+ }
+
+ /// Get a snapshot of collected statistics.
+ pub fn stats(&self) -> Stats {
+ self.stats.borrow().clone()
+ }
+
+ // This function wraps a call to another function and sets the connection state
+ // properly if that call fails.
+ fn capture_error<T>(&mut self, now: Instant, frame_type: FrameType, res: Res<T>) -> Res<T> {
+ if let Err(v) = &res {
+ #[cfg(debug_assertions)]
+ let msg = format!("{:?}", v);
+ #[cfg(not(debug_assertions))]
+ let msg = "";
+ let error = ConnectionError::Transport(v.clone());
+ match &self.state {
+ State::Closing { error: err, .. }
+ | State::Draining { error: err, .. }
+ | State::Closed(err) => {
+ qwarn!([self], "Closing again after error {:?}", err);
+ }
+ State::Init => {
+ // We have not even sent anything just close the connection without sending any error.
+ // This may happen when client_start fails.
+ self.set_state(State::Closed(error));
+ }
+ State::WaitInitial => {
+ // We don't have any state yet, so don't bother with
+ // the closing state, just send one CONNECTION_CLOSE.
+ self.state_signaling.close(error.clone(), frame_type, msg);
+ self.set_state(State::Closed(error));
+ }
+ _ => {
+ self.state_signaling.close(error.clone(), frame_type, msg);
+ if matches!(v, Error::KeysExhausted) {
+ self.set_state(State::Closed(error));
+ } else {
+ self.set_state(State::Closing {
+ error,
+ timeout: self.get_closing_period_time(now),
+ });
+ }
+ }
+ }
+ }
+ res
+ }
+
+ /// For use with process_input(). Errors there can be ignored, but this
+ /// needs to ensure that the state is updated.
+ fn absorb_error<T>(&mut self, now: Instant, res: Res<T>) -> Option<T> {
+ self.capture_error(now, 0, res).ok()
+ }
+
+ fn process_timer(&mut self, now: Instant) {
+ if let State::Closing { error, timeout } | State::Draining { error, timeout } = &self.state
+ {
+ if *timeout <= now {
+ // Close timeout expired, move to Closed
+ let st = State::Closed(error.clone());
+ self.set_state(st);
+ qinfo!("Closing timer expired");
+ return;
+ }
+ }
+ if let State::Closed(_) = self.state {
+ qdebug!("Timer fired while closed");
+ return;
+ }
+
+ let pto = self.loss_recovery.pto_raw(PNSpace::ApplicationData);
+ if self.idle_timeout.expired(now, pto) {
+ qinfo!([self], "idle timeout expired");
+ self.set_state(State::Closed(ConnectionError::Transport(
+ Error::IdleTimeout,
+ )));
+ return;
+ }
+
+ self.cleanup_streams();
+
+ let res = self.crypto.states.check_key_update(now);
+ self.absorb_error(now, res);
+
+ let lost = self.loss_recovery.timeout(now);
+ self.handle_lost_packets(&lost);
+ qlog::packets_lost(&mut self.qlog, &lost);
+
+ if self.release_resumption_token_timer.is_some() {
+ self.create_resumption_token(now);
+ }
+ }
+
+ /// Process new input datagrams on the connection.
+ pub fn process_input(&mut self, d: Datagram, now: Instant) {
+ let res = self.input(d, now);
+ self.absorb_error(now, res);
+ self.process_saved(now);
+ self.cleanup_streams();
+ }
+
+ /// Get the time that we next need to be called back, relative to `now`.
+ fn next_delay(&mut self, now: Instant, paced: bool) -> Duration {
+ qtrace!([self], "Get callback delay {:?}", now);
+
+ // Only one timer matters when closing...
+ if let State::Closing { timeout, .. } | State::Draining { timeout, .. } = self.state {
+ return timeout.duration_since(now);
+ }
+
+ let mut delays = SmallVec::<[_; 6]>::new();
+ if let Some(ack_time) = self.acks.ack_time(now) {
+ qtrace!([self], "Delayed ACK timer {:?}", ack_time);
+ delays.push(ack_time);
+ }
+
+ let pto = self.loss_recovery.pto_raw(PNSpace::ApplicationData);
+ let idle_time = self.idle_timeout.expiry(now, pto);
+ qtrace!([self], "Idle timer {:?}", idle_time);
+ delays.push(idle_time);
+
+ if let Some(lr_time) = self.loss_recovery.next_timeout() {
+ qtrace!([self], "Loss recovery timer {:?}", lr_time);
+ delays.push(lr_time);
+ }
+
+ if let Some(key_update_time) = self.crypto.states.update_time() {
+ qtrace!([self], "Key update timer {:?}", key_update_time);
+ delays.push(key_update_time);
+ }
+
+ if paced {
+ if let Some(pace_time) = self.loss_recovery.next_paced() {
+ qtrace!([self], "Pacing timer {:?}", pace_time);
+ delays.push(pace_time);
+ }
+ }
+
+ // `release_resumption_token_timer` is not considered here, because
+ // it is not important enough to force the application to set a
+ // timeout for it It is expected thatt other activities will
+ // drive it.
+
+ let earliest = delays.into_iter().min().unwrap();
+ // TODO(agrover, mt) - need to analyze and fix #47
+ // rather than just clamping to zero here.
+ qdebug!(
+ [self],
+ "delay duration {:?}",
+ max(now, earliest).duration_since(now)
+ );
+ debug_assert!(earliest > now);
+ max(now, earliest).duration_since(now)
+ }
+
+ /// Get output packets, as a result of receiving packets, or actions taken
+ /// by the application.
+ /// Returns datagrams to send, and how long to wait before calling again
+ /// even if no incoming packets.
+ #[must_use = "Output of the process_output function must be handled"]
+ pub fn process_output(&mut self, now: Instant) -> Output {
+ qtrace!([self], "process_output {:?} {:?}", self.state, now);
+
+ if self.state == State::Init {
+ if self.role == Role::Client {
+ let res = self.client_start(now);
+ self.absorb_error(now, res);
+ }
+ } else {
+ self.process_timer(now);
+ }
+
+ match self.output(now) {
+ SendOption::Yes(dgram) => Output::Datagram(dgram),
+ SendOption::No(paced) => match self.state {
+ State::Init | State::Closed(_) => Output::None,
+ State::Closing { timeout, .. } | State::Draining { timeout, .. } => {
+ Output::Callback(timeout.duration_since(now))
+ }
+ _ => Output::Callback(self.next_delay(now, paced)),
+ },
+ }
+ }
+
+ /// Process input and generate output.
+ #[must_use = "Output of the process function must be handled"]
+ pub fn process(&mut self, dgram: Option<Datagram>, now: Instant) -> Output {
+ if let Some(d) = dgram {
+ let res = self.input(d, now);
+ self.absorb_error(now, res);
+ self.process_saved(now);
+ }
+ self.process_output(now)
+ }
+
+ fn is_valid_cid(&self, cid: &ConnectionIdRef) -> bool {
+ self.valid_cids.iter().any(|c| c == cid) || self.path.iter().any(|p| p.valid_local_cid(cid))
+ }
+
+ fn handle_retry(&mut self, packet: &PublicPacket) -> Res<()> {
+ qinfo!([self], "received Retry");
+ if matches!(self.address_validation, AddressValidationInfo::Retry { .. }) {
+ self.stats.borrow_mut().pkt_dropped("Extra Retry");
+ return Ok(());
+ }
+ if packet.token().is_empty() {
+ self.stats.borrow_mut().pkt_dropped("Retry without a token");
+ return Ok(());
+ }
+ if !packet.is_valid_retry(&self.original_destination_cid.as_ref().unwrap()) {
+ self.stats
+ .borrow_mut()
+ .pkt_dropped("Retry with bad integrity tag");
+ return Ok(());
+ }
+ if let Some(p) = &mut self.path {
+ // At this point, we shouldn't have a remote connection ID for the path.
+ p.set_remote_cid(packet.scid());
+ } else {
+ qinfo!([self], "No path, but we received a Retry");
+ return Err(Error::InternalError);
+ };
+
+ let retry_scid = ConnectionId::from(packet.scid());
+ qinfo!(
+ [self],
+ "Valid Retry received, token={} scid={}",
+ hex(packet.token()),
+ retry_scid
+ );
+
+ let lost_packets = self.loss_recovery.retry();
+ self.handle_lost_packets(&lost_packets);
+
+ self.crypto
+ .states
+ .init(self.quic_version, self.role, &retry_scid);
+ self.address_validation = AddressValidationInfo::Retry {
+ token: packet.token().to_vec(),
+ retry_source_cid: retry_scid,
+ };
+ Ok(())
+ }
+
+ fn discard_keys(&mut self, space: PNSpace, now: Instant) {
+ if self.crypto.discard(space) {
+ qinfo!([self], "Drop packet number space {}", space);
+ self.loss_recovery.discard(space, now);
+ self.acks.drop_space(space);
+ }
+ }
+
+ fn token_equal(a: &[u8; 16], b: &[u8; 16]) -> bool {
+ // rustc might decide to optimize this and make this non-constant-time
+ // with respect to `t`, but it doesn't appear to currently.
+ let mut c = 0;
+ for (&a, &b) in a.iter().zip(b) {
+ c |= a ^ b;
+ }
+ c == 0
+ }
+
+ fn is_stateless_reset(&self, d: &Datagram) -> bool {
+ if d.len() < 16 {
+ return false;
+ }
+ let token = <&[u8; 16]>::try_from(&d[d.len() - 16..]).unwrap();
+ // TODO(mt) only check the path that matches the datagram.
+ self.path
+ .as_ref()
+ .map(|p| p.reset_token())
+ .flatten()
+ .map_or(false, |t| Self::token_equal(t, token))
+ }
+
+ fn check_stateless_reset<'a, 'b>(
+ &'a mut self,
+ d: &'b Datagram,
+ first: bool,
+ now: Instant,
+ ) -> Res<()> {
+ if first && self.is_stateless_reset(d) {
+ // Failing to process a packet in a datagram might
+ // indicate that there is a stateless reset present.
+ qdebug!([self], "Stateless reset: {}", hex(&d[d.len() - 16..]));
+ self.state_signaling.reset();
+ self.set_state(State::Draining {
+ error: ConnectionError::Transport(Error::StatelessReset),
+ timeout: self.get_closing_period_time(now),
+ });
+ Err(Error::StatelessReset)
+ } else {
+ Ok(())
+ }
+ }
+
+ /// Process any saved datagrams that might be available for processing.
+ fn process_saved(&mut self, now: Instant) {
+ while let Some(cspace) = self.saved_datagrams.available() {
+ qdebug!([self], "process saved for space {:?}", cspace);
+ debug_assert!(self.crypto.states.rx_hp(cspace).is_some());
+ for saved in self.saved_datagrams.take_saved() {
+ qtrace!([self], "input saved @{:?}: {:?}", saved.t, saved.d);
+ let res = self.input(saved.d, saved.t);
+ self.absorb_error(now, res);
+ }
+ }
+ }
+
+ /// In case a datagram arrives that we can only partially process, save any
+ /// part that we don't have keys for.
+ fn save_datagram(&mut self, cspace: CryptoSpace, d: Datagram, remaining: usize, now: Instant) {
+ let d = if remaining < d.len() {
+ Datagram::new(d.source(), d.destination(), &d[d.len() - remaining..])
+ } else {
+ d
+ };
+ self.saved_datagrams.save(cspace, d, now);
+ self.stats.borrow_mut().saved_datagrams += 1;
+ }
+
+ /// Perform any processing that we might have to do on packets prior to
+ /// attempting to remove protection.
+ fn preprocess(
+ &mut self,
+ packet: &PublicPacket,
+ dcid: Option<&ConnectionId>,
+ now: Instant,
+ ) -> Res<PreprocessResult> {
+ if dcid.map_or(false, |d| d != packet.dcid()) {
+ self.stats
+ .borrow_mut()
+ .pkt_dropped("Coalesced packet has different DCID");
+ return Ok(PreprocessResult::Next);
+ }
+
+ match (packet.packet_type(), &self.state, &self.role) {
+ (PacketType::Initial, State::Init, Role::Server) => {
+ if !packet.is_valid_initial() {
+ self.stats.borrow_mut().pkt_dropped("Invalid Initial");
+ return Ok(PreprocessResult::Next);
+ }
+ qinfo!(
+ [self],
+ "Received valid Initial packet with scid {:?} dcid {:?}",
+ packet.scid(),
+ packet.dcid()
+ );
+ self.set_state(State::WaitInitial);
+ self.loss_recovery.start_pacer(now);
+ self.crypto
+ .states
+ .init(self.quic_version, self.role, &packet.dcid());
+
+ // We need to make sure that we set this transport parameter.
+ // This has to happen prior to processing the packet so that
+ // the TLS handshake has all it needs.
+ if !self.retry_sent() {
+ self.tps.borrow_mut().local.set_bytes(
+ tparams::ORIGINAL_DESTINATION_CONNECTION_ID,
+ packet.dcid().to_vec(),
+ )
+ }
+ }
+ (PacketType::VersionNegotiation, State::WaitInitial, Role::Client) => {
+ match packet.supported_versions() {
+ Ok(versions) => {
+ if versions.is_empty()
+ || versions.contains(&self.quic_version.as_u32())
+ || packet.dcid() != self.odcid().unwrap()
+ || matches!(self.address_validation, AddressValidationInfo::Retry { .. })
+ {
+ // Ignore VersionNegotiation packets that contain the current version.
+ // Or don't have the right connection ID.
+ // Or are received after a Retry.
+ self.stats.borrow_mut().pkt_dropped("Invalid VN");
+ return Ok(PreprocessResult::End);
+ }
+
+ self.set_state(State::Closed(ConnectionError::Transport(
+ Error::VersionNegotiation,
+ )));
+ return Err(Error::VersionNegotiation);
+ }
+ Err(_) => {
+ self.stats.borrow_mut().pkt_dropped("Invalid VN");
+ return Ok(PreprocessResult::End);
+ }
+ }
+ }
+ (PacketType::Retry, State::WaitInitial, Role::Client) => {
+ self.handle_retry(packet)?;
+ return Ok(PreprocessResult::Next);
+ }
+ (PacketType::Handshake, State::WaitInitial, Role::Client)
+ | (PacketType::Short, State::WaitInitial, Role::Client) => {
+ // This packet can't be processed now, but it could be a sign
+ // that Initial packets were lost.
+ // Resend Initial CRYPTO frames immediately a few times just
+ // in case. As we don't have an RTT estimate yet, this helps
+ // when there is a short RTT and losses.
+ if dcid.is_none()
+ && self.is_valid_cid(packet.dcid())
+ && self.stats.borrow().saved_datagrams <= EXTRA_INITIALS
+ {
+ self.crypto.resend_unacked(PNSpace::Initial);
+ }
+ }
+ (PacketType::VersionNegotiation, ..)
+ | (PacketType::Retry, ..)
+ | (PacketType::OtherVersion, ..) => {
+ self.stats
+ .borrow_mut()
+ .pkt_dropped(format!("{:?}", packet.packet_type()));
+ return Ok(PreprocessResult::Next);
+ }
+ _ => {}
+ }
+
+ let res = match self.state {
+ State::Init => {
+ self.stats
+ .borrow_mut()
+ .pkt_dropped("Received while in Init state");
+ PreprocessResult::Next
+ }
+ State::WaitInitial => PreprocessResult::Continue,
+ State::Handshaking | State::Connected | State::Confirmed => {
+ if !self.is_valid_cid(packet.dcid()) {
+ self.stats
+ .borrow_mut()
+ .pkt_dropped(format!("Invalid DCID {:?}", packet.dcid()));
+ PreprocessResult::Next
+ } else {
+ if self.role == Role::Server && packet.packet_type() == PacketType::Handshake {
+ // Server has received a Handshake packet -> discard Initial keys and states
+ self.discard_keys(PNSpace::Initial, now);
+ }
+ PreprocessResult::Continue
+ }
+ }
+ State::Closing { .. } => {
+ // Don't bother processing the packet. Instead ask to get a
+ // new close frame.
+ self.state_signaling.send_close();
+ PreprocessResult::Next
+ }
+ State::Draining { .. } | State::Closed(..) => {
+ // Do nothing.
+ self.stats
+ .borrow_mut()
+ .pkt_dropped(format!("State {:?}", self.state));
+ PreprocessResult::Next
+ }
+ };
+ Ok(res)
+ }
+
+ /// Take a datagram as input. This reports an error if the packet was bad.
+ fn input(&mut self, d: Datagram, now: Instant) -> Res<()> {
+ let mut slc = &d[..];
+ let mut dcid = None;
+
+ qtrace!([self], "input {}", hex(&**d));
+
+ // Handle each packet in the datagram.
+ while !slc.is_empty() {
+ self.stats.borrow_mut().packets_rx += 1;
+ let (packet, remainder) =
+ match PublicPacket::decode(slc, self.cid_manager.borrow().as_decoder()) {
+ Ok((packet, remainder)) => (packet, remainder),
+ Err(e) => {
+ qinfo!([self], "Garbage packet: {}", e);
+ qtrace!([self], "Garbage packet contents: {}", hex(slc));
+ self.stats.borrow_mut().pkt_dropped("Garbage packet");
+ break;
+ }
+ };
+ match self.preprocess(&packet, dcid.as_ref(), now)? {
+ PreprocessResult::Continue => (),
+ PreprocessResult::Next => break,
+ PreprocessResult::End => return Ok(()),
+ }
+
+ qtrace!([self], "Received unverified packet {:?}", packet);
+
+ let pto = self.loss_recovery.pto_raw(PNSpace::ApplicationData);
+ match packet.decrypt(&mut self.crypto.states, now + pto) {
+ Ok(payload) => {
+ // TODO(ekr@rtfm.com): Have the server blow away the initial
+ // crypto state if this fails? Otherwise, we will get a panic
+ // on the assert for doesn't exist.
+ // OK, we have a valid packet.
+ self.idle_timeout.on_packet_received(now);
+ dump_packet(
+ self,
+ "-> RX",
+ payload.packet_type(),
+ payload.pn(),
+ &payload[..],
+ );
+ qlog::packet_received(&mut self.qlog, &packet, &payload);
+ let res = self.process_packet(&payload, now);
+ if res.is_err() && self.path.is_none() {
+ // We need to make a path for sending an error message.
+ // But this connection is going to be closed.
+ self.remote_initial_source_cid = Some(ConnectionId::from(packet.scid()));
+ self.initialize_path(d.destination(), d.source());
+ }
+ res?;
+ if self.state == State::WaitInitial {
+ self.start_handshake(&packet, &d)?;
+ }
+ self.process_migrations(&d)?;
+ }
+ Err(e) => {
+ match e {
+ Error::KeysPending(cspace) => {
+ // This packet can't be decrypted because we don't have the keys yet.
+ // Don't check this packet for a stateless reset, just return.
+ let remaining = slc.len();
+ self.save_datagram(cspace, d, remaining, now);
+ return Ok(());
+ }
+ Error::KeysExhausted => {
+ // Exhausting read keys is fatal.
+ return Err(e);
+ }
+ _ => (),
+ }
+ // Decryption failure, or not having keys is not fatal.
+ // If the state isn't available, or we can't decrypt the packet, drop
+ // the rest of the datagram on the floor, but don't generate an error.
+ self.check_stateless_reset(&d, dcid.is_none(), now)?;
+ self.stats.borrow_mut().pkt_dropped("Decryption failure");
+ qlog::packet_dropped(&mut self.qlog, &packet);
+ }
+ }
+ slc = remainder;
+ dcid = Some(ConnectionId::from(packet.dcid()));
+ }
+ self.check_stateless_reset(&d, dcid.is_none(), now)?;
+ Ok(())
+ }
+
+ fn process_packet(&mut self, packet: &DecryptedPacket, now: Instant) -> Res<()> {
+ // TODO(ekr@rtfm.com): Have the server blow away the initial
+ // crypto state if this fails? Otherwise, we will get a panic
+ // on the assert for doesn't exist.
+ // OK, we have a valid packet.
+
+ let space = PNSpace::from(packet.packet_type());
+ if self.acks.get_mut(space).unwrap().is_duplicate(packet.pn()) {
+ qdebug!([self], "Duplicate packet from {} pn={}", space, packet.pn());
+ self.stats.borrow_mut().dups_rx += 1;
+ return Ok(());
+ }
+
+ let mut ack_eliciting = false;
+ let mut d = Decoder::from(&packet[..]);
+ let mut consecutive_padding = 0;
+ while d.remaining() > 0 {
+ let mut f = Frame::decode(&mut d)?;
+
+ // Skip padding
+ while f == Frame::Padding && d.remaining() > 0 {
+ consecutive_padding += 1;
+ f = Frame::decode(&mut d)?;
+ }
+ if consecutive_padding > 0 {
+ qdebug!(
+ [self],
+ "PADDING frame repeated {} times",
+ consecutive_padding
+ );
+ consecutive_padding = 0;
+ }
+
+ ack_eliciting |= f.ack_eliciting();
+ let t = f.get_type();
+ let res = self.input_frame(packet.packet_type(), f, now);
+ self.capture_error(now, t, res)?;
+ }
+ self.acks
+ .get_mut(space)
+ .unwrap()
+ .set_received(now, packet.pn(), ack_eliciting);
+
+ Ok(())
+ }
+
+ fn initialize_path(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr) {
+ debug_assert!(self.path.is_none());
+ self.path = Some(Path::new(
+ local_addr,
+ remote_addr,
+ self.local_initial_source_cid.clone(),
+ // Ideally we know what the peer wants us to use for the remote CID.
+ // But we will use our own guess if necessary.
+ self.remote_initial_source_cid
+ .as_ref()
+ .or_else(|| self.original_destination_cid.as_ref())
+ .unwrap()
+ .clone(),
+ ));
+ }
+
+ fn start_handshake(&mut self, packet: &PublicPacket, d: &Datagram) -> Res<()> {
+ qtrace!([self], "starting handshake");
+ debug_assert_eq!(packet.packet_type(), PacketType::Initial);
+ self.remote_initial_source_cid = Some(ConnectionId::from(packet.scid()));
+
+ if self.role == Role::Server {
+ // A server needs to accept the client's selected CID during the handshake.
+ self.valid_cids.push(ConnectionId::from(packet.dcid()));
+ self.original_destination_cid = Some(ConnectionId::from(packet.dcid()));
+ // Install a path.
+ self.initialize_path(d.destination(), d.source());
+
+ self.zero_rtt_state = match self.crypto.enable_0rtt(self.role) {
+ Ok(true) => {
+ qdebug!([self], "Accepted 0-RTT");
+ ZeroRttState::AcceptedServer
+ }
+ _ => ZeroRttState::Rejected,
+ };
+ } else {
+ qdebug!([self], "Changing to use Server CID={}", packet.scid());
+ let p = self
+ .path
+ .iter_mut()
+ .find(|p| p.received_on(&d))
+ .expect("should have a path for sending Initial");
+ p.set_remote_cid(packet.scid());
+ }
+
+ self.set_state(State::Handshaking);
+ Ok(())
+ }
+
+ fn process_migrations(&self, d: &Datagram) -> Res<()> {
+ if self.path.iter().any(|p| p.received_on(&d)) {
+ Ok(())
+ } else {
+ // Right now, we don't support any form of migration.
+ // So generate an error if a packet is received on a new path.
+ Err(Error::InvalidMigration)
+ }
+ }
+
+ fn output(&mut self, now: Instant) -> SendOption {
+ qtrace!([self], "output {:?}", now);
+ if let Some(mut path) = self.path.take() {
+ let res = match &self.state {
+ State::Init
+ | State::WaitInitial
+ | State::Handshaking
+ | State::Connected
+ | State::Confirmed => self.output_path(&mut path, now),
+ State::Closing { .. } | State::Draining { .. } | State::Closed(_) => {
+ if let Some(frame) = self.state_signaling.close_frame() {
+ self.output_close(&path, &frame)
+ } else {
+ Ok(SendOption::default())
+ }
+ }
+ };
+ let out = self.absorb_error(now, res).unwrap_or_default();
+ self.path = Some(path);
+ out
+ } else {
+ SendOption::default()
+ }
+ }
+
+ fn build_packet_header(
+ path: &Path,
+ cspace: CryptoSpace,
+ encoder: Encoder,
+ tx: &CryptoDxState,
+ address_validation: &AddressValidationInfo,
+ quic_version: QuicVersion,
+ grease_quic_bit: bool,
+ ) -> (PacketType, PacketBuilder) {
+ let pt = PacketType::from(cspace);
+ let mut builder = if pt == PacketType::Short {
+ qdebug!("Building Short dcid {}", path.remote_cid());
+ PacketBuilder::short(encoder, tx.key_phase(), path.remote_cid())
+ } else {
+ qdebug!(
+ "Building {:?} dcid {} scid {}",
+ pt,
+ path.remote_cid(),
+ path.local_cid(),
+ );
+
+ PacketBuilder::long(
+ encoder,
+ pt,
+ quic_version,
+ path.remote_cid(),
+ path.local_cid(),
+ )
+ };
+ builder.scramble(grease_quic_bit);
+ if pt == PacketType::Initial {
+ builder.initial_token(address_validation.token());
+ }
+
+ (pt, builder)
+ }
+
+ fn add_packet_number(
+ builder: &mut PacketBuilder,
+ tx: &CryptoDxState,
+ largest_acknowledged: Option<PacketNumber>,
+ ) -> PacketNumber {
+ // Get the packet number and work out how long it is.
+ let pn = tx.next_pn();
+ let unacked_range = if let Some(la) = largest_acknowledged {
+ // Double the range from this to the last acknowledged in this space.
+ (pn - la) << 1
+ } else {
+ pn + 1
+ };
+ // Count how many bytes in this range are non-zero.
+ let pn_len = mem::size_of::<PacketNumber>()
+ - usize::try_from(unacked_range.leading_zeros() / 8).unwrap();
+ // pn_len can't be zero (unacked_range is > 0)
+ // TODO(mt) also use `4*path CWND/path MTU` to set a minimum length.
+ builder.pn(pn, pn_len);
+ pn
+ }
+
+ fn can_grease_quic_bit(&self) -> bool {
+ let tph = self.tps.borrow();
+ if let Some(r) = &tph.remote {
+ r.get_empty(tparams::GREASE_QUIC_BIT)
+ } else if let Some(r) = &tph.remote_0rtt {
+ r.get_empty(tparams::GREASE_QUIC_BIT)
+ } else {
+ false
+ }
+ }
+
+ fn output_close(&mut self, path: &Path, frame: &Frame) -> Res<SendOption> {
+ let mut encoder = Encoder::with_capacity(path.mtu());
+ let grease_quic_bit = self.can_grease_quic_bit();
+ for space in PNSpace::iter() {
+ let (cspace, tx) = if let Some(crypto) = self.crypto.states.select_tx(*space) {
+ crypto
+ } else {
+ continue;
+ };
+
+ let (_, mut builder) = Self::build_packet_header(
+ path,
+ cspace,
+ encoder,
+ tx,
+ &AddressValidationInfo::None,
+ self.quic_version,
+ grease_quic_bit,
+ );
+ let _ = Self::add_packet_number(
+ &mut builder,
+ tx,
+ self.loss_recovery.largest_acknowledged_pn(*space),
+ );
+
+ // ConnectionError::Application is only allowed at 1RTT.
+ let sanitized = if *space == PNSpace::ApplicationData {
+ &frame
+ } else {
+ frame.sanitize_close()
+ };
+ if let Frame::ConnectionClose {
+ error_code,
+ frame_type,
+ reason_phrase,
+ } = sanitized
+ {
+ builder.encode_varint(sanitized.get_type());
+ builder.encode_varint(error_code.code());
+ if let CloseError::Transport(_) = error_code {
+ builder.encode_varint(*frame_type);
+ }
+ let reason_len = min(min(reason_phrase.len(), 256), builder.remaining() - 2);
+ builder.encode_vvec(&reason_phrase[..reason_len]);
+ } else {
+ unreachable!();
+ }
+
+ encoder = builder.build(tx)?;
+ }
+
+ Ok(SendOption::Yes(path.datagram(encoder)))
+ }
+
+ /// Write frames to the provided builder. Returns a list of tokens used for
+ /// tracking loss or acknowledgment, whether any frame was ACK eliciting, and
+ /// whether the packet was padded.
+ fn write_frames(
+ &mut self,
+ space: PNSpace,
+ profile: &SendProfile,
+ builder: &mut PacketBuilder,
+ mut pad: bool,
+ now: Instant,
+ ) -> (Vec<RecoveryToken>, bool, bool) {
+ let mut tokens = Vec::new();
+ let stats = &mut self.stats.borrow_mut().frame_tx;
+
+ let ack_token = self.acks.write_frame(space, now, builder, stats);
+
+ if profile.ack_only(space) {
+ // If we are CC limited we can only send acks!
+ if let Some(t) = ack_token {
+ tokens.push(t);
+ }
+ return (tokens, false, false);
+ }
+
+ if space == PNSpace::ApplicationData && self.role == Role::Server {
+ if let Some(t) = self.state_signaling.write_done(builder) {
+ tokens.push(t);
+ stats.handshake_done += 1;
+ }
+ }
+
+ if let Some(t) = self.crypto.streams.write_frame(space, builder) {
+ tokens.push(t);
+ stats.crypto += 1;
+ }
+
+ if space == PNSpace::ApplicationData {
+ self.flow_mgr
+ .borrow_mut()
+ .write_frames(builder, &mut tokens, stats);
+
+ self.send_streams.write_frames(builder, &mut tokens, stats);
+ self.new_token.write_frames(builder, &mut tokens, stats);
+ }
+
+ // Anything - other than ACK - that registered a token wants an acknowledgment.
+ let ack_eliciting = !tokens.is_empty()
+ || if profile.should_probe(space) {
+ // Nothing ack-eliciting and we need to probe; send PING.
+ debug_assert_ne!(builder.remaining(), 0);
+ builder.encode_varint(crate::frame::FRAME_TYPE_PING);
+ stats.ping += 1;
+ stats.all += 1;
+ true
+ } else {
+ false
+ };
+
+ // Add padding. Only pad 1-RTT packets so that we don't prevent coalescing.
+ // And avoid padding packets that otherwise only contain ACK because adding PADDING
+ // causes those packets to consume congestion window, which is not tracked (yet).
+ pad &= ack_eliciting && space == PNSpace::ApplicationData;
+ if pad {
+ builder.pad();
+ stats.padding += 1;
+ stats.all += 1;
+ }
+
+ if let Some(t) = ack_token {
+ tokens.push(t);
+ }
+ stats.all += tokens.len();
+ (tokens, ack_eliciting, pad)
+ }
+
+ /// Build a datagram, possibly from multiple packets (for different PN
+ /// spaces) and each containing 1+ frames.
+ fn output_path(&mut self, path: &mut Path, now: Instant) -> Res<SendOption> {
+ let mut initial_sent = None;
+ let mut needs_padding = false;
+ let grease_quic_bit = self.can_grease_quic_bit();
+
+ // Determine how we are sending packets (PTO, etc..).
+ let profile = self.loss_recovery.send_profile(now, path.mtu());
+ qdebug!([self], "output_path send_profile {:?}", profile);
+
+ // Frames for different epochs must go in different packets, but then these
+ // packets can go in a single datagram
+ let mut encoder = Encoder::with_capacity(profile.limit());
+ for space in PNSpace::iter() {
+ // Ensure we have tx crypto state for this epoch, or skip it.
+ let (cspace, tx) = if let Some(crypto) = self.crypto.states.select_tx(*space) {
+ crypto
+ } else {
+ continue;
+ };
+
+ let header_start = encoder.len();
+ let (pt, mut builder) = Self::build_packet_header(
+ path,
+ cspace,
+ encoder,
+ tx,
+ &self.address_validation,
+ self.quic_version,
+ grease_quic_bit,
+ );
+ let pn = Self::add_packet_number(
+ &mut builder,
+ tx,
+ self.loss_recovery.largest_acknowledged_pn(*space),
+ );
+ let payload_start = builder.len();
+
+ // Work out if we have space left.
+ let aead_expansion = tx.expansion();
+ if builder.len() + aead_expansion > profile.limit() {
+ // No space for a packet of this type.
+ encoder = builder.abort();
+ continue;
+ }
+
+ // Add frames to the packet.
+ let limit = profile.limit() - aead_expansion;
+ builder.set_limit(limit);
+ let (tokens, ack_eliciting, padded) =
+ self.write_frames(*space, &profile, &mut builder, needs_padding, now);
+ if builder.packet_empty() {
+ // Nothing to include in this packet.
+ encoder = builder.abort();
+ continue;
+ }
+
+ dump_packet(self, "TX ->", pt, pn, &builder[payload_start..]);
+ qlog::packet_sent(
+ &mut self.qlog,
+ pt,
+ pn,
+ builder.len() - header_start + aead_expansion,
+ &builder[payload_start..],
+ );
+
+ self.stats.borrow_mut().packets_tx += 1;
+ encoder = builder.build(self.crypto.states.tx(cspace).unwrap())?;
+ debug_assert!(encoder.len() <= path.mtu());
+ self.crypto.states.auto_update()?;
+
+ if ack_eliciting {
+ self.idle_timeout.on_packet_sent(now);
+ }
+ let sent = SentPacket::new(
+ pt,
+ pn,
+ now,
+ ack_eliciting,
+ tokens,
+ encoder.len() - header_start,
+ );
+ if padded {
+ needs_padding = false;
+ self.loss_recovery.on_packet_sent(sent);
+ } else if pt == PacketType::Initial && (self.role == Role::Client || ack_eliciting) {
+ // Packets containing Initial packets might need padding, and we want to
+ // track that padding along with the Initial packet. So defer tracking.
+ initial_sent = Some(sent);
+ needs_padding = true;
+ } else {
+ if pt == PacketType::Handshake && self.role == Role::Client {
+ needs_padding = false;
+ }
+ self.loss_recovery.on_packet_sent(sent);
+ }
+
+ if *space == PNSpace::Handshake {
+ if self.role == Role::Client {
+ // Client can send Handshake packets -> discard Initial keys and states
+ self.discard_keys(PNSpace::Initial, now);
+ } else if self.state == State::Confirmed {
+ // We could discard handshake keys in set_state, but wait until after sending an ACK.
+ self.discard_keys(PNSpace::Handshake, now);
+ }
+ }
+ }
+
+ if encoder.is_empty() {
+ Ok(SendOption::No(profile.paced()))
+ } else {
+ // Perform additional padding for Initial packets as necessary.
+ let mut packets: Vec<u8> = encoder.into();
+ if let Some(mut initial) = initial_sent.take() {
+ if needs_padding {
+ qdebug!([self], "pad Initial to path MTU {}", path.mtu());
+ initial.size += path.mtu() - packets.len();
+ packets.resize(path.mtu(), 0);
+ }
+ self.loss_recovery.on_packet_sent(initial);
+ }
+ Ok(SendOption::Yes(path.datagram(packets)))
+ }
+ }
+
+ pub fn initiate_key_update(&mut self) -> Res<()> {
+ if self.state == State::Confirmed {
+ let la = self
+ .loss_recovery
+ .largest_acknowledged_pn(PNSpace::ApplicationData);
+ qinfo!([self], "Initiating key update");
+ self.crypto.states.initiate_key_update(la)
+ } else {
+ Err(Error::KeyUpdateBlocked)
+ }
+ }
+
+ #[cfg(test)]
+ pub fn get_epochs(&self) -> (Option<usize>, Option<usize>) {
+ self.crypto.states.get_epochs()
+ }
+
+ fn client_start(&mut self, now: Instant) -> Res<()> {
+ qinfo!([self], "client_start");
+ debug_assert_eq!(self.role, Role::Client);
+ qlog::client_connection_started(&mut self.qlog, self.path.as_ref().unwrap());
+ self.loss_recovery.start_pacer(now);
+
+ self.handshake(now, PNSpace::Initial, None)?;
+ self.set_state(State::WaitInitial);
+ self.zero_rtt_state = if self.crypto.enable_0rtt(self.role)? {
+ qdebug!([self], "Enabled 0-RTT");
+ ZeroRttState::Sending
+ } else {
+ ZeroRttState::Init
+ };
+ Ok(())
+ }
+
+ fn get_closing_period_time(&self, now: Instant) -> Instant {
+ // Spec says close time should be at least PTO times 3.
+ now + (self.loss_recovery.pto_raw(PNSpace::ApplicationData) * 3)
+ }
+
+ /// Close the connection.
+ pub fn close(&mut self, now: Instant, app_error: AppError, msg: impl AsRef<str>) {
+ let error = ConnectionError::Application(app_error);
+ let timeout = self.get_closing_period_time(now);
+ self.state_signaling.close(error.clone(), 0, msg);
+ self.set_state(State::Closing { error, timeout });
+ }
+
+ fn set_initial_limits(&mut self) {
+ let tps = self.tps.borrow();
+ let remote = tps.remote();
+ self.indexes.remote_max_stream_bidi =
+ StreamIndex::new(remote.get_integer(tparams::INITIAL_MAX_STREAMS_BIDI));
+ self.indexes.remote_max_stream_uni =
+ StreamIndex::new(remote.get_integer(tparams::INITIAL_MAX_STREAMS_UNI));
+ self.flow_mgr
+ .borrow_mut()
+ .conn_increase_max_credit(remote.get_integer(tparams::INITIAL_MAX_DATA));
+
+ let peer_timeout = remote.get_integer(tparams::IDLE_TIMEOUT);
+ if peer_timeout > 0 {
+ self.idle_timeout
+ .set_peer_timeout(Duration::from_millis(peer_timeout));
+ }
+ }
+
+ /// Process the final set of transport parameters.
+ fn process_tps(&mut self) -> Res<()> {
+ self.validate_cids()?;
+ {
+ let tps = self.tps.borrow();
+ if let Some(token) = tps
+ .remote
+ .as_ref()
+ .unwrap()
+ .get_bytes(tparams::STATELESS_RESET_TOKEN)
+ {
+ let reset_token = <[u8; 16]>::try_from(token).unwrap().to_owned();
+ self.path.as_mut().unwrap().set_reset_token(reset_token);
+ }
+ let mad = Duration::from_millis(
+ tps.remote
+ .as_ref()
+ .unwrap()
+ .get_integer(tparams::MAX_ACK_DELAY),
+ );
+ self.loss_recovery.set_peer_max_ack_delay(mad);
+ }
+ self.set_initial_limits();
+ qlog::connection_tparams_set(&mut self.qlog, &*self.tps.borrow());
+ Ok(())
+ }
+
+ fn validate_cids(&mut self) -> Res<()> {
+ match self.quic_version {
+ QuicVersion::Draft27 => self.validate_cids_draft_27(),
+ _ => self.validate_cids_draft_28_plus(),
+ }
+ }
+
+ fn validate_cids_draft_27(&mut self) -> Res<()> {
+ if let AddressValidationInfo::Retry { token, .. } = &self.address_validation {
+ debug_assert!(!token.is_empty());
+ let tph = self.tps.borrow();
+ let tp = tph
+ .remote
+ .as_ref()
+ .unwrap()
+ .get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID);
+ if self
+ .original_destination_cid
+ .as_ref()
+ .map(ConnectionId::as_cid_ref)
+ != tp.map(ConnectionIdRef::from)
+ {
+ return Err(Error::InvalidRetry);
+ }
+ }
+ Ok(())
+ }
+
+ fn validate_cids_draft_28_plus(&mut self) -> Res<()> {
+ let tph = self.tps.borrow();
+ let remote_tps = tph.remote.as_ref().unwrap();
+
+ let tp = remote_tps.get_bytes(tparams::INITIAL_SOURCE_CONNECTION_ID);
+ if self
+ .remote_initial_source_cid
+ .as_ref()
+ .map(ConnectionId::as_cid_ref)
+ != tp.map(ConnectionIdRef::from)
+ {
+ qwarn!(
+ [self],
+ "ISCID test failed: self cid {:?} != tp cid {:?}",
+ self.remote_initial_source_cid,
+ tp.map(hex),
+ );
+ return Err(Error::ProtocolViolation);
+ }
+
+ if self.role == Role::Client {
+ let tp = remote_tps.get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID);
+ if self
+ .original_destination_cid
+ .as_ref()
+ .map(ConnectionId::as_cid_ref)
+ != tp.map(ConnectionIdRef::from)
+ {
+ qwarn!(
+ [self],
+ "ODCID test failed: self cid {:?} != tp cid {:?}",
+ self.original_destination_cid,
+ tp.map(hex),
+ );
+ return Err(Error::ProtocolViolation);
+ }
+
+ let tp = remote_tps.get_bytes(tparams::RETRY_SOURCE_CONNECTION_ID);
+ let expected = if let AddressValidationInfo::Retry {
+ retry_source_cid, ..
+ } = &self.address_validation
+ {
+ Some(retry_source_cid.as_cid_ref())
+ } else {
+ None
+ };
+ if expected != tp.map(ConnectionIdRef::from) {
+ qwarn!(
+ [self],
+ "RSCID test failed. self cid {:?} != tp cid {:?}",
+ expected,
+ tp.map(hex),
+ );
+ return Err(Error::ProtocolViolation);
+ }
+ }
+
+ Ok(())
+ }
+
+ fn handshake(&mut self, now: Instant, space: PNSpace, data: Option<&[u8]>) -> Res<()> {
+ qtrace!([self], "Handshake space={} data={:0x?}", space, data);
+
+ let try_update = data.is_some();
+ match self.crypto.handshake(now, space, data)? {
+ HandshakeState::Authenticated(_) | HandshakeState::InProgress => (),
+ HandshakeState::AuthenticationPending => self.events.authentication_needed(),
+ HandshakeState::Complete(_) => {
+ if !self.state.connected() {
+ self.set_connected(now)?;
+ }
+ }
+ _ => {
+ unreachable!("Crypto state should not be new or failed after successful handshake")
+ }
+ }
+
+ // There is a chance that this could be called less often, but getting the
+ // conditions right is a little tricky, so call it on every CRYPTO frame.
+ if try_update {
+ // We have transport parameters, it's go time.
+ if self.tps.borrow().remote.is_some() {
+ self.set_initial_limits();
+ }
+ if self.crypto.install_keys(self.role)? {
+ self.saved_datagrams.make_available(CryptoSpace::Handshake);
+ }
+ }
+
+ Ok(())
+ }
+
+ fn handle_max_data(&mut self, maximum_data: u64) {
+ let conn_was_blocked = self.flow_mgr.borrow().conn_credit_avail() == 0;
+ let conn_credit_increased = self
+ .flow_mgr
+ .borrow_mut()
+ .conn_increase_max_credit(maximum_data);
+
+ if conn_was_blocked && conn_credit_increased {
+ for (id, ss) in &mut self.send_streams {
+ if ss.avail() > 0 {
+ // These may not actually all be writable if one
+ // uses up all the conn credit. Not our fault.
+ self.events.send_stream_writable(*id)
+ }
+ }
+ }
+ }
+
+ fn input_frame(&mut self, ptype: PacketType, frame: Frame, now: Instant) -> Res<()> {
+ if !frame.is_allowed(ptype) {
+ qinfo!("frame not allowed: {:?} {:?}", frame, ptype);
+ return Err(Error::ProtocolViolation);
+ }
+ self.stats.borrow_mut().frame_rx.all += 1;
+ let space = PNSpace::from(ptype);
+ match frame {
+ Frame::Padding => {
+ // Note: This counts contiguous padding as a single frame.
+ self.stats.borrow_mut().frame_rx.padding += 1;
+ }
+ Frame::Ping => {
+ // If we get a PING and there are outstanding CRYPTO frames,
+ // prepare to resend them.
+ self.stats.borrow_mut().frame_rx.ping += 1;
+ self.crypto.resend_unacked(space);
+ }
+ Frame::Ack {
+ largest_acknowledged,
+ ack_delay,
+ first_ack_range,
+ ack_ranges,
+ } => {
+ self.handle_ack(
+ space,
+ largest_acknowledged,
+ ack_delay,
+ first_ack_range,
+ ack_ranges,
+ now,
+ )?;
+ }
+ Frame::ResetStream {
+ stream_id,
+ application_error_code,
+ ..
+ } => {
+ // TODO(agrover@mozilla.com): use final_size for connection MaxData calc
+ self.stats.borrow_mut().frame_rx.reset_stream += 1;
+ if let (_, Some(rs)) = self.obtain_stream(stream_id)? {
+ rs.reset(application_error_code);
+ }
+ }
+ Frame::StopSending {
+ stream_id,
+ application_error_code,
+ } => {
+ self.stats.borrow_mut().frame_rx.stop_sending += 1;
+ self.events
+ .send_stream_stop_sending(stream_id, application_error_code);
+ if let (Some(ss), _) = self.obtain_stream(stream_id)? {
+ ss.reset(application_error_code);
+ }
+ }
+ Frame::Crypto { offset, data } => {
+ qtrace!(
+ [self],
+ "Crypto frame on space={} offset={}, data={:0x?}",
+ space,
+ offset,
+ &data
+ );
+ self.stats.borrow_mut().frame_rx.crypto += 1;
+ self.crypto.streams.inbound_frame(space, offset, data);
+ if self.crypto.streams.data_ready(space) {
+ let mut buf = Vec::new();
+ let read = self.crypto.streams.read_to_end(space, &mut buf);
+ qdebug!("Read {} bytes", read);
+ self.handshake(now, space, Some(&buf))?;
+ self.create_resumption_token(now);
+ } else {
+ // If we get a useless CRYPTO frame send outstanding CRYPTO frames again.
+ self.crypto.resend_unacked(space);
+ }
+ }
+ Frame::NewToken { token } => {
+ self.stats.borrow_mut().frame_rx.new_token += 1;
+ self.new_token.save_token(token.to_vec());
+ self.create_resumption_token(now);
+ }
+ Frame::Stream {
+ fin,
+ stream_id,
+ offset,
+ data,
+ ..
+ } => {
+ self.stats.borrow_mut().frame_rx.stream += 1;
+ if let (_, Some(rs)) = self.obtain_stream(stream_id)? {
+ rs.inbound_stream_frame(fin, offset, data)?;
+ }
+ }
+ Frame::MaxData { maximum_data } => {
+ self.stats.borrow_mut().frame_rx.max_data += 1;
+ self.handle_max_data(maximum_data);
+ }
+ Frame::MaxStreamData {
+ stream_id,
+ maximum_stream_data,
+ } => {
+ self.stats.borrow_mut().frame_rx.max_stream_data += 1;
+ if let (Some(ss), _) = self.obtain_stream(stream_id)? {
+ ss.set_max_stream_data(maximum_stream_data);
+ }
+ }
+ Frame::MaxStreams {
+ stream_type,
+ maximum_streams,
+ } => {
+ self.stats.borrow_mut().frame_rx.max_streams += 1;
+ let remote_max = match stream_type {
+ StreamType::BiDi => &mut self.indexes.remote_max_stream_bidi,
+ StreamType::UniDi => &mut self.indexes.remote_max_stream_uni,
+ };
+
+ if maximum_streams > *remote_max {
+ *remote_max = maximum_streams;
+ self.events.send_stream_creatable(stream_type);
+ }
+ }
+ Frame::DataBlocked { data_limit } => {
+ // Should never happen since we set data limit to max
+ qwarn!(
+ [self],
+ "Received DataBlocked with data limit {}",
+ data_limit
+ );
+ self.stats.borrow_mut().frame_rx.data_blocked += 1;
+ // But if it does, open it up all the way
+ self.flow_mgr.borrow_mut().max_data(LOCAL_MAX_DATA);
+ }
+ Frame::StreamDataBlocked {
+ stream_id,
+ stream_data_limit,
+ } => {
+ self.stats.borrow_mut().frame_rx.stream_data_blocked += 1;
+ // Terminate connection with STREAM_STATE_ERROR if send-only
+ // stream (-transport 19.13)
+ if stream_id.is_send_only(self.role()) {
+ return Err(Error::StreamStateError);
+ }
+
+ if let (_, Some(rs)) = self.obtain_stream(stream_id)? {
+ if let Some(msd) = rs.max_stream_data() {
+ qinfo!(
+ [self],
+ "Got StreamDataBlocked(id {} MSD {}); curr MSD {}",
+ stream_id.as_u64(),
+ stream_data_limit,
+ msd
+ );
+ if stream_data_limit != msd {
+ self.flow_mgr.borrow_mut().max_stream_data(stream_id, msd)
+ }
+ }
+ }
+ }
+ Frame::StreamsBlocked { stream_type, .. } => {
+ self.stats.borrow_mut().frame_rx.streams_blocked += 1;
+ let local_max = match stream_type {
+ StreamType::BiDi => &mut self.indexes.local_max_stream_bidi,
+ StreamType::UniDi => &mut self.indexes.local_max_stream_uni,
+ };
+
+ self.flow_mgr
+ .borrow_mut()
+ .max_streams(*local_max, stream_type)
+ }
+ Frame::NewConnectionId {
+ sequence_number,
+ connection_id,
+ stateless_reset_token,
+ ..
+ } => {
+ self.stats.borrow_mut().frame_rx.new_connection_id += 1;
+ let cid = ConnectionId::from(connection_id);
+ let srt = stateless_reset_token.to_owned();
+ self.connection_ids.insert(sequence_number, (cid, srt));
+ }
+ Frame::RetireConnectionId { sequence_number } => {
+ self.stats.borrow_mut().frame_rx.retire_connection_id += 1;
+ self.connection_ids.remove(&sequence_number);
+ }
+ Frame::PathChallenge { data } => {
+ self.stats.borrow_mut().frame_rx.path_challenge += 1;
+ self.flow_mgr.borrow_mut().path_response(data);
+ }
+ Frame::PathResponse { .. } => {
+ // Should never see this, we don't support migration atm and
+ // do not send path challenges
+ qwarn!([self], "Received Path Response");
+ self.stats.borrow_mut().frame_rx.path_response += 1;
+ }
+ Frame::ConnectionClose {
+ error_code,
+ frame_type,
+ reason_phrase,
+ } => {
+ self.stats.borrow_mut().frame_rx.connection_close += 1;
+ let reason_phrase = String::from_utf8_lossy(&reason_phrase);
+ qinfo!(
+ [self],
+ "ConnectionClose received. Error code: {:?} frame type {:x} reason {}",
+ error_code,
+ frame_type,
+ reason_phrase
+ );
+ let (detail, frame_type) = if let CloseError::Application(_) = error_code {
+ // Use a transport error here because we want to send
+ // NO_ERROR in this case.
+ (
+ Error::PeerApplicationError(error_code.code()),
+ FRAME_TYPE_CONNECTION_CLOSE_APPLICATION,
+ )
+ } else {
+ (
+ Error::PeerError(error_code.code()),
+ FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT,
+ )
+ };
+ let error = ConnectionError::Transport(detail);
+ self.state_signaling.drain(error.clone(), frame_type, "");
+ self.set_state(State::Draining {
+ error,
+ timeout: self.get_closing_period_time(now),
+ });
+ }
+ Frame::HandshakeDone => {
+ self.stats.borrow_mut().frame_rx.handshake_done += 1;
+ if self.role == Role::Server || !self.state.connected() {
+ return Err(Error::ProtocolViolation);
+ }
+ self.set_state(State::Confirmed);
+ self.discard_keys(PNSpace::Handshake, now);
+ }
+ };
+
+ Ok(())
+ }
+
+ /// Given a set of `SentPacket` instances, ensure that the source of the packet
+ /// is told that they are lost. This gives the frame generation code a chance
+ /// to retransmit the frame as needed.
+ fn handle_lost_packets(&mut self, lost_packets: &[SentPacket]) {
+ for lost in lost_packets {
+ for token in &lost.tokens {
+ qdebug!([self], "Lost: {:?}", token);
+ match token {
+ RecoveryToken::Ack(_) => {}
+ RecoveryToken::Stream(st) => self.send_streams.lost(&st),
+ RecoveryToken::Crypto(ct) => self.crypto.lost(&ct),
+ RecoveryToken::Flow(ft) => self.flow_mgr.borrow_mut().lost(
+ &ft,
+ &mut self.send_streams,
+ &mut self.recv_streams,
+ &mut self.indexes,
+ ),
+ RecoveryToken::HandshakeDone => self.state_signaling.handshake_done(),
+ RecoveryToken::NewToken(seqno) => self.new_token.lost(*seqno),
+ }
+ }
+ }
+ }
+
+ fn decode_ack_delay(&self, v: u64) -> Duration {
+ // If we have remote transport parameters, use them.
+ // Otherwise, ack delay should be zero (because it's the handshake).
+ if let Some(r) = self.tps.borrow().remote.as_ref() {
+ let exponent = u32::try_from(r.get_integer(tparams::ACK_DELAY_EXPONENT)).unwrap();
+ Duration::from_micros(v.checked_shl(exponent).unwrap_or(u64::MAX))
+ } else {
+ Duration::new(0, 0)
+ }
+ }
+
+ fn handle_ack(
+ &mut self,
+ space: PNSpace,
+ largest_acknowledged: u64,
+ ack_delay: u64,
+ first_ack_range: u64,
+ ack_ranges: Vec<AckRange>,
+ now: Instant,
+ ) -> Res<()> {
+ qinfo!(
+ [self],
+ "Rx ACK space={}, largest_acked={}, first_ack_range={}, ranges={:?}",
+ space,
+ largest_acknowledged,
+ first_ack_range,
+ ack_ranges
+ );
+
+ let acked_ranges =
+ Frame::decode_ack_frame(largest_acknowledged, first_ack_range, &ack_ranges)?;
+ let (acked_packets, lost_packets) = self.loss_recovery.on_ack_received(
+ space,
+ largest_acknowledged,
+ acked_ranges,
+ self.decode_ack_delay(ack_delay),
+ now,
+ );
+ for acked in acked_packets {
+ for token in &acked.tokens {
+ match token {
+ RecoveryToken::Ack(at) => self.acks.acked(at),
+ RecoveryToken::Stream(st) => self.send_streams.acked(st),
+ RecoveryToken::Crypto(ct) => self.crypto.acked(ct),
+ RecoveryToken::Flow(ft) => {
+ self.flow_mgr.borrow_mut().acked(ft, &mut self.send_streams)
+ }
+ RecoveryToken::HandshakeDone => (),
+ RecoveryToken::NewToken(seqno) => self.new_token.acked(*seqno),
+ }
+ }
+ }
+ self.handle_lost_packets(&lost_packets);
+ qlog::packets_lost(&mut self.qlog, &lost_packets);
+ let stats = &mut self.stats.borrow_mut().frame_rx;
+ stats.ack += 1;
+ stats.largest_acknowledged = max(stats.largest_acknowledged, largest_acknowledged);
+ Ok(())
+ }
+
+ /// When the server rejects 0-RTT we need to drop a bunch of stuff.
+ fn client_0rtt_rejected(&mut self) {
+ if !matches!(self.zero_rtt_state, ZeroRttState::Sending) {
+ return;
+ }
+ qdebug!([self], "0-RTT rejected");
+
+ // Tell 0-RTT packets that they were "lost".
+ let dropped = self.loss_recovery.drop_0rtt();
+ self.handle_lost_packets(&dropped);
+
+ self.send_streams.clear();
+ self.recv_streams.clear();
+ self.indexes = StreamIndexes::new();
+ self.crypto.states.discard_0rtt_keys();
+ self.events.client_0rtt_rejected();
+ }
+
+ fn set_connected(&mut self, now: Instant) -> Res<()> {
+ qinfo!([self], "TLS connection complete");
+ if self.crypto.tls.info().map(SecretAgentInfo::alpn).is_none() {
+ qwarn!([self], "No ALPN. Closing connection.");
+ // 120 = no_application_protocol
+ return Err(Error::CryptoAlert(120));
+ }
+ if self.role == Role::Server {
+ // Remove the randomized client CID from the list of acceptable CIDs.
+ debug_assert_eq!(1, self.valid_cids.len());
+ self.valid_cids.clear();
+ // Generate a qlog event that the server connection started.
+ qlog::server_connection_started(&mut self.qlog, self.path.as_ref().unwrap());
+ } else {
+ self.zero_rtt_state = if self.crypto.tls.info().unwrap().early_data_accepted() {
+ ZeroRttState::AcceptedClient
+ } else {
+ self.client_0rtt_rejected();
+ ZeroRttState::Rejected
+ };
+ }
+
+ // Setting application keys has to occur after 0-RTT rejection.
+ let pto = self.loss_recovery.pto_raw(PNSpace::ApplicationData);
+ self.crypto.install_application_keys(now + pto)?;
+ self.process_tps()?;
+ self.set_state(State::Connected);
+ self.create_resumption_token(now);
+ self.saved_datagrams
+ .make_available(CryptoSpace::ApplicationData);
+ self.stats.borrow_mut().resumed = self.crypto.tls.info().unwrap().resumed();
+ if self.role == Role::Server {
+ self.state_signaling.handshake_done();
+ self.set_state(State::Confirmed);
+ }
+ qinfo!([self], "Connection established");
+ Ok(())
+ }
+
+ fn set_state(&mut self, state: State) {
+ if state > self.state {
+ qinfo!([self], "State change from {:?} -> {:?}", self.state, state);
+ self.state = state.clone();
+ if self.state.closed() {
+ self.send_streams.clear();
+ self.recv_streams.clear();
+ }
+ self.events.connection_state_change(state);
+ qlog::connection_state_updated(&mut self.qlog, &self.state)
+ } else if mem::discriminant(&state) != mem::discriminant(&self.state) {
+ // Only tolerate a regression in state if the new state is closing
+ // and the connection is already closed.
+ debug_assert!(matches!(state, State::Closing { .. } | State::Draining { .. }));
+ debug_assert!(self.state.closed());
+ }
+ }
+
+ fn cleanup_streams(&mut self) {
+ self.send_streams.clear_terminal();
+ let recv_to_remove = self
+ .recv_streams
+ .iter()
+ .filter_map(|(id, stream)| {
+ // Remove all streams for which the receiving is done (or aborted).
+ // But only if they are unidirectional, or we have finished sending.
+ if stream.is_terminal() && (id.is_uni() || !self.send_streams.exists(*id)) {
+ Some(*id)
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<_>>();
+
+ let mut removed_bidi = 0;
+ let mut removed_uni = 0;
+ for id in &recv_to_remove {
+ self.recv_streams.remove(&id);
+ if id.is_remote_initiated(self.role()) {
+ if id.is_bidi() {
+ removed_bidi += 1;
+ } else {
+ removed_uni += 1;
+ }
+ }
+ }
+
+ // Send max_streams updates if we removed remote-initiated recv streams.
+ if removed_bidi > 0 {
+ self.indexes.local_max_stream_bidi += removed_bidi;
+ self.flow_mgr
+ .borrow_mut()
+ .max_streams(self.indexes.local_max_stream_bidi, StreamType::BiDi)
+ }
+ if removed_uni > 0 {
+ self.indexes.local_max_stream_uni += removed_uni;
+ self.flow_mgr
+ .borrow_mut()
+ .max_streams(self.indexes.local_max_stream_uni, StreamType::UniDi)
+ }
+ }
+
+ /// Get or make a stream, and implicitly open additional streams as
+ /// indicated by its stream id.
+ fn obtain_stream(
+ &mut self,
+ stream_id: StreamId,
+ ) -> Res<(Option<&mut SendStream>, Option<&mut RecvStream>)> {
+ if !self.state.connected()
+ && !matches!(
+ (&self.state, &self.zero_rtt_state),
+ (State::Handshaking, ZeroRttState::AcceptedServer)
+ )
+ {
+ return Err(Error::ConnectionState);
+ }
+
+ // May require creating new stream(s)
+ if stream_id.is_remote_initiated(self.role()) {
+ let next_stream_idx = if stream_id.is_bidi() {
+ &mut self.indexes.local_next_stream_bidi
+ } else {
+ &mut self.indexes.local_next_stream_uni
+ };
+ let stream_idx: StreamIndex = stream_id.into();
+
+ if stream_idx >= *next_stream_idx {
+ let recv_initial_max_stream_data = if stream_id.is_bidi() {
+ if stream_idx > self.indexes.local_max_stream_bidi {
+ qwarn!(
+ [self],
+ "remote bidi stream create blocked, next={:?} max={:?}",
+ stream_idx,
+ self.indexes.local_max_stream_bidi
+ );
+ return Err(Error::StreamLimitError);
+ }
+ // From the local perspective, this is a remote- originated BiDi stream. From
+ // the remote perspective, this is a local-originated BiDi stream. Therefore,
+ // look at the local transport parameters for the
+ // INITIAL_MAX_STREAM_DATA_BIDI_REMOTE value to decide how much this endpoint
+ // will allow its peer to send.
+ self.tps
+ .borrow()
+ .local
+ .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE)
+ } else {
+ if stream_idx > self.indexes.local_max_stream_uni {
+ qwarn!(
+ [self],
+ "remote uni stream create blocked, next={:?} max={:?}",
+ stream_idx,
+ self.indexes.local_max_stream_uni
+ );
+ return Err(Error::StreamLimitError);
+ }
+ self.tps
+ .borrow()
+ .local
+ .get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI)
+ };
+
+ loop {
+ let next_stream_id =
+ next_stream_idx.to_stream_id(stream_id.stream_type(), stream_id.role());
+ self.events.new_stream(next_stream_id);
+
+ self.recv_streams.insert(
+ next_stream_id,
+ RecvStream::new(
+ next_stream_id,
+ recv_initial_max_stream_data,
+ self.flow_mgr.clone(),
+ self.events.clone(),
+ ),
+ );
+
+ if next_stream_id.is_bidi() {
+ // From the local perspective, this is a remote- originated BiDi stream.
+ // From the remote perspective, this is a local-originated BiDi stream.
+ // Therefore, look at the remote's transport parameters for the
+ // INITIAL_MAX_STREAM_DATA_BIDI_LOCAL value to decide how much this endpoint
+ // is allowed to send its peer.
+ let send_initial_max_stream_data = self
+ .tps
+ .borrow()
+ .remote()
+ .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL);
+ self.send_streams.insert(
+ next_stream_id,
+ SendStream::new(
+ next_stream_id,
+ send_initial_max_stream_data,
+ self.flow_mgr.clone(),
+ self.events.clone(),
+ ),
+ );
+ }
+
+ *next_stream_idx += 1;
+ if *next_stream_idx > stream_idx {
+ break;
+ }
+ }
+ }
+ }
+
+ Ok((
+ self.send_streams.get_mut(stream_id).ok(),
+ self.recv_streams.get_mut(&stream_id),
+ ))
+ }
+
+ /// Create a stream.
+ /// Returns new stream id
+ /// # Errors
+ /// `ConnectionState` if the connecton stat does not allow to create streams.
+ /// `StreamLimitError` if we are limiied by server's stream concurence.
+ pub fn stream_create(&mut self, st: StreamType) -> Res<u64> {
+ // Can't make streams while closing, otherwise rely on the stream limits.
+ match self.state {
+ State::Closing { .. } | State::Draining { .. } | State::Closed { .. } => {
+ return Err(Error::ConnectionState);
+ }
+ State::WaitInitial | State::Handshaking => {
+ if self.role == Role::Client && self.zero_rtt_state != ZeroRttState::Sending {
+ return Err(Error::ConnectionState);
+ }
+ }
+ // In all other states, trust that the stream limits are correct.
+ _ => (),
+ }
+
+ Ok(match st {
+ StreamType::UniDi => {
+ if self.indexes.remote_next_stream_uni >= self.indexes.remote_max_stream_uni {
+ self.flow_mgr
+ .borrow_mut()
+ .streams_blocked(self.indexes.remote_max_stream_uni, StreamType::UniDi);
+ qwarn!(
+ [self],
+ "local uni stream create blocked, next={:?} max={:?}",
+ self.indexes.remote_next_stream_uni,
+ self.indexes.remote_max_stream_uni
+ );
+ return Err(Error::StreamLimitError);
+ }
+ let new_id = self
+ .indexes
+ .remote_next_stream_uni
+ .to_stream_id(StreamType::UniDi, self.role);
+ self.indexes.remote_next_stream_uni += 1;
+ let initial_max_stream_data = self
+ .tps
+ .borrow()
+ .remote()
+ .get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI);
+
+ self.send_streams.insert(
+ new_id,
+ SendStream::new(
+ new_id,
+ initial_max_stream_data,
+ self.flow_mgr.clone(),
+ self.events.clone(),
+ ),
+ );
+ new_id.as_u64()
+ }
+ StreamType::BiDi => {
+ if self.indexes.remote_next_stream_bidi >= self.indexes.remote_max_stream_bidi {
+ self.flow_mgr
+ .borrow_mut()
+ .streams_blocked(self.indexes.remote_max_stream_bidi, StreamType::BiDi);
+ qwarn!(
+ [self],
+ "local bidi stream create blocked, next={:?} max={:?}",
+ self.indexes.remote_next_stream_bidi,
+ self.indexes.remote_max_stream_bidi
+ );
+ return Err(Error::StreamLimitError);
+ }
+ let new_id = self
+ .indexes
+ .remote_next_stream_bidi
+ .to_stream_id(StreamType::BiDi, self.role);
+ self.indexes.remote_next_stream_bidi += 1;
+ // From the local perspective, this is a local- originated BiDi stream. From the
+ // remote perspective, this is a remote-originated BiDi stream. Therefore, look at
+ // the remote transport parameters for the INITIAL_MAX_STREAM_DATA_BIDI_REMOTE value
+ // to decide how much this endpoint is allowed to send its peer.
+ let send_initial_max_stream_data = self
+ .tps
+ .borrow()
+ .remote()
+ .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE);
+
+ self.send_streams.insert(
+ new_id,
+ SendStream::new(
+ new_id,
+ send_initial_max_stream_data,
+ self.flow_mgr.clone(),
+ self.events.clone(),
+ ),
+ );
+ // From the local perspective, this is a local- originated BiDi stream. From the
+ // remote perspective, this is a remote-originated BiDi stream. Therefore, look at
+ // the local transport parameters for the INITIAL_MAX_STREAM_DATA_BIDI_LOCAL value
+ // to decide how much this endpoint will allow its peer to send.
+ let recv_initial_max_stream_data = self
+ .tps
+ .borrow()
+ .local
+ .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL);
+
+ self.recv_streams.insert(
+ new_id,
+ RecvStream::new(
+ new_id,
+ recv_initial_max_stream_data,
+ self.flow_mgr.clone(),
+ self.events.clone(),
+ ),
+ );
+ new_id.as_u64()
+ }
+ })
+ }
+
+ /// Send data on a stream.
+ /// Returns how many bytes were successfully sent. Could be less
+ /// than total, based on receiver credit space available, etc.
+ /// # Errors
+ /// `InvalidStreamId` the stream does not exist,
+ /// `InvalidInput` if length of `data` is zero,
+ /// `FinalSizeError` if the stream has already been closed.
+ pub fn stream_send(&mut self, stream_id: u64, data: &[u8]) -> Res<usize> {
+ self.send_streams.get_mut(stream_id.into())?.send(data)
+ }
+
+ /// Send all data or nothing on a stream. May cause DATA_BLOCKED or
+ /// STREAM_DATA_BLOCKED frames to be sent.
+ /// Returns true if data was successfully sent, otherwise false.
+ /// # Errors
+ /// `InvalidStreamId` the stream does not exist,
+ /// `InvalidInput` if length of `data` is zero,
+ /// `FinalSizeError` if the stream has already been closed.
+ pub fn stream_send_atomic(&mut self, stream_id: u64, data: &[u8]) -> Res<bool> {
+ let val = self
+ .send_streams
+ .get_mut(stream_id.into())?
+ .send_atomic(data);
+ if let Ok(val) = val {
+ debug_assert!(
+ val == 0 || val == data.len(),
+ "Unexpected value {} when trying to send {} bytes atomically",
+ val,
+ data.len()
+ );
+ }
+ val.map(|v| v == data.len())
+ }
+
+ /// Bytes that stream_send() is guaranteed to accept for sending.
+ /// i.e. that will not be blocked by flow credits or send buffer max
+ /// capacity.
+ pub fn stream_avail_send_space(&self, stream_id: u64) -> Res<usize> {
+ Ok(self.send_streams.get(stream_id.into())?.avail())
+ }
+
+ /// Close the stream. Enqueued data will be sent.
+ pub fn stream_close_send(&mut self, stream_id: u64) -> Res<()> {
+ self.send_streams.get_mut(stream_id.into())?.close();
+ Ok(())
+ }
+
+ /// Abandon transmission of in-flight and future stream data.
+ pub fn stream_reset_send(&mut self, stream_id: u64, err: AppError) -> Res<()> {
+ self.send_streams.get_mut(stream_id.into())?.reset(err);
+ Ok(())
+ }
+
+ /// Read buffered data from stream. bool says whether read bytes includes
+ /// the final data on stream.
+ /// # Errors
+ /// `InvalidStreamId` if the stream does not exist.
+ /// `NoMoreData` if data and fin bit were previously read by the application.
+ pub fn stream_recv(&mut self, stream_id: u64, data: &mut [u8]) -> Res<(usize, bool)> {
+ let stream = self
+ .recv_streams
+ .get_mut(&stream_id.into())
+ .ok_or(Error::InvalidStreamId)?;
+
+ let rb = stream.read(data)?;
+ Ok((rb.0 as usize, rb.1))
+ }
+
+ /// Application is no longer interested in this stream.
+ pub fn stream_stop_sending(&mut self, stream_id: u64, err: AppError) -> Res<()> {
+ let stream = self
+ .recv_streams
+ .get_mut(&stream_id.into())
+ .ok_or(Error::InvalidStreamId)?;
+
+ stream.stop_sending(err);
+ Ok(())
+ }
+
+ #[cfg(test)]
+ pub fn get_pto(&self) -> Duration {
+ self.loss_recovery.pto_raw(PNSpace::ApplicationData)
+ }
+}
+
+impl EventProvider for Connection {
+ type Event = ConnectionEvent;
+
+ /// Return true if there are outstanding events.
+ fn has_events(&self) -> bool {
+ self.events.has_events()
+ }
+
+ /// Get events that indicate state changes on the connection. This method
+ /// correctly handles cases where handling one event can obsolete
+ /// previously-queued events, or cause new events to be generated.
+ fn next_event(&mut self) -> Option<Self::Event> {
+ self.events.next_event()
+ }
+}
+
+impl ::std::fmt::Display for Connection {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "{:?} ", self.role)?;
+ if let Some(cid) = self.odcid() {
+ std::fmt::Display::fmt(&cid, f)
+ } else {
+ write!(f, "...")
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests;
diff --git a/third_party/rust/neqo-transport/src/connection/params.rs b/third_party/rust/neqo-transport/src/connection/params.rs
new file mode 100644
index 0000000000..c4404b54d9
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/params.rs
@@ -0,0 +1,71 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::frame::StreamType;
+use crate::{
+ CongestionControlAlgorithm, QuicVersion, LOCAL_STREAM_LIMIT_BIDI, LOCAL_STREAM_LIMIT_UNI,
+};
+
+/// ConnectionParameters use for setting intitial value for QUIC parameters.
+/// This collect like initial limits, protocol version and congestion control.
+#[derive(Clone)]
+pub struct ConnectionParameters {
+ quic_version: QuicVersion,
+ cc_algorithm: CongestionControlAlgorithm,
+ max_streams_bidi: u64,
+ max_streams_uni: u64,
+}
+
+impl Default for ConnectionParameters {
+ fn default() -> Self {
+ Self {
+ quic_version: QuicVersion::default(),
+ cc_algorithm: CongestionControlAlgorithm::NewReno,
+ max_streams_bidi: LOCAL_STREAM_LIMIT_BIDI,
+ max_streams_uni: LOCAL_STREAM_LIMIT_UNI,
+ }
+ }
+}
+
+impl ConnectionParameters {
+ pub fn get_quic_version(&self) -> QuicVersion {
+ self.quic_version
+ }
+
+ pub fn quic_version(mut self, v: QuicVersion) -> Self {
+ self.quic_version = v;
+ self
+ }
+
+ pub fn get_cc_algorithm(&self) -> CongestionControlAlgorithm {
+ self.cc_algorithm
+ }
+
+ pub fn cc_algorithm(mut self, v: CongestionControlAlgorithm) -> Self {
+ self.cc_algorithm = v;
+ self
+ }
+
+ pub fn get_max_streams(&self, stream_type: StreamType) -> u64 {
+ match stream_type {
+ StreamType::BiDi => self.max_streams_bidi,
+ StreamType::UniDi => self.max_streams_uni,
+ }
+ }
+
+ pub fn max_streams(mut self, stream_type: StreamType, v: u64) -> Self {
+ assert!(v <= (1 << 60), "max_streams's parameter too big");
+ match stream_type {
+ StreamType::BiDi => {
+ self.max_streams_bidi = v;
+ }
+ StreamType::UniDi => {
+ self.max_streams_uni = v;
+ }
+ }
+ self
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/connection/saved.rs b/third_party/rust/neqo-transport/src/connection/saved.rs
new file mode 100644
index 0000000000..b2da0c644a
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/saved.rs
@@ -0,0 +1,72 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::mem;
+use std::time::Instant;
+
+use crate::crypto::CryptoSpace;
+use neqo_common::{qdebug, qinfo, Datagram};
+
+/// The number of datagrams that are saved during the handshake when
+/// keys to decrypt them are not yet available.
+///
+/// This value exceeds what should be possible to send during the handshake.
+/// Neither endpoint should have enough congestion window to send this
+/// much before the handshake completes.
+const MAX_SAVED_DATAGRAMS: usize = 32;
+
+pub struct SavedDatagram {
+ /// The datagram.
+ pub d: Datagram,
+ /// The time that the datagram was received.
+ pub t: Instant,
+}
+
+#[derive(Default)]
+pub struct SavedDatagrams {
+ handshake: Vec<SavedDatagram>,
+ application_data: Vec<SavedDatagram>,
+ available: Option<CryptoSpace>,
+}
+
+impl SavedDatagrams {
+ fn store(&mut self, cspace: CryptoSpace) -> &mut Vec<SavedDatagram> {
+ match cspace {
+ CryptoSpace::Handshake => &mut self.handshake,
+ CryptoSpace::ApplicationData => &mut self.application_data,
+ _ => panic!("unexpected space"),
+ }
+ }
+
+ pub fn save(&mut self, cspace: CryptoSpace, d: Datagram, t: Instant) {
+ let store = self.store(cspace);
+
+ if store.len() < MAX_SAVED_DATAGRAMS {
+ qdebug!("saving datagram of {} bytes", d.len());
+ store.push(SavedDatagram { d, t });
+ } else {
+ qinfo!("not saving datagram of {} bytes", d.len());
+ }
+ }
+
+ pub fn make_available(&mut self, cspace: CryptoSpace) {
+ debug_assert_ne!(cspace, CryptoSpace::ZeroRtt);
+ debug_assert_ne!(cspace, CryptoSpace::Initial);
+ if !self.store(cspace).is_empty() {
+ self.available = Some(cspace);
+ }
+ }
+
+ pub fn available(&self) -> Option<CryptoSpace> {
+ self.available
+ }
+
+ pub fn take_saved(&mut self) -> Vec<SavedDatagram> {
+ self.available
+ .take()
+ .map_or_else(Vec::new, |cspace| mem::take(self.store(cspace)))
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/connection/state.rs b/third_party/rust/neqo-transport/src/connection/state.rs
new file mode 100644
index 0000000000..1b38c1650c
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/state.rs
@@ -0,0 +1,207 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::cmp::Ordering;
+use std::mem;
+use std::time::Instant;
+
+use crate::frame::{Frame, FrameType};
+use crate::packet::PacketBuilder;
+use crate::recovery::RecoveryToken;
+use crate::{CloseError, ConnectionError};
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+/// The state of the Connection.
+pub enum State {
+ Init,
+ WaitInitial,
+ Handshaking,
+ Connected,
+ Confirmed,
+ Closing {
+ error: ConnectionError,
+ timeout: Instant,
+ },
+ Draining {
+ error: ConnectionError,
+ timeout: Instant,
+ },
+ Closed(ConnectionError),
+}
+
+impl State {
+ #[must_use]
+ pub fn connected(&self) -> bool {
+ matches!(self, Self::Connected | Self::Confirmed)
+ }
+
+ #[must_use]
+ pub fn closed(&self) -> bool {
+ matches!(self, Self::Closing { .. } | Self::Draining { .. } | Self::Closed(_))
+ }
+}
+
+// Implement `PartialOrd` so that we can enforce monotonic state progression.
+impl PartialOrd for State {
+ #[allow(clippy::match_same_arms)] // Lint bug: rust-lang/rust-clippy#860
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ if mem::discriminant(self) == mem::discriminant(other) {
+ return Some(Ordering::Equal);
+ }
+ Some(match (self, other) {
+ (Self::Init, _) => Ordering::Less,
+ (_, Self::Init) => Ordering::Greater,
+ (Self::WaitInitial, _) => Ordering::Less,
+ (_, Self::WaitInitial) => Ordering::Greater,
+ (Self::Handshaking, _) => Ordering::Less,
+ (_, Self::Handshaking) => Ordering::Greater,
+ (Self::Connected, _) => Ordering::Less,
+ (_, Self::Connected) => Ordering::Greater,
+ (Self::Confirmed, _) => Ordering::Less,
+ (_, Self::Confirmed) => Ordering::Greater,
+ (Self::Closing { .. }, _) => Ordering::Less,
+ (_, Self::Closing { .. }) => Ordering::Greater,
+ (Self::Draining { .. }, _) => Ordering::Less,
+ (_, Self::Draining { .. }) => Ordering::Greater,
+ (Self::Closed(_), _) => unreachable!(),
+ })
+ }
+}
+
+impl Ord for State {
+ fn cmp(&self, other: &Self) -> Ordering {
+ if mem::discriminant(self) == mem::discriminant(other) {
+ return Ordering::Equal;
+ }
+ match (self, other) {
+ (Self::Init, _) => Ordering::Less,
+ (_, Self::Init) => Ordering::Greater,
+ (Self::WaitInitial, _) => Ordering::Less,
+ (_, Self::WaitInitial) => Ordering::Greater,
+ (Self::Handshaking, _) => Ordering::Less,
+ (_, Self::Handshaking) => Ordering::Greater,
+ (Self::Connected, _) => Ordering::Less,
+ (_, Self::Connected) => Ordering::Greater,
+ (Self::Confirmed, _) => Ordering::Less,
+ (_, Self::Confirmed) => Ordering::Greater,
+ (Self::Closing { .. }, _) => Ordering::Less,
+ (_, Self::Closing { .. }) => Ordering::Greater,
+ (Self::Draining { .. }, _) => Ordering::Less,
+ (_, Self::Draining { .. }) => Ordering::Greater,
+ (Self::Closed(_), _) => unreachable!(),
+ }
+ }
+}
+
+type ClosingFrame = Frame<'static>;
+
+/// `StateSignaling` manages whether we need to send HANDSHAKE_DONE and CONNECTION_CLOSE.
+/// Valid state transitions are:
+/// * Idle -> HandshakeDone: at the server when the handshake completes
+/// * HandshakeDone -> Idle: when a HANDSHAKE_DONE frame is sent
+/// * Idle/HandshakeDone -> Closing/Draining: when closing or draining
+/// * Closing/Draining -> CloseSent: after sending CONNECTION_CLOSE
+/// * CloseSent -> Closing: any time a new CONNECTION_CLOSE is needed
+/// * -> Reset: from any state in case of a stateless reset
+#[derive(Debug, Clone, PartialEq)]
+pub enum StateSignaling {
+ Idle,
+ HandshakeDone,
+ /// These states save the frame that needs to be sent.
+ Closing(ClosingFrame),
+ Draining(ClosingFrame),
+ /// This state saves the frame that might need to be sent again.
+ /// If it is `None`, then we are draining and don't send.
+ CloseSent(Option<ClosingFrame>),
+ Reset,
+}
+
+impl StateSignaling {
+ pub fn handshake_done(&mut self) {
+ if *self != Self::Idle {
+ debug_assert!(false, "StateSignaling must be in Idle state.");
+ return;
+ }
+ *self = Self::HandshakeDone
+ }
+
+ pub fn write_done(&mut self, builder: &mut PacketBuilder) -> Option<RecoveryToken> {
+ if *self == Self::HandshakeDone && builder.remaining() >= 1 {
+ *self = Self::Idle;
+ builder.encode_varint(Frame::HandshakeDone.get_type());
+ Some(RecoveryToken::HandshakeDone)
+ } else {
+ None
+ }
+ }
+
+ fn make_close_frame(
+ error: ConnectionError,
+ frame_type: FrameType,
+ message: impl AsRef<str>,
+ ) -> ClosingFrame {
+ let reason_phrase = message.as_ref().as_bytes().to_owned();
+ Frame::ConnectionClose {
+ error_code: CloseError::from(error),
+ frame_type,
+ reason_phrase,
+ }
+ }
+
+ pub fn close(
+ &mut self,
+ error: ConnectionError,
+ frame_type: FrameType,
+ message: impl AsRef<str>,
+ ) {
+ if *self != Self::Reset {
+ *self = Self::Closing(Self::make_close_frame(error, frame_type, message));
+ }
+ }
+
+ pub fn drain(
+ &mut self,
+ error: ConnectionError,
+ frame_type: FrameType,
+ message: impl AsRef<str>,
+ ) {
+ if *self != Self::Reset {
+ *self = Self::Draining(Self::make_close_frame(error, frame_type, message));
+ }
+ }
+
+ /// If a close is pending, take a frame.
+ pub fn close_frame(&mut self) -> Option<ClosingFrame> {
+ match self {
+ Self::Closing(frame) => {
+ // When we are closing, we might need to send the close frame again.
+ let frame = mem::replace(frame, Frame::Padding);
+ *self = Self::CloseSent(Some(frame.clone()));
+ Some(frame)
+ }
+ Self::Draining(frame) => {
+ // When we are draining, just send once.
+ let frame = mem::replace(frame, Frame::Padding);
+ *self = Self::CloseSent(None);
+ Some(frame)
+ }
+ _ => None,
+ }
+ }
+
+ /// If a close can be sent again, prepare to send it again.
+ pub fn send_close(&mut self) {
+ if let Self::CloseSent(Some(frame)) = self {
+ let frame = mem::replace(frame, Frame::Padding);
+ *self = Self::Closing(frame);
+ }
+ }
+
+ /// We just got a stateless reset. Terminate.
+ pub fn reset(&mut self) {
+ *self = Self::Reset;
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/cc.rs b/third_party/rust/neqo-transport/src/connection/tests/cc.rs
new file mode 100644
index 0000000000..f69234a357
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/cc.rs
@@ -0,0 +1,526 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use super::super::{Connection, Output};
+use super::{
+ assert_full_cwnd, connect_force_idle, connect_rtt_idle, cwnd_packets, default_client,
+ default_server, fill_cwnd, send_something, AT_LEAST_PTO, DEFAULT_RTT, POST_HANDSHAKE_CWND,
+};
+use crate::cc::{CWND_MIN, MAX_DATAGRAM_SIZE};
+use crate::frame::StreamType;
+use crate::packet::PacketNumber;
+use crate::recovery::{ACK_ONLY_SIZE_LIMIT, PACKET_THRESHOLD};
+use crate::sender::PACING_BURST_SIZE;
+use crate::stats::MAX_PTO_COUNTS;
+use crate::tparams::{self, TransportParameter};
+use crate::tracking::MAX_UNACKED_PKTS;
+
+use neqo_common::{qdebug, qinfo, qtrace, Datagram};
+use std::convert::TryFrom;
+use std::time::{Duration, Instant};
+use test_fixture::{self, now};
+
+fn induce_persistent_congestion(
+ client: &mut Connection,
+ server: &mut Connection,
+ mut now: Instant,
+) -> Instant {
+ // Note: wait some arbitrary time that should be longer than pto
+ // timer. This is rather brittle.
+ now += AT_LEAST_PTO;
+
+ let mut pto_counts = [0; MAX_PTO_COUNTS];
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ qtrace!([client], "first PTO");
+ let (c_tx_dgrams, next_now) = fill_cwnd(client, 0, now);
+ now = next_now;
+ assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets
+
+ pto_counts[0] = 1;
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ qtrace!([client], "second PTO");
+ now += AT_LEAST_PTO * 2;
+ let (c_tx_dgrams, next_now) = fill_cwnd(client, 0, now);
+ now = next_now;
+ assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets
+
+ pto_counts[0] = 0;
+ pto_counts[1] = 1;
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ qtrace!([client], "third PTO");
+ now += AT_LEAST_PTO * 4;
+ let (c_tx_dgrams, next_now) = fill_cwnd(client, 0, now);
+ now = next_now;
+ assert_eq!(c_tx_dgrams.len(), 2); // Two PTO packets
+
+ pto_counts[1] = 0;
+ pto_counts[2] = 1;
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ // Generate ACK
+ let s_tx_dgram = ack_bytes(server, 0, c_tx_dgrams, now);
+
+ // An ACK for the third PTO causes persistent congestion.
+ for dgram in s_tx_dgram {
+ client.process_input(dgram, now);
+ }
+
+ assert_eq!(client.loss_recovery.cwnd(), CWND_MIN);
+ now
+}
+
+// Receive multiple packets and generate an ack-only packet.
+fn ack_bytes<D>(dest: &mut Connection, stream: u64, in_dgrams: D, now: Instant) -> Vec<Datagram>
+where
+ D: IntoIterator<Item = Datagram>,
+ D::IntoIter: ExactSizeIterator,
+{
+ let mut srv_buf = [0; 4_096];
+
+ let in_dgrams = in_dgrams.into_iter();
+ qdebug!([dest], "ack_bytes {} datagrams", in_dgrams.len());
+ for dgram in in_dgrams {
+ dest.process_input(dgram, now);
+ }
+
+ loop {
+ let (bytes_read, _fin) = dest.stream_recv(stream, &mut srv_buf).unwrap();
+ qtrace!([dest], "ack_bytes read {} bytes", bytes_read);
+ if bytes_read == 0 {
+ break;
+ }
+ }
+
+ let mut tx_dgrams = Vec::new();
+ while let Output::Datagram(dg) = dest.process_output(now) {
+ tx_dgrams.push(dg);
+ }
+
+ assert!((tx_dgrams.len() == 1) || (tx_dgrams.len() == 2));
+ tx_dgrams
+}
+
+#[test]
+/// Verify initial CWND is honored.
+fn cc_slow_start() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ server
+ .set_local_tparam(
+ tparams::INITIAL_MAX_DATA,
+ TransportParameter::Integer(65536),
+ )
+ .unwrap();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Try to send a lot of data
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2);
+ let (c_tx_dgrams, _) = fill_cwnd(&mut client, 2, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+ assert!(client.loss_recovery.cwnd_avail() < ACK_ONLY_SIZE_LIMIT);
+}
+
+#[test]
+/// Verify that CC moves to cong avoidance when a packet is marked lost.
+fn cc_slow_start_to_cong_avoidance_recovery_period() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Create stream 0
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0);
+
+ // Buffer up lot of data and generate packets
+ let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+ // Predict the packet number of the last packet sent.
+ // We have already sent one packet in `connect_force_idle` (an ACK),
+ // so this will be equal to the number of packets in this flight.
+ let flight1_largest = PacketNumber::try_from(c_tx_dgrams.len()).unwrap();
+
+ // Server: Receive and generate ack
+ now += DEFAULT_RTT / 2;
+ let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now);
+ assert_eq!(
+ server.stats().frame_tx.largest_acknowledged,
+ flight1_largest
+ );
+
+ // Client: Process ack
+ now += DEFAULT_RTT / 2;
+ for dgram in s_tx_dgram {
+ client.process_input(dgram, now);
+ }
+ assert_eq!(
+ client.stats().frame_rx.largest_acknowledged,
+ flight1_largest
+ );
+
+ // Client: send more
+ let (mut c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND * 2);
+ let flight2_largest = flight1_largest + u64::try_from(c_tx_dgrams.len()).unwrap();
+
+ // Server: Receive and generate ack again, but drop first packet
+ now += DEFAULT_RTT / 2;
+ c_tx_dgrams.remove(0);
+ let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now);
+ assert_eq!(
+ server.stats().frame_tx.largest_acknowledged,
+ flight2_largest
+ );
+
+ // Client: Process ack
+ now += DEFAULT_RTT / 2;
+ for dgram in s_tx_dgram {
+ client.process_input(dgram, now);
+ }
+ assert_eq!(
+ client.stats().frame_rx.largest_acknowledged,
+ flight2_largest
+ );
+}
+
+#[test]
+/// Verify that CC stays in recovery period when packet sent before start of
+/// recovery period is acked.
+fn cc_cong_avoidance_recovery_period_unchanged() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ // Create stream 0
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0);
+
+ // Buffer up lot of data and generate packets
+ let (mut c_tx_dgrams, now) = fill_cwnd(&mut client, 0, now());
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+
+ // Drop 0th packet. When acked, this should put client into CARP.
+ c_tx_dgrams.remove(0);
+
+ let c_tx_dgrams2 = c_tx_dgrams.split_off(5);
+
+ // Server: Receive and generate ack
+ let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now);
+ for dgram in s_tx_dgram {
+ client.process_input(dgram, now);
+ }
+
+ let cwnd1 = client.loss_recovery.cwnd();
+
+ // Generate ACK for more received packets
+ let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams2, now);
+
+ // ACK more packets but they were sent before end of recovery period
+ for dgram in s_tx_dgram {
+ client.process_input(dgram, now);
+ }
+
+ // cwnd should not have changed since ACKed packets were sent before
+ // recovery period expired
+ let cwnd2 = client.loss_recovery.cwnd();
+ assert_eq!(cwnd1, cwnd2);
+}
+
+#[test]
+/// Ensure that a single packet is sent after entering recovery, even
+/// when that exceeds the available congestion window.
+fn single_packet_on_recovery() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ // Drop a few packets, up to the reordering threshold.
+ for _ in 0..PACKET_THRESHOLD {
+ let _dropped = send_something(&mut client, now());
+ }
+ let delivered = send_something(&mut client, now());
+
+ // Now fill the congestion window.
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0);
+ let _ = fill_cwnd(&mut client, 0, now());
+ assert!(client.loss_recovery.cwnd_avail() < ACK_ONLY_SIZE_LIMIT);
+
+ // Acknowledge just one packet and cause one packet to be declared lost.
+ // The length is the amount of credit the client should have.
+ let ack = server.process(Some(delivered), now()).dgram();
+ assert!(ack.is_some());
+
+ // The client should see the loss and enter recovery.
+ // As there are many outstanding packets, there should be no available cwnd.
+ client.process_input(ack.unwrap(), now());
+ assert_eq!(client.loss_recovery.cwnd_avail(), 0);
+
+ // The client should send one packet, ignoring the cwnd.
+ let dgram = client.process_output(now()).dgram();
+ assert!(dgram.is_some());
+}
+
+#[test]
+/// Verify that CC moves out of recovery period when packet sent after start
+/// of recovery period is acked.
+fn cc_cong_avoidance_recovery_period_to_cong_avoidance() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Create stream 0
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0);
+
+ // Buffer up lot of data and generate packets
+ let (mut c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now);
+
+ // Drop 0th packet. When acked, this should put client into CARP.
+ c_tx_dgrams.remove(0);
+
+ // Server: Receive and generate ack
+ now += DEFAULT_RTT / 2;
+ let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now);
+
+ // Client: Process ack
+ now += DEFAULT_RTT / 2;
+ for dgram in s_tx_dgram {
+ client.process_input(dgram, now);
+ }
+
+ // Should be in CARP now.
+ now += DEFAULT_RTT / 2;
+ qinfo!(
+ "moving to congestion avoidance {}",
+ client.loss_recovery.cwnd()
+ );
+
+ // Now make sure that we increase congestion window according to the
+ // accurate byte counting version of congestion avoidance.
+ // Check over several increases to be sure.
+ let mut expected_cwnd = client.loss_recovery.cwnd();
+ // Fill cwnd.
+ let (mut c_tx_dgrams, next_now) = fill_cwnd(&mut client, 0, now);
+ now = next_now;
+ for i in 0..5 {
+ qinfo!("iteration {}", i);
+
+ let c_tx_size: usize = c_tx_dgrams.iter().map(|d| d.len()).sum();
+ qinfo!(
+ "client sending {} bytes into cwnd of {}",
+ c_tx_size,
+ client.loss_recovery.cwnd()
+ );
+ assert_eq!(c_tx_size, expected_cwnd);
+
+ // As acks arrive we will continue filling cwnd and save all packets
+ // from this cycle will be stored in next_c_tx_dgrams.
+ let mut next_c_tx_dgrams: Vec<Datagram> = Vec::new();
+
+ // Until we process all the packets, the congestion window remains the same.
+ // Note that we need the client to process ACK frames in stages, so split the
+ // datagrams into two, ensuring that we allow for an ACK for each batch.
+ let most = c_tx_dgrams.len() - MAX_UNACKED_PKTS - 1;
+ let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams.drain(..most), now);
+ for dgram in s_tx_dgram {
+ assert_eq!(client.loss_recovery.cwnd(), expected_cwnd);
+ client.process_input(dgram, now);
+ // make sure to fill cwnd again.
+ let (mut new_pkts, next_now) = fill_cwnd(&mut client, 0, now);
+ now = next_now;
+ next_c_tx_dgrams.append(&mut new_pkts);
+ }
+ let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now);
+ for dgram in s_tx_dgram {
+ assert_eq!(client.loss_recovery.cwnd(), expected_cwnd);
+ client.process_input(dgram, now);
+ // make sure to fill cwnd again.
+ let (mut new_pkts, next_now) = fill_cwnd(&mut client, 0, now);
+ now = next_now;
+ next_c_tx_dgrams.append(&mut new_pkts);
+ }
+ expected_cwnd += MAX_DATAGRAM_SIZE;
+ assert_eq!(client.loss_recovery.cwnd(), expected_cwnd);
+ c_tx_dgrams = next_c_tx_dgrams;
+ }
+}
+
+#[test]
+/// Verify transition to persistent congestion state if conditions are met.
+fn cc_slow_start_to_persistent_congestion_no_acks() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let now = connect_rtt_idle(&mut client, &mut server, DEFAULT_RTT);
+
+ // Create stream 0
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0);
+
+ // Buffer up lot of data and generate packets
+ let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now);
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+
+ // Server: Receive and generate ack
+ now += DEFAULT_RTT / 2;
+ let _ = ack_bytes(&mut server, 0, c_tx_dgrams, now);
+
+ // ACK lost.
+ induce_persistent_congestion(&mut client, &mut server, now);
+}
+
+#[test]
+/// Verify transition to persistent congestion state if conditions are met.
+fn cc_slow_start_to_persistent_congestion_some_acks() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ // Create stream 0
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0);
+
+ // Buffer up lot of data and generate packets
+ let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now());
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+
+ // Server: Receive and generate ack
+ now += Duration::from_millis(100);
+ let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now);
+
+ now += Duration::from_millis(100);
+ for dgram in s_tx_dgram {
+ client.process_input(dgram, now);
+ }
+
+ // send bytes that will be lost
+ let (_, next_now) = fill_cwnd(&mut client, 0, now);
+ now = next_now + Duration::from_millis(100);
+
+ induce_persistent_congestion(&mut client, &mut server, now);
+}
+
+#[test]
+/// Verify persistent congestion moves to slow start after recovery period
+/// ends.
+fn cc_persistent_congestion_to_slow_start() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ // Create stream 0
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0);
+
+ // Buffer up lot of data and generate packets
+ let (c_tx_dgrams, mut now) = fill_cwnd(&mut client, 0, now());
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+
+ // Server: Receive and generate ack
+ now += Duration::from_millis(10);
+ let _ = ack_bytes(&mut server, 0, c_tx_dgrams, now);
+
+ // ACK lost.
+
+ now = induce_persistent_congestion(&mut client, &mut server, now);
+
+ // New part of test starts here
+
+ now += Duration::from_millis(10);
+
+ // Send packets from after start of CARP
+ let (c_tx_dgrams, next_now) = fill_cwnd(&mut client, 0, now);
+ assert_eq!(c_tx_dgrams.len(), 2);
+
+ // Server: Receive and generate ack
+ now = next_now + Duration::from_millis(100);
+ let s_tx_dgram = ack_bytes(&mut server, 0, c_tx_dgrams, now);
+
+ // No longer in CARP. (pkts acked from after start of CARP)
+ // Should be in slow start now.
+ for dgram in s_tx_dgram {
+ client.process_input(dgram, now);
+ }
+
+ // ACKing 2 packets should let client send 4.
+ let (c_tx_dgrams, _) = fill_cwnd(&mut client, 0, now);
+ assert_eq!(c_tx_dgrams.len(), 4);
+}
+
+#[test]
+fn ack_are_not_cc() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ // Create a stream
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0);
+
+ // Buffer up lot of data and generate packets, so that cc window is filled.
+ let (c_tx_dgrams, now) = fill_cwnd(&mut client, 0, now());
+ assert_full_cwnd(&c_tx_dgrams, POST_HANDSHAKE_CWND);
+
+ // The server hasn't received any of these packets yet, the server
+ // won't ACK, but if it sends an ack-eliciting packet instead.
+ qdebug!([server], "Sending ack-eliciting");
+ assert_eq!(server.stream_create(StreamType::BiDi).unwrap(), 1);
+ server.stream_send(1, b"dropped").unwrap();
+ let dropped_packet = server.process(None, now).dgram();
+ assert!(dropped_packet.is_some()); // Now drop this one.
+
+ // Now the server sends a packet that will force an ACK,
+ // because the client will detect a gap.
+ server.stream_send(1, b"sent").unwrap();
+ let ack_eliciting_packet = server.process(None, now).dgram();
+ assert!(ack_eliciting_packet.is_some());
+
+ // The client can ack the server packet even if cc windows is full.
+ qdebug!([client], "Process ack-eliciting");
+ let ack_pkt = client.process(ack_eliciting_packet, now).dgram();
+ assert!(ack_pkt.is_some());
+ qdebug!([server], "Handle ACK");
+ let prev_ack_count = server.stats().frame_rx.ack;
+ server.process_input(ack_pkt.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.ack, prev_ack_count + 1);
+}
+
+#[test]
+fn pace() {
+ const RTT: Duration = Duration::from_millis(1000);
+ const DATA: &[u8] = &[0xcc; 4_096];
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = connect_rtt_idle(&mut client, &mut server, RTT);
+
+ // Now fill up the pipe and watch it trickle out.
+ let stream = client.stream_create(StreamType::BiDi).unwrap();
+ loop {
+ let written = client.stream_send(stream, DATA).unwrap();
+ if written < DATA.len() {
+ break;
+ }
+ }
+ let mut count = 0;
+ // We should get a burst at first.
+ for _ in 0..PACING_BURST_SIZE {
+ let dgram = client.process_output(now).dgram();
+ assert!(dgram.is_some());
+ count += 1;
+ }
+ let gap = client.process_output(now).callback();
+ assert_ne!(gap, Duration::new(0, 0));
+ // The last one will not be paced.
+ for _ in PACING_BURST_SIZE..cwnd_packets(POST_HANDSHAKE_CWND) - 1 {
+ assert_eq!(client.process_output(now).callback(), gap);
+ now += gap;
+ let dgram = client.process_output(now).dgram();
+ assert!(dgram.is_some());
+ count += 1;
+ }
+ let dgram = client.process_output(now).dgram();
+ assert!(dgram.is_some());
+ count += 1;
+ assert_eq!(count, cwnd_packets(POST_HANDSHAKE_CWND));
+ let fin = client.process_output(now).callback();
+ assert_ne!(fin, Duration::new(0, 0));
+ assert_ne!(fin, gap);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/close.rs b/third_party/rust/neqo-transport/src/connection/tests/close.rs
new file mode 100644
index 0000000000..715073523b
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/close.rs
@@ -0,0 +1,206 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use super::super::{Connection, Output, State};
+use super::{connect, connect_force_idle, default_client, default_server, send_something};
+use crate::tparams::{self, TransportParameter};
+use crate::{AppError, ConnectionError, Error, ERROR_APPLICATION_CLOSE};
+
+use neqo_common::Datagram;
+use std::time::Duration;
+use test_fixture::{self, loopback, now};
+
+fn assert_draining(c: &Connection, expected: &Error) {
+ assert!(c.state().closed());
+ if let State::Draining {
+ error: ConnectionError::Transport(error),
+ ..
+ } = c.state()
+ {
+ assert_eq!(error, expected);
+ } else {
+ panic!();
+ }
+}
+
+#[test]
+fn connection_close() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let now = now();
+
+ client.close(now, 42, "");
+
+ let out = client.process(None, now);
+
+ server.process_input(out.dgram().unwrap(), now);
+ assert_draining(&server, &Error::PeerApplicationError(42));
+}
+
+#[test]
+fn connection_close_with_long_reason_string() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let now = now();
+ // Create a long string and use it as the close reason.
+ let long_reason = String::from_utf8([0x61; 2048].to_vec()).unwrap();
+ client.close(now, 42, long_reason);
+
+ let out = client.process(None, now);
+
+ server.process_input(out.dgram().unwrap(), now);
+ assert_draining(&server, &Error::PeerApplicationError(42));
+}
+
+// During the handshake, an application close should be sanitized.
+#[test]
+fn early_application_close() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ // One flight each.
+ let dgram = client.process(None, now()).dgram();
+ assert!(dgram.is_some());
+ let dgram = server.process(dgram, now()).dgram();
+ assert!(dgram.is_some());
+
+ server.close(now(), 77, String::from(""));
+ assert!(server.state().closed());
+ let dgram = server.process(None, now()).dgram();
+ assert!(dgram.is_some());
+
+ client.process_input(dgram.unwrap(), now());
+ assert_draining(&client, &Error::PeerError(ERROR_APPLICATION_CLOSE));
+}
+
+#[test]
+fn bad_tls_version() {
+ let mut client = default_client();
+ // Do a bad, bad thing.
+ client
+ .crypto
+ .tls
+ .set_option(neqo_crypto::Opt::Tls13CompatMode, true)
+ .unwrap();
+ let mut server = default_server();
+
+ let dgram = client.process(None, now()).dgram();
+ assert!(dgram.is_some());
+ let dgram = server.process(dgram, now()).dgram();
+ assert_eq!(
+ *server.state(),
+ State::Closed(ConnectionError::Transport(Error::ProtocolViolation))
+ );
+ assert!(dgram.is_some());
+ client.process_input(dgram.unwrap(), now());
+ assert_draining(&client, &Error::PeerError(Error::ProtocolViolation.code()));
+}
+
+/// Test the interaction between the loss recovery timer
+/// and the closing timer.
+#[test]
+fn closing_timers_interation() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let mut now = now();
+
+ // We're going to induce time-based loss recovery so that timer is set.
+ let _p1 = send_something(&mut client, now);
+ let p2 = send_something(&mut client, now);
+ let ack = server.process(Some(p2), now).dgram();
+ assert!(ack.is_some()); // This is an ACK.
+
+ // After processing the ACK, we should be on the loss recovery timer.
+ let cb = client.process(ack, now).callback();
+ assert_ne!(cb, Duration::from_secs(0));
+ now += cb;
+
+ // Rather than let the timer pop, close the connection.
+ client.close(now, 0, "");
+ let client_close = client.process(None, now).dgram();
+ assert!(client_close.is_some());
+ // This should now report the end of the closing period, not a
+ // zero-duration wait driven by the (now defunct) loss recovery timer.
+ let client_close_timer = client.process(None, now).callback();
+ assert_ne!(client_close_timer, Duration::from_secs(0));
+}
+
+#[test]
+fn closing_and_draining() {
+ const APP_ERROR: AppError = 7;
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // Save a packet from the client for later.
+ let p1 = send_something(&mut client, now());
+
+ // Close the connection.
+ client.close(now(), APP_ERROR, "");
+ let client_close = client.process(None, now()).dgram();
+ assert!(client_close.is_some());
+ let client_close_timer = client.process(None, now()).callback();
+ assert_ne!(client_close_timer, Duration::from_secs(0));
+
+ // The client will spit out the same packet in response to anything it receives.
+ let p3 = send_something(&mut server, now());
+ let client_close2 = client.process(Some(p3), now()).dgram();
+ assert_eq!(
+ client_close.as_ref().unwrap().len(),
+ client_close2.as_ref().unwrap().len()
+ );
+
+ // After this time, the client should transition to closed.
+ let end = client.process(None, now() + client_close_timer);
+ assert_eq!(end, Output::None);
+ assert_eq!(
+ *client.state(),
+ State::Closed(ConnectionError::Application(APP_ERROR))
+ );
+
+ // When the server receives the close, it too should generate CONNECTION_CLOSE.
+ let server_close = server.process(client_close, now()).dgram();
+ assert!(server.state().closed());
+ assert!(server_close.is_some());
+ // .. but it ignores any further close packets.
+ let server_close_timer = server.process(client_close2, now()).callback();
+ assert_ne!(server_close_timer, Duration::from_secs(0));
+ // Even a legitimate packet without a close in it.
+ let server_close_timer2 = server.process(Some(p1), now()).callback();
+ assert_eq!(server_close_timer, server_close_timer2);
+
+ let end = server.process(None, now() + server_close_timer);
+ assert_eq!(end, Output::None);
+ assert_eq!(
+ *server.state(),
+ State::Closed(ConnectionError::Transport(Error::PeerApplicationError(
+ APP_ERROR
+ )))
+ );
+}
+
+/// Test that a client can handle a stateless reset correctly.
+#[test]
+fn stateless_reset_client() {
+ let mut client = default_client();
+ let mut server = default_server();
+ server
+ .set_local_tparam(
+ tparams::STATELESS_RESET_TOKEN,
+ TransportParameter::Bytes(vec![77; 16]),
+ )
+ .unwrap();
+ connect_force_idle(&mut client, &mut server);
+
+ client.process_input(Datagram::new(loopback(), loopback(), vec![77; 21]), now());
+ assert_draining(&client, &Error::StatelessReset);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/handshake.rs b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs
new file mode 100644
index 0000000000..c9769b3c3c
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/handshake.rs
@@ -0,0 +1,697 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use super::super::{Connection, FixedConnectionIdManager, Output, State, LOCAL_IDLE_TIMEOUT};
+use super::{
+ assert_error, connect_force_idle, connect_with_rtt, default_client, default_server, get_tokens,
+ handshake, maybe_authenticate, send_something, AT_LEAST_PTO, DEFAULT_RTT, DEFAULT_STREAM_DATA,
+};
+use crate::connection::AddressValidation;
+use crate::events::ConnectionEvent;
+use crate::frame::StreamType;
+use crate::path::PATH_MTU_V6;
+use crate::server::ValidateAddress;
+use crate::{ConnectionError, ConnectionParameters, Error};
+
+use neqo_common::{event::Provider, qdebug, Datagram};
+use neqo_crypto::{constants::TLS_CHACHA20_POLY1305_SHA256, AuthenticationStatus};
+use std::cell::RefCell;
+use std::rc::Rc;
+use std::time::Duration;
+use test_fixture::{self, assertions, fixture_init, loopback, now, split_datagram};
+
+#[test]
+fn full_handshake() {
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6);
+
+ qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN");
+ let mut server = default_server();
+ let out = server.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6);
+
+ qdebug!("---- client: cert verification");
+ let out = client.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_some());
+
+ let out = server.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_none());
+
+ assert!(maybe_authenticate(&mut client));
+
+ qdebug!("---- client: SH..FIN -> FIN");
+ let out = client.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(*client.state(), State::Connected);
+
+ qdebug!("---- server: FIN -> ACKS");
+ let out = server.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(*server.state(), State::Confirmed);
+
+ qdebug!("---- client: ACKS -> 0");
+ let out = client.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_none());
+ assert_eq!(*client.state(), State::Confirmed);
+}
+
+#[test]
+fn handshake_failed_authentication() {
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+
+ qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN");
+ let mut server = default_server();
+ let out = server.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_some());
+
+ qdebug!("---- client: cert verification");
+ let out = client.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_some());
+
+ let out = server.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_none());
+
+ let authentication_needed = |e| matches!(e, ConnectionEvent::AuthenticationNeeded);
+ assert!(client.events().any(authentication_needed));
+ qdebug!("---- client: Alert(certificate_revoked)");
+ client.authenticated(AuthenticationStatus::CertRevoked, now());
+
+ qdebug!("---- client: -> Alert(certificate_revoked)");
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+
+ qdebug!("---- server: Alert(certificate_revoked)");
+ let out = server.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_error(&client, &ConnectionError::Transport(Error::CryptoAlert(44)));
+ assert_error(&server, &ConnectionError::Transport(Error::PeerError(300)));
+}
+
+#[test]
+fn no_alpn() {
+ fixture_init();
+ let mut client = Connection::new_client(
+ "example.com",
+ &["bad-alpn"],
+ Rc::new(RefCell::new(FixedConnectionIdManager::new(9))),
+ loopback(),
+ loopback(),
+ &ConnectionParameters::default(),
+ )
+ .unwrap();
+ let mut server = default_server();
+
+ handshake(&mut client, &mut server, now(), Duration::new(0, 0));
+ // TODO (mt): errors are immediate, which means that we never send CONNECTION_CLOSE
+ // and the client never sees the server's rejection of its handshake.
+ //assert_error(&client, ConnectionError::Transport(Error::CryptoAlert(120)));
+ assert_error(
+ &server,
+ &ConnectionError::Transport(Error::CryptoAlert(120)),
+ );
+}
+
+#[test]
+fn dup_server_flight1() {
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(out.as_dgram_ref().unwrap().len(), PATH_MTU_V6);
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+
+ qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN");
+ let mut server = default_server();
+ let out_to_rep = server.process(out.dgram(), now());
+ assert!(out_to_rep.as_dgram_ref().is_some());
+ qdebug!("Output={:0x?}", out_to_rep.as_dgram_ref());
+
+ qdebug!("---- client: cert verification");
+ let out = client.process(Some(out_to_rep.as_dgram_ref().unwrap().clone()), now());
+ assert!(out.as_dgram_ref().is_some());
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+
+ let out = server.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_none());
+
+ assert!(maybe_authenticate(&mut client));
+
+ qdebug!("---- client: SH..FIN -> FIN");
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+
+ assert_eq!(3, client.stats().packets_rx);
+ assert_eq!(0, client.stats().dups_rx);
+ assert_eq!(1, client.stats().dropped_rx);
+
+ qdebug!("---- Dup, ignored");
+ let out = client.process(out_to_rep.dgram(), now());
+ assert!(out.as_dgram_ref().is_none());
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+
+ // Four packets total received, 1 of them is a dup and one has been dropped because Initial keys
+ // are dropped. Add 2 counts of the padding that the server adds to Initial packets.
+ assert_eq!(6, client.stats().packets_rx);
+ assert_eq!(1, client.stats().dups_rx);
+ assert_eq!(3, client.stats().dropped_rx);
+}
+
+// Test that we split crypto data if they cannot fit into one packet.
+// To test this we will use a long server certificate.
+#[test]
+fn crypto_frame_split() {
+ let mut client = default_client();
+
+ let mut server = Connection::new_server(
+ test_fixture::LONG_CERT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(FixedConnectionIdManager::new(6))),
+ &ConnectionParameters::default(),
+ )
+ .expect("create a server");
+
+ let client1 = client.process(None, now());
+ assert!(client1.as_dgram_ref().is_some());
+
+ // The entire server flight doesn't fit in a single packet because the
+ // certificate is large, therefore the server will produce 2 packets.
+ let server1 = server.process(client1.dgram(), now());
+ assert!(server1.as_dgram_ref().is_some());
+ let server2 = server.process(None, now());
+ assert!(server2.as_dgram_ref().is_some());
+
+ let client2 = client.process(server1.dgram(), now());
+ // This is an ack.
+ assert!(client2.as_dgram_ref().is_some());
+ // The client might have the certificate now, so we can't guarantee that
+ // this will work.
+ let auth1 = maybe_authenticate(&mut client);
+ assert_eq!(*client.state(), State::Handshaking);
+
+ // let server process the ack for the first packet.
+ let server3 = server.process(client2.dgram(), now());
+ assert!(server3.as_dgram_ref().is_none());
+
+ // Consume the second packet from the server.
+ let client3 = client.process(server2.dgram(), now());
+
+ // Check authentication.
+ let auth2 = maybe_authenticate(&mut client);
+ assert!(auth1 ^ auth2);
+ // Now client has all data to finish handshake.
+ assert_eq!(*client.state(), State::Connected);
+
+ let client4 = client.process(server3.dgram(), now());
+ // One of these will contain data depending on whether Authentication was completed
+ // after the first or second server packet.
+ assert!(client3.as_dgram_ref().is_some() ^ client4.as_dgram_ref().is_some());
+
+ let _ = server.process(client3.dgram(), now());
+ let _ = server.process(client4.dgram(), now());
+
+ assert_eq!(*client.state(), State::Connected);
+ assert_eq!(*server.state(), State::Confirmed);
+}
+
+/// Run a single ChaCha20-Poly1305 test and get a PTO.
+#[test]
+fn chacha20poly1305() {
+ let mut server = default_server();
+ let mut client = Connection::new_client(
+ test_fixture::DEFAULT_SERVER_NAME,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(FixedConnectionIdManager::new(0))),
+ loopback(),
+ loopback(),
+ &ConnectionParameters::default(),
+ )
+ .expect("create a default client");
+ client.set_ciphers(&[TLS_CHACHA20_POLY1305_SHA256]).unwrap();
+ connect_force_idle(&mut client, &mut server);
+}
+
+/// Test that a server can send 0.5 RTT application data.
+#[test]
+fn send_05rtt() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ let c1 = client.process(None, now()).dgram();
+ assert!(c1.is_some());
+ let s1 = server.process(c1, now()).dgram().unwrap();
+ assert_eq!(s1.len(), PATH_MTU_V6);
+
+ // The server should accept writes at this point.
+ let s2 = send_something(&mut server, now());
+
+ // Complete the handshake at the client.
+ client.process_input(s1, now());
+ maybe_authenticate(&mut client);
+ assert_eq!(*client.state(), State::Connected);
+
+ // The client should receive the 0.5-RTT data now.
+ client.process_input(s2, now());
+ let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1];
+ let stream_id = client
+ .events()
+ .find_map(|e| {
+ if let ConnectionEvent::RecvStreamReadable { stream_id } = e {
+ Some(stream_id)
+ } else {
+ None
+ }
+ })
+ .unwrap();
+ let (l, ended) = client.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(&buf[..l], DEFAULT_STREAM_DATA);
+ assert!(ended);
+}
+
+/// Test that a client buffers 0.5-RTT data when it arrives early.
+#[test]
+fn reorder_05rtt() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ let c1 = client.process(None, now()).dgram();
+ assert!(c1.is_some());
+ let s1 = server.process(c1, now()).dgram().unwrap();
+
+ // The server should accept writes at this point.
+ let s2 = send_something(&mut server, now());
+
+ // We can't use the standard facility to complete the handshake, so
+ // drive it as aggressively as possible.
+ client.process_input(s2, now());
+ assert_eq!(client.stats().saved_datagrams, 1);
+
+ // After processing the first packet, the client should go back and
+ // process the 0.5-RTT packet data, which should make data available.
+ client.process_input(s1, now());
+ // We can't use `maybe_authenticate` here as that consumes events.
+ client.authenticated(AuthenticationStatus::Ok, now());
+ assert_eq!(*client.state(), State::Connected);
+
+ let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1];
+ let stream_id = client
+ .events()
+ .find_map(|e| {
+ if let ConnectionEvent::RecvStreamReadable { stream_id } = e {
+ Some(stream_id)
+ } else {
+ None
+ }
+ })
+ .unwrap();
+ let (l, ended) = client.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(&buf[..l], DEFAULT_STREAM_DATA);
+ assert!(ended);
+}
+
+#[test]
+fn reorder_05rtt_with_0rtt() {
+ const RTT: Duration = Duration::from_millis(100);
+
+ let mut client = default_client();
+ let mut server = default_server();
+ let validation = AddressValidation::new(now(), ValidateAddress::NoToken).unwrap();
+ let validation = Rc::new(RefCell::new(validation));
+ server.set_validation(Rc::clone(&validation));
+ let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT);
+
+ // Include RTT in sending the ticket or the ticket age reported by the
+ // client is wrong, which causes the server to reject 0-RTT.
+ now += RTT / 2;
+ server.send_ticket(now, &[]).unwrap();
+ let ticket = server.process_output(now).dgram().unwrap();
+ now += RTT / 2;
+ client.process_input(ticket, now);
+
+ let token = get_tokens(&mut client).pop().unwrap();
+ let mut client = default_client();
+ client.enable_resumption(now, token).unwrap();
+ let mut server = default_server();
+
+ // Send ClientHello and some 0-RTT.
+ let c1 = send_something(&mut client, now);
+ assertions::assert_coalesced_0rtt(&c1[..]);
+ // Drop the 0-RTT from the coalesced datagram, so that the server
+ // acknowledges the next 0-RTT packet.
+ let (c1, _) = split_datagram(&c1);
+ let c2 = send_something(&mut client, now);
+
+ // Handle the first packet and send 0.5-RTT in response. Drop the response.
+ now += RTT / 2;
+ let _ = server.process(Some(c1), now).dgram().unwrap();
+ // The gap in 0-RTT will result in this 0.5 RTT containing an ACK.
+ server.process_input(c2, now);
+ let s2 = send_something(&mut server, now);
+
+ // Save the 0.5 RTT.
+ now += RTT / 2;
+ client.process_input(s2, now);
+ assert_eq!(client.stats().saved_datagrams, 1);
+
+ // Now PTO at the client and cause the server to re-send handshake packets.
+ now += AT_LEAST_PTO;
+ let c3 = client.process(None, now).dgram();
+
+ now += RTT / 2;
+ let s3 = server.process(c3, now).dgram().unwrap();
+ assertions::assert_no_1rtt(&s3[..]);
+
+ // The client should be able to process the 0.5 RTT now.
+ // This should contain an ACK, so we are processing an ACK from the past.
+ now += RTT / 2;
+ client.process_input(s3, now);
+ maybe_authenticate(&mut client);
+ let c4 = client.process(None, now).dgram();
+ assert_eq!(*client.state(), State::Connected);
+ assert_eq!(client.loss_recovery.rtt(), RTT);
+
+ now += RTT / 2;
+ server.process_input(c4.unwrap(), now);
+ assert_eq!(*server.state(), State::Confirmed);
+ assert_eq!(server.loss_recovery.rtt(), RTT);
+}
+
+/// Test that a server that coalesces 0.5 RTT with handshake packets
+/// doesn't cause the client to drop application data.
+#[test]
+fn coalesce_05rtt() {
+ const RTT: Duration = Duration::from_millis(100);
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+
+ // The first exchange doesn't offer a chance for the server to send.
+ // So drop the server flight and wait for the PTO.
+ let c1 = client.process(None, now).dgram();
+ assert!(c1.is_some());
+ now += RTT / 2;
+ let s1 = server.process(c1, now).dgram();
+ assert!(s1.is_some());
+
+ // Drop the server flight. Then send some data.
+ let stream_id = server.stream_create(StreamType::UniDi).unwrap();
+ assert!(server.stream_send(stream_id, DEFAULT_STREAM_DATA).is_ok());
+ assert!(server.stream_close_send(stream_id).is_ok());
+
+ // Now after a PTO the client can send another packet.
+ // The server should then send its entire flight again,
+ // including the application data, which it sends in a 1-RTT packet.
+ now += AT_LEAST_PTO;
+ let c2 = client.process(None, now).dgram();
+ assert!(c2.is_some());
+ now += RTT / 2;
+ let s2 = server.process(c2, now).dgram();
+ // Even though there is a 1-RTT packet at the end of the datagram, the
+ // flight should be padded to full size.
+ assert_eq!(s2.as_ref().unwrap().len(), PATH_MTU_V6);
+
+ // The client should process the datagram. It can't process the 1-RTT
+ // packet until authentication completes though. So it saves it.
+ now += RTT / 2;
+ assert_eq!(client.stats().dropped_rx, 0);
+ let _ = client.process(s2, now).dgram();
+ // This packet will contain an ACK, but we can ignore it.
+ assert_eq!(client.stats().dropped_rx, 0);
+ assert_eq!(client.stats().packets_rx, 3);
+ assert_eq!(client.stats().saved_datagrams, 1);
+
+ // After (successful) authentication, the packet is processed.
+ maybe_authenticate(&mut client);
+ let c3 = client.process(None, now).dgram();
+ assert!(c3.is_some());
+ assert_eq!(client.stats().dropped_rx, 0); // No Initial padding.
+ assert_eq!(client.stats().packets_rx, 4);
+ assert_eq!(client.stats().saved_datagrams, 1);
+ assert_eq!(client.stats().frame_rx.padding, 1); // Padding uses frames.
+
+ // Allow the handshake to complete.
+ now += RTT / 2;
+ let s3 = server.process(c3, now).dgram();
+ assert!(s3.is_some());
+ assert_eq!(*server.state(), State::Confirmed);
+ now += RTT / 2;
+ let _ = client.process(s3, now).dgram();
+ assert_eq!(*client.state(), State::Confirmed);
+
+ assert_eq!(client.stats().dropped_rx, 0); // No dropped packets.
+}
+
+#[test]
+fn reorder_handshake() {
+ const RTT: Duration = Duration::from_millis(100);
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+
+ let c1 = client.process(None, now).dgram();
+ assert!(c1.is_some());
+
+ now += RTT / 2;
+ let s1 = server.process(c1, now).dgram();
+ assert!(s1.is_some());
+
+ // Drop the Initial packet from this.
+ let (_, s_hs) = split_datagram(&s1.unwrap());
+ assert!(s_hs.is_some());
+
+ // Pass just the handshake packet in and the client can't handle it yet.
+ // It can only send another Initial packet.
+ now += RTT / 2;
+ let dgram = client.process(s_hs, now).dgram();
+ assertions::assert_initial(&dgram.as_ref().unwrap(), false);
+ assert_eq!(client.stats().saved_datagrams, 1);
+ assert_eq!(client.stats().packets_rx, 1);
+
+ // Get the server to try again.
+ // Though we currently allow the server to arm its PTO timer, use
+ // a second client Initial packet to cause it to send again.
+ now += AT_LEAST_PTO;
+ let c2 = client.process(None, now).dgram();
+ now += RTT / 2;
+ let s2 = server.process(c2, now).dgram();
+ assert!(s2.is_some());
+
+ let (s_init, s_hs) = split_datagram(&s2.unwrap());
+ assert!(s_hs.is_some());
+
+ // Processing the Handshake packet first should save it.
+ now += RTT / 2;
+ client.process_input(s_hs.unwrap(), now);
+ assert_eq!(client.stats().saved_datagrams, 2);
+ assert_eq!(client.stats().packets_rx, 2);
+
+ client.process_input(s_init, now);
+ // Each saved packet should now be "received" again.
+ assert_eq!(client.stats().packets_rx, 7);
+ maybe_authenticate(&mut client);
+ let c3 = client.process(None, now).dgram();
+ assert!(c3.is_some());
+
+ // Note that though packets were saved and processed very late,
+ // they don't cause the RTT to change.
+ now += RTT / 2;
+ let s3 = server.process(c3, now).dgram();
+ assert_eq!(*server.state(), State::Confirmed);
+ assert_eq!(server.loss_recovery.rtt(), RTT);
+
+ now += RTT / 2;
+ client.process_input(s3.unwrap(), now);
+ assert_eq!(*client.state(), State::Confirmed);
+ assert_eq!(client.loss_recovery.rtt(), RTT);
+}
+
+#[test]
+fn reorder_1rtt() {
+ const RTT: Duration = Duration::from_millis(100);
+ const PACKETS: usize = 6; // Many, but not enough to overflow cwnd.
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+
+ let c1 = client.process(None, now).dgram();
+ assert!(c1.is_some());
+
+ now += RTT / 2;
+ let s1 = server.process(c1, now).dgram();
+ assert!(s1.is_some());
+
+ now += RTT / 2;
+ client.process_input(s1.unwrap(), now);
+ maybe_authenticate(&mut client);
+ let c2 = client.process(None, now).dgram();
+ assert!(c2.is_some());
+
+ // Now get a bunch of packets from the client.
+ // Give them to the server before giving it `c2`.
+ for _ in 0..PACKETS {
+ let d = send_something(&mut client, now);
+ server.process_input(d, now + RTT / 2);
+ }
+ // The server has now received those packets, and saved them.
+ // The two extra received are Initial + the junk we use for padding.
+ assert_eq!(server.stats().packets_rx, PACKETS + 2);
+ assert_eq!(server.stats().saved_datagrams, PACKETS);
+ assert_eq!(server.stats().dropped_rx, 1);
+
+ now += RTT / 2;
+ let s2 = server.process(c2, now).dgram();
+ // The server has now received those packets, and saved them.
+ // The two additional are an Initial ACK and Handshake.
+ assert_eq!(server.stats().packets_rx, PACKETS * 2 + 4);
+ assert_eq!(server.stats().saved_datagrams, PACKETS);
+ assert_eq!(server.stats().dropped_rx, 1);
+ assert_eq!(*server.state(), State::Confirmed);
+ assert_eq!(server.loss_recovery.rtt(), RTT);
+
+ now += RTT / 2;
+ client.process_input(s2.unwrap(), now);
+ assert_eq!(client.loss_recovery.rtt(), RTT);
+
+ // All the stream data that was sent should now be available.
+ let streams = server
+ .events()
+ .filter_map(|e| {
+ if let ConnectionEvent::RecvStreamReadable { stream_id } = e {
+ Some(stream_id)
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<_>>();
+ assert_eq!(streams.len(), PACKETS);
+ for stream_id in streams {
+ let mut buf = vec![0; DEFAULT_STREAM_DATA.len() + 1];
+ let (recvd, fin) = server.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(recvd, DEFAULT_STREAM_DATA.len());
+ assert!(fin);
+ }
+}
+
+#[test]
+fn corrupted_initial() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let d = client.process(None, now()).dgram().unwrap();
+ let mut corrupted = Vec::from(&d[..]);
+ // Find the last non-zero value and corrupt that.
+ let (idx, _) = corrupted
+ .iter()
+ .enumerate()
+ .rev()
+ .find(|(_, &v)| v != 0)
+ .unwrap();
+ corrupted[idx] ^= 0x76;
+ let dgram = Datagram::new(d.source(), d.destination(), corrupted);
+ server.process_input(dgram, now());
+ // The server should have received two packets,
+ // the first should be dropped, the second saved.
+ assert_eq!(server.stats().packets_rx, 2);
+ assert_eq!(server.stats().dropped_rx, 2);
+ assert_eq!(server.stats().saved_datagrams, 0);
+}
+
+#[test]
+// Absent path PTU discovery, max v6 packet size should be PATH_MTU_V6.
+fn verify_pkt_honors_mtu() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let now = now();
+
+ let res = client.process(None, now);
+ assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT));
+
+ // Try to send a large stream and verify first packet is correctly sized
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2);
+ assert_eq!(client.stream_send(2, &[0xbb; 2000]).unwrap(), 2000);
+ let pkt0 = client.process(None, now);
+ assert!(matches!(pkt0, Output::Datagram(_)));
+ assert_eq!(pkt0.as_dgram_ref().unwrap().len(), PATH_MTU_V6);
+}
+
+#[test]
+fn extra_initial_hs() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+
+ let c_init = client.process(None, now).dgram();
+ assert!(c_init.is_some());
+ now += DEFAULT_RTT / 2;
+ let s_init = server.process(c_init, now).dgram();
+ assert!(s_init.is_some());
+ now += DEFAULT_RTT / 2;
+
+ // Drop the Initial packet, keep only the Handshake.
+ let (_, undecryptable) = split_datagram(&s_init.unwrap());
+ assert!(undecryptable.is_some());
+
+ // Feed the same undecryptable packet into the client a few times.
+ // Do that EXTRA_INITIALS times and each time the client will emit
+ // another Initial packet.
+ for _ in 0..=super::super::EXTRA_INITIALS {
+ let c_init = client.process(undecryptable.clone(), now).dgram();
+ assertions::assert_initial(&c_init.as_ref().unwrap(), false);
+ now += DEFAULT_RTT / 10;
+ }
+
+ // After EXTRA_INITIALS, the client stops sending Initial packets.
+ let nothing = client.process(undecryptable, now).dgram();
+ assert!(nothing.is_none());
+
+ // Until PTO, where another Initial can be used to complete the handshake.
+ now += AT_LEAST_PTO;
+ let c_init = client.process(None, now).dgram();
+ assertions::assert_initial(c_init.as_ref().unwrap(), false);
+ now += DEFAULT_RTT / 2;
+ let s_init = server.process(c_init, now).dgram();
+ now += DEFAULT_RTT / 2;
+ client.process_input(s_init.unwrap(), now);
+ maybe_authenticate(&mut client);
+ let c_fin = client.process_output(now).dgram();
+ assert_eq!(*client.state(), State::Connected);
+ now += DEFAULT_RTT / 2;
+ server.process_input(c_fin.unwrap(), now);
+ assert_eq!(*server.state(), State::Confirmed);
+}
+
+#[test]
+fn extra_initial_invalid_cid() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = now();
+
+ let c_init = client.process(None, now).dgram();
+ assert!(c_init.is_some());
+ now += DEFAULT_RTT / 2;
+ let s_init = server.process(c_init, now).dgram();
+ assert!(s_init.is_some());
+ now += DEFAULT_RTT / 2;
+
+ // If the client receives a packet that contains the wrong connection
+ // ID, it won't send another Initial.
+ let (_, hs) = split_datagram(&s_init.unwrap());
+ let hs = hs.unwrap();
+ let mut copy = hs.to_vec();
+ assert_ne!(copy[5], 0); // The DCID should be non-zero length.
+ copy[6] ^= 0xc4;
+ let dgram_copy = Datagram::new(hs.destination(), hs.source(), copy);
+ let nothing = client.process(Some(dgram_copy), now).dgram();
+ assert!(nothing.is_none());
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/idle.rs b/third_party/rust/neqo-transport/src/connection/tests/idle.rs
new file mode 100644
index 0000000000..962645eff3
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/idle.rs
@@ -0,0 +1,274 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use super::super::{IdleTimeout, Output, State, LOCAL_IDLE_TIMEOUT};
+use super::{
+ connect, connect_force_idle, connect_with_rtt, default_client, default_server,
+ maybe_authenticate, send_something, AT_LEAST_PTO,
+};
+use crate::frame::StreamType;
+use crate::packet::PacketBuilder;
+use crate::tparams::{self, TransportParameter};
+use crate::tracking::PNSpace;
+
+use neqo_common::Encoder;
+use std::time::Duration;
+use test_fixture::{self, now, split_datagram};
+
+#[test]
+fn idle_timeout() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let now = now();
+
+ let res = client.process(None, now);
+ assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT));
+
+ // Still connected after 29 seconds. Idle timer not reset
+ let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT - Duration::from_secs(1));
+ assert!(matches!(client.state(), State::Confirmed));
+
+ let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT);
+
+ // Not connected after LOCAL_IDLE_TIMEOUT seconds.
+ assert!(matches!(client.state(), State::Closed(_)));
+}
+
+#[test]
+fn asymmetric_idle_timeout() {
+ const LOWER_TIMEOUT_MS: u64 = 1000;
+ const LOWER_TIMEOUT: Duration = Duration::from_millis(LOWER_TIMEOUT_MS);
+ // Sanity check the constant.
+ assert!(LOWER_TIMEOUT < LOCAL_IDLE_TIMEOUT);
+
+ let mut client = default_client();
+ let mut server = default_server();
+
+ // Overwrite the default at the server.
+ server
+ .tps
+ .borrow_mut()
+ .local
+ .set_integer(tparams::IDLE_TIMEOUT, LOWER_TIMEOUT_MS);
+ server.idle_timeout = IdleTimeout::new(LOWER_TIMEOUT);
+
+ // Now connect and force idleness manually.
+ connect(&mut client, &mut server);
+ let p1 = send_something(&mut server, now());
+ let p2 = send_something(&mut server, now());
+ client.process_input(p2, now());
+ let ack = client.process(Some(p1), now()).dgram();
+ assert!(ack.is_some());
+ // Now the server has its ACK and both should be idle.
+ assert_eq!(server.process(ack, now()), Output::Callback(LOWER_TIMEOUT));
+ assert_eq!(client.process(None, now()), Output::Callback(LOWER_TIMEOUT));
+}
+
+#[test]
+fn tiny_idle_timeout() {
+ const RTT: Duration = Duration::from_millis(500);
+ const LOWER_TIMEOUT_MS: u64 = 100;
+ const LOWER_TIMEOUT: Duration = Duration::from_millis(LOWER_TIMEOUT_MS);
+ // We won't respect a value that is lower than 3*PTO, sanity check.
+ assert!(LOWER_TIMEOUT < 3 * RTT);
+
+ let mut client = default_client();
+ let mut server = default_server();
+
+ // Overwrite the default at the server.
+ server
+ .set_local_tparam(
+ tparams::IDLE_TIMEOUT,
+ TransportParameter::Integer(LOWER_TIMEOUT_MS),
+ )
+ .unwrap();
+ server.idle_timeout = IdleTimeout::new(LOWER_TIMEOUT);
+
+ // Now connect with an RTT and force idleness manually.
+ let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT);
+ let p1 = send_something(&mut server, now);
+ let p2 = send_something(&mut server, now);
+ now += RTT / 2;
+ client.process_input(p2, now);
+ let ack = client.process(Some(p1), now).dgram();
+ assert!(ack.is_some());
+
+ // The client should be idle now, but with a different timer.
+ if let Output::Callback(t) = client.process(None, now) {
+ assert!(t > LOWER_TIMEOUT);
+ } else {
+ panic!("Client not idle");
+ }
+
+ // The server should go idle after the ACK, but again with a larger timeout.
+ now += RTT / 2;
+ if let Output::Callback(t) = client.process(ack, now) {
+ assert!(t > LOWER_TIMEOUT);
+ } else {
+ panic!("Client not idle");
+ }
+}
+
+#[test]
+fn idle_send_packet1() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let now = now();
+
+ let res = client.process(None, now);
+ assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT));
+
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2);
+ assert_eq!(client.stream_send(2, b"hello").unwrap(), 5);
+
+ let out = client.process(None, now + Duration::from_secs(10));
+ let out = server.process(out.dgram(), now + Duration::from_secs(10));
+
+ // Still connected after 39 seconds because idle timer reset by outgoing
+ // packet
+ let _ = client.process(
+ out.dgram(),
+ now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(9),
+ );
+ assert!(matches!(client.state(), State::Confirmed));
+
+ // Not connected after 40 seconds.
+ let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(10));
+
+ assert!(matches!(client.state(), State::Closed(_)));
+}
+
+#[test]
+fn idle_send_packet2() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let now = now();
+
+ let res = client.process(None, now);
+ assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT));
+
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2);
+ assert_eq!(client.stream_send(2, b"hello").unwrap(), 5);
+
+ let _out = client.process(None, now + Duration::from_secs(10));
+
+ assert_eq!(client.stream_send(2, b"there").unwrap(), 5);
+ let _out = client.process(None, now + Duration::from_secs(20));
+
+ // Still connected after 39 seconds.
+ let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(9));
+ assert!(matches!(client.state(), State::Confirmed));
+
+ // Not connected after 40 seconds because timer not reset by second
+ // outgoing packet
+ let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(10));
+ assert!(matches!(client.state(), State::Closed(_)));
+}
+
+#[test]
+fn idle_recv_packet() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let now = now();
+
+ let res = client.process(None, now);
+ assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT));
+
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0);
+ assert_eq!(client.stream_send(0, b"hello").unwrap(), 5);
+
+ // Respond with another packet
+ let out = client.process(None, now + Duration::from_secs(10));
+ server.process_input(out.dgram().unwrap(), now + Duration::from_secs(10));
+ assert_eq!(server.stream_send(0, b"world").unwrap(), 5);
+ let out = server.process_output(now + Duration::from_secs(10));
+ assert_ne!(out.as_dgram_ref(), None);
+
+ let _ = client.process(out.dgram(), now + Duration::from_secs(20));
+ assert!(matches!(client.state(), State::Confirmed));
+
+ // Still connected after 49 seconds because idle timer reset by received
+ // packet
+ let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(19));
+ assert!(matches!(client.state(), State::Confirmed));
+
+ // Not connected after 50 seconds.
+ let _ = client.process(None, now + LOCAL_IDLE_TIMEOUT + Duration::from_secs(20));
+
+ assert!(matches!(client.state(), State::Closed(_)));
+}
+
+/// Caching packets should not cause the connection to become idle.
+/// This requires a few tricks to keep the connection from going
+/// idle while preventing any progress on the handshake.
+#[test]
+fn idle_caching() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let start = now();
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+
+ // Perform the first round trip, but drop the Initial from the server.
+ // The client then caches the Handshake packet.
+ let dgram = client.process_output(start).dgram();
+ let dgram = server.process(dgram, start).dgram();
+ let (_, handshake) = split_datagram(&dgram.unwrap());
+ client.process_input(handshake.unwrap(), start);
+
+ // Perform an exchange and keep the connection alive.
+ // Only allow a packet containing a PING to pass.
+ let middle = start + AT_LEAST_PTO;
+ let _ = client.process_output(middle);
+ let dgram = client.process_output(middle).dgram();
+
+ // Get the server to send its first probe and throw that away.
+ let _ = server.process_output(middle).dgram();
+ // Now let the server process the client PING. This causes the server
+ // to send CRYPTO frames again, so manually extract and discard those.
+ let ping_before_s = server.stats().frame_rx.ping;
+ server.process_input(dgram.unwrap(), middle);
+ assert_eq!(server.stats().frame_rx.ping, ping_before_s + 1);
+ let crypto = server
+ .crypto
+ .streams
+ .write_frame(PNSpace::Initial, &mut builder);
+ assert!(crypto.is_some());
+ let crypto = server
+ .crypto
+ .streams
+ .write_frame(PNSpace::Initial, &mut builder);
+ assert!(crypto.is_none());
+ let dgram = server.process_output(middle).dgram();
+
+ // Now only allow the Initial packet from the server through;
+ // it shouldn't contain a CRYPTO frame.
+ let (initial, _) = split_datagram(&dgram.unwrap());
+ let ping_before_c = client.stats().frame_rx.ping;
+ let ack_before = client.stats().frame_rx.ack;
+ client.process_input(initial, middle);
+ assert_eq!(client.stats().frame_rx.ping, ping_before_c + 1);
+ assert_eq!(client.stats().frame_rx.ack, ack_before + 1);
+
+ let end = start + LOCAL_IDLE_TIMEOUT + (AT_LEAST_PTO / 2);
+ // Now let the server Initial through, with the CRYPTO frame.
+ let dgram = server.process_output(end).dgram();
+ let (initial, _) = split_datagram(&dgram.unwrap());
+ let _ = client.process(Some(initial), end);
+ maybe_authenticate(&mut client);
+ let dgram = client.process_output(end).dgram();
+ let dgram = server.process(dgram, end).dgram();
+ client.process_input(dgram.unwrap(), end);
+ assert_eq!(*client.state(), State::Confirmed);
+ assert_eq!(*server.state(), State::Confirmed);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/keys.rs b/third_party/rust/neqo-transport/src/connection/tests/keys.rs
new file mode 100644
index 0000000000..cc572e85ca
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/keys.rs
@@ -0,0 +1,330 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use super::super::super::{ConnectionError, ERROR_AEAD_LIMIT_REACHED};
+use super::super::{Connection, Error, Output, State, StreamType, LOCAL_IDLE_TIMEOUT};
+use super::{
+ connect, connect_force_idle, default_client, default_server, maybe_authenticate,
+ send_and_receive, send_something, AT_LEAST_PTO,
+};
+use crate::crypto::{OVERWRITE_INVOCATIONS, UPDATE_WRITE_KEYS_AT};
+use crate::packet::PacketNumber;
+use crate::path::PATH_MTU_V6;
+
+use neqo_common::{qdebug, Datagram};
+use test_fixture::{self, now};
+
+fn check_discarded(peer: &mut Connection, pkt: Datagram, dropped: usize, dups: usize) {
+ // Make sure to flush any saved datagrams before doing this.
+ let _ = peer.process_output(now());
+
+ let before = peer.stats();
+ let out = peer.process(Some(pkt), now());
+ assert!(out.as_dgram_ref().is_none());
+ let after = peer.stats();
+ assert_eq!(dropped, after.dropped_rx - before.dropped_rx);
+ assert_eq!(dups, after.dups_rx - before.dups_rx);
+}
+
+fn assert_update_blocked(c: &mut Connection) {
+ assert_eq!(
+ c.initiate_key_update().unwrap_err(),
+ Error::KeyUpdateBlocked
+ )
+}
+
+fn overwrite_invocations(n: PacketNumber) {
+ OVERWRITE_INVOCATIONS.with(|v| {
+ *v.borrow_mut() = Some(n);
+ });
+}
+
+#[test]
+fn discarded_initial_keys() {
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let init_pkt_c = client.process(None, now()).dgram();
+ assert!(init_pkt_c.is_some());
+ assert_eq!(init_pkt_c.as_ref().unwrap().len(), PATH_MTU_V6);
+
+ qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN");
+ let mut server = default_server();
+ let init_pkt_s = server.process(init_pkt_c.clone(), now()).dgram();
+ assert!(init_pkt_s.is_some());
+
+ qdebug!("---- client: cert verification");
+ let out = client.process(init_pkt_s.clone(), now()).dgram();
+ assert!(out.is_some());
+
+ // The client has received handshake packet. It will remove the Initial keys.
+ // We will check this by processing init_pkt_s a second time.
+ // The initial packet should be dropped. The packet contains a Handshake packet as well, which
+ // will be marked as dup. And it will contain padding, which will be "dropped".
+ check_discarded(&mut client, init_pkt_s.unwrap(), 2, 1);
+
+ assert!(maybe_authenticate(&mut client));
+
+ // The server has not removed the Initial keys yet, because it has not yet received a Handshake
+ // packet from the client.
+ // We will check this by processing init_pkt_c a second time.
+ // The dropped packet is padding. The Initial packet has been mark dup.
+ check_discarded(&mut server, init_pkt_c.clone().unwrap(), 1, 1);
+
+ qdebug!("---- client: SH..FIN -> FIN");
+ let out = client.process(None, now()).dgram();
+ assert!(out.is_some());
+
+ // The server will process the first Handshake packet.
+ // After this the Initial keys will be dropped.
+ let out = server.process(out, now()).dgram();
+ assert!(out.is_some());
+
+ // Check that the Initial keys are dropped at the server
+ // We will check this by processing init_pkt_c a third time.
+ // The Initial packet has been dropped and padding that follows it.
+ // There is no dups, everything has been dropped.
+ check_discarded(&mut server, init_pkt_c.unwrap(), 1, 0);
+}
+
+#[test]
+fn key_update_client() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+ let mut now = now();
+
+ assert_eq!(client.get_epochs(), (Some(3), Some(3))); // (write, read)
+ assert_eq!(server.get_epochs(), (Some(3), Some(3)));
+
+ assert!(client.initiate_key_update().is_ok());
+ assert_update_blocked(&mut client);
+
+ // Initiating an update should only increase the write epoch.
+ assert_eq!(
+ Output::Callback(LOCAL_IDLE_TIMEOUT),
+ client.process(None, now)
+ );
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+
+ // Send something to propagate the update.
+ assert!(send_and_receive(&mut client, &mut server, now).is_none());
+
+ // The server should now be waiting to discharge read keys.
+ assert_eq!(server.get_epochs(), (Some(4), Some(3)));
+ let res = server.process(None, now);
+ if let Output::Callback(t) = res {
+ assert!(t < LOCAL_IDLE_TIMEOUT);
+ } else {
+ panic!("server should now be waiting to clear keys");
+ }
+
+ // Without having had time to purge old keys, more updates are blocked.
+ // The spec would permits it at this point, but we are more conservative.
+ assert_update_blocked(&mut client);
+ // The server can't update until it receives an ACK for a packet.
+ assert_update_blocked(&mut server);
+
+ // Waiting now for at least a PTO should cause the server to drop old keys.
+ // But at this point the client hasn't received a key update from the server.
+ // It will be stuck with old keys.
+ now += AT_LEAST_PTO;
+ let dgram = client.process(None, now).dgram();
+ assert!(dgram.is_some()); // Drop this packet.
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+ let _ = server.process(None, now);
+ assert_eq!(server.get_epochs(), (Some(4), Some(4)));
+
+ // Even though the server has updated, it hasn't received an ACK yet.
+ assert_update_blocked(&mut server);
+
+ // Now get an ACK from the server.
+ // The previous PTO packet (see above) was dropped, so we should get an ACK here.
+ let dgram = send_and_receive(&mut client, &mut server, now);
+ assert!(dgram.is_some());
+ let res = client.process(dgram, now);
+ // This is the first packet that the client has received from the server
+ // with new keys, so its read timer just started.
+ if let Output::Callback(t) = res {
+ assert!(t < LOCAL_IDLE_TIMEOUT);
+ } else {
+ panic!("client should now be waiting to clear keys");
+ }
+
+ assert_update_blocked(&mut client);
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+ // The server can't update until it gets something from the client.
+ assert_update_blocked(&mut server);
+
+ now += AT_LEAST_PTO;
+ let _ = client.process(None, now);
+ assert_eq!(client.get_epochs(), (Some(4), Some(4)));
+}
+
+#[test]
+fn key_update_consecutive() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let now = now();
+
+ assert!(server.initiate_key_update().is_ok());
+ assert_eq!(server.get_epochs(), (Some(4), Some(3)));
+
+ // Server sends something.
+ // Send twice and drop the first to induce an ACK from the client.
+ let _ = send_something(&mut server, now); // Drop this.
+
+ // Another packet from the server will cause the client to ACK and update keys.
+ let dgram = send_and_receive(&mut server, &mut client, now);
+ assert!(dgram.is_some());
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+
+ // Have the server process the ACK.
+ if let Output::Callback(_) = server.process(dgram, now) {
+ assert_eq!(server.get_epochs(), (Some(4), Some(3)));
+ // Now move the server temporarily into the future so that it
+ // rotates the keys. The client stays in the present.
+ let _ = server.process(None, now + AT_LEAST_PTO);
+ assert_eq!(server.get_epochs(), (Some(4), Some(4)));
+ } else {
+ panic!("server should have a timer set");
+ }
+
+ // Now update keys on the server again.
+ assert!(server.initiate_key_update().is_ok());
+ assert_eq!(server.get_epochs(), (Some(5), Some(4)));
+
+ let dgram = send_something(&mut server, now + AT_LEAST_PTO);
+
+ // However, as the server didn't wait long enough to update again, the
+ // client hasn't rotated its keys, so the packet gets dropped.
+ check_discarded(&mut client, dgram, 1, 0);
+}
+
+// Key updates can't be initiated too early.
+#[test]
+fn key_update_before_confirmed() {
+ let mut client = default_client();
+ assert_update_blocked(&mut client);
+ let mut server = default_server();
+ assert_update_blocked(&mut server);
+
+ // Client Initial
+ let dgram = client.process(None, now()).dgram();
+ assert!(dgram.is_some());
+ assert_update_blocked(&mut client);
+
+ // Server Initial + Handshake
+ let dgram = server.process(dgram, now()).dgram();
+ assert!(dgram.is_some());
+ assert_update_blocked(&mut server);
+
+ // Client Handshake
+ client.process_input(dgram.unwrap(), now());
+ assert_update_blocked(&mut client);
+
+ assert!(maybe_authenticate(&mut client));
+ assert_update_blocked(&mut client);
+
+ let dgram = client.process(None, now()).dgram();
+ assert!(dgram.is_some());
+ assert_update_blocked(&mut client);
+
+ // Server HANDSHAKE_DONE
+ let dgram = server.process(dgram, now()).dgram();
+ assert!(dgram.is_some());
+ assert!(server.initiate_key_update().is_ok());
+
+ // Client receives HANDSHAKE_DONE
+ let dgram = client.process(dgram, now()).dgram();
+ assert!(dgram.is_none());
+ assert!(client.initiate_key_update().is_ok());
+}
+
+#[test]
+fn exhaust_write_keys() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ overwrite_invocations(0);
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert!(client.stream_send(stream_id, b"explode!").is_ok());
+ let dgram = client.process_output(now()).dgram();
+ assert!(dgram.is_none());
+ assert!(matches!(
+ client.state(),
+ State::Closed(ConnectionError::Transport(Error::KeysExhausted))
+ ));
+}
+
+#[test]
+fn exhaust_read_keys() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let dgram = send_something(&mut client, now());
+
+ overwrite_invocations(0);
+ let dgram = server.process(Some(dgram), now()).dgram();
+ assert!(matches!(
+ server.state(),
+ State::Closed(ConnectionError::Transport(Error::KeysExhausted))
+ ));
+
+ client.process_input(dgram.unwrap(), now());
+ assert!(matches!(client.state(), State::Draining {
+ error: ConnectionError::Transport(Error::PeerError(ERROR_AEAD_LIMIT_REACHED)), ..
+ }));
+}
+
+#[test]
+fn automatic_update_write_keys() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ overwrite_invocations(UPDATE_WRITE_KEYS_AT);
+ let _ = send_something(&mut client, now());
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+}
+
+#[test]
+fn automatic_update_write_keys_later() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ overwrite_invocations(UPDATE_WRITE_KEYS_AT + 2);
+ // No update after the first.
+ let _ = send_something(&mut client, now());
+ assert_eq!(client.get_epochs(), (Some(3), Some(3)));
+ // The second will update though.
+ let _ = send_something(&mut client, now());
+ assert_eq!(client.get_epochs(), (Some(4), Some(3)));
+}
+
+#[test]
+fn automatic_update_write_keys_blocked() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ // An outstanding key update will block the automatic update.
+ client.initiate_key_update().unwrap();
+
+ overwrite_invocations(UPDATE_WRITE_KEYS_AT);
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert!(client.stream_send(stream_id, b"explode!").is_ok());
+ let dgram = client.process_output(now()).dgram();
+ // Not being able to update is fatal.
+ assert!(dgram.is_none());
+ assert!(matches!(
+ client.state(),
+ State::Closed(ConnectionError::Transport(Error::KeysExhausted))
+ ));
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/mod.rs b/third_party/rust/neqo-transport/src/connection/tests/mod.rs
new file mode 100644
index 0000000000..20add0faad
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/mod.rs
@@ -0,0 +1,305 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![deny(clippy::pedantic)]
+
+use super::{
+ Connection, ConnectionError, FixedConnectionIdManager, Output, State, LOCAL_IDLE_TIMEOUT,
+};
+use crate::addr_valid::{AddressValidation, ValidateAddress};
+use crate::cc::CWND_INITIAL_PKTS;
+use crate::events::ConnectionEvent;
+use crate::frame::StreamType;
+use crate::path::PATH_MTU_V6;
+use crate::recovery::ACK_ONLY_SIZE_LIMIT;
+use crate::ConnectionParameters;
+
+use std::cell::RefCell;
+use std::mem;
+use std::rc::Rc;
+use std::time::{Duration, Instant};
+
+use neqo_common::{event::Provider, qdebug, qtrace, Datagram};
+use neqo_crypto::{AllowZeroRtt, AuthenticationStatus, ResumptionToken};
+use test_fixture::{self, fixture_init, loopback, now};
+
+// All the tests.
+mod cc;
+mod close;
+mod handshake;
+mod idle;
+mod keys;
+mod recovery;
+mod resumption;
+mod stream;
+mod vn;
+mod zerortt;
+
+const DEFAULT_RTT: Duration = Duration::from_millis(100);
+const AT_LEAST_PTO: Duration = Duration::from_secs(1);
+const DEFAULT_STREAM_DATA: &[u8] = b"message";
+
+// This is fabulous: because test_fixture uses the public API for Connection,
+// it gets a different type to the ones that are referenced via super::super::*.
+// Thus, this code can't use default_client() and default_server() from
+// test_fixture because they produce different types.
+//
+// These are a direct copy of those functions.
+pub fn default_client() -> Connection {
+ fixture_init();
+ Connection::new_client(
+ test_fixture::DEFAULT_SERVER_NAME,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(FixedConnectionIdManager::new(3))),
+ loopback(),
+ loopback(),
+ &ConnectionParameters::default(),
+ )
+ .expect("create a default client")
+}
+pub fn default_server() -> Connection {
+ fixture_init();
+
+ let mut c = Connection::new_server(
+ test_fixture::DEFAULT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(FixedConnectionIdManager::new(5))),
+ &ConnectionParameters::default(),
+ )
+ .expect("create a default server");
+ c.server_enable_0rtt(&test_fixture::anti_replay(), AllowZeroRtt {})
+ .expect("enable 0-RTT");
+ c
+}
+
+/// If state is `AuthenticationNeeded` call `authenticated()`. This function will
+/// consume all outstanding events on the connection.
+pub fn maybe_authenticate(conn: &mut Connection) -> bool {
+ let authentication_needed = |e| matches!(e, ConnectionEvent::AuthenticationNeeded);
+ if conn.events().any(authentication_needed) {
+ conn.authenticated(AuthenticationStatus::Ok, now());
+ return true;
+ }
+ false
+}
+
+/// Drive the handshake between the client and server.
+fn handshake(
+ client: &mut Connection,
+ server: &mut Connection,
+ now: Instant,
+ rtt: Duration,
+) -> Instant {
+ let mut a = client;
+ let mut b = server;
+ let mut now = now;
+
+ let mut input = None;
+ let is_done = |c: &mut Connection| matches!(c.state(), State::Confirmed | State::Closing { .. } | State::Closed(..));
+
+ while !is_done(a) {
+ let _ = maybe_authenticate(a);
+ let had_input = input.is_some();
+ let output = a.process(input, now).dgram();
+ assert!(had_input || output.is_some());
+ input = output;
+ qtrace!("t += {:?}", rtt / 2);
+ now += rtt / 2;
+ mem::swap(&mut a, &mut b);
+ }
+ let _ = a.process(input, now);
+ now
+}
+
+fn connect_with_rtt(
+ client: &mut Connection,
+ server: &mut Connection,
+ now: Instant,
+ rtt: Duration,
+) -> Instant {
+ let now = handshake(client, server, now, rtt);
+ assert_eq!(*client.state(), State::Confirmed);
+ assert_eq!(*client.state(), State::Confirmed);
+
+ assert_eq!(client.loss_recovery.rtt(), rtt);
+ assert_eq!(server.loss_recovery.rtt(), rtt);
+ now
+}
+
+fn connect(client: &mut Connection, server: &mut Connection) {
+ connect_with_rtt(client, server, now(), Duration::new(0, 0));
+}
+
+fn assert_error(c: &Connection, err: &ConnectionError) {
+ match c.state() {
+ State::Closing { error, .. } | State::Draining { error, .. } | State::Closed(error) => {
+ assert_eq!(*error, *err);
+ }
+ _ => panic!("bad state {:?}", c.state()),
+ }
+}
+
+fn exchange_ticket(
+ client: &mut Connection,
+ server: &mut Connection,
+ now: Instant,
+) -> ResumptionToken {
+ let validation = AddressValidation::new(now, ValidateAddress::NoToken).unwrap();
+ let validation = Rc::new(RefCell::new(validation));
+ server.set_validation(Rc::clone(&validation));
+ server.send_ticket(now, &[]).expect("can send ticket");
+ let ticket = server.process_output(now).dgram();
+ assert!(ticket.is_some());
+ client.process_input(ticket.unwrap(), now);
+ assert_eq!(*client.state(), State::Confirmed);
+ get_tokens(client).pop().expect("should have token")
+}
+
+/// Connect with an RTT and then force both peers to be idle.
+/// Getting the client and server to reach an idle state is surprisingly hard.
+/// The server sends `HANDSHAKE_DONE` at the end of the handshake, and the client
+/// doesn't immediately acknowledge it. Reordering packets does the trick.
+fn connect_rtt_idle(client: &mut Connection, server: &mut Connection, rtt: Duration) -> Instant {
+ let mut now = connect_with_rtt(client, server, now(), rtt);
+ let p1 = send_something(server, now);
+ let p2 = send_something(server, now);
+ now += rtt / 2;
+ // Delivering p2 first at the client causes it to want to ACK.
+ client.process_input(p2, now);
+ // Delivering p1 should not have the client change its mind about the ACK.
+ let ack = client.process(Some(p1), now).dgram();
+ assert!(ack.is_some());
+ assert_eq!(
+ server.process(ack, now),
+ Output::Callback(LOCAL_IDLE_TIMEOUT)
+ );
+ assert_eq!(
+ client.process_output(now),
+ Output::Callback(LOCAL_IDLE_TIMEOUT)
+ );
+ now
+}
+
+fn connect_force_idle(client: &mut Connection, server: &mut Connection) {
+ connect_rtt_idle(client, server, Duration::new(0, 0));
+}
+
+/// This fills the congestion window from a single source.
+/// As the pacer will interfere with this, this moves time forward
+/// as `Output::Callback` is received. Because it is hard to tell
+/// from the return value whether a timeout is an ACK delay, PTO, or
+/// pacing, this looks at the congestion window to tell when to stop.
+/// Returns a list of datagrams and the new time.
+fn fill_cwnd(src: &mut Connection, stream: u64, mut now: Instant) -> (Vec<Datagram>, Instant) {
+ const BLOCK_SIZE: usize = 4_096;
+ let mut total_dgrams = Vec::new();
+
+ qtrace!(
+ "fill_cwnd starting cwnd: {}",
+ src.loss_recovery.cwnd_avail()
+ );
+
+ loop {
+ let bytes_sent = src.stream_send(stream, &[0x42; BLOCK_SIZE]).unwrap();
+ qtrace!("fill_cwnd wrote {} bytes", bytes_sent);
+ if bytes_sent < BLOCK_SIZE {
+ break;
+ }
+ }
+
+ loop {
+ let pkt = src.process_output(now);
+ qtrace!(
+ "fill_cwnd cwnd remaining={}, output: {:?}",
+ src.loss_recovery.cwnd_avail(),
+ pkt
+ );
+ match pkt {
+ Output::Datagram(dgram) => {
+ total_dgrams.push(dgram);
+ }
+ Output::Callback(t) => {
+ if src.loss_recovery.cwnd_avail() < ACK_ONLY_SIZE_LIMIT {
+ break;
+ }
+ now += t;
+ }
+ Output::None => panic!(),
+ }
+ }
+
+ (total_dgrams, now)
+}
+
+/// This magic number is the size of the client's CWND after the handshake completes.
+/// This is the same as the initial congestion window, because during the handshake
+/// the cc is app limited and cwnd is not increased.
+///
+/// As we change how we build packets, or even as NSS changes,
+/// this number might be different. The tests that depend on this
+/// value could fail as a result of variations, so it's OK to just
+/// change this value, but it is good to first understand where the
+/// change came from.
+const POST_HANDSHAKE_CWND: usize = PATH_MTU_V6 * CWND_INITIAL_PKTS;
+
+/// Determine the number of packets required to fill the CWND.
+const fn cwnd_packets(data: usize) -> usize {
+ (data + ACK_ONLY_SIZE_LIMIT - 1) / PATH_MTU_V6
+}
+
+/// Determine the size of the last packet.
+/// The minimal size of a packet is `ACK_ONLY_SIZE_LIMIT`.
+fn last_packet(cwnd: usize) -> usize {
+ if (cwnd % PATH_MTU_V6) > ACK_ONLY_SIZE_LIMIT {
+ cwnd % PATH_MTU_V6
+ } else {
+ PATH_MTU_V6
+ }
+}
+
+/// Assert that the set of packets fill the CWND.
+fn assert_full_cwnd(packets: &[Datagram], cwnd: usize) {
+ assert_eq!(packets.len(), cwnd_packets(cwnd));
+ let (last, rest) = packets.split_last().unwrap();
+ assert!(rest.iter().all(|d| d.len() == PATH_MTU_V6));
+ assert_eq!(last.len(), last_packet(cwnd));
+}
+
+/// Send something on a stream from `sender` to `receiver`.
+/// Return the resulting datagram.
+#[must_use]
+fn send_something(sender: &mut Connection, now: Instant) -> Datagram {
+ let stream_id = sender.stream_create(StreamType::UniDi).unwrap();
+ assert!(sender.stream_send(stream_id, DEFAULT_STREAM_DATA).is_ok());
+ assert!(sender.stream_close_send(stream_id).is_ok());
+ qdebug!([sender], "send_something on {}", stream_id);
+ let dgram = sender.process(None, now).dgram();
+ dgram.expect("should have something to send")
+}
+
+/// Send something on a stream from `sender` to `receiver`.
+/// Return any ACK that might result.
+fn send_and_receive(
+ sender: &mut Connection,
+ receiver: &mut Connection,
+ now: Instant,
+) -> Option<Datagram> {
+ let dgram = send_something(sender, now);
+ receiver.process(Some(dgram), now).dgram()
+}
+
+fn get_tokens(client: &mut Connection) -> Vec<ResumptionToken> {
+ client
+ .events()
+ .filter_map(|e| {
+ if let ConnectionEvent::ResumptionToken(token) = e {
+ Some(token)
+ } else {
+ None
+ }
+ })
+ .collect()
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/recovery.rs b/third_party/rust/neqo-transport/src/connection/tests/recovery.rs
new file mode 100644
index 0000000000..ba7069ccb2
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/recovery.rs
@@ -0,0 +1,636 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use super::super::{Output, State, LOCAL_IDLE_TIMEOUT};
+use super::{
+ assert_full_cwnd, connect, connect_force_idle, connect_with_rtt, default_client,
+ default_server, fill_cwnd, maybe_authenticate, send_and_receive, send_something, AT_LEAST_PTO,
+ POST_HANDSHAKE_CWND,
+};
+use crate::frame::StreamType;
+use crate::path::PATH_MTU_V6;
+use crate::recovery::PTO_PACKET_COUNT;
+use crate::stats::MAX_PTO_COUNTS;
+use crate::tparams::TransportParameter;
+use crate::tracking::ACK_DELAY;
+
+use neqo_common::qdebug;
+use neqo_crypto::AuthenticationStatus;
+use std::time::Duration;
+use test_fixture::{self, now, split_datagram};
+
+#[test]
+fn pto_works_basic() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let mut now = now();
+
+ let res = client.process(None, now);
+ assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT));
+
+ // Send data on two streams
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2);
+ assert_eq!(client.stream_send(2, b"hello").unwrap(), 5);
+ assert_eq!(client.stream_send(2, b" world").unwrap(), 6);
+
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 6);
+ assert_eq!(client.stream_send(6, b"there!").unwrap(), 6);
+
+ // Send a packet after some time.
+ now += Duration::from_secs(10);
+ let out = client.process(None, now);
+ assert!(out.dgram().is_some());
+
+ // Nothing to do, should return callback
+ let out = client.process(None, now);
+ assert!(matches!(out, Output::Callback(_)));
+
+ // One second later, it should want to send PTO packet
+ now += AT_LEAST_PTO;
+ let out = client.process(None, now);
+
+ let stream_before = server.stats().frame_rx.stream;
+ server.process_input(out.dgram().unwrap(), now);
+ assert_eq!(server.stats().frame_rx.stream, stream_before + 2);
+}
+
+#[test]
+fn pto_works_full_cwnd() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let res = client.process(None, now());
+ assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT));
+
+ // Send lots of data.
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2);
+ let (dgrams, now) = fill_cwnd(&mut client, 2, now());
+ assert_full_cwnd(&dgrams, POST_HANDSHAKE_CWND);
+
+ neqo_common::qwarn!("waiting over");
+ // Fill the CWND after waiting for a PTO.
+ let (dgrams, now) = fill_cwnd(&mut client, 2, now + AT_LEAST_PTO);
+ // Two packets in the PTO.
+ // The first should be full sized; the second might be small.
+ assert_eq!(dgrams.len(), 2);
+ assert_eq!(dgrams[0].len(), PATH_MTU_V6);
+
+ // Both datagrams contain a STREAM frame.
+ for d in dgrams {
+ let stream_before = server.stats().frame_rx.stream;
+ server.process_input(d, now);
+ assert_eq!(server.stats().frame_rx.stream, stream_before + 1);
+ }
+}
+
+#[test]
+#[allow(clippy::cognitive_complexity)]
+fn pto_works_ping() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let now = now();
+
+ let res = client.process(None, now);
+ assert_eq!(res, Output::Callback(LOCAL_IDLE_TIMEOUT));
+
+ // Send "zero" pkt
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2);
+ assert_eq!(client.stream_send(2, b"zero").unwrap(), 4);
+ let pkt0 = client.process(None, now + Duration::from_secs(10));
+ assert!(matches!(pkt0, Output::Datagram(_)));
+
+ // Send "one" pkt
+ assert_eq!(client.stream_send(2, b"one").unwrap(), 3);
+ let pkt1 = client.process(None, now + Duration::from_secs(10));
+
+ // Send "two" pkt
+ assert_eq!(client.stream_send(2, b"two").unwrap(), 3);
+ let pkt2 = client.process(None, now + Duration::from_secs(10));
+
+ // Send "three" pkt
+ assert_eq!(client.stream_send(2, b"three").unwrap(), 5);
+ let pkt3 = client.process(None, now + Duration::from_secs(10));
+
+ // Nothing to do, should return callback
+ let out = client.process(None, now + Duration::from_secs(10));
+ // Check callback delay is what we expect
+ assert!(matches!(out, Output::Callback(x) if x == Duration::from_millis(45)));
+
+ // Process these by server, skipping pkt0
+ let srv0_pkt1 = server.process(pkt1.dgram(), now + Duration::from_secs(10));
+ // ooo, ack client pkt 1
+ assert!(matches!(srv0_pkt1, Output::Datagram(_)));
+
+ // process pkt2 (no ack yet)
+ let srv2 = server.process(
+ pkt2.dgram(),
+ now + Duration::from_secs(10) + Duration::from_millis(20),
+ );
+ assert!(matches!(srv2, Output::Callback(_)));
+
+ // process pkt3 (acked)
+ let srv2 = server.process(
+ pkt3.dgram(),
+ now + Duration::from_secs(10) + Duration::from_millis(20),
+ );
+ // ack client pkt 2 & 3
+ assert!(matches!(srv2, Output::Datagram(_)));
+
+ // client processes ack
+ let pkt4 = client.process(
+ srv2.dgram(),
+ now + Duration::from_secs(10) + Duration::from_millis(40),
+ );
+ // client resends data from pkt0
+ assert!(matches!(pkt4, Output::Datagram(_)));
+
+ // server sees ooo pkt0 and generates ack
+ let srv_pkt2 = server.process(
+ pkt0.dgram(),
+ now + Duration::from_secs(10) + Duration::from_millis(40),
+ );
+ assert!(matches!(srv_pkt2, Output::Datagram(_)));
+
+ // Orig data is acked
+ let pkt5 = client.process(
+ srv_pkt2.dgram(),
+ now + Duration::from_secs(10) + Duration::from_millis(40),
+ );
+ assert!(matches!(pkt5, Output::Callback(_)));
+
+ // PTO expires. No unacked data. Only send PING.
+ let pkt6 = client.process(
+ None,
+ now + Duration::from_secs(10) + Duration::from_millis(110),
+ );
+
+ let ping_before = server.stats().frame_rx.ping;
+ server.process_input(
+ pkt6.dgram().unwrap(),
+ now + Duration::from_secs(10) + Duration::from_millis(110),
+ );
+ assert_eq!(server.stats().frame_rx.ping, ping_before + 1);
+}
+
+#[test]
+fn pto_initial() {
+ const INITIAL_PTO: Duration = Duration::from_millis(300);
+ let mut now = now();
+
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let pkt1 = client.process(None, now).dgram();
+ assert!(pkt1.is_some());
+ assert_eq!(pkt1.clone().unwrap().len(), PATH_MTU_V6);
+
+ let delay = client.process(None, now).callback();
+ assert_eq!(delay, INITIAL_PTO);
+
+ // Resend initial after PTO.
+ now += delay;
+ let pkt2 = client.process(None, now).dgram();
+ assert!(pkt2.is_some());
+ assert_eq!(pkt2.unwrap().len(), PATH_MTU_V6);
+
+ let pkt3 = client.process(None, now).dgram();
+ assert!(pkt3.is_some());
+ assert_eq!(pkt3.unwrap().len(), PATH_MTU_V6);
+
+ let delay = client.process(None, now).callback();
+ // PTO has doubled.
+ assert_eq!(delay, INITIAL_PTO * 2);
+
+ // Server process the first initial pkt.
+ let mut server = default_server();
+ let out = server.process(pkt1, now).dgram();
+ assert!(out.is_some());
+
+ // Client receives ack for the first initial packet as well a Handshake packet.
+ // After the handshake packet the initial keys and the crypto stream for the initial
+ // packet number space will be discarded.
+ // Here only an ack for the Handshake packet will be sent.
+ let out = client.process(out, now).dgram();
+ assert!(out.is_some());
+
+ // We do not have PTO for the resent initial packet any more, but
+ // the Handshake PTO timer should be armed. As the RTT is apparently
+ // the same as the initial PTO value, and there is only one sample,
+ // the PTO will be 3x the INITIAL PTO.
+ let delay = client.process(None, now).callback();
+ assert_eq!(delay, INITIAL_PTO * 3);
+}
+
+/// A complete handshake that involves a PTO in the Handshake space.
+#[test]
+fn pto_handshake_complete() {
+ let mut now = now();
+ // start handshake
+ let mut client = default_client();
+ let mut server = default_server();
+
+ let pkt = client.process(None, now).dgram();
+ let cb = client.process(None, now).callback();
+ assert_eq!(cb, Duration::from_millis(300));
+
+ now += Duration::from_millis(10);
+ let pkt = server.process(pkt, now).dgram();
+
+ now += Duration::from_millis(10);
+ let pkt = client.process(pkt, now).dgram();
+
+ let cb = client.process(None, now).callback();
+ // The client now has a single RTT estimate (20ms), so
+ // the handshake PTO is set based on that.
+ assert_eq!(cb, Duration::from_millis(60));
+
+ now += Duration::from_millis(10);
+ let pkt = server.process(pkt, now).dgram();
+ assert!(pkt.is_none());
+
+ now += Duration::from_millis(10);
+ client.authenticated(AuthenticationStatus::Ok, now);
+
+ qdebug!("---- client: SH..FIN -> FIN");
+ let pkt1 = client.process(None, now).dgram();
+ assert!(pkt1.is_some());
+
+ let cb = client.process(None, now).callback();
+ assert_eq!(cb, Duration::from_millis(60));
+
+ let mut pto_counts = [0; MAX_PTO_COUNTS];
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ // Wait for PTO to expire and resend a handshake packet
+ now += Duration::from_millis(60);
+ let pkt2 = client.process(None, now).dgram();
+ assert!(pkt2.is_some());
+
+ pto_counts[0] = 1;
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ // Get a second PTO packet.
+ let pkt3 = client.process(None, now).dgram();
+ assert!(pkt3.is_some());
+
+ // PTO has been doubled.
+ let cb = client.process(None, now).callback();
+ assert_eq!(cb, Duration::from_millis(120));
+
+ // We still have only a single PTO
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ now += Duration::from_millis(10);
+ // Server receives the first packet.
+ // The output will be a Handshake packet with an ack and a app pn space packet with
+ // HANDSHAKE_DONE.
+ let pkt = server.process(pkt1, now).dgram();
+ assert!(pkt.is_some());
+
+ // Check that the PTO packets (pkt2, pkt3) are Handshake packets.
+ // The server discarded the Handshake keys already, therefore they are dropped.
+ let dropped_before1 = server.stats().dropped_rx;
+ let frames_before = server.stats().frame_rx.all;
+ server.process_input(pkt2.unwrap(), now);
+ assert_eq!(1, server.stats().dropped_rx - dropped_before1);
+ assert_eq!(server.stats().frame_rx.all, frames_before);
+
+ let dropped_before2 = server.stats().dropped_rx;
+ server.process_input(pkt3.unwrap(), now);
+ assert_eq!(1, server.stats().dropped_rx - dropped_before2);
+ assert_eq!(server.stats().frame_rx.all, frames_before);
+
+ now += Duration::from_millis(10);
+ // Client receive ack for the first packet
+ let cb = client.process(pkt, now).callback();
+ // Ack delay timer for the packet carrying HANDSHAKE_DONE.
+ assert_eq!(cb, ACK_DELAY);
+
+ // Let the ack timer expire.
+ now += cb;
+ let out = client.process(None, now).dgram();
+ assert!(out.is_some());
+ let cb = client.process(None, now).callback();
+ // The handshake keys are discarded, but now we're back to the idle timeout.
+ // We don't send another PING because the handshake space is done and there
+ // is nothing to probe for.
+
+ pto_counts[0] = 1;
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+ assert_eq!(cb, LOCAL_IDLE_TIMEOUT - ACK_DELAY);
+}
+
+/// Test that PTO in the Handshake space contains the right frames.
+#[test]
+fn pto_handshake_frames() {
+ let mut now = now();
+ qdebug!("---- client: generate CH");
+ let mut client = default_client();
+ let pkt = client.process(None, now);
+
+ now += Duration::from_millis(10);
+ qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN");
+ let mut server = default_server();
+ let pkt = server.process(pkt.dgram(), now);
+
+ now += Duration::from_millis(10);
+ qdebug!("---- client: cert verification");
+ let pkt = client.process(pkt.dgram(), now);
+
+ now += Duration::from_millis(10);
+ let _ = server.process(pkt.dgram(), now);
+
+ now += Duration::from_millis(10);
+ client.authenticated(AuthenticationStatus::Ok, now);
+
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2);
+ assert_eq!(client.stream_send(2, b"zero").unwrap(), 4);
+ qdebug!("---- client: SH..FIN -> FIN and 1RTT packet");
+ let pkt1 = client.process(None, now).dgram();
+ assert!(pkt1.is_some());
+
+ // Get PTO timer.
+ let out = client.process(None, now);
+ assert_eq!(out, Output::Callback(Duration::from_millis(60)));
+
+ // Wait for PTO to expire and resend a handshake packet.
+ now += Duration::from_millis(60);
+ let pkt2 = client.process(None, now).dgram();
+ assert!(pkt2.is_some());
+
+ now += Duration::from_millis(10);
+ let crypto_before = server.stats().frame_rx.crypto;
+ server.process_input(pkt2.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.crypto, crypto_before + 1)
+}
+
+/// In the case that the Handshake takes too many packets, the server might
+/// be stalled on the anti-amplification limit. If a Handshake ACK from the
+/// client is lost, the client has to keep the PTO timer armed or the server
+/// might be unable to send anything, causing a deadlock.
+#[test]
+fn handshake_ack_pto() {
+ const RTT: Duration = Duration::from_millis(10);
+ let mut now = now();
+ let mut client = default_client();
+ let mut server = default_server();
+ // This is a greasing transport parameter, and large enough that the
+ // server needs to send two Handshake packets.
+ let big = TransportParameter::Bytes(vec![0; PATH_MTU_V6]);
+ server.set_local_tparam(0xce16, big).unwrap();
+
+ let c1 = client.process(None, now).dgram();
+
+ now += RTT / 2;
+ let s1 = server.process(c1, now).dgram();
+ assert!(s1.is_some());
+ let s2 = server.process(None, now).dgram();
+ assert!(s1.is_some());
+
+ // Now let the client have the Initial, but drop the first coalesced Handshake packet.
+ now += RTT / 2;
+ let (initial, _) = split_datagram(&s1.unwrap());
+ client.process_input(initial, now);
+ let c2 = client.process(s2, now).dgram();
+ assert!(c2.is_some()); // This is an ACK. Drop it.
+ let delay = client.process(None, now).callback();
+ assert_eq!(delay, RTT * 3);
+
+ let mut pto_counts = [0; MAX_PTO_COUNTS];
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ // Wait for the PTO and ensure that the client generates a packet.
+ now += delay;
+ let c3 = client.process(None, now).dgram();
+ assert!(c3.is_some());
+
+ now += RTT / 2;
+ let ping_before = server.stats().frame_rx.ping;
+ server.process_input(c3.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.ping, ping_before + 1);
+
+ pto_counts[0] = 1;
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+
+ // Now complete the handshake as cheaply as possible.
+ let dgram = server.process(None, now).dgram();
+ client.process_input(dgram.unwrap(), now);
+ maybe_authenticate(&mut client);
+ let dgram = client.process(None, now).dgram();
+ assert_eq!(*client.state(), State::Connected);
+ let dgram = server.process(dgram, now).dgram();
+ assert_eq!(*server.state(), State::Confirmed);
+ client.process_input(dgram.unwrap(), now);
+ assert_eq!(*client.state(), State::Confirmed);
+
+ assert_eq!(client.stats.borrow().pto_counts, pto_counts);
+}
+
+#[test]
+fn loss_recovery_crash() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ let now = now();
+
+ // The server sends something, but we will drop this.
+ let _ = send_something(&mut server, now);
+
+ // Then send something again, but let it through.
+ let ack = send_and_receive(&mut server, &mut client, now);
+ assert!(ack.is_some());
+
+ // Have the server process the ACK.
+ let cb = server.process(ack, now).callback();
+ assert!(cb > Duration::from_secs(0));
+
+ // Now we leap into the future. The server should regard the first
+ // packet as lost based on time alone.
+ let dgram = server.process(None, now + AT_LEAST_PTO).dgram();
+ assert!(dgram.is_some());
+
+ // This crashes.
+ let _ = send_something(&mut server, now + AT_LEAST_PTO);
+}
+
+// If we receive packets after the PTO timer has fired, we won't clear
+// the PTO state, but we might need to acknowledge those packets.
+// This shouldn't happen, but we found that some implementations do this.
+#[test]
+fn ack_after_pto() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect_force_idle(&mut client, &mut server);
+
+ let mut now = now();
+
+ // The client sends and is forced into a PTO.
+ let _ = send_something(&mut client, now);
+
+ // Jump forward to the PTO and drain the PTO packets.
+ now += AT_LEAST_PTO;
+ for _ in 0..PTO_PACKET_COUNT {
+ let dgram = client.process(None, now).dgram();
+ assert!(dgram.is_some());
+ }
+ assert!(client.process(None, now).dgram().is_none());
+
+ // The server now needs to send something that will cause the
+ // client to want to acknowledge it. A little out of order
+ // delivery is just the thing.
+ // Note: The server can't ACK anything here, but none of what
+ // the client has sent so far has been transferred.
+ let _ = send_something(&mut server, now);
+ let dgram = send_something(&mut server, now);
+
+ // The client is now after a PTO, but if it receives something
+ // that demands acknowledgment, it will send just the ACK.
+ let ack = client.process(Some(dgram), now).dgram();
+ assert!(ack.is_some());
+
+ // Make sure that the packet only contained an ACK frame.
+ let all_frames_before = server.stats().frame_rx.all;
+ let ack_before = server.stats().frame_rx.ack;
+ server.process_input(ack.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.all, all_frames_before + 1);
+ assert_eq!(server.stats().frame_rx.ack, ack_before + 1);
+}
+
+/// When we declare a packet as lost, we keep it around for a while for another loss period.
+/// Those packets should not affect how we report the loss recovery timer.
+/// As the loss recovery timer based on RTT we use that to drive the state.
+#[test]
+fn lost_but_kept_and_lr_timer() {
+ const RTT: Duration = Duration::from_secs(1);
+ let mut client = default_client();
+ let mut server = default_server();
+ let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT);
+
+ // Two packets (p1, p2) are sent at around t=0. The first is lost.
+ let _p1 = send_something(&mut client, now);
+ let p2 = send_something(&mut client, now);
+
+ // At t=RTT/2 the server receives the packet and ACKs it.
+ now += RTT / 2;
+ let ack = server.process(Some(p2), now).dgram();
+ assert!(ack.is_some());
+ // The client also sends another two packets (p3, p4), again losing the first.
+ let _p3 = send_something(&mut client, now);
+ let p4 = send_something(&mut client, now);
+
+ // At t=RTT the client receives the ACK and goes into timed loss recovery.
+ // The client doesn't call p1 lost at this stage, but it will soon.
+ now += RTT / 2;
+ let res = client.process(ack, now);
+ // The client should be on a loss recovery timer as p1 is missing.
+ let lr_timer = res.callback();
+ // Loss recovery timer should be RTT/8, but only check for 0 or >=RTT/2.
+ assert_ne!(lr_timer, Duration::from_secs(0));
+ assert!(lr_timer < (RTT / 2));
+ // The server also receives and acknowledges p4, again sending an ACK.
+ let ack = server.process(Some(p4), now).dgram();
+ assert!(ack.is_some());
+
+ // At t=RTT*3/2 the client should declare p1 to be lost.
+ now += RTT / 2;
+ // So the client will send the data from p1 again.
+ let res = client.process(None, now);
+ assert!(res.dgram().is_some());
+ // When the client processes the ACK, it should engage the
+ // loss recovery timer for p3, not p1 (even though it still tracks p1).
+ let res = client.process(ack, now);
+ let lr_timer2 = res.callback();
+ assert_eq!(lr_timer, lr_timer2);
+}
+
+/// We should not be setting the loss recovery timer based on packets
+/// that are sent prior to the largest acknowledged.
+/// Testing this requires that we construct a case where one packet
+/// number space causes the loss recovery timer to be engaged. At the same time,
+/// there is a packet in another space that hasn't been acknowledged AND
+/// that packet number space has not received acknowledgments for later packets.
+#[test]
+fn loss_time_past_largest_acked() {
+ const RTT: Duration = Duration::from_secs(10);
+ const INCR: Duration = Duration::from_millis(1);
+ let mut client = default_client();
+ let mut server = default_server();
+
+ let mut now = now();
+
+ // Start the handshake.
+ let c_in = client.process(None, now).dgram();
+ now += RTT / 2;
+ let s_hs1 = server.process(c_in, now).dgram();
+
+ // Get some spare server handshake packets for the client to ACK.
+ // This involves a time machine, so be a little cautious.
+ // This test uses an RTT of 10s, but our server starts
+ // with a much lower RTT estimate, so the PTO at this point should
+ // be much smaller than an RTT and so the server shouldn't see
+ // time go backwards.
+ let s_pto = server.process(None, now).callback();
+ assert_ne!(s_pto, Duration::from_secs(0));
+ assert!(s_pto < RTT);
+ let s_hs2 = server.process(None, now + s_pto).dgram();
+ assert!(s_hs2.is_some());
+ let s_hs3 = server.process(None, now + s_pto).dgram();
+ assert!(s_hs3.is_some());
+
+ // Get some Handshake packets from the client.
+ // We need one to be left unacknowledged before one that is acknowledged.
+ // So that the client engages the loss recovery timer.
+ // This is complicated by the fact that it is hard to cause the client
+ // to generate an ack-eliciting packet. For that, we use the Finished message.
+ // Reordering delivery ensures that the later packet is also acknowledged.
+ now += RTT / 2;
+ let c_hs1 = client.process(s_hs1, now).dgram();
+ assert!(c_hs1.is_some()); // This comes first, so it's useless.
+ maybe_authenticate(&mut client);
+ let c_hs2 = client.process(None, now).dgram();
+ assert!(c_hs2.is_some()); // This one will elicit an ACK.
+
+ // The we need the outstanding packet to be sent after the
+ // application data packet, so space these out a tiny bit.
+ let _p1 = send_something(&mut client, now + INCR);
+ let c_hs3 = client.process(s_hs2, now + (INCR * 2)).dgram();
+ assert!(c_hs3.is_some()); // This will be left outstanding.
+ let c_hs4 = client.process(s_hs3, now + (INCR * 3)).dgram();
+ assert!(c_hs4.is_some()); // This will be acknowledged.
+
+ // Process c_hs2 and c_hs4, but skip c_hs3.
+ // Then get an ACK for the client.
+ now += RTT / 2;
+ // Deliver c_hs4 first, but don't generate a packet.
+ server.process_input(c_hs4.unwrap(), now);
+ let s_ack = server.process(c_hs2, now).dgram();
+ assert!(s_ack.is_some());
+ // This includes an ACK, but it also includes HANDSHAKE_DONE,
+ // which we need to remove because that will cause the Handshake loss
+ // recovery state to be dropped.
+ let (s_hs_ack, _s_ap_ack) = split_datagram(&s_ack.unwrap());
+
+ // Now the client should start its loss recovery timer based on the ACK.
+ now += RTT / 2;
+ let c_ack = client.process(Some(s_hs_ack), now).dgram();
+ assert!(c_ack.is_none());
+ // The client should now have the loss recovery timer active.
+ let lr_time = client.process(None, now).callback();
+ assert_ne!(lr_time, Duration::from_secs(0));
+ assert!(lr_time < (RTT / 2));
+
+ // Skipping forward by the loss recovery timer should cause the client to
+ // mark packets as lost and retransmit, after which we should be on the PTO
+ // timer.
+ now += lr_time;
+ let delay = client.process(None, now).callback();
+ assert_ne!(delay, Duration::from_secs(0));
+ assert!(delay > lr_time);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/resumption.rs b/third_party/rust/neqo-transport/src/connection/tests/resumption.rs
new file mode 100644
index 0000000000..19421dded8
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/resumption.rs
@@ -0,0 +1,182 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use super::{
+ connect, connect_with_rtt, default_client, default_server, exchange_ticket, get_tokens,
+ send_something, AT_LEAST_PTO,
+};
+use crate::addr_valid::{AddressValidation, ValidateAddress};
+
+use std::cell::RefCell;
+use std::rc::Rc;
+use std::time::Duration;
+use test_fixture::{self, assertions, now};
+
+#[test]
+fn resume() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let token = exchange_ticket(&mut client, &mut server, now());
+ let mut client = default_client();
+ client
+ .enable_resumption(now(), token)
+ .expect("should set token");
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ assert!(client.tls_info().unwrap().resumed());
+ assert!(server.tls_info().unwrap().resumed());
+}
+
+#[test]
+fn remember_smoothed_rtt() {
+ const RTT1: Duration = Duration::from_millis(130);
+ const RTT2: Duration = Duration::from_millis(70);
+
+ let mut client = default_client();
+ let mut server = default_server();
+
+ let now = connect_with_rtt(&mut client, &mut server, now(), RTT1);
+ assert_eq!(client.loss_recovery.rtt(), RTT1);
+
+ let token = exchange_ticket(&mut client, &mut server, now);
+ let mut client = default_client();
+ let mut server = default_server();
+ client.enable_resumption(now, token).unwrap();
+ assert_eq!(
+ client.loss_recovery.rtt(),
+ RTT1,
+ "client should remember previous RTT"
+ );
+
+ connect_with_rtt(&mut client, &mut server, now, RTT2);
+ assert_eq!(
+ client.loss_recovery.rtt(),
+ RTT2,
+ "previous RTT should be completely erased"
+ );
+}
+
+/// Check that a resumed connection uses a token on Initial packets.
+#[test]
+fn address_validation_token_resume() {
+ const RTT: Duration = Duration::from_millis(10);
+
+ let mut client = default_client();
+ let mut server = default_server();
+ let validation = AddressValidation::new(now(), ValidateAddress::Always).unwrap();
+ let validation = Rc::new(RefCell::new(validation));
+ server.set_validation(Rc::clone(&validation));
+ let mut now = connect_with_rtt(&mut client, &mut server, now(), RTT);
+
+ let token = exchange_ticket(&mut client, &mut server, now);
+ let mut client = default_client();
+ client.enable_resumption(now, token).unwrap();
+ let mut server = default_server();
+
+ // Grab an Initial packet from the client.
+ let dgram = client.process(None, now).dgram();
+ assertions::assert_initial(dgram.as_ref().unwrap(), true);
+
+ // Now try to complete the handshake after giving time for a client PTO.
+ now += AT_LEAST_PTO;
+ connect_with_rtt(&mut client, &mut server, now, RTT);
+ assert!(client.crypto.tls.info().unwrap().resumed());
+ assert!(server.crypto.tls.info().unwrap().resumed());
+}
+
+fn can_resume(token: impl AsRef<[u8]>, initial_has_token: bool) {
+ let mut client = default_client();
+ client.enable_resumption(now(), token).unwrap();
+ let initial = client.process_output(now()).dgram();
+ assertions::assert_initial(initial.as_ref().unwrap(), initial_has_token);
+}
+
+#[test]
+fn two_tickets_on_timer() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // Send two tickets and then bundle those into a packet.
+ server.send_ticket(now(), &[]).expect("send ticket1");
+ server.send_ticket(now(), &[]).expect("send ticket2");
+ let pkt = send_something(&mut server, now());
+
+ // process() will return an ack first
+ assert!(client.process(Some(pkt), now()).dgram().is_some());
+ // We do not have a ResumptionToken event yet, because NEW_TOKEN was not sent.
+ assert_eq!(get_tokens(&mut client).len(), 0);
+
+ // We need to wait for release_resumption_token_timer to expire. The timer will be
+ // set to 3 * PTO
+ let mut now = now() + 3 * client.get_pto();
+ let _ = client.process(None, now);
+ let mut recv_tokens = get_tokens(&mut client);
+ assert_eq!(recv_tokens.len(), 1);
+ let token1 = recv_tokens.pop().unwrap();
+ // Wai for anottheer 3 * PTO to get the nex okeen.
+ now += 3 * client.get_pto();
+ let _ = client.process(None, now);
+ let mut recv_tokens = get_tokens(&mut client);
+ assert_eq!(recv_tokens.len(), 1);
+ let token2 = recv_tokens.pop().unwrap();
+ // Wait for 3 * PTO, but now there are no more tokens.
+ now += 3 * client.get_pto();
+ let _ = client.process(None, now);
+ assert_eq!(get_tokens(&mut client).len(), 0);
+ assert_ne!(token1.as_ref(), token2.as_ref());
+
+ can_resume(&token1, false);
+ can_resume(&token2, false);
+}
+
+#[test]
+fn two_tickets_with_new_token() {
+ let mut client = default_client();
+ let mut server = default_server();
+ let validation = AddressValidation::new(now(), ValidateAddress::Always).unwrap();
+ let validation = Rc::new(RefCell::new(validation));
+ server.set_validation(Rc::clone(&validation));
+ connect(&mut client, &mut server);
+
+ // Send two tickets with tokens and then bundle those into a packet.
+ server.send_ticket(now(), &[]).expect("send ticket1");
+ server.send_ticket(now(), &[]).expect("send ticket2");
+ let pkt = send_something(&mut server, now());
+
+ client.process_input(pkt, now());
+ let mut all_tokens = get_tokens(&mut client);
+ assert_eq!(all_tokens.len(), 2);
+ let token1 = all_tokens.pop().unwrap();
+ let token2 = all_tokens.pop().unwrap();
+ assert_ne!(token1.as_ref(), token2.as_ref());
+
+ can_resume(&token1, true);
+ can_resume(&token2, true);
+}
+
+/// By disabling address validation, the server won't send `NEW_TOKEN`, but
+/// we can take the session ticket still.
+#[test]
+fn take_token() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ server.send_ticket(now(), &[]).unwrap();
+ let dgram = server.process(None, now()).dgram();
+ client.process_input(dgram.unwrap(), now());
+
+ // There should be no ResumptionToken event here.
+ let tokens = get_tokens(&mut client);
+ assert_eq!(tokens.len(), 0);
+
+ // But we should be able to get the token directly, and use it.
+ let token = client.take_resumption_token(now()).unwrap();
+ can_resume(&token, false);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/stream.rs b/third_party/rust/neqo-transport/src/connection/tests/stream.rs
new file mode 100644
index 0000000000..1c9eebaa17
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/stream.rs
@@ -0,0 +1,580 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use super::super::State;
+use super::{
+ connect, default_client, default_server, maybe_authenticate, send_something,
+ DEFAULT_STREAM_DATA,
+};
+use crate::events::ConnectionEvent;
+use crate::frame::StreamType;
+use crate::recv_stream::RECV_BUFFER_SIZE;
+use crate::send_stream::SEND_BUFFER_SIZE;
+use crate::tparams::{self, TransportParameter};
+use crate::tracking::MAX_UNACKED_PKTS;
+use crate::{Error, StreamId};
+
+use neqo_common::{event::Provider, qdebug};
+use std::convert::TryFrom;
+use test_fixture::now;
+
+#[test]
+fn stream_create() {
+ let mut client = default_client();
+
+ let out = client.process(None, now());
+ let mut server = default_server();
+ let out = server.process(out.dgram(), now());
+
+ let out = client.process(out.dgram(), now());
+ let _ = server.process(out.dgram(), now());
+ assert!(maybe_authenticate(&mut client));
+ let out = client.process(None, now());
+
+ // client now in State::Connected
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2);
+ assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 6);
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 0);
+ assert_eq!(client.stream_create(StreamType::BiDi).unwrap(), 4);
+
+ let _ = server.process(out.dgram(), now());
+ // server now in State::Connected
+ assert_eq!(server.stream_create(StreamType::UniDi).unwrap(), 3);
+ assert_eq!(server.stream_create(StreamType::UniDi).unwrap(), 7);
+ assert_eq!(server.stream_create(StreamType::BiDi).unwrap(), 1);
+ assert_eq!(server.stream_create(StreamType::BiDi).unwrap(), 5);
+}
+
+#[test]
+#[allow(clippy::cognitive_complexity)]
+// tests stream send/recv after connection is established.
+fn transfer() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ qdebug!("---- client");
+ let out = client.process(None, now());
+ assert!(out.as_dgram_ref().is_some());
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+ // -->> Initial[0]: CRYPTO[CH]
+
+ qdebug!("---- server");
+ let out = server.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_some());
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+ // <<-- Initial[0]: CRYPTO[SH] ACK[0]
+ // <<-- Handshake[0]: CRYPTO[EE, CERT, CV, FIN]
+
+ qdebug!("---- client");
+ let out = client.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_some());
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+ // -->> Initial[1]: ACK[0]
+
+ let out = server.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_none());
+
+ assert!(maybe_authenticate(&mut client));
+
+ qdebug!("---- client");
+ let out = client.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(*client.state(), State::Connected);
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+ // -->> Handshake[0]: CRYPTO[FIN], ACK[0]
+
+ qdebug!("---- server");
+ let out = server.process(out.dgram(), now());
+ assert!(out.as_dgram_ref().is_some());
+ assert_eq!(*server.state(), State::Confirmed);
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+ // ACK and HANDSHAKE_DONE
+ // -->> nothing
+
+ qdebug!("---- client");
+ // Send
+ let client_stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(client_stream_id, &[6; 100]).unwrap();
+ client.stream_send(client_stream_id, &[7; 40]).unwrap();
+ client.stream_send(client_stream_id, &[8; 4000]).unwrap();
+
+ // Send to another stream but some data after fin has been set
+ let client_stream_id2 = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(client_stream_id2, &[6; 60]).unwrap();
+ client.stream_close_send(client_stream_id2).unwrap();
+ client.stream_send(client_stream_id2, &[7; 50]).unwrap_err();
+ // Sending this much takes a few datagrams.
+ let mut datagrams = vec![];
+ let mut out = client.process(out.dgram(), now());
+ while let Some(d) = out.dgram() {
+ datagrams.push(d);
+ out = client.process(None, now());
+ }
+ assert_eq!(datagrams.len(), 4);
+ assert_eq!(*client.state(), State::Confirmed);
+
+ qdebug!("---- server");
+ for (d_num, d) in datagrams.into_iter().enumerate() {
+ let out = server.process(Some(d), now());
+ assert_eq!(
+ out.as_dgram_ref().is_some(),
+ (d_num + 1) % (MAX_UNACKED_PKTS + 1) == 0
+ );
+ qdebug!("Output={:0x?}", out.as_dgram_ref());
+ }
+ assert_eq!(*server.state(), State::Confirmed);
+
+ let mut buf = vec![0; 4000];
+
+ let mut stream_ids = server.events().filter_map(|evt| match evt {
+ ConnectionEvent::NewStream { stream_id, .. } => Some(stream_id),
+ _ => None,
+ });
+ let first_stream = stream_ids.next().expect("should have a new stream event");
+ let second_stream = stream_ids
+ .next()
+ .expect("should have a second new stream event");
+ assert!(stream_ids.next().is_none());
+ let (received1, fin1) = server.stream_recv(first_stream.as_u64(), &mut buf).unwrap();
+ assert_eq!(received1, 4000);
+ assert_eq!(fin1, false);
+ let (received2, fin2) = server.stream_recv(first_stream.as_u64(), &mut buf).unwrap();
+ assert_eq!(received2, 140);
+ assert_eq!(fin2, false);
+
+ let (received3, fin3) = server
+ .stream_recv(second_stream.as_u64(), &mut buf)
+ .unwrap();
+ assert_eq!(received3, 60);
+ assert_eq!(fin3, true);
+}
+
+#[test]
+// Send fin even if a peer closes a reomte bidi send stream before sending any data.
+fn report_fin_when_stream_closed_wo_data() {
+ // Note that the two servers in this test will get different anti-replay filters.
+ // That's OK because we aren't testing anti-replay.
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // create a stream
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out = client.process(None, now());
+ let _ = server.process(out.dgram(), now());
+
+ assert_eq!(Ok(()), server.stream_close_send(stream_id));
+ let out = server.process(None, now());
+ let _ = client.process(out.dgram(), now());
+ let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable {..});
+ assert!(client.events().any(stream_readable));
+}
+
+#[test]
+fn max_data() {
+ const SMALL_MAX_DATA: usize = 16383;
+
+ let mut client = default_client();
+ let mut server = default_server();
+
+ server
+ .set_local_tparam(
+ tparams::INITIAL_MAX_DATA,
+ TransportParameter::Integer(u64::try_from(SMALL_MAX_DATA).unwrap()),
+ )
+ .unwrap();
+
+ connect(&mut client, &mut server);
+
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(client.events().count(), 2); // SendStreamWritable, StateChange(connected)
+ assert_eq!(stream_id, 2);
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SMALL_MAX_DATA
+ );
+ assert_eq!(
+ client
+ .stream_send(stream_id, &vec![b'a'; RECV_BUFFER_SIZE].into_boxed_slice())
+ .unwrap(),
+ SMALL_MAX_DATA
+ );
+ assert_eq!(client.events().count(), 0);
+
+ assert_eq!(client.stream_send(stream_id, b"hello").unwrap(), 0);
+ client
+ .send_streams
+ .get_mut(stream_id.into())
+ .unwrap()
+ .mark_as_sent(0, 4096, false);
+ assert_eq!(client.events().count(), 0);
+ client
+ .send_streams
+ .get_mut(stream_id.into())
+ .unwrap()
+ .mark_as_acked(0, 4096, false);
+ assert_eq!(client.events().count(), 0);
+
+ assert_eq!(client.stream_send(stream_id, b"hello").unwrap(), 0);
+ // no event because still limited by conn max data
+ assert_eq!(client.events().count(), 0);
+
+ // Increase max data. Avail space now limited by stream credit
+ client.handle_max_data(100_000_000);
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SEND_BUFFER_SIZE - SMALL_MAX_DATA
+ );
+
+ // Increase max stream data. Avail space now limited by tx buffer
+ client
+ .send_streams
+ .get_mut(stream_id.into())
+ .unwrap()
+ .set_max_stream_data(100_000_000);
+ assert_eq!(
+ client.stream_avail_send_space(stream_id).unwrap(),
+ SEND_BUFFER_SIZE - SMALL_MAX_DATA + 4096
+ );
+
+ let evts = client.events().collect::<Vec<_>>();
+ assert_eq!(evts.len(), 1);
+ assert!(matches!(evts[0], ConnectionEvent::SendStreamWritable{..}));
+}
+
+#[test]
+// If we send a stop_sending to the peer, we should not accept more data from the peer.
+fn do_not_accept_data_after_stop_sending() {
+ // Note that the two servers in this test will get different anti-replay filters.
+ // That's OK because we aren't testing anti-replay.
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // create a stream
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out = client.process(None, now());
+ let _ = server.process(out.dgram(), now());
+
+ let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable {..});
+ assert!(server.events().any(stream_readable));
+
+ // Send one more packet from client. The packet should arrive after the server
+ // has already requested stop_sending.
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out_second_data_frame = client.process(None, now());
+ // Call stop sending.
+ assert_eq!(
+ Ok(()),
+ server.stream_stop_sending(stream_id, Error::NoError.code())
+ );
+
+ // Receive the second data frame. The frame should be ignored and
+ // DataReadable events shouldn't be posted.
+ let out = server.process(out_second_data_frame.dgram(), now());
+ assert!(!server.events().any(stream_readable));
+
+ let _ = client.process(out.dgram(), now());
+ assert_eq!(
+ Err(Error::FinalSizeError),
+ client.stream_send(stream_id, &[0x00])
+ );
+}
+
+#[test]
+// Server sends stop_sending, the client simultaneous sends reset.
+fn simultaneous_stop_sending_and_reset() {
+ // Note that the two servers in this test will get different anti-replay filters.
+ // That's OK because we aren't testing anti-replay.
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // create a stream
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out = client.process(None, now());
+ let _ = server.process(out.dgram(), now());
+
+ let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable {..});
+ assert!(server.events().any(stream_readable));
+
+ // The client resets the stream. The packet with reset should arrive after the server
+ // has already requested stop_sending.
+ client
+ .stream_reset_send(stream_id, Error::NoError.code())
+ .unwrap();
+ let out_reset_frame = client.process(None, now());
+ // Call stop sending.
+ assert_eq!(
+ Ok(()),
+ server.stream_stop_sending(stream_id, Error::NoError.code())
+ );
+
+ // Receive the second data frame. The frame should be ignored and
+ // DataReadable events shouldn't be posted.
+ let out = server.process(out_reset_frame.dgram(), now());
+ assert!(!server.events().any(stream_readable));
+
+ // The client gets the STOP_SENDING frame.
+ let _ = client.process(out.dgram(), now());
+ assert_eq!(
+ Err(Error::InvalidStreamId),
+ client.stream_send(stream_id, &[0x00])
+ );
+}
+
+#[test]
+fn client_fin_reorder() {
+ let mut client = default_client();
+ let mut server = default_server();
+
+ // Send ClientHello.
+ let client_hs = client.process(None, now());
+ assert!(client_hs.as_dgram_ref().is_some());
+
+ let server_hs = server.process(client_hs.dgram(), now());
+ assert!(server_hs.as_dgram_ref().is_some()); // ServerHello, etc...
+
+ let client_ack = client.process(server_hs.dgram(), now());
+ assert!(client_ack.as_dgram_ref().is_some());
+
+ let server_out = server.process(client_ack.dgram(), now());
+ assert!(server_out.as_dgram_ref().is_none());
+
+ assert!(maybe_authenticate(&mut client));
+ assert_eq!(*client.state(), State::Connected);
+
+ let client_fin = client.process(None, now());
+ assert!(client_fin.as_dgram_ref().is_some());
+
+ let client_stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(client_stream_id, &[1, 2, 3]).unwrap();
+ let client_stream_data = client.process(None, now());
+ assert!(client_stream_data.as_dgram_ref().is_some());
+
+ // Now stream data gets before client_fin
+ let server_out = server.process(client_stream_data.dgram(), now());
+ assert!(server_out.as_dgram_ref().is_none()); // the packet will be discarded
+
+ assert_eq!(*server.state(), State::Handshaking);
+ let server_out = server.process(client_fin.dgram(), now());
+ assert!(server_out.as_dgram_ref().is_some());
+}
+
+#[test]
+fn after_fin_is_read_conn_events_for_stream_should_be_removed() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let id = server.stream_create(StreamType::BiDi).unwrap();
+ server.stream_send(id, &[6; 10]).unwrap();
+ server.stream_close_send(id).unwrap();
+ let out = server.process(None, now()).dgram();
+ assert!(out.is_some());
+
+ let _ = client.process(out, now());
+
+ // read from the stream before checking connection events.
+ let mut buf = vec![0; 4000];
+ let (_, fin) = client.stream_recv(id, &mut buf).unwrap();
+ assert_eq!(fin, true);
+
+ // Make sure we do not have RecvStreamReadable events for the stream when fin has been read.
+ let readable_stream_evt =
+ |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id } if stream_id == id);
+ assert!(!client.events().any(readable_stream_evt));
+}
+
+#[test]
+fn after_stream_stop_sending_is_called_conn_events_for_stream_should_be_removed() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let id = server.stream_create(StreamType::BiDi).unwrap();
+ server.stream_send(id, &[6; 10]).unwrap();
+ server.stream_close_send(id).unwrap();
+ let out = server.process(None, now()).dgram();
+ assert!(out.is_some());
+
+ let _ = client.process(out, now());
+
+ // send stop seending.
+ client
+ .stream_stop_sending(id, Error::NoError.code())
+ .unwrap();
+
+ // Make sure we do not have RecvStreamReadable events for the stream after stream_stop_sending
+ // has been called.
+ let readable_stream_evt =
+ |e| matches!(e, ConnectionEvent::RecvStreamReadable { stream_id } if stream_id == id);
+ assert!(!client.events().any(readable_stream_evt));
+}
+
+#[test]
+fn stream_data_blocked_generates_max_stream_data() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let now = now();
+
+ // Send some data and include STREAM_DATA_BLOCKED with any value.
+ let stream_id = server.stream_create(StreamType::UniDi).unwrap();
+ let _ = server.stream_send(stream_id, DEFAULT_STREAM_DATA).unwrap();
+ server.flow_mgr.borrow_mut().stream_data_blocked(
+ StreamId::from(stream_id),
+ u64::try_from(DEFAULT_STREAM_DATA.len()).unwrap(),
+ );
+
+ let dgram = server.process(None, now).dgram();
+ assert!(dgram.is_some());
+
+ let sdb_before = client.stats().frame_rx.stream_data_blocked;
+ client.process_input(dgram.unwrap(), now);
+ assert_eq!(client.stats().frame_rx.stream_data_blocked, sdb_before + 1);
+
+ // Consume the data.
+ let mut buf = [0; 10];
+ let (count, end) = client.stream_recv(stream_id, &mut buf[..]).unwrap();
+ assert_eq!(count, DEFAULT_STREAM_DATA.len());
+ assert!(!end);
+
+ let dgram = client.process_output(now).dgram();
+
+ // Client should have sent a MAX_STREAM_DATA frame with just a small increase
+ // on the default window size.
+ let msd_before = server.stats().frame_rx.max_stream_data;
+ server.process_input(dgram.unwrap(), now);
+ assert_eq!(server.stats().frame_rx.max_stream_data, msd_before + 1);
+
+ // Test that more space is available, but that it is small.
+ let mut written = 0;
+ loop {
+ const LARGE_BUFFER: &[u8] = &[0; 1024];
+ let amount = server.stream_send(stream_id, LARGE_BUFFER).unwrap();
+ if amount == 0 {
+ break;
+ }
+ written += amount;
+ }
+ assert_eq!(written, RECV_BUFFER_SIZE - DEFAULT_STREAM_DATA.len());
+}
+
+/// See <https://github.com/mozilla/neqo/issues/871>
+#[test]
+fn max_streams_after_bidi_closed() {
+ const REQUEST: &[u8] = b"ping";
+ const RESPONSE: &[u8] = b"pong";
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ while client.stream_create(StreamType::BiDi).is_ok() {
+ // Exhaust the stream limit.
+ }
+ // Write on the one stream and send that out.
+ let _ = client.stream_send(stream_id, REQUEST).unwrap();
+ client.stream_close_send(stream_id).unwrap();
+ let dgram = client.process(None, now()).dgram();
+
+ // Now handle the stream and send an incomplete response.
+ server.process_input(dgram.unwrap(), now());
+ server.stream_send(stream_id, RESPONSE).unwrap();
+ let dgram = server.process_output(now()).dgram();
+
+ // The server shouldn't have released more stream credit.
+ client.process_input(dgram.unwrap(), now());
+ let e = client.stream_create(StreamType::BiDi).unwrap_err();
+ assert!(matches!(e, Error::StreamLimitError));
+
+ // Closing the stream isn't enough.
+ server.stream_close_send(stream_id).unwrap();
+ let dgram = server.process_output(now()).dgram();
+ client.process_input(dgram.unwrap(), now());
+ assert!(client.stream_create(StreamType::BiDi).is_err());
+
+ // The server needs to see an acknowledgment from the client for its
+ // response AND the server has to read all of the request.
+ // and the server needs to read all the data. Read first.
+ let mut buf = [0; REQUEST.len()];
+ let (count, fin) = server.stream_recv(stream_id, &mut buf).unwrap();
+ assert_eq!(&buf[..count], REQUEST);
+ assert!(fin);
+
+ // We need an ACK from the client now, but that isn't guaranteed,
+ // so give the client one more packet just in case.
+ let dgram = send_something(&mut server, now());
+ client.process_input(dgram, now());
+
+ // Now get the client to send the ACK and have the server handle that.
+ let dgram = send_something(&mut client, now());
+ let dgram = server.process(Some(dgram), now()).dgram();
+ client.process_input(dgram.unwrap(), now());
+ assert!(client.stream_create(StreamType::BiDi).is_ok());
+ assert!(client.stream_create(StreamType::BiDi).is_err());
+}
+
+#[test]
+fn no_dupdata_readable_events() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // create a stream
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out = client.process(None, now());
+ let _ = server.process(out.dgram(), now());
+
+ // We have a data_readable event.
+ let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable {..});
+ assert!(server.events().any(stream_readable));
+
+ // Send one more data frame from client. The previous stream data has not been read yet,
+ // therefore there should not be a new DataReadable event.
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out_second_data_frame = client.process(None, now());
+ let _ = server.process(out_second_data_frame.dgram(), now());
+ assert!(!server.events().any(stream_readable));
+
+ // One more frame with a fin will not produce a new DataReadable event, because the
+ // previous stream data has not been read yet.
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ client.stream_close_send(stream_id).unwrap();
+ let out_third_data_frame = client.process(None, now());
+ let _ = server.process(out_third_data_frame.dgram(), now());
+ assert!(!server.events().any(stream_readable));
+}
+
+#[test]
+fn no_dupdata_readable_events_empty_last_frame() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ // create a stream
+ let stream_id = client.stream_create(StreamType::BiDi).unwrap();
+ client.stream_send(stream_id, &[0x00]).unwrap();
+ let out = client.process(None, now());
+ let _ = server.process(out.dgram(), now());
+
+ // We have a data_readable event.
+ let stream_readable = |e| matches!(e, ConnectionEvent::RecvStreamReadable {..});
+ assert!(server.events().any(stream_readable));
+
+ // An empty frame with a fin will not produce a new DataReadable event, because
+ // the previous stream data has not been read yet.
+ client.stream_close_send(stream_id).unwrap();
+ let out_second_data_frame = client.process(None, now());
+ let _ = server.process(out_second_data_frame.dgram(), now());
+ assert!(!server.events().any(stream_readable));
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/vn.rs b/third_party/rust/neqo-transport/src/connection/tests/vn.rs
new file mode 100644
index 0000000000..f7d37c7864
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/vn.rs
@@ -0,0 +1,201 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use super::super::{ConnectionError, Output, State};
+use super::{default_client, default_server};
+use crate::packet::PACKET_BIT_LONG;
+use crate::{Error, QuicVersion};
+
+use neqo_common::{Datagram, Decoder, Encoder};
+use std::time::Duration;
+use test_fixture::{self, loopback, now};
+
+// The expected PTO duration after the first Initial is sent.
+const INITIAL_PTO: Duration = Duration::from_millis(300);
+
+#[test]
+fn unknown_version() {
+ let mut client = default_client();
+ // Start the handshake.
+ let _ = client.process(None, now()).dgram();
+
+ let mut unknown_version_packet = vec![0x80, 0x1a, 0x1a, 0x1a, 0x1a];
+ unknown_version_packet.resize(1200, 0x0);
+ let _ = client.process(
+ Some(Datagram::new(
+ loopback(),
+ loopback(),
+ unknown_version_packet,
+ )),
+ now(),
+ );
+ assert_eq!(1, client.stats().dropped_rx);
+}
+
+#[test]
+fn server_receive_unknown_first_packet() {
+ let mut server = default_server();
+
+ let mut unknown_version_packet = vec![0x80, 0x1a, 0x1a, 0x1a, 0x1a];
+ unknown_version_packet.resize(1200, 0x0);
+
+ assert_eq!(
+ server.process(
+ Some(Datagram::new(
+ loopback(),
+ loopback(),
+ unknown_version_packet,
+ )),
+ now(),
+ ),
+ Output::None
+ );
+
+ assert_eq!(1, server.stats().dropped_rx);
+}
+
+fn create_vn(initial_pkt: &[u8], versions: &[u32]) -> Vec<u8> {
+ let mut dec = Decoder::from(&initial_pkt[5..]); // Skip past version.
+ let dst_cid = dec.decode_vec(1).expect("client DCID");
+ let src_cid = dec.decode_vec(1).expect("client SCID");
+
+ let mut encoder = Encoder::default();
+ encoder.encode_byte(PACKET_BIT_LONG);
+ encoder.encode(&[0; 4]); // Zero version == VN.
+ encoder.encode_vec(1, dst_cid);
+ encoder.encode_vec(1, src_cid);
+
+ for v in versions {
+ encoder.encode_uint(4, *v);
+ }
+ encoder.into()
+}
+
+#[test]
+fn version_negotiation_current_version() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let vn = create_vn(
+ &initial_pkt,
+ &[0x1a1a_1a1a, QuicVersion::default().as_u32()],
+ );
+
+ let dgram = Datagram::new(loopback(), loopback(), vn);
+ let delay = client.process(Some(dgram), now()).callback();
+ assert_eq!(delay, INITIAL_PTO);
+ assert_eq!(*client.state(), State::WaitInitial);
+ assert_eq!(1, client.stats().dropped_rx);
+}
+
+#[test]
+fn version_negotiation_only_reserved() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a]);
+
+ let dgram = Datagram::new(loopback(), loopback(), vn);
+ assert_eq!(client.process(Some(dgram), now()), Output::None);
+ match client.state() {
+ State::Closed(err) => {
+ assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation))
+ }
+ _ => panic!("Invalid client state"),
+ }
+}
+
+#[test]
+fn version_negotiation_corrupted() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a]);
+
+ let dgram = Datagram::new(loopback(), loopback(), &vn[..vn.len() - 1]);
+ let delay = client.process(Some(dgram), now()).callback();
+ assert_eq!(delay, INITIAL_PTO);
+ assert_eq!(*client.state(), State::WaitInitial);
+ assert_eq!(1, client.stats().dropped_rx);
+}
+
+#[test]
+fn version_negotiation_empty() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let vn = create_vn(&initial_pkt, &[]);
+
+ let dgram = Datagram::new(loopback(), loopback(), vn);
+ let delay = client.process(Some(dgram), now()).callback();
+ assert_eq!(delay, INITIAL_PTO);
+ assert_eq!(*client.state(), State::WaitInitial);
+ assert_eq!(1, client.stats().dropped_rx);
+}
+
+#[test]
+fn version_negotiation_not_supported() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a, 0xff00_0001]);
+
+ assert_eq!(
+ client.process(Some(Datagram::new(loopback(), loopback(), vn)), now(),),
+ Output::None
+ );
+ match client.state() {
+ State::Closed(err) => {
+ assert_eq!(*err, ConnectionError::Transport(Error::VersionNegotiation))
+ }
+ _ => panic!("Invalid client state"),
+ }
+}
+
+#[test]
+fn version_negotiation_bad_cid() {
+ let mut client = default_client();
+ // Start the handshake.
+ let initial_pkt = client
+ .process(None, now())
+ .dgram()
+ .expect("a datagram")
+ .to_vec();
+
+ let mut vn = create_vn(&initial_pkt, &[0x1a1a_1a1a, 0x2a2a_2a2a, 0xff00_0001]);
+ vn[6] ^= 0xc4;
+
+ let dgram = Datagram::new(loopback(), loopback(), vn);
+ let delay = client.process(Some(dgram), now()).callback();
+ assert_eq!(delay, INITIAL_PTO);
+ assert_eq!(*client.state(), State::WaitInitial);
+ assert_eq!(1, client.stats().dropped_rx);
+}
diff --git a/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs b/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs
new file mode 100644
index 0000000000..4ccffe4203
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/connection/tests/zerortt.rs
@@ -0,0 +1,193 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use super::super::{Connection, FixedConnectionIdManager};
+use super::{connect, default_client, default_server, exchange_ticket};
+use crate::events::ConnectionEvent;
+use crate::frame::StreamType;
+use crate::{ConnectionParameters, Error};
+
+use neqo_common::event::Provider;
+use neqo_crypto::{AllowZeroRtt, AntiReplay};
+use std::cell::RefCell;
+use std::rc::Rc;
+use test_fixture::{self, assertions, now};
+
+#[test]
+fn zero_rtt_negotiate() {
+ // Note that the two servers in this test will get different anti-replay filters.
+ // That's OK because we aren't testing anti-replay.
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let token = exchange_ticket(&mut client, &mut server, now());
+ let mut client = default_client();
+ client
+ .enable_resumption(now(), token)
+ .expect("should set token");
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+ assert!(client.tls_info().unwrap().early_data_accepted());
+ assert!(server.tls_info().unwrap().early_data_accepted());
+}
+
+#[test]
+fn zero_rtt_send_recv() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let token = exchange_ticket(&mut client, &mut server, now());
+ let mut client = default_client();
+ client
+ .enable_resumption(now(), token)
+ .expect("should set token");
+ let mut server = default_server();
+
+ // Send ClientHello.
+ let client_hs = client.process(None, now());
+ assert!(client_hs.as_dgram_ref().is_some());
+
+ // Now send a 0-RTT packet.
+ let client_stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(client_stream_id, &[1, 2, 3]).unwrap();
+ let client_0rtt = client.process(None, now());
+ assert!(client_0rtt.as_dgram_ref().is_some());
+ // 0-RTT packets on their own shouldn't be padded to 1200.
+ assert!(client_0rtt.as_dgram_ref().unwrap().len() < 1200);
+
+ let server_hs = server.process(client_hs.dgram(), now());
+ assert!(server_hs.as_dgram_ref().is_some()); // ServerHello, etc...
+ let server_process_0rtt = server.process(client_0rtt.dgram(), now());
+ assert!(server_process_0rtt.as_dgram_ref().is_none());
+
+ let server_stream_id = server
+ .events()
+ .find_map(|evt| match evt {
+ ConnectionEvent::NewStream { stream_id, .. } => Some(stream_id),
+ _ => None,
+ })
+ .expect("should have received a new stream event");
+ assert_eq!(client_stream_id, server_stream_id.as_u64());
+}
+
+#[test]
+fn zero_rtt_send_coalesce() {
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let token = exchange_ticket(&mut client, &mut server, now());
+ let mut client = default_client();
+ client
+ .enable_resumption(now(), token)
+ .expect("should set token");
+ let mut server = default_server();
+
+ // Write 0-RTT before generating any packets.
+ // This should result in a datagram that coalesces Initial and 0-RTT.
+ let client_stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(client_stream_id, &[1, 2, 3]).unwrap();
+ let client_0rtt = client.process(None, now());
+ assert!(client_0rtt.as_dgram_ref().is_some());
+
+ assertions::assert_coalesced_0rtt(&client_0rtt.as_dgram_ref().unwrap()[..]);
+
+ let server_hs = server.process(client_0rtt.dgram(), now());
+ assert!(server_hs.as_dgram_ref().is_some()); // Should produce ServerHello etc...
+
+ let server_stream_id = server
+ .events()
+ .find_map(|evt| match evt {
+ ConnectionEvent::NewStream { stream_id } => Some(stream_id),
+ _ => None,
+ })
+ .expect("should have received a new stream event");
+ assert_eq!(client_stream_id, server_stream_id.as_u64());
+}
+
+#[test]
+fn zero_rtt_before_resumption_token() {
+ let mut client = default_client();
+ assert!(client.stream_create(StreamType::BiDi).is_err());
+}
+
+#[test]
+fn zero_rtt_send_reject() {
+ const MESSAGE: &[u8] = &[1, 2, 3];
+
+ let mut client = default_client();
+ let mut server = default_server();
+ connect(&mut client, &mut server);
+
+ let token = exchange_ticket(&mut client, &mut server, now());
+ let mut client = default_client();
+ client
+ .enable_resumption(now(), token)
+ .expect("should set token");
+ let mut server = Connection::new_server(
+ test_fixture::DEFAULT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ Rc::new(RefCell::new(FixedConnectionIdManager::new(10))),
+ &ConnectionParameters::default(),
+ )
+ .unwrap();
+ // Using a freshly initialized anti-replay context
+ // should result in the server rejecting 0-RTT.
+ let ar =
+ AntiReplay::new(now(), test_fixture::ANTI_REPLAY_WINDOW, 1, 3).expect("setup anti-replay");
+ server
+ .server_enable_0rtt(&ar, AllowZeroRtt {})
+ .expect("enable 0-RTT");
+
+ // Send ClientHello.
+ let client_hs = client.process(None, now());
+ assert!(client_hs.as_dgram_ref().is_some());
+
+ // Write some data on the client.
+ let stream_id = client.stream_create(StreamType::UniDi).unwrap();
+ client.stream_send(stream_id, MESSAGE).unwrap();
+ let client_0rtt = client.process(None, now());
+ assert!(client_0rtt.as_dgram_ref().is_some());
+
+ let server_hs = server.process(client_hs.dgram(), now());
+ assert!(server_hs.as_dgram_ref().is_some()); // Should produce ServerHello etc...
+ let server_ignored = server.process(client_0rtt.dgram(), now());
+ assert!(server_ignored.as_dgram_ref().is_none());
+
+ // The server shouldn't receive that 0-RTT data.
+ let recvd_stream_evt = |e| matches!(e, ConnectionEvent::NewStream { .. });
+ assert!(!server.events().any(recvd_stream_evt));
+
+ // Client should get a rejection.
+ let client_fin = client.process(server_hs.dgram(), now());
+ let recvd_0rtt_reject = |e| e == ConnectionEvent::ZeroRttRejected;
+ assert!(client.events().any(recvd_0rtt_reject));
+
+ // Server consume client_fin
+ let server_ack = server.process(client_fin.dgram(), now());
+ assert!(server_ack.as_dgram_ref().is_some());
+ let client_out = client.process(server_ack.dgram(), now());
+ assert!(client_out.as_dgram_ref().is_none());
+
+ // ...and the client stream should be gone.
+ let res = client.stream_send(stream_id, MESSAGE);
+ assert!(res.is_err());
+ assert_eq!(res.unwrap_err(), Error::InvalidStreamId);
+
+ // Open a new stream and send data. StreamId should start with 0.
+ let stream_id_after_reject = client.stream_create(StreamType::UniDi).unwrap();
+ assert_eq!(stream_id, stream_id_after_reject);
+ client.stream_send(stream_id_after_reject, MESSAGE).unwrap();
+ let client_after_reject = client.process(None, now());
+ assert!(client_after_reject.as_dgram_ref().is_some());
+
+ // The server should receive new stream
+ let server_out = server.process(client_after_reject.dgram(), now());
+ assert!(server_out.as_dgram_ref().is_none()); // suppress the ack
+ assert!(server.events().any(recvd_stream_evt));
+}
diff --git a/third_party/rust/neqo-transport/src/crypto.rs b/third_party/rust/neqo-transport/src/crypto.rs
new file mode 100644
index 0000000000..bed669f003
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/crypto.rs
@@ -0,0 +1,1293 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::cell::RefCell;
+use std::cmp::{max, min};
+use std::convert::TryFrom;
+use std::mem;
+use std::ops::{Index, IndexMut, Range};
+use std::rc::Rc;
+use std::time::Instant;
+
+use neqo_common::{hex, hex_snip_middle, qdebug, qinfo, qtrace, Encoder, Role};
+use neqo_crypto::{
+ aead::Aead, hkdf, hp::HpKey, Agent, AntiReplay, Cipher, Epoch, HandshakeState, Record,
+ RecordList, ResumptionToken, SymKey, ZeroRttChecker, TLS_AES_128_GCM_SHA256,
+ TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, TLS_CT_HANDSHAKE,
+ TLS_EPOCH_APPLICATION_DATA, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL, TLS_EPOCH_ZERO_RTT,
+ TLS_VERSION_1_3,
+};
+
+use crate::packet::{PacketBuilder, PacketNumber, QuicVersion};
+use crate::recovery::RecoveryToken;
+use crate::recv_stream::RxStreamOrderer;
+use crate::send_stream::TxBuffer;
+use crate::tparams::{TpZeroRttChecker, TransportParameters, TransportParametersHandler};
+use crate::tracking::PNSpace;
+use crate::{Error, Res};
+
+const MAX_AUTH_TAG: usize = 32;
+/// The number of invocations remaining on a write cipher before we try
+/// to update keys. This has to be much smaller than the number returned
+/// by `CryptoDxState::limit` or updates will happen too often. As we don't
+/// need to ask permission to update, this can be quite small.
+pub(crate) const UPDATE_WRITE_KEYS_AT: PacketNumber = 100;
+
+// This is a testing kludge that allows for overwriting the number of
+// invocations of the next cipher to operate. With this, it is possible
+// to test what happens when the number of invocations reaches 0, or
+// when it hits `UPDATE_WRITE_KEYS_AT` and an automatic update should occur.
+// This is a little crude, but it saves a lot of plumbing.
+#[cfg(test)]
+thread_local!(pub(crate) static OVERWRITE_INVOCATIONS: RefCell<Option<PacketNumber>> = RefCell::default());
+
+#[derive(Debug)]
+pub struct Crypto {
+ pub(crate) tls: Agent,
+ pub(crate) streams: CryptoStreams,
+ pub(crate) states: CryptoStates,
+}
+
+type TpHandler = Rc<RefCell<TransportParametersHandler>>;
+
+impl Crypto {
+ pub fn new(mut agent: Agent, protocols: &[impl AsRef<str>], tphandler: TpHandler) -> Res<Self> {
+ agent.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?;
+ agent.set_ciphers(&[
+ TLS_AES_128_GCM_SHA256,
+ TLS_AES_256_GCM_SHA384,
+ TLS_CHACHA20_POLY1305_SHA256,
+ ])?;
+ agent.set_alpn(protocols)?;
+ agent.disable_end_of_early_data()?;
+ // Always enable 0-RTT on the client, but the server needs
+ // more configuration passed to server_enable_0rtt.
+ if let Agent::Client(c) = &mut agent {
+ c.enable_0rtt()?;
+ }
+ agent.extension_handler(0xffa5, tphandler)?;
+ Ok(Self {
+ tls: agent,
+ streams: Default::default(),
+ states: Default::default(),
+ })
+ }
+
+ pub fn server_enable_0rtt(
+ &mut self,
+ tphandler: TpHandler,
+ anti_replay: &AntiReplay,
+ zero_rtt_checker: impl ZeroRttChecker + 'static,
+ ) -> Res<()> {
+ if let Agent::Server(s) = &mut self.tls {
+ Ok(s.enable_0rtt(
+ anti_replay,
+ 0xffff_ffff,
+ TpZeroRttChecker::wrap(tphandler, zero_rtt_checker),
+ )?)
+ } else {
+ panic!("not a server");
+ }
+ }
+
+ pub fn handshake(
+ &mut self,
+ now: Instant,
+ space: PNSpace,
+ data: Option<&[u8]>,
+ ) -> Res<&HandshakeState> {
+ let input = data.map(|d| {
+ qtrace!("Handshake record received {:0x?} ", d);
+ let epoch = match space {
+ PNSpace::Initial => TLS_EPOCH_INITIAL,
+ PNSpace::Handshake => TLS_EPOCH_HANDSHAKE,
+ // Our epoch progresses forward, but the TLS epoch is fixed to 3.
+ PNSpace::ApplicationData => TLS_EPOCH_APPLICATION_DATA,
+ };
+ Record {
+ ct: TLS_CT_HANDSHAKE,
+ epoch,
+ data: d.to_vec(),
+ }
+ });
+
+ match self.tls.handshake_raw(now, input) {
+ Ok(output) => {
+ self.buffer_records(output)?;
+ Ok(self.tls.state())
+ }
+ Err(e) => {
+ qinfo!("Handshake failed");
+ Err(match self.tls.alert() {
+ Some(a) => Error::CryptoAlert(*a),
+ _ => Error::CryptoError(e),
+ })
+ }
+ }
+ }
+
+ /// Enable 0-RTT and return `true` if it is enabled successfully.
+ pub fn enable_0rtt(&mut self, role: Role) -> Res<bool> {
+ let info = self.tls.preinfo()?;
+ // `info.early_data()` returns false for a server,
+ // so use `early_data_cipher()` to tell if 0-RTT is enabled.
+ let cipher = info.early_data_cipher();
+ if cipher.is_none() {
+ return Ok(false);
+ }
+ let (dir, secret) = match role {
+ Role::Client => (
+ CryptoDxDirection::Write,
+ self.tls.write_secret(TLS_EPOCH_ZERO_RTT),
+ ),
+ Role::Server => (
+ CryptoDxDirection::Read,
+ self.tls.read_secret(TLS_EPOCH_ZERO_RTT),
+ ),
+ };
+ let secret = secret.ok_or(Error::InternalError)?;
+ self.states.set_0rtt_keys(dir, &secret, cipher.unwrap());
+ Ok(true)
+ }
+
+ /// Returns true if new handshake keys were installed.
+ pub fn install_keys(&mut self, role: Role) -> Res<bool> {
+ if !self.tls.state().is_final() {
+ let installed_hs = self.install_handshake_keys()?;
+ if role == Role::Server {
+ self.maybe_install_application_write_key()?;
+ }
+ Ok(installed_hs)
+ } else {
+ Ok(false)
+ }
+ }
+
+ fn install_handshake_keys(&mut self) -> Res<bool> {
+ qtrace!([self], "Attempt to install handshake keys");
+ let write_secret = if let Some(secret) = self.tls.write_secret(TLS_EPOCH_HANDSHAKE) {
+ secret
+ } else {
+ // No keys is fine.
+ return Ok(false);
+ };
+ let read_secret = self
+ .tls
+ .read_secret(TLS_EPOCH_HANDSHAKE)
+ .ok_or(Error::InternalError)?;
+ let cipher = match self.tls.info() {
+ None => self.tls.preinfo()?.cipher_suite(),
+ Some(info) => Some(info.cipher_suite()),
+ }
+ .ok_or(Error::InternalError)?;
+ self.states
+ .set_handshake_keys(&write_secret, &read_secret, cipher);
+ qdebug!([self], "Handshake keys installed");
+ Ok(true)
+ }
+
+ fn maybe_install_application_write_key(&mut self) -> Res<()> {
+ qtrace!([self], "Attempt to install application write key");
+ if let Some(secret) = self.tls.write_secret(TLS_EPOCH_APPLICATION_DATA) {
+ self.states.set_application_write_key(secret)?;
+ qdebug!([self], "Application write key installed");
+ }
+ Ok(())
+ }
+
+ pub fn install_application_keys(&mut self, expire_0rtt: Instant) -> Res<()> {
+ self.maybe_install_application_write_key()?;
+ // The write key might have been installed earlier, but it should
+ // always be installed now.
+ debug_assert!(self.states.app_write.is_some());
+ let read_secret = self
+ .tls
+ .read_secret(TLS_EPOCH_APPLICATION_DATA)
+ .ok_or(Error::InternalError)?;
+ self.states
+ .set_application_read_key(read_secret, expire_0rtt)?;
+ qdebug!([self], "application read keys installed");
+ Ok(())
+ }
+
+ /// Buffer crypto records for sending.
+ pub fn buffer_records(&mut self, records: RecordList) -> Res<()> {
+ for r in records {
+ if r.ct != TLS_CT_HANDSHAKE {
+ return Err(Error::ProtocolViolation);
+ }
+ qtrace!([self], "Adding CRYPTO data {:?}", r);
+ self.streams.send(PNSpace::from(r.epoch), &r.data);
+ }
+ Ok(())
+ }
+
+ pub fn acked(&mut self, token: &CryptoRecoveryToken) {
+ qinfo!(
+ "Acked crypto frame space={} offset={} length={}",
+ token.space,
+ token.offset,
+ token.length
+ );
+ self.streams.acked(token);
+ }
+
+ pub fn lost(&mut self, token: &CryptoRecoveryToken) {
+ qinfo!(
+ "Lost crypto frame space={} offset={} length={}",
+ token.space,
+ token.offset,
+ token.length
+ );
+ self.streams.lost(token);
+ }
+
+ /// Mark any outstanding frames in the indicated space as "lost" so
+ /// that they can be sent again.
+ pub fn resend_unacked(&mut self, space: PNSpace) {
+ self.streams.resend_unacked(space);
+ }
+
+ /// Discard state for a packet number space and return true
+ /// if something was discarded.
+ pub fn discard(&mut self, space: PNSpace) -> bool {
+ self.streams.discard(space);
+ self.states.discard(space)
+ }
+
+ pub fn create_resumption_token(
+ &mut self,
+ new_token: Option<&[u8]>,
+ tps: &TransportParameters,
+ rtt: u64,
+ ) -> Option<ResumptionToken> {
+ if let Agent::Client(ref mut c) = self.tls {
+ if let Some(ref t) = c.resumption_token() {
+ qtrace!("TLS token {}", hex(t.as_ref()));
+ let mut enc = Encoder::default();
+ enc.encode_varint(rtt);
+ enc.encode_vvec_with(|enc_inner| {
+ tps.encode(enc_inner);
+ });
+ enc.encode_vvec(new_token.unwrap_or(&[]));
+ enc.encode(t.as_ref());
+ qinfo!("resumption token {}", hex_snip_middle(&enc[..]));
+ Some(ResumptionToken::new(enc.into(), t.expiration_time()))
+ } else {
+ None
+ }
+ } else {
+ unreachable!("It is a server.");
+ }
+ }
+
+ pub fn has_resumption_token(&self) -> bool {
+ if let Agent::Client(c) = &self.tls {
+ c.has_resumption_token()
+ } else {
+ unreachable!("It is a server.");
+ }
+ }
+}
+
+impl ::std::fmt::Display for Crypto {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "Crypto")
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub enum CryptoDxDirection {
+ Read,
+ Write,
+}
+
+#[derive(Debug)]
+pub struct CryptoDxState {
+ direction: CryptoDxDirection,
+ /// The epoch of this crypto state. This initially tracks TLS epochs
+ /// via DTLS: 0 = initial, 1 = 0-RTT, 2 = handshake, 3 = application.
+ /// But we don't need to keep that, and QUIC isn't limited in how
+ /// many times keys can be updated, so we don't use `u16` for this.
+ epoch: usize,
+ aead: Aead,
+ hpkey: HpKey,
+ /// This tracks the range of packet numbers that have been seen. This allows
+ /// for verifying that packet numbers before a key update are strictly lower
+ /// than packet numbers after a key update.
+ used_pn: Range<PacketNumber>,
+ /// This is the minimum packet number that is allowed.
+ min_pn: PacketNumber,
+ /// The total number of operations that are remaining before the keys
+ /// become exhausted and can't be used any more.
+ invocations: PacketNumber,
+}
+
+impl CryptoDxState {
+ #[allow(clippy::unknown_clippy_lints)] // Until we require rust 1.45.
+ #[allow(clippy::reversed_empty_ranges)] // To initialize an empty range.
+ pub fn new(
+ direction: CryptoDxDirection,
+ epoch: Epoch,
+ secret: &SymKey,
+ cipher: Cipher,
+ ) -> Self {
+ qinfo!(
+ "Making {:?} {} CryptoDxState, cipher={}",
+ direction,
+ epoch,
+ cipher
+ );
+ Self {
+ direction,
+ epoch: usize::from(epoch),
+ aead: Aead::new(TLS_VERSION_1_3, cipher, secret, "quic ").unwrap(),
+ hpkey: HpKey::extract(TLS_VERSION_1_3, cipher, secret, "quic hp").unwrap(),
+ used_pn: 0..0,
+ min_pn: 0,
+ invocations: Self::limit(direction, cipher),
+ }
+ }
+
+ pub fn new_initial(
+ quic_version: QuicVersion,
+ direction: CryptoDxDirection,
+ label: &str,
+ dcid: &[u8],
+ ) -> Self {
+ qtrace!("new_initial for {:?}", quic_version);
+ const INITIAL_SALT_27: &[u8] = &[
+ 0xc3, 0xee, 0xf7, 0x12, 0xc7, 0x2e, 0xbb, 0x5a, 0x11, 0xa7, 0xd2, 0x43, 0x2b, 0xb4,
+ 0x63, 0x65, 0xbe, 0xf9, 0xf5, 0x02,
+ ];
+ const INITIAL_SALT_29_32: &[u8] = &[
+ 0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61,
+ 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99,
+ ];
+ let salt = match quic_version {
+ QuicVersion::Draft27 | QuicVersion::Draft28 => INITIAL_SALT_27,
+ QuicVersion::Draft29
+ | QuicVersion::Draft30
+ | QuicVersion::Draft31
+ | QuicVersion::Draft32 => INITIAL_SALT_29_32,
+ };
+ let cipher = TLS_AES_128_GCM_SHA256;
+ let initial_secret = hkdf::extract(
+ TLS_VERSION_1_3,
+ cipher,
+ Some(
+ hkdf::import_key(TLS_VERSION_1_3, cipher, salt)
+ .as_ref()
+ .unwrap(),
+ ),
+ hkdf::import_key(TLS_VERSION_1_3, cipher, dcid)
+ .as_ref()
+ .unwrap(),
+ )
+ .unwrap();
+
+ let secret =
+ hkdf::expand_label(TLS_VERSION_1_3, cipher, &initial_secret, &[], label).unwrap();
+
+ Self::new(direction, TLS_EPOCH_INITIAL, &secret, cipher)
+ }
+
+ /// Determine the confidentiality and integrity limits for the cipher.
+ fn limit(direction: CryptoDxDirection, cipher: Cipher) -> PacketNumber {
+ match direction {
+ // This uses the smaller limits for 2^16 byte packets
+ // as we don't control incoming packet size.
+ CryptoDxDirection::Read => match cipher {
+ TLS_AES_128_GCM_SHA256 => 1 << 52,
+ TLS_AES_256_GCM_SHA384 => PacketNumber::MAX,
+ TLS_CHACHA20_POLY1305_SHA256 => 1 << 36,
+ _ => unreachable!(),
+ },
+ // This uses the larger limits for 2^11 byte packets.
+ CryptoDxDirection::Write => match cipher {
+ TLS_AES_128_GCM_SHA256 | TLS_AES_256_GCM_SHA384 => 1 << 28,
+ TLS_CHACHA20_POLY1305_SHA256 => PacketNumber::MAX,
+ _ => unreachable!(),
+ },
+ }
+ }
+
+ fn invoked(&mut self) -> Res<()> {
+ #[cfg(test)]
+ OVERWRITE_INVOCATIONS.with(|v| {
+ if let Some(i) = v.borrow_mut().take() {
+ neqo_common::qwarn!("Setting {:?} invocations to {}", self.direction, i);
+ self.invocations = i;
+ }
+ });
+ self.invocations = self
+ .invocations
+ .checked_sub(1)
+ .ok_or(Error::KeysExhausted)?;
+ Ok(())
+ }
+
+ /// Determine whether we should initiate a key update.
+ pub fn should_update(&self) -> bool {
+ // There is no point in updating read keys as the limit is global.
+ debug_assert_eq!(self.direction, CryptoDxDirection::Write);
+ self.invocations <= UPDATE_WRITE_KEYS_AT
+ }
+
+ pub fn next(&self, next_secret: &SymKey, cipher: Cipher) -> Self {
+ let pn = self.next_pn();
+ // We count invocations of each write key just for that key, but all
+ // attempts to invocations to read count toward a single limit.
+ // This doesn't count use of Handshake keys.
+ let invocations = if self.direction == CryptoDxDirection::Read {
+ self.invocations
+ } else {
+ Self::limit(CryptoDxDirection::Write, cipher)
+ };
+ Self {
+ direction: self.direction,
+ epoch: self.epoch + 1,
+ aead: Aead::new(TLS_VERSION_1_3, cipher, next_secret, "quic ").unwrap(),
+ hpkey: self.hpkey.clone(),
+ used_pn: pn..pn,
+ min_pn: pn,
+ invocations,
+ }
+ }
+
+ #[must_use]
+ pub fn key_phase(&self) -> bool {
+ // Epoch 3 => 0, 4 => 1, 5 => 0, 6 => 1, ...
+ self.epoch & 1 != 1
+ }
+
+ /// This is a continuation of a previous, so adjust the range accordingly.
+ /// Fail if the two ranges overlap. Do nothing if the directions don't match.
+ pub fn continuation(&mut self, prev: &Self) -> Res<()> {
+ debug_assert_eq!(self.direction, prev.direction);
+ let next = prev.next_pn();
+ self.min_pn = next;
+ // TODO(mt) use Range::is_empty() when available
+ if self.used_pn.start == self.used_pn.end {
+ self.used_pn = next..next;
+ Ok(())
+ } else if prev.used_pn.end > self.used_pn.start {
+ qdebug!(
+ [self],
+ "Found packet with too new packet number {} > {}, compared to {}",
+ self.used_pn.start,
+ prev.used_pn.end,
+ prev,
+ );
+ Err(Error::PacketNumberOverlap)
+ } else {
+ self.used_pn.start = next;
+ Ok(())
+ }
+ }
+
+ /// Mark a packet number as used. If this is too low, reject it.
+ /// Note that this won't catch a value that is too high if packets protected with
+ /// old keys are received after a key update. That needs to be caught elsewhere.
+ pub fn used(&mut self, pn: PacketNumber) -> Res<()> {
+ if pn < self.min_pn {
+ qdebug!(
+ [self],
+ "Found packet with too old packet number: {} < {}",
+ pn,
+ self.min_pn
+ );
+ return Err(Error::PacketNumberOverlap);
+ }
+ if self.used_pn.start == self.used_pn.end {
+ self.used_pn.start = pn;
+ }
+ self.used_pn.end = max(pn + 1, self.used_pn.end);
+ Ok(())
+ }
+
+ #[must_use]
+ pub fn needs_update(&self) -> bool {
+ // Only initiate a key update if we have processed exactly one packet
+ // and we are in an epoch greater than 3.
+ self.used_pn.start + 1 == self.used_pn.end
+ && self.epoch > usize::from(TLS_EPOCH_APPLICATION_DATA)
+ }
+
+ #[must_use]
+ pub fn can_update(&self, largest_acknowledged: Option<PacketNumber>) -> bool {
+ if let Some(la) = largest_acknowledged {
+ self.used_pn.contains(&la)
+ } else {
+ // If we haven't received any acknowledgments, it's OK to update
+ // the first application data epoch.
+ self.epoch == usize::from(TLS_EPOCH_APPLICATION_DATA)
+ }
+ }
+
+ pub fn compute_mask(&self, sample: &[u8]) -> Res<Vec<u8>> {
+ let mask = self.hpkey.mask(sample)?;
+ qtrace!([self], "HP sample={} mask={}", hex(sample), hex(&mask));
+ Ok(mask)
+ }
+
+ #[must_use]
+ pub fn next_pn(&self) -> PacketNumber {
+ self.used_pn.end
+ }
+
+ pub fn encrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> {
+ debug_assert_eq!(self.direction, CryptoDxDirection::Write);
+ qtrace!(
+ [self],
+ "encrypt pn={} hdr={} body={}",
+ pn,
+ hex(hdr),
+ hex(body)
+ );
+ // The numbers in `Self::limit` assume a maximum packet size of 2^11.
+ assert!(body.len() <= 2048);
+ self.invoked()?;
+
+ let size = body.len() + MAX_AUTH_TAG;
+ let mut out = vec![0; size];
+ let res = self.aead.encrypt(pn, hdr, body, &mut out)?;
+
+ qtrace!([self], "encrypt ct={}", hex(res));
+ debug_assert_eq!(pn, self.next_pn());
+ self.used(pn)?;
+ Ok(res.to_vec())
+ }
+
+ #[must_use]
+ pub fn expansion(&self) -> usize {
+ self.aead.expansion()
+ }
+
+ pub fn decrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> {
+ debug_assert_eq!(self.direction, CryptoDxDirection::Read);
+ qtrace!(
+ [self],
+ "decrypt pn={} hdr={} body={}",
+ pn,
+ hex(hdr),
+ hex(body)
+ );
+ self.invoked()?;
+ let mut out = vec![0; body.len()];
+ let res = self.aead.decrypt(pn, hdr, body, &mut out)?;
+ self.used(pn)?;
+ Ok(res.to_vec())
+ }
+
+ #[cfg(test)]
+ pub(crate) fn test_default() -> Self {
+ // This matches the value in packet.rs
+ const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
+ Self::new_initial(
+ QuicVersion::default(),
+ CryptoDxDirection::Write,
+ "server in",
+ CLIENT_CID,
+ )
+ }
+
+ /// Get the amount of extra padding packets protected with this profile need.
+ /// This is the difference between the size of the header protection sample
+ /// and the AEAD expansion.
+ pub fn extra_padding(&self) -> usize {
+ self.hpkey
+ .sample_size()
+ .saturating_sub(self.aead.expansion())
+ }
+}
+
+impl std::fmt::Display for CryptoDxState {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "epoch {} {:?}", self.epoch, self.direction)
+ }
+}
+
+#[derive(Debug)]
+pub struct CryptoState {
+ tx: CryptoDxState,
+ rx: CryptoDxState,
+}
+
+impl Index<CryptoDxDirection> for CryptoState {
+ type Output = CryptoDxState;
+
+ fn index(&self, dir: CryptoDxDirection) -> &Self::Output {
+ match dir {
+ CryptoDxDirection::Read => &self.rx,
+ CryptoDxDirection::Write => &self.tx,
+ }
+ }
+}
+
+impl IndexMut<CryptoDxDirection> for CryptoState {
+ fn index_mut(&mut self, dir: CryptoDxDirection) -> &mut Self::Output {
+ match dir {
+ CryptoDxDirection::Read => &mut self.rx,
+ CryptoDxDirection::Write => &mut self.tx,
+ }
+ }
+}
+
+/// `CryptoDxAppData` wraps the state necessary for one direction of application data keys.
+/// This includes the secret needed to generate the next set of keys.
+#[derive(Debug)]
+pub(crate) struct CryptoDxAppData {
+ dx: CryptoDxState,
+ cipher: Cipher,
+ // Not the secret used to create `self.dx`, but the one needed for the next iteration.
+ next_secret: SymKey,
+}
+
+impl CryptoDxAppData {
+ pub fn new(dir: CryptoDxDirection, secret: SymKey, cipher: Cipher) -> Res<Self> {
+ Ok(Self {
+ dx: CryptoDxState::new(dir, TLS_EPOCH_APPLICATION_DATA, &secret, cipher),
+ cipher,
+ next_secret: Self::update_secret(cipher, &secret)?,
+ })
+ }
+
+ fn update_secret(cipher: Cipher, secret: &SymKey) -> Res<SymKey> {
+ let next = hkdf::expand_label(TLS_VERSION_1_3, cipher, secret, &[], "quic ku")?;
+ Ok(next)
+ }
+
+ pub fn next(&self) -> Res<Self> {
+ if self.dx.epoch == usize::max_value() {
+ // Guard against too many key updates.
+ return Err(Error::KeysExhausted);
+ }
+ let next_secret = Self::update_secret(self.cipher, &self.next_secret)?;
+ Ok(Self {
+ dx: self.dx.next(&self.next_secret, self.cipher),
+ cipher: self.cipher,
+ next_secret,
+ })
+ }
+
+ pub fn epoch(&self) -> usize {
+ self.dx.epoch
+ }
+}
+
+#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
+pub enum CryptoSpace {
+ Initial,
+ ZeroRtt,
+ Handshake,
+ ApplicationData,
+}
+
+#[derive(Debug, Default)]
+pub struct CryptoStates {
+ initial: Option<CryptoState>,
+ handshake: Option<CryptoState>,
+ zero_rtt: Option<CryptoDxState>, // One direction only!
+ cipher: Cipher,
+ app_write: Option<CryptoDxAppData>,
+ app_read: Option<CryptoDxAppData>,
+ app_read_next: Option<CryptoDxAppData>,
+ // If this is set, then we have noticed a genuine update.
+ // Once this time passes, we should switch in new keys.
+ read_update_time: Option<Instant>,
+}
+
+impl CryptoStates {
+ /// Select a `CryptoDxState` and `CryptoSpace` for the given `PNSpace`.
+ /// This selects 0-RTT keys for `PNSpace::ApplicationData` if 1-RTT keys are
+ /// not yet available.
+ pub fn select_tx(&mut self, space: PNSpace) -> Option<(CryptoSpace, &mut CryptoDxState)> {
+ match space {
+ PNSpace::Initial => self
+ .tx(CryptoSpace::Initial)
+ .map(|dx| (CryptoSpace::Initial, dx)),
+ PNSpace::Handshake => self
+ .tx(CryptoSpace::Handshake)
+ .map(|dx| (CryptoSpace::Handshake, dx)),
+ PNSpace::ApplicationData => {
+ if let Some(app) = self.app_write.as_mut() {
+ Some((CryptoSpace::ApplicationData, &mut app.dx))
+ } else {
+ self.zero_rtt.as_mut().map(|dx| (CryptoSpace::ZeroRtt, dx))
+ }
+ }
+ }
+ }
+
+ pub fn tx<'a>(&'a mut self, cspace: CryptoSpace) -> Option<&'a mut CryptoDxState> {
+ let tx = |k: Option<&'a mut CryptoState>| k.map(|dx| &mut dx.tx);
+ match cspace {
+ CryptoSpace::Initial => tx(self.initial.as_mut()),
+ CryptoSpace::ZeroRtt => self
+ .zero_rtt
+ .as_mut()
+ .filter(|z| z.direction == CryptoDxDirection::Write),
+ CryptoSpace::Handshake => tx(self.handshake.as_mut()),
+ CryptoSpace::ApplicationData => self.app_write.as_mut().map(|app| &mut app.dx),
+ }
+ }
+
+ pub fn rx_hp(&mut self, cspace: CryptoSpace) -> Option<&mut CryptoDxState> {
+ if let CryptoSpace::ApplicationData = cspace {
+ self.app_read.as_mut().map(|ar| &mut ar.dx)
+ } else {
+ self.rx(cspace, false)
+ }
+ }
+
+ pub fn rx<'a>(
+ &'a mut self,
+ cspace: CryptoSpace,
+ key_phase: bool,
+ ) -> Option<&'a mut CryptoDxState> {
+ let rx = |x: Option<&'a mut CryptoState>| x.map(|dx| &mut dx.rx);
+ match cspace {
+ CryptoSpace::Initial => rx(self.initial.as_mut()),
+ CryptoSpace::ZeroRtt => self
+ .zero_rtt
+ .as_mut()
+ .filter(|z| z.direction == CryptoDxDirection::Read),
+ CryptoSpace::Handshake => rx(self.handshake.as_mut()),
+ CryptoSpace::ApplicationData => {
+ let f = |a: Option<&'a mut CryptoDxAppData>| {
+ a.filter(|ar| ar.dx.key_phase() == key_phase)
+ };
+ // XOR to reduce the leakage about which key is chosen.
+ f(self.app_read.as_mut())
+ .xor(f(self.app_read_next.as_mut()))
+ .map(|ar| &mut ar.dx)
+ }
+ }
+ }
+
+ /// Whether keys for processing packets in the indicated space are pending.
+ /// This allows the caller to determine whether to save a packet for later
+ /// when keys are not available.
+ /// NOTE: 0-RTT keys are not considered here. The expectation is that a
+ /// server will have to save 0-RTT packets in a different place. Though it
+ /// is possible to attribute 0-RTT packets to an existing connection if there
+ /// is a multi-packet Initial, that is an unusual circumstance, so we
+ /// don't do caching for that in those places that call this function.
+ pub fn rx_pending(&self, space: CryptoSpace) -> bool {
+ match space {
+ CryptoSpace::Initial | CryptoSpace::ZeroRtt => false,
+ CryptoSpace::Handshake => self.handshake.is_none() && self.initial.is_some(),
+ CryptoSpace::ApplicationData => self.app_read.is_none(),
+ }
+ }
+
+ /// Create the initial crypto state.
+ pub fn init(&mut self, quic_version: QuicVersion, role: Role, dcid: &[u8]) {
+ const CLIENT_INITIAL_LABEL: &str = "client in";
+ const SERVER_INITIAL_LABEL: &str = "server in";
+
+ qinfo!(
+ [self],
+ "Creating initial cipher state role={:?} dcid={}",
+ role,
+ hex(dcid)
+ );
+
+ let (write, read) = match role {
+ Role::Client => (CLIENT_INITIAL_LABEL, SERVER_INITIAL_LABEL),
+ Role::Server => (SERVER_INITIAL_LABEL, CLIENT_INITIAL_LABEL),
+ };
+
+ let mut initial = CryptoState {
+ tx: CryptoDxState::new_initial(quic_version, CryptoDxDirection::Write, write, dcid),
+ rx: CryptoDxState::new_initial(quic_version, CryptoDxDirection::Read, read, dcid),
+ };
+ if let Some(prev) = &self.initial {
+ qinfo!(
+ [self],
+ "Continue packet numbers for initial after retry (write is {:?})",
+ prev.rx.used_pn,
+ );
+ initial.tx.continuation(&prev.tx).unwrap();
+ }
+ self.initial = Some(initial);
+ }
+
+ pub fn set_0rtt_keys(&mut self, dir: CryptoDxDirection, secret: &SymKey, cipher: Cipher) {
+ qtrace!([self], "install 0-RTT keys");
+ self.zero_rtt = Some(CryptoDxState::new(dir, TLS_EPOCH_ZERO_RTT, secret, cipher));
+ }
+
+ /// Discard keys and return true if that happened.
+ pub fn discard(&mut self, space: PNSpace) -> bool {
+ match space {
+ PNSpace::Initial => self.initial.take().is_some(),
+ PNSpace::Handshake => self.handshake.take().is_some(),
+ PNSpace::ApplicationData => panic!("Can't drop application data keys"),
+ }
+ }
+
+ pub fn discard_0rtt_keys(&mut self) {
+ qtrace!([self], "discard 0-RTT keys");
+ assert!(
+ self.app_read.is_none(),
+ "Can't discard 0-RTT after setting application keys"
+ );
+ self.zero_rtt = None;
+ }
+
+ pub fn set_handshake_keys(
+ &mut self,
+ write_secret: &SymKey,
+ read_secret: &SymKey,
+ cipher: Cipher,
+ ) {
+ self.cipher = cipher;
+ self.handshake = Some(CryptoState {
+ tx: CryptoDxState::new(
+ CryptoDxDirection::Write,
+ TLS_EPOCH_HANDSHAKE,
+ write_secret,
+ cipher,
+ ),
+ rx: CryptoDxState::new(
+ CryptoDxDirection::Read,
+ TLS_EPOCH_HANDSHAKE,
+ read_secret,
+ cipher,
+ ),
+ });
+ }
+
+ pub fn set_application_write_key(&mut self, secret: SymKey) -> Res<()> {
+ debug_assert!(self.app_write.is_none());
+ debug_assert_ne!(self.cipher, 0);
+ let mut app = CryptoDxAppData::new(CryptoDxDirection::Write, secret, self.cipher)?;
+ if let Some(z) = &self.zero_rtt {
+ if z.direction == CryptoDxDirection::Write {
+ app.dx.continuation(z)?;
+ }
+ }
+ self.zero_rtt = None;
+ self.app_write = Some(app);
+ Ok(())
+ }
+
+ pub fn set_application_read_key(&mut self, secret: SymKey, expire_0rtt: Instant) -> Res<()> {
+ debug_assert!(self.app_write.is_some(), "should have write keys installed");
+ debug_assert!(self.app_read.is_none());
+ let mut app = CryptoDxAppData::new(CryptoDxDirection::Read, secret, self.cipher)?;
+ if let Some(z) = &self.zero_rtt {
+ if z.direction == CryptoDxDirection::Read {
+ app.dx.continuation(z)?;
+ }
+ self.read_update_time = Some(expire_0rtt);
+ }
+ self.app_read_next = Some(app.next()?);
+ self.app_read = Some(app);
+ Ok(())
+ }
+
+ /// Update the write keys.
+ pub fn initiate_key_update(&mut self, largest_acknowledged: Option<PacketNumber>) -> Res<()> {
+ // Only update if we are able to. We can only do this if we have
+ // received an acknowledgement for a packet in the current phase.
+ // Also, skip this if we are waiting for read keys on the existing
+ // key update to be rolled over.
+ let write = &self.app_write.as_ref().unwrap().dx;
+ if write.can_update(largest_acknowledged) && self.read_update_time.is_none() {
+ // This call additionally checks that we don't advance to the next
+ // epoch while a key update is in progress.
+ if self.maybe_update_write()? {
+ Ok(())
+ } else {
+ qdebug!([self], "Write keys already updated");
+ Err(Error::KeyUpdateBlocked)
+ }
+ } else {
+ qdebug!([self], "Waiting for ACK or blocked on read key timer");
+ Err(Error::KeyUpdateBlocked)
+ }
+ }
+
+ /// Try to update, and return true if it happened.
+ fn maybe_update_write(&mut self) -> Res<bool> {
+ // Update write keys. But only do so if the write keys are not already
+ // ahead of the read keys. If we initiated the key update, the write keys
+ // will already be ahead.
+ debug_assert!(self.read_update_time.is_none());
+ let write = &self.app_write.as_ref().unwrap();
+ let read = &self.app_read.as_ref().unwrap();
+ if write.epoch() == read.epoch() {
+ qdebug!([self], "Update write keys to epoch={}", write.epoch() + 1);
+ self.app_write = Some(write.next()?);
+ Ok(true)
+ } else {
+ Ok(false)
+ }
+ }
+
+ /// Check whether write keys are close to running out of invocations.
+ /// If that is close, update them if possible. Failing to update at
+ /// this stage is cause for a fatal error.
+ pub fn auto_update(&mut self) -> Res<()> {
+ if let Some(app_write) = self.app_write.as_ref() {
+ if app_write.dx.should_update() {
+ qinfo!([self], "Initiating automatic key update");
+ if !self.maybe_update_write()? {
+ return Err(Error::KeysExhausted);
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn has_0rtt_read(&self) -> bool {
+ self.zero_rtt
+ .as_ref()
+ .filter(|z| z.direction == CryptoDxDirection::Read)
+ .is_some()
+ }
+
+ /// Prepare to update read keys. This doesn't happen immediately as
+ /// we want to ensure that we can continue to receive any delayed
+ /// packets that use the old keys. So we just set a timer.
+ pub fn key_update_received(&mut self, expiration: Instant) -> Res<()> {
+ qtrace!([self], "Key update received");
+ // If we received a key update, then we assume that the peer has
+ // acknowledged a packet we sent in this epoch. It's OK to do that
+ // because they aren't allowed to update without first having received
+ // something from us. If the ACK isn't in the packet that triggered this
+ // key update, it must be in some other packet they have sent.
+ let _ = self.maybe_update_write()?;
+
+ // We shouldn't have 0-RTT keys at this point, but if we do, dump them.
+ debug_assert_eq!(self.read_update_time.is_some(), self.has_0rtt_read());
+ if self.has_0rtt_read() {
+ self.zero_rtt = None;
+ }
+ self.read_update_time = Some(expiration);
+ Ok(())
+ }
+
+ #[must_use]
+ pub fn update_time(&self) -> Option<Instant> {
+ self.read_update_time
+ }
+
+ /// Check if time has passed for updating key update parameters.
+ /// If it has, then swap keys over and allow more key updates to be initiated.
+ /// This is also used to discard 0-RTT read keys at the server in the same way.
+ pub fn check_key_update(&mut self, now: Instant) -> Res<()> {
+ if let Some(expiry) = self.read_update_time {
+ // If enough time has passed, then install new keys and clear the timer.
+ if now >= expiry {
+ if self.has_0rtt_read() {
+ qtrace!([self], "Discarding 0-RTT keys");
+ self.zero_rtt = None;
+ } else {
+ qtrace!([self], "Rotating read keys");
+ mem::swap(&mut self.app_read, &mut self.app_read_next);
+ self.app_read_next = Some(self.app_read.as_ref().unwrap().next()?);
+ }
+ self.read_update_time = None;
+ }
+ }
+ Ok(())
+ }
+
+ /// Get the current/highest epoch. This returns (write, read) epochs.
+ #[cfg(test)]
+ pub fn get_epochs(&self) -> (Option<usize>, Option<usize>) {
+ let to_epoch = |app: &Option<CryptoDxAppData>| app.as_ref().map(|a| a.dx.epoch);
+ (to_epoch(&self.app_write), to_epoch(&self.app_read))
+ }
+
+ /// While we are awaiting the completion of a key update, we might receive
+ /// valid packets that are protected with old keys. We need to ensure that
+ /// these don't carry packet numbers higher than those in packets protected
+ /// with the newer keys. To ensure that, this is called after every decryption.
+ pub fn check_pn_overlap(&mut self) -> Res<()> {
+ // We only need to do the check while we are waiting for read keys to be updated.
+ if self.read_update_time.is_some() {
+ qtrace!([self], "Checking for PN overlap");
+ let next_dx = &mut self.app_read_next.as_mut().unwrap().dx;
+ next_dx.continuation(&self.app_read.as_ref().unwrap().dx)?;
+ }
+ Ok(())
+ }
+
+ /// Make some state for removing protection in tests.
+ #[cfg(test)]
+ pub(crate) fn test_default() -> Self {
+ let read = |epoch| {
+ let mut dx = CryptoDxState::test_default();
+ dx.direction = CryptoDxDirection::Read;
+ dx.epoch = epoch;
+ dx
+ };
+ let app_read = |epoch| CryptoDxAppData {
+ dx: read(epoch),
+ cipher: TLS_AES_128_GCM_SHA256,
+ next_secret: hkdf::import_key(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, &[0xaa; 32])
+ .unwrap(),
+ };
+ Self {
+ initial: Some(CryptoState {
+ tx: CryptoDxState::test_default(),
+ rx: read(0),
+ }),
+ handshake: None,
+ zero_rtt: None,
+ cipher: TLS_AES_128_GCM_SHA256,
+ // This isn't used, but the epoch is read to check for a key update.
+ app_write: Some(app_read(3)),
+ app_read: Some(app_read(3)),
+ app_read_next: Some(app_read(4)),
+ read_update_time: None,
+ }
+ }
+
+ #[cfg(test)]
+ pub(crate) fn test_chacha() -> Self {
+ const SECRET: &[u8] = &[
+ 0x9a, 0xc3, 0x12, 0xa7, 0xf8, 0x77, 0x46, 0x8e, 0xbe, 0x69, 0x42, 0x27, 0x48, 0xad,
+ 0x00, 0xa1, 0x54, 0x43, 0xf1, 0x82, 0x03, 0xa0, 0x7d, 0x60, 0x60, 0xf6, 0x88, 0xf3,
+ 0x0f, 0x21, 0x63, 0x2b,
+ ];
+ let secret =
+ hkdf::import_key(TLS_VERSION_1_3, TLS_CHACHA20_POLY1305_SHA256, SECRET).unwrap();
+ let app_read = |epoch| CryptoDxAppData {
+ dx: CryptoDxState {
+ direction: CryptoDxDirection::Read,
+ epoch,
+ aead: Aead::new(
+ TLS_VERSION_1_3,
+ TLS_CHACHA20_POLY1305_SHA256,
+ &secret,
+ "quic ",
+ )
+ .unwrap(),
+ hpkey: HpKey::extract(
+ TLS_VERSION_1_3,
+ TLS_CHACHA20_POLY1305_SHA256,
+ &secret,
+ "quic hp",
+ )
+ .unwrap(),
+ used_pn: 0..645_971_972,
+ min_pn: 0,
+ invocations: 10,
+ },
+ cipher: TLS_CHACHA20_POLY1305_SHA256,
+ next_secret: secret.clone(),
+ };
+ Self {
+ initial: None,
+ handshake: None,
+ zero_rtt: None,
+ cipher: TLS_CHACHA20_POLY1305_SHA256,
+ app_write: Some(app_read(3)),
+ app_read: Some(app_read(3)),
+ app_read_next: Some(app_read(4)),
+ read_update_time: None,
+ }
+ }
+}
+
+impl std::fmt::Display for CryptoStates {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "CryptoStates")
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct CryptoStream {
+ tx: TxBuffer,
+ rx: RxStreamOrderer,
+}
+
+#[derive(Debug)]
+#[allow(dead_code)] // Suppress false positive: https://github.com/rust-lang/rust/issues/68408
+pub enum CryptoStreams {
+ Initial {
+ initial: CryptoStream,
+ handshake: CryptoStream,
+ application: CryptoStream,
+ },
+ Handshake {
+ handshake: CryptoStream,
+ application: CryptoStream,
+ },
+ ApplicationData {
+ application: CryptoStream,
+ },
+}
+
+impl CryptoStreams {
+ pub fn discard(&mut self, space: PNSpace) {
+ match space {
+ PNSpace::Initial => {
+ if let Self::Initial {
+ handshake,
+ application,
+ ..
+ } = self
+ {
+ *self = Self::Handshake {
+ handshake: mem::take(handshake),
+ application: mem::take(application),
+ };
+ }
+ }
+ PNSpace::Handshake => {
+ if let Self::Handshake { application, .. } = self {
+ *self = Self::ApplicationData {
+ application: mem::take(application),
+ };
+ } else if matches!(self, Self::Initial {..}) {
+ panic!("Discarding handshake before initial discarded");
+ }
+ }
+ PNSpace::ApplicationData => panic!("Discarding application data crypto streams"),
+ }
+ }
+
+ pub fn send(&mut self, space: PNSpace, data: &[u8]) {
+ self.get_mut(space).unwrap().tx.send(data);
+ }
+
+ pub fn inbound_frame(&mut self, space: PNSpace, offset: u64, data: &[u8]) {
+ self.get_mut(space).unwrap().rx.inbound_frame(offset, data);
+ }
+
+ pub fn data_ready(&self, space: PNSpace) -> bool {
+ self.get(space).map_or(false, |cs| cs.rx.data_ready())
+ }
+
+ pub fn read_to_end(&mut self, space: PNSpace, buf: &mut Vec<u8>) -> usize {
+ self.get_mut(space).unwrap().rx.read_to_end(buf)
+ }
+
+ pub fn acked(&mut self, token: &CryptoRecoveryToken) {
+ self.get_mut(token.space)
+ .unwrap()
+ .tx
+ .mark_as_acked(token.offset, token.length);
+ }
+
+ pub fn lost(&mut self, token: &CryptoRecoveryToken) {
+ // See BZ 1624800, ignore lost packets in spaces we've dropped keys
+ if let Some(cs) = self.get_mut(token.space) {
+ cs.tx.mark_as_lost(token.offset, token.length);
+ }
+ }
+
+ /// Resend any Initial or Handshake CRYPTO frames that might be outstanding.
+ /// This can help speed up handshake times.
+ pub fn resend_unacked(&mut self, space: PNSpace) {
+ if space != PNSpace::ApplicationData {
+ if let Some(cs) = self.get_mut(space) {
+ cs.tx.unmark_sent();
+ }
+ }
+ }
+
+ fn get(&self, space: PNSpace) -> Option<&CryptoStream> {
+ let (initial, hs, app) = match self {
+ Self::Initial {
+ initial,
+ handshake,
+ application,
+ } => (Some(initial), Some(handshake), Some(application)),
+ Self::Handshake {
+ handshake,
+ application,
+ } => (None, Some(handshake), Some(application)),
+ Self::ApplicationData { application } => (None, None, Some(application)),
+ };
+ match space {
+ PNSpace::Initial => initial,
+ PNSpace::Handshake => hs,
+ PNSpace::ApplicationData => app,
+ }
+ }
+
+ fn get_mut(&mut self, space: PNSpace) -> Option<&mut CryptoStream> {
+ let (initial, hs, app) = match self {
+ Self::Initial {
+ initial,
+ handshake,
+ application,
+ } => (Some(initial), Some(handshake), Some(application)),
+ Self::Handshake {
+ handshake,
+ application,
+ } => (None, Some(handshake), Some(application)),
+ Self::ApplicationData { application } => (None, None, Some(application)),
+ };
+ match space {
+ PNSpace::Initial => initial,
+ PNSpace::Handshake => hs,
+ PNSpace::ApplicationData => app,
+ }
+ }
+
+ pub fn write_frame(
+ &mut self,
+ space: PNSpace,
+ builder: &mut PacketBuilder,
+ ) -> Option<RecoveryToken> {
+ let cs = self.get_mut(space).unwrap();
+ if let Some((offset, data)) = cs.tx.next_bytes() {
+ let mut header_len = 1 + Encoder::varint_len(offset) + 1;
+
+ // Don't bother if there isn't room for the header and some data.
+ if builder.remaining() < header_len + 1 {
+ return None;
+ }
+ // Calculate length of data based on the minimum of:
+ // - available data
+ // - remaining space, less the header, which counts only one byte
+ // for the length at first to avoid underestimating length
+ let length = min(data.len(), builder.remaining() - header_len);
+ header_len += Encoder::varint_len(u64::try_from(length).unwrap()) - 1;
+ let length = min(data.len(), builder.remaining() - header_len);
+
+ builder.encode_varint(crate::frame::FRAME_TYPE_CRYPTO);
+ builder.encode_varint(offset);
+ builder.encode_vvec(&data[..length]);
+ cs.tx.mark_as_sent(offset, length);
+
+ qdebug!("CRYPTO for {} offset={}, len={}", space, offset, length);
+ Some(RecoveryToken::Crypto(CryptoRecoveryToken {
+ space,
+ offset,
+ length,
+ }))
+ } else {
+ None
+ }
+ }
+}
+
+impl Default for CryptoStreams {
+ fn default() -> Self {
+ Self::Initial {
+ initial: CryptoStream::default(),
+ handshake: CryptoStream::default(),
+ application: CryptoStream::default(),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct CryptoRecoveryToken {
+ space: PNSpace,
+ offset: u64,
+ length: usize,
+}
diff --git a/third_party/rust/neqo-transport/src/dump.rs b/third_party/rust/neqo-transport/src/dump.rs
new file mode 100644
index 0000000000..e8f5b32ae9
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/dump.rs
@@ -0,0 +1,32 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Enable just this file for logging to just see packets.
+// e.g. "RUST_LOG=neqo_transport::dump neqo-client ..."
+
+use crate::connection::Connection;
+use crate::frame::Frame;
+use crate::packet::{PacketNumber, PacketType};
+use neqo_common::{qdebug, Decoder};
+
+#[allow(clippy::module_name_repetitions)]
+pub fn dump_packet(conn: &Connection, dir: &str, pt: PacketType, pn: PacketNumber, payload: &[u8]) {
+ let mut s = String::from("");
+ let mut d = Decoder::from(payload);
+ while d.remaining() > 0 {
+ let f = match Frame::decode(&mut d) {
+ Ok(f) => f,
+ Err(_) => {
+ s.push_str(" [broken]...");
+ break;
+ }
+ };
+ if let Some(x) = f.dump() {
+ s.push_str(&format!("\n {} {}", dir, &x));
+ }
+ }
+ qdebug!([conn], "pn={} type={:?}{}", pn, pt, s);
+}
diff --git a/third_party/rust/neqo-transport/src/events.rs b/third_party/rust/neqo-transport/src/events.rs
new file mode 100644
index 0000000000..fc998cf32c
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/events.rs
@@ -0,0 +1,254 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Collecting a list of events relevant to whoever is using the Connection.
+
+use std::cell::RefCell;
+use std::collections::VecDeque;
+use std::rc::Rc;
+
+use crate::connection::State;
+use crate::frame::StreamType;
+use crate::stream_id::StreamId;
+use crate::AppError;
+use neqo_common::event::Provider as EventProvider;
+use neqo_crypto::ResumptionToken;
+
+#[derive(Debug, PartialOrd, Ord, PartialEq, Eq)]
+pub enum ConnectionEvent {
+ /// Cert authentication needed
+ AuthenticationNeeded,
+ /// A new uni (read) or bidi stream has been opened by the peer.
+ NewStream {
+ stream_id: StreamId,
+ },
+ /// Space available in the buffer for an application write to succeed.
+ SendStreamWritable {
+ stream_id: StreamId,
+ },
+ /// New bytes available for reading.
+ RecvStreamReadable {
+ stream_id: u64,
+ },
+ /// Peer reset the stream.
+ RecvStreamReset {
+ stream_id: u64,
+ app_error: AppError,
+ },
+ /// Peer has sent STOP_SENDING
+ SendStreamStopSending {
+ stream_id: u64,
+ app_error: AppError,
+ },
+ /// Peer has acked everything sent on the stream.
+ SendStreamComplete {
+ stream_id: u64,
+ },
+ /// Peer increased MAX_STREAMS
+ SendStreamCreatable {
+ stream_type: StreamType,
+ },
+ /// Connection state change.
+ StateChange(State),
+ /// The server rejected 0-RTT.
+ /// This event invalidates all state in streams that has been created.
+ /// Any data written to streams needs to be written again.
+ ZeroRttRejected,
+ ResumptionToken(ResumptionToken),
+}
+
+#[derive(Debug, Default, Clone)]
+#[allow(clippy::module_name_repetitions)]
+pub struct ConnectionEvents {
+ events: Rc<RefCell<VecDeque<ConnectionEvent>>>,
+}
+
+impl ConnectionEvents {
+ pub fn authentication_needed(&self) {
+ self.insert(ConnectionEvent::AuthenticationNeeded);
+ }
+
+ pub fn new_stream(&self, stream_id: StreamId) {
+ self.insert(ConnectionEvent::NewStream { stream_id });
+ }
+
+ pub fn recv_stream_readable(&self, stream_id: StreamId) {
+ self.insert(ConnectionEvent::RecvStreamReadable {
+ stream_id: stream_id.as_u64(),
+ });
+ }
+
+ pub fn recv_stream_reset(&self, stream_id: StreamId, app_error: AppError) {
+ // If reset, no longer readable.
+ self.remove(|evt| matches!(evt, ConnectionEvent::RecvStreamReadable { stream_id: x } if *x == stream_id.as_u64()));
+
+ self.insert(ConnectionEvent::RecvStreamReset {
+ stream_id: stream_id.as_u64(),
+ app_error,
+ });
+ }
+
+ pub fn send_stream_writable(&self, stream_id: StreamId) {
+ self.insert(ConnectionEvent::SendStreamWritable { stream_id });
+ }
+
+ pub fn send_stream_stop_sending(&self, stream_id: StreamId, app_error: AppError) {
+ // If stopped, no longer writable.
+ self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamWritable { stream_id: x } if *x == stream_id.as_u64()));
+
+ self.insert(ConnectionEvent::SendStreamStopSending {
+ stream_id: stream_id.as_u64(),
+ app_error,
+ });
+ }
+
+ pub fn send_stream_complete(&self, stream_id: StreamId) {
+ self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamWritable { stream_id: x } if *x == stream_id));
+
+ self.remove(|evt| matches!(evt, ConnectionEvent::SendStreamStopSending { stream_id: x, .. } if *x == stream_id.as_u64()));
+
+ self.insert(ConnectionEvent::SendStreamComplete {
+ stream_id: stream_id.as_u64(),
+ });
+ }
+
+ pub fn send_stream_creatable(&self, stream_type: StreamType) {
+ self.insert(ConnectionEvent::SendStreamCreatable { stream_type });
+ }
+
+ pub fn connection_state_change(&self, state: State) {
+ // If closing, existing events no longer relevant.
+ match state {
+ State::Closing { .. } | State::Closed(_) => self.events.borrow_mut().clear(),
+ _ => (),
+ }
+ self.insert(ConnectionEvent::StateChange(state));
+ }
+
+ pub fn client_resumption_token(&self, token: ResumptionToken) {
+ self.insert(ConnectionEvent::ResumptionToken(token));
+ }
+
+ pub fn client_0rtt_rejected(&self) {
+ // If 0rtt rejected, must start over and existing events are no longer
+ // relevant.
+ self.events.borrow_mut().clear();
+ self.insert(ConnectionEvent::ZeroRttRejected);
+ }
+
+ pub fn recv_stream_complete(&self, stream_id: StreamId) {
+ // If stopped, no longer readable.
+ self.remove(|evt| matches!(evt, ConnectionEvent::RecvStreamReadable { stream_id: x } if *x == stream_id.as_u64()));
+ }
+
+ fn insert(&self, event: ConnectionEvent) {
+ let mut q = self.events.borrow_mut();
+
+ // Special-case two enums that are not strictly PartialEq equal but that
+ // we wish to avoid inserting duplicates.
+ let already_present = match &event {
+ ConnectionEvent::SendStreamStopSending { stream_id, .. } => q.iter().any(|evt| {
+ matches!(evt, ConnectionEvent::SendStreamStopSending { stream_id: x, .. }
+ if *x == *stream_id)
+ }),
+ ConnectionEvent::RecvStreamReset { stream_id, .. } => q.iter().any(|evt| {
+ matches!(evt, ConnectionEvent::RecvStreamReset { stream_id: x, .. }
+ if *x == *stream_id)
+ }),
+ _ => q.contains(&event),
+ };
+ if !already_present {
+ q.push_back(event);
+ }
+ }
+
+ fn remove<F>(&self, f: F)
+ where
+ F: Fn(&ConnectionEvent) -> bool,
+ {
+ self.events.borrow_mut().retain(|evt| !f(evt))
+ }
+}
+
+impl EventProvider for ConnectionEvents {
+ type Event = ConnectionEvent;
+
+ fn has_events(&self) -> bool {
+ !self.events.borrow().is_empty()
+ }
+
+ fn next_event(&mut self) -> Option<Self::Event> {
+ self.events.borrow_mut().pop_front()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{ConnectionError, Error};
+
+ #[test]
+ fn event_culling() {
+ let mut evts = ConnectionEvents::default();
+
+ evts.client_0rtt_rejected();
+ evts.client_0rtt_rejected();
+ assert_eq!(evts.events().count(), 1);
+ assert_eq!(evts.events().count(), 0);
+
+ evts.new_stream(4.into());
+ evts.new_stream(4.into());
+ assert_eq!(evts.events().count(), 1);
+
+ evts.recv_stream_readable(6.into());
+ evts.recv_stream_reset(6.into(), 66);
+ evts.recv_stream_reset(6.into(), 65);
+ assert_eq!(evts.events().count(), 1);
+
+ evts.send_stream_writable(8.into());
+ evts.send_stream_writable(8.into());
+ evts.send_stream_stop_sending(8.into(), 55);
+ evts.send_stream_stop_sending(8.into(), 56);
+ let events = evts.events().collect::<Vec<_>>();
+ assert_eq!(events.len(), 1);
+ assert_eq!(
+ events[0],
+ ConnectionEvent::SendStreamStopSending {
+ stream_id: 8,
+ app_error: 55
+ }
+ );
+
+ evts.send_stream_writable(8.into());
+ evts.send_stream_writable(8.into());
+ evts.send_stream_stop_sending(8.into(), 55);
+ evts.send_stream_stop_sending(8.into(), 56);
+ evts.send_stream_complete(8.into());
+ assert_eq!(evts.events().count(), 1);
+
+ evts.send_stream_writable(8.into());
+ evts.send_stream_writable(9.into());
+ evts.send_stream_stop_sending(10.into(), 55);
+ evts.send_stream_stop_sending(11.into(), 56);
+ evts.send_stream_complete(12.into());
+ assert_eq!(evts.events().count(), 5);
+
+ evts.send_stream_writable(8.into());
+ evts.send_stream_writable(9.into());
+ evts.send_stream_stop_sending(10.into(), 55);
+ evts.send_stream_stop_sending(11.into(), 56);
+ evts.send_stream_complete(12.into());
+ evts.client_0rtt_rejected();
+ assert_eq!(evts.events().count(), 1);
+
+ evts.send_stream_writable(9.into());
+ evts.send_stream_stop_sending(10.into(), 55);
+ evts.connection_state_change(State::Closed(ConnectionError::Transport(
+ Error::StreamStateError,
+ )));
+ assert_eq!(evts.events().count(), 1);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/flow_mgr.rs b/third_party/rust/neqo-transport/src/flow_mgr.rs
new file mode 100644
index 0000000000..0ae7fd3c00
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/flow_mgr.rs
@@ -0,0 +1,400 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Tracks possibly-redundant flow control signals from other code and converts
+// into flow control frames needing to be sent to the remote.
+
+use std::collections::HashMap;
+use std::mem;
+
+use neqo_common::{qinfo, qwarn, Encoder};
+use smallvec::{smallvec, SmallVec};
+
+use crate::frame::{Frame, StreamType};
+use crate::packet::PacketBuilder;
+use crate::recovery::RecoveryToken;
+use crate::recv_stream::RecvStreams;
+use crate::send_stream::SendStreams;
+use crate::stats::FrameStats;
+use crate::stream_id::{StreamId, StreamIndex, StreamIndexes};
+use crate::AppError;
+
+type FlowFrame = Frame<'static>;
+pub type FlowControlRecoveryToken = FlowFrame;
+
+#[derive(Debug, Default)]
+pub struct FlowMgr {
+ // Discriminant as key ensures only 1 of every frame type will be queued.
+ from_conn: HashMap<mem::Discriminant<FlowFrame>, FlowFrame>,
+
+ // (id, discriminant) as key ensures only 1 of every frame type per stream
+ // will be queued.
+ from_streams: HashMap<(StreamId, mem::Discriminant<FlowFrame>), FlowFrame>,
+
+ // (stream_type, discriminant) as key ensures only 1 of every frame type
+ // per stream type will be queued.
+ from_stream_types: HashMap<(StreamType, mem::Discriminant<FlowFrame>), FlowFrame>,
+
+ used_data: u64,
+ max_data: u64,
+}
+
+impl FlowMgr {
+ pub fn conn_credit_avail(&self) -> u64 {
+ self.max_data - self.used_data
+ }
+
+ pub fn conn_increase_credit_used(&mut self, amount: u64) {
+ self.used_data += amount;
+ assert!(self.used_data <= self.max_data)
+ }
+
+ // Dummy DataBlocked frame for discriminant use below
+
+ /// Returns whether max credit was actually increased.
+ pub fn conn_increase_max_credit(&mut self, new: u64) -> bool {
+ const DB_FRAME: FlowFrame = Frame::DataBlocked { data_limit: 0 };
+
+ if new > self.max_data {
+ self.max_data = new;
+ self.from_conn.remove(&mem::discriminant(&DB_FRAME));
+
+ true
+ } else {
+ false
+ }
+ }
+
+ // -- frames scoped on connection --
+
+ pub fn data_blocked(&mut self) {
+ let frame = Frame::DataBlocked {
+ data_limit: self.max_data,
+ };
+ self.from_conn.insert(mem::discriminant(&frame), frame);
+ }
+
+ pub fn path_response(&mut self, data: [u8; 8]) {
+ let frame = Frame::PathResponse { data };
+ self.from_conn.insert(mem::discriminant(&frame), frame);
+ }
+
+ pub fn max_data(&mut self, maximum_data: u64) {
+ let frame = Frame::MaxData { maximum_data };
+ self.from_conn.insert(mem::discriminant(&frame), frame);
+ }
+
+ // -- frames scoped on stream --
+
+ /// Indicate to receiving remote the stream is reset
+ pub fn stream_reset(
+ &mut self,
+ stream_id: StreamId,
+ application_error_code: AppError,
+ final_size: u64,
+ ) {
+ let frame = Frame::ResetStream {
+ stream_id,
+ application_error_code,
+ final_size,
+ };
+ self.from_streams
+ .insert((stream_id, mem::discriminant(&frame)), frame);
+ }
+
+ /// Indicate to sending remote we are no longer interested in the stream
+ pub fn stop_sending(&mut self, stream_id: StreamId, application_error_code: AppError) {
+ let frame = Frame::StopSending {
+ stream_id,
+ application_error_code,
+ };
+ self.from_streams
+ .insert((stream_id, mem::discriminant(&frame)), frame);
+ }
+
+ /// Update sending remote with more credits
+ pub fn max_stream_data(&mut self, stream_id: StreamId, maximum_stream_data: u64) {
+ let frame = Frame::MaxStreamData {
+ stream_id,
+ maximum_stream_data,
+ };
+ self.from_streams
+ .insert((stream_id, mem::discriminant(&frame)), frame);
+ }
+
+ /// Don't send stream data updates if no more data is coming
+ pub fn clear_max_stream_data(&mut self, stream_id: StreamId) {
+ let frame = Frame::MaxStreamData {
+ stream_id,
+ maximum_stream_data: 0,
+ };
+ self.from_streams
+ .remove(&(stream_id, mem::discriminant(&frame)));
+ }
+
+ /// Indicate to receiving remote we need more credits
+ pub fn stream_data_blocked(&mut self, stream_id: StreamId, stream_data_limit: u64) {
+ let frame = Frame::StreamDataBlocked {
+ stream_id,
+ stream_data_limit,
+ };
+ self.from_streams
+ .insert((stream_id, mem::discriminant(&frame)), frame);
+ }
+
+ // -- frames scoped on stream type --
+
+ pub fn max_streams(&mut self, stream_limit: StreamIndex, stream_type: StreamType) {
+ let frame = Frame::MaxStreams {
+ stream_type,
+ maximum_streams: stream_limit,
+ };
+ self.from_stream_types
+ .insert((stream_type, mem::discriminant(&frame)), frame);
+ }
+
+ pub fn streams_blocked(&mut self, stream_limit: StreamIndex, stream_type: StreamType) {
+ let frame = Frame::StreamsBlocked {
+ stream_type,
+ stream_limit,
+ };
+ self.from_stream_types
+ .insert((stream_type, mem::discriminant(&frame)), frame);
+ }
+
+ pub fn peek(&self) -> Option<&Frame> {
+ if let Some(key) = self.from_conn.keys().next() {
+ self.from_conn.get(key)
+ } else if let Some(key) = self.from_streams.keys().next() {
+ self.from_streams.get(key)
+ } else if let Some(key) = self.from_stream_types.keys().next() {
+ self.from_stream_types.get(key)
+ } else {
+ None
+ }
+ }
+
+ pub(crate) fn acked(
+ &mut self,
+ token: &FlowControlRecoveryToken,
+ send_streams: &mut SendStreams,
+ ) {
+ const RESET_STREAM: &Frame = &Frame::ResetStream {
+ stream_id: StreamId::new(0),
+ application_error_code: 0,
+ final_size: 0,
+ };
+
+ if let Frame::ResetStream { stream_id, .. } = token {
+ qinfo!("Reset received stream={}", stream_id.as_u64());
+
+ if self
+ .from_streams
+ .remove(&(*stream_id, mem::discriminant(RESET_STREAM)))
+ .is_some()
+ {
+ qinfo!("Removed RESET_STREAM frame for {}", stream_id.as_u64());
+ }
+
+ send_streams.reset_acked(*stream_id);
+ }
+ }
+
+ pub(crate) fn lost(
+ &mut self,
+ token: &FlowControlRecoveryToken,
+ send_streams: &mut SendStreams,
+ recv_streams: &mut RecvStreams,
+ indexes: &mut StreamIndexes,
+ ) {
+ match *token {
+ // Always resend ResetStream if lost
+ Frame::ResetStream {
+ stream_id,
+ application_error_code,
+ final_size,
+ } => {
+ qinfo!(
+ "Reset lost stream={} err={} final_size={}",
+ stream_id.as_u64(),
+ application_error_code,
+ final_size
+ );
+ if send_streams.get(stream_id).is_ok() {
+ self.stream_reset(stream_id, application_error_code, final_size);
+ }
+ }
+ // Resend MaxStreams if lost (with updated value)
+ Frame::MaxStreams { stream_type, .. } => {
+ let local_max = match stream_type {
+ StreamType::BiDi => &mut indexes.local_max_stream_bidi,
+ StreamType::UniDi => &mut indexes.local_max_stream_uni,
+ };
+
+ self.max_streams(*local_max, stream_type)
+ }
+ // Only resend "*Blocked" frames if still blocked
+ Frame::DataBlocked { .. } => {
+ if self.conn_credit_avail() == 0 {
+ self.data_blocked()
+ }
+ }
+ Frame::StreamDataBlocked { stream_id, .. } => {
+ if let Ok(ss) = send_streams.get(stream_id) {
+ if ss.credit_avail() == 0 {
+ self.stream_data_blocked(stream_id, ss.max_stream_data())
+ }
+ }
+ }
+ Frame::StreamsBlocked { stream_type, .. } => match stream_type {
+ StreamType::UniDi => {
+ if indexes.remote_next_stream_uni >= indexes.remote_max_stream_uni {
+ self.streams_blocked(indexes.remote_max_stream_uni, StreamType::UniDi);
+ }
+ }
+ StreamType::BiDi => {
+ if indexes.remote_next_stream_bidi >= indexes.remote_max_stream_bidi {
+ self.streams_blocked(indexes.remote_max_stream_bidi, StreamType::BiDi);
+ }
+ }
+ },
+ // Resend StopSending
+ Frame::StopSending {
+ stream_id,
+ application_error_code,
+ } => self.stop_sending(stream_id, application_error_code),
+ Frame::MaxStreamData { stream_id, .. } => {
+ if let Some(rs) = recv_streams.get_mut(&stream_id) {
+ if let Some(msd) = rs.max_stream_data() {
+ self.max_stream_data(stream_id, msd)
+ }
+ }
+ }
+ Frame::PathResponse { .. } => qinfo!("Path Response lost, not re-sent"),
+ _ => qwarn!("Unexpected Flow frame {:?} lost, not re-sent", token),
+ }
+ }
+
+ pub(crate) fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ while let Some(frame) = self.peek() {
+ // All these frames are bags of varints, so we can just extract the
+ // varints and use common code for writing.
+ let values: SmallVec<[_; 3]> = match frame {
+ Frame::ResetStream {
+ stream_id,
+ application_error_code,
+ final_size,
+ } => {
+ stats.reset_stream += 1;
+ smallvec![stream_id.as_u64(), *application_error_code, *final_size]
+ }
+ Frame::StopSending {
+ stream_id,
+ application_error_code,
+ } => {
+ stats.stop_sending += 1;
+ smallvec![stream_id.as_u64(), *application_error_code]
+ }
+
+ Frame::MaxStreams {
+ maximum_streams, ..
+ } => {
+ stats.max_streams += 1;
+ smallvec![maximum_streams.as_u64()]
+ }
+ Frame::StreamsBlocked { stream_limit, .. } => {
+ stats.streams_blocked += 1;
+ smallvec![stream_limit.as_u64()]
+ }
+
+ Frame::MaxData { maximum_data } => {
+ stats.max_data += 1;
+ smallvec![*maximum_data]
+ }
+ Frame::DataBlocked { data_limit } => {
+ stats.data_blocked += 1;
+ smallvec![*data_limit]
+ }
+
+ Frame::MaxStreamData {
+ stream_id,
+ maximum_stream_data,
+ } => {
+ stats.max_stream_data += 1;
+ smallvec![stream_id.as_u64(), *maximum_stream_data]
+ }
+ Frame::StreamDataBlocked {
+ stream_id,
+ stream_data_limit,
+ } => {
+ stats.stream_data_blocked += 1;
+ smallvec![stream_id.as_u64(), *stream_data_limit]
+ }
+
+ // A special case, just write it out and move on..
+ Frame::PathResponse { data } => {
+ stats.path_response += 1;
+ if builder.remaining() >= Encoder::varint_len(frame.get_type()) + data.len() {
+ builder.encode_varint(frame.get_type());
+ builder.encode(data);
+ tokens.push(RecoveryToken::Flow(self.next().unwrap()));
+ continue;
+ } else {
+ return;
+ }
+ }
+
+ _ => unreachable!("{:?}", frame),
+ };
+ debug_assert!(!values.spilled());
+
+ if builder.remaining()
+ >= Encoder::varint_len(frame.get_type())
+ + values
+ .iter()
+ .map(|&v| Encoder::varint_len(v))
+ .sum::<usize>()
+ {
+ builder.encode_varint(frame.get_type());
+ for v in values {
+ builder.encode_varint(v);
+ }
+ tokens.push(RecoveryToken::Flow(self.next().unwrap()));
+ } else {
+ return;
+ }
+ }
+ }
+}
+
+impl Iterator for FlowMgr {
+ type Item = FlowFrame;
+
+ /// Used by generator to get a flow control frame.
+ fn next(&mut self) -> Option<Self::Item> {
+ let first_key = self.from_conn.keys().next();
+ if let Some(&first_key) = first_key {
+ return self.from_conn.remove(&first_key);
+ }
+
+ let first_key = self.from_streams.keys().next();
+ if let Some(&first_key) = first_key {
+ return self.from_streams.remove(&first_key);
+ }
+
+ let first_key = self.from_stream_types.keys().next();
+ if let Some(&first_key) = first_key {
+ return self.from_stream_types.remove(&first_key);
+ }
+
+ None
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/frame.rs b/third_party/rust/neqo-transport/src/frame.rs
new file mode 100644
index 0000000000..5f3e7c9b78
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/frame.rs
@@ -0,0 +1,835 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Directly relating to QUIC frames.
+
+use neqo_common::{qtrace, Decoder};
+
+use crate::cid::MAX_CONNECTION_ID_LEN;
+use crate::packet::PacketType;
+use crate::stream_id::{StreamId, StreamIndex};
+use crate::{AppError, ConnectionError, Error, Res, TransportError, ERROR_APPLICATION_CLOSE};
+
+use std::convert::TryFrom;
+use std::ops::RangeInclusive;
+
+#[allow(clippy::module_name_repetitions)]
+pub type FrameType = u64;
+
+const FRAME_TYPE_PADDING: FrameType = 0x0;
+pub const FRAME_TYPE_PING: FrameType = 0x1;
+pub const FRAME_TYPE_ACK: FrameType = 0x2;
+const FRAME_TYPE_ACK_ECN: FrameType = 0x3;
+const FRAME_TYPE_RST_STREAM: FrameType = 0x4;
+const FRAME_TYPE_STOP_SENDING: FrameType = 0x5;
+pub const FRAME_TYPE_CRYPTO: FrameType = 0x6;
+pub const FRAME_TYPE_NEW_TOKEN: FrameType = 0x7;
+const FRAME_TYPE_STREAM: FrameType = 0x8;
+const FRAME_TYPE_STREAM_MAX: FrameType = 0xf;
+const FRAME_TYPE_MAX_DATA: FrameType = 0x10;
+const FRAME_TYPE_MAX_STREAM_DATA: FrameType = 0x11;
+const FRAME_TYPE_MAX_STREAMS_BIDI: FrameType = 0x12;
+const FRAME_TYPE_MAX_STREAMS_UNIDI: FrameType = 0x13;
+const FRAME_TYPE_DATA_BLOCKED: FrameType = 0x14;
+const FRAME_TYPE_STREAM_DATA_BLOCKED: FrameType = 0x15;
+const FRAME_TYPE_STREAMS_BLOCKED_BIDI: FrameType = 0x16;
+const FRAME_TYPE_STREAMS_BLOCKED_UNIDI: FrameType = 0x17;
+const FRAME_TYPE_NEW_CONNECTION_ID: FrameType = 0x18;
+const FRAME_TYPE_RETIRE_CONNECTION_ID: FrameType = 0x19;
+const FRAME_TYPE_PATH_CHALLENGE: FrameType = 0x1a;
+const FRAME_TYPE_PATH_RESPONSE: FrameType = 0x1b;
+pub const FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT: FrameType = 0x1c;
+pub const FRAME_TYPE_CONNECTION_CLOSE_APPLICATION: FrameType = 0x1d;
+const FRAME_TYPE_HANDSHAKE_DONE: FrameType = 0x1e;
+
+const STREAM_FRAME_BIT_FIN: u64 = 0x01;
+const STREAM_FRAME_BIT_LEN: u64 = 0x02;
+const STREAM_FRAME_BIT_OFF: u64 = 0x04;
+
+/// `FRAME_APPLICATION_CLOSE` is the default CONNECTION_CLOSE frame that
+/// is sent when an application error code needs to be sent in an
+/// Initial or Handshake packet.
+const FRAME_APPLICATION_CLOSE: &Frame = &Frame::ConnectionClose {
+ error_code: CloseError::Transport(ERROR_APPLICATION_CLOSE),
+ frame_type: 0,
+ reason_phrase: Vec::new(),
+};
+
+#[derive(PartialEq, Debug, Copy, Clone, PartialOrd, Eq, Ord, Hash)]
+/// Bi-Directional or Uni-Directional.
+pub enum StreamType {
+ BiDi,
+ UniDi,
+}
+
+impl StreamType {
+ fn frame_type_bit(self) -> u64 {
+ match self {
+ Self::BiDi => 0,
+ Self::UniDi => 1,
+ }
+ }
+ fn from_type_bit(bit: u64) -> Self {
+ if (bit & 0x01) == 0 {
+ Self::BiDi
+ } else {
+ Self::UniDi
+ }
+ }
+}
+
+#[derive(PartialEq, Eq, Debug, PartialOrd, Ord, Clone, Copy)]
+pub enum CloseError {
+ Transport(TransportError),
+ Application(AppError),
+}
+
+impl CloseError {
+ fn frame_type_bit(self) -> u64 {
+ match self {
+ Self::Transport(_) => 0,
+ Self::Application(_) => 1,
+ }
+ }
+
+ fn from_type_bit(bit: u64, code: u64) -> Self {
+ if (bit & 0x01) == 0 {
+ Self::Transport(code)
+ } else {
+ Self::Application(code)
+ }
+ }
+
+ pub fn code(&self) -> u64 {
+ match self {
+ Self::Transport(c) | Self::Application(c) => *c,
+ }
+ }
+}
+
+impl From<ConnectionError> for CloseError {
+ fn from(err: ConnectionError) -> Self {
+ match err {
+ ConnectionError::Transport(c) => Self::Transport(c.code()),
+ ConnectionError::Application(c) => Self::Application(c),
+ }
+ }
+}
+
+#[derive(PartialEq, Debug, Default, Clone)]
+pub struct AckRange {
+ pub(crate) gap: u64,
+ pub(crate) range: u64,
+}
+
+#[derive(PartialEq, Debug, Clone)]
+pub enum Frame<'a> {
+ Padding,
+ Ping,
+ Ack {
+ largest_acknowledged: u64,
+ ack_delay: u64,
+ first_ack_range: u64,
+ ack_ranges: Vec<AckRange>,
+ },
+ ResetStream {
+ stream_id: StreamId,
+ application_error_code: AppError,
+ final_size: u64,
+ },
+ StopSending {
+ stream_id: StreamId,
+ application_error_code: AppError,
+ },
+ Crypto {
+ offset: u64,
+ data: &'a [u8],
+ },
+ NewToken {
+ token: &'a [u8],
+ },
+ Stream {
+ stream_id: StreamId,
+ offset: u64,
+ data: &'a [u8],
+ fin: bool,
+ fill: bool,
+ },
+ MaxData {
+ maximum_data: u64,
+ },
+ MaxStreamData {
+ stream_id: StreamId,
+ maximum_stream_data: u64,
+ },
+ MaxStreams {
+ stream_type: StreamType,
+ maximum_streams: StreamIndex,
+ },
+ DataBlocked {
+ data_limit: u64,
+ },
+ StreamDataBlocked {
+ stream_id: StreamId,
+ stream_data_limit: u64,
+ },
+ StreamsBlocked {
+ stream_type: StreamType,
+ stream_limit: StreamIndex,
+ },
+ NewConnectionId {
+ sequence_number: u64,
+ retire_prior: u64,
+ connection_id: &'a [u8],
+ stateless_reset_token: &'a [u8; 16],
+ },
+ RetireConnectionId {
+ sequence_number: u64,
+ },
+ PathChallenge {
+ data: [u8; 8],
+ },
+ PathResponse {
+ data: [u8; 8],
+ },
+ ConnectionClose {
+ error_code: CloseError,
+ frame_type: u64,
+ // Not a reference as we use this to hold the value.
+ // This is not used in optimized builds anyway.
+ reason_phrase: Vec<u8>,
+ },
+ HandshakeDone,
+}
+
+impl<'a> Frame<'a> {
+ pub fn get_type(&self) -> FrameType {
+ match self {
+ Self::Padding => FRAME_TYPE_PADDING,
+ Self::Ping => FRAME_TYPE_PING,
+ Self::Ack { .. } => FRAME_TYPE_ACK, // We don't do ACK ECN.
+ Self::ResetStream { .. } => FRAME_TYPE_RST_STREAM,
+ Self::StopSending { .. } => FRAME_TYPE_STOP_SENDING,
+ Self::Crypto { .. } => FRAME_TYPE_CRYPTO,
+ Self::NewToken { .. } => FRAME_TYPE_NEW_TOKEN,
+ Self::Stream {
+ fin, offset, fill, ..
+ } => Self::stream_type(*fin, *offset > 0, *fill),
+ Self::MaxData { .. } => FRAME_TYPE_MAX_DATA,
+ Self::MaxStreamData { .. } => FRAME_TYPE_MAX_STREAM_DATA,
+ Self::MaxStreams { stream_type, .. } => {
+ FRAME_TYPE_MAX_STREAMS_BIDI + stream_type.frame_type_bit()
+ }
+ Self::DataBlocked { .. } => FRAME_TYPE_DATA_BLOCKED,
+ Self::StreamDataBlocked { .. } => FRAME_TYPE_STREAM_DATA_BLOCKED,
+ Self::StreamsBlocked { stream_type, .. } => {
+ FRAME_TYPE_STREAMS_BLOCKED_BIDI + stream_type.frame_type_bit()
+ }
+ Self::NewConnectionId { .. } => FRAME_TYPE_NEW_CONNECTION_ID,
+ Self::RetireConnectionId { .. } => FRAME_TYPE_RETIRE_CONNECTION_ID,
+ Self::PathChallenge { .. } => FRAME_TYPE_PATH_CHALLENGE,
+ Self::PathResponse { .. } => FRAME_TYPE_PATH_RESPONSE,
+ Self::ConnectionClose { error_code, .. } => {
+ FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT + error_code.frame_type_bit()
+ }
+ Self::HandshakeDone => FRAME_TYPE_HANDSHAKE_DONE,
+ }
+ }
+
+ pub fn stream_type(fin: bool, nonzero_offset: bool, fill: bool) -> u64 {
+ let mut t = FRAME_TYPE_STREAM;
+ if fin {
+ t |= STREAM_FRAME_BIT_FIN;
+ }
+ if nonzero_offset {
+ t |= STREAM_FRAME_BIT_OFF;
+ }
+ if !fill {
+ t |= STREAM_FRAME_BIT_LEN;
+ }
+ t
+ }
+
+ /// Convert a CONNECTION_CLOSE into a nicer CONNECTION_CLOSE.
+ pub fn sanitize_close(&self) -> &Self {
+ if let Self::ConnectionClose { error_code, .. } = &self {
+ if let CloseError::Application(_) = error_code {
+ FRAME_APPLICATION_CLOSE
+ } else {
+ self
+ }
+ } else {
+ panic!("Attempted to sanitize a non-close frame");
+ }
+ }
+
+ pub fn ack_eliciting(&self) -> bool {
+ !matches!(self, Self::Ack { .. } | Self::Padding | Self::ConnectionClose { .. })
+ }
+
+ /// Converts AckRanges as encoded in a ACK frame (see -transport
+ /// 19.3.1) into ranges of acked packets (end, start), inclusive of
+ /// start and end values.
+ pub fn decode_ack_frame(
+ largest_acked: u64,
+ first_ack_range: u64,
+ ack_ranges: &[AckRange],
+ ) -> Res<Vec<RangeInclusive<u64>>> {
+ let mut acked_ranges = Vec::with_capacity(ack_ranges.len() + 1);
+
+ if largest_acked < first_ack_range {
+ return Err(Error::FrameEncodingError);
+ }
+ acked_ranges.push((largest_acked - first_ack_range)..=largest_acked);
+ if !ack_ranges.is_empty() && largest_acked < first_ack_range + 1 {
+ return Err(Error::FrameEncodingError);
+ }
+ let mut cur = if ack_ranges.is_empty() {
+ 0
+ } else {
+ largest_acked - first_ack_range - 1
+ };
+ for r in ack_ranges {
+ if cur < r.gap + 1 {
+ return Err(Error::FrameEncodingError);
+ }
+ cur = cur - r.gap - 1;
+
+ if cur < r.range {
+ return Err(Error::FrameEncodingError);
+ }
+ acked_ranges.push((cur - r.range)..=cur);
+
+ if cur > r.range + 1 {
+ cur -= r.range + 1;
+ } else {
+ cur -= r.range;
+ }
+ }
+
+ Ok(acked_ranges)
+ }
+
+ pub fn dump(&self) -> Option<String> {
+ match self {
+ Self::Crypto { offset, data } => Some(format!(
+ "Crypto {{ offset: {}, len: {} }}",
+ offset,
+ data.len()
+ )),
+ Self::Stream {
+ stream_id,
+ offset,
+ fill,
+ data,
+ fin,
+ } => Some(format!(
+ "Stream {{ stream_id: {}, offset: {}, len: {}{}, fin: {} }}",
+ stream_id.as_u64(),
+ offset,
+ if *fill { ">>" } else { "" },
+ data.len(),
+ fin,
+ )),
+ Self::Padding => None,
+ _ => Some(format!("{:?}", self)),
+ }
+ }
+
+ pub fn is_allowed(&self, pt: PacketType) -> bool {
+ match self {
+ Self::Padding | Self::Ping => true,
+ Self::Crypto { .. }
+ | Self::Ack { .. }
+ | Self::ConnectionClose {
+ error_code: CloseError::Transport(_),
+ ..
+ } => pt != PacketType::ZeroRtt,
+ Self::NewToken { .. } | Self::ConnectionClose { .. } => pt == PacketType::Short,
+ _ => pt == PacketType::ZeroRtt || pt == PacketType::Short,
+ }
+ }
+
+ pub fn decode(dec: &mut Decoder<'a>) -> Res<Self> {
+ fn d<T>(v: Option<T>) -> Res<T> {
+ v.ok_or(Error::NoMoreData)
+ }
+ fn dv(dec: &mut Decoder) -> Res<u64> {
+ d(dec.decode_varint())
+ }
+
+ // TODO(ekr@rtfm.com): check for minimal encoding
+ let t = d(dec.decode_varint())?;
+ match t {
+ FRAME_TYPE_PADDING => Ok(Self::Padding),
+ FRAME_TYPE_PING => Ok(Self::Ping),
+ FRAME_TYPE_RST_STREAM => Ok(Self::ResetStream {
+ stream_id: StreamId::from(dv(dec)?),
+ application_error_code: d(dec.decode_varint())?,
+ final_size: match dec.decode_varint() {
+ Some(v) => v,
+ _ => return Err(Error::NoMoreData),
+ },
+ }),
+ FRAME_TYPE_ACK | FRAME_TYPE_ACK_ECN => {
+ let la = dv(dec)?;
+ let ad = dv(dec)?;
+ let nr = dv(dec)?;
+ let fa = dv(dec)?;
+ let mut arr: Vec<AckRange> = Vec::with_capacity(nr as usize);
+ for _ in 0..nr {
+ let ar = AckRange {
+ gap: dv(dec)?,
+ range: dv(dec)?,
+ };
+ arr.push(ar);
+ }
+
+ // Now check for the values for ACK_ECN.
+ if t == FRAME_TYPE_ACK_ECN {
+ dv(dec)?;
+ dv(dec)?;
+ dv(dec)?;
+ }
+
+ Ok(Self::Ack {
+ largest_acknowledged: la,
+ ack_delay: ad,
+ first_ack_range: fa,
+ ack_ranges: arr,
+ })
+ }
+ FRAME_TYPE_STOP_SENDING => Ok(Self::StopSending {
+ stream_id: StreamId::from(dv(dec)?),
+ application_error_code: d(dec.decode_varint())?,
+ }),
+ FRAME_TYPE_CRYPTO => {
+ let offset = dv(dec)?;
+ let data = d(dec.decode_vvec())?;
+ if offset + u64::try_from(data.len()).unwrap() > ((1 << 62) - 1) {
+ return Err(Error::FrameEncodingError);
+ }
+ Ok(Self::Crypto { offset, data })
+ }
+ FRAME_TYPE_NEW_TOKEN => {
+ let token = d(dec.decode_vvec())?;
+ if token.is_empty() {
+ return Err(Error::FrameEncodingError);
+ }
+ Ok(Self::NewToken { token })
+ }
+ FRAME_TYPE_STREAM..=FRAME_TYPE_STREAM_MAX => {
+ let s = dv(dec)?;
+ let o = if t & STREAM_FRAME_BIT_OFF == 0 {
+ 0
+ } else {
+ dv(dec)?
+ };
+ let fill = (t & STREAM_FRAME_BIT_LEN) == 0;
+ let data = if fill {
+ qtrace!("STREAM frame, extends to the end of the packet");
+ dec.decode_remainder()
+ } else {
+ qtrace!("STREAM frame, with length");
+ d(dec.decode_vvec())?
+ };
+ if o + u64::try_from(data.len()).unwrap() > ((1 << 62) - 1) {
+ return Err(Error::FrameEncodingError);
+ }
+ Ok(Self::Stream {
+ fin: (t & STREAM_FRAME_BIT_FIN) != 0,
+ stream_id: StreamId::from(s),
+ offset: o,
+ data,
+ fill,
+ })
+ }
+ FRAME_TYPE_MAX_DATA => Ok(Self::MaxData {
+ maximum_data: dv(dec)?,
+ }),
+ FRAME_TYPE_MAX_STREAM_DATA => Ok(Self::MaxStreamData {
+ stream_id: StreamId::from(dv(dec)?),
+ maximum_stream_data: dv(dec)?,
+ }),
+ FRAME_TYPE_MAX_STREAMS_BIDI | FRAME_TYPE_MAX_STREAMS_UNIDI => {
+ let m = dv(dec)?;
+ if m > (1 << 60) {
+ return Err(Error::StreamLimitError);
+ }
+ Ok(Self::MaxStreams {
+ stream_type: StreamType::from_type_bit(t),
+ maximum_streams: StreamIndex::new(m),
+ })
+ }
+ FRAME_TYPE_DATA_BLOCKED => Ok(Self::DataBlocked {
+ data_limit: dv(dec)?,
+ }),
+ FRAME_TYPE_STREAM_DATA_BLOCKED => Ok(Self::StreamDataBlocked {
+ stream_id: dv(dec)?.into(),
+ stream_data_limit: dv(dec)?,
+ }),
+ FRAME_TYPE_STREAMS_BLOCKED_BIDI | FRAME_TYPE_STREAMS_BLOCKED_UNIDI => {
+ Ok(Self::StreamsBlocked {
+ stream_type: StreamType::from_type_bit(t),
+ stream_limit: StreamIndex::new(dv(dec)?),
+ })
+ }
+ FRAME_TYPE_NEW_CONNECTION_ID => {
+ let sequence_number = dv(dec)?;
+ let retire_prior = dv(dec)?;
+ let connection_id = d(dec.decode_vec(1))?;
+ if connection_id.len() > MAX_CONNECTION_ID_LEN {
+ return Err(Error::DecodingFrame);
+ }
+ let srt = d(dec.decode(16))?;
+ let stateless_reset_token = <&[_; 16]>::try_from(srt).unwrap();
+
+ Ok(Self::NewConnectionId {
+ sequence_number,
+ retire_prior,
+ connection_id,
+ stateless_reset_token,
+ })
+ }
+ FRAME_TYPE_RETIRE_CONNECTION_ID => Ok(Self::RetireConnectionId {
+ sequence_number: dv(dec)?,
+ }),
+ FRAME_TYPE_PATH_CHALLENGE => {
+ let data = d(dec.decode(8))?;
+ let mut datav: [u8; 8] = [0; 8];
+ datav.copy_from_slice(&data);
+ Ok(Self::PathChallenge { data: datav })
+ }
+ FRAME_TYPE_PATH_RESPONSE => {
+ let data = d(dec.decode(8))?;
+ let mut datav: [u8; 8] = [0; 8];
+ datav.copy_from_slice(&data);
+ Ok(Self::PathResponse { data: datav })
+ }
+ FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT | FRAME_TYPE_CONNECTION_CLOSE_APPLICATION => {
+ let error_code = CloseError::from_type_bit(t, d(dec.decode_varint())?);
+ let frame_type = if t == FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT {
+ dv(dec)?
+ } else {
+ 0
+ };
+ // We can tolerate this copy for now.
+ let reason_phrase = d(dec.decode_vvec())?.to_vec();
+ Ok(Self::ConnectionClose {
+ error_code,
+ frame_type,
+ reason_phrase,
+ })
+ }
+ FRAME_TYPE_HANDSHAKE_DONE => Ok(Self::HandshakeDone),
+ _ => Err(Error::UnknownFrameType),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use neqo_common::{Decoder, Encoder};
+
+ fn just_dec(f: &Frame, s: &str) {
+ let encoded = Encoder::from_hex(s);
+ let decoded = Frame::decode(&mut encoded.as_decoder()).unwrap();
+ assert_eq!(*f, decoded);
+ }
+
+ #[test]
+ fn padding() {
+ let f = Frame::Padding;
+ just_dec(&f, "00");
+ }
+
+ #[test]
+ fn ping() {
+ let f = Frame::Ping;
+ just_dec(&f, "01");
+ }
+
+ #[test]
+ fn ack() {
+ let ar = vec![AckRange { gap: 1, range: 2 }, AckRange { gap: 3, range: 4 }];
+
+ let f = Frame::Ack {
+ largest_acknowledged: 0x1234,
+ ack_delay: 0x1235,
+ first_ack_range: 0x1236,
+ ack_ranges: ar,
+ };
+
+ just_dec(&f, "025234523502523601020304");
+
+ // Try to parse ACK_ECN without ECN values
+ let enc = Encoder::from_hex("035234523502523601020304");
+ let mut dec = enc.as_decoder();
+ assert_eq!(Frame::decode(&mut dec).unwrap_err(), Error::NoMoreData);
+
+ // Try to parse ACK_ECN without ECN values
+ let enc = Encoder::from_hex("035234523502523601020304010203");
+ let mut dec = enc.as_decoder();
+ assert_eq!(Frame::decode(&mut dec).unwrap(), f);
+ }
+
+ #[test]
+ fn reset_stream() {
+ let f = Frame::ResetStream {
+ stream_id: StreamId::from(0x1234),
+ application_error_code: 0x77,
+ final_size: 0x3456,
+ };
+
+ just_dec(&f, "04523440777456");
+ }
+
+ #[test]
+ fn stop_sending() {
+ let f = Frame::StopSending {
+ stream_id: StreamId::from(63),
+ application_error_code: 0x77,
+ };
+
+ just_dec(&f, "053F4077")
+ }
+
+ #[test]
+ fn crypto() {
+ let f = Frame::Crypto {
+ offset: 1,
+ data: &[1, 2, 3],
+ };
+
+ just_dec(&f, "060103010203");
+ }
+
+ #[test]
+ fn new_token() {
+ let f = Frame::NewToken {
+ token: &[0x12, 0x34, 0x56],
+ };
+
+ just_dec(&f, "0703123456");
+ }
+
+ #[test]
+ fn empty_new_token() {
+ let mut dec = Decoder::from(&[0x07, 0x00][..]);
+ assert_eq!(
+ Frame::decode(&mut dec).unwrap_err(),
+ Error::FrameEncodingError
+ );
+ }
+
+ #[test]
+ fn stream() {
+ // First, just set the length bit.
+ let f = Frame::Stream {
+ fin: false,
+ stream_id: StreamId::from(5),
+ offset: 0,
+ data: &[1, 2, 3],
+ fill: false,
+ };
+
+ just_dec(&f, "0a0503010203");
+
+ // Now with offset != 0 and FIN
+ let f = Frame::Stream {
+ fin: true,
+ stream_id: StreamId::from(5),
+ offset: 1,
+ data: &[1, 2, 3],
+ fill: false,
+ };
+ just_dec(&f, "0f050103010203");
+
+ // Now to fill the packet.
+ let f = Frame::Stream {
+ fin: true,
+ stream_id: StreamId::from(5),
+ offset: 0,
+ data: &[1, 2, 3],
+ fill: true,
+ };
+ just_dec(&f, "0905010203");
+ }
+
+ #[test]
+ fn max_data() {
+ let f = Frame::MaxData {
+ maximum_data: 0x1234,
+ };
+
+ just_dec(&f, "105234");
+ }
+
+ #[test]
+ fn max_stream_data() {
+ let f = Frame::MaxStreamData {
+ stream_id: StreamId::from(5),
+ maximum_stream_data: 0x1234,
+ };
+
+ just_dec(&f, "11055234");
+ }
+
+ #[test]
+ fn max_streams() {
+ let mut f = Frame::MaxStreams {
+ stream_type: StreamType::BiDi,
+ maximum_streams: StreamIndex::new(0x1234),
+ };
+
+ just_dec(&f, "125234");
+
+ f = Frame::MaxStreams {
+ stream_type: StreamType::UniDi,
+ maximum_streams: StreamIndex::new(0x1234),
+ };
+
+ just_dec(&f, "135234");
+ }
+
+ #[test]
+ fn data_blocked() {
+ let f = Frame::DataBlocked { data_limit: 0x1234 };
+
+ just_dec(&f, "145234");
+ }
+
+ #[test]
+ fn stream_data_blocked() {
+ let f = Frame::StreamDataBlocked {
+ stream_id: StreamId::from(5),
+ stream_data_limit: 0x1234,
+ };
+
+ just_dec(&f, "15055234");
+ }
+
+ #[test]
+ fn streams_blocked() {
+ let mut f = Frame::StreamsBlocked {
+ stream_type: StreamType::BiDi,
+ stream_limit: StreamIndex::new(0x1234),
+ };
+
+ just_dec(&f, "165234");
+
+ f = Frame::StreamsBlocked {
+ stream_type: StreamType::UniDi,
+ stream_limit: StreamIndex::new(0x1234),
+ };
+
+ just_dec(&f, "175234");
+ }
+
+ #[test]
+ fn new_connection_id() {
+ let f = Frame::NewConnectionId {
+ sequence_number: 0x1234,
+ retire_prior: 0,
+ connection_id: &[0x01, 0x02],
+ stateless_reset_token: &[9; 16],
+ };
+
+ just_dec(&f, "1852340002010209090909090909090909090909090909");
+ }
+
+ #[test]
+ fn too_large_new_connection_id() {
+ let mut enc = Encoder::from_hex("18523400"); // up to the CID
+ enc.encode_vvec(&[0x0c; MAX_CONNECTION_ID_LEN + 10]);
+ enc.encode(&[0x11; 16][..]);
+ assert_eq!(
+ Frame::decode(&mut enc.as_decoder()).unwrap_err(),
+ Error::DecodingFrame
+ );
+ }
+
+ #[test]
+ fn retire_connection_id() {
+ let f = Frame::RetireConnectionId {
+ sequence_number: 0x1234,
+ };
+
+ just_dec(&f, "195234");
+ }
+
+ #[test]
+ fn path_challenge() {
+ let f = Frame::PathChallenge { data: [9; 8] };
+
+ just_dec(&f, "1a0909090909090909");
+ }
+
+ #[test]
+ fn path_response() {
+ let f = Frame::PathResponse { data: [9; 8] };
+
+ just_dec(&f, "1b0909090909090909");
+ }
+
+ #[test]
+ fn connection_close_transport() {
+ let f = Frame::ConnectionClose {
+ error_code: CloseError::Transport(0x5678),
+ frame_type: 0x1234,
+ reason_phrase: vec![0x01, 0x02, 0x03],
+ };
+
+ just_dec(&f, "1c80005678523403010203");
+ }
+
+ #[test]
+ fn connection_close_application() {
+ let f = Frame::ConnectionClose {
+ error_code: CloseError::Application(0x5678),
+ frame_type: 0,
+ reason_phrase: vec![0x01, 0x02, 0x03],
+ };
+
+ just_dec(&f, "1d8000567803010203");
+ }
+
+ #[test]
+ fn test_compare() {
+ let f1 = Frame::Padding;
+ let f2 = Frame::Padding;
+ let f3 = Frame::Crypto {
+ offset: 0,
+ data: &[1, 2, 3],
+ };
+ let f4 = Frame::Crypto {
+ offset: 0,
+ data: &[1, 2, 3],
+ };
+ let f5 = Frame::Crypto {
+ offset: 1,
+ data: &[1, 2, 3],
+ };
+ let f6 = Frame::Crypto {
+ offset: 0,
+ data: &[1, 2, 4],
+ };
+
+ assert_eq!(f1, f2);
+ assert_ne!(f1, f3);
+ assert_eq!(f3, f4);
+ assert_ne!(f3, f5);
+ assert_ne!(f3, f6);
+ }
+
+ #[test]
+ fn decode_ack_frame() {
+ let res = Frame::decode_ack_frame(7, 2, &[AckRange { gap: 0, range: 3 }]);
+ assert!(res.is_ok());
+ assert_eq!(res.unwrap(), vec![5..=7, 0..=3]);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/lib.rs b/third_party/rust/neqo-transport/src/lib.rs
new file mode 100644
index 0000000000..30c9e783af
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/lib.rs
@@ -0,0 +1,195 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![cfg_attr(feature = "deny-warnings", deny(warnings))]
+#![warn(clippy::use_self)]
+
+use neqo_common::qinfo;
+
+mod addr_valid;
+mod cc;
+mod cid;
+mod connection;
+mod crypto;
+mod dump;
+mod events;
+mod flow_mgr;
+mod frame;
+mod pace;
+mod packet;
+mod path;
+mod qlog;
+mod recovery;
+mod recv_stream;
+mod send_stream;
+mod sender;
+pub mod server;
+mod stats;
+mod stream_id;
+pub mod tparams;
+mod tracking;
+
+pub use self::cc::CongestionControlAlgorithm;
+pub use self::cid::{ConnectionId, ConnectionIdManager};
+pub use self::connection::{
+ params::ConnectionParameters, Connection, FixedConnectionIdManager, Output, State,
+ ZeroRttState, LOCAL_STREAM_LIMIT_BIDI, LOCAL_STREAM_LIMIT_UNI,
+};
+pub use self::events::{ConnectionEvent, ConnectionEvents};
+pub use self::frame::{CloseError, StreamType};
+pub use self::packet::QuicVersion;
+pub use self::sender::PacketSender;
+pub use self::stats::Stats;
+pub use self::stream_id::StreamId;
+
+pub use self::recv_stream::RECV_BUFFER_SIZE;
+pub use self::send_stream::SEND_BUFFER_SIZE;
+
+pub type TransportError = u64;
+const ERROR_APPLICATION_CLOSE: TransportError = 12;
+const ERROR_AEAD_LIMIT_REACHED: TransportError = 15;
+
+#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)]
+#[allow(clippy::pub_enum_variant_names)]
+pub enum Error {
+ NoError,
+ InternalError,
+ ConnectionRefused,
+ FlowControlError,
+ StreamLimitError,
+ StreamStateError,
+ FinalSizeError,
+ FrameEncodingError,
+ TransportParameterError,
+ ProtocolViolation,
+ InvalidToken,
+ ApplicationError,
+ CryptoError(neqo_crypto::Error),
+ QlogError,
+ CryptoAlert(u8),
+
+ // All internal errors from here.
+ AckedUnsentPacket,
+ ConnectionState,
+ DecodingFrame,
+ DecryptError,
+ HandshakeFailed,
+ IdleTimeout,
+ IntegerOverflow,
+ InvalidInput,
+ InvalidMigration,
+ InvalidPacket,
+ InvalidResumptionToken,
+ InvalidRetry,
+ InvalidStreamId,
+ KeysDiscarded,
+ /// Packet protection keys are exhausted.
+ /// Also used when too many key updates have happened.
+ KeysExhausted,
+ /// Packet protection keys aren't available yet for the identified space.
+ KeysPending(crypto::CryptoSpace),
+ /// An attempt to update keys can be blocked if
+ /// a packet sent with the current keys hasn't been acknowledged.
+ KeyUpdateBlocked,
+ NoMoreData,
+ NotConnected,
+ PacketNumberOverlap,
+ PeerApplicationError(AppError),
+ PeerError(TransportError),
+ StatelessReset,
+ TooMuchData,
+ UnexpectedMessage,
+ UnknownFrameType,
+ VersionNegotiation,
+ WrongRole,
+}
+
+impl Error {
+ pub fn code(&self) -> TransportError {
+ match self {
+ Self::NoError
+ | Self::IdleTimeout
+ | Self::PeerError(_)
+ | Self::PeerApplicationError(_) => 0,
+ Self::ConnectionRefused => 2,
+ Self::FlowControlError => 3,
+ Self::StreamLimitError => 4,
+ Self::StreamStateError => 5,
+ Self::FinalSizeError => 6,
+ Self::FrameEncodingError => 7,
+ Self::TransportParameterError => 8,
+ Self::ProtocolViolation => 10,
+ Self::InvalidToken => 11,
+ Self::KeysExhausted => ERROR_AEAD_LIMIT_REACHED,
+ Self::ApplicationError => ERROR_APPLICATION_CLOSE,
+ Self::CryptoAlert(a) => 0x100 + u64::from(*a),
+ // All the rest are internal errors.
+ _ => 1,
+ }
+ }
+}
+
+impl From<neqo_crypto::Error> for Error {
+ fn from(err: neqo_crypto::Error) -> Self {
+ qinfo!("Crypto operation failed {:?}", err);
+ Self::CryptoError(err)
+ }
+}
+
+impl From<::qlog::Error> for Error {
+ fn from(_err: ::qlog::Error) -> Self {
+ Self::QlogError
+ }
+}
+
+impl From<std::num::TryFromIntError> for Error {
+ fn from(_: std::num::TryFromIntError) -> Self {
+ Self::IntegerOverflow
+ }
+}
+
+impl ::std::error::Error for Error {
+ fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
+ match self {
+ Self::CryptoError(e) => Some(e),
+ _ => None,
+ }
+ }
+}
+
+impl ::std::fmt::Display for Error {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "Transport error: {:?}", self)
+ }
+}
+
+pub type AppError = u64;
+
+#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)]
+pub enum ConnectionError {
+ Transport(Error),
+ Application(AppError),
+}
+
+impl ConnectionError {
+ pub fn app_code(&self) -> Option<AppError> {
+ match self {
+ Self::Application(e) => Some(*e),
+ _ => None,
+ }
+ }
+}
+
+impl From<CloseError> for ConnectionError {
+ fn from(err: CloseError) -> Self {
+ match err {
+ CloseError::Transport(c) => Self::Transport(Error::PeerError(c)),
+ CloseError::Application(c) => Self::Application(c),
+ }
+ }
+}
+
+pub type Res<T> = std::result::Result<T, Error>;
diff --git a/third_party/rust/neqo-transport/src/pace.rs b/third_party/rust/neqo-transport/src/pace.rs
new file mode 100644
index 0000000000..624fb8622e
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/pace.rs
@@ -0,0 +1,138 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Pacer
+#![deny(clippy::pedantic)]
+
+use neqo_common::qtrace;
+
+use std::cmp::min;
+use std::convert::TryFrom;
+use std::fmt::{Debug, Display};
+use std::time::{Duration, Instant};
+
+/// This value determines how much faster the pacer operates than the
+/// congestion window.
+///
+/// A value of 1 would cause all packets to be spaced over the entire RTT,
+/// which is a little slow and might act as an additional restriction in
+/// the case the congestion controller increases the congestion window.
+/// This value spaces packets over half the congestion window, which matches
+/// our current congestion controller, which double the window every RTT.
+const PACER_SPEEDUP: usize = 2;
+
+/// A pacer that uses a leaky bucket.
+pub struct Pacer {
+ /// The last update time.
+ t: Instant,
+ /// The maximum capacity, or burst size, in bytes.
+ m: usize,
+ /// The current capacity, in bytes.
+ c: usize,
+ /// The packet size or minimum capacity for sending, in bytes.
+ p: usize,
+}
+
+impl Pacer {
+ /// Create a new `Pacer`. This takes the current time and the
+ /// initial congestion window.
+ ///
+ /// The value of `m` is the maximum capacity. `m` primes the pacer
+ /// with credit and determines the burst size. `m` must not exceed
+ /// the initial congestion window, but it should probably be lower.
+ ///
+ /// The value of `p` is the packet size, which determines the minimum
+ /// credit needed before a packet is sent. This should be a substantial
+ /// fraction of the maximum packet size, if not the packet size.
+ pub fn new(now: Instant, m: usize, p: usize) -> Self {
+ assert!(m >= p, "maximum capacity has to be at least one packet");
+ Self { t: now, m, c: m, p }
+ }
+
+ /// Determine when the next packet will be available based on the provided RTT
+ /// and congestion window. This doesn't update state.
+ /// This returns a time, which could be in the past (this object doesn't know what
+ /// the current time is).
+ pub fn next(&self, rtt: Duration, cwnd: usize) -> Instant {
+ if self.c >= self.p {
+ qtrace!([self], "next {}/{:?} no wait = {:?}", cwnd, rtt, self.t);
+ self.t
+ } else {
+ // This is the inverse of the function in `spend`:
+ // self.t + rtt * (self.p - self.c) / (PACER_SPEEDUP * cwnd)
+ let r = rtt.as_nanos();
+ let d = r.saturating_mul(u128::try_from(self.p - self.c).unwrap());
+ let add = d / u128::try_from(cwnd * PACER_SPEEDUP).unwrap();
+ let w = u64::try_from(add).map(Duration::from_nanos).unwrap_or(rtt);
+ let nxt = self.t + w;
+ qtrace!([self], "next {}/{:?} wait {:?} = {:?}", cwnd, rtt, w, nxt);
+ nxt
+ }
+ }
+
+ /// Spend credit. This cannot fail; users of this API are expected to call
+ /// next() to determine when to spend. This takes the current time (`now`),
+ /// an estimate of the round trip time (`rtt`), the estimated congestion
+ /// window (`cwnd`), and the number of bytes that were sent (`count`).
+ pub fn spend(&mut self, now: Instant, rtt: Duration, cwnd: usize, count: usize) {
+ qtrace!([self], "spend {} over {}, {:?}", count, cwnd, rtt);
+ // Increase the capacity by:
+ // `(now - self.t) * PACER_SPEEDUP * cwnd / rtt`
+ // That is, the elapsed fraction of the RTT times rate that data is added.
+ let incr = now
+ .saturating_duration_since(self.t)
+ .as_nanos()
+ .saturating_mul(u128::try_from(cwnd * PACER_SPEEDUP).unwrap())
+ .checked_div(rtt.as_nanos())
+ .and_then(|i| usize::try_from(i).ok())
+ .unwrap_or(self.m);
+
+ // Add the capacity up to a limit of `self.m`, then subtract `count`.
+ self.c = min(self.m, (self.c + incr).saturating_sub(count));
+ self.t = now;
+ }
+}
+
+impl Display for Pacer {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "Pacer {}/{}", self.c, self.p)
+ }
+}
+
+impl Debug for Pacer {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "Pacer@{:?} {}/{}..{}", self.t, self.c, self.p, self.m)
+ }
+}
+
+#[cfg(tests)]
+mod tests {
+ use super::Pacer;
+ use test_fixture::now;
+
+ const RTT: Duration = Duration::from_millis(1000);
+ const PACKET: usize = 1000;
+ const CWND: usize = PACKET * 10;
+
+ #[test]
+ fn even() {
+ let mut n = now();
+ let p = Pacer::new(n, PACKET, PACKET);
+ assert_eq!(p.next(RTT, CWND), None);
+ p.spend(n, RTT, CWND, PACKET);
+ assert_eq!(p.next(RTT, CWND), Some(n + (RTT / 10)));
+ }
+
+ #[test]
+ fn backwards_in_time() {
+ let mut n = now();
+ let p = Pacer::new(n + RTT, PACKET, PACKET);
+ assert_eq!(p.next(RTT, CWND), None);
+ // Now spend some credit in the past using a time machine.
+ p.spend(n, RTT, CWND, PACKET);
+ assert_eq!(p.next(RTT, CWND), Some(n + (RTT / 10)));
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/packet/mod.rs b/third_party/rust/neqo-transport/src/packet/mod.rs
new file mode 100644
index 0000000000..9ab4d811a1
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/packet/mod.rs
@@ -0,0 +1,1339 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Encoding and decoding packets off the wire.
+use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdRef, MAX_CONNECTION_ID_LEN};
+use crate::crypto::{CryptoDxState, CryptoSpace, CryptoStates};
+use crate::{Error, Res};
+
+use neqo_common::{hex, hex_with_len, qtrace, qwarn, Decoder, Encoder};
+use neqo_crypto::random;
+
+use std::cmp::min;
+use std::convert::TryFrom;
+use std::fmt;
+use std::iter::ExactSizeIterator;
+use std::ops::{Deref, DerefMut, Range};
+use std::time::Instant;
+
+const PACKET_TYPE_INITIAL: u8 = 0x0;
+const PACKET_TYPE_0RTT: u8 = 0x01;
+const PACKET_TYPE_HANDSHAKE: u8 = 0x2;
+const PACKET_TYPE_RETRY: u8 = 0x03;
+
+pub const PACKET_BIT_LONG: u8 = 0x80;
+const PACKET_BIT_SHORT: u8 = 0x00;
+const PACKET_BIT_FIXED_QUIC: u8 = 0x40;
+const PACKET_BIT_SPIN: u8 = 0x20;
+const PACKET_BIT_KEY_PHASE: u8 = 0x04;
+
+const PACKET_HP_MASK_LONG: u8 = 0x0f;
+const PACKET_HP_MASK_SHORT: u8 = 0x1f;
+
+const SAMPLE_SIZE: usize = 16;
+const SAMPLE_OFFSET: usize = 4;
+const MAX_PACKET_NUMBER_LEN: usize = 4;
+
+mod retry;
+
+pub type PacketNumber = u64;
+type Version = u32;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum PacketType {
+ VersionNegotiation,
+ Initial,
+ Handshake,
+ ZeroRtt,
+ Retry,
+ Short,
+ OtherVersion,
+}
+
+impl PacketType {
+ #[must_use]
+ fn code(self) -> u8 {
+ match self {
+ Self::Initial => PACKET_TYPE_INITIAL,
+ Self::ZeroRtt => PACKET_TYPE_0RTT,
+ Self::Handshake => PACKET_TYPE_HANDSHAKE,
+ Self::Retry => PACKET_TYPE_RETRY,
+ _ => panic!("shouldn't be here"),
+ }
+ }
+}
+
+impl Into<CryptoSpace> for PacketType {
+ fn into(self) -> CryptoSpace {
+ match self {
+ Self::Initial => CryptoSpace::Initial,
+ Self::ZeroRtt => CryptoSpace::ZeroRtt,
+ Self::Handshake => CryptoSpace::Handshake,
+ Self::Short => CryptoSpace::ApplicationData,
+ _ => panic!("shouldn't be here"),
+ }
+ }
+}
+
+impl From<CryptoSpace> for PacketType {
+ fn from(cs: CryptoSpace) -> Self {
+ match cs {
+ CryptoSpace::Initial => Self::Initial,
+ CryptoSpace::ZeroRtt => Self::ZeroRtt,
+ CryptoSpace::Handshake => Self::Handshake,
+ CryptoSpace::ApplicationData => Self::Short,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum QuicVersion {
+ Draft27,
+ Draft28,
+ Draft29,
+ Draft30,
+ Draft31,
+ Draft32,
+}
+
+impl QuicVersion {
+ pub fn as_u32(self) -> Version {
+ match self {
+ Self::Draft27 => 0xff00_0000 + 27,
+ Self::Draft28 => 0xff00_0000 + 28,
+ Self::Draft29 => 0xff00_0000 + 29,
+ Self::Draft30 => 0xff00_0000 + 30,
+ Self::Draft31 => 0xff00_0000 + 31,
+ Self::Draft32 => 0xff00_0000 + 32,
+ }
+ }
+}
+
+impl Default for QuicVersion {
+ fn default() -> Self {
+ Self::Draft29
+ }
+}
+
+impl TryFrom<Version> for QuicVersion {
+ type Error = Error;
+
+ fn try_from(ver: Version) -> Res<Self> {
+ if ver == 0xff00_0000 + 27 {
+ Ok(Self::Draft27)
+ } else if ver == 0xff00_0000 + 28 {
+ Ok(Self::Draft28)
+ } else if ver == 0xff00_0000 + 29 {
+ Ok(Self::Draft29)
+ } else if ver == 0xff00_0000 + 30 {
+ Ok(Self::Draft30)
+ } else if ver == 0xff00_0000 + 31 {
+ Ok(Self::Draft31)
+ } else if ver == 0xff00_0000 + 32 {
+ Ok(Self::Draft32)
+ } else {
+ Err(Error::VersionNegotiation)
+ }
+ }
+}
+
+struct PacketBuilderOffsets {
+ /// The bits of the first octet that need masking.
+ first_byte_mask: u8,
+ /// The offset of the length field.
+ len: usize,
+ /// The location of the packet number field.
+ pn: Range<usize>,
+}
+
+/// A packet builder that can be used to produce short packets and long packets.
+/// This does not produce Retry or Version Negotiation.
+pub struct PacketBuilder {
+ encoder: Encoder,
+ pn: PacketNumber,
+ header: Range<usize>,
+ offsets: PacketBuilderOffsets,
+ limit: usize,
+}
+
+impl PacketBuilder {
+ fn infer_limit(encoder: &Encoder) -> usize {
+ if encoder.capacity() > 64 {
+ encoder.capacity()
+ } else {
+ 2048
+ }
+ }
+
+ /// Start building a short header packet.
+ #[allow(clippy::unknown_clippy_lints)] // Until we require rust 1.45.
+ #[allow(clippy::reversed_empty_ranges)]
+ pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self {
+ let header_start = encoder.len();
+ encoder.encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2));
+ encoder.encode(dcid.as_ref());
+ let limit = Self::infer_limit(&encoder);
+ Self {
+ encoder,
+ pn: u64::max_value(),
+ header: header_start..header_start,
+ offsets: PacketBuilderOffsets {
+ first_byte_mask: PACKET_HP_MASK_SHORT,
+ pn: 0..0,
+ len: 0,
+ },
+ limit,
+ }
+ }
+
+ /// Start building a long header packet.
+ /// For an Initial packet you will need to call initial_token(),
+ /// even if the token is empty.
+ #[allow(clippy::unknown_clippy_lints)] // Until we require rust 1.45.
+ #[allow(clippy::reversed_empty_ranges)] // For initializing an empty range.
+ pub fn long(
+ mut encoder: Encoder,
+ pt: PacketType,
+ quic_version: QuicVersion,
+ dcid: impl AsRef<[u8]>,
+ scid: impl AsRef<[u8]>,
+ ) -> Self {
+ let header_start = encoder.len();
+ encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.code() << 4);
+ encoder.encode_uint(4, quic_version.as_u32());
+ encoder.encode_vec(1, dcid.as_ref());
+ encoder.encode_vec(1, scid.as_ref());
+ let limit = Self::infer_limit(&encoder);
+ Self {
+ encoder,
+ pn: u64::max_value(),
+ header: header_start..header_start,
+ offsets: PacketBuilderOffsets {
+ first_byte_mask: PACKET_HP_MASK_LONG,
+ pn: 0..0,
+ len: 0,
+ },
+ limit,
+ }
+ }
+
+ fn is_long(&self) -> bool {
+ self[self.header.start] & 0x80 == PACKET_BIT_LONG
+ }
+
+ /// This stores a value that can be used as a limit. This does not cause
+ /// this limit to be enforced until encryption occurs. Prior to that, it
+ /// is only used voluntarily by users of the builder, through `remaining()`.
+ pub fn set_limit(&mut self, limit: usize) {
+ self.limit = limit;
+ }
+
+ /// How many bytes remain against the size limit for the builder.
+ #[must_use]
+ pub fn remaining(&self) -> usize {
+ self.limit - self.encoder.len()
+ }
+
+ /// Pad with "PADDING" frames.
+ pub fn pad(&mut self) {
+ self.encoder.pad_to(self.limit, 0);
+ }
+
+ /// Add unpredictable values for unprotected parts of the packet.
+ pub fn scramble(&mut self, quic_bit: bool) {
+ let mask = if quic_bit { PACKET_BIT_FIXED_QUIC } else { 0 }
+ | if self.is_long() { 0 } else { PACKET_BIT_SPIN };
+ let first = self.header.start;
+ self[first] ^= random(1)[0] & mask;
+ }
+
+ /// For an Initial packet, encode the token.
+ /// If you fail to do this, then you will not get a valid packet.
+ pub fn initial_token(&mut self, token: &[u8]) {
+ debug_assert_eq!(
+ self.encoder[self.header.start] & 0xb0,
+ PACKET_BIT_LONG | PACKET_TYPE_INITIAL << 4
+ );
+ self.encoder.encode_vvec(token);
+ }
+
+ /// Add a packet number of the given size.
+ /// For a long header packet, this also inserts a dummy length.
+ /// The length is filled in after calling `build`.
+ pub fn pn(&mut self, pn: PacketNumber, pn_len: usize) {
+ // Reserve space for a length in long headers.
+ if self.is_long() {
+ self.offsets.len = self.encoder.len();
+ self.encoder.encode(&[0; 2]);
+ }
+
+ // This allows the input to be >4, which is absurd, but we can eat that.
+ let pn_len = min(MAX_PACKET_NUMBER_LEN, pn_len);
+ debug_assert_ne!(pn_len, 0);
+ // Encode the packet number and save its offset.
+ let pn_offset = self.encoder.len();
+ self.encoder.encode_uint(pn_len, pn);
+ self.offsets.pn = pn_offset..self.encoder.len();
+
+ // Now encode the packet number length and save the header length.
+ self.encoder[self.header.start] |= u8::try_from(pn_len - 1).unwrap();
+ self.header.end = self.encoder.len();
+ self.pn = pn;
+ }
+
+ fn write_len(&mut self, expansion: usize) {
+ let len = self.encoder.len() - (self.offsets.len + 2) + expansion;
+ self.encoder[self.offsets.len] = 0x40 | ((len >> 8) & 0x3f) as u8;
+ self.encoder[self.offsets.len + 1] = (len & 0xff) as u8;
+ }
+
+ fn pad_for_crypto(&mut self, crypto: &mut CryptoDxState) {
+ // Make sure that there is enough data in the packet.
+ // The length of the packet number plus the payload length needs to
+ // be at least 4 (MAX_PACKET_NUMBER_LEN) plus any amount by which
+ // the header protection sample exceeds the AEAD expansion.
+ let crypto_pad = crypto.extra_padding();
+ self.encoder.pad_to(
+ self.offsets.pn.start + MAX_PACKET_NUMBER_LEN + crypto_pad,
+ 0,
+ );
+ }
+
+ /// Build the packet and return the encoder.
+ pub fn build(mut self, crypto: &mut CryptoDxState) -> Res<Encoder> {
+ if self.len() > self.limit {
+ qwarn!("Packet contents are more than the limit");
+ debug_assert!(false);
+ return Err(Error::InternalError);
+ }
+
+ self.pad_for_crypto(crypto);
+ if self.offsets.len > 0 {
+ self.write_len(crypto.expansion());
+ }
+
+ let hdr = &self.encoder[self.header.clone()];
+ let body = &self.encoder[self.header.end..];
+ qtrace!(
+ "Packet build pn={} hdr={} body={}",
+ self.pn,
+ hex(hdr),
+ hex(body)
+ );
+ let ciphertext = crypto.encrypt(self.pn, hdr, body)?;
+
+ // Calculate the mask.
+ let offset = SAMPLE_OFFSET - self.offsets.pn.len();
+ assert!(offset + SAMPLE_SIZE <= ciphertext.len());
+ let sample = &ciphertext[offset..offset + SAMPLE_SIZE];
+ let mask = crypto.compute_mask(sample)?;
+
+ // Apply the mask.
+ self.encoder[self.header.start] ^= mask[0] & self.offsets.first_byte_mask;
+ for (i, j) in (1..=self.offsets.pn.len()).zip(self.offsets.pn) {
+ self.encoder[j] ^= mask[i];
+ }
+
+ // Finally, cut off the plaintext and add back the ciphertext.
+ self.encoder.truncate(self.header.end);
+ self.encoder.encode(&ciphertext);
+ qtrace!("Packet built {}", hex(&self.encoder));
+ Ok(self.encoder)
+ }
+
+ /// Abort writing of this packet and return the encoder.
+ #[must_use]
+ pub fn abort(mut self) -> Encoder {
+ self.encoder.truncate(self.header.start);
+ self.encoder
+ }
+
+ /// Work out if nothing was added after the header.
+ #[must_use]
+ pub fn packet_empty(&self) -> bool {
+ self.encoder.len() == self.header.end
+ }
+
+ /// Make a retry packet.
+ /// As this is a simple packet, this is just an associated function.
+ /// As Retry is odd (it has to be constructed with leading bytes),
+ /// this returns a Vec<u8> rather than building on an encoder.
+ pub fn retry(
+ quic_version: QuicVersion,
+ dcid: &[u8],
+ scid: &[u8],
+ token: &[u8],
+ odcid: &[u8],
+ ) -> Res<Vec<u8>> {
+ let mut encoder = Encoder::default();
+ encoder.encode_vec(1, odcid);
+ let start = encoder.len();
+ encoder.encode_byte(
+ PACKET_BIT_LONG
+ | PACKET_BIT_FIXED_QUIC
+ | (PACKET_TYPE_RETRY << 4)
+ | (random(1)[0] & 0xf),
+ );
+ encoder.encode_uint(4, quic_version.as_u32());
+ encoder.encode_vec(1, dcid);
+ encoder.encode_vec(1, scid);
+ debug_assert_ne!(token.len(), 0);
+ encoder.encode(token);
+ let tag = retry::use_aead(quic_version, |aead| {
+ let mut buf = vec![0; aead.expansion()];
+ Ok(aead.encrypt(0, &encoder, &[], &mut buf)?.to_vec())
+ })?;
+ encoder.encode(&tag);
+ let mut complete: Vec<u8> = encoder.into();
+ Ok(complete.split_off(start))
+ }
+
+ /// Make a Version Negotiation packet.
+ pub fn version_negotiation(dcid: &[u8], scid: &[u8]) -> Vec<u8> {
+ let mut encoder = Encoder::default();
+ let mut grease = random(5);
+ // This will not include the "QUIC bit" sometimes. Intentionally.
+ encoder.encode_byte(PACKET_BIT_LONG | (grease[4] & 0x7f));
+ encoder.encode(&[0; 4]); // Zero version == VN.
+ encoder.encode_vec(1, dcid);
+ encoder.encode_vec(1, scid);
+ encoder.encode_uint(4, QuicVersion::Draft27.as_u32());
+ encoder.encode_uint(4, QuicVersion::Draft28.as_u32());
+ encoder.encode_uint(4, QuicVersion::Draft29.as_u32());
+ encoder.encode_uint(4, QuicVersion::Draft30.as_u32());
+ encoder.encode_uint(4, QuicVersion::Draft31.as_u32());
+ encoder.encode_uint(4, QuicVersion::Draft32.as_u32());
+ // Add a greased version, using the randomness already generated.
+ for g in &mut grease[..4] {
+ *g = *g & 0xf0 | 0x0a;
+ }
+ encoder.encode(&grease[0..4]);
+ encoder.into()
+ }
+}
+
+impl Deref for PacketBuilder {
+ type Target = Encoder;
+
+ fn deref(&self) -> &Self::Target {
+ &self.encoder
+ }
+}
+
+impl DerefMut for PacketBuilder {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.encoder
+ }
+}
+
+impl Into<Encoder> for PacketBuilder {
+ fn into(self) -> Encoder {
+ self.encoder
+ }
+}
+
+/// PublicPacket holds information from packets that is public only. This allows for
+/// processing of packets prior to decryption.
+pub struct PublicPacket<'a> {
+ /// The packet type.
+ packet_type: PacketType,
+ /// The recovered destination connection ID.
+ dcid: ConnectionIdRef<'a>,
+ /// The source connection ID, if this is a long header packet.
+ scid: Option<ConnectionIdRef<'a>>,
+ /// Any token that is included in the packet (Retry always has a token; Initial sometimes does).
+ /// This is empty when there is no token.
+ token: &'a [u8],
+ /// The size of the header, not including the packet number.
+ header_len: usize,
+ /// Protocol version, if present in header.
+ quic_version: Option<QuicVersion>,
+ /// A reference to the entire packet, including the header.
+ data: &'a [u8],
+}
+
+impl<'a> PublicPacket<'a> {
+ fn opt<T>(v: Option<T>) -> Res<T> {
+ if let Some(v) = v {
+ Ok(v)
+ } else {
+ Err(Error::NoMoreData)
+ }
+ }
+
+ /// Decode the type-specific portions of a long header.
+ /// This includes reading the length and the remainder of the packet.
+ /// Returns a tuple of any token and the length of the header.
+ fn decode_long(
+ decoder: &mut Decoder<'a>,
+ packet_type: PacketType,
+ quic_version: QuicVersion,
+ ) -> Res<(&'a [u8], usize)> {
+ if packet_type == PacketType::Retry {
+ let header_len = decoder.offset();
+ let expansion = retry::expansion(quic_version);
+ let token = Self::opt(decoder.decode(decoder.remaining() - expansion))?;
+ if token.is_empty() {
+ return Err(Error::InvalidPacket);
+ }
+ Self::opt(decoder.decode(expansion))?;
+ return Ok((token, header_len));
+ }
+ let token = if packet_type == PacketType::Initial {
+ Self::opt(decoder.decode_vvec())?
+ } else {
+ &[]
+ };
+ let len = Self::opt(decoder.decode_varint())?;
+ let header_len = decoder.offset();
+ let _body = Self::opt(decoder.decode(usize::try_from(len)?))?;
+ Ok((token, header_len))
+ }
+
+ /// Decode the common parts of a packet. This provides minimal parsing and validation.
+ /// Returns a tuple of a `PublicPacket` and a slice with any remainder from the datagram.
+ pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> {
+ let mut decoder = Decoder::new(data);
+ let first = Self::opt(decoder.decode_byte())?;
+
+ if first & 0x80 == PACKET_BIT_SHORT {
+ let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?;
+ if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE {
+ return Err(Error::InvalidPacket);
+ }
+ let header_len = decoder.offset();
+ return Ok((
+ Self {
+ packet_type: PacketType::Short,
+ dcid,
+ scid: None,
+ token: &[],
+ header_len,
+ quic_version: None,
+ data,
+ },
+ &[],
+ ));
+ }
+
+ // Generic long header.
+ let version = Version::try_from(Self::opt(decoder.decode_uint(4))?).unwrap();
+ let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
+ let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
+
+ // Version negotiation.
+ if version == 0 {
+ return Ok((
+ Self {
+ packet_type: PacketType::VersionNegotiation,
+ dcid,
+ scid: Some(scid),
+ token: &[],
+ header_len: decoder.offset(),
+ quic_version: None,
+ data,
+ },
+ &[],
+ ));
+ }
+
+ // Check that this is a long header from a supported version.
+ let quic_version = if let Ok(v) = QuicVersion::try_from(version) {
+ v
+ } else {
+ return Ok((
+ Self {
+ packet_type: PacketType::OtherVersion,
+ dcid,
+ scid: Some(scid),
+ token: &[],
+ header_len: decoder.offset(),
+ quic_version: None,
+ data,
+ },
+ &[],
+ ));
+ };
+
+ if dcid.len() > MAX_CONNECTION_ID_LEN || scid.len() > MAX_CONNECTION_ID_LEN {
+ return Err(Error::InvalidPacket);
+ }
+ let packet_type = match (first >> 4) & 3 {
+ PACKET_TYPE_INITIAL => PacketType::Initial,
+ PACKET_TYPE_0RTT => PacketType::ZeroRtt,
+ PACKET_TYPE_HANDSHAKE => PacketType::Handshake,
+ PACKET_TYPE_RETRY => PacketType::Retry,
+ _ => unreachable!(),
+ };
+
+ // The type-specific code includes a token. This consumes the remainder of the packet.
+ let (token, header_len) = Self::decode_long(&mut decoder, packet_type, quic_version)?;
+ let end = data.len() - decoder.remaining();
+ let (data, remainder) = data.split_at(end);
+ Ok((
+ Self {
+ packet_type,
+ dcid,
+ scid: Some(scid),
+ token,
+ header_len,
+ quic_version: Some(quic_version),
+ data,
+ },
+ remainder,
+ ))
+ }
+
+ /// Validate the given packet as though it were a retry.
+ pub fn is_valid_retry(&self, odcid: &ConnectionId) -> bool {
+ if self.packet_type != PacketType::Retry {
+ return false;
+ }
+ let version = self.quic_version.unwrap();
+ let expansion = retry::expansion(version);
+ if self.data.len() <= expansion {
+ return false;
+ }
+ let (header, tag) = self.data.split_at(self.data.len() - expansion);
+ let mut encoder = Encoder::with_capacity(self.data.len());
+ encoder.encode_vec(1, odcid);
+ encoder.encode(header);
+ retry::use_aead(version, |aead| {
+ let mut buf = vec![0; expansion];
+ Ok(aead.decrypt(0, &encoder, tag, &mut buf)?.is_empty())
+ })
+ .unwrap_or(false)
+ }
+
+ pub fn is_valid_initial(&self) -> bool {
+ // Packet has to be an initial, with a DCID of 8 bytes, or a token.
+ // Note: the Server class validates the token and checks the length.
+ self.packet_type == PacketType::Initial
+ && (self.dcid().len() >= 8 || !self.token.is_empty())
+ }
+
+ pub fn packet_type(&self) -> PacketType {
+ self.packet_type
+ }
+
+ pub fn dcid(&self) -> &ConnectionIdRef<'a> {
+ &self.dcid
+ }
+
+ pub fn scid(&self) -> &ConnectionIdRef<'a> {
+ self.scid
+ .as_ref()
+ .expect("should only be called for long header packets")
+ }
+
+ pub fn token(&self) -> &'a [u8] {
+ self.token
+ }
+
+ pub fn version(&self) -> Option<QuicVersion> {
+ self.quic_version
+ }
+
+ pub fn len(&self) -> usize {
+ self.data.len()
+ }
+
+ fn decode_pn(expected: PacketNumber, pn: u64, w: usize) -> PacketNumber {
+ let window = 1_u64 << (w * 8);
+ let candidate = (expected & !(window - 1)) | pn;
+ if candidate + (window / 2) <= expected {
+ candidate + window
+ } else if candidate > expected + (window / 2) {
+ match candidate.checked_sub(window) {
+ Some(pn_sub) => pn_sub,
+ None => candidate,
+ }
+ } else {
+ candidate
+ }
+ }
+
+ /// Decrypt the header of the packet.
+ fn decrypt_header(
+ &self,
+ crypto: &mut CryptoDxState,
+ ) -> Res<(bool, PacketNumber, Vec<u8>, &'a [u8])> {
+ assert_ne!(self.packet_type, PacketType::Retry);
+ assert_ne!(self.packet_type, PacketType::VersionNegotiation);
+
+ qtrace!(
+ "unmask hdr={}",
+ hex(&self.data[..self.header_len + SAMPLE_OFFSET])
+ );
+
+ let sample_offset = self.header_len + SAMPLE_OFFSET;
+ let mask = if let Some(sample) = self.data.get(sample_offset..(sample_offset + SAMPLE_SIZE))
+ {
+ crypto.compute_mask(sample)
+ } else {
+ Err(Error::NoMoreData)
+ }?;
+
+ // Un-mask the leading byte.
+ let bits = if self.packet_type == PacketType::Short {
+ PACKET_HP_MASK_SHORT
+ } else {
+ PACKET_HP_MASK_LONG
+ };
+ let first_byte = self.data[0] ^ (mask[0] & bits);
+
+ // Make a copy of the header to work on.
+ let mut hdrbytes = self.data[..self.header_len + 4].to_vec();
+ hdrbytes[0] = first_byte;
+
+ // Unmask the PN.
+ let mut pn_encoded: u64 = 0;
+ for i in 0..MAX_PACKET_NUMBER_LEN {
+ hdrbytes[self.header_len + i] ^= mask[1 + i];
+ pn_encoded <<= 8;
+ pn_encoded += u64::from(hdrbytes[self.header_len + i]);
+ }
+
+ // Now decode the packet number length and apply it, hopefully in constant time.
+ let pn_len = usize::from((first_byte & 0x3) + 1);
+ hdrbytes.truncate(self.header_len + pn_len);
+ pn_encoded >>= 8 * (MAX_PACKET_NUMBER_LEN - pn_len);
+
+ qtrace!("unmasked hdr={}", hex(&hdrbytes));
+
+ let key_phase = self.packet_type == PacketType::Short
+ && (first_byte & PACKET_BIT_KEY_PHASE) == PACKET_BIT_KEY_PHASE;
+ let pn = Self::decode_pn(crypto.next_pn(), pn_encoded, pn_len);
+ Ok((
+ key_phase,
+ pn,
+ hdrbytes,
+ &self.data[self.header_len + pn_len..],
+ ))
+ }
+
+ pub fn decrypt(&self, crypto: &mut CryptoStates, release_at: Instant) -> Res<DecryptedPacket> {
+ let cspace: CryptoSpace = self.packet_type.into();
+ // This has to work in two stages because we need to remove header protection
+ // before picking the keys to use.
+ if let Some(rx) = crypto.rx_hp(cspace) {
+ // Note that this will dump early, which creates a side-channel.
+ // This is OK in this case because we the only reason this can
+ // fail is if the cryptographic module is bad or the packet is
+ // too small (which is public information).
+ let (key_phase, pn, header, body) = self.decrypt_header(rx)?;
+ qtrace!([rx], "decoded header: {:?}", header);
+ let rx = crypto.rx(cspace, key_phase).unwrap();
+ let d = rx.decrypt(pn, &header, body)?;
+ // If this is the first packet ever successfully decrypted
+ // using `rx`, make sure to initiate a key update.
+ if rx.needs_update() {
+ crypto.key_update_received(release_at)?;
+ }
+ crypto.check_pn_overlap()?;
+ Ok(DecryptedPacket {
+ pt: self.packet_type,
+ pn,
+ data: d,
+ })
+ } else if crypto.rx_pending(cspace) {
+ Err(Error::KeysPending(cspace))
+ } else {
+ qtrace!("keys for {:?} already discarded", cspace);
+ Err(Error::KeysDiscarded)
+ }
+ }
+
+ pub fn supported_versions(&self) -> Res<Vec<Version>> {
+ assert_eq!(self.packet_type, PacketType::VersionNegotiation);
+ let mut decoder = Decoder::new(&self.data[self.header_len..]);
+ let mut res = Vec::new();
+ while decoder.remaining() > 0 {
+ let version = Version::try_from(Self::opt(decoder.decode_uint(4))?)?;
+ res.push(version);
+ }
+ Ok(res)
+ }
+}
+
+impl fmt::Debug for PublicPacket<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{:?}: {} {}",
+ self.packet_type(),
+ hex_with_len(&self.data[..self.header_len]),
+ hex_with_len(&self.data[self.header_len..])
+ )
+ }
+}
+
+pub struct DecryptedPacket {
+ pt: PacketType,
+ pn: PacketNumber,
+ data: Vec<u8>,
+}
+
+impl DecryptedPacket {
+ pub fn packet_type(&self) -> PacketType {
+ self.pt
+ }
+
+ pub fn pn(&self) -> PacketNumber {
+ self.pn
+ }
+}
+
+impl Deref for DecryptedPacket {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.data[..]
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::crypto::{CryptoDxState, CryptoStates};
+ use crate::{FixedConnectionIdManager, QuicVersion};
+ use neqo_common::Encoder;
+ use test_fixture::{fixture_init, now};
+
+ const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
+ const SERVER_CID: &[u8] = &[0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5];
+
+ /// This is a connection ID manager, which is only used for decoding short header packets.
+ fn cid_mgr() -> FixedConnectionIdManager {
+ FixedConnectionIdManager::new(SERVER_CID.len())
+ }
+
+ const SAMPLE_INITIAL_PAYLOAD: &[u8] = &[
+ 0x02, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x40, 0x5a, 0x02, 0x00, 0x00, 0x56, 0x03, 0x03,
+ 0xee, 0xfc, 0xe7, 0xf7, 0xb3, 0x7b, 0xa1, 0xd1, 0x63, 0x2e, 0x96, 0x67, 0x78, 0x25, 0xdd,
+ 0xf7, 0x39, 0x88, 0xcf, 0xc7, 0x98, 0x25, 0xdf, 0x56, 0x6d, 0xc5, 0x43, 0x0b, 0x9a, 0x04,
+ 0x5a, 0x12, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00,
+ 0x20, 0x9d, 0x3c, 0x94, 0x0d, 0x89, 0x69, 0x0b, 0x84, 0xd0, 0x8a, 0x60, 0x99, 0x3c, 0x14,
+ 0x4e, 0xca, 0x68, 0x4d, 0x10, 0x81, 0x28, 0x7c, 0x83, 0x4d, 0x53, 0x11, 0xbc, 0xf3, 0x2b,
+ 0xb9, 0xda, 0x1a, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04,
+ ];
+ const SAMPLE_INITIAL: &[u8] = &[
+ 0xc7, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x00, 0x40, 0x75, 0xfb, 0x12, 0xff, 0x07, 0x82, 0x3a, 0x5d, 0x24, 0x53, 0x4d, 0x90, 0x6c,
+ 0xe4, 0xc7, 0x67, 0x82, 0xa2, 0x16, 0x7e, 0x34, 0x79, 0xc0, 0xf7, 0xf6, 0x39, 0x5d, 0xc2,
+ 0xc9, 0x16, 0x76, 0x30, 0x2f, 0xe6, 0xd7, 0x0b, 0xb7, 0xcb, 0xeb, 0x11, 0x7b, 0x4d, 0xdb,
+ 0x7d, 0x17, 0x34, 0x98, 0x44, 0xfd, 0x61, 0xda, 0xe2, 0x00, 0xb8, 0x33, 0x8e, 0x1b, 0x93,
+ 0x29, 0x76, 0xb6, 0x1d, 0x91, 0xe6, 0x4a, 0x02, 0xe9, 0xe0, 0xee, 0x72, 0xe3, 0xa6, 0xf6,
+ 0x3a, 0xba, 0x4c, 0xee, 0xee, 0xc5, 0xbe, 0x2f, 0x24, 0xf2, 0xd8, 0x60, 0x27, 0x57, 0x29,
+ 0x43, 0x53, 0x38, 0x46, 0xca, 0xa1, 0x3e, 0x6f, 0x16, 0x3f, 0xb2, 0x57, 0x47, 0x3d, 0xcc,
+ 0xa2, 0x53, 0x96, 0xe8, 0x87, 0x24, 0xf1, 0xe5, 0xd9, 0x64, 0xde, 0xde, 0xe9, 0xb6, 0x33,
+ ];
+
+ #[test]
+ fn sample_server_initial() {
+ fixture_init();
+ let mut prot = CryptoDxState::test_default();
+
+ // The spec uses PN=1, but our crypto refuses to skip packet numbers.
+ // So burn an encryption:
+ let burn = prot.encrypt(0, &[], &[]).expect("burn OK");
+ assert_eq!(burn.len(), prot.expansion());
+
+ let mut builder = PacketBuilder::long(
+ Encoder::new(),
+ PacketType::Initial,
+ QuicVersion::default(),
+ &ConnectionId::from(&[][..]),
+ &ConnectionId::from(SERVER_CID),
+ );
+ builder.initial_token(&[]);
+ builder.pn(1, 2);
+ builder.encode(&SAMPLE_INITIAL_PAYLOAD);
+ let packet = builder.build(&mut prot).expect("build");
+ assert_eq!(&packet[..], SAMPLE_INITIAL);
+ }
+
+ #[test]
+ fn decrypt_initial() {
+ const EXTRA: &[u8] = &[0xce; 33];
+
+ fixture_init();
+ let mut padded = SAMPLE_INITIAL.to_vec();
+ padded.extend_from_slice(EXTRA);
+ let (packet, remainder) = PublicPacket::decode(&padded, &cid_mgr()).unwrap();
+ assert_eq!(packet.packet_type(), PacketType::Initial);
+ assert_eq!(&packet.dcid()[..], &[] as &[u8]);
+ assert_eq!(&packet.scid()[..], SERVER_CID);
+ assert!(packet.token().is_empty());
+ assert_eq!(remainder, EXTRA);
+
+ let decrypted = packet
+ .decrypt(&mut CryptoStates::test_default(), now())
+ .unwrap();
+ assert_eq!(decrypted.pn(), 1);
+ }
+
+ #[test]
+ fn disallow_long_dcid() {
+ let mut enc = Encoder::new();
+ enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC);
+ enc.encode_uint(4, QuicVersion::default().as_u32());
+ enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 1]);
+ enc.encode_vec(1, &[]);
+ enc.encode(&[0xff; 40]); // junk
+
+ assert!(PublicPacket::decode(&enc, &cid_mgr()).is_err());
+ }
+
+ #[test]
+ fn disallow_long_scid() {
+ let mut enc = Encoder::new();
+ enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC);
+ enc.encode_uint(4, QuicVersion::default().as_u32());
+ enc.encode_vec(1, &[]);
+ enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 2]);
+ enc.encode(&[0xff; 40]); // junk
+
+ assert!(PublicPacket::decode(&enc, &cid_mgr()).is_err());
+ }
+
+ const SAMPLE_SHORT: &[u8] = &[
+ 0x55, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x99, 0x9c, 0xbd, 0x77, 0xf5, 0xd7,
+ 0x0a, 0x28, 0xe8, 0xfb, 0xc3, 0xed, 0xf5, 0x71, 0xb1, 0x04, 0x32, 0x2a, 0xae, 0xae,
+ ];
+ const SAMPLE_SHORT_PAYLOAD: &[u8] = &[0; 3];
+
+ #[test]
+ fn build_short() {
+ fixture_init();
+ let mut builder =
+ PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID));
+ builder.pn(0, 1);
+ builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling.
+ let packet = builder
+ .build(&mut CryptoDxState::test_default())
+ .expect("build");
+ assert_eq!(&packet[..], SAMPLE_SHORT);
+ }
+
+ #[test]
+ fn scramble_short() {
+ fixture_init();
+ let mut firsts = Vec::new();
+ for _ in 0..64 {
+ let mut builder =
+ PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID));
+ builder.scramble(true);
+ builder.pn(0, 1);
+ firsts.push(builder[0]);
+ }
+ let is_set = |bit| move |v| v & bit == bit;
+ // There should be at least one value with the QUIC bit set:
+ assert!(firsts.iter().any(is_set(PACKET_BIT_FIXED_QUIC)));
+ // ... but not all:
+ assert!(!firsts.iter().all(is_set(PACKET_BIT_FIXED_QUIC)));
+ // There should be at least one value with the spin bit set:
+ assert!(firsts.iter().any(is_set(PACKET_BIT_SPIN)));
+ // ... but not all:
+ assert!(!firsts.iter().all(is_set(PACKET_BIT_SPIN)));
+ }
+
+ #[test]
+ fn decode_short() {
+ fixture_init();
+ let (packet, remainder) = PublicPacket::decode(SAMPLE_SHORT, &cid_mgr()).unwrap();
+ assert_eq!(packet.packet_type(), PacketType::Short);
+ assert!(remainder.is_empty());
+ let decrypted = packet
+ .decrypt(&mut CryptoStates::test_default(), now())
+ .unwrap();
+ assert_eq!(&decrypted[..], SAMPLE_SHORT_PAYLOAD);
+ }
+
+ /// By telling the decoder that the connection ID is shorter than it really is, we get a decryption error.
+ #[test]
+ fn decode_short_bad_cid() {
+ fixture_init();
+ let (packet, remainder) = PublicPacket::decode(
+ SAMPLE_SHORT,
+ &FixedConnectionIdManager::new(SERVER_CID.len() - 1),
+ )
+ .unwrap();
+ assert_eq!(packet.packet_type(), PacketType::Short);
+ assert!(remainder.is_empty());
+ assert!(packet
+ .decrypt(&mut CryptoStates::test_default(), now())
+ .is_err());
+ }
+
+ /// Saying that the connection ID is longer causes the initial decode to fail.
+ #[test]
+ fn decode_short_long_cid() {
+ assert!(PublicPacket::decode(
+ SAMPLE_SHORT,
+ &FixedConnectionIdManager::new(SERVER_CID.len() + 1)
+ )
+ .is_err());
+ }
+
+ #[test]
+ fn build_two() {
+ fixture_init();
+ let mut prot = CryptoDxState::test_default();
+ let mut builder = PacketBuilder::long(
+ Encoder::new(),
+ PacketType::Handshake,
+ QuicVersion::default(),
+ &ConnectionId::from(SERVER_CID),
+ &ConnectionId::from(CLIENT_CID),
+ );
+ builder.pn(0, 1);
+ builder.encode(&[0; 3]);
+ let encoder = builder.build(&mut prot).expect("build");
+ assert_eq!(encoder.len(), 45);
+ let first = encoder.clone();
+
+ let mut builder = PacketBuilder::short(encoder, false, &ConnectionId::from(SERVER_CID));
+ builder.pn(1, 3);
+ builder.encode(&[0]); // Minimal size (packet number is big enough).
+ let encoder = builder.build(&mut prot).expect("build");
+ assert_eq!(
+ &first[..],
+ &encoder[..first.len()],
+ "the first packet should be a prefix"
+ );
+ assert_eq!(encoder.len(), 45 + 29);
+ }
+
+ #[test]
+ fn build_long() {
+ const EXPECTED: &[u8] = &[
+ 0xe5, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x00, 0x40, 0x14, 0xa8, 0x9d, 0xbf, 0x74, 0x70,
+ 0x32, 0xda, 0xba, 0xfb, 0x87, 0x61, 0xb8, 0x31, 0x90, 0xf3, 0x25, 0x52, 0x0b, 0xbe,
+ 0xdb,
+ ];
+
+ fixture_init();
+ let mut builder = PacketBuilder::long(
+ Encoder::new(),
+ PacketType::Handshake,
+ QuicVersion::default(),
+ &ConnectionId::from(&[][..]),
+ &ConnectionId::from(&[][..]),
+ );
+ builder.pn(0, 1);
+ builder.encode(&[1, 2, 3]);
+ let packet = builder.build(&mut CryptoDxState::test_default()).unwrap();
+ assert_eq!(&packet[..], EXPECTED);
+ }
+
+ #[test]
+ fn scramble_long() {
+ fixture_init();
+ let mut found_unset = false;
+ let mut found_set = false;
+ for _ in 1..64 {
+ let mut builder = PacketBuilder::long(
+ Encoder::new(),
+ PacketType::Handshake,
+ QuicVersion::default(),
+ &ConnectionId::from(&[][..]),
+ &ConnectionId::from(&[][..]),
+ );
+ builder.pn(0, 1);
+ builder.scramble(true);
+ if (builder[0] & PACKET_BIT_FIXED_QUIC) == 0 {
+ found_unset = true;
+ } else {
+ found_set = true;
+ }
+ }
+ assert!(found_unset);
+ assert!(found_set);
+ }
+
+ #[test]
+ fn build_abort() {
+ let mut builder = PacketBuilder::long(
+ Encoder::new(),
+ PacketType::Initial,
+ QuicVersion::default(),
+ &ConnectionId::from(&[][..]),
+ &ConnectionId::from(SERVER_CID),
+ );
+ builder.initial_token(&[]);
+ builder.pn(1, 2);
+ let encoder = builder.abort();
+ assert!(encoder.is_empty());
+ }
+
+ const SAMPLE_RETRY_27: &[u8] = &[
+ 0xff, 0xff, 0x00, 0x00, 0x1b, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xa5, 0x23, 0xcb, 0x5b, 0xa5, 0x24, 0x69, 0x5f, 0x65, 0x69,
+ 0xf2, 0x93, 0xa1, 0x35, 0x9d, 0x8e,
+ ];
+
+ const SAMPLE_RETRY_28: &[u8] = &[
+ 0xff, 0xff, 0x00, 0x00, 0x1c, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xf7, 0x1a, 0x5f, 0x12, 0xaf, 0xe3, 0xec, 0xf8, 0x00, 0x1a,
+ 0x92, 0x0e, 0x6f, 0xdf, 0x1d, 0x63,
+ ];
+
+ const SAMPLE_RETRY_29: &[u8] = &[
+ 0xff, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xd1, 0x69, 0x26, 0xd8, 0x1f, 0x6f, 0x9c, 0xa2, 0x95, 0x3a,
+ 0x8a, 0xa4, 0x57, 0x5e, 0x1e, 0x49,
+ ];
+
+ const SAMPLE_RETRY_30: &[u8] = &[
+ 0xff, 0xff, 0x00, 0x00, 0x1e, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2d, 0x3e, 0x04, 0x5d, 0x6d, 0x39, 0x20, 0x67, 0x89, 0x94,
+ 0x37, 0x10, 0x8c, 0xe0, 0x0a, 0x61,
+ ];
+
+ const SAMPLE_RETRY_31: &[u8] = &[
+ 0xff, 0xff, 0x00, 0x00, 0x1f, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xc7, 0x0c, 0xe5, 0xde, 0x43, 0x0b, 0x4b, 0xdb, 0x7d, 0xf1,
+ 0xa3, 0x83, 0x3a, 0x75, 0xf9, 0x86,
+ ];
+
+ const SAMPLE_RETRY_32: &[u8] = &[
+ 0xff, 0xff, 0x00, 0x00, 0x20, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x59, 0x75, 0x65, 0x19, 0xdd, 0x6c, 0xc8, 0x5b, 0xd9, 0x0e,
+ 0x33, 0xa9, 0x34, 0xd2, 0xff, 0x85,
+ ];
+
+ const RETRY_TOKEN: &[u8] = b"token";
+
+ fn build_retry_single(quic_version: QuicVersion, sample_retry: &[u8]) {
+ fixture_init();
+ let retry =
+ PacketBuilder::retry(quic_version, &[], SERVER_CID, RETRY_TOKEN, CLIENT_CID).unwrap();
+
+ let (packet, remainder) = PublicPacket::decode(&retry, &cid_mgr()).unwrap();
+ assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID)));
+ assert!(remainder.is_empty());
+
+ // The builder adds randomness, which makes expectations hard.
+ // So only do a full check when that randomness matches up.
+ if retry[0] == sample_retry[0] {
+ assert_eq!(&retry, &sample_retry);
+ } else {
+ // Otherwise, just check that the header is OK.
+ assert_eq!(retry[0] & 0xf0, 0xf0);
+ let header_range = 1..retry.len() - 16;
+ assert_eq!(&retry[header_range.clone()], &sample_retry[header_range]);
+ }
+ }
+
+ #[test]
+ fn build_retry_27() {
+ build_retry_single(QuicVersion::Draft27, SAMPLE_RETRY_27);
+ }
+
+ #[test]
+ fn build_retry_28() {
+ build_retry_single(QuicVersion::Draft28, SAMPLE_RETRY_28);
+ }
+
+ #[test]
+ fn build_retry_29() {
+ build_retry_single(QuicVersion::Draft29, SAMPLE_RETRY_29);
+ }
+
+ #[test]
+ fn build_retry_30() {
+ build_retry_single(QuicVersion::Draft30, SAMPLE_RETRY_30);
+ }
+
+ #[test]
+ fn build_retry_31() {
+ build_retry_single(QuicVersion::Draft31, SAMPLE_RETRY_31);
+ }
+
+ #[test]
+ fn build_retry_32() {
+ build_retry_single(QuicVersion::Draft32, SAMPLE_RETRY_32);
+ }
+
+ #[test]
+ fn build_retry_multiple() {
+ // Run the build_retry test a few times.
+ // Odds are approximately 1 in 8 that the full comparison doesn't happen
+ // for a given version.
+ for _ in 0..32 {
+ build_retry_27();
+ build_retry_28();
+ build_retry_29();
+ build_retry_30();
+ }
+ }
+
+ fn decode_retry(quic_version: QuicVersion, sample_retry: &[u8]) {
+ fixture_init();
+ let (packet, remainder) =
+ PublicPacket::decode(sample_retry, &FixedConnectionIdManager::new(5)).unwrap();
+ assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID)));
+ assert_eq!(Some(quic_version), packet.quic_version);
+ assert!(packet.dcid().is_empty());
+ assert_eq!(&packet.scid()[..], SERVER_CID);
+ assert_eq!(packet.token(), RETRY_TOKEN);
+ assert!(remainder.is_empty());
+ }
+
+ #[test]
+ fn decode_retry_27() {
+ decode_retry(QuicVersion::Draft27, SAMPLE_RETRY_27);
+ }
+
+ #[test]
+ fn decode_retry_28() {
+ decode_retry(QuicVersion::Draft28, SAMPLE_RETRY_28);
+ }
+
+ #[test]
+ fn decode_retry_29() {
+ decode_retry(QuicVersion::Draft29, SAMPLE_RETRY_29);
+ }
+
+ #[test]
+ fn decode_retry_30() {
+ decode_retry(QuicVersion::Draft30, SAMPLE_RETRY_30);
+ }
+
+ #[test]
+ fn decode_retry_31() {
+ decode_retry(QuicVersion::Draft31, SAMPLE_RETRY_31);
+ }
+
+ #[test]
+ fn decode_retry_32() {
+ decode_retry(QuicVersion::Draft32, SAMPLE_RETRY_32);
+ }
+
+ /// Check some packets that are clearly not valid Retry packets.
+ #[test]
+ fn invalid_retry() {
+ fixture_init();
+ let cid_mgr = FixedConnectionIdManager::new(5);
+ let odcid = ConnectionId::from(CLIENT_CID);
+
+ assert!(PublicPacket::decode(&[], &cid_mgr).is_err());
+
+ let (packet, remainder) = PublicPacket::decode(SAMPLE_RETRY_28, &cid_mgr).unwrap();
+ assert!(remainder.is_empty());
+ assert!(packet.is_valid_retry(&odcid));
+
+ let mut damaged_retry = SAMPLE_RETRY_28.to_vec();
+ let last = damaged_retry.len() - 1;
+ damaged_retry[last] ^= 66;
+ let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap();
+ assert!(remainder.is_empty());
+ assert!(!packet.is_valid_retry(&odcid));
+
+ damaged_retry.truncate(last);
+ let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap();
+ assert!(remainder.is_empty());
+ assert!(!packet.is_valid_retry(&odcid));
+
+ // An invalid token should be rejected sooner.
+ damaged_retry.truncate(last - 4);
+ assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err());
+
+ damaged_retry.truncate(last - 1);
+ assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err());
+ }
+
+ const SAMPLE_VN: &[u8] = &[
+ 0x80, 0x00, 0x00, 0x00, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x08,
+ 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, 0xff, 0x00, 0x00, 0x1b, 0xff, 0x00, 0x00,
+ 0x1c, 0xff, 0x00, 0x00, 0x1d, 0xff, 0x00, 0x00, 0x1e, 0xff, 0x00, 0x00, 0x1f, 0xff, 0x00,
+ 0x00, 0x20, 0x0a, 0x0a, 0x0a, 0x0a,
+ ];
+
+ #[test]
+ fn build_vn() {
+ fixture_init();
+ let mut vn = PacketBuilder::version_negotiation(SERVER_CID, CLIENT_CID);
+ // Erase randomness from greasing...
+ assert_eq!(vn.len(), SAMPLE_VN.len());
+ vn[0] &= 0x80;
+ for v in vn.iter_mut().skip(SAMPLE_VN.len() - 4) {
+ *v &= 0x0f;
+ }
+ assert_eq!(&vn, &SAMPLE_VN);
+ }
+
+ #[test]
+ fn parse_vn() {
+ let (packet, remainder) =
+ PublicPacket::decode(SAMPLE_VN, &FixedConnectionIdManager::new(5)).unwrap();
+ assert!(remainder.is_empty());
+ assert_eq!(&packet.dcid[..], SERVER_CID);
+ assert!(packet.scid.is_some());
+ assert_eq!(&packet.scid.unwrap()[..], CLIENT_CID);
+ }
+
+ /// A Version Negotiation packet can have a long connection ID.
+ #[test]
+ fn parse_vn_big_cid() {
+ const BIG_DCID: &[u8] = &[0x44; MAX_CONNECTION_ID_LEN + 1];
+ const BIG_SCID: &[u8] = &[0xee; 255];
+
+ let mut enc = Encoder::from(&[0xff, 0x00, 0x00, 0x00, 0x00][..]);
+ enc.encode_vec(1, BIG_DCID);
+ enc.encode_vec(1, BIG_SCID);
+ enc.encode_uint(4, 0x1a2a_3a4a_u64);
+ enc.encode_uint(4, QuicVersion::default().as_u32());
+ enc.encode_uint(4, 0x5a6a_7a8a_u64);
+
+ let (packet, remainder) =
+ PublicPacket::decode(&enc, &FixedConnectionIdManager::new(5)).unwrap();
+ assert!(remainder.is_empty());
+ assert_eq!(&packet.dcid[..], BIG_DCID);
+ assert!(packet.scid.is_some());
+ assert_eq!(&packet.scid.unwrap()[..], BIG_SCID);
+ }
+
+ #[test]
+ fn decode_pn() {
+ // When the expected value is low, the value doesn't go negative.
+ assert_eq!(PublicPacket::decode_pn(0, 0, 1), 0);
+ assert_eq!(PublicPacket::decode_pn(0, 0xff, 1), 0xff);
+ assert_eq!(PublicPacket::decode_pn(10, 0, 1), 0);
+ assert_eq!(PublicPacket::decode_pn(0x7f, 0, 1), 0);
+ assert_eq!(PublicPacket::decode_pn(0x80, 0, 1), 0x100);
+ assert_eq!(PublicPacket::decode_pn(0x80, 2, 1), 2);
+ assert_eq!(PublicPacket::decode_pn(0x80, 0xff, 1), 0xff);
+ assert_eq!(PublicPacket::decode_pn(0x7ff, 0xfe, 1), 0x7fe);
+
+ // This is invalid by spec, as we are expected to check for overflow around 2^62-1,
+ // but we don't need to worry about overflow
+ // and hitting this is basically impossible in practice.
+ assert_eq!(
+ PublicPacket::decode_pn(0x3fff_ffff_ffff_ffff, 2, 4),
+ 0x4000_0000_0000_0002
+ );
+ }
+
+ #[test]
+ fn chacha20_sample() {
+ const PACKET: &[u8] = &[
+ 0x4c, 0xfe, 0x41, 0x89, 0x65, 0x5e, 0x5c, 0xd5, 0x5c, 0x41, 0xf6, 0x90, 0x80, 0x57,
+ 0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb,
+ ];
+ fixture_init();
+ let (packet, slice) =
+ PublicPacket::decode(PACKET, &FixedConnectionIdManager::new(0)).unwrap();
+ assert!(slice.is_empty());
+ let decrypted = packet
+ .decrypt(&mut CryptoStates::test_chacha(), now())
+ .unwrap();
+ assert_eq!(decrypted.packet_type(), PacketType::Short);
+ assert_eq!(decrypted.pn(), 654_360_564);
+ assert_eq!(&decrypted[..], &[0x01]);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/packet/retry.rs b/third_party/rust/neqo-transport/src/packet/retry.rs
new file mode 100644
index 0000000000..596714aa6d
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/packet/retry.rs
@@ -0,0 +1,63 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![deny(clippy::pedantic)]
+
+use crate::packet::QuicVersion;
+use crate::{Error, Res};
+
+use neqo_common::qerror;
+use neqo_crypto::{aead::Aead, hkdf, TLS_AES_128_GCM_SHA256, TLS_VERSION_1_3};
+
+use std::cell::RefCell;
+
+const RETRY_SECRET_27: &[u8] = &[
+ 0x65, 0x6e, 0x61, 0xe3, 0x36, 0xae, 0x94, 0x17, 0xf7, 0xf0, 0xed, 0xd8, 0xd7, 0x8d, 0x46, 0x1e,
+ 0x2a, 0xa7, 0x08, 0x4a, 0xba, 0x7a, 0x14, 0xc1, 0xe9, 0xf7, 0x26, 0xd5, 0x57, 0x09, 0x16, 0x9a,
+];
+const RETRY_SECRET_29: &[u8] = &[
+ 0x8b, 0x0d, 0x37, 0xeb, 0x85, 0x35, 0x02, 0x2e, 0xbc, 0x8d, 0x76, 0xa2, 0x07, 0xd8, 0x0d, 0xf2,
+ 0x26, 0x46, 0xec, 0x06, 0xdc, 0x80, 0x96, 0x42, 0xc3, 0x0a, 0x8b, 0xaa, 0x2b, 0xaa, 0xff, 0x4c,
+];
+
+/// The AEAD used for Retry is fixed, so use thread local storage.
+fn make_aead(secret: &[u8]) -> Aead {
+ #[cfg(debug_assertions)]
+ ::neqo_crypto::assert_initialized();
+
+ let secret = hkdf::import_key(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, secret).unwrap();
+ Aead::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, &secret, "quic ").unwrap()
+}
+thread_local!(static RETRY_AEAD_27: RefCell<Aead> = RefCell::new(make_aead(RETRY_SECRET_27)));
+thread_local!(static RETRY_AEAD_29: RefCell<Aead> = RefCell::new(make_aead(RETRY_SECRET_29)));
+
+/// Run a function with the appropriate Retry AEAD.
+pub fn use_aead<F, T>(quic_version: QuicVersion, f: F) -> Res<T>
+where
+ F: FnOnce(&Aead) -> Res<T>,
+{
+ match quic_version {
+ QuicVersion::Draft27 | QuicVersion::Draft28 => &RETRY_AEAD_27,
+ QuicVersion::Draft29
+ | QuicVersion::Draft30
+ | QuicVersion::Draft31
+ | QuicVersion::Draft32 => &RETRY_AEAD_29,
+ }
+ .try_with(|aead| f(&aead.borrow()))
+ .map_err(|e| {
+ qerror!("Unable to access Retry AEAD: {:?}", e);
+ Error::InternalError
+ })?
+}
+
+/// Determine how large the expansion is for a given key.
+pub fn expansion(quic_version: QuicVersion) -> usize {
+ if let Ok(ex) = use_aead(quic_version, |aead| Ok(aead.expansion())) {
+ ex
+ } else {
+ panic!("Unable to access Retry AEAD")
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/path.rs b/third_party/rust/neqo-transport/src/path.rs
new file mode 100644
index 0000000000..a4e6b2f361
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/path.rs
@@ -0,0 +1,109 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::net::SocketAddr;
+
+use crate::cid::{ConnectionId, ConnectionIdRef};
+
+use neqo_common::Datagram;
+
+/// This is the MTU that we assume when using IPv6.
+/// We use this size for Initial packets, so we don't need to worry about probing for support.
+/// If the path doesn't support this MTU, we will assume that it doesn't support QUIC.
+///
+/// This is a multiple of 16 greater than the largest possible short header (1 + 20 + 4).
+pub const PATH_MTU_V6: usize = 1337;
+/// The path MTU for IPv4 can be 20 bytes larger than for v6.
+pub const PATH_MTU_V4: usize = PATH_MTU_V6 + 20;
+
+#[derive(Clone, Debug, PartialEq)]
+pub struct Path {
+ local: SocketAddr,
+ remote: SocketAddr,
+ local_cids: Vec<ConnectionId>,
+ remote_cid: ConnectionId,
+ reset_token: Option<[u8; 16]>,
+}
+
+impl Path {
+ /// Create a path from addresses and connection IDs.
+ pub fn new(
+ local: SocketAddr,
+ remote: SocketAddr,
+ local_cid: ConnectionId,
+ remote_cid: ConnectionId,
+ ) -> Self {
+ Self {
+ local,
+ remote,
+ local_cids: vec![local_cid],
+ remote_cid,
+ reset_token: None,
+ }
+ }
+
+ pub fn received_on(&self, d: &Datagram) -> bool {
+ self.local == d.destination() && self.remote == d.source()
+ }
+
+ pub fn mtu(&self) -> usize {
+ if self.local.is_ipv4() {
+ PATH_MTU_V4
+ } else {
+ PATH_MTU_V6 // IPv6
+ }
+ }
+
+ /// Add a connection ID to the local set.
+ pub fn add_local_cid(&mut self, cid: ConnectionId) {
+ self.local_cids.push(cid);
+ }
+
+ /// Determine if the given connection ID is valid.
+ pub fn valid_local_cid(&self, cid: &ConnectionIdRef) -> bool {
+ self.local_cids.iter().any(|c| c == cid)
+ }
+
+ /// Get the first local connection ID.
+ pub fn local_cid(&self) -> &ConnectionId {
+ self.local_cids.first().as_ref().unwrap()
+ }
+
+ /// Set the remote connection ID based on the peer's choice.
+ pub fn set_remote_cid(&mut self, cid: &ConnectionIdRef) {
+ self.remote_cid = ConnectionId::from(cid);
+ }
+
+ /// Access the remote connection ID.
+ pub fn remote_cid(&self) -> &ConnectionId {
+ &self.remote_cid
+ }
+
+ /// Set the stateless reset token for the connection ID that is currently in use.
+ pub fn set_reset_token(&mut self, token: [u8; 16]) {
+ self.reset_token = Some(token);
+ }
+
+ /// Access the reset token.
+ pub fn reset_token(&self) -> Option<&[u8; 16]> {
+ self.reset_token.as_ref()
+ }
+
+ /// Make a datagram.
+ pub fn datagram<V: Into<Vec<u8>>>(&self, payload: V) -> Datagram {
+ Datagram::new(self.local, self.remote, payload)
+ }
+
+ /// Get local address as `SocketAddr`
+ pub fn local_address(&self) -> SocketAddr {
+ self.local
+ }
+
+ /// Get remote address as `SocketAddr`
+ pub fn remote_address(&self) -> SocketAddr {
+ self.remote
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/qlog.rs b/third_party/rust/neqo-transport/src/qlog.rs
new file mode 100644
index 0000000000..b4cc1adf67
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/qlog.rs
@@ -0,0 +1,442 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Functions that handle capturing QLOG traces.
+
+use std::convert::TryFrom;
+use std::ops::RangeInclusive;
+use std::string::String;
+use std::time::Duration;
+
+use qlog::{self, event::Event, PacketHeader, QuicFrame};
+
+use neqo_common::{hex, qinfo, qlog::NeqoQlog, Decoder};
+
+use crate::connection::State;
+use crate::frame::{self, Frame};
+use crate::packet::{DecryptedPacket, PacketNumber, PacketType, PublicPacket};
+use crate::path::Path;
+use crate::tparams::{self, TransportParametersHandler};
+use crate::tracking::SentPacket;
+use crate::QuicVersion;
+
+pub fn connection_tparams_set(qlog: &mut NeqoQlog, tph: &TransportParametersHandler) {
+ qlog.add_event(|| {
+ let remote = tph.remote();
+ Some(Event::transport_parameters_set(
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ if let Some(ocid) = remote.get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID) {
+ // Cannot use packet::ConnectionId's Display trait implementation
+ // because it does not include the 0x prefix.
+ Some(hex(ocid))
+ } else {
+ None
+ },
+ if let Some(srt) = remote.get_bytes(tparams::STATELESS_RESET_TOKEN) {
+ Some(hex(srt))
+ } else {
+ None
+ },
+ if remote.get_empty(tparams::DISABLE_MIGRATION) {
+ Some(true)
+ } else {
+ None
+ },
+ Some(remote.get_integer(tparams::IDLE_TIMEOUT)),
+ Some(remote.get_integer(tparams::MAX_UDP_PAYLOAD_SIZE)),
+ Some(remote.get_integer(tparams::ACK_DELAY_EXPONENT)),
+ Some(remote.get_integer(tparams::MAX_ACK_DELAY)),
+ // TODO(hawkinsw@obs.cr): We do not yet handle ACTIVE_CONNECTION_ID_LIMIT in tparams yet.
+ None,
+ Some(format!("{}", remote.get_integer(tparams::INITIAL_MAX_DATA))),
+ Some(format!(
+ "{}",
+ remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL)
+ )),
+ Some(format!(
+ "{}",
+ remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE)
+ )),
+ Some(format!(
+ "{}",
+ remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI)
+ )),
+ Some(format!(
+ "{}",
+ remote.get_integer(tparams::INITIAL_MAX_STREAMS_BIDI)
+ )),
+ Some(format!(
+ "{}",
+ remote.get_integer(tparams::INITIAL_MAX_STREAMS_UNI)
+ )),
+ // TODO(hawkinsw@obs.cr): We do not yet handle PREFERRED_ADDRESS in tparams yet.
+ None,
+ ))
+ })
+}
+
+pub fn server_connection_started(qlog: &mut NeqoQlog, path: &Path) {
+ connection_started(qlog, path)
+}
+
+pub fn client_connection_started(qlog: &mut NeqoQlog, path: &Path) {
+ connection_started(qlog, path)
+}
+
+fn connection_started(qlog: &mut NeqoQlog, path: &Path) {
+ qlog.add_event(|| {
+ Some(Event::connection_started(
+ if path.local_address().ip().is_ipv4() {
+ "ipv4".into()
+ } else {
+ "ipv6".into()
+ },
+ format!("{}", path.local_address().ip()),
+ format!("{}", path.remote_address().ip()),
+ Some("QUIC".into()),
+ path.local_address().port().into(),
+ path.remote_address().port().into(),
+ Some(format!("{:x}", QuicVersion::default().as_u32())),
+ Some(format!("{}", path.local_cid())),
+ Some(format!("{}", path.remote_cid())),
+ ))
+ })
+}
+
+pub fn connection_state_updated(qlog: &mut NeqoQlog, new: &State) {
+ qlog.add_event(|| {
+ Some(Event::connection_state_updated_min(match new {
+ State::Init => qlog::ConnectionState::Attempted,
+ State::WaitInitial => qlog::ConnectionState::Attempted,
+ State::Handshaking => qlog::ConnectionState::Handshake,
+ State::Connected => qlog::ConnectionState::Active,
+ State::Confirmed => qlog::ConnectionState::Active,
+ State::Closing { .. } => qlog::ConnectionState::Draining,
+ State::Draining { .. } => qlog::ConnectionState::Draining,
+ State::Closed { .. } => qlog::ConnectionState::Closed,
+ }))
+ })
+}
+
+pub fn packet_sent(
+ qlog: &mut NeqoQlog,
+ pt: PacketType,
+ pn: PacketNumber,
+ plen: usize,
+ body: &[u8],
+) {
+ qlog.add_event_with_stream(|stream| {
+ let mut d = Decoder::from(body);
+
+ stream.add_event(Event::packet_sent_min(
+ to_qlog_pkt_type(pt),
+ PacketHeader::new(
+ pn,
+ Some(u64::try_from(plen).unwrap()),
+ None,
+ None,
+ None,
+ None,
+ ),
+ Some(Vec::new()),
+ ))?;
+
+ while d.remaining() > 0 {
+ match Frame::decode(&mut d) {
+ Ok(f) => {
+ stream.add_frame(frame_to_qlogframe(&f), false)?;
+ }
+ Err(_) => {
+ qinfo!("qlog: invalid frame");
+ break;
+ }
+ }
+ }
+
+ stream.finish_frames()
+ })
+}
+
+pub fn packet_dropped(qlog: &mut NeqoQlog, payload: &PublicPacket) {
+ qlog.add_event(|| {
+ Some(Event::packet_dropped(
+ Some(to_qlog_pkt_type(payload.packet_type())),
+ Some(u64::try_from(payload.len()).unwrap()),
+ None,
+ ))
+ })
+}
+
+pub fn packets_lost(qlog: &mut NeqoQlog, pkts: &[SentPacket]) {
+ qlog.add_event_with_stream(|stream| {
+ for pkt in pkts {
+ stream.add_event(Event::packet_lost_min(
+ to_qlog_pkt_type(pkt.pt),
+ pkt.pn.to_string(),
+ Vec::new(),
+ ))?;
+
+ stream.finish_frames()?;
+ }
+ Ok(())
+ })
+}
+
+pub fn packet_received(
+ qlog: &mut NeqoQlog,
+ public_packet: &PublicPacket,
+ payload: &DecryptedPacket,
+) {
+ qlog.add_event_with_stream(|stream| {
+ let mut d = Decoder::from(&payload[..]);
+
+ stream.add_event(Event::packet_received(
+ to_qlog_pkt_type(payload.packet_type()),
+ PacketHeader::new(
+ payload.pn(),
+ Some(u64::try_from(public_packet.len()).unwrap()),
+ None,
+ None,
+ None,
+ None,
+ ),
+ Some(Vec::new()),
+ None,
+ None,
+ None,
+ ))?;
+
+ while d.remaining() > 0 {
+ match Frame::decode(&mut d) {
+ Ok(f) => stream.add_frame(frame_to_qlogframe(&f), false)?,
+ Err(_) => {
+ qinfo!("qlog: invalid frame");
+ break;
+ }
+ }
+ }
+
+ stream.finish_frames()
+ })
+}
+
+#[allow(dead_code)]
+pub enum QlogMetric {
+ MinRtt(Duration),
+ SmoothedRtt(Duration),
+ LatestRtt(Duration),
+ RttVariance(u64),
+ MaxAckDelay(u64),
+ PtoCount(usize),
+ CongestionWindow(usize),
+ BytesInFlight(usize),
+ SsThresh(usize),
+ PacketsInFlight(u64),
+ InRecovery(bool),
+ PacingRate(u64),
+}
+
+pub fn metrics_updated(qlog: &mut NeqoQlog, updated_metrics: &[QlogMetric]) {
+ debug_assert!(!updated_metrics.is_empty());
+
+ qlog.add_event(|| {
+ let mut min_rtt: Option<u64> = None;
+ let mut smoothed_rtt: Option<u64> = None;
+ let mut latest_rtt: Option<u64> = None;
+ let mut rtt_variance: Option<u64> = None;
+ let mut max_ack_delay: Option<u64> = None;
+ let mut pto_count: Option<u64> = None;
+ let mut congestion_window: Option<u64> = None;
+ let mut bytes_in_flight: Option<u64> = None;
+ let mut ssthresh: Option<u64> = None;
+ let mut packets_in_flight: Option<u64> = None;
+ let mut in_recovery: Option<bool> = None;
+ let mut pacing_rate: Option<u64> = None;
+
+ for metric in updated_metrics {
+ match metric {
+ QlogMetric::MinRtt(v) => min_rtt = Some(u64::try_from(v.as_millis()).unwrap()),
+ QlogMetric::SmoothedRtt(v) => {
+ smoothed_rtt = Some(u64::try_from(v.as_millis()).unwrap())
+ }
+ QlogMetric::LatestRtt(v) => {
+ latest_rtt = Some(u64::try_from(v.as_millis()).unwrap())
+ }
+ QlogMetric::RttVariance(v) => rtt_variance = Some(*v),
+ QlogMetric::MaxAckDelay(v) => max_ack_delay = Some(*v),
+ QlogMetric::PtoCount(v) => pto_count = Some(u64::try_from(*v).unwrap()),
+ QlogMetric::CongestionWindow(v) => {
+ congestion_window = Some(u64::try_from(*v).unwrap())
+ }
+ QlogMetric::BytesInFlight(v) => bytes_in_flight = Some(u64::try_from(*v).unwrap()),
+ QlogMetric::SsThresh(v) => ssthresh = Some(u64::try_from(*v).unwrap()),
+ QlogMetric::PacketsInFlight(v) => packets_in_flight = Some(*v),
+ QlogMetric::InRecovery(v) => in_recovery = Some(*v),
+ QlogMetric::PacingRate(v) => pacing_rate = Some(*v),
+ }
+ }
+
+ Some(Event::metrics_updated(
+ min_rtt,
+ smoothed_rtt,
+ latest_rtt,
+ rtt_variance,
+ max_ack_delay,
+ pto_count,
+ congestion_window,
+ bytes_in_flight,
+ ssthresh,
+ packets_in_flight,
+ in_recovery,
+ pacing_rate,
+ ))
+ })
+}
+
+// Helper functions
+
+fn frame_to_qlogframe(frame: &Frame) -> QuicFrame {
+ match frame {
+ Frame::Padding => QuicFrame::padding(),
+ Frame::Ping => QuicFrame::ping(),
+ Frame::Ack {
+ largest_acknowledged,
+ ack_delay,
+ first_ack_range,
+ ack_ranges,
+ } => {
+ let ranges =
+ Frame::decode_ack_frame(*largest_acknowledged, *first_ack_range, ack_ranges).ok();
+
+ QuicFrame::ack(
+ Some(ack_delay.to_string()),
+ ranges.map(|all| {
+ all.into_iter()
+ .map(RangeInclusive::into_inner)
+ .collect::<Vec<_>>()
+ }),
+ None,
+ None,
+ None,
+ )
+ }
+ Frame::ResetStream {
+ stream_id,
+ application_error_code,
+ final_size,
+ } => QuicFrame::reset_stream(
+ stream_id.as_u64().to_string(),
+ *application_error_code,
+ final_size.to_string(),
+ ),
+ Frame::StopSending {
+ stream_id,
+ application_error_code,
+ } => QuicFrame::stop_sending(stream_id.as_u64().to_string(), *application_error_code),
+ Frame::Crypto { offset, data } => {
+ QuicFrame::crypto(offset.to_string(), data.len().to_string())
+ }
+ Frame::NewToken { token } => QuicFrame::new_token(token.len().to_string(), hex(&token)),
+ Frame::Stream {
+ fin,
+ stream_id,
+ offset,
+ data,
+ ..
+ } => QuicFrame::stream(
+ stream_id.as_u64().to_string(),
+ offset.to_string(),
+ data.len().to_string(),
+ *fin,
+ None,
+ ),
+ Frame::MaxData { maximum_data } => QuicFrame::max_data(maximum_data.to_string()),
+ Frame::MaxStreamData {
+ stream_id,
+ maximum_stream_data,
+ } => QuicFrame::max_stream_data(
+ stream_id.as_u64().to_string(),
+ maximum_stream_data.to_string(),
+ ),
+ Frame::MaxStreams {
+ stream_type,
+ maximum_streams,
+ } => QuicFrame::max_streams(
+ match stream_type {
+ frame::StreamType::BiDi => qlog::StreamType::Bidirectional,
+ frame::StreamType::UniDi => qlog::StreamType::Unidirectional,
+ },
+ maximum_streams.as_u64().to_string(),
+ ),
+ Frame::DataBlocked { data_limit } => QuicFrame::data_blocked(data_limit.to_string()),
+ Frame::StreamDataBlocked {
+ stream_id,
+ stream_data_limit,
+ } => QuicFrame::stream_data_blocked(
+ stream_id.as_u64().to_string(),
+ stream_data_limit.to_string(),
+ ),
+ Frame::StreamsBlocked {
+ stream_type,
+ stream_limit,
+ } => QuicFrame::streams_blocked(
+ match stream_type {
+ frame::StreamType::BiDi => qlog::StreamType::Bidirectional,
+ frame::StreamType::UniDi => qlog::StreamType::Unidirectional,
+ },
+ stream_limit.as_u64().to_string(),
+ ),
+ Frame::NewConnectionId {
+ sequence_number,
+ retire_prior,
+ connection_id,
+ stateless_reset_token,
+ } => QuicFrame::new_connection_id(
+ sequence_number.to_string(),
+ retire_prior.to_string(),
+ connection_id.len() as u64,
+ hex(&connection_id),
+ hex(stateless_reset_token),
+ ),
+ Frame::RetireConnectionId { sequence_number } => {
+ QuicFrame::retire_connection_id(sequence_number.to_string())
+ }
+ Frame::PathChallenge { data } => QuicFrame::path_challenge(Some(hex(data))),
+ Frame::PathResponse { data } => QuicFrame::path_response(Some(hex(data))),
+ Frame::ConnectionClose {
+ error_code,
+ frame_type,
+ reason_phrase,
+ } => QuicFrame::connection_close(
+ match error_code {
+ frame::CloseError::Transport(_) => qlog::ErrorSpace::TransportError,
+ frame::CloseError::Application(_) => qlog::ErrorSpace::ApplicationError,
+ },
+ error_code.code(),
+ 0,
+ String::from_utf8_lossy(&reason_phrase).to_string(),
+ Some(frame_type.to_string()),
+ ),
+ Frame::HandshakeDone => QuicFrame::handshake_done(),
+ }
+}
+
+fn to_qlog_pkt_type(ptype: PacketType) -> qlog::PacketType {
+ match ptype {
+ PacketType::Initial => qlog::PacketType::Initial,
+ PacketType::Handshake => qlog::PacketType::Handshake,
+ PacketType::ZeroRtt => qlog::PacketType::ZeroRtt,
+ PacketType::Short => qlog::PacketType::OneRtt,
+ PacketType::Retry => qlog::PacketType::Retry,
+ PacketType::VersionNegotiation => qlog::PacketType::VersionNegotiation,
+ PacketType::OtherVersion => qlog::PacketType::Unknown,
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/recovery.rs b/third_party/rust/neqo-transport/src/recovery.rs
new file mode 100644
index 0000000000..bac07dc988
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/recovery.rs
@@ -0,0 +1,1470 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Tracking of sent packets and detecting their loss.
+
+#![deny(clippy::pedantic)]
+
+use std::cmp::{max, min};
+use std::collections::BTreeMap;
+use std::mem;
+use std::ops::RangeInclusive;
+use std::time::{Duration, Instant};
+
+use smallvec::{smallvec, SmallVec};
+
+use neqo_common::{qdebug, qinfo, qlog::NeqoQlog, qtrace};
+
+use crate::cc::CongestionControlAlgorithm;
+use crate::connection::LOCAL_IDLE_TIMEOUT;
+use crate::crypto::CryptoRecoveryToken;
+use crate::flow_mgr::FlowControlRecoveryToken;
+use crate::packet::PacketNumber;
+use crate::qlog::{self, QlogMetric};
+use crate::send_stream::StreamRecoveryToken;
+use crate::stats::{Stats, StatsCell};
+use crate::tracking::{AckToken, PNSpace, PNSpaceSet, SentPacket};
+use crate::PacketSender;
+
+pub const GRANULARITY: Duration = Duration::from_millis(20);
+/// The default value for the maximum time a peer can delay acknowledgment
+/// of an ack-eliciting packet.
+pub const MAX_ACK_DELAY: Duration = Duration::from_millis(25);
+// Defined in -recovery 6.2 as 333ms but using lower value.
+const INITIAL_RTT: Duration = Duration::from_millis(100);
+pub(crate) const PACKET_THRESHOLD: u64 = 3;
+/// `ACK_ONLY_SIZE_LIMIT` is the minimum size of the congestion window.
+/// If the congestion window is this small, we will only send ACK frames.
+pub(crate) const ACK_ONLY_SIZE_LIMIT: usize = 256;
+/// The number of packets we send on a PTO.
+/// And the number to declare lost when the PTO timer is hit.
+pub const PTO_PACKET_COUNT: usize = 2;
+
+#[derive(Debug, Clone)]
+#[allow(clippy::module_name_repetitions)]
+pub enum RecoveryToken {
+ Ack(AckToken),
+ Stream(StreamRecoveryToken),
+ Crypto(CryptoRecoveryToken),
+ Flow(FlowControlRecoveryToken),
+ HandshakeDone,
+ NewToken(usize),
+}
+
+#[derive(Debug)]
+struct RttVals {
+ first_sample_time: Option<Instant>,
+ latest_rtt: Duration,
+ smoothed_rtt: Duration,
+ rttvar: Duration,
+ min_rtt: Duration,
+ max_ack_delay: Duration,
+}
+
+impl RttVals {
+ pub fn set_initial_rtt(&mut self, rtt: Duration) {
+ // Only allow this when there are no samples.
+ debug_assert!(self.first_sample_time.is_none());
+ self.latest_rtt = rtt;
+ self.min_rtt = rtt;
+ self.smoothed_rtt = rtt;
+ self.rttvar = rtt / 2;
+ }
+
+ pub fn set_peer_max_ack_delay(&mut self, mad: Duration) {
+ self.max_ack_delay = mad;
+ }
+
+ fn update_rtt(
+ &mut self,
+ mut qlog: &mut NeqoQlog,
+ mut rtt_sample: Duration,
+ ack_delay: Duration,
+ now: Instant,
+ ) {
+ // min_rtt ignores ack delay.
+ self.min_rtt = min(self.min_rtt, rtt_sample);
+ // Note: the caller adjusts `ack_delay` based on `max_ack_delay`.
+ // Adjust for ack delay unless it goes below `min_rtt`.
+ if rtt_sample - self.min_rtt >= ack_delay {
+ rtt_sample -= ack_delay;
+ }
+
+ if self.first_sample_time.is_none() {
+ self.set_initial_rtt(rtt_sample);
+ self.first_sample_time = Some(now);
+ } else {
+ // Calculate EWMA RTT (based on {{?RFC6298}}).
+ let rttvar_sample = if self.smoothed_rtt > rtt_sample {
+ self.smoothed_rtt - rtt_sample
+ } else {
+ rtt_sample - self.smoothed_rtt
+ };
+
+ self.latest_rtt = rtt_sample;
+ self.rttvar = (self.rttvar * 3 + rttvar_sample) / 4;
+ self.smoothed_rtt = (self.smoothed_rtt * 7 + rtt_sample) / 8;
+ }
+ qtrace!(
+ "RTT latest={:?} -> estimate={:?}~{:?}",
+ self.latest_rtt,
+ self.smoothed_rtt,
+ self.rttvar
+ );
+ qlog::metrics_updated(
+ &mut qlog,
+ &[
+ QlogMetric::LatestRtt(self.latest_rtt),
+ QlogMetric::MinRtt(self.min_rtt),
+ QlogMetric::SmoothedRtt(self.smoothed_rtt),
+ ],
+ );
+ }
+
+ pub fn rtt(&self) -> Duration {
+ self.smoothed_rtt
+ }
+
+ fn pto(&self, pn_space: PNSpace) -> Duration {
+ self.rtt()
+ + max(4 * self.rttvar, GRANULARITY)
+ + if pn_space == PNSpace::ApplicationData {
+ self.max_ack_delay
+ } else {
+ Duration::from_millis(0)
+ }
+ }
+
+ fn first_sample_time(&self) -> Option<Instant> {
+ self.first_sample_time
+ }
+}
+
+impl Default for RttVals {
+ fn default() -> Self {
+ Self {
+ first_sample_time: None,
+ latest_rtt: INITIAL_RTT,
+ smoothed_rtt: INITIAL_RTT,
+ rttvar: INITIAL_RTT / 2,
+ min_rtt: INITIAL_RTT,
+ max_ack_delay: MAX_ACK_DELAY,
+ }
+ }
+}
+
+/// `SendProfile` tells a sender how to send packets.
+#[derive(Debug)]
+pub struct SendProfile {
+ limit: usize,
+ pto: Option<PNSpace>,
+ probe: PNSpaceSet,
+ paced: bool,
+}
+
+impl SendProfile {
+ pub fn new_limited(limit: usize) -> Self {
+ // When the limit is too low, we only send ACK frames.
+ // Set the limit to `ACK_ONLY_SIZE_LIMIT - 1` to ensure that
+ // ACK-only packets are still limited in size.
+ Self {
+ limit: max(ACK_ONLY_SIZE_LIMIT - 1, limit),
+ pto: None,
+ probe: PNSpaceSet::default(),
+ paced: false,
+ }
+ }
+
+ pub fn new_paced() -> Self {
+ // When pacing, we still allow ACK frames to be sent.
+ Self {
+ limit: ACK_ONLY_SIZE_LIMIT - 1,
+ pto: None,
+ probe: PNSpaceSet::default(),
+ paced: true,
+ }
+ }
+
+ pub fn new_pto(pn_space: PNSpace, mtu: usize, probe: PNSpaceSet) -> Self {
+ debug_assert!(mtu > ACK_ONLY_SIZE_LIMIT);
+ debug_assert!(probe[pn_space]);
+ Self {
+ limit: mtu,
+ pto: Some(pn_space),
+ probe,
+ paced: false,
+ }
+ }
+
+ /// Whether probing this space is helpful. This isn't necessarily the space
+ /// that caused the timer to pop, but it is helpful to send a PING in a space
+ /// that has the PTO timer armed.
+ pub fn should_probe(&self, space: PNSpace) -> bool {
+ self.probe[space]
+ }
+
+ /// Determine whether an ACK-only packet should be sent for the given packet
+ /// number space.
+ /// Send only ACKs either: when the space available is too small, or when a PTO
+ /// exists for a later packet number space (which should get the most space).
+ pub fn ack_only(&self, space: PNSpace) -> bool {
+ self.limit < ACK_ONLY_SIZE_LIMIT || self.pto.map_or(false, |sp| space < sp)
+ }
+
+ pub fn paced(&self) -> bool {
+ self.paced
+ }
+
+ pub fn limit(&self) -> usize {
+ self.limit
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct LossRecoverySpace {
+ space: PNSpace,
+ largest_acked: Option<PacketNumber>,
+ largest_acked_sent_time: Option<Instant>,
+ /// The time used to calculate the PTO timer for this space.
+ /// This is the time that the last ACK-eliciting packet in this space
+ /// was sent. This might be the time that a probe was sent.
+ pto_base_time: Option<Instant>,
+ /// The number of outstanding packets in this space that are in flight.
+ /// This might be less than the number of ACK-eliciting packets,
+ /// because PTO packets don't count.
+ in_flight_outstanding: u64,
+ sent_packets: BTreeMap<u64, SentPacket>,
+ /// The time that the first out-of-order packet was sent.
+ /// This is `None` if there were no out-of-order packets detected.
+ /// When set to `Some(T)`, time-based loss detection should be enabled.
+ first_ooo_time: Option<Instant>,
+}
+
+impl LossRecoverySpace {
+ pub fn new(space: PNSpace) -> Self {
+ Self {
+ space,
+ largest_acked: None,
+ largest_acked_sent_time: None,
+ pto_base_time: None,
+ in_flight_outstanding: 0,
+ sent_packets: BTreeMap::default(),
+ first_ooo_time: None,
+ }
+ }
+
+ #[must_use]
+ pub fn space(&self) -> PNSpace {
+ self.space
+ }
+
+ /// Find the time we sent the first packet that is lower than the
+ /// largest acknowledged and that isn't yet declared lost.
+ /// Use the value we prepared earlier in `detect_lost_packets`.
+ #[must_use]
+ pub fn loss_recovery_timer_start(&self) -> Option<Instant> {
+ self.first_ooo_time
+ }
+
+ pub fn in_flight_outstanding(&self) -> bool {
+ self.in_flight_outstanding > 0
+ }
+
+ pub fn pto_packets(&mut self, count: usize) -> impl Iterator<Item = &SentPacket> {
+ self.sent_packets
+ .iter_mut()
+ .filter_map(|(pn, sent)| {
+ if sent.pto() {
+ qtrace!("PTO: marking packet {} lost ", pn);
+ Some(&*sent)
+ } else {
+ None
+ }
+ })
+ .take(count)
+ }
+
+ pub fn pto_base_time(&self) -> Option<Instant> {
+ if self.in_flight_outstanding() {
+ debug_assert!(self.pto_base_time.is_some());
+ self.pto_base_time
+ } else if self.space == PNSpace::ApplicationData {
+ None
+ } else {
+ // Nasty special case to prevent handshake deadlocks.
+ // A client needs to keep the PTO timer armed to prevent a stall
+ // of the handshake. Technically, this has to stop once we receive
+ // an ACK of Handshake or 1-RTT, or when we receive HANDSHAKE_DONE,
+ // but a few extra probes won't hurt.
+ // A server shouldn't arm its PTO timer this way. The server sends
+ // ack-eliciting, in-flight packets immediately so this only
+ // happens when the server has nothing outstanding. If we had
+ // client authentication, this might cause some extra probes,
+ // but they would be harmless anyway.
+ self.pto_base_time
+ }
+ }
+
+ pub fn on_packet_sent(&mut self, sent_packet: SentPacket) {
+ if sent_packet.ack_eliciting() {
+ self.pto_base_time = Some(sent_packet.time_sent);
+ self.in_flight_outstanding += 1;
+ } else if self.space != PNSpace::ApplicationData && self.pto_base_time.is_none() {
+ // For Initial and Handshake spaces, make sure that we have a PTO baseline
+ // always. See `LossRecoverySpace::pto_base_time()` for details.
+ self.pto_base_time = Some(sent_packet.time_sent);
+ }
+ self.sent_packets.insert(sent_packet.pn, sent_packet);
+ }
+
+ fn remove_packet(&mut self, p: &SentPacket) {
+ if p.ack_eliciting() {
+ debug_assert!(self.in_flight_outstanding > 0);
+ self.in_flight_outstanding -= 1;
+ if self.in_flight_outstanding == 0 {
+ qtrace!("remove_packet outstanding == 0 for space {}", self.space);
+
+ // See above comments; keep PTO armed for Initial/Handshake even
+ // if no outstanding packets.
+ if self.space == PNSpace::ApplicationData {
+ self.pto_base_time = None;
+ }
+ }
+ }
+ }
+
+ /// Remove all acknowledged packets.
+ /// Returns all the acknowledged packets, with the largest packet number first.
+ /// ...and a boolean indicating if any of those packets were ack-eliciting.
+ /// This operates more efficiently because it assumes that the input is sorted
+ /// in the order that an ACK frame is (from the top).
+ fn remove_acked<R>(&mut self, acked_ranges: R, stats: &mut Stats) -> (Vec<SentPacket>, bool)
+ where
+ R: IntoIterator<Item = RangeInclusive<u64>>,
+ R::IntoIter: ExactSizeIterator,
+ {
+ let acked_ranges = acked_ranges.into_iter();
+ let mut keep = Vec::with_capacity(acked_ranges.len());
+
+ let mut acked = Vec::new();
+ let mut eliciting = false;
+ for range in acked_ranges {
+ let first_keep = *range.end() + 1;
+ if let Some((&first, _)) = self.sent_packets.range(range).next() {
+ let mut tail = self.sent_packets.split_off(&first);
+ if let Some((&next, _)) = tail.range(first_keep..).next() {
+ keep.push(tail.split_off(&next));
+ }
+ for (_, p) in tail.into_iter().rev() {
+ self.remove_packet(&p);
+ eliciting |= p.ack_eliciting();
+ if p.lost() {
+ stats.late_ack += 1;
+ }
+ if p.pto_fired() {
+ stats.pto_ack += 1;
+ }
+ acked.push(p);
+ }
+ }
+ }
+
+ for mut k in keep.into_iter().rev() {
+ self.sent_packets.append(&mut k);
+ }
+
+ (acked, eliciting)
+ }
+
+ /// Remove all tracked packets from the space.
+ /// This is called by a client when 0-RTT packets are dropped, when a Retry is received
+ /// and when keys are dropped.
+ fn remove_ignored(&mut self) -> impl Iterator<Item = SentPacket> {
+ self.in_flight_outstanding = 0;
+ mem::take(&mut self.sent_packets)
+ .into_iter()
+ .map(|(_, v)| v)
+ }
+
+ /// Remove old packets that we've been tracking in case they get acknowledged.
+ /// We try to keep these around until a probe is sent for them, so it is
+ /// important that `cd` is set to at least the current PTO time; otherwise we
+ /// might remove all in-flight packets and stop sending probes.
+ #[allow(clippy::option_if_let_else, clippy::unknown_clippy_lints)] // Hard enough to read as-is.
+ fn remove_old_lost(&mut self, now: Instant, cd: Duration) {
+ let mut it = self.sent_packets.iter();
+ // If the first item is not expired, do nothing.
+ if it.next().map_or(false, |(_, p)| p.expired(now, cd)) {
+ // Find the index of the first unexpired packet.
+ let to_remove = if let Some(first_keep) =
+ it.find_map(|(i, p)| if p.expired(now, cd) { None } else { Some(*i) })
+ {
+ // Some packets haven't expired, so keep those.
+ let keep = self.sent_packets.split_off(&first_keep);
+ mem::replace(&mut self.sent_packets, keep)
+ } else {
+ // All packets are expired.
+ mem::take(&mut self.sent_packets)
+ };
+ for (_, p) in to_remove {
+ self.remove_packet(&p);
+ }
+ }
+ }
+
+ /// Detect lost packets.
+ /// `loss_delay` is the time we will wait before declaring something lost.
+ /// `cleanup_delay` is the time we will wait before cleaning up a lost packet.
+ pub fn detect_lost_packets(
+ &mut self,
+ now: Instant,
+ loss_delay: Duration,
+ cleanup_delay: Duration,
+ lost_packets: &mut Vec<SentPacket>,
+ ) {
+ // Housekeeping.
+ self.remove_old_lost(now, cleanup_delay);
+
+ // Packets sent before this time are deemed lost.
+ let lost_deadline = now - loss_delay;
+ qtrace!(
+ "detect lost {}: now={:?} delay={:?} deadline={:?}",
+ self.space,
+ now,
+ loss_delay,
+ lost_deadline
+ );
+ self.first_ooo_time = None;
+
+ let largest_acked = self.largest_acked;
+
+ // Lost for retrans/CC purposes
+ let mut lost_pns = SmallVec::<[_; 8]>::new();
+
+ for (pn, packet) in self
+ .sent_packets
+ .iter_mut()
+ // BTreeMap iterates in order of ascending PN
+ .take_while(|(&k, _)| Some(k) < largest_acked)
+ {
+ if packet.time_sent <= lost_deadline {
+ qtrace!(
+ "lost={}, time sent {:?} is before lost_deadline {:?}",
+ pn,
+ packet.time_sent,
+ lost_deadline
+ );
+ } else if largest_acked >= Some(*pn + PACKET_THRESHOLD) {
+ qtrace!(
+ "lost={}, is >= {} from largest acked {:?}",
+ pn,
+ PACKET_THRESHOLD,
+ largest_acked
+ );
+ } else {
+ self.first_ooo_time = Some(packet.time_sent);
+ // No more packets can be declared lost after this one.
+ break;
+ };
+
+ if packet.declare_lost(now) {
+ lost_pns.push(*pn);
+ }
+ }
+
+ lost_packets.extend(lost_pns.iter().map(|pn| self.sent_packets[pn].clone()));
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct LossRecoverySpaces {
+ /// When we have all of the loss recovery spaces, this will use a separate
+ /// allocation, but this is reduced once the handshake is done.
+ spaces: SmallVec<[LossRecoverySpace; 1]>,
+}
+
+impl LossRecoverySpaces {
+ fn idx(space: PNSpace) -> usize {
+ match space {
+ PNSpace::ApplicationData => 0,
+ PNSpace::Handshake => 1,
+ PNSpace::Initial => 2,
+ }
+ }
+
+ /// Drop a packet number space and return all the packets that were
+ /// outstanding, so that those can be marked as lost.
+ /// # Panics
+ /// If the space has already been removed.
+ pub fn drop_space(&mut self, space: PNSpace) -> Vec<SentPacket> {
+ let sp = match space {
+ PNSpace::Initial => self.spaces.pop(),
+ PNSpace::Handshake => {
+ let sp = self.spaces.pop();
+ self.spaces.shrink_to_fit();
+ sp
+ }
+ PNSpace::ApplicationData => panic!("discarding application space"),
+ };
+ let mut sp = sp.unwrap();
+ assert_eq!(sp.space(), space, "dropping spaces out of order");
+ sp.remove_ignored().collect()
+ }
+
+ pub fn get(&self, space: PNSpace) -> Option<&LossRecoverySpace> {
+ self.spaces.get(Self::idx(space))
+ }
+
+ pub fn get_mut(&mut self, space: PNSpace) -> Option<&mut LossRecoverySpace> {
+ self.spaces.get_mut(Self::idx(space))
+ }
+
+ fn iter(&self) -> impl Iterator<Item = &LossRecoverySpace> {
+ self.spaces.iter()
+ }
+
+ fn iter_mut(&mut self) -> impl Iterator<Item = &mut LossRecoverySpace> {
+ self.spaces.iter_mut()
+ }
+}
+
+impl Default for LossRecoverySpaces {
+ fn default() -> Self {
+ Self {
+ spaces: smallvec![
+ LossRecoverySpace::new(PNSpace::ApplicationData),
+ LossRecoverySpace::new(PNSpace::Handshake),
+ LossRecoverySpace::new(PNSpace::Initial),
+ ],
+ }
+ }
+}
+
+#[derive(Debug)]
+struct PtoState {
+ /// The packet number space that caused the PTO to fire.
+ space: PNSpace,
+ /// The number of probes that we have sent.
+ count: usize,
+ packets: usize,
+ /// The complete set of packet number spaces that can have probes sent.
+ probe: PNSpaceSet,
+}
+
+impl PtoState {
+ pub fn new(space: PNSpace, probe: PNSpaceSet) -> Self {
+ debug_assert!(probe[space]);
+ Self {
+ space,
+ count: 1,
+ packets: PTO_PACKET_COUNT,
+ probe,
+ }
+ }
+
+ pub fn pto(&mut self, space: PNSpace, probe: PNSpaceSet) {
+ debug_assert!(probe[space]);
+ self.space = space;
+ self.count += 1;
+ self.packets = PTO_PACKET_COUNT;
+ self.probe = probe;
+ }
+
+ pub fn count(&self) -> usize {
+ self.count
+ }
+
+ pub fn count_pto(&self, stats: &mut Stats) {
+ stats.add_pto_count(self.count);
+ }
+
+ /// Generate a sending profile, indicating what space it should be from.
+ /// This takes a packet from the supply or returns an ack-only profile if it can't.
+ pub fn send_profile(&mut self, mtu: usize) -> SendProfile {
+ if self.packets > 0 {
+ self.packets -= 1;
+ SendProfile::new_pto(self.space, mtu, self.probe)
+ } else {
+ SendProfile::new_limited(0)
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct LossRecovery {
+ /// When the handshake was confirmed, if it has been.
+ confirmed_time: Option<Instant>,
+ pto_state: Option<PtoState>,
+ rtt_vals: RttVals,
+ packet_sender: PacketSender,
+ spaces: LossRecoverySpaces,
+ qlog: NeqoQlog,
+ stats: StatsCell,
+}
+
+impl LossRecovery {
+ pub fn new(alg: CongestionControlAlgorithm, stats: StatsCell) -> Self {
+ Self {
+ confirmed_time: None,
+ pto_state: None,
+ rtt_vals: RttVals::default(),
+ packet_sender: PacketSender::new(alg),
+ spaces: LossRecoverySpaces::default(),
+ qlog: NeqoQlog::default(),
+ stats,
+ }
+ }
+
+ #[cfg(test)]
+ pub fn cwnd(&self) -> usize {
+ self.packet_sender.cwnd()
+ }
+
+ pub fn rtt(&self) -> Duration {
+ self.rtt_vals.rtt()
+ }
+
+ pub fn set_initial_rtt(&mut self, rtt: Duration) {
+ self.rtt_vals.set_initial_rtt(rtt)
+ }
+
+ pub fn set_peer_max_ack_delay(&mut self, mad: Duration) {
+ self.rtt_vals.set_peer_max_ack_delay(mad);
+ }
+
+ pub fn cwnd_avail(&self) -> usize {
+ self.packet_sender.cwnd_avail()
+ }
+
+ pub fn largest_acknowledged_pn(&self, pn_space: PNSpace) -> Option<PacketNumber> {
+ self.spaces.get(pn_space).and_then(|sp| sp.largest_acked)
+ }
+
+ pub fn set_qlog(&mut self, qlog: NeqoQlog) {
+ self.packet_sender.set_qlog(qlog.clone());
+ self.qlog = qlog;
+ }
+
+ pub fn drop_0rtt(&mut self) -> Vec<SentPacket> {
+ // The largest acknowledged or loss_time should still be unset.
+ // The client should not have received any ACK frames when it drops 0-RTT.
+ assert!(self
+ .spaces
+ .get(PNSpace::ApplicationData)
+ .unwrap()
+ .largest_acked
+ .is_none());
+ self.spaces
+ .get_mut(PNSpace::ApplicationData)
+ .unwrap()
+ .remove_ignored()
+ .inspect(|p| self.packet_sender.discard(&p))
+ .collect()
+ }
+
+ pub fn on_packet_sent(&mut self, sent_packet: SentPacket) {
+ let pn_space = PNSpace::from(sent_packet.pt);
+ qdebug!([self], "packet {}-{} sent", pn_space, sent_packet.pn);
+ let rtt = self.rtt();
+ if let Some(space) = self.spaces.get_mut(pn_space) {
+ self.packet_sender.on_packet_sent(&sent_packet, rtt);
+ space.on_packet_sent(sent_packet);
+ } else {
+ qinfo!(
+ [self],
+ "ignoring {}-{} from dropped space",
+ pn_space,
+ sent_packet.pn
+ );
+ }
+ }
+
+ /// Record an RTT sample.
+ fn rtt_sample(&mut self, send_time: Instant, now: Instant, ack_delay: Duration) {
+ // Limit ack delay by max_ack_delay if confirmed.
+ let delay = self.confirmed_time.map_or(ack_delay, |confirmed| {
+ if confirmed < send_time {
+ ack_delay
+ } else {
+ min(ack_delay, self.rtt_vals.max_ack_delay)
+ }
+ });
+
+ let sample = now - send_time;
+ self.rtt_vals.update_rtt(&mut self.qlog, sample, delay, now);
+ }
+
+ /// Returns (acked packets, lost packets)
+ pub fn on_ack_received(
+ &mut self,
+ pn_space: PNSpace,
+ largest_acked: u64,
+ acked_ranges: Vec<RangeInclusive<u64>>,
+ ack_delay: Duration,
+ now: Instant,
+ ) -> (Vec<SentPacket>, Vec<SentPacket>) {
+ qdebug!(
+ [self],
+ "ACK for {} - largest_acked={}.",
+ pn_space,
+ largest_acked
+ );
+
+ let space = self
+ .spaces
+ .get_mut(pn_space)
+ .expect("ACK on discarded space");
+ let (acked_packets, any_ack_eliciting) =
+ space.remove_acked(acked_ranges, &mut *self.stats.borrow_mut());
+ if acked_packets.is_empty() {
+ // No new information.
+ return (Vec::new(), Vec::new());
+ }
+
+ // Track largest PN acked per space
+ let prev_largest_acked = space.largest_acked_sent_time;
+ if Some(largest_acked) > space.largest_acked {
+ space.largest_acked = Some(largest_acked);
+
+ // If the largest acknowledged is newly acked and any newly acked
+ // packet was ack-eliciting, update the RTT. (-recovery 5.1)
+ let largest_acked_pkt = acked_packets.first().expect("must be there");
+ space.largest_acked_sent_time = Some(largest_acked_pkt.time_sent);
+ if any_ack_eliciting {
+ self.rtt_sample(largest_acked_pkt.time_sent, now, ack_delay);
+ }
+ }
+
+ // Perform loss detection.
+ // PTO is used to remove lost packets from in-flight accounting.
+ // We need to ensure that we have sent any PTO probes before they are removed
+ // as we rely on the count of in-flight packets to determine whether to send
+ // another probe. Removing them too soon would result in not sending on PTO.
+ let loss_delay = self.loss_delay();
+ let cleanup = self.pto_period(pn_space);
+ let mut lost = Vec::new();
+ self.spaces
+ .get_mut(pn_space)
+ .unwrap()
+ .detect_lost_packets(now, loss_delay, cleanup, &mut lost);
+ self.stats.borrow_mut().lost += lost.len();
+
+ // Tell the congestion controller about any lost packets.
+ // The PTO for congestion control is the raw number, without exponential
+ // backoff, so that we can determine persistent congestion.
+ let pto_raw = self.pto_raw(pn_space);
+ let first_rtt_sample = self.rtt_vals.first_sample_time();
+ self.packet_sender
+ .on_packets_lost(first_rtt_sample, prev_largest_acked, pto_raw, &lost);
+
+ // This must happen after on_packets_lost. If in recovery, this could
+ // take us out, and then lost packets will start a new recovery period
+ // when it shouldn't.
+ self.packet_sender.on_packets_acked(&acked_packets);
+
+ self.pto_state = None;
+
+ (acked_packets, lost)
+ }
+
+ fn loss_delay(&self) -> Duration {
+ // kTimeThreshold = 9/8
+ // loss_delay = kTimeThreshold * max(latest_rtt, smoothed_rtt)
+ // loss_delay = max(loss_delay, kGranularity)
+ let rtt = max(self.rtt_vals.latest_rtt, self.rtt_vals.smoothed_rtt);
+ max(rtt * 9 / 8, GRANULARITY)
+ }
+
+ /// When receiving a retry, get all the sent packets so that they can be flushed.
+ /// We also need to pretend that they never happened for the purposes of congestion control.
+ pub fn retry(&mut self) -> Vec<SentPacket> {
+ self.pto_state = None;
+ let packet_sender = &mut self.packet_sender;
+ self.spaces
+ .iter_mut()
+ .flat_map(LossRecoverySpace::remove_ignored)
+ .inspect(|p| packet_sender.discard(&p))
+ .collect()
+ }
+
+ fn confirmed(&mut self, now: Instant) {
+ debug_assert!(self.confirmed_time.is_none());
+ self.confirmed_time = Some(now);
+ // Up until now, the ApplicationData space has been ignored for PTO.
+ // So maybe fire a PTO.
+ if let Some(pto) = self.pto_time(PNSpace::ApplicationData) {
+ if pto < now {
+ let probes = PNSpaceSet::from(&[PNSpace::ApplicationData]);
+ self.fire_pto(PNSpace::ApplicationData, probes);
+ }
+ }
+ }
+
+ /// Discard state for a given packet number space.
+ pub fn discard(&mut self, space: PNSpace, now: Instant) {
+ qdebug!([self], "Reset loss recovery state for {}", space);
+ for p in self.spaces.drop_space(space) {
+ self.packet_sender.discard(&p);
+ }
+
+ // We just made progress, so discard PTO count.
+ // The spec says that clients should not do this until confirming that
+ // the server has completed address validation, but ignore that.
+ self.pto_state = None;
+
+ if space == PNSpace::Handshake {
+ self.confirmed(now);
+ }
+ }
+
+ /// Calculate when the next timeout is likely to be. This is the earlier of the loss timer
+ /// and the PTO timer; either or both might be disabled, so this can return `None`.
+ pub fn next_timeout(&mut self) -> Option<Instant> {
+ let loss_time = self.earliest_loss_time();
+ let pto_time = self.earliest_pto();
+ qtrace!(
+ [self],
+ "next_timeout loss={:?} pto={:?}",
+ loss_time,
+ pto_time
+ );
+ match (loss_time, pto_time) {
+ (Some(loss_time), Some(pto_time)) => Some(min(loss_time, pto_time)),
+ (Some(loss_time), None) => Some(loss_time),
+ (None, Some(pto_time)) => Some(pto_time),
+ _ => None,
+ }
+ }
+
+ /// Find when the earliest sent packet should be considered lost.
+ fn earliest_loss_time(&self) -> Option<Instant> {
+ self.spaces
+ .iter()
+ .filter_map(LossRecoverySpace::loss_recovery_timer_start)
+ .min()
+ .map(|val| val + self.loss_delay())
+ }
+
+ // The borrow checker is a harsh mistress.
+ // It's important that calls to `RttVals::pto()` are routed through a central point
+ // because that ensures consistency, but we often have a mutable borrow on other
+ // pieces of `self` that prevents that.
+ // An associated function avoids another borrow on `&self`.
+ fn pto_raw_inner(rtt_vals: &RttVals, space: PNSpace) -> Duration {
+ rtt_vals.pto(space)
+ }
+
+ // Borrow checker hack, see above.
+ fn pto_period_inner(
+ rtt_vals: &RttVals,
+ pto_state: &Option<PtoState>,
+ pn_space: PNSpace,
+ ) -> Duration {
+ Self::pto_raw_inner(rtt_vals, pn_space)
+ .checked_mul(1 << pto_state.as_ref().map_or(0, |p| p.count))
+ .unwrap_or(LOCAL_IDLE_TIMEOUT * 2)
+ }
+
+ /// Get the Base PTO value, which is derived only from the `RTT` and `RTTvar` values.
+ /// This is for those cases where you need a value for the time you might sensibly
+ /// wait for a packet to propagate. Using `3*pto_raw(..)` is common.
+ pub fn pto_raw(&self, space: PNSpace) -> Duration {
+ Self::pto_raw_inner(&self.rtt_vals, space)
+ }
+
+ /// Get the current PTO period for the given packet number space.
+ /// Unlike `pto_raw`, this includes calculation for the exponential backoff.
+ fn pto_period(&self, pn_space: PNSpace) -> Duration {
+ Self::pto_period_inner(&self.rtt_vals, &self.pto_state, pn_space)
+ }
+
+ // Calculate PTO time for the given space.
+ fn pto_time(&self, pn_space: PNSpace) -> Option<Instant> {
+ if self.confirmed_time.is_none() && pn_space == PNSpace::ApplicationData {
+ None
+ } else {
+ self.spaces
+ .get(pn_space)
+ .and_then(|space| space.pto_base_time().map(|t| t + self.pto_period(pn_space)))
+ }
+ }
+
+ /// Find the earliest PTO time for all active packet number spaces.
+ /// Ignore Application if either Initial or Handshake have an active PTO.
+ fn earliest_pto(&self) -> Option<Instant> {
+ if self.confirmed_time.is_some() {
+ self.pto_time(PNSpace::ApplicationData)
+ } else {
+ self.pto_time(PNSpace::Initial)
+ .iter()
+ .chain(self.pto_time(PNSpace::Handshake).iter())
+ .min()
+ .cloned()
+ }
+ }
+
+ fn fire_pto(&mut self, pn_space: PNSpace, allow_probes: PNSpaceSet) {
+ if let Some(st) = &mut self.pto_state {
+ st.pto(pn_space, allow_probes);
+ } else {
+ self.pto_state = Some(PtoState::new(pn_space, allow_probes));
+ }
+
+ self.pto_state
+ .as_mut()
+ .unwrap()
+ .count_pto(&mut *self.stats.borrow_mut());
+
+ qlog::metrics_updated(
+ &mut self.qlog,
+ &[QlogMetric::PtoCount(
+ self.pto_state.as_ref().unwrap().count(),
+ )],
+ );
+ }
+
+ /// This checks whether the PTO timer has fired and fires it if needed.
+ /// When it has, mark a few packets as "lost" for the purposes of having frames
+ /// regenerated in subsequent packets. The packets aren't truly lost, so
+ /// we have to clone the `SentPacket` instance.
+ fn maybe_fire_pto(&mut self, now: Instant, lost: &mut Vec<SentPacket>) {
+ let mut pto_space = None;
+ // The spaces in which we will allow probing.
+ let mut allow_probes = PNSpaceSet::default();
+ for pn_space in PNSpace::iter() {
+ if let Some(t) = self.pto_time(*pn_space) {
+ allow_probes[*pn_space] = true;
+ if t <= now {
+ qdebug!([self], "PTO timer fired for {}", pn_space);
+ let space = self.spaces.get_mut(*pn_space).unwrap();
+ lost.extend(space.pto_packets(PTO_PACKET_COUNT).cloned());
+
+ pto_space = pto_space.or(Some(*pn_space));
+ }
+ }
+ }
+
+ // This has to happen outside the loop. Increasing the PTO count here causes the
+ // pto_time to increase which might cause PTO for later packet number spaces to not fire.
+ if let Some(pn_space) = pto_space {
+ qtrace!([self], "PTO {}, probing {:?}", pn_space, allow_probes);
+ self.fire_pto(pn_space, allow_probes);
+ }
+ }
+
+ pub fn timeout(&mut self, now: Instant) -> Vec<SentPacket> {
+ qtrace!([self], "timeout {:?}", now);
+
+ let loss_delay = self.loss_delay();
+ let first_rtt_sample = self.rtt_vals.first_sample_time();
+
+ let mut lost_packets = Vec::new();
+ for space in self.spaces.iter_mut() {
+ let first = lost_packets.len(); // The first packet lost in this space.
+ let pto = Self::pto_period_inner(&self.rtt_vals, &self.pto_state, space.space());
+ space.detect_lost_packets(now, loss_delay, pto, &mut lost_packets);
+ self.packet_sender.on_packets_lost(
+ first_rtt_sample,
+ space.largest_acked_sent_time,
+ Self::pto_raw_inner(&self.rtt_vals, space.space()),
+ &lost_packets[first..],
+ );
+ }
+ self.stats.borrow_mut().lost += lost_packets.len();
+
+ self.maybe_fire_pto(now, &mut lost_packets);
+ lost_packets
+ }
+
+ /// Start the packet pacer.
+ pub fn start_pacer(&mut self, now: Instant) {
+ self.packet_sender.start_pacer(now);
+ }
+
+ /// Get the next time that a paced packet might be sent.
+ pub fn next_paced(&self) -> Option<Instant> {
+ self.packet_sender.next_paced(self.rtt())
+ }
+
+ /// Check how packets should be sent, based on whether there is a PTO,
+ /// what the current congestion window is, and what the pacer says.
+ #[allow(clippy::option_if_let_else, clippy::unknown_clippy_lints)]
+ pub fn send_profile(&mut self, now: Instant, mtu: usize) -> SendProfile {
+ qdebug!([self], "get send profile {:?}", now);
+ if let Some(pto) = self.pto_state.as_mut() {
+ pto.send_profile(mtu)
+ } else {
+ let cwnd = self.cwnd_avail();
+ if cwnd > mtu {
+ // More than an MTU available; we might need to pace.
+ if self.next_paced().map_or(false, |t| t > now) {
+ SendProfile::new_paced()
+ } else {
+ SendProfile::new_limited(mtu)
+ }
+ } else if self.packet_sender.recovery_packet() {
+ // After entering recovery, allow a packet to be sent immediately.
+ // This uses the PTO machinery, probing in all spaces. This will
+ // result in a PING being sent in every active space.
+ SendProfile::new_pto(PNSpace::Initial, mtu, PNSpaceSet::all())
+ } else {
+ SendProfile::new_limited(cwnd)
+ }
+ }
+ }
+}
+
+impl ::std::fmt::Display for LossRecovery {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "LossRecovery")
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{
+ CongestionControlAlgorithm, LossRecovery, LossRecoverySpace, PNSpace, SentPacket,
+ INITIAL_RTT, MAX_ACK_DELAY,
+ };
+ use crate::packet::PacketType;
+ use crate::stats::{Stats, StatsCell};
+ use std::convert::TryInto;
+ use std::time::{Duration, Instant};
+ use test_fixture::now;
+
+ const ON_SENT_SIZE: usize = 100;
+
+ fn assert_rtts(
+ lr: &LossRecovery,
+ latest_rtt: Duration,
+ smoothed_rtt: Duration,
+ rttvar: Duration,
+ min_rtt: Duration,
+ ) {
+ println!(
+ "rtts: {:?} {:?} {:?} {:?}",
+ lr.rtt_vals.latest_rtt,
+ lr.rtt_vals.smoothed_rtt,
+ lr.rtt_vals.rttvar,
+ lr.rtt_vals.min_rtt,
+ );
+ assert_eq!(lr.rtt_vals.latest_rtt, latest_rtt, "latest RTT");
+ assert_eq!(lr.rtt_vals.smoothed_rtt, smoothed_rtt, "smoothed RTT");
+ assert_eq!(lr.rtt_vals.rttvar, rttvar, "RTT variance");
+ assert_eq!(lr.rtt_vals.min_rtt, min_rtt, "min RTT");
+ }
+
+ fn assert_sent_times(
+ lr: &LossRecovery,
+ initial: Option<Instant>,
+ handshake: Option<Instant>,
+ app_data: Option<Instant>,
+ ) {
+ let est = |sp| {
+ lr.spaces
+ .get(sp)
+ .and_then(LossRecoverySpace::loss_recovery_timer_start)
+ };
+ println!(
+ "loss times: {:?} {:?} {:?}",
+ est(PNSpace::Initial),
+ est(PNSpace::Handshake),
+ est(PNSpace::ApplicationData),
+ );
+ assert_eq!(est(PNSpace::Initial), initial, "Initial earliest sent time");
+ assert_eq!(
+ est(PNSpace::Handshake),
+ handshake,
+ "Handshake earliest sent time"
+ );
+ assert_eq!(
+ est(PNSpace::ApplicationData),
+ app_data,
+ "AppData earliest sent time"
+ );
+ }
+
+ fn assert_no_sent_times(lr: &LossRecovery) {
+ assert_sent_times(lr, None, None, None);
+ }
+
+ // Time in milliseconds.
+ macro_rules! ms {
+ ($t:expr) => {
+ Duration::from_millis($t)
+ };
+ }
+
+ // In most of the tests below, packets are sent at a fixed cadence, with PACING between each.
+ const PACING: Duration = ms!(7);
+ fn pn_time(pn: u64) -> Instant {
+ now() + (PACING * pn.try_into().unwrap())
+ }
+
+ fn pace(lr: &mut LossRecovery, count: u64) {
+ for pn in 0..count {
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ pn,
+ pn_time(pn),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ }
+ }
+
+ const ACK_DELAY: Duration = ms!(24);
+ /// Acknowledge PN with the identified delay.
+ fn ack(lr: &mut LossRecovery, pn: u64, delay: Duration) {
+ lr.on_ack_received(
+ PNSpace::ApplicationData,
+ pn,
+ vec![pn..=pn],
+ ACK_DELAY,
+ pn_time(pn) + delay,
+ );
+ }
+
+ fn add_sent(lrs: &mut LossRecoverySpace, packet_numbers: &[u64]) {
+ for &pn in packet_numbers {
+ lrs.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ pn,
+ pn_time(pn),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ }
+ }
+
+ fn match_acked(acked: &[SentPacket], expected: &[u64]) {
+ assert!(acked.iter().map(|p| &p.pn).eq(expected));
+ }
+
+ #[test]
+ fn remove_acked() {
+ let mut lrs = LossRecoverySpace::new(PNSpace::ApplicationData);
+ let mut stats = Stats::default();
+ add_sent(&mut lrs, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
+ let (acked, _) = lrs.remove_acked(vec![], &mut stats);
+ assert!(acked.is_empty());
+ let (acked, _) = lrs.remove_acked(vec![7..=8, 2..=4], &mut stats);
+ match_acked(&acked, &[8, 7, 4, 3, 2]);
+ let (acked, _) = lrs.remove_acked(vec![8..=11], &mut stats);
+ match_acked(&acked, &[10, 9]);
+ let (acked, _) = lrs.remove_acked(vec![0..=2], &mut stats);
+ match_acked(&acked, &[1]);
+ let (acked, _) = lrs.remove_acked(vec![5..=6], &mut stats);
+ match_acked(&acked, &[6, 5]);
+ }
+
+ #[test]
+ fn initial_rtt() {
+ let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default());
+ lr.start_pacer(now());
+ pace(&mut lr, 1);
+ let rtt = ms!(100);
+ ack(&mut lr, 0, rtt);
+ assert_rtts(&lr, rtt, rtt, rtt / 2, rtt);
+ assert_no_sent_times(&lr);
+ }
+
+ /// An initial RTT for using with `setup_lr`.
+ const TEST_RTT: Duration = ms!(80);
+ const TEST_RTTVAR: Duration = ms!(40);
+
+ /// Send `n` packets (using PACING), then acknowledge the first.
+ fn setup_lr(n: u64) -> LossRecovery {
+ let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default());
+ lr.start_pacer(now());
+ pace(&mut lr, n);
+ ack(&mut lr, 0, TEST_RTT);
+ assert_rtts(&lr, TEST_RTT, TEST_RTT, TEST_RTTVAR, TEST_RTT);
+ assert_no_sent_times(&lr);
+ lr
+ }
+
+ // The ack delay is removed from any RTT estimate.
+ #[test]
+ fn ack_delay_adjusted() {
+ let mut lr = setup_lr(2);
+ ack(&mut lr, 1, TEST_RTT + ACK_DELAY);
+ // RTT stays the same, but the RTTVAR is adjusted downwards.
+ assert_rtts(&lr, TEST_RTT, TEST_RTT, TEST_RTTVAR * 3 / 4, TEST_RTT);
+ assert_no_sent_times(&lr);
+ }
+
+ // The ack delay is ignored when it would cause a sample to be less than min_rtt.
+ #[test]
+ fn ack_delay_ignored() {
+ let mut lr = setup_lr(2);
+ let extra = ms!(8);
+ assert!(extra < ACK_DELAY);
+ ack(&mut lr, 1, TEST_RTT + extra);
+ let expected_rtt = TEST_RTT + (extra / 8);
+ let expected_rttvar = (TEST_RTTVAR * 3 + extra) / 4;
+ assert_rtts(
+ &lr,
+ TEST_RTT + extra,
+ expected_rtt,
+ expected_rttvar,
+ TEST_RTT,
+ );
+ assert_no_sent_times(&lr);
+ }
+
+ // A lower observed RTT is used as min_rtt (and ack delay is ignored).
+ #[test]
+ fn reduce_min_rtt() {
+ let mut lr = setup_lr(2);
+ let delta = ms!(4);
+ let reduced_rtt = TEST_RTT - delta;
+ ack(&mut lr, 1, reduced_rtt);
+ let expected_rtt = TEST_RTT - (delta / 8);
+ let expected_rttvar = (TEST_RTTVAR * 3 + delta) / 4;
+ assert_rtts(&lr, reduced_rtt, expected_rtt, expected_rttvar, reduced_rtt);
+ assert_no_sent_times(&lr);
+ }
+
+ // Acknowledging something again has no effect.
+ #[test]
+ fn no_new_acks() {
+ let mut lr = setup_lr(1);
+ let check = |lr: &LossRecovery| {
+ assert_rtts(&lr, TEST_RTT, TEST_RTT, TEST_RTTVAR, TEST_RTT);
+ assert_no_sent_times(&lr);
+ };
+ check(&lr);
+
+ ack(&mut lr, 0, ms!(1339)); // much delayed ACK
+ check(&lr);
+
+ ack(&mut lr, 0, ms!(3)); // time travel!
+ check(&lr);
+ }
+
+ // Test time loss detection as part of handling a regular ACK.
+ #[test]
+ fn time_loss_detection_gap() {
+ let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default());
+ lr.start_pacer(now());
+ // Create a single packet gap, and have pn 0 time out.
+ // This can't use the default pacing, which is too tight.
+ // So send two packets with 1/4 RTT between them. Acknowledge pn 1 after 1 RTT.
+ // pn 0 should then be marked lost because it is then outstanding for 5RTT/4
+ // the loss time for packets is 9RTT/8.
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ 0,
+ pn_time(0),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ 1,
+ pn_time(0) + TEST_RTT / 4,
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ let (_, lost) = lr.on_ack_received(
+ PNSpace::ApplicationData,
+ 1,
+ vec![1..=1],
+ ACK_DELAY,
+ pn_time(0) + (TEST_RTT * 5 / 4),
+ );
+ assert_eq!(lost.len(), 1);
+ assert_no_sent_times(&lr);
+ }
+
+ // Test time loss detection as part of an explicit timeout.
+ #[test]
+ fn time_loss_detection_timeout() {
+ let mut lr = setup_lr(3);
+
+ // We want to declare PN 2 as acknowledged before we declare PN 1 as lost.
+ // For this to work, we need PACING above to be less than 1/8 of an RTT.
+ let pn1_sent_time = pn_time(1);
+ let pn1_loss_time = pn1_sent_time + (TEST_RTT * 9 / 8);
+ let pn2_ack_time = pn_time(2) + TEST_RTT;
+ assert!(pn1_loss_time > pn2_ack_time);
+
+ let (_, lost) = lr.on_ack_received(
+ PNSpace::ApplicationData,
+ 2,
+ vec![2..=2],
+ ACK_DELAY,
+ pn2_ack_time,
+ );
+ assert!(lost.is_empty());
+ // Run the timeout function here to force time-based loss recovery to be enabled.
+ let lost = lr.timeout(pn2_ack_time);
+ assert!(lost.is_empty());
+ assert_sent_times(&lr, None, None, Some(pn1_sent_time));
+
+ // After time elapses, pn 1 is marked lost.
+ let callback_time = lr.next_timeout();
+ assert_eq!(callback_time, Some(pn1_loss_time));
+ let packets = lr.timeout(pn1_loss_time);
+ assert_eq!(packets.len(), 1);
+ // Checking for expiration with zero delay lets us check the loss time.
+ assert!(packets[0].expired(pn1_loss_time, Duration::new(0, 0)));
+ assert_no_sent_times(&lr);
+ }
+
+ #[test]
+ fn big_gap_loss() {
+ let mut lr = setup_lr(5); // This sends packets 0-4 and acknowledges pn 0.
+ // Acknowledge just 2-4, which will cause pn 1 to be marked as lost.
+ assert_eq!(super::PACKET_THRESHOLD, 3);
+ let (_, lost) = lr.on_ack_received(
+ PNSpace::ApplicationData,
+ 4,
+ vec![2..=4],
+ ACK_DELAY,
+ pn_time(4),
+ );
+ assert_eq!(lost.len(), 1);
+ }
+
+ #[test]
+ #[should_panic(expected = "discarding application space")]
+ fn drop_app() {
+ let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default());
+ lr.discard(PNSpace::ApplicationData, now());
+ }
+
+ #[test]
+ #[should_panic(expected = "dropping spaces out of order")]
+ fn drop_out_of_order() {
+ let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default());
+ lr.discard(PNSpace::Handshake, now());
+ }
+
+ #[test]
+ #[should_panic(expected = "ACK on discarded space")]
+ fn ack_after_drop() {
+ let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default());
+ lr.start_pacer(now());
+ lr.discard(PNSpace::Initial, now());
+ lr.on_ack_received(
+ PNSpace::Initial,
+ 0,
+ vec![],
+ Duration::from_millis(0),
+ pn_time(0),
+ );
+ }
+
+ #[test]
+ fn drop_spaces() {
+ let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default());
+ lr.start_pacer(now());
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Initial,
+ 0,
+ pn_time(0),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Handshake,
+ 0,
+ pn_time(1),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ 0,
+ pn_time(2),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+
+ // Now put all spaces on the LR timer so we can see them.
+ for sp in &[
+ PacketType::Initial,
+ PacketType::Handshake,
+ PacketType::Short,
+ ] {
+ let sent_pkt = SentPacket::new(*sp, 1, pn_time(3), true, Vec::new(), ON_SENT_SIZE);
+ let pn_space = PNSpace::from(sent_pkt.pt);
+ lr.on_packet_sent(sent_pkt);
+ lr.on_ack_received(pn_space, 1, vec![1..=1], Duration::from_secs(0), pn_time(3));
+ let mut lost = Vec::new();
+ lr.spaces.get_mut(pn_space).unwrap().detect_lost_packets(
+ pn_time(3),
+ TEST_RTT,
+ TEST_RTT * 3, // unused
+ &mut lost,
+ );
+ assert!(lost.is_empty());
+ }
+
+ lr.discard(PNSpace::Initial, pn_time(3));
+ assert_sent_times(&lr, None, Some(pn_time(1)), Some(pn_time(2)));
+
+ lr.discard(PNSpace::Handshake, pn_time(3));
+ assert_sent_times(&lr, None, None, Some(pn_time(2)));
+
+ // There are cases where we send a packet that is not subsequently tracked.
+ // So check that this works.
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Initial,
+ 0,
+ pn_time(3),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ assert_sent_times(&lr, None, None, Some(pn_time(2)));
+ }
+
+ #[test]
+ fn rearm_pto_after_confirmed() {
+ let mut lr = LossRecovery::new(CongestionControlAlgorithm::NewReno, StatsCell::default());
+ lr.start_pacer(now());
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Handshake,
+ 0,
+ now(),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+ lr.on_packet_sent(SentPacket::new(
+ PacketType::Short,
+ 0,
+ now(),
+ true,
+ Vec::new(),
+ ON_SENT_SIZE,
+ ));
+
+ assert_eq!(lr.pto_time(PNSpace::ApplicationData), None);
+ lr.discard(PNSpace::Initial, pn_time(1));
+ assert_eq!(lr.pto_time(PNSpace::ApplicationData), None);
+
+ // Expiring state after the PTO on the ApplicationData space has
+ // expired should result in setting a PTO state.
+ let expected_pto = pn_time(2) + (INITIAL_RTT * 3) + MAX_ACK_DELAY;
+ lr.discard(PNSpace::Handshake, expected_pto);
+ let profile = lr.send_profile(expected_pto, 10000);
+ assert!(profile.pto.is_some());
+ assert!(!profile.should_probe(PNSpace::Initial));
+ assert!(!profile.should_probe(PNSpace::Handshake));
+ assert!(profile.should_probe(PNSpace::ApplicationData));
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/recv_stream.rs b/third_party/rust/neqo-transport/src/recv_stream.rs
new file mode 100644
index 0000000000..e82e08d2ec
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/recv_stream.rs
@@ -0,0 +1,1110 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Building a stream of ordered bytes to give the application from a series of
+// incoming STREAM frames.
+
+use std::cell::RefCell;
+use std::cmp::max;
+use std::collections::BTreeMap;
+use std::convert::TryFrom;
+use std::mem;
+use std::ops::Bound::{Included, Unbounded};
+use std::rc::Rc;
+
+use smallvec::SmallVec;
+
+use crate::events::ConnectionEvents;
+use crate::flow_mgr::FlowMgr;
+use crate::stream_id::StreamId;
+use crate::{AppError, Error, Res};
+use neqo_common::qtrace;
+
+const RX_STREAM_DATA_WINDOW: u64 = 0x10_0000; // 1MiB
+
+// Export as usize for consistency with SEND_BUFFER_SIZE
+pub const RECV_BUFFER_SIZE: usize = RX_STREAM_DATA_WINDOW as usize;
+
+pub(crate) type RecvStreams = BTreeMap<StreamId, RecvStream>;
+
+/// Holds data not yet read by application. Orders and dedupes data ranges
+/// from incoming STREAM frames.
+#[derive(Debug, Default)]
+pub struct RxStreamOrderer {
+ data_ranges: BTreeMap<u64, Vec<u8>>, // (start_offset, data)
+ retired: u64, // Number of bytes the application has read
+}
+
+impl RxStreamOrderer {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ /// Process an incoming stream frame off the wire. This may result in data
+ /// being available to upper layers if frame is not out of order (ooo) or
+ /// if the frame fills a gap.
+ pub fn inbound_frame(&mut self, mut new_start: u64, mut new_data: &[u8]) {
+ qtrace!("Inbound data offset={} len={}", new_start, new_data.len());
+
+ // Get entry before where new entry would go, so we can see if we already
+ // have the new bytes.
+ // Avoid copies and duplicated data.
+ let new_end = new_start + u64::try_from(new_data.len()).unwrap();
+
+ if new_end <= self.retired {
+ // Range already read by application, this frame is very late and unneeded.
+ return;
+ }
+
+ if new_start < self.retired {
+ new_data = &new_data[usize::try_from(self.retired - new_start).unwrap()..];
+ new_start = self.retired;
+ }
+
+ if new_data.is_empty() {
+ // No data to insert
+ return;
+ }
+
+ let extend = if let Some((&prev_start, prev_vec)) = self
+ .data_ranges
+ .range_mut((Unbounded, Included(new_start)))
+ .next_back()
+ {
+ let prev_end = prev_start + u64::try_from(prev_vec.len()).unwrap();
+ if new_end > prev_end {
+ // PPPPPP -> PPPPPP
+ // NNNNNN NN
+ // NNNNNNNN NN
+ // Add a range containing only new data
+ // (In-order frames will take this path, with no overlap)
+ let overlap = prev_end.saturating_sub(new_start);
+ qtrace!(
+ "New frame {}-{} received, overlap: {}",
+ new_start,
+ new_end,
+ overlap
+ );
+ new_start += overlap;
+ new_data = &new_data[usize::try_from(overlap).unwrap()..];
+ // If it is small enough, extend the previous buffer.
+ // This can't always extend, because otherwise the buffer could end up
+ // growing indefinitely without being released.
+ prev_vec.len() < 4096 && prev_end == new_start
+ } else {
+ // PPPPPP -> PPPPPP
+ // NNNN
+ // NNNN
+ // Do nothing
+ qtrace!(
+ "Dropping frame with already-received range {}-{}",
+ new_start,
+ new_end
+ );
+ return;
+ }
+ } else {
+ qtrace!("New frame {}-{} received", new_start, new_end);
+ false
+ };
+
+ // Now handle possible overlap with next entries
+ let mut to_remove = SmallVec::<[_; 8]>::new();
+ let mut to_add = new_data;
+
+ for (&next_start, next_data) in self.data_ranges.range_mut(new_start..) {
+ let next_end = next_start + u64::try_from(next_data.len()).unwrap();
+ let overlap = new_end.saturating_sub(next_start);
+ if overlap == 0 {
+ break;
+ } else if next_end >= new_end {
+ qtrace!(
+ "New frame {}-{} overlaps with next frame by {}, truncating",
+ new_start,
+ new_end,
+ overlap
+ );
+ let truncate_to = new_data.len() - usize::try_from(overlap).unwrap();
+ to_add = &new_data[..truncate_to];
+ break;
+ } else {
+ qtrace!(
+ "New frame {}-{} spans entire next frame {}-{}, replacing",
+ new_start,
+ new_end,
+ next_start,
+ next_end
+ );
+ to_remove.push(next_start);
+ }
+ }
+
+ for start in to_remove {
+ self.data_ranges.remove(&start);
+ }
+
+ if !to_add.is_empty() {
+ if extend {
+ let (_, buf) = self
+ .data_ranges
+ .range_mut((Unbounded, Included(new_start)))
+ .next_back()
+ .unwrap();
+ buf.extend_from_slice(to_add);
+ } else {
+ self.data_ranges.insert(new_start, to_add.to_vec());
+ }
+ }
+ }
+
+ /// Are any bytes readable?
+ pub fn data_ready(&self) -> bool {
+ self.data_ranges
+ .keys()
+ .next()
+ .map_or(false, |&start| start <= self.retired)
+ }
+
+ /// How many bytes are readable?
+ fn bytes_ready(&self) -> usize {
+ let mut prev_end = self.retired;
+ self.data_ranges
+ .iter()
+ .map(|(start_offset, data)| {
+ // All ranges don't overlap but we could have partially
+ // retired some of the first entry's data.
+ let data_len = data.len() as u64 - self.retired.saturating_sub(*start_offset);
+ (start_offset, data_len)
+ })
+ .take_while(|(start_offset, data_len)| {
+ if **start_offset <= prev_end {
+ prev_end += data_len;
+ true
+ } else {
+ false
+ }
+ })
+ .map(|(_, data_len)| data_len as usize)
+ .sum()
+ }
+
+ /// Bytes read by the application.
+ fn retired(&self) -> u64 {
+ self.retired
+ }
+
+ /// Data bytes buffered. Could be more than bytes_readable if there are
+ /// ranges missing.
+ fn buffered(&self) -> u64 {
+ self.data_ranges
+ .iter()
+ .map(|(&start, data)| data.len() as u64 - (self.retired.saturating_sub(start)))
+ .sum()
+ }
+
+ /// Copy received data (if any) into the buffer. Returns bytes copied.
+ fn read(&mut self, buf: &mut [u8]) -> usize {
+ qtrace!("Reading {} bytes, {} available", buf.len(), self.buffered());
+ let mut copied = 0;
+
+ for (&range_start, range_data) in &mut self.data_ranges {
+ let mut keep = false;
+ if self.retired >= range_start {
+ // Frame data has new contiguous bytes.
+ let copy_offset =
+ usize::try_from(max(range_start, self.retired) - range_start).unwrap();
+ assert!(range_data.len() >= copy_offset);
+ let available = range_data.len() - copy_offset;
+ let space = buf.len() - copied;
+ let copy_bytes = if available > space {
+ keep = true;
+ space
+ } else {
+ available
+ };
+
+ if copy_bytes > 0 {
+ let copy_slc = &range_data[copy_offset..copy_offset + copy_bytes];
+ buf[copied..copied + copy_bytes].copy_from_slice(copy_slc);
+ copied += copy_bytes;
+ self.retired += u64::try_from(copy_bytes).unwrap();
+ }
+ } else {
+ // The data in the buffer isn't contiguous.
+ keep = true;
+ }
+ if keep {
+ let mut keep = self.data_ranges.split_off(&range_start);
+ mem::swap(&mut self.data_ranges, &mut keep);
+ return copied;
+ }
+ }
+
+ self.data_ranges.clear();
+ copied
+ }
+
+ /// Extend the given Vector with any available data.
+ pub fn read_to_end(&mut self, buf: &mut Vec<u8>) -> usize {
+ let orig_len = buf.len();
+ buf.resize(orig_len + self.bytes_ready(), 0);
+ self.read(&mut buf[orig_len..])
+ }
+
+ fn highest_seen_offset(&self) -> u64 {
+ let maybe_ooo_last = self
+ .data_ranges
+ .iter()
+ .next_back()
+ .map(|(start, data)| *start + data.len() as u64);
+ maybe_ooo_last.unwrap_or(self.retired)
+ }
+}
+
+/// QUIC receiving states, based on -transport 3.2.
+#[derive(Debug)]
+#[allow(dead_code)]
+// Because a dead_code warning is easier than clippy::unused_self, see https://github.com/rust-lang/rust/issues/68408
+enum RecvStreamState {
+ Recv {
+ recv_buf: RxStreamOrderer,
+ max_bytes: u64, // Maximum size of recv_buf
+ max_stream_data: u64,
+ },
+ SizeKnown {
+ recv_buf: RxStreamOrderer,
+ final_size: u64,
+ },
+ DataRecvd {
+ recv_buf: RxStreamOrderer,
+ },
+ DataRead,
+ ResetRecvd,
+ // Defined by spec but we don't use it: ResetRead
+}
+
+impl RecvStreamState {
+ fn new(max_bytes: u64) -> Self {
+ Self::Recv {
+ recv_buf: RxStreamOrderer::new(),
+ max_bytes,
+ max_stream_data: max_bytes,
+ }
+ }
+
+ fn name(&self) -> &str {
+ match self {
+ Self::Recv { .. } => "Recv",
+ Self::SizeKnown { .. } => "SizeKnown",
+ Self::DataRecvd { .. } => "DataRecvd",
+ Self::DataRead => "DataRead",
+ Self::ResetRecvd => "ResetRecvd",
+ }
+ }
+
+ fn recv_buf(&self) -> Option<&RxStreamOrderer> {
+ match self {
+ Self::Recv { recv_buf, .. }
+ | Self::SizeKnown { recv_buf, .. }
+ | Self::DataRecvd { recv_buf } => Some(recv_buf),
+ Self::DataRead | Self::ResetRecvd => None,
+ }
+ }
+
+ fn final_size(&self) -> Option<u64> {
+ match self {
+ Self::SizeKnown { final_size, .. } => Some(*final_size),
+ _ => None,
+ }
+ }
+
+ fn max_stream_data(&self) -> Option<u64> {
+ match self {
+ Self::Recv {
+ max_stream_data, ..
+ } => Some(*max_stream_data),
+ _ => None,
+ }
+ }
+}
+
+/// Implement a QUIC receive stream.
+#[derive(Debug)]
+pub struct RecvStream {
+ stream_id: StreamId,
+ state: RecvStreamState,
+ flow_mgr: Rc<RefCell<FlowMgr>>,
+ conn_events: ConnectionEvents,
+}
+
+impl RecvStream {
+ pub fn new(
+ stream_id: StreamId,
+ max_stream_data: u64,
+ flow_mgr: Rc<RefCell<FlowMgr>>,
+ conn_events: ConnectionEvents,
+ ) -> Self {
+ Self {
+ stream_id,
+ state: RecvStreamState::new(max_stream_data),
+ flow_mgr,
+ conn_events,
+ }
+ }
+
+ fn set_state(&mut self, new_state: RecvStreamState) {
+ debug_assert_ne!(
+ mem::discriminant(&self.state),
+ mem::discriminant(&new_state)
+ );
+ qtrace!(
+ "RecvStream {} state {} -> {}",
+ self.stream_id.as_u64(),
+ self.state.name(),
+ new_state.name()
+ );
+
+ if let RecvStreamState::Recv { .. } = &self.state {
+ self.flow_mgr
+ .borrow_mut()
+ .clear_max_stream_data(self.stream_id)
+ }
+
+ if let RecvStreamState::DataRead = new_state {
+ self.conn_events.recv_stream_complete(self.stream_id);
+ }
+
+ self.state = new_state;
+ }
+
+ pub fn inbound_stream_frame(&mut self, fin: bool, offset: u64, data: &[u8]) -> Res<()> {
+ // We should post a DataReadable event only once when we change from no-data-ready to
+ // data-ready. Therefore remember the state before processing a new frame.
+ let already_data_ready = self.data_ready();
+ let new_end = offset + u64::try_from(data.len()).unwrap();
+
+ // Send final size errors even if stream is closed
+ if let Some(final_size) = self.state.final_size() {
+ if new_end > final_size || (fin && new_end != final_size) {
+ return Err(Error::FinalSizeError);
+ }
+ }
+
+ match &mut self.state {
+ RecvStreamState::Recv {
+ recv_buf,
+ max_stream_data,
+ ..
+ } => {
+ if new_end > *max_stream_data {
+ qtrace!("Stream RX window {} exceeded: {}", max_stream_data, new_end);
+ return Err(Error::FlowControlError);
+ }
+
+ if fin {
+ let final_size = offset + data.len() as u64;
+ if final_size < recv_buf.highest_seen_offset() {
+ return Err(Error::FinalSizeError);
+ }
+ recv_buf.inbound_frame(offset, data);
+
+ let buf = mem::replace(recv_buf, RxStreamOrderer::new());
+ if final_size == buf.retired() + buf.bytes_ready() as u64 {
+ self.set_state(RecvStreamState::DataRecvd { recv_buf: buf });
+ } else {
+ self.set_state(RecvStreamState::SizeKnown {
+ recv_buf: buf,
+ final_size,
+ });
+ }
+ } else {
+ recv_buf.inbound_frame(offset, data);
+ }
+ }
+ RecvStreamState::SizeKnown {
+ recv_buf,
+ final_size,
+ } => {
+ recv_buf.inbound_frame(offset, data);
+ if *final_size == recv_buf.retired() + recv_buf.bytes_ready() as u64 {
+ let buf = mem::replace(recv_buf, RxStreamOrderer::new());
+ self.set_state(RecvStreamState::DataRecvd { recv_buf: buf });
+ }
+ }
+ RecvStreamState::DataRecvd { .. }
+ | RecvStreamState::DataRead
+ | RecvStreamState::ResetRecvd => {
+ qtrace!("data received when we are in state {}", self.state.name())
+ }
+ }
+
+ if !already_data_ready && (self.data_ready() || self.needs_to_inform_app_about_fin()) {
+ self.conn_events.recv_stream_readable(self.stream_id)
+ }
+
+ Ok(())
+ }
+
+ pub fn reset(&mut self, application_error_code: AppError) {
+ match self.state {
+ RecvStreamState::Recv { .. } | RecvStreamState::SizeKnown { .. } => {
+ self.conn_events
+ .recv_stream_reset(self.stream_id, application_error_code);
+ self.set_state(RecvStreamState::ResetRecvd);
+ }
+ _ => {
+ // Ignore reset if in DataRecvd, DataRead, or ResetRecvd
+ }
+ }
+ }
+
+ /// If we should tell the sender they have more credit, return an offset
+ pub fn maybe_send_flowc_update(&mut self) {
+ // Only ever needed if actively receiving and not in SizeKnown state
+ if let RecvStreamState::Recv {
+ max_bytes,
+ max_stream_data,
+ recv_buf,
+ } = &mut self.state
+ {
+ // Algo: send an update if app has consumed more than half
+ // the data in the current window
+ // TODO(agrover@mozilla.com): This algo is not great but
+ // should prevent Silly Window Syndrome. Spec refers to using
+ // highest seen offset somehow? RTT maybe?
+ let maybe_new_max = recv_buf.retired() + *max_bytes;
+ if maybe_new_max > (*max_bytes / 2) + *max_stream_data {
+ *max_stream_data = maybe_new_max;
+ self.flow_mgr
+ .borrow_mut()
+ .max_stream_data(self.stream_id, maybe_new_max)
+ }
+ }
+ }
+
+ pub fn max_stream_data(&self) -> Option<u64> {
+ self.state.max_stream_data()
+ }
+
+ pub fn is_terminal(&self) -> bool {
+ matches!(
+ self.state,
+ RecvStreamState::ResetRecvd | RecvStreamState::DataRead
+ )
+ }
+
+ // App got all data but did not get the fin signal.
+ fn needs_to_inform_app_about_fin(&self) -> bool {
+ matches!(self.state, RecvStreamState::DataRecvd { .. })
+ }
+
+ fn data_ready(&self) -> bool {
+ self.state
+ .recv_buf()
+ .map_or(false, RxStreamOrderer::data_ready)
+ }
+
+ /// # Errors
+ /// `NoMoreData` if data and fin bit were previously read by the application.
+ pub fn read(&mut self, buf: &mut [u8]) -> Res<(usize, bool)> {
+ let res = match &mut self.state {
+ RecvStreamState::Recv { recv_buf, .. }
+ | RecvStreamState::SizeKnown { recv_buf, .. } => Ok((recv_buf.read(buf), false)),
+ RecvStreamState::DataRecvd { recv_buf } => {
+ let bytes_read = recv_buf.read(buf);
+ let fin_read = recv_buf.buffered() == 0;
+ if fin_read {
+ self.set_state(RecvStreamState::DataRead);
+ }
+ Ok((bytes_read, fin_read))
+ }
+ RecvStreamState::DataRead | RecvStreamState::ResetRecvd => Err(Error::NoMoreData),
+ };
+ self.maybe_send_flowc_update();
+ res
+ }
+
+ pub fn stop_sending(&mut self, err: AppError) {
+ qtrace!("stop_sending called when in state {}", self.state.name());
+ match &self.state {
+ RecvStreamState::Recv { .. } | RecvStreamState::SizeKnown { .. } => {
+ self.set_state(RecvStreamState::ResetRecvd);
+ self.flow_mgr.borrow_mut().stop_sending(self.stream_id, err)
+ }
+ RecvStreamState::DataRecvd { .. } => self.set_state(RecvStreamState::DataRead),
+ RecvStreamState::DataRead | RecvStreamState::ResetRecvd => {
+ // Already in terminal state
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::frame::Frame;
+ use std::ops::Range;
+
+ fn recv_ranges(ranges: &[Range<u64>], available: usize) {
+ const ZEROES: &[u8] = &[0; 100];
+ qtrace!("recv_ranges {:?}", ranges);
+
+ let mut s = RxStreamOrderer::default();
+ for r in ranges {
+ let data = &ZEROES[..usize::try_from(r.end - r.start).unwrap()];
+ s.inbound_frame(r.start, data);
+ }
+
+ let mut buf = [0xff; 100];
+ let mut total_recvd = 0;
+ loop {
+ let recvd = s.read(&mut buf[..]);
+ qtrace!("recv_ranges read {}", recvd);
+ total_recvd += recvd;
+ if recvd == 0 {
+ assert_eq!(total_recvd, available);
+ break;
+ }
+ }
+ }
+
+ #[test]
+ fn recv_noncontiguous() {
+ // Non-contiguous with the start, no data available.
+ recv_ranges(&[10..20], 0);
+ }
+
+ /// Overlaps with the start of a 10..20 range of bytes.
+ #[test]
+ fn recv_overlap_start() {
+ // Overlap the start, with a larger new value.
+ // More overlap than not.
+ recv_ranges(&[10..20, 4..18, 0..4], 20);
+ // Overlap the start, with a larger new value.
+ // Less overlap than not.
+ recv_ranges(&[10..20, 2..15, 0..2], 20);
+ // Overlap the start, with a smaller new value.
+ // More overlap than not.
+ recv_ranges(&[10..20, 8..14, 0..8], 20);
+ // Overlap the start, with a smaller new value.
+ // Less overlap than not.
+ recv_ranges(&[10..20, 6..13, 0..6], 20);
+
+ // Again with some of the first range split in two.
+ recv_ranges(&[10..11, 11..20, 4..18, 0..4], 20);
+ recv_ranges(&[10..11, 11..20, 2..15, 0..2], 20);
+ recv_ranges(&[10..11, 11..20, 8..14, 0..8], 20);
+ recv_ranges(&[10..11, 11..20, 6..13, 0..6], 20);
+
+ // Again with a gap in the first range.
+ recv_ranges(&[10..11, 12..20, 4..18, 0..4], 20);
+ recv_ranges(&[10..11, 12..20, 2..15, 0..2], 20);
+ recv_ranges(&[10..11, 12..20, 8..14, 0..8], 20);
+ recv_ranges(&[10..11, 12..20, 6..13, 0..6], 20);
+ }
+
+ /// Overlaps with the end of a 10..20 range of bytes.
+ #[test]
+ fn recv_overlap_end() {
+ // Overlap the end, with a larger new value.
+ // More overlap than not.
+ recv_ranges(&[10..20, 12..25, 0..10], 25);
+ // Overlap the end, with a larger new value.
+ // Less overlap than not.
+ recv_ranges(&[10..20, 17..33, 0..10], 33);
+ // Overlap the end, with a smaller new value.
+ // More overlap than not.
+ recv_ranges(&[10..20, 15..21, 0..10], 21);
+ // Overlap the end, with a smaller new value.
+ // Less overlap than not.
+ recv_ranges(&[10..20, 17..25, 0..10], 25);
+
+ // Again with some of the first range split in two.
+ recv_ranges(&[10..19, 19..20, 12..25, 0..10], 25);
+ recv_ranges(&[10..19, 19..20, 17..33, 0..10], 33);
+ recv_ranges(&[10..19, 19..20, 15..21, 0..10], 21);
+ recv_ranges(&[10..19, 19..20, 17..25, 0..10], 25);
+
+ // Again with a gap in the first range.
+ recv_ranges(&[10..18, 19..20, 12..25, 0..10], 25);
+ recv_ranges(&[10..18, 19..20, 17..33, 0..10], 33);
+ recv_ranges(&[10..18, 19..20, 15..21, 0..10], 21);
+ recv_ranges(&[10..18, 19..20, 17..25, 0..10], 25);
+ }
+
+ /// Complete overlaps with the start of a 10..20 range of bytes.
+ #[test]
+ fn recv_overlap_complete() {
+ // Complete overlap, more at the end.
+ recv_ranges(&[10..20, 9..23, 0..9], 23);
+ // Complete overlap, more at the start.
+ recv_ranges(&[10..20, 3..23, 0..3], 23);
+ // Complete overlap, to end.
+ recv_ranges(&[10..20, 5..20, 0..5], 20);
+ // Complete overlap, from start.
+ recv_ranges(&[10..20, 10..27, 0..10], 27);
+ // Complete overlap, from 0 and more.
+ recv_ranges(&[10..20, 0..23], 23);
+
+ // Again with the first range split in two.
+ recv_ranges(&[10..14, 14..20, 9..23, 0..9], 23);
+ recv_ranges(&[10..14, 14..20, 3..23, 0..3], 23);
+ recv_ranges(&[10..14, 14..20, 5..20, 0..5], 20);
+ recv_ranges(&[10..14, 14..20, 10..27, 0..10], 27);
+ recv_ranges(&[10..14, 14..20, 0..23], 23);
+
+ // Again with the a gap in the first range.
+ recv_ranges(&[10..13, 14..20, 9..23, 0..9], 23);
+ recv_ranges(&[10..13, 14..20, 3..23, 0..3], 23);
+ recv_ranges(&[10..13, 14..20, 5..20, 0..5], 20);
+ recv_ranges(&[10..13, 14..20, 10..27, 0..10], 27);
+ recv_ranges(&[10..13, 14..20, 0..23], 23);
+ }
+
+ /// An overlap with no new bytes.
+ #[test]
+ fn recv_overlap_duplicate() {
+ recv_ranges(&[10..20, 11..12, 0..10], 20);
+ recv_ranges(&[10..20, 10..15, 0..10], 20);
+ recv_ranges(&[10..20, 14..20, 0..10], 20);
+ // Now with the first range split.
+ recv_ranges(&[10..14, 14..20, 10..15, 0..10], 20);
+ recv_ranges(&[10..15, 16..20, 21..25, 10..25, 0..10], 25);
+ }
+
+ /// Reading exactly one chunk works, when the next chunk starts immediately.
+ #[test]
+ fn stop_reading_at_chunk() {
+ const CHUNK_SIZE: usize = 10;
+ const EXTRA_SIZE: usize = 3;
+ let mut s = RxStreamOrderer::new();
+
+ // Add three chunks.
+ s.inbound_frame(0, &[0; CHUNK_SIZE]);
+ let offset = u64::try_from(CHUNK_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+ let offset = u64::try_from(CHUNK_SIZE + EXTRA_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+
+ // Read, providing only enough space for the first.
+ let mut buf = vec![0; 100];
+ let count = s.read(&mut buf[..CHUNK_SIZE]);
+ assert_eq!(count, CHUNK_SIZE);
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, EXTRA_SIZE * 2);
+ }
+
+ #[test]
+ fn recv_overlap_while_reading() {
+ let mut s = RxStreamOrderer::new();
+
+ // Add a chunk
+ s.inbound_frame(0, &[0; 150]);
+ assert_eq!(s.data_ranges.get(&0).unwrap().len(), 150);
+ // Read, providing only enough space for the first 100.
+ let mut buf = [0; 100];
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, 100);
+ assert_eq!(s.retired, 100);
+
+ // Add a second frame that overlaps.
+ // This shouldn't truncate the first frame, as we're already
+ // Reading from it.
+ s.inbound_frame(120, &[0; 60]);
+ assert_eq!(s.data_ranges.get(&0).unwrap().len(), 180);
+ // Read second part of first frame and all of the second frame
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, 80);
+ }
+
+ /// Reading exactly one chunk works, when there is a gap.
+ #[test]
+ fn stop_reading_at_gap() {
+ const CHUNK_SIZE: usize = 10;
+ const EXTRA_SIZE: usize = 3;
+ let mut s = RxStreamOrderer::new();
+
+ // Add three chunks.
+ s.inbound_frame(0, &[0; CHUNK_SIZE]);
+ let offset = u64::try_from(CHUNK_SIZE + EXTRA_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+
+ // Read, providing only enough space for the first chunk.
+ let mut buf = [0; 100];
+ let count = s.read(&mut buf[..CHUNK_SIZE]);
+ assert_eq!(count, CHUNK_SIZE);
+
+ // Now fill the gap and ensure that everything can be read.
+ let offset = u64::try_from(CHUNK_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, EXTRA_SIZE * 2);
+ }
+
+ /// Reading exactly one chunk works, when there is a gap.
+ #[test]
+ fn stop_reading_in_chunk() {
+ const CHUNK_SIZE: usize = 10;
+ const EXTRA_SIZE: usize = 3;
+ let mut s = RxStreamOrderer::new();
+
+ // Add two chunks.
+ s.inbound_frame(0, &[0; CHUNK_SIZE]);
+ let offset = u64::try_from(CHUNK_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+
+ // Read, providing only enough space for some of the first chunk.
+ let mut buf = [0; 100];
+ let count = s.read(&mut buf[..CHUNK_SIZE - EXTRA_SIZE]);
+ assert_eq!(count, CHUNK_SIZE - EXTRA_SIZE);
+
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, EXTRA_SIZE * 2);
+ }
+
+ /// Read one byte at a time.
+ #[test]
+ fn read_byte_at_a_time() {
+ const CHUNK_SIZE: usize = 10;
+ const EXTRA_SIZE: usize = 3;
+ let mut s = RxStreamOrderer::new();
+
+ // Add two chunks.
+ s.inbound_frame(0, &[0; CHUNK_SIZE]);
+ let offset = u64::try_from(CHUNK_SIZE).unwrap();
+ s.inbound_frame(offset, &[0; EXTRA_SIZE]);
+
+ let mut buf = [0; 1];
+ for _ in 0..CHUNK_SIZE + EXTRA_SIZE {
+ let count = s.read(&mut buf[..]);
+ assert_eq!(count, 1);
+ }
+ assert_eq!(0, s.read(&mut buf[..]));
+ }
+
+ #[test]
+ fn stream_rx() {
+ let flow_mgr = Rc::new(RefCell::new(FlowMgr::default()));
+ let conn_events = ConnectionEvents::default();
+
+ let mut s = RecvStream::new(StreamId::from(567), 1024, Rc::clone(&flow_mgr), conn_events);
+
+ // test receiving a contig frame and reading it works
+ s.inbound_stream_frame(false, 0, &[1; 10]).unwrap();
+ assert_eq!(s.data_ready(), true);
+ let mut buf = vec![0u8; 100];
+ assert_eq!(s.read(&mut buf).unwrap(), (10, false));
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 0);
+
+ // test receiving a noncontig frame
+ s.inbound_stream_frame(false, 12, &[2; 12]).unwrap();
+ assert_eq!(s.data_ready(), false);
+ assert_eq!(s.read(&mut buf).unwrap(), (0, false));
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 12);
+
+ // another frame that overlaps the first
+ s.inbound_stream_frame(false, 14, &[3; 8]).unwrap();
+ assert_eq!(s.data_ready(), false);
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 12);
+
+ // fill in the gap, but with a FIN
+ s.inbound_stream_frame(true, 10, &[4; 6]).unwrap_err();
+ assert_eq!(s.data_ready(), false);
+ assert_eq!(s.read(&mut buf).unwrap(), (0, false));
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 12);
+
+ // fill in the gap
+ s.inbound_stream_frame(false, 10, &[5; 10]).unwrap();
+ assert_eq!(s.data_ready(), true);
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 14);
+
+ // a legit FIN
+ s.inbound_stream_frame(true, 24, &[6; 18]).unwrap();
+ assert_eq!(s.state.recv_buf().unwrap().retired(), 10);
+ assert_eq!(s.state.recv_buf().unwrap().buffered(), 32);
+ assert_eq!(s.data_ready(), true);
+ assert_eq!(s.read(&mut buf).unwrap(), (32, true));
+
+ // Stream now no longer readable (is in DataRead state)
+ s.read(&mut buf).unwrap_err();
+ }
+
+ fn check_chunks(s: &mut RxStreamOrderer, expected: &[(u64, usize)]) {
+ assert_eq!(s.data_ranges.len(), expected.len());
+ for ((start, buf), (expected_start, expected_len)) in s.data_ranges.iter().zip(expected) {
+ assert_eq!((*start, buf.len()), (*expected_start, *expected_len));
+ }
+ }
+
+ // Test deduplication when the new data is at the end.
+ #[test]
+ fn stream_rx_dedupe_tail() {
+ let mut s = RxStreamOrderer::new();
+
+ s.inbound_frame(0, &[1; 6]);
+ check_chunks(&mut s, &[(0, 6)]);
+
+ // New data that overlaps entirely (starting from the head), is ignored.
+ s.inbound_frame(0, &[2; 3]);
+ check_chunks(&mut s, &[(0, 6)]);
+
+ // New data that overlaps at the tail has any new data appended.
+ s.inbound_frame(2, &[3; 6]);
+ check_chunks(&mut s, &[(0, 8)]);
+
+ // New data that overlaps entirely (up to the tail), is ignored.
+ s.inbound_frame(4, &[4; 4]);
+ check_chunks(&mut s, &[(0, 8)]);
+
+ // New data that overlaps, starting from the beginning is appended too.
+ s.inbound_frame(0, &[5; 10]);
+ check_chunks(&mut s, &[(0, 10)]);
+
+ // New data that is entirely subsumed is ignored.
+ s.inbound_frame(2, &[6; 2]);
+ check_chunks(&mut s, &[(0, 10)]);
+
+ let mut buf = [0; 16];
+ assert_eq!(s.read(&mut buf[..]), 10);
+ assert_eq!(buf[..10], [1, 1, 1, 1, 1, 1, 3, 3, 5, 5]);
+ }
+
+ /// When chunks are added before existing data, they aren't merged.
+ #[test]
+ fn stream_rx_dedupe_head() {
+ let mut s = RxStreamOrderer::new();
+
+ s.inbound_frame(1, &[6; 6]);
+ check_chunks(&mut s, &[(1, 6)]);
+
+ // Insertion before an existing chunk causes truncation of the new chunk.
+ s.inbound_frame(0, &[7; 6]);
+ check_chunks(&mut s, &[(0, 1), (1, 6)]);
+
+ // Perfect overlap with existing slices has no effect.
+ s.inbound_frame(0, &[8; 7]);
+ check_chunks(&mut s, &[(0, 1), (1, 6)]);
+
+ let mut buf = [0; 16];
+ assert_eq!(s.read(&mut buf[..]), 7);
+ assert_eq!(buf[..7], [7, 6, 6, 6, 6, 6, 6]);
+ }
+
+ #[test]
+ fn stream_rx_dedupe_new_tail() {
+ let mut s = RxStreamOrderer::new();
+
+ s.inbound_frame(1, &[6; 6]);
+ check_chunks(&mut s, &[(1, 6)]);
+
+ // Insertion before an existing chunk causes truncation of the new chunk.
+ s.inbound_frame(0, &[7; 6]);
+ check_chunks(&mut s, &[(0, 1), (1, 6)]);
+
+ // New data at the end causes the tail to be added to the first chunk,
+ // replacing later chunks entirely.
+ s.inbound_frame(0, &[9; 8]);
+ check_chunks(&mut s, &[(0, 8)]);
+
+ let mut buf = [0; 16];
+ assert_eq!(s.read(&mut buf[..]), 8);
+ assert_eq!(buf[..8], [7, 9, 9, 9, 9, 9, 9, 9]);
+ }
+
+ #[test]
+ fn stream_rx_dedupe_replace() {
+ let mut s = RxStreamOrderer::new();
+
+ s.inbound_frame(2, &[6; 6]);
+ check_chunks(&mut s, &[(2, 6)]);
+
+ // Insertion before an existing chunk causes truncation of the new chunk.
+ s.inbound_frame(1, &[7; 6]);
+ check_chunks(&mut s, &[(1, 1), (2, 6)]);
+
+ // New data at the start and end replaces all the slices.
+ s.inbound_frame(0, &[9; 10]);
+ check_chunks(&mut s, &[(0, 10)]);
+
+ let mut buf = [0; 16];
+ assert_eq!(s.read(&mut buf[..]), 10);
+ assert_eq!(buf[..10], [9; 10]);
+ }
+
+ #[test]
+ fn trim_retired() {
+ let mut s = RxStreamOrderer::new();
+
+ let mut buf = [0; 18];
+ s.inbound_frame(0, &[1; 10]);
+
+ // Partially read slices are retained.
+ assert_eq!(s.read(&mut buf[..6]), 6);
+ check_chunks(&mut s, &[(0, 10)]);
+
+ // Partially read slices are kept and so are added to.
+ s.inbound_frame(3, &buf[..10]);
+ check_chunks(&mut s, &[(0, 13)]);
+
+ // Wholly read pieces are dropped.
+ assert_eq!(s.read(&mut buf[..]), 7);
+ assert!(s.data_ranges.is_empty());
+
+ // New data that overlaps with retired data is trimmed.
+ s.inbound_frame(0, &buf[..]);
+ check_chunks(&mut s, &[(13, 5)]);
+ }
+
+ #[test]
+ fn stream_flowc_update() {
+ let flow_mgr = Rc::default();
+ let conn_events = ConnectionEvents::default();
+
+ let frame1 = vec![0; RECV_BUFFER_SIZE];
+
+ let mut s = RecvStream::new(
+ StreamId::from(4),
+ RX_STREAM_DATA_WINDOW,
+ Rc::clone(&flow_mgr),
+ conn_events,
+ );
+
+ let mut buf = vec![0u8; RECV_BUFFER_SIZE + 100]; // Make it overlarge
+
+ s.maybe_send_flowc_update();
+ assert_eq!(s.flow_mgr.borrow().peek(), None);
+ s.inbound_stream_frame(false, 0, &frame1).unwrap();
+ s.maybe_send_flowc_update();
+ assert_eq!(s.flow_mgr.borrow().peek(), None);
+ assert_eq!(s.read(&mut buf).unwrap(), (RECV_BUFFER_SIZE, false));
+ assert_eq!(s.data_ready(), false);
+ s.maybe_send_flowc_update();
+
+ // flow msg generated!
+ assert!(s.flow_mgr.borrow().peek().is_some());
+
+ // consume it
+ s.flow_mgr.borrow_mut().next().unwrap();
+
+ // it should be gone
+ s.maybe_send_flowc_update();
+ assert_eq!(s.flow_mgr.borrow().peek(), None);
+ }
+
+ #[test]
+ fn stream_max_stream_data() {
+ let flow_mgr = Rc::new(RefCell::new(FlowMgr::default()));
+ let conn_events = ConnectionEvents::default();
+
+ let frame1 = vec![0; RECV_BUFFER_SIZE];
+ let mut s = RecvStream::new(
+ StreamId::from(67),
+ RX_STREAM_DATA_WINDOW,
+ Rc::clone(&flow_mgr),
+ conn_events,
+ );
+
+ s.maybe_send_flowc_update();
+ assert_eq!(s.flow_mgr.borrow().peek(), None);
+ s.inbound_stream_frame(false, 0, &frame1).unwrap();
+ s.inbound_stream_frame(false, RX_STREAM_DATA_WINDOW, &[1; 1])
+ .unwrap_err();
+ }
+
+ #[test]
+ fn stream_orderer_bytes_ready() {
+ let mut rx_ord = RxStreamOrderer::new();
+
+ rx_ord.inbound_frame(0, &[1; 6]);
+ assert_eq!(rx_ord.bytes_ready(), 6);
+ assert_eq!(rx_ord.buffered(), 6);
+ assert_eq!(rx_ord.retired(), 0);
+
+ // read some so there's an offset into the first frame
+ let mut buf = [0u8; 10];
+ rx_ord.read(&mut buf[..2]);
+ assert_eq!(rx_ord.bytes_ready(), 4);
+ assert_eq!(rx_ord.buffered(), 4);
+ assert_eq!(rx_ord.retired(), 2);
+
+ // an overlapping frame
+ rx_ord.inbound_frame(5, &[2; 6]);
+ assert_eq!(rx_ord.bytes_ready(), 9);
+ assert_eq!(rx_ord.buffered(), 9);
+ assert_eq!(rx_ord.retired(), 2);
+
+ // a noncontig frame
+ rx_ord.inbound_frame(20, &[3; 6]);
+ assert_eq!(rx_ord.bytes_ready(), 9);
+ assert_eq!(rx_ord.buffered(), 15);
+ assert_eq!(rx_ord.retired(), 2);
+
+ // an old frame
+ rx_ord.inbound_frame(0, &[4; 2]);
+ assert_eq!(rx_ord.bytes_ready(), 9);
+ assert_eq!(rx_ord.buffered(), 15);
+ assert_eq!(rx_ord.retired(), 2);
+ }
+
+ #[test]
+ fn no_stream_flowc_event_after_exiting_recv() {
+ let flow_mgr = Rc::new(RefCell::new(FlowMgr::default()));
+ let conn_events = ConnectionEvents::default();
+
+ let frame1 = vec![0; RECV_BUFFER_SIZE];
+ let stream_id = StreamId::from(67);
+ let mut s = RecvStream::new(
+ stream_id,
+ RX_STREAM_DATA_WINDOW,
+ Rc::clone(&flow_mgr),
+ conn_events,
+ );
+
+ s.inbound_stream_frame(false, 0, &frame1).unwrap();
+ flow_mgr.borrow_mut().max_stream_data(stream_id, 100);
+ assert!(matches!(s.flow_mgr.borrow().peek().unwrap(), Frame::MaxStreamData{..}));
+ s.inbound_stream_frame(true, RX_STREAM_DATA_WINDOW, &[])
+ .unwrap();
+ assert!(matches!(s.flow_mgr.borrow().peek(), None));
+ }
+
+ #[test]
+ fn resend_flowc_if_lost() {
+ let flow_mgr = Rc::new(RefCell::new(FlowMgr::default()));
+ let conn_events = ConnectionEvents::default();
+
+ let frame1 = &[0; RECV_BUFFER_SIZE];
+ let stream_id = StreamId::from(67);
+ let mut s = RecvStream::new(
+ stream_id,
+ RX_STREAM_DATA_WINDOW,
+ Rc::clone(&flow_mgr),
+ conn_events,
+ );
+
+ // A flow control update is queued
+ s.inbound_stream_frame(false, 0, frame1).unwrap();
+ flow_mgr.borrow_mut().max_stream_data(stream_id, 100);
+ // Generates frame
+ assert!(matches!(
+ s.flow_mgr.borrow_mut().next().unwrap(),
+ Frame::MaxStreamData { .. }
+ ));
+ // Nothing else queued
+ assert!(matches!(s.flow_mgr.borrow().peek(), None));
+ // Asking for another one won't get you one
+ s.maybe_send_flowc_update();
+ assert!(matches!(s.flow_mgr.borrow().peek(), None));
+ // But if lost, another frame is generated
+ flow_mgr.borrow_mut().max_stream_data(stream_id, 100);
+ assert!(matches!(s.flow_mgr.borrow_mut().next().unwrap(), Frame::MaxStreamData{..}));
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/send_stream.rs b/third_party/rust/neqo-transport/src/send_stream.rs
new file mode 100644
index 0000000000..b6b9eea5f5
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/send_stream.rs
@@ -0,0 +1,1746 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Buffering data to send until it is acked.
+
+use std::cell::RefCell;
+use std::cmp::{max, min};
+use std::collections::{BTreeMap, VecDeque};
+use std::convert::{TryFrom, TryInto};
+use std::mem;
+use std::rc::Rc;
+
+use indexmap::IndexMap;
+use smallvec::SmallVec;
+
+use neqo_common::{qdebug, qerror, qinfo, qtrace, Encoder};
+
+use crate::events::ConnectionEvents;
+use crate::flow_mgr::FlowMgr;
+use crate::frame::Frame;
+use crate::packet::PacketBuilder;
+use crate::recovery::RecoveryToken;
+use crate::stats::FrameStats;
+use crate::stream_id::StreamId;
+use crate::{AppError, Error, Res};
+
+pub const SEND_BUFFER_SIZE: usize = 0x10_0000; // 1 MiB
+
+#[derive(Debug, PartialEq, Clone, Copy)]
+enum RangeState {
+ Sent,
+ Acked,
+}
+
+/// Track ranges in the stream as sent or acked. Acked implies sent. Not in a
+/// range implies needing-to-be-sent, either initially or as a retransmission.
+#[derive(Debug, Default, PartialEq)]
+struct RangeTracker {
+ // offset, (len, RangeState). Use u64 for len because ranges can exceed 32bits.
+ used: BTreeMap<u64, (u64, RangeState)>,
+}
+
+impl RangeTracker {
+ fn highest_offset(&self) -> u64 {
+ self.used
+ .range(..)
+ .next_back()
+ .map_or(0, |(k, (v, _))| *k + *v)
+ }
+
+ fn acked_from_zero(&self) -> u64 {
+ self.used
+ .get(&0)
+ .filter(|(_, state)| *state == RangeState::Acked)
+ .map_or(0, |(v, _)| *v)
+ }
+
+ /// Find the first unmarked range. If all are contiguous, this will return
+ /// (highest_offset(), None).
+ fn first_unmarked_range(&self) -> (u64, Option<u64>) {
+ let mut prev_end = 0;
+
+ for (cur_off, (cur_len, _)) in &self.used {
+ if prev_end == *cur_off {
+ prev_end = cur_off + cur_len;
+ } else {
+ return (prev_end, Some(cur_off - prev_end));
+ }
+ }
+ (prev_end, None)
+ }
+
+ /// Turn one range into a list of subranges that align with existing
+ /// ranges.
+ /// Check impermissible overlaps in subregions: Sent cannot overwrite Acked.
+ //
+ // e.g. given N is new and ABC are existing:
+ // NNNNNNNNNNNNNNNN
+ // AAAAA BBBCCCCC ...then we want 5 chunks:
+ // 1122222333444555
+ //
+ // but also if we have this:
+ // NNNNNNNNNNNNNNNN
+ // AAAAAAAAAA BBBB ...then break existing A and B ranges up:
+ //
+ // 1111111122222233
+ // aaAAAAAAAA BBbb
+ //
+ // Doing all this work up front should make handling each chunk much
+ // easier.
+ fn chunk_range_on_edges(
+ &mut self,
+ new_off: u64,
+ new_len: u64,
+ new_state: RangeState,
+ ) -> Vec<(u64, u64, RangeState)> {
+ let mut tmp_off = new_off;
+ let mut tmp_len = new_len;
+ let mut v = Vec::new();
+
+ // cut previous overlapping range if needed
+ let prev = self.used.range_mut(..tmp_off).next_back();
+ if let Some((prev_off, (prev_len, prev_state))) = prev {
+ let prev_state = *prev_state;
+ let overlap = (*prev_off + *prev_len).saturating_sub(new_off);
+ *prev_len -= overlap;
+ if overlap > 0 {
+ self.used.insert(new_off, (overlap, prev_state));
+ }
+ }
+
+ let mut last_existing_remaining = None;
+ for (off, (len, state)) in self.used.range(tmp_off..tmp_off + tmp_len) {
+ // Create chunk for "overhang" before an existing range
+ if tmp_off < *off {
+ let sub_len = off - tmp_off;
+ v.push((tmp_off, sub_len, new_state));
+ tmp_off += sub_len;
+ tmp_len -= sub_len;
+ }
+
+ // Create chunk to match existing range
+ let sub_len = min(*len, tmp_len);
+ let remaining_len = len - sub_len;
+ if new_state == RangeState::Sent && *state == RangeState::Acked {
+ qinfo!(
+ "Attempted to downgrade overlapping range Acked range {}-{} with Sent {}-{}",
+ off,
+ len,
+ new_off,
+ new_len
+ );
+ } else {
+ v.push((tmp_off, sub_len, new_state));
+ }
+ tmp_off += sub_len;
+ tmp_len -= sub_len;
+
+ if remaining_len > 0 {
+ last_existing_remaining = Some((*off, sub_len, remaining_len, *state));
+ }
+ }
+
+ // Maybe break last existing range in two so that a final chunk will
+ // have the same length as an existing range entry
+ if let Some((off, sub_len, remaining_len, state)) = last_existing_remaining {
+ *self.used.get_mut(&off).expect("must be there") = (sub_len, state);
+ self.used.insert(off + sub_len, (remaining_len, state));
+ }
+
+ // Create final chunk if anything remains of the new range
+ if tmp_len > 0 {
+ v.push((tmp_off, tmp_len, new_state))
+ }
+
+ v
+ }
+
+ /// Merge contiguous Acked ranges into the first entry (0). This range may
+ /// be dropped from the send buffer.
+ fn coalesce_acked_from_zero(&mut self) {
+ let acked_range_from_zero = self
+ .used
+ .get_mut(&0)
+ .filter(|(_, state)| *state == RangeState::Acked)
+ .map(|(len, _)| *len);
+
+ if let Some(len_from_zero) = acked_range_from_zero {
+ let mut to_remove = SmallVec::<[_; 8]>::new();
+
+ let mut new_len_from_zero = len_from_zero;
+
+ // See if there's another Acked range entry contiguous to this one
+ while let Some((next_len, _)) = self
+ .used
+ .get(&new_len_from_zero)
+ .filter(|(_, state)| *state == RangeState::Acked)
+ {
+ to_remove.push(new_len_from_zero);
+ new_len_from_zero += *next_len;
+ }
+
+ if len_from_zero != new_len_from_zero {
+ self.used.get_mut(&0).expect("must be there").0 = new_len_from_zero;
+ }
+
+ for val in to_remove {
+ self.used.remove(&val);
+ }
+ }
+ }
+
+ fn mark_range(&mut self, off: u64, len: usize, state: RangeState) {
+ if len == 0 {
+ qinfo!("mark 0-length range at {}", off);
+ return;
+ }
+
+ let subranges = self.chunk_range_on_edges(off, len as u64, state);
+
+ for (sub_off, sub_len, sub_state) in subranges {
+ self.used.insert(sub_off, (sub_len, sub_state));
+ }
+
+ self.coalesce_acked_from_zero()
+ }
+
+ fn unmark_range(&mut self, off: u64, len: usize) {
+ if len == 0 {
+ qdebug!("unmark 0-length range at {}", off);
+ return;
+ }
+
+ let len = u64::try_from(len).unwrap();
+ let end_off = off + len;
+
+ let mut to_remove = SmallVec::<[_; 8]>::new();
+ let mut to_add = None;
+
+ // Walk backwards through possibly affected existing ranges
+ for (cur_off, (cur_len, cur_state)) in self.used.range_mut(..off + len).rev() {
+ // Maybe fixup range preceding the removed range
+ if *cur_off < off {
+ // Check for overlap
+ if *cur_off + *cur_len > off {
+ if *cur_state == RangeState::Acked {
+ qdebug!(
+ "Attempted to unmark Acked range {}-{} with unmark_range {}-{}",
+ cur_off,
+ cur_len,
+ off,
+ off + len
+ );
+ } else {
+ *cur_len = off - cur_off;
+ }
+ }
+ break;
+ }
+
+ if *cur_state == RangeState::Acked {
+ qdebug!(
+ "Attempted to unmark Acked range {}-{} with unmark_range {}-{}",
+ cur_off,
+ cur_len,
+ off,
+ off + len
+ );
+ continue;
+ }
+
+ // Add a new range for old subrange extending beyond
+ // to-be-unmarked range
+ let cur_end_off = cur_off + *cur_len;
+ if cur_end_off > end_off {
+ let new_cur_off = off + len;
+ let new_cur_len = cur_end_off - end_off;
+ assert_eq!(to_add, None);
+ to_add = Some((new_cur_off, new_cur_len, *cur_state));
+ }
+
+ to_remove.push(*cur_off);
+ }
+
+ for remove_off in to_remove {
+ self.used.remove(&remove_off);
+ }
+
+ if let Some((new_cur_off, new_cur_len, cur_state)) = to_add {
+ self.used.insert(new_cur_off, (new_cur_len, cur_state));
+ }
+ }
+
+ /// Unmark all sent ranges.
+ pub fn unmark_sent(&mut self) {
+ self.unmark_range(0, usize::try_from(self.highest_offset()).unwrap());
+ }
+}
+
+/// Buffer to contain queued bytes and track their state.
+#[derive(Debug, Default, PartialEq)]
+pub struct TxBuffer {
+ retired: u64, // contig acked bytes, no longer in buffer
+ send_buf: VecDeque<u8>, // buffer of not-acked bytes
+ ranges: RangeTracker, // ranges in buffer that have been sent or acked
+}
+
+impl TxBuffer {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ /// Attempt to add some or all of the passed-in buffer to the TxBuffer.
+ pub fn send(&mut self, buf: &[u8]) -> usize {
+ let can_buffer = min(SEND_BUFFER_SIZE - self.buffered(), buf.len());
+ if can_buffer > 0 {
+ self.send_buf.extend(&buf[..can_buffer]);
+ assert!(self.send_buf.len() <= SEND_BUFFER_SIZE);
+ }
+ can_buffer
+ }
+
+ pub fn next_bytes(&self) -> Option<(u64, &[u8])> {
+ let (start, maybe_len) = self.ranges.first_unmarked_range();
+
+ if start == self.retired + u64::try_from(self.buffered()).unwrap() {
+ return None;
+ }
+
+ // Convert from ranges-relative-to-zero to
+ // ranges-relative-to-buffer-start
+ let buff_off = usize::try_from(start - self.retired).unwrap();
+
+ // Deque returns two slices. Create a subslice from whichever
+ // one contains the first unmarked data.
+ let slc = if buff_off < self.send_buf.as_slices().0.len() {
+ &self.send_buf.as_slices().0[buff_off..]
+ } else {
+ &self.send_buf.as_slices().1[buff_off - self.send_buf.as_slices().0.len()..]
+ };
+
+ let len = if let Some(range_len) = maybe_len {
+ // Truncate if range crosses deque slices
+ min(usize::try_from(range_len).unwrap(), slc.len())
+ } else {
+ slc.len()
+ };
+
+ debug_assert!(len > 0);
+ debug_assert!(len <= slc.len());
+
+ Some((start, &slc[..len]))
+ }
+
+ pub fn mark_as_sent(&mut self, offset: u64, len: usize) {
+ self.ranges.mark_range(offset, len, RangeState::Sent)
+ }
+
+ pub fn mark_as_acked(&mut self, offset: u64, len: usize) {
+ self.ranges.mark_range(offset, len, RangeState::Acked);
+
+ // We can drop contig acked range from the buffer
+ let new_retirable = self.ranges.acked_from_zero() - self.retired;
+ debug_assert!(new_retirable <= self.buffered() as u64);
+ let keep_len =
+ self.buffered() - usize::try_from(new_retirable).expect("should fit in usize");
+
+ // Truncate front
+ self.send_buf.rotate_left(self.buffered() - keep_len);
+ self.send_buf.truncate(keep_len);
+
+ self.retired += new_retirable;
+ }
+
+ pub fn mark_as_lost(&mut self, offset: u64, len: usize) {
+ self.ranges.unmark_range(offset, len)
+ }
+
+ /// Forget about anything that was marked as sent.
+ pub fn unmark_sent(&mut self) {
+ self.ranges.unmark_sent();
+ }
+
+ fn data_limit(&self) -> u64 {
+ self.buffered() as u64 + self.retired
+ }
+
+ fn buffered(&self) -> usize {
+ self.send_buf.len()
+ }
+
+ fn avail(&self) -> usize {
+ SEND_BUFFER_SIZE - self.buffered()
+ }
+
+ pub fn highest_sent(&self) -> u64 {
+ self.ranges.highest_offset()
+ }
+}
+
+/// QUIC sending stream states, based on -transport 3.1.
+#[derive(Debug, PartialEq)]
+enum SendStreamState {
+ Ready,
+ Send {
+ send_buf: TxBuffer,
+ },
+ DataSent {
+ send_buf: TxBuffer,
+ final_size: u64,
+ fin_sent: bool,
+ fin_acked: bool,
+ },
+ DataRecvd {
+ final_size: u64,
+ },
+ ResetSent,
+ ResetRecvd,
+}
+
+impl SendStreamState {
+ fn tx_buf(&self) -> Option<&TxBuffer> {
+ match self {
+ Self::Send { send_buf } | Self::DataSent { send_buf, .. } => Some(send_buf),
+ Self::Ready | Self::DataRecvd { .. } | Self::ResetSent | Self::ResetRecvd => None,
+ }
+ }
+
+ fn tx_buf_mut(&mut self) -> Option<&mut TxBuffer> {
+ match self {
+ Self::Send { send_buf } | Self::DataSent { send_buf, .. } => Some(send_buf),
+ Self::Ready | Self::DataRecvd { .. } | Self::ResetSent | Self::ResetRecvd => None,
+ }
+ }
+
+ fn tx_avail(&self) -> u64 {
+ match self {
+ // In Ready, TxBuffer not yet allocated but size is known
+ Self::Ready => SEND_BUFFER_SIZE.try_into().unwrap(),
+ Self::Send { send_buf } | Self::DataSent { send_buf, .. } => {
+ send_buf.avail().try_into().unwrap()
+ }
+ Self::DataRecvd { .. } | Self::ResetSent | Self::ResetRecvd => 0,
+ }
+ }
+
+ fn final_size(&self) -> Option<u64> {
+ match self {
+ Self::DataSent { final_size, .. } | Self::DataRecvd { final_size } => Some(*final_size),
+ Self::Ready | Self::Send { .. } | Self::ResetSent | Self::ResetRecvd => None,
+ }
+ }
+
+ fn name(&self) -> &str {
+ match self {
+ Self::Ready => "Ready",
+ Self::Send { .. } => "Send",
+ Self::DataSent { .. } => "DataSent",
+ Self::DataRecvd { .. } => "DataRecvd",
+ Self::ResetSent => "ResetSent",
+ Self::ResetRecvd => "ResetRecvd",
+ }
+ }
+
+ fn transition(&mut self, new_state: Self) {
+ qtrace!("SendStream state {} -> {}", self.name(), new_state.name());
+ *self = new_state;
+ }
+}
+
+/// Implement a QUIC send stream.
+#[derive(Debug)]
+pub struct SendStream {
+ stream_id: StreamId,
+ max_stream_data: u64,
+ state: SendStreamState,
+ flow_mgr: Rc<RefCell<FlowMgr>>,
+ conn_events: ConnectionEvents,
+}
+
+impl SendStream {
+ pub fn new(
+ stream_id: StreamId,
+ max_stream_data: u64,
+ flow_mgr: Rc<RefCell<FlowMgr>>,
+ conn_events: ConnectionEvents,
+ ) -> Self {
+ let ss = Self {
+ stream_id,
+ max_stream_data,
+ state: SendStreamState::Ready,
+ flow_mgr,
+ conn_events,
+ };
+ if ss.avail() > 0 {
+ ss.conn_events.send_stream_writable(stream_id);
+ }
+ ss
+ }
+
+ /// Return the next range to be sent, if any.
+ pub fn next_bytes(&mut self) -> Option<(u64, &[u8])> {
+ match self.state {
+ SendStreamState::Send { ref send_buf } => send_buf.next_bytes(),
+ SendStreamState::DataSent {
+ ref send_buf,
+ fin_sent,
+ final_size,
+ ..
+ } => {
+ let bytes = send_buf.next_bytes();
+ if bytes.is_some() {
+ // Must be a resend
+ bytes
+ } else if fin_sent {
+ None
+ } else {
+ // Send empty stream frame with fin set
+ Some((final_size, &[]))
+ }
+ }
+ SendStreamState::Ready
+ | SendStreamState::DataRecvd { .. }
+ | SendStreamState::ResetSent
+ | SendStreamState::ResetRecvd => None,
+ }
+ }
+
+ /// Calculate how many bytes (length) can fit into available space and whether
+ /// the remainder of the space can be filled (or if a length field is needed).
+ fn length_and_fill(data_len: usize, space: usize) -> (usize, bool) {
+ if data_len >= space {
+ // Either more data than space allows, or an exact fit.
+ qtrace!("SendStream::length_and_fill fill {}", space);
+ return (space, true);
+ }
+
+ // Estimate size of the length field based on the available space,
+ // less 1, which is the worst case.
+ let length = min(space.saturating_sub(1), data_len);
+ let length_len = Encoder::varint_len(u64::try_from(length).unwrap());
+ if length_len > space {
+ qtrace!(
+ "SendStream::length_and_fill no room for length of {} in {}",
+ length,
+ space
+ );
+ return (0, false);
+ }
+
+ let length = min(data_len, space - length_len);
+ qtrace!("SendStream::length_and_fill {} in {}", length, space);
+ (length, false)
+ }
+
+ pub fn write_frame(&mut self, builder: &mut PacketBuilder) -> Option<RecoveryToken> {
+ let id = self.stream_id;
+ let final_size = self.final_size();
+ if let Some((offset, data)) = self.next_bytes() {
+ let overhead = 1 // Frame type
+ + Encoder::varint_len(id.as_u64())
+ + if offset > 0 {
+ Encoder::varint_len(offset)
+ } else {
+ 0
+ };
+ if overhead > builder.remaining() {
+ qtrace!("SendStream::write_frame no space for header");
+ return None;
+ }
+
+ let (length, fill) = Self::length_and_fill(data.len(), builder.remaining() - overhead);
+ let fin = final_size.map_or(false, |fs| fs == offset + u64::try_from(length).unwrap());
+ if length == 0 && !fin {
+ qtrace!("SendStream::write_frame no data, no fin");
+ return None;
+ }
+
+ // Write the stream out.
+ builder.encode_varint(Frame::stream_type(fin, offset > 0, fill));
+ builder.encode_varint(id.as_u64());
+ if offset > 0 {
+ builder.encode_varint(offset);
+ }
+ if fill {
+ builder.encode(&data[..length]);
+ } else {
+ builder.encode_vvec(&data[..length]);
+ }
+
+ self.mark_as_sent(offset, length, fin);
+ Some(RecoveryToken::Stream(StreamRecoveryToken {
+ id,
+ offset,
+ length,
+ fin,
+ }))
+ } else {
+ None
+ }
+ }
+
+ pub fn mark_as_sent(&mut self, offset: u64, len: usize, fin: bool) {
+ if let Some(buf) = self.state.tx_buf_mut() {
+ buf.mark_as_sent(offset, len);
+ self.send_blocked_if_space_needed(0);
+ };
+
+ if fin {
+ if let SendStreamState::DataSent { fin_sent, .. } = &mut self.state {
+ *fin_sent = true;
+ }
+ }
+ }
+
+ pub fn mark_as_acked(&mut self, offset: u64, len: usize, fin: bool) {
+ match self.state {
+ SendStreamState::Send { ref mut send_buf } => {
+ send_buf.mark_as_acked(offset, len);
+ if self.avail() > 0 {
+ self.conn_events.send_stream_writable(self.stream_id)
+ }
+ }
+ SendStreamState::DataSent {
+ ref mut send_buf,
+ final_size,
+ ref mut fin_acked,
+ ..
+ } => {
+ send_buf.mark_as_acked(offset, len);
+ if fin {
+ *fin_acked = true;
+ }
+ if *fin_acked && send_buf.buffered() == 0 {
+ self.conn_events.send_stream_complete(self.stream_id);
+ self.state
+ .transition(SendStreamState::DataRecvd { final_size });
+ }
+ }
+ _ => qtrace!("mark_as_acked called from state {}", self.state.name()),
+ }
+ }
+
+ pub fn mark_as_lost(&mut self, offset: u64, len: usize, fin: bool) {
+ if let Some(buf) = self.state.tx_buf_mut() {
+ buf.mark_as_lost(offset, len);
+ }
+
+ if fin {
+ if let SendStreamState::DataSent {
+ fin_sent,
+ fin_acked,
+ ..
+ } = &mut self.state
+ {
+ *fin_sent = *fin_acked;
+ }
+ }
+ }
+
+ pub fn final_size(&self) -> Option<u64> {
+ self.state.final_size()
+ }
+
+ /// Stream credit available
+ pub fn credit_avail(&self) -> u64 {
+ if self.state == SendStreamState::Ready {
+ self.max_stream_data
+ } else {
+ self.state
+ .tx_buf()
+ .map_or(0, |tx| self.max_stream_data - tx.data_limit())
+ }
+ }
+
+ /// Bytes sendable on stream. Constrained by stream credit available,
+ /// connection credit available, and space in the tx buffer.
+ pub fn avail(&self) -> usize {
+ min(
+ min(self.state.tx_avail(), self.credit_avail()),
+ self.flow_mgr.borrow().conn_credit_avail(),
+ )
+ .try_into()
+ .unwrap()
+ }
+
+ pub fn max_stream_data(&self) -> u64 {
+ self.max_stream_data
+ }
+
+ pub fn set_max_stream_data(&mut self, value: u64) {
+ let stream_was_blocked = self.avail() == 0;
+ self.max_stream_data = max(self.max_stream_data, value);
+ if stream_was_blocked && self.avail() > 0 {
+ self.conn_events.send_stream_writable(self.stream_id)
+ }
+ }
+
+ pub fn reset_acked(&mut self) {
+ match self.state {
+ SendStreamState::Ready
+ | SendStreamState::Send { .. }
+ | SendStreamState::DataSent { .. }
+ | SendStreamState::DataRecvd { .. } => {
+ qtrace!("Reset acked while in {} state?", self.state.name())
+ }
+ SendStreamState::ResetSent => self.state.transition(SendStreamState::ResetRecvd),
+ SendStreamState::ResetRecvd => qtrace!("already in ResetRecvd state"),
+ };
+ }
+
+ pub fn is_terminal(&self) -> bool {
+ matches!(self.state, SendStreamState::DataRecvd { .. } | SendStreamState::ResetRecvd)
+ }
+
+ pub fn send(&mut self, buf: &[u8]) -> Res<usize> {
+ self.send_internal(buf, false)
+ }
+
+ pub fn send_atomic(&mut self, buf: &[u8]) -> Res<usize> {
+ self.send_internal(buf, true)
+ }
+
+ fn send_blocked_if_space_needed(&mut self, needed_space: u64) {
+ if self.credit_avail() <= needed_space {
+ self.flow_mgr
+ .borrow_mut()
+ .stream_data_blocked(self.stream_id, self.max_stream_data);
+ }
+
+ if self.flow_mgr.borrow().conn_credit_avail() <= needed_space {
+ self.flow_mgr.borrow_mut().data_blocked();
+ }
+ }
+
+ fn send_internal(&mut self, buf: &[u8], atomic: bool) -> Res<usize> {
+ if buf.is_empty() {
+ qerror!("zero-length send on stream {}", self.stream_id.as_u64());
+ return Err(Error::InvalidInput);
+ }
+
+ if let SendStreamState::Ready = self.state {
+ self.state.transition(SendStreamState::Send {
+ send_buf: TxBuffer::new(),
+ });
+ }
+
+ if !matches!(self.state, SendStreamState::Send{..}) {
+ return Err(Error::FinalSizeError);
+ }
+
+ let buf = if buf.is_empty() || (self.avail() == 0) {
+ return Ok(0);
+ } else if self.avail() < buf.len() {
+ if atomic {
+ self.send_blocked_if_space_needed(buf.len() as u64);
+ return Ok(0);
+ } else {
+ &buf[..self.avail()]
+ }
+ } else {
+ buf
+ };
+
+ let sent = match &mut self.state {
+ SendStreamState::Ready => unreachable!(),
+ SendStreamState::Send { send_buf } => send_buf.send(buf),
+ _ => return Err(Error::FinalSizeError),
+ };
+
+ self.flow_mgr
+ .borrow_mut()
+ .conn_increase_credit_used(sent as u64);
+
+ Ok(sent)
+ }
+
+ pub fn close(&mut self) {
+ match &mut self.state {
+ SendStreamState::Ready => {
+ self.state.transition(SendStreamState::DataSent {
+ send_buf: TxBuffer::new(),
+ final_size: 0,
+ fin_sent: false,
+ fin_acked: false,
+ });
+ }
+ SendStreamState::Send { send_buf } => {
+ let final_size = send_buf.retired + send_buf.buffered() as u64;
+ let owned_buf = mem::replace(send_buf, TxBuffer::new());
+ self.state.transition(SendStreamState::DataSent {
+ send_buf: owned_buf,
+ final_size,
+ fin_sent: false,
+ fin_acked: false,
+ });
+ }
+ SendStreamState::DataSent { .. } => qtrace!("already in DataSent state"),
+ SendStreamState::DataRecvd { .. } => qtrace!("already in DataRecvd state"),
+ SendStreamState::ResetSent => qtrace!("already in ResetSent state"),
+ SendStreamState::ResetRecvd => qtrace!("already in ResetRecvd state"),
+ }
+ }
+
+ pub fn reset(&mut self, err: AppError) {
+ match &self.state {
+ SendStreamState::Ready => {
+ self.flow_mgr
+ .borrow_mut()
+ .stream_reset(self.stream_id, err, 0);
+
+ self.state.transition(SendStreamState::ResetSent);
+ }
+ SendStreamState::Send { send_buf } => {
+ self.flow_mgr.borrow_mut().stream_reset(
+ self.stream_id,
+ err,
+ send_buf.highest_sent(),
+ );
+
+ self.state.transition(SendStreamState::ResetSent);
+ }
+ SendStreamState::DataSent { final_size, .. } => {
+ self.flow_mgr
+ .borrow_mut()
+ .stream_reset(self.stream_id, err, *final_size);
+
+ self.state.transition(SendStreamState::ResetSent);
+ }
+ SendStreamState::DataRecvd { .. } => qtrace!("already in DataRecvd state"),
+ SendStreamState::ResetSent => qtrace!("already in ResetSent state"),
+ SendStreamState::ResetRecvd => qtrace!("already in ResetRecvd state"),
+ };
+ }
+}
+
+#[derive(Debug, Default)]
+pub(crate) struct SendStreams(IndexMap<StreamId, SendStream>);
+
+impl SendStreams {
+ pub fn get(&self, id: StreamId) -> Res<&SendStream> {
+ self.0.get(&id).ok_or(Error::InvalidStreamId)
+ }
+
+ pub fn get_mut(&mut self, id: StreamId) -> Res<&mut SendStream> {
+ self.0.get_mut(&id).ok_or(Error::InvalidStreamId)
+ }
+
+ pub fn exists(&self, id: StreamId) -> bool {
+ self.0.contains_key(&id)
+ }
+
+ pub fn insert(&mut self, id: StreamId, stream: SendStream) {
+ self.0.insert(id, stream);
+ }
+
+ pub fn acked(&mut self, token: &StreamRecoveryToken) {
+ if let Some(ss) = self.0.get_mut(&token.id) {
+ ss.mark_as_acked(token.offset, token.length, token.fin);
+ }
+ }
+
+ pub fn reset_acked(&mut self, id: StreamId) {
+ if let Some(ss) = self.0.get_mut(&id) {
+ ss.reset_acked()
+ }
+ }
+
+ pub fn lost(&mut self, token: &StreamRecoveryToken) {
+ if let Some(ss) = self.0.get_mut(&token.id) {
+ ss.mark_as_lost(token.offset, token.length, token.fin);
+ }
+ }
+
+ pub fn clear(&mut self) {
+ self.0.clear()
+ }
+
+ pub fn clear_terminal(&mut self) {
+ self.0.retain(|_, stream| !stream.is_terminal())
+ }
+
+ pub(crate) fn write_frames(
+ &mut self,
+ builder: &mut PacketBuilder,
+ tokens: &mut Vec<RecoveryToken>,
+ stats: &mut FrameStats,
+ ) {
+ for (_, stream) in self {
+ if let Some(t) = stream.write_frame(builder) {
+ tokens.push(t);
+ stats.stream += 1;
+ }
+ }
+ }
+}
+
+impl<'a> IntoIterator for &'a mut SendStreams {
+ type Item = (&'a StreamId, &'a mut SendStream);
+ type IntoIter = indexmap::map::IterMut<'a, StreamId, SendStream>;
+
+ fn into_iter(self) -> indexmap::map::IterMut<'a, StreamId, SendStream> {
+ self.0.iter_mut()
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct StreamRecoveryToken {
+ pub(crate) id: StreamId,
+ offset: u64,
+ length: usize,
+ fin: bool,
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::events::ConnectionEvent;
+ use neqo_common::{event::Provider, hex_with_len, qtrace};
+
+ #[test]
+ fn test_mark_range() {
+ let mut rt = RangeTracker::default();
+
+ // ranges can go from nothing->Sent if queued for retrans and then
+ // acks arrive
+ rt.mark_range(5, 5, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 10);
+ assert_eq!(rt.acked_from_zero(), 0);
+ rt.mark_range(10, 4, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 14);
+ assert_eq!(rt.acked_from_zero(), 0);
+
+ rt.mark_range(0, 5, RangeState::Sent);
+ assert_eq!(rt.highest_offset(), 14);
+ assert_eq!(rt.acked_from_zero(), 0);
+ rt.mark_range(0, 5, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 14);
+ assert_eq!(rt.acked_from_zero(), 14);
+
+ rt.mark_range(12, 20, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 32);
+ assert_eq!(rt.acked_from_zero(), 32);
+
+ // ack the lot
+ rt.mark_range(0, 400, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 400);
+ assert_eq!(rt.acked_from_zero(), 400);
+
+ // acked trumps sent
+ rt.mark_range(0, 200, RangeState::Sent);
+ assert_eq!(rt.highest_offset(), 400);
+ assert_eq!(rt.acked_from_zero(), 400);
+ }
+
+ #[test]
+ fn unmark_sent_start() {
+ let mut rt = RangeTracker::default();
+
+ rt.mark_range(0, 5, RangeState::Sent);
+ assert_eq!(rt.highest_offset(), 5);
+ assert_eq!(rt.acked_from_zero(), 0);
+
+ rt.unmark_sent();
+ assert_eq!(rt.highest_offset(), 0);
+ assert_eq!(rt.acked_from_zero(), 0);
+ assert_eq!(rt.first_unmarked_range(), (0, None));
+ }
+
+ #[test]
+ fn unmark_sent_middle() {
+ let mut rt = RangeTracker::default();
+
+ rt.mark_range(0, 5, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 5);
+ assert_eq!(rt.acked_from_zero(), 5);
+ rt.mark_range(5, 5, RangeState::Sent);
+ assert_eq!(rt.highest_offset(), 10);
+ assert_eq!(rt.acked_from_zero(), 5);
+ rt.mark_range(10, 5, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 15);
+ assert_eq!(rt.acked_from_zero(), 5);
+ assert_eq!(rt.first_unmarked_range(), (15, None));
+
+ rt.unmark_sent();
+ assert_eq!(rt.highest_offset(), 15);
+ assert_eq!(rt.acked_from_zero(), 5);
+ assert_eq!(rt.first_unmarked_range(), (5, Some(5)));
+ }
+
+ #[test]
+ fn unmark_sent_end() {
+ let mut rt = RangeTracker::default();
+
+ rt.mark_range(0, 5, RangeState::Acked);
+ assert_eq!(rt.highest_offset(), 5);
+ assert_eq!(rt.acked_from_zero(), 5);
+ rt.mark_range(5, 5, RangeState::Sent);
+ assert_eq!(rt.highest_offset(), 10);
+ assert_eq!(rt.acked_from_zero(), 5);
+ assert_eq!(rt.first_unmarked_range(), (10, None));
+
+ rt.unmark_sent();
+ assert_eq!(rt.highest_offset(), 5);
+ assert_eq!(rt.acked_from_zero(), 5);
+ assert_eq!(rt.first_unmarked_range(), (5, None));
+ }
+
+ #[test]
+ fn truncate_front() {
+ let mut v = VecDeque::new();
+ v.push_back(5);
+ v.push_back(6);
+ v.push_back(7);
+ v.push_front(4usize);
+
+ v.rotate_left(1);
+ v.truncate(3);
+ assert_eq!(*v.front().unwrap(), 5);
+ assert_eq!(*v.back().unwrap(), 7);
+ }
+
+ #[test]
+ fn test_unmark_range() {
+ let mut rt = RangeTracker::default();
+
+ rt.mark_range(5, 5, RangeState::Acked);
+ rt.mark_range(10, 5, RangeState::Sent);
+
+ // Should unmark sent but not acked range
+ rt.unmark_range(7, 6);
+
+ let res = rt.first_unmarked_range();
+ assert_eq!(res, (0, Some(5)));
+ assert_eq!(
+ rt.used.iter().next().unwrap(),
+ (&5, &(5, RangeState::Acked))
+ );
+ assert_eq!(
+ rt.used.iter().nth(1).unwrap(),
+ (&13, &(2, RangeState::Sent))
+ );
+ assert!(rt.used.iter().nth(2).is_none());
+ rt.mark_range(0, 5, RangeState::Sent);
+
+ let res = rt.first_unmarked_range();
+ assert_eq!(res, (10, Some(3)));
+ rt.mark_range(10, 3, RangeState::Sent);
+
+ let res = rt.first_unmarked_range();
+ assert_eq!(res, (15, None));
+ }
+
+ #[test]
+ #[allow(clippy::cognitive_complexity)]
+ fn tx_buffer_next_bytes_1() {
+ let mut txb = TxBuffer::new();
+
+ assert_eq!(txb.avail(), SEND_BUFFER_SIZE);
+
+ // Fill the buffer
+ assert_eq!(txb.send(&[1; SEND_BUFFER_SIZE * 2]), SEND_BUFFER_SIZE);
+ assert!(matches!(txb.next_bytes(),
+ Some((0, x)) if x.len()==SEND_BUFFER_SIZE
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Mark almost all as sent. Get what's left
+ let one_byte_from_end = SEND_BUFFER_SIZE as u64 - 1;
+ txb.mark_as_sent(0, one_byte_from_end as usize);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 1
+ && start == one_byte_from_end
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Mark all as sent. Get nothing
+ txb.mark_as_sent(0, SEND_BUFFER_SIZE);
+ assert!(matches!(txb.next_bytes(), None));
+
+ // Mark as lost. Get it again
+ txb.mark_as_lost(one_byte_from_end, 1);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 1
+ && start == one_byte_from_end
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Mark a larger range lost, including beyond what's in the buffer even.
+ // Get a little more
+ let five_bytes_from_end = SEND_BUFFER_SIZE as u64 - 5;
+ txb.mark_as_lost(five_bytes_from_end, 100);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 5
+ && start == five_bytes_from_end
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Contig acked range at start means it can be removed from buffer
+ // Impl of vecdeque should now result in a split buffer when more data
+ // is sent
+ txb.mark_as_acked(0, five_bytes_from_end as usize);
+ assert_eq!(txb.send(&[2; 30]), 30);
+ // Just get 5 even though there is more
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 5
+ && start == five_bytes_from_end
+ && x.iter().all(|ch| *ch == 1)));
+ assert_eq!(txb.retired, five_bytes_from_end);
+ assert_eq!(txb.buffered(), 35);
+
+ // Marking that bit as sent should let the last contig bit be returned
+ // when called again
+ txb.mark_as_sent(five_bytes_from_end, 5);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 30
+ && start == SEND_BUFFER_SIZE as u64
+ && x.iter().all(|ch| *ch == 2)));
+ }
+
+ #[test]
+ fn tx_buffer_next_bytes_2() {
+ let mut txb = TxBuffer::new();
+
+ assert_eq!(txb.avail(), SEND_BUFFER_SIZE);
+
+ // Fill the buffer
+ assert_eq!(txb.send(&[1; SEND_BUFFER_SIZE * 2]), SEND_BUFFER_SIZE);
+ assert!(matches!(txb.next_bytes(),
+ Some((0, x)) if x.len()==SEND_BUFFER_SIZE
+ && x.iter().all(|ch| *ch == 1)));
+
+ // As above
+ let forty_bytes_from_end = SEND_BUFFER_SIZE as u64 - 40;
+
+ txb.mark_as_acked(0, forty_bytes_from_end as usize);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 40
+ && start == forty_bytes_from_end
+ ));
+
+ // Valid new data placed in split locations
+ assert_eq!(txb.send(&[2; 100]), 100);
+
+ // Mark a little more as sent
+ txb.mark_as_sent(forty_bytes_from_end, 10);
+ let thirty_bytes_from_end = forty_bytes_from_end + 10;
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 30
+ && start == thirty_bytes_from_end
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Mark a range 'A' in second slice as sent. Should still return the same
+ let range_a_start = SEND_BUFFER_SIZE as u64 + 30;
+ let range_a_end = range_a_start + 10;
+ txb.mark_as_sent(range_a_start, 10);
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 30
+ && start == thirty_bytes_from_end
+ && x.iter().all(|ch| *ch == 1)));
+
+ // Ack entire first slice and into second slice
+ let ten_bytes_past_end = SEND_BUFFER_SIZE as u64 + 10;
+ txb.mark_as_acked(0, ten_bytes_past_end as usize);
+
+ // Get up to marked range A
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 20
+ && start == ten_bytes_past_end
+ && x.iter().all(|ch| *ch == 2)));
+
+ txb.mark_as_sent(ten_bytes_past_end, 20);
+
+ // Get bit after earlier marked range A
+ assert!(matches!(txb.next_bytes(),
+ Some((start, x)) if x.len() == 60
+ && start == range_a_end
+ && x.iter().all(|ch| *ch == 2)));
+
+ // No more bytes.
+ txb.mark_as_sent(range_a_end, 60);
+ assert!(matches!(txb.next_bytes(), None));
+ }
+
+ #[test]
+ fn test_stream_tx() {
+ let flow_mgr = Rc::new(RefCell::new(FlowMgr::default()));
+ flow_mgr.borrow_mut().conn_increase_max_credit(4096);
+ let conn_events = ConnectionEvents::default();
+
+ let mut s = SendStream::new(4.into(), 1024, Rc::clone(&flow_mgr), conn_events);
+
+ let res = s.send(&[4; 100]).unwrap();
+ assert_eq!(res, 100);
+ s.mark_as_sent(0, 50, false);
+ assert_eq!(s.state.tx_buf().unwrap().data_limit(), 100);
+
+ // Should hit stream flow control limit before filling up send buffer
+ let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap();
+ assert_eq!(res, 1024 - 100);
+
+ // should do nothing, max stream data already 1024
+ s.set_max_stream_data(1024);
+ let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap();
+ assert_eq!(res, 0);
+
+ // should now hit the conn flow control (4096)
+ s.set_max_stream_data(1_048_576);
+ let res = s.send(&[4; SEND_BUFFER_SIZE]).unwrap();
+ assert_eq!(res, 3072);
+
+ // should now hit the tx buffer size
+ flow_mgr
+ .borrow_mut()
+ .conn_increase_max_credit(SEND_BUFFER_SIZE as u64);
+ let res = s.send(&[4; SEND_BUFFER_SIZE + 100]).unwrap();
+ assert_eq!(res, SEND_BUFFER_SIZE - 4096);
+
+ // TODO(agrover@mozilla.com): test ooo acks somehow
+ s.mark_as_acked(0, 40, false);
+ }
+
+ #[test]
+ fn test_tx_buffer_acks() {
+ let mut tx = TxBuffer::new();
+ assert_eq!(tx.send(&[4; 100]), 100);
+ let res = tx.next_bytes().unwrap();
+ assert_eq!(res.0, 0);
+ assert_eq!(res.1.len(), 100);
+ tx.mark_as_sent(0, 100);
+ let res = tx.next_bytes();
+ assert_eq!(res, None);
+
+ tx.mark_as_acked(0, 100);
+ let res = tx.next_bytes();
+ assert_eq!(res, None);
+ }
+
+ #[test]
+ fn send_stream_writable_event_gen() {
+ let flow_mgr = Rc::new(RefCell::new(FlowMgr::default()));
+ flow_mgr.borrow_mut().conn_increase_max_credit(2);
+ let mut conn_events = ConnectionEvents::default();
+
+ let mut s = SendStream::new(4.into(), 0, Rc::clone(&flow_mgr), conn_events.clone());
+
+ // Stream is initially blocked (conn:2, stream:0)
+ // and will not accept data.
+ assert_eq!(s.send(b"hi").unwrap(), 0);
+
+ // increasing to (conn:2, stream:2) will allow 2 bytes, and also
+ // generate a SendStreamWritable event.
+ s.set_max_stream_data(2);
+ let evts = conn_events.events().collect::<Vec<_>>();
+ assert_eq!(evts.len(), 1);
+ assert!(matches!(evts[0], ConnectionEvent::SendStreamWritable{..}));
+ assert_eq!(s.send(b"hello").unwrap(), 2);
+
+ // increasing to (conn:2, stream:4) will not generate an event or allow
+ // sending anything.
+ s.set_max_stream_data(4);
+ let evts = conn_events.events().collect::<Vec<_>>();
+ assert_eq!(evts.len(), 0);
+ assert_eq!(s.send(b"hello").unwrap(), 0);
+
+ // Increasing conn max (conn:4, stream:4) will unblock but not emit
+ // event b/c that happens in Connection::emit_frame() (tested in
+ // connection.rs)
+ assert_eq!(flow_mgr.borrow_mut().conn_increase_max_credit(4), true);
+ let evts = conn_events.events().collect::<Vec<_>>();
+ assert_eq!(evts.len(), 0);
+ assert_eq!(s.avail(), 2);
+ assert_eq!(s.send(b"hello").unwrap(), 2);
+
+ // No event because still blocked by conn
+ s.set_max_stream_data(1_000_000_000);
+ let evts = conn_events.events().collect::<Vec<_>>();
+ assert_eq!(evts.len(), 0);
+
+ // No event because happens in emit_frame()
+ flow_mgr
+ .borrow_mut()
+ .conn_increase_max_credit(1_000_000_000);
+ let evts = conn_events.events().collect::<Vec<_>>();
+ assert_eq!(evts.len(), 0);
+
+ // Unblocking both by a large amount will cause avail() to be limited by
+ // tx buffer size.
+ assert_eq!(s.avail(), SEND_BUFFER_SIZE - 4);
+
+ assert_eq!(
+ s.send(&[b'a'; SEND_BUFFER_SIZE]).unwrap(),
+ SEND_BUFFER_SIZE - 4
+ );
+
+ // No event because still blocked by tx buffer full
+ s.set_max_stream_data(2_000_000_000);
+ let evts = conn_events.events().collect::<Vec<_>>();
+ assert_eq!(evts.len(), 0);
+ assert_eq!(s.send(b"hello").unwrap(), 0);
+ }
+
+ #[test]
+ fn send_stream_writable_event_new_stream() {
+ let flow_mgr = Rc::new(RefCell::new(FlowMgr::default()));
+ flow_mgr.borrow_mut().conn_increase_max_credit(2);
+ let mut conn_events = ConnectionEvents::default();
+
+ let _s = SendStream::new(4.into(), 100, flow_mgr, conn_events.clone());
+
+ // Creating a new stream with conn and stream credits should result in
+ // an event.
+ let evts = conn_events.events().collect::<Vec<_>>();
+ assert_eq!(evts.len(), 1);
+ assert!(matches!(evts[0], ConnectionEvent::SendStreamWritable{..}));
+ }
+
+ #[test]
+ // Verify lost frames handle fin properly
+ fn send_stream_get_frame_data() {
+ let flow_mgr = Rc::new(RefCell::new(FlowMgr::default()));
+ flow_mgr.borrow_mut().conn_increase_max_credit(100);
+ let conn_events = ConnectionEvents::default();
+
+ let mut s = SendStream::new(0.into(), 100, flow_mgr, conn_events);
+ s.send(&[0; 10]).unwrap();
+ s.close();
+
+ let mut ss = SendStreams::default();
+ ss.insert(StreamId::from(0), s);
+
+ let mut tokens = Vec::new();
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+
+ // Write a small frame: no fin.
+ let written = builder.len();
+ builder.set_limit(written + 6);
+ ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ assert_eq!(builder.len(), written + 6);
+ assert_eq!(tokens.len(), 1);
+ let f1_token = tokens.remove(0);
+ assert!(matches!(&f1_token, RecoveryToken::Stream(x) if !x.fin));
+
+ // Write the rest: fin.
+ let written = builder.len();
+ builder.set_limit(written + 200);
+ ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ assert_eq!(builder.len(), written + 10);
+ assert_eq!(tokens.len(), 1);
+ let f2_token = tokens.remove(0);
+ assert!(matches!(&f2_token, RecoveryToken::Stream(x) if x.fin));
+
+ // Should be no more data to frame.
+ let written = builder.len();
+ ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ assert_eq!(builder.len(), written);
+ assert!(tokens.is_empty());
+
+ // Mark frame 1 as lost
+ if let RecoveryToken::Stream(rt) = f1_token {
+ ss.lost(&rt);
+ } else {
+ panic!();
+ }
+
+ // Next frame should not set fin even though stream has fin but frame
+ // does not include end of stream
+ let written = builder.len();
+ ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ assert_eq!(builder.len(), written + 7); // Needs a length this time.
+ assert_eq!(tokens.len(), 1);
+ let f4_token = tokens.remove(0);
+ assert!(matches!(&f4_token, RecoveryToken::Stream(x) if !x.fin));
+
+ // Mark frame 2 as lost
+ if let RecoveryToken::Stream(rt) = f2_token {
+ ss.lost(&rt);
+ } else {
+ panic!();
+ }
+
+ // Next frame should set fin because it includes end of stream
+ let written = builder.len();
+ ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ assert_eq!(builder.len(), written + 10);
+ assert_eq!(tokens.len(), 1);
+ let f5_token = tokens.remove(0);
+ assert!(matches!(&f5_token, RecoveryToken::Stream(x) if x.fin));
+ }
+
+ #[test]
+ #[allow(clippy::cognitive_complexity)]
+ // Verify lost frames handle fin properly with zero length fin
+ fn send_stream_get_frame_zerolength_fin() {
+ let flow_mgr = Rc::new(RefCell::new(FlowMgr::default()));
+ flow_mgr.borrow_mut().conn_increase_max_credit(100);
+ let conn_events = ConnectionEvents::default();
+
+ let mut s = SendStream::new(0.into(), 100, flow_mgr, conn_events);
+ s.send(&[0; 10]).unwrap();
+
+ let mut ss = SendStreams::default();
+ ss.insert(StreamId::from(0), s);
+
+ let mut tokens = Vec::new();
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+ ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ let f1_token = tokens.remove(0);
+ assert!(matches!(&f1_token, RecoveryToken::Stream(x) if x.offset == 0));
+ assert!(matches!(&f1_token, RecoveryToken::Stream(x) if x.length == 10));
+ assert!(matches!(&f1_token, RecoveryToken::Stream(x) if !x.fin));
+
+ // Should be no more data to frame
+ ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ assert!(tokens.is_empty());
+
+ ss.get_mut(StreamId::from(0)).unwrap().close();
+
+ ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ let f2_token = tokens.remove(0);
+ assert!(matches!(&f2_token, RecoveryToken::Stream(x) if x.offset == 10));
+ assert!(matches!(&f2_token, RecoveryToken::Stream(x) if x.length == 0));
+ assert!(matches!(&f2_token, RecoveryToken::Stream(x) if x.fin));
+
+ // Mark frame 2 as lost
+ if let RecoveryToken::Stream(rt) = f2_token {
+ ss.lost(&rt);
+ } else {
+ panic!();
+ }
+
+ // Next frame should set fin
+ ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ let f3_token = tokens.remove(0);
+ assert!(matches!(&f3_token, RecoveryToken::Stream(x) if x.offset == 10));
+ assert!(matches!(&f3_token, RecoveryToken::Stream(x) if x.length == 0));
+ assert!(matches!(&f3_token, RecoveryToken::Stream(x) if x.fin));
+
+ // Mark frame 1 as lost
+ if let RecoveryToken::Stream(rt) = f1_token {
+ ss.lost(&rt);
+ } else {
+ panic!();
+ }
+
+ // Next frame should set fin and include all data
+ ss.write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
+ let f4_token = tokens.remove(0);
+ assert!(matches!(&f4_token, RecoveryToken::Stream(x) if x.offset == 0));
+ assert!(matches!(&f4_token, RecoveryToken::Stream(x) if x.length == 10));
+ assert!(matches!(&f4_token, RecoveryToken::Stream(x) if x.fin));
+ }
+
+ #[test]
+ fn send_atomic() {
+ let flow_mgr = Rc::new(RefCell::new(FlowMgr::default()));
+ flow_mgr.borrow_mut().conn_increase_max_credit(5);
+ let conn_events = ConnectionEvents::default();
+
+ let stream_id = StreamId::from(4);
+ let mut s = SendStream::new(stream_id, 0, Rc::clone(&flow_mgr), conn_events);
+ s.set_max_stream_data(2);
+
+ // Stream is initially blocked (conn:5, stream:2)
+ // and will not accept atomic write of 3 bytes.
+ assert_eq!(s.send_atomic(b"abc").unwrap(), 0);
+
+ // assert that STREAM_DATA_BLOCKED is sent.
+ assert_eq!(
+ flow_mgr.borrow_mut().next().unwrap(),
+ Frame::StreamDataBlocked {
+ stream_id,
+ stream_data_limit: 0x2
+ }
+ );
+
+ // assert non-atomic write works
+ assert_eq!(s.send(b"abc").unwrap(), 2);
+ assert_eq!(s.next_bytes(), Some((0, &b"ab"[..])));
+ // STREAM_DATA_BLOCKED is not sent yet.
+ assert!(flow_mgr.borrow_mut().next().is_none());
+
+ // STREAM_DATA_BLOCKED is queued once bytes using all credit are sent.
+ s.mark_as_sent(0, 2, false);
+ assert_eq!(
+ flow_mgr.borrow_mut().next().unwrap(),
+ Frame::StreamDataBlocked {
+ stream_id,
+ stream_data_limit: 0x2
+ }
+ );
+
+ // increasing to (conn:5, stream:10)
+ s.set_max_stream_data(10);
+ // will not accept atomic write of 4 bytes.
+ assert_eq!(s.send_atomic(b"abcd").unwrap(), 0);
+
+ // assert that STREAM_DATA_BLOCKED is sent.
+ assert_eq!(
+ flow_mgr.borrow_mut().next().unwrap(),
+ Frame::DataBlocked { data_limit: 0x5 }
+ );
+
+ // assert non-atomic write works
+ assert_eq!(s.send(b"abcd").unwrap(), 3);
+ assert_eq!(s.next_bytes(), Some((2, &b"abc"[..])));
+ // DATA_BLOCKED is not sent yet.
+ assert!(flow_mgr.borrow_mut().next().is_none());
+
+ // DATA_BLOCKED is queued once bytes using all credit are sent.
+ s.mark_as_sent(2, 3, false);
+ assert_eq!(
+ flow_mgr.borrow_mut().next().unwrap(),
+ Frame::DataBlocked { data_limit: 0x5 }
+ );
+
+ // increasing to (conn:15, stream:15)
+ s.set_max_stream_data(15);
+ flow_mgr.borrow_mut().conn_increase_max_credit(15);
+
+ // assert that atomic writing 10 byte works
+ assert_eq!(s.send_atomic(b"abcdefghij").unwrap(), 10);
+ }
+
+ #[test]
+ fn ack_fin_first() {
+ const MESSAGE: &[u8] = b"hello";
+ let len_u64 = u64::try_from(MESSAGE.len()).unwrap();
+
+ let flow_mgr = Rc::new(RefCell::new(FlowMgr::default()));
+ flow_mgr.borrow_mut().conn_increase_max_credit(len_u64);
+ let conn_events = ConnectionEvents::default();
+
+ let mut s = SendStream::new(StreamId::new(100), 0, Rc::clone(&flow_mgr), conn_events);
+ s.set_max_stream_data(len_u64);
+
+ // Send all the data, then the fin.
+ let _ = s.send(MESSAGE).unwrap();
+ s.mark_as_sent(0, MESSAGE.len(), false);
+ s.close();
+ s.mark_as_sent(len_u64, 0, true);
+
+ // Ack the fin, then the data.
+ s.mark_as_acked(len_u64, 0, true);
+ s.mark_as_acked(0, MESSAGE.len(), false);
+ assert!(s.is_terminal());
+ }
+
+ #[test]
+ fn ack_then_lose_fin() {
+ const MESSAGE: &[u8] = b"hello";
+ let len_u64 = u64::try_from(MESSAGE.len()).unwrap();
+
+ let mut flow_mgr = FlowMgr::default();
+ flow_mgr.conn_increase_max_credit(len_u64);
+ let conn_events = ConnectionEvents::default();
+
+ let id = StreamId::new(100);
+ let mut s = SendStream::new(id, 0, Rc::new(RefCell::new(flow_mgr)), conn_events);
+ s.set_max_stream_data(len_u64);
+
+ // Send all the data, then the fin.
+ let _ = s.send(MESSAGE).unwrap();
+ s.mark_as_sent(0, MESSAGE.len(), false);
+ s.close();
+ s.mark_as_sent(len_u64, 0, true);
+
+ // Ack the fin, then mark it lost.
+ s.mark_as_acked(len_u64, 0, true);
+ s.mark_as_lost(len_u64, 0, true);
+
+ // No frame should be sent here.
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+ assert!(s.write_frame(&mut builder).is_none());
+ }
+
+ /// Create a `SendStream` and force it into a state where it believes that
+ /// `offset` bytes have already been sent and acknowledged.
+ fn stream_with_sent(stream: u64, offset: usize) -> SendStream {
+ const MAX_VARINT: u64 = (1 << 62) - 1;
+
+ let mut flow_mgr = FlowMgr::default();
+ flow_mgr.conn_increase_max_credit(MAX_VARINT);
+
+ let mut s = SendStream::new(
+ StreamId::from(stream),
+ MAX_VARINT,
+ Rc::new(RefCell::new(flow_mgr)),
+ ConnectionEvents::default(),
+ );
+
+ let mut send_buf = TxBuffer::new();
+ send_buf.retired = u64::try_from(offset).unwrap();
+ send_buf.ranges.mark_range(0, offset, RangeState::Acked);
+ s.state = SendStreamState::Send { send_buf };
+ s
+ }
+
+ fn frame_sent_sid(stream: u64, offset: usize, len: usize, fin: bool, space: usize) -> bool {
+ const BUF: &[u8] = &[0x42; 128];
+ let mut s = stream_with_sent(stream, offset);
+
+ // Now write out the proscribed data and maybe close.
+ if len > 0 {
+ s.send(&BUF[..len]).unwrap();
+ }
+ if fin {
+ s.close();
+ }
+
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+ let header_len = builder.len();
+ builder.set_limit(header_len + space);
+ let token = s.write_frame(&mut builder);
+ qtrace!("STREAM frame: {}", hex_with_len(&builder[header_len..]));
+ token.is_some()
+ }
+
+ fn frame_sent(offset: usize, len: usize, fin: bool, space: usize) -> bool {
+ frame_sent_sid(0, offset, len, fin, space)
+ }
+
+ #[test]
+ fn stream_frame_empty() {
+ // Stream frames with empty data and no fin never work.
+ assert!(!frame_sent(10, 0, false, 2));
+ assert!(!frame_sent(10, 0, false, 3));
+ assert!(!frame_sent(10, 0, false, 4));
+ assert!(!frame_sent(10, 0, false, 5));
+ assert!(!frame_sent(10, 0, false, 100));
+
+ // Empty data with fin is only a problem if there is no space.
+ assert!(!frame_sent(0, 0, true, 1));
+ assert!(frame_sent(0, 0, true, 2));
+ assert!(!frame_sent(10, 0, true, 2));
+ assert!(frame_sent(10, 0, true, 3));
+ assert!(frame_sent(10, 0, true, 4));
+ assert!(frame_sent(10, 0, true, 5));
+ assert!(frame_sent(10, 0, true, 100));
+ }
+
+ #[test]
+ fn stream_frame_minimum() {
+ // Add minimum data
+ assert!(!frame_sent(10, 1, false, 3));
+ assert!(!frame_sent(10, 1, true, 3));
+ assert!(frame_sent(10, 1, false, 4));
+ assert!(frame_sent(10, 1, true, 4));
+ assert!(frame_sent(10, 1, false, 5));
+ assert!(frame_sent(10, 1, true, 5));
+ assert!(frame_sent(10, 1, false, 100));
+ assert!(frame_sent(10, 1, true, 100));
+ }
+
+ #[test]
+ fn stream_frame_more() {
+ // Try more data
+ assert!(!frame_sent(10, 100, false, 3));
+ assert!(!frame_sent(10, 100, true, 3));
+ assert!(frame_sent(10, 100, false, 4));
+ assert!(frame_sent(10, 100, true, 4));
+ assert!(frame_sent(10, 100, false, 5));
+ assert!(frame_sent(10, 100, true, 5));
+ assert!(frame_sent(10, 100, false, 100));
+ assert!(frame_sent(10, 100, true, 100));
+
+ assert!(frame_sent(10, 100, false, 1000));
+ assert!(frame_sent(10, 100, true, 1000));
+ }
+
+ #[test]
+ fn stream_frame_big_id() {
+ // A value that encodes to the largest varint.
+ const BIG: u64 = 1 << 30;
+ const BIGSZ: usize = 1 << 30;
+
+ assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 16));
+ assert!(!frame_sent_sid(BIG, BIGSZ, 0, true, 16));
+ assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 17));
+ assert!(frame_sent_sid(BIG, BIGSZ, 0, true, 17));
+ assert!(!frame_sent_sid(BIG, BIGSZ, 0, false, 18));
+ assert!(frame_sent_sid(BIG, BIGSZ, 0, true, 18));
+
+ assert!(!frame_sent_sid(BIG, BIGSZ, 1, false, 17));
+ assert!(!frame_sent_sid(BIG, BIGSZ, 1, true, 17));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 18));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 18));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 19));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 19));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, false, 100));
+ assert!(frame_sent_sid(BIG, BIGSZ, 1, true, 100));
+ }
+
+ #[test]
+ fn stream_frame_16384() {
+ const DATA16384: &[u8] = &[0x43; 16384];
+
+ // 16383/16384 is an odd boundary in STREAM frame construction.
+ // That is the boundary where a length goes from 2 bytes to 4 bytes.
+ // If the data fits in the available space, then it is simple:
+ let mut s = stream_with_sent(0, 0);
+ s.send(DATA16384).unwrap();
+ s.close();
+
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+ let header_len = builder.len();
+ builder.set_limit(header_len + DATA16384.len() + 2);
+ let token = s.write_frame(&mut builder);
+ assert!(token.is_some());
+ // Expect STREAM + FIN only.
+ assert_eq!(&builder[header_len..header_len + 2], &[0b1001, 0]);
+ assert_eq!(&builder[header_len + 2..], DATA16384);
+
+ s.mark_as_lost(0, DATA16384.len(), true);
+
+ // However, if there is one extra byte of space, we will try to add a length.
+ // That length will then make the frame to be too large and the data will be
+ // truncated. The frame could carry one more byte of data, but it's a corner
+ // case we don't want to address as it should be rare (if not impossible).
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+ let header_len = builder.len();
+ builder.set_limit(header_len + DATA16384.len() + 3);
+ let token = s.write_frame(&mut builder);
+ assert!(token.is_some());
+ // Expect STREAM + LEN + FIN.
+ assert_eq!(
+ &builder[header_len..header_len + 4],
+ &[0b1010, 0, 0x7f, 0xfd]
+ );
+ assert_eq!(
+ &builder[header_len + 4..],
+ &DATA16384[..DATA16384.len() - 3]
+ );
+ }
+
+ #[test]
+ fn stream_frame_64() {
+ const DATA64: &[u8] = &[0x43; 64];
+
+ // Unlike 16383/16384, the boundary at 63/64 is easy because the difference
+ // is just one byte. We lose just the last byte when there is more space.
+ let mut s = stream_with_sent(0, 0);
+ s.send(DATA64).unwrap();
+ s.close();
+
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+ let header_len = builder.len();
+ builder.set_limit(header_len + 66);
+ let token = s.write_frame(&mut builder);
+ assert!(token.is_some());
+ // Expect STREAM + FIN only.
+ assert_eq!(&builder[header_len..header_len + 2], &[0b1001, 0]);
+ assert_eq!(&builder[header_len + 2..], DATA64);
+
+ s.mark_as_lost(0, DATA64.len(), true);
+
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+ let header_len = builder.len();
+ builder.set_limit(header_len + 67);
+ let token = s.write_frame(&mut builder);
+ assert!(token.is_some());
+ // Expect STREAM + LEN, not FIN.
+ assert_eq!(&builder[header_len..header_len + 3], &[0b1010, 0, 63]);
+ assert_eq!(&builder[header_len + 3..], &DATA64[..63]);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/sender.rs b/third_party/rust/neqo-transport/src/sender.rs
new file mode 100644
index 0000000000..15675f4b2b
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/sender.rs
@@ -0,0 +1,124 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Congestion control
+#![deny(clippy::pedantic)]
+#![allow(clippy::module_name_repetitions)]
+
+use crate::cc::{
+ ClassicCongestionControl, CongestionControl, CongestionControlAlgorithm, NewReno,
+ MAX_DATAGRAM_SIZE,
+};
+use crate::pace::Pacer;
+use crate::tracking::SentPacket;
+use neqo_common::qlog::NeqoQlog;
+
+use std::fmt::{self, Debug, Display};
+use std::time::{Duration, Instant};
+
+/// The number of packets we allow to burst from the pacer.
+pub const PACING_BURST_SIZE: usize = 2;
+
+#[derive(Debug)]
+pub struct PacketSender {
+ cc: Box<dyn CongestionControl>,
+ pacer: Option<Pacer>,
+}
+
+impl Display for PacketSender {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{}", self.cc)?;
+ if let Some(p) = &self.pacer {
+ write!(f, " {}", p)?;
+ }
+ Ok(())
+ }
+}
+
+impl PacketSender {
+ #[must_use]
+ pub fn new(alg: CongestionControlAlgorithm) -> Self {
+ Self {
+ cc: match alg {
+ CongestionControlAlgorithm::NewReno => {
+ Box::new(ClassicCongestionControl::new(NewReno::default()))
+ }
+ },
+ pacer: None,
+ }
+ }
+
+ pub fn set_qlog(&mut self, qlog: NeqoQlog) {
+ self.cc.set_qlog(qlog);
+ }
+
+ #[cfg(test)]
+ #[must_use]
+ pub fn cwnd(&self) -> usize {
+ self.cc.cwnd()
+ }
+
+ #[must_use]
+ pub fn cwnd_avail(&self) -> usize {
+ self.cc.cwnd_avail()
+ }
+
+ // Multi-packet version of OnPacketAckedCC
+ pub fn on_packets_acked(&mut self, acked_pkts: &[SentPacket]) {
+ self.cc.on_packets_acked(acked_pkts);
+ }
+
+ pub fn on_packets_lost(
+ &mut self,
+ first_rtt_sample_time: Option<Instant>,
+ prev_largest_acked_sent: Option<Instant>,
+ pto: Duration,
+ lost_packets: &[SentPacket],
+ ) {
+ self.cc.on_packets_lost(
+ first_rtt_sample_time,
+ prev_largest_acked_sent,
+ pto,
+ lost_packets,
+ );
+ }
+
+ pub fn discard(&mut self, pkt: &SentPacket) {
+ self.cc.discard(pkt);
+ }
+
+ pub fn on_packet_sent(&mut self, pkt: &SentPacket, rtt: Duration) {
+ self.pacer
+ .as_mut()
+ .unwrap()
+ .spend(pkt.time_sent, rtt, self.cc.cwnd(), pkt.size);
+ self.cc.on_packet_sent(pkt);
+ }
+
+ pub fn start_pacer(&mut self, now: Instant) {
+ // Start the pacer with a small burst size.
+ self.pacer = Some(Pacer::new(
+ now,
+ MAX_DATAGRAM_SIZE * PACING_BURST_SIZE,
+ MAX_DATAGRAM_SIZE,
+ ));
+ }
+
+ #[must_use]
+ pub fn next_paced(&self, rtt: Duration) -> Option<Instant> {
+ // Only pace if there are bytes in flight.
+ if self.cc.bytes_in_flight() > 0 {
+ Some(self.pacer.as_ref().unwrap().next(rtt, self.cc.cwnd()))
+ } else {
+ None
+ }
+ }
+
+ #[must_use]
+ pub fn recovery_packet(&self) -> bool {
+ self.cc.recovery_packet()
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/server.rs b/third_party/rust/neqo-transport/src/server.rs
new file mode 100644
index 0000000000..ef161a9fe4
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/server.rs
@@ -0,0 +1,636 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// This file implements a server that can handle multiple connections.
+
+use neqo_common::{
+ self as common, event::Provider, hex, qdebug, qerror, qinfo, qlog::NeqoQlog, qtrace, qwarn,
+ timer::Timer, Datagram, Decoder, Role,
+};
+use neqo_crypto::{AntiReplay, Cipher, ZeroRttCheckResult, ZeroRttChecker};
+
+pub use crate::addr_valid::ValidateAddress;
+use crate::addr_valid::{AddressValidation, AddressValidationResult};
+use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdManager, ConnectionIdRef};
+use crate::connection::{Connection, Output, State};
+use crate::packet::{PacketBuilder, PacketType, PublicPacket};
+use crate::{ConnectionParameters, QuicVersion, Res};
+
+use std::cell::RefCell;
+use std::collections::{HashMap, HashSet, VecDeque};
+use std::fs::OpenOptions;
+use std::mem;
+use std::net::SocketAddr;
+use std::ops::{Deref, DerefMut};
+use std::path::PathBuf;
+use std::rc::{Rc, Weak};
+use std::time::{Duration, Instant};
+
+pub enum InitialResult {
+ Accept,
+ Drop,
+ Retry(Vec<u8>),
+}
+
+/// MIN_INITIAL_PACKET_SIZE is the smallest packet that can be used to establish
+/// a new connection across all QUIC versions this server supports.
+const MIN_INITIAL_PACKET_SIZE: usize = 1200;
+const TIMER_GRANULARITY: Duration = Duration::from_millis(10);
+const TIMER_CAPACITY: usize = 16384;
+
+type StateRef = Rc<RefCell<ServerConnectionState>>;
+type CidMgr = Rc<RefCell<dyn ConnectionIdManager>>;
+type ConnectionTableRef = Rc<RefCell<HashMap<ConnectionId, StateRef>>>;
+
+#[derive(Debug)]
+pub struct ServerConnectionState {
+ c: Connection,
+ active_attempt: Option<AttemptKey>,
+ last_timer: Instant,
+}
+
+impl Deref for ServerConnectionState {
+ type Target = Connection;
+ fn deref(&self) -> &Self::Target {
+ &self.c
+ }
+}
+
+impl DerefMut for ServerConnectionState {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.c
+ }
+}
+
+/// A `AttemptKey` is used to disambiguate connection attempts.
+/// Multiple connection attempts with the same key won't produce multiple connections.
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+struct AttemptKey {
+ // Using the remote address is sufficient for disambiguation,
+ // until we support multiple local socket addresses.
+ remote_address: SocketAddr,
+ odcid: ConnectionId,
+}
+
+/// A `ServerZeroRttChecker` is a simple wrapper around a single checker.
+/// It uses `RefCell` so that the wrapped checker can be shared between
+/// multiple connections created by the server.
+#[derive(Clone, Debug)]
+struct ServerZeroRttChecker {
+ checker: Rc<RefCell<Box<dyn ZeroRttChecker>>>,
+}
+
+impl ServerZeroRttChecker {
+ pub fn new(checker: Box<dyn ZeroRttChecker>) -> Self {
+ Self {
+ checker: Rc::new(RefCell::new(checker)),
+ }
+ }
+}
+
+impl ZeroRttChecker for ServerZeroRttChecker {
+ fn check(&self, token: &[u8]) -> ZeroRttCheckResult {
+ self.checker.borrow().check(token)
+ }
+}
+
+/// `InitialDetails` holds important information for processing `Initial` packets.
+struct InitialDetails {
+ src_cid: ConnectionId,
+ dst_cid: ConnectionId,
+ token: Vec<u8>,
+ quic_version: QuicVersion,
+}
+
+impl InitialDetails {
+ fn new(packet: &PublicPacket) -> Self {
+ Self {
+ src_cid: ConnectionId::from(packet.scid()),
+ dst_cid: ConnectionId::from(packet.dcid()),
+ token: packet.token().to_vec(),
+ quic_version: packet.version().unwrap(),
+ }
+ }
+}
+
+pub struct Server {
+ /// The names of certificates.
+ certs: Vec<String>,
+ /// The ALPN values that the server supports.
+ protocols: Vec<String>,
+ /// The cipher suites that the server supports.
+ ciphers: Vec<Cipher>,
+ /// Anti-replay configuration for 0-RTT.
+ anti_replay: AntiReplay,
+ /// A function for determining if 0-RTT can be accepted.
+ zero_rtt_checker: ServerZeroRttChecker,
+ /// A connection ID manager.
+ cid_manager: CidMgr,
+ /// Connection parameters.
+ conn_params: ConnectionParameters,
+ /// Active connection attempts, keyed by `AttemptKey`. Initial packets with
+ /// the same key are routed to the connection that was first accepted.
+ /// This is cleared out when the connection is closed or established.
+ active_attempts: HashMap<AttemptKey, StateRef>,
+ /// All connections, keyed by ConnectionId.
+ connections: ConnectionTableRef,
+ /// The connections that have new events.
+ active: HashSet<ActiveConnectionRef>,
+ /// The set of connections that need immediate processing.
+ waiting: VecDeque<StateRef>,
+ /// Outstanding timers for connections.
+ timers: Timer<StateRef>,
+ /// Address validation logic, which determines whether we send a Retry.
+ address_validation: Rc<RefCell<AddressValidation>>,
+ /// Directory to create qlog traces in
+ qlog_dir: Option<PathBuf>,
+}
+
+impl Server {
+ /// Construct a new server.
+ /// * `now` is the time that the server is instantiated.
+ /// * `certs` is a list of the certificates that should be configured.
+ /// * `protocols` is the preference list of ALPN values.
+ /// * `anti_replay` is an anti-replay context.
+ /// * `zero_rtt_checker` determines whether 0-RTT should be accepted. This
+ /// will be passed the value of the `extra` argument that was passed to
+ /// `Connection::send_ticket` to see if it is OK.
+ /// * `cid_manager` is responsible for generating connection IDs and parsing them;
+ /// connection IDs produced by the manager cannot be zero-length.
+ pub fn new(
+ now: Instant,
+ certs: &[impl AsRef<str>],
+ protocols: &[impl AsRef<str>],
+ anti_replay: AntiReplay,
+ zero_rtt_checker: Box<dyn ZeroRttChecker>,
+ cid_manager: CidMgr,
+ conn_params: ConnectionParameters,
+ ) -> Res<Self> {
+ let validation = AddressValidation::new(now, ValidateAddress::Never)?;
+ Ok(Self {
+ certs: certs.iter().map(|x| String::from(x.as_ref())).collect(),
+ protocols: protocols.iter().map(|x| String::from(x.as_ref())).collect(),
+ ciphers: Vec::new(),
+ anti_replay,
+ zero_rtt_checker: ServerZeroRttChecker::new(zero_rtt_checker),
+ cid_manager,
+ conn_params,
+ active_attempts: HashMap::default(),
+ connections: Rc::default(),
+ active: HashSet::default(),
+ waiting: VecDeque::default(),
+ timers: Timer::new(now, TIMER_GRANULARITY, TIMER_CAPACITY),
+ address_validation: Rc::new(RefCell::new(validation)),
+ qlog_dir: None,
+ })
+ }
+
+ /// Set or clear directory to create logs of connection events in QLOG format.
+ pub fn set_qlog_dir(&mut self, dir: Option<PathBuf>) {
+ self.qlog_dir = dir;
+ }
+
+ /// Set the policy for address validation.
+ pub fn set_validation(&mut self, v: ValidateAddress) {
+ self.address_validation.borrow_mut().set_validation(v);
+ }
+
+ /// Set the cipher suites that should be used. Set an empty value to use
+ /// default values.
+ pub fn set_ciphers(&mut self, ciphers: impl AsRef<[Cipher]>) {
+ self.ciphers = Vec::from(ciphers.as_ref());
+ }
+
+ fn remove_timer(&mut self, c: &StateRef) {
+ let last = c.borrow().last_timer;
+ self.timers.remove(last, |t| Rc::ptr_eq(t, c));
+ }
+
+ fn process_connection(
+ &mut self,
+ c: StateRef,
+ dgram: Option<Datagram>,
+ now: Instant,
+ ) -> Option<Datagram> {
+ qtrace!([self], "Process connection {:?}", c);
+ let out = c.borrow_mut().process(dgram, now);
+ match out {
+ Output::Datagram(_) => {
+ qtrace!([self], "Sending packet, added to waiting connections");
+ self.waiting.push_back(Rc::clone(&c));
+ }
+ Output::Callback(delay) => {
+ let next = now + delay;
+ if next != c.borrow().last_timer {
+ qtrace!([self], "Change timer to {:?}", next);
+ self.remove_timer(&c);
+ c.borrow_mut().last_timer = next;
+ self.timers.add(next, Rc::clone(&c));
+ }
+ }
+ _ => {
+ self.remove_timer(&c);
+ }
+ }
+ if c.borrow().has_events() {
+ qtrace!([self], "Connection active: {:?}", c);
+ self.active.insert(ActiveConnectionRef { c: Rc::clone(&c) });
+ }
+
+ if *c.borrow().state() > State::Handshaking {
+ // Remove any active connection attempt now that this is no longer handshaking.
+ if let Some(k) = c.borrow_mut().active_attempt.take() {
+ self.active_attempts.remove(&k);
+ }
+ }
+
+ if matches!(c.borrow().state(), State::Closed(_)) {
+ c.borrow_mut().set_qlog(NeqoQlog::disabled());
+ self.connections
+ .borrow_mut()
+ .retain(|_, v| !Rc::ptr_eq(v, &c));
+ }
+ out.dgram()
+ }
+
+ fn connection(&self, cid: &ConnectionIdRef) -> Option<StateRef> {
+ if let Some(c) = self.connections.borrow().get(&cid[..]) {
+ Some(Rc::clone(&c))
+ } else {
+ None
+ }
+ }
+
+ fn handle_initial(
+ &mut self,
+ initial: InitialDetails,
+ dgram: Datagram,
+ now: Instant,
+ ) -> Option<Datagram> {
+ qdebug!([self], "Handle initial");
+ let res = self
+ .address_validation
+ .borrow()
+ .validate(&initial.token, dgram.source(), now);
+ match res {
+ AddressValidationResult::Invalid => None,
+ AddressValidationResult::Pass => self.connection_attempt(initial, dgram, None, now),
+ AddressValidationResult::ValidRetry(orig_dcid) => {
+ self.connection_attempt(initial, dgram, Some(orig_dcid), now)
+ }
+ AddressValidationResult::Validate => {
+ qinfo!([self], "Send retry for {:?}", initial.dst_cid);
+
+ let res = self.address_validation.borrow().generate_retry_token(
+ &initial.dst_cid,
+ dgram.source(),
+ now,
+ );
+ let token = if let Ok(t) = res {
+ t
+ } else {
+ qerror!([self], "unable to generate token, dropping packet");
+ return None;
+ };
+ let new_dcid = self.cid_manager.borrow_mut().generate_cid();
+ let packet = PacketBuilder::retry(
+ initial.quic_version,
+ &initial.src_cid,
+ &new_dcid,
+ &token,
+ &initial.dst_cid,
+ );
+ if let Ok(p) = packet {
+ let retry = Datagram::new(dgram.destination(), dgram.source(), p);
+ Some(retry)
+ } else {
+ qerror!([self], "unable to encode retry, dropping packet");
+ None
+ }
+ }
+ }
+ }
+
+ fn connection_attempt(
+ &mut self,
+ initial: InitialDetails,
+ dgram: Datagram,
+ orig_dcid: Option<ConnectionId>,
+ now: Instant,
+ ) -> Option<Datagram> {
+ let attempt_key = AttemptKey {
+ remote_address: dgram.source(),
+ odcid: orig_dcid.as_ref().unwrap_or(&initial.dst_cid).clone(),
+ };
+ if let Some(c) = self.active_attempts.get(&attempt_key) {
+ qdebug!(
+ [self],
+ "Handle Initial for existing connection attempt {:?}",
+ attempt_key
+ );
+ let c = Rc::clone(c);
+ self.process_connection(c, Some(dgram), now)
+ } else {
+ self.accept_connection(attempt_key, initial, dgram, orig_dcid, now)
+ }
+ }
+
+ fn create_qlog_trace(&self, attempt_key: &AttemptKey) -> NeqoQlog {
+ if let Some(qlog_dir) = &self.qlog_dir {
+ let mut qlog_path = qlog_dir.to_path_buf();
+
+ qlog_path.push(format!("{}.qlog", attempt_key.odcid));
+
+ // The original DCID is chosen by the client. Using create_new()
+ // prevents attackers from overwriting existing logs.
+ match OpenOptions::new()
+ .write(true)
+ .create_new(true)
+ .open(&qlog_path)
+ {
+ Ok(f) => {
+ qinfo!("Qlog output to {}", qlog_path.display());
+
+ let streamer = ::qlog::QlogStreamer::new(
+ qlog::QLOG_VERSION.to_string(),
+ Some("Neqo server qlog".to_string()),
+ Some("Neqo server qlog".to_string()),
+ None,
+ std::time::Instant::now(),
+ common::qlog::new_trace(Role::Server),
+ Box::new(f),
+ );
+ let n_qlog = NeqoQlog::enabled(streamer, qlog_path);
+ match n_qlog {
+ Ok(nql) => nql,
+ Err(e) => {
+ // Keep going but w/o qlogging
+ qerror!("NeqoQlog error: {}", e);
+ NeqoQlog::disabled()
+ }
+ }
+ }
+ Err(e) => {
+ qerror!(
+ "Could not open file {} for qlog output: {}",
+ qlog_path.display(),
+ e
+ );
+ NeqoQlog::disabled()
+ }
+ }
+ } else {
+ NeqoQlog::disabled()
+ }
+ }
+
+ fn accept_connection(
+ &mut self,
+ attempt_key: AttemptKey,
+ initial: InitialDetails,
+ dgram: Datagram,
+ orig_dcid: Option<ConnectionId>,
+ now: Instant,
+ ) -> Option<Datagram> {
+ qinfo!([self], "Accept connection {:?}", attempt_key);
+ // The internal connection ID manager that we use is not used directly.
+ // Instead, wrap it so that we can save connection IDs.
+
+ let cid_mgr = Rc::new(RefCell::new(ServerConnectionIdManager {
+ c: Weak::new(),
+ cid_manager: Rc::clone(&self.cid_manager),
+ connections: Rc::clone(&self.connections),
+ saved_cids: Vec::new(),
+ }));
+
+ let sconn = Connection::new_server(
+ &self.certs,
+ &self.protocols,
+ Rc::clone(&cid_mgr) as _,
+ &self.conn_params.clone().quic_version(initial.quic_version),
+ );
+
+ if let Ok(mut c) = sconn {
+ let zcheck = self.zero_rtt_checker.clone();
+ if c.server_enable_0rtt(&self.anti_replay, zcheck).is_err() {
+ qwarn!([self], "Unable to enable 0-RTT");
+ }
+ if let Some(odcid) = orig_dcid {
+ // There was a retry, so set the connection IDs for.
+ c.set_retry_cids(odcid, initial.src_cid, initial.dst_cid);
+ }
+ c.set_validation(Rc::clone(&self.address_validation));
+ c.set_qlog(self.create_qlog_trace(&attempt_key));
+ let c = Rc::new(RefCell::new(ServerConnectionState {
+ c,
+ last_timer: now,
+ active_attempt: Some(attempt_key.clone()),
+ }));
+ cid_mgr.borrow_mut().set_connection(Rc::clone(&c));
+ let previous_attempt = self.active_attempts.insert(attempt_key, Rc::clone(&c));
+ debug_assert!(previous_attempt.is_none());
+ self.process_connection(c, Some(dgram), now)
+ } else {
+ qwarn!([self], "Unable to create connection");
+ None
+ }
+ }
+
+ fn process_input(&mut self, dgram: Datagram, now: Instant) -> Option<Datagram> {
+ qtrace!("Process datagram: {}", hex(&dgram[..]));
+
+ // This is only looking at the first packet header in the datagram.
+ // All packets in the datagram are routed to the same connection.
+ let res = PublicPacket::decode(&dgram[..], self.cid_manager.borrow().as_decoder());
+ let (packet, _remainder) = match res {
+ Ok(res) => res,
+ _ => {
+ qtrace!([self], "Discarding {:?}", dgram);
+ return None;
+ }
+ };
+
+ // Finding an existing connection. Should be the most common case.
+ if let Some(c) = self.connection(packet.dcid()) {
+ return self.process_connection(c, Some(dgram), now);
+ }
+
+ if packet.packet_type() == PacketType::Short {
+ // TODO send a stateless reset here.
+ qtrace!([self], "Short header packet for an unknown connection");
+ return None;
+ }
+
+ if dgram.len() < MIN_INITIAL_PACKET_SIZE {
+ qtrace!([self], "Bogus packet: too short");
+ return None;
+ }
+ match packet.packet_type() {
+ PacketType::Initial => {
+ // Copy values from `packet` because they are currently still borrowing from `dgram`.
+ let initial = InitialDetails::new(&packet);
+ self.handle_initial(initial, dgram, now)
+ }
+ PacketType::OtherVersion => {
+ let vn = PacketBuilder::version_negotiation(packet.scid(), packet.dcid());
+ Some(Datagram::new(dgram.destination(), dgram.source(), vn))
+ }
+ _ => {
+ qtrace!([self], "Not an initial packet");
+ None
+ }
+ }
+ }
+
+ /// Iterate through the pending connections looking for any that might want
+ /// to send a datagram. Stop at the first one that does.
+ fn process_next_output(&mut self, now: Instant) -> Option<Datagram> {
+ qtrace!([self], "No packet to send, look at waiting connections");
+ while let Some(c) = self.waiting.pop_front() {
+ if let Some(d) = self.process_connection(c, None, now) {
+ return Some(d);
+ }
+ }
+ qtrace!([self], "No packet to send still, run timers");
+ while let Some(c) = self.timers.take_next(now) {
+ if let Some(d) = self.process_connection(c, None, now) {
+ return Some(d);
+ }
+ }
+ None
+ }
+
+ fn next_time(&mut self, now: Instant) -> Option<Duration> {
+ if self.waiting.is_empty() {
+ self.timers.next_time().map(|x| x - now)
+ } else {
+ Some(Duration::new(0, 0))
+ }
+ }
+
+ pub fn process(&mut self, dgram: Option<Datagram>, now: Instant) -> Output {
+ let out = if let Some(d) = dgram {
+ self.process_input(d, now)
+ } else {
+ None
+ };
+ let out = out.or_else(|| self.process_next_output(now));
+ match out {
+ Some(d) => {
+ qtrace!([self], "Send packet: {:?}", d);
+ Output::Datagram(d)
+ }
+ _ => match self.next_time(now) {
+ Some(delay) => {
+ qtrace!([self], "Wait: {:?}", delay);
+ Output::Callback(delay)
+ }
+ _ => {
+ qtrace!([self], "Go dormant");
+ Output::None
+ }
+ },
+ }
+ }
+
+ /// This lists the connections that have received new events
+ /// as a result of calling `process()`.
+ pub fn active_connections(&mut self) -> Vec<ActiveConnectionRef> {
+ mem::take(&mut self.active).into_iter().collect()
+ }
+
+ pub fn add_to_waiting(&mut self, c: ActiveConnectionRef) {
+ self.waiting.push_back(c.connection());
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct ActiveConnectionRef {
+ c: StateRef,
+}
+
+impl ActiveConnectionRef {
+ pub fn borrow<'a>(&'a self) -> impl Deref<Target = Connection> + 'a {
+ std::cell::Ref::map(self.c.borrow(), |c| &c.c)
+ }
+
+ pub fn borrow_mut<'a>(&'a mut self) -> impl DerefMut<Target = Connection> + 'a {
+ std::cell::RefMut::map(self.c.borrow_mut(), |c| &mut c.c)
+ }
+
+ pub fn connection(&self) -> StateRef {
+ Rc::clone(&self.c)
+ }
+}
+
+impl std::hash::Hash for ActiveConnectionRef {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ let ptr: *const _ = self.c.as_ref();
+ ptr.hash(state)
+ }
+}
+
+impl PartialEq for ActiveConnectionRef {
+ fn eq(&self, other: &Self) -> bool {
+ Rc::ptr_eq(&self.c, &other.c)
+ }
+}
+
+impl Eq for ActiveConnectionRef {}
+
+struct ServerConnectionIdManager {
+ c: Weak<RefCell<ServerConnectionState>>,
+ connections: ConnectionTableRef,
+ cid_manager: CidMgr,
+ saved_cids: Vec<ConnectionId>,
+}
+
+impl ServerConnectionIdManager {
+ pub fn set_connection(&mut self, c: StateRef) {
+ let saved = std::mem::replace(&mut self.saved_cids, Vec::with_capacity(0));
+ for cid in saved {
+ qtrace!("ServerConnectionIdManager inserting saved cid {}", cid);
+ self.insert_cid(cid, Rc::clone(&c));
+ }
+ self.c = Rc::downgrade(&c);
+ }
+
+ fn insert_cid(&mut self, cid: ConnectionId, rc: StateRef) {
+ debug_assert!(!cid.is_empty());
+ self.connections.borrow_mut().insert(cid, rc);
+ }
+}
+
+impl ConnectionIdDecoder for ServerConnectionIdManager {
+ fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option<ConnectionIdRef<'a>> {
+ self.cid_manager.borrow_mut().decode_cid(dec)
+ }
+}
+
+impl ConnectionIdManager for ServerConnectionIdManager {
+ fn generate_cid(&mut self) -> ConnectionId {
+ let cid = self.cid_manager.borrow_mut().generate_cid();
+ if let Some(rc) = self.c.upgrade() {
+ self.insert_cid(cid.clone(), rc);
+ } else {
+ // This function can be called before the connection is set.
+ // So save any connection IDs until that hookup happens.
+ qtrace!("ServerConnectionIdManager saving cid {}", cid);
+ self.saved_cids.push(cid.clone());
+ }
+ cid
+ }
+
+ fn as_decoder(&self) -> &dyn ConnectionIdDecoder {
+ self
+ }
+}
+
+impl ::std::fmt::Display for Server {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "Server")
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/stats.rs b/third_party/rust/neqo-transport/src/stats.rs
new file mode 100644
index 0000000000..9d01dfe211
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/stats.rs
@@ -0,0 +1,195 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Tracking of some useful statistics.
+#![deny(clippy::pedantic)]
+
+use crate::packet::PacketNumber;
+use neqo_common::qinfo;
+use std::cell::RefCell;
+use std::fmt::{self, Debug};
+use std::ops::Deref;
+use std::rc::Rc;
+
+pub(crate) const MAX_PTO_COUNTS: usize = 16;
+
+#[derive(Default, Clone)]
+#[allow(clippy::module_name_repetitions)]
+pub struct FrameStats {
+ pub all: usize,
+ pub ack: usize,
+ pub largest_acknowledged: PacketNumber,
+
+ pub crypto: usize,
+ pub stream: usize,
+ pub reset_stream: usize,
+ pub stop_sending: usize,
+
+ pub ping: usize,
+ pub padding: usize,
+
+ pub max_streams: usize,
+ pub streams_blocked: usize,
+ pub max_data: usize,
+ pub data_blocked: usize,
+ pub max_stream_data: usize,
+ pub stream_data_blocked: usize,
+
+ pub new_connection_id: usize,
+ pub retire_connection_id: usize,
+
+ pub path_challenge: usize,
+ pub path_response: usize,
+
+ pub connection_close: usize,
+ pub handshake_done: usize,
+ pub new_token: usize,
+}
+
+impl Debug for FrameStats {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ writeln!(
+ f,
+ " crypto {} done {} token {} close {}",
+ self.crypto, self.handshake_done, self.new_token, self.connection_close,
+ )?;
+ writeln!(
+ f,
+ " ack {} (max {}) ping {} padding {}",
+ self.ack, self.largest_acknowledged, self.ping, self.padding
+ )?;
+ writeln!(
+ f,
+ " stream {} reset {} stop {}",
+ self.stream, self.reset_stream, self.stop_sending,
+ )?;
+ writeln!(
+ f,
+ " max: stream {} data {} stream_data {}",
+ self.max_streams, self.max_data, self.max_stream_data,
+ )?;
+ writeln!(
+ f,
+ " blocked: stream {} data {} stream_data {}",
+ self.streams_blocked, self.data_blocked, self.stream_data_blocked,
+ )?;
+ writeln!(
+ f,
+ " ncid {} rcid {} pchallenge {} presponse {}",
+ self.new_connection_id,
+ self.retire_connection_id,
+ self.path_challenge,
+ self.path_response,
+ )
+ }
+}
+
+/// Connection statistics
+#[derive(Default, Clone)]
+#[allow(clippy::module_name_repetitions)]
+pub struct Stats {
+ info: String,
+
+ /// Total packets received, including all the bad ones.
+ pub packets_rx: usize,
+ /// Duplicate packets received.
+ pub dups_rx: usize,
+ /// Dropped packets or dropped garbage.
+ pub dropped_rx: usize,
+ /// The number of packet that were saved for later processing.
+ pub saved_datagrams: usize,
+
+ /// Total packets sent.
+ pub packets_tx: usize,
+ /// Total number of packets that are declared lost.
+ pub lost: usize,
+ /// Late acknowledgments, for packets that were declared lost already.
+ pub late_ack: usize,
+ /// Acknowledgments for packets that contained data that was marked
+ /// for retransmission when the PTO timer popped.
+ pub pto_ack: usize,
+
+ /// Whether the connection was resumed successfully.
+ pub resumed: bool,
+
+ /// Count PTOs. Single PTOs, 2 PTOs in a row, 3 PTOs in row, etc. are counted
+ /// separately.
+ pub pto_counts: [usize; MAX_PTO_COUNTS],
+
+ /// Count frames received.
+ pub frame_rx: FrameStats,
+ /// Count frames sent.
+ pub frame_tx: FrameStats,
+}
+
+impl Stats {
+ pub fn init(&mut self, info: String) {
+ self.info = info;
+ }
+
+ pub fn pkt_dropped(&mut self, reason: impl AsRef<str>) {
+ self.dropped_rx += 1;
+ qinfo!(
+ [self.info],
+ "Dropped received packet: {}; Total: {}",
+ reason.as_ref(),
+ self.dropped_rx
+ )
+ }
+
+ pub fn add_pto_count(&mut self, count: usize) {
+ debug_assert!(count > 0);
+ if count >= MAX_PTO_COUNTS {
+ // We can't move this count any further, so stop.
+ return;
+ }
+ self.pto_counts[count - 1] += 1;
+ if count > 1 {
+ debug_assert!(self.pto_counts[count - 2] > 0);
+ self.pto_counts[count - 2] -= 1;
+ }
+ }
+}
+
+impl Debug for Stats {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ writeln!(f, "stats for {}", self.info)?;
+ writeln!(
+ f,
+ " rx: {} drop {} dup {} saved {}",
+ self.packets_rx, self.dropped_rx, self.dups_rx, self.saved_datagrams
+ )?;
+ writeln!(
+ f,
+ " tx: {} lost {} lateack {} ptoack {}",
+ self.packets_tx, self.lost, self.late_ack, self.pto_ack
+ )?;
+ writeln!(f, " resumed: {} ", self.resumed)?;
+ writeln!(f, " frames rx:")?;
+ self.frame_rx.fmt(f)?;
+ writeln!(f, " frames tx:")?;
+ self.frame_tx.fmt(f)
+ }
+}
+
+#[derive(Default, Clone)]
+#[allow(clippy::module_name_repetitions)]
+pub struct StatsCell {
+ stats: Rc<RefCell<Stats>>,
+}
+
+impl Deref for StatsCell {
+ type Target = RefCell<Stats>;
+ fn deref(&self) -> &Self::Target {
+ &*self.stats
+ }
+}
+
+impl Debug for StatsCell {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ self.stats.borrow().fmt(f)
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/stream_id.rs b/third_party/rust/neqo-transport/src/stream_id.rs
new file mode 100644
index 0000000000..486bfb937c
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/stream_id.rs
@@ -0,0 +1,205 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Stream ID and stream index handling.
+
+use std::ops::AddAssign;
+
+use neqo_common::Role;
+
+use crate::connection::{LOCAL_STREAM_LIMIT_BIDI, LOCAL_STREAM_LIMIT_UNI};
+use crate::frame::StreamType;
+
+pub struct StreamIndexes {
+ pub local_max_stream_uni: StreamIndex,
+ pub local_max_stream_bidi: StreamIndex,
+ pub local_next_stream_uni: StreamIndex,
+ pub local_next_stream_bidi: StreamIndex,
+ pub remote_max_stream_uni: StreamIndex,
+ pub remote_max_stream_bidi: StreamIndex,
+ pub remote_next_stream_uni: StreamIndex,
+ pub remote_next_stream_bidi: StreamIndex,
+}
+
+impl StreamIndexes {
+ pub fn new() -> Self {
+ Self {
+ local_max_stream_bidi: StreamIndex::new(LOCAL_STREAM_LIMIT_BIDI),
+ local_max_stream_uni: StreamIndex::new(LOCAL_STREAM_LIMIT_UNI),
+ local_next_stream_uni: StreamIndex::new(0),
+ local_next_stream_bidi: StreamIndex::new(0),
+ remote_max_stream_bidi: StreamIndex::new(0),
+ remote_max_stream_uni: StreamIndex::new(0),
+ remote_next_stream_uni: StreamIndex::new(0),
+ remote_next_stream_bidi: StreamIndex::new(0),
+ }
+ }
+}
+
+#[derive(Debug, Eq, PartialEq, Clone, Copy, Ord, PartialOrd, Hash)]
+pub struct StreamId(u64);
+
+impl StreamId {
+ pub const fn new(id: u64) -> Self {
+ Self(id)
+ }
+
+ pub fn as_u64(self) -> u64 {
+ self.0
+ }
+
+ pub fn is_bidi(self) -> bool {
+ self.as_u64() & 0x02 == 0
+ }
+
+ pub fn is_uni(self) -> bool {
+ !self.is_bidi()
+ }
+
+ pub fn stream_type(self) -> StreamType {
+ if self.is_bidi() {
+ StreamType::BiDi
+ } else {
+ StreamType::UniDi
+ }
+ }
+
+ pub fn is_client_initiated(self) -> bool {
+ self.as_u64() & 0x01 == 0
+ }
+
+ pub fn is_server_initiated(self) -> bool {
+ !self.is_client_initiated()
+ }
+
+ pub fn role(self) -> Role {
+ if self.is_client_initiated() {
+ Role::Client
+ } else {
+ Role::Server
+ }
+ }
+
+ pub fn is_self_initiated(self, my_role: Role) -> bool {
+ match my_role {
+ Role::Client if self.is_client_initiated() => true,
+ Role::Server if self.is_server_initiated() => true,
+ _ => false,
+ }
+ }
+
+ pub fn is_remote_initiated(self, my_role: Role) -> bool {
+ !self.is_self_initiated(my_role)
+ }
+
+ pub fn is_send_only(self, my_role: Role) -> bool {
+ self.is_uni() && self.is_self_initiated(my_role)
+ }
+
+ pub fn is_recv_only(self, my_role: Role) -> bool {
+ self.is_uni() && self.is_remote_initiated(my_role)
+ }
+}
+
+impl From<u64> for StreamId {
+ fn from(val: u64) -> Self {
+ Self::new(val)
+ }
+}
+
+impl PartialEq<u64> for StreamId {
+ fn eq(&self, other: &u64) -> bool {
+ self.as_u64() == *other
+ }
+}
+
+impl ::std::fmt::Display for StreamId {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "{}", self.as_u64())
+ }
+}
+
+#[derive(Debug, Eq, PartialEq, Clone, Copy, Ord, PartialOrd, Hash)]
+pub struct StreamIndex(u64);
+
+impl StreamIndex {
+ pub fn new(val: u64) -> Self {
+ Self(val)
+ }
+
+ pub fn to_stream_id(self, stream_type: StreamType, role: Role) -> StreamId {
+ let type_val = match stream_type {
+ StreamType::BiDi => 0,
+ StreamType::UniDi => 2,
+ };
+ let role_val = match role {
+ Role::Server => 1,
+ Role::Client => 0,
+ };
+
+ StreamId::from((self.0 << 2) + type_val + role_val)
+ }
+
+ pub fn as_u64(self) -> u64 {
+ self.0
+ }
+}
+
+impl From<StreamId> for StreamIndex {
+ fn from(val: StreamId) -> Self {
+ Self(val.as_u64() >> 2)
+ }
+}
+
+impl AddAssign<u64> for StreamIndex {
+ fn add_assign(&mut self, other: u64) {
+ *self = Self::new(self.as_u64() + other)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::{StreamIndex, StreamType};
+ use neqo_common::Role;
+
+ #[test]
+ fn bidi_stream_properties() {
+ let id1 = StreamIndex::new(4).to_stream_id(StreamType::BiDi, Role::Client);
+ assert_eq!(id1.is_bidi(), true);
+ assert_eq!(id1.is_uni(), false);
+ assert_eq!(id1.is_client_initiated(), true);
+ assert_eq!(id1.is_server_initiated(), false);
+ assert_eq!(id1.role(), Role::Client);
+ assert_eq!(id1.is_self_initiated(Role::Client), true);
+ assert_eq!(id1.is_self_initiated(Role::Server), false);
+ assert_eq!(id1.is_remote_initiated(Role::Client), false);
+ assert_eq!(id1.is_remote_initiated(Role::Server), true);
+ assert_eq!(id1.is_send_only(Role::Server), false);
+ assert_eq!(id1.is_send_only(Role::Client), false);
+ assert_eq!(id1.is_recv_only(Role::Server), false);
+ assert_eq!(id1.is_recv_only(Role::Client), false);
+ assert_eq!(id1.as_u64(), 16);
+ }
+
+ #[test]
+ fn uni_stream_properties() {
+ let id2 = StreamIndex::new(8).to_stream_id(StreamType::UniDi, Role::Server);
+ assert_eq!(id2.is_bidi(), false);
+ assert_eq!(id2.is_uni(), true);
+ assert_eq!(id2.is_client_initiated(), false);
+ assert_eq!(id2.is_server_initiated(), true);
+ assert_eq!(id2.role(), Role::Server);
+ assert_eq!(id2.is_self_initiated(Role::Client), false);
+ assert_eq!(id2.is_self_initiated(Role::Server), true);
+ assert_eq!(id2.is_remote_initiated(Role::Client), true);
+ assert_eq!(id2.is_remote_initiated(Role::Server), false);
+ assert_eq!(id2.is_send_only(Role::Server), true);
+ assert_eq!(id2.is_send_only(Role::Client), false);
+ assert_eq!(id2.is_recv_only(Role::Server), false);
+ assert_eq!(id2.is_recv_only(Role::Client), true);
+ assert_eq!(id2.as_u64(), 35);
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/tparams.rs b/third_party/rust/neqo-transport/src/tparams.rs
new file mode 100644
index 0000000000..f0cfbf2203
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/tparams.rs
@@ -0,0 +1,541 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Transport parameters. See -transport section 7.3.
+
+#![allow(dead_code)]
+use crate::{Error, Res};
+use neqo_common::{hex, qdebug, qinfo, qtrace, Decoder, Encoder};
+use neqo_crypto::constants::{TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS};
+use neqo_crypto::ext::{ExtensionHandler, ExtensionHandlerResult, ExtensionWriterResult};
+use neqo_crypto::{HandshakeMessage, ZeroRttCheckResult, ZeroRttChecker};
+use std::cell::RefCell;
+use std::collections::HashMap;
+use std::rc::Rc;
+
+struct PreferredAddress {
+ // TODO(ekr@rtfm.com): Implement.
+}
+
+pub type TransportParameterId = u64;
+macro_rules! tpids {
+ { $($n:ident = $v:expr),+ $(,)? } => {
+ $(pub const $n: TransportParameterId = $v as TransportParameterId;)+
+ };
+ }
+tpids! {
+ ORIGINAL_DESTINATION_CONNECTION_ID = 0x00,
+ IDLE_TIMEOUT = 0x01,
+ STATELESS_RESET_TOKEN = 0x02,
+ MAX_UDP_PAYLOAD_SIZE = 0x03,
+ INITIAL_MAX_DATA = 0x04,
+ INITIAL_MAX_STREAM_DATA_BIDI_LOCAL = 0x05,
+ INITIAL_MAX_STREAM_DATA_BIDI_REMOTE = 0x06,
+ INITIAL_MAX_STREAM_DATA_UNI = 0x07,
+ INITIAL_MAX_STREAMS_BIDI = 0x08,
+ INITIAL_MAX_STREAMS_UNI = 0x09,
+ ACK_DELAY_EXPONENT = 0x0a,
+ MAX_ACK_DELAY = 0x0b,
+ DISABLE_MIGRATION = 0x0c,
+ PREFERRED_ADDRESS = 0x0d,
+ ACTIVE_CONNECTION_ID_LIMIT = 0x0e,
+ INITIAL_SOURCE_CONNECTION_ID = 0x0f,
+ RETRY_SOURCE_CONNECTION_ID = 0x10,
+ GREASE_QUIC_BIT = 0x2ab2,
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub enum TransportParameter {
+ Bytes(Vec<u8>),
+ Integer(u64),
+ Empty,
+}
+
+impl TransportParameter {
+ fn encode(&self, enc: &mut Encoder, tp: TransportParameterId) {
+ qdebug!("TP encoded; type 0x{:02x} val {:?}", tp, self);
+ enc.encode_varint(tp);
+ match self {
+ Self::Bytes(a) => {
+ enc.encode_vvec(a);
+ }
+ Self::Integer(a) => {
+ enc.encode_vvec_with(|enc_inner| {
+ enc_inner.encode_varint(*a);
+ });
+ }
+ Self::Empty => {
+ enc.encode_varint(0_u64);
+ }
+ };
+ }
+
+ fn decode(dec: &mut Decoder) -> Res<Option<(TransportParameterId, Self)>> {
+ let tp = match dec.decode_varint() {
+ Some(v) => v,
+ _ => return Err(Error::NoMoreData),
+ };
+ let content = match dec.decode_vvec() {
+ Some(v) => v,
+ _ => return Err(Error::NoMoreData),
+ };
+ qtrace!("TP {:x} length {:x}", tp, content.len());
+ let mut d = Decoder::from(content);
+ let value = match tp {
+ ORIGINAL_DESTINATION_CONNECTION_ID
+ | INITIAL_SOURCE_CONNECTION_ID
+ | RETRY_SOURCE_CONNECTION_ID => Self::Bytes(d.decode_remainder().to_vec()),
+ STATELESS_RESET_TOKEN => {
+ if d.remaining() != 16 {
+ return Err(Error::TransportParameterError);
+ }
+ Self::Bytes(d.decode_remainder().to_vec())
+ }
+ IDLE_TIMEOUT
+ | INITIAL_MAX_DATA
+ | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL
+ | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE
+ | INITIAL_MAX_STREAM_DATA_UNI
+ | MAX_ACK_DELAY => match d.decode_varint() {
+ Some(v) => Self::Integer(v),
+ None => return Err(Error::TransportParameterError),
+ },
+
+ INITIAL_MAX_STREAMS_BIDI | INITIAL_MAX_STREAMS_UNI => match d.decode_varint() {
+ Some(v) if v <= (1 << 60) => Self::Integer(v),
+ _ => return Err(Error::StreamLimitError),
+ },
+
+ MAX_UDP_PAYLOAD_SIZE => match d.decode_varint() {
+ Some(v) if v >= 1200 => Self::Integer(v),
+ _ => return Err(Error::TransportParameterError),
+ },
+
+ ACK_DELAY_EXPONENT => match d.decode_varint() {
+ Some(v) if v <= 20 => Self::Integer(v),
+ _ => return Err(Error::TransportParameterError),
+ },
+ ACTIVE_CONNECTION_ID_LIMIT => match d.decode_varint() {
+ Some(v) if v >= 2 => Self::Integer(v),
+ _ => return Err(Error::TransportParameterError),
+ },
+
+ DISABLE_MIGRATION | GREASE_QUIC_BIT => Self::Empty,
+ // Skip.
+ _ => return Ok(None),
+ };
+ if d.remaining() > 0 {
+ return Err(Error::TooMuchData);
+ }
+ qdebug!("TP decoded; type 0x{:02x} val {:?}", tp, value);
+ Ok(Some((tp, value)))
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq)]
+pub struct TransportParameters {
+ params: HashMap<TransportParameterId, TransportParameter>,
+}
+
+impl TransportParameters {
+ /// Set a value.
+ pub fn set(&mut self, k: TransportParameterId, v: TransportParameter) {
+ self.params.insert(k, v);
+ }
+
+ /// Clear a key.
+ pub fn remove(&mut self, k: TransportParameterId) {
+ self.params.remove(&k);
+ }
+
+ /// Decode is a static function that parses transport parameters
+ /// using the provided decoder.
+ pub(crate) fn decode(d: &mut Decoder) -> Res<Self> {
+ let mut tps = Self::default();
+ qtrace!("Parsed fixed TP header");
+
+ while d.remaining() > 0 {
+ match TransportParameter::decode(d) {
+ Ok(Some((tipe, tp))) => {
+ tps.set(tipe, tp);
+ }
+ Ok(None) => {}
+ Err(e) => return Err(e),
+ }
+ }
+ Ok(tps)
+ }
+
+ pub(crate) fn encode(&self, enc: &mut Encoder) {
+ for (tipe, tp) in &self.params {
+ tp.encode(enc, *tipe);
+ }
+ }
+
+ // Get an integer type or a default.
+ pub fn get_integer(&self, tp: TransportParameterId) -> u64 {
+ let default = match tp {
+ IDLE_TIMEOUT
+ | INITIAL_MAX_DATA
+ | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL
+ | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE
+ | INITIAL_MAX_STREAM_DATA_UNI
+ | INITIAL_MAX_STREAMS_BIDI
+ | INITIAL_MAX_STREAMS_UNI => 0,
+ MAX_UDP_PAYLOAD_SIZE => 65527,
+ ACK_DELAY_EXPONENT => 3,
+ MAX_ACK_DELAY => 25,
+ ACTIVE_CONNECTION_ID_LIMIT => 2,
+ _ => panic!("Transport parameter not known or not an Integer"),
+ };
+ match self.params.get(&tp) {
+ None => default,
+ Some(TransportParameter::Integer(x)) => *x,
+ _ => panic!("Internal error"),
+ }
+ }
+
+ // Set an integer type or a default.
+ pub fn set_integer(&mut self, tp: TransportParameterId, value: u64) {
+ match tp {
+ IDLE_TIMEOUT
+ | INITIAL_MAX_DATA
+ | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL
+ | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE
+ | INITIAL_MAX_STREAM_DATA_UNI
+ | INITIAL_MAX_STREAMS_BIDI
+ | INITIAL_MAX_STREAMS_UNI
+ | MAX_UDP_PAYLOAD_SIZE
+ | ACK_DELAY_EXPONENT
+ | MAX_ACK_DELAY
+ | ACTIVE_CONNECTION_ID_LIMIT => {
+ self.set(tp, TransportParameter::Integer(value));
+ }
+ _ => panic!("Transport parameter not known"),
+ }
+ }
+
+ pub fn get_bytes(&self, tp: TransportParameterId) -> Option<&[u8]> {
+ match tp {
+ ORIGINAL_DESTINATION_CONNECTION_ID
+ | INITIAL_SOURCE_CONNECTION_ID
+ | RETRY_SOURCE_CONNECTION_ID
+ | STATELESS_RESET_TOKEN => {}
+ _ => panic!("Transport parameter not known or not type bytes"),
+ }
+
+ match self.params.get(&tp) {
+ None => None,
+ Some(TransportParameter::Bytes(x)) => Some(&x),
+ _ => panic!("Internal error"),
+ }
+ }
+
+ pub fn set_bytes(&mut self, tp: TransportParameterId, value: Vec<u8>) {
+ match tp {
+ ORIGINAL_DESTINATION_CONNECTION_ID
+ | INITIAL_SOURCE_CONNECTION_ID
+ | RETRY_SOURCE_CONNECTION_ID
+ | STATELESS_RESET_TOKEN => {
+ self.set(tp, TransportParameter::Bytes(value));
+ }
+ _ => panic!("Transport parameter not known or not type bytes"),
+ }
+ }
+
+ pub fn set_empty(&mut self, tp: TransportParameterId) {
+ match tp {
+ DISABLE_MIGRATION | GREASE_QUIC_BIT => {
+ self.set(tp, TransportParameter::Empty);
+ }
+ _ => panic!("Transport parameter not known or not type empty"),
+ }
+ }
+
+ pub fn get_empty(&self, tipe: TransportParameterId) -> bool {
+ match self.params.get(&tipe) {
+ None => false,
+ Some(TransportParameter::Empty) => true,
+ _ => panic!("Internal error"),
+ }
+ }
+
+ /// Return true if the remembered transport parameters are OK for 0-RTT.
+ /// Generally this means that any value that is currently in effect is greater than
+ /// or equal to the promised value.
+ pub(crate) fn ok_for_0rtt(&self, remembered: &Self) -> bool {
+ for (k, v_rem) in &remembered.params {
+ // Skip checks for these, which don't affect 0-RTT.
+ if matches!(
+ *k,
+ ORIGINAL_DESTINATION_CONNECTION_ID
+ | INITIAL_SOURCE_CONNECTION_ID
+ | RETRY_SOURCE_CONNECTION_ID
+ | STATELESS_RESET_TOKEN
+ | IDLE_TIMEOUT
+ | ACK_DELAY_EXPONENT
+ | MAX_ACK_DELAY
+ | ACTIVE_CONNECTION_ID_LIMIT
+ ) {
+ continue;
+ }
+ match self.params.get(k) {
+ Some(v_self) => match (v_self, v_rem) {
+ (TransportParameter::Integer(i_self), TransportParameter::Integer(i_rem)) => {
+ if *i_self < *i_rem {
+ return false;
+ }
+ }
+ (TransportParameter::Empty, TransportParameter::Empty) => {}
+ _ => return false,
+ },
+ _ => return false,
+ }
+ }
+ true
+ }
+
+ fn was_sent(&self, tp: TransportParameterId) -> bool {
+ self.params.contains_key(&tp)
+ }
+}
+
+#[derive(Default, Debug)]
+pub struct TransportParametersHandler {
+ pub(crate) local: TransportParameters,
+ pub(crate) remote: Option<TransportParameters>,
+ pub(crate) remote_0rtt: Option<TransportParameters>,
+}
+
+impl TransportParametersHandler {
+ pub fn remote(&self) -> &TransportParameters {
+ match (self.remote.as_ref(), self.remote_0rtt.as_ref()) {
+ (Some(tp), _) | (_, Some(tp)) => tp,
+ _ => panic!("no transport parameters from peer"),
+ }
+ }
+}
+
+impl ExtensionHandler for TransportParametersHandler {
+ fn write(&mut self, msg: HandshakeMessage, d: &mut [u8]) -> ExtensionWriterResult {
+ if !matches!(msg, TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS) {
+ return ExtensionWriterResult::Skip;
+ }
+
+ qdebug!("Writing transport parameters, msg={:?}", msg);
+
+ // TODO(ekr@rtfm.com): Modify to avoid a copy.
+ let mut enc = Encoder::default();
+ self.local.encode(&mut enc);
+ assert!(enc.len() <= d.len());
+ d[..enc.len()].copy_from_slice(&enc);
+ ExtensionWriterResult::Write(enc.len())
+ }
+
+ fn handle(&mut self, msg: HandshakeMessage, d: &[u8]) -> ExtensionHandlerResult {
+ qtrace!(
+ "Handling transport parameters, msg={:?} value={}",
+ msg,
+ hex(d),
+ );
+
+ if !matches!(msg, TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS) {
+ return ExtensionHandlerResult::Alert(110); // unsupported_extension
+ }
+
+ let mut dec = Decoder::from(d);
+ match TransportParameters::decode(&mut dec) {
+ Ok(tp) => {
+ self.remote = Some(tp);
+ ExtensionHandlerResult::Ok
+ }
+ _ => ExtensionHandlerResult::Alert(47), // illegal_parameter
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct TpZeroRttChecker<T> {
+ handler: Rc<RefCell<TransportParametersHandler>>,
+ app_checker: T,
+}
+
+impl<T> TpZeroRttChecker<T>
+where
+ T: ZeroRttChecker + 'static,
+{
+ pub fn wrap(
+ handler: Rc<RefCell<TransportParametersHandler>>,
+ app_checker: T,
+ ) -> Box<dyn ZeroRttChecker> {
+ Box::new(Self {
+ handler,
+ app_checker,
+ })
+ }
+}
+
+impl<T> ZeroRttChecker for TpZeroRttChecker<T>
+where
+ T: ZeroRttChecker,
+{
+ fn check(&self, token: &[u8]) -> ZeroRttCheckResult {
+ // Reject 0-RTT if there is no token.
+ if token.is_empty() {
+ qdebug!("0-RTT: no token, no 0-RTT");
+ return ZeroRttCheckResult::Reject;
+ }
+ let mut dec = Decoder::from(token);
+ let tpslice = if let Some(v) = dec.decode_vvec() {
+ v
+ } else {
+ qinfo!("0-RTT: token code error");
+ return ZeroRttCheckResult::Fail;
+ };
+ let mut dec_tp = Decoder::from(tpslice);
+ let remembered = if let Ok(v) = TransportParameters::decode(&mut dec_tp) {
+ v
+ } else {
+ qinfo!("0-RTT: transport parameter decode error");
+ return ZeroRttCheckResult::Fail;
+ };
+ if self.handler.borrow().local.ok_for_0rtt(&remembered) {
+ qinfo!("0-RTT: transport parameters OK, passing to application checker");
+ self.app_checker.check(dec.decode_remainder())
+ } else {
+ qinfo!("0-RTT: transport parameters bad, rejecting");
+ ZeroRttCheckResult::Reject
+ }
+ }
+}
+
+// TODO(ekr@rtfm.com): Need to write more TP unit tests.
+#[cfg(test)]
+#[allow(unused_variables)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn basic_tps() {
+ const RESET_TOKEN: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8];
+ let mut tps = TransportParameters::default();
+ tps.set(
+ STATELESS_RESET_TOKEN,
+ TransportParameter::Bytes(RESET_TOKEN.to_vec()),
+ );
+ tps.params
+ .insert(INITIAL_MAX_STREAMS_BIDI, TransportParameter::Integer(10));
+
+ let mut enc = Encoder::default();
+ tps.encode(&mut enc);
+
+ let tps2 = TransportParameters::decode(&mut enc.as_decoder()).expect("Couldn't decode");
+ assert_eq!(tps, tps2);
+
+ println!("TPS = {:?}", tps);
+ assert_eq!(tps2.get_integer(IDLE_TIMEOUT), 0); // Default
+ assert_eq!(tps2.get_integer(MAX_ACK_DELAY), 25); // Default
+ assert_eq!(tps2.get_integer(ACTIVE_CONNECTION_ID_LIMIT), 2); // Default
+ assert_eq!(tps2.get_integer(INITIAL_MAX_STREAMS_BIDI), 10); // Sent
+ assert_eq!(tps2.get_bytes(STATELESS_RESET_TOKEN), Some(RESET_TOKEN));
+ assert_eq!(tps2.get_bytes(ORIGINAL_DESTINATION_CONNECTION_ID), None);
+ assert_eq!(tps2.get_bytes(INITIAL_SOURCE_CONNECTION_ID), None);
+ assert_eq!(tps2.get_bytes(RETRY_SOURCE_CONNECTION_ID), None);
+ assert_eq!(tps2.was_sent(ORIGINAL_DESTINATION_CONNECTION_ID), false);
+ assert_eq!(tps2.was_sent(INITIAL_SOURCE_CONNECTION_ID), false);
+ assert_eq!(tps2.was_sent(RETRY_SOURCE_CONNECTION_ID), false);
+ assert_eq!(tps2.was_sent(STATELESS_RESET_TOKEN), true);
+
+ let mut enc = Encoder::default();
+ tps.encode(&mut enc);
+
+ let tps2 = TransportParameters::decode(&mut enc.as_decoder()).expect("Couldn't decode");
+ }
+
+ #[test]
+ fn compatible_0rtt_ignored_values() {
+ let mut tps_a = TransportParameters::default();
+ tps_a.set(
+ STATELESS_RESET_TOKEN,
+ TransportParameter::Bytes(vec![1, 2, 3]),
+ );
+ tps_a.set(IDLE_TIMEOUT, TransportParameter::Integer(10));
+ tps_a.set(MAX_ACK_DELAY, TransportParameter::Integer(22));
+ tps_a.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(33));
+
+ let mut tps_b = TransportParameters::default();
+ assert!(tps_a.ok_for_0rtt(&tps_b));
+ assert!(tps_b.ok_for_0rtt(&tps_a));
+
+ tps_b.set(
+ STATELESS_RESET_TOKEN,
+ TransportParameter::Bytes(vec![8, 9, 10]),
+ );
+ tps_b.set(IDLE_TIMEOUT, TransportParameter::Integer(100));
+ tps_b.set(MAX_ACK_DELAY, TransportParameter::Integer(2));
+ tps_b.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(44));
+ assert!(tps_a.ok_for_0rtt(&tps_b));
+ assert!(tps_b.ok_for_0rtt(&tps_a));
+ }
+
+ #[test]
+ fn compatible_0rtt_integers() {
+ let mut tps_a = TransportParameters::default();
+ const INTEGER_KEYS: &[TransportParameterId] = &[
+ INITIAL_MAX_DATA,
+ INITIAL_MAX_STREAM_DATA_BIDI_LOCAL,
+ INITIAL_MAX_STREAM_DATA_BIDI_REMOTE,
+ INITIAL_MAX_STREAM_DATA_UNI,
+ INITIAL_MAX_STREAMS_BIDI,
+ INITIAL_MAX_STREAMS_UNI,
+ MAX_UDP_PAYLOAD_SIZE,
+ ];
+ for i in INTEGER_KEYS {
+ tps_a.set(*i, TransportParameter::Integer(12));
+ }
+
+ let tps_b = tps_a.clone();
+ assert!(tps_a.ok_for_0rtt(&tps_b));
+ assert!(tps_b.ok_for_0rtt(&tps_a));
+
+ // For each integer key, increase the value by one.
+ for i in INTEGER_KEYS {
+ let mut tps_b = tps_a.clone();
+ tps_b.set(*i, TransportParameter::Integer(13));
+ // If an increased value is remembered, then we can't attempt 0-RTT with these parameters.
+ assert!(!tps_a.ok_for_0rtt(&tps_b));
+ // If an increased value is lower, then we can attempt 0-RTT with these parameters.
+ assert!(tps_b.ok_for_0rtt(&tps_a));
+ }
+
+ // Drop integer values and check.
+ for i in INTEGER_KEYS {
+ let mut tps_b = tps_a.clone();
+ tps_b.remove(*i);
+ // A value that is missing from what is rememebered is OK.
+ assert!(tps_a.ok_for_0rtt(&tps_b));
+ // A value that is rememebered, but not current is not OK.
+ assert!(!tps_b.ok_for_0rtt(&tps_a));
+ }
+ }
+
+ #[test]
+ fn active_connection_id_limit_lt_2_is_error() {
+ let mut tps = TransportParameters::default();
+
+ // Intentionally set an invalid value for the ACTIVE_CONNECTION_ID_LIMIT transport parameter.
+ tps.params
+ .insert(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(1));
+
+ let mut enc = Encoder::default();
+ tps.encode(&mut enc);
+
+ // When decoding a set of transport parameters with an invalid ACTIVE_CONNECTION_ID_LIMIT
+ // the result should be an error.
+ let invalid_decode_result = TransportParameters::decode(&mut enc.as_decoder());
+ assert!(invalid_decode_result.is_err());
+ }
+}
diff --git a/third_party/rust/neqo-transport/src/tracking.rs b/third_party/rust/neqo-transport/src/tracking.rs
new file mode 100644
index 0000000000..efd3b06069
--- /dev/null
+++ b/third_party/rust/neqo-transport/src/tracking.rs
@@ -0,0 +1,992 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// Tracking of received packets and generating acks thereof.
+
+#![deny(clippy::pedantic)]
+
+use std::cmp::min;
+use std::collections::VecDeque;
+use std::convert::TryFrom;
+use std::ops::{Index, IndexMut};
+use std::time::{Duration, Instant};
+
+use neqo_common::{qdebug, qinfo, qtrace, qwarn};
+use neqo_crypto::{Epoch, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL};
+
+use crate::packet::{PacketBuilder, PacketNumber, PacketType};
+use crate::recovery::RecoveryToken;
+use crate::stats::FrameStats;
+
+use smallvec::{smallvec, SmallVec};
+
+// TODO(mt) look at enabling EnumMap for this: https://stackoverflow.com/a/44905797/1375574
+#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq)]
+pub enum PNSpace {
+ Initial,
+ Handshake,
+ ApplicationData,
+}
+
+#[allow(clippy::use_self)] // https://github.com/rust-lang/rust-clippy/issues/3410
+impl PNSpace {
+ pub fn iter() -> impl Iterator<Item = &'static PNSpace> {
+ const SPACES: &[PNSpace] = &[
+ PNSpace::Initial,
+ PNSpace::Handshake,
+ PNSpace::ApplicationData,
+ ];
+ SPACES.iter()
+ }
+}
+
+impl From<Epoch> for PNSpace {
+ fn from(epoch: Epoch) -> Self {
+ match epoch {
+ TLS_EPOCH_INITIAL => Self::Initial,
+ TLS_EPOCH_HANDSHAKE => Self::Handshake,
+ _ => Self::ApplicationData,
+ }
+ }
+}
+
+impl From<PacketType> for PNSpace {
+ fn from(pt: PacketType) -> Self {
+ match pt {
+ PacketType::Initial => Self::Initial,
+ PacketType::Handshake => Self::Handshake,
+ PacketType::ZeroRtt | PacketType::Short => Self::ApplicationData,
+ _ => panic!("Attempted to get space from wrong packet type"),
+ }
+ }
+}
+
+#[derive(Clone, Copy, Default)]
+pub struct PNSpaceSet {
+ initial: bool,
+ handshake: bool,
+ application_data: bool,
+}
+
+impl PNSpaceSet {
+ pub fn all() -> Self {
+ Self {
+ initial: true,
+ handshake: true,
+ application_data: true,
+ }
+ }
+}
+
+impl Index<PNSpace> for PNSpaceSet {
+ type Output = bool;
+
+ fn index(&self, space: PNSpace) -> &Self::Output {
+ match space {
+ PNSpace::Initial => &self.initial,
+ PNSpace::Handshake => &self.handshake,
+ PNSpace::ApplicationData => &self.application_data,
+ }
+ }
+}
+
+impl IndexMut<PNSpace> for PNSpaceSet {
+ fn index_mut(&mut self, space: PNSpace) -> &mut Self::Output {
+ match space {
+ PNSpace::Initial => &mut self.initial,
+ PNSpace::Handshake => &mut self.handshake,
+ PNSpace::ApplicationData => &mut self.application_data,
+ }
+ }
+}
+
+impl<T: AsRef<[PNSpace]>> From<T> for PNSpaceSet {
+ fn from(spaces: T) -> Self {
+ let mut v = Self::default();
+ for sp in spaces.as_ref() {
+ v[*sp] = true;
+ }
+ v
+ }
+}
+
+impl std::fmt::Debug for PNSpaceSet {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ let mut first = true;
+ f.write_str("(")?;
+ for sp in PNSpace::iter() {
+ if self[*sp] {
+ if !first {
+ f.write_str("+")?;
+ first = false;
+ }
+ std::fmt::Display::fmt(sp, f)?;
+ }
+ }
+ f.write_str(")")
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct SentPacket {
+ pub pt: PacketType,
+ pub pn: PacketNumber,
+ ack_eliciting: bool,
+ pub time_sent: Instant,
+ pub tokens: Vec<RecoveryToken>,
+
+ time_declared_lost: Option<Instant>,
+ /// After a PTO, this is true when the packet has been released.
+ pto: bool,
+
+ pub size: usize,
+}
+
+impl SentPacket {
+ pub fn new(
+ pt: PacketType,
+ pn: PacketNumber,
+ time_sent: Instant,
+ ack_eliciting: bool,
+ tokens: Vec<RecoveryToken>,
+ size: usize,
+ ) -> Self {
+ Self {
+ pt,
+ pn,
+ time_sent,
+ ack_eliciting,
+ tokens,
+ time_declared_lost: None,
+ pto: false,
+ size,
+ }
+ }
+
+ /// Returns `true` if the packet will elicit an ACK.
+ pub fn ack_eliciting(&self) -> bool {
+ self.ack_eliciting
+ }
+
+ /// Whether the packet has been declared lost.
+ pub fn lost(&self) -> bool {
+ self.time_declared_lost.is_some()
+ }
+
+ /// Whether accounting for the loss or acknowledgement in the
+ /// congestion controller is pending.
+ /// Returns `true` if the packet counts as being "in flight",
+ /// and has not previously been declared lost.
+ /// Note that this should count packets that contain only ACK and PADDING,
+ /// but we don't send PADDING, so we don't track that.
+ pub fn cc_outstanding(&self) -> bool {
+ self.ack_eliciting() && !self.lost()
+ }
+
+ /// Declare the packet as lost. Returns `true` if this is the first time.
+ pub fn declare_lost(&mut self, now: Instant) -> bool {
+ if self.lost() {
+ false
+ } else {
+ self.time_declared_lost = Some(now);
+ true
+ }
+ }
+
+ /// Ask whether this tracked packet has been declared lost for long enough
+ /// that it can be expired and no longer tracked.
+ pub fn expired(&self, now: Instant, expiration_period: Duration) -> bool {
+ self.time_declared_lost
+ .map_or(false, |loss_time| (loss_time + expiration_period) <= now)
+ }
+
+ /// Whether the packet contents were cleared out after a PTO.
+ pub fn pto_fired(&self) -> bool {
+ self.pto
+ }
+
+ /// On PTO, we need to get the recovery tokens so that we can ensure that
+ /// the frames we sent can be sent again in the PTO packet(s). Do that just once.
+ pub fn pto(&mut self) -> bool {
+ if self.pto || self.lost() {
+ false
+ } else {
+ self.pto = true;
+ true
+ }
+ }
+}
+
+impl std::fmt::Display for PNSpace {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ f.write_str(match self {
+ Self::Initial => "in",
+ Self::Handshake => "hs",
+ Self::ApplicationData => "ap",
+ })
+ }
+}
+
+/// `InsertionResult` tracks whether something was inserted for `PacketRange::add()`.
+pub enum InsertionResult {
+ Largest,
+ Smallest,
+ NotInserted,
+}
+
+#[derive(Clone, Debug, Default)]
+pub struct PacketRange {
+ largest: PacketNumber,
+ smallest: PacketNumber,
+ ack_needed: bool,
+}
+
+impl PacketRange {
+ /// Make a single packet range.
+ pub fn new(pn: PacketNumber) -> Self {
+ Self {
+ largest: pn,
+ smallest: pn,
+ ack_needed: true,
+ }
+ }
+
+ /// Get the number of acknowleged packets in the range.
+ pub fn len(&self) -> u64 {
+ self.largest - self.smallest + 1
+ }
+
+ /// Returns whether this needs to be sent.
+ pub fn ack_needed(&self) -> bool {
+ self.ack_needed
+ }
+
+ /// Return whether the given number is in the range.
+ pub fn contains(&self, pn: PacketNumber) -> bool {
+ (pn >= self.smallest) && (pn <= self.largest)
+ }
+
+ /// Maybe add a packet number to the range. Returns true if it was added
+ /// at the small end (which indicates that this might need merging with a
+ /// preceding range).
+ pub fn add(&mut self, pn: PacketNumber) -> InsertionResult {
+ assert!(!self.contains(pn));
+ // Only insert if this is adjacent the current range.
+ if (self.largest + 1) == pn {
+ qtrace!([self], "Adding largest {}", pn);
+ self.largest += 1;
+ self.ack_needed = true;
+ InsertionResult::Largest
+ } else if self.smallest == (pn + 1) {
+ qtrace!([self], "Adding smallest {}", pn);
+ self.smallest -= 1;
+ self.ack_needed = true;
+ InsertionResult::Smallest
+ } else {
+ InsertionResult::NotInserted
+ }
+ }
+
+ /// Maybe merge a higher-numbered range into this.
+ fn merge_larger(&mut self, other: &Self) {
+ qinfo!([self], "Merging {}", other);
+ // This only works if they are immediately adjacent.
+ assert_eq!(self.largest + 1, other.smallest);
+
+ self.largest = other.largest;
+ self.ack_needed = self.ack_needed || other.ack_needed;
+ }
+
+ /// When a packet containing the range `other` is acknowledged,
+ /// clear the `ack_needed` attribute on this.
+ /// Requires that other is equal to this, or a larger range.
+ pub fn acknowledged(&mut self, other: &Self) {
+ if (other.smallest <= self.smallest) && (other.largest >= self.largest) {
+ self.ack_needed = false;
+ }
+ }
+}
+
+impl ::std::fmt::Display for PacketRange {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "{}->{}", self.largest, self.smallest)
+ }
+}
+
+/// The ACK delay we use.
+pub const ACK_DELAY: Duration = Duration::from_millis(20); // 20ms
+pub const MAX_UNACKED_PKTS: usize = 1;
+const MAX_TRACKED_RANGES: usize = 32;
+const MAX_ACKS_PER_FRAME: usize = 32;
+
+/// A structure that tracks what was included in an ACK.
+#[derive(Debug, Clone)]
+pub struct AckToken {
+ space: PNSpace,
+ ranges: Vec<PacketRange>,
+}
+
+/// A structure that tracks what packets have been received,
+/// and what needs acknowledgement for a packet number space.
+#[derive(Debug)]
+pub struct RecvdPackets {
+ space: PNSpace,
+ ranges: VecDeque<PacketRange>,
+ /// The packet number of the lowest number packet that we are tracking.
+ min_tracked: PacketNumber,
+ /// The time we got the largest acknowledged.
+ largest_pn_time: Option<Instant>,
+ // The time that we should be sending an ACK.
+ ack_time: Option<Instant>,
+ pkts_since_last_ack: usize,
+}
+
+impl RecvdPackets {
+ /// Make a new `RecvdPackets` for the indicated packet number space.
+ pub fn new(space: PNSpace) -> Self {
+ Self {
+ space,
+ ranges: VecDeque::new(),
+ min_tracked: 0,
+ largest_pn_time: None,
+ ack_time: None,
+ pkts_since_last_ack: 0,
+ }
+ }
+
+ /// Get the time at which the next ACK should be sent.
+ pub fn ack_time(&self) -> Option<Instant> {
+ self.ack_time
+ }
+
+ /// Returns true if an ACK frame should be sent now.
+ fn ack_now(&self, now: Instant) -> bool {
+ match self.ack_time {
+ Some(t) => t <= now,
+ None => false,
+ }
+ }
+
+ // A simple addition of a packet number to the tracked set.
+ // This doesn't do a binary search on the assumption that
+ // new packets will generally be added to the start of the list.
+ fn add(&mut self, pn: PacketNumber) {
+ for i in 0..self.ranges.len() {
+ match self.ranges[i].add(pn) {
+ InsertionResult::Largest => return,
+ InsertionResult::Smallest => {
+ // If this was the smallest, it might have filled a gap.
+ let nxt = i + 1;
+ if (nxt < self.ranges.len()) && (pn - 1 == self.ranges[nxt].largest) {
+ let larger = self.ranges.remove(i).unwrap();
+ self.ranges[i].merge_larger(&larger);
+ }
+ return;
+ }
+ InsertionResult::NotInserted => {
+ if self.ranges[i].largest < pn {
+ self.ranges.insert(i, PacketRange::new(pn));
+ return;
+ }
+ }
+ }
+ }
+ self.ranges.push_back(PacketRange::new(pn));
+ }
+
+ fn trim_ranges(&mut self) {
+ // Limit the number of ranges that are tracked to MAX_TRACKED_RANGES.
+ if self.ranges.len() > MAX_TRACKED_RANGES {
+ let oldest = self.ranges.pop_back().unwrap();
+ if oldest.ack_needed {
+ qwarn!([self], "Dropping unacknowledged ACK range: {}", oldest);
+ // TODO(mt) Record some statistics about this so we can tune MAX_TRACKED_RANGES.
+ } else {
+ qdebug!([self], "Drop ACK range: {}", oldest);
+ }
+ self.min_tracked = oldest.largest + 1;
+ }
+ }
+
+ /// Add the packet to the tracked set.
+ pub fn set_received(&mut self, now: Instant, pn: PacketNumber, ack_eliciting: bool) {
+ let next_in_order_pn = self.ranges.front().map_or(0, |pr| pr.largest + 1);
+ qdebug!(
+ [self],
+ "received {}, next in order pn: {}",
+ pn,
+ next_in_order_pn
+ );
+
+ self.add(pn);
+ self.trim_ranges();
+
+ // The new addition was the largest, so update the time we use for calculating ACK delay.
+ if pn >= next_in_order_pn {
+ self.largest_pn_time = Some(now);
+ }
+
+ if ack_eliciting {
+ self.pkts_since_last_ack += 1;
+
+ // Send ACK right away if out-of-order
+ // On the first in-order ack-eliciting packet since sending an ACK,
+ // set a delay.
+ // Count packets until we exceed MAX_UNACKED_PKTS, then remove the
+ // delay.
+ if pn != next_in_order_pn {
+ self.ack_time = Some(now);
+ } else if self.space == PNSpace::ApplicationData {
+ match &mut self.pkts_since_last_ack {
+ 0 => unreachable!(),
+ 1 => self.ack_time = Some(now + ACK_DELAY),
+ x if *x > MAX_UNACKED_PKTS => self.ack_time = Some(now),
+ _ => debug_assert!(self.ack_time.is_some()),
+ }
+ } else {
+ self.ack_time = Some(now);
+ }
+ qdebug!([self], "Set ACK timer to {:?}", self.ack_time);
+ }
+ }
+
+ /// Check if the packet is a duplicate.
+ pub fn is_duplicate(&self, pn: PacketNumber) -> bool {
+ if pn < self.min_tracked {
+ return true;
+ }
+ // TODO(mt) consider a binary search or early exit.
+ for range in &self.ranges {
+ if range.contains(pn) {
+ return true;
+ }
+ }
+ false
+ }
+
+ /// Mark the given range as having been acknowledged.
+ pub fn acknowledged(&mut self, acked: &[PacketRange]) {
+ let mut range_iter = self.ranges.iter_mut();
+ let mut cur = range_iter.next().expect("should have at least one range");
+ for ack in acked {
+ while cur.smallest > ack.largest {
+ cur = match range_iter.next() {
+ Some(c) => c,
+ None => return,
+ };
+ }
+ cur.acknowledged(&ack);
+ }
+ }
+
+ /// Generate an ACK frame for this packet number space.
+ ///
+ /// Unlike other frame generators this doesn't modify the underlying instance
+ /// to track what has been sent. This only clears the delayed ACK timer.
+ ///
+ /// When sending ACKs, we want to always send the most recent ranges,
+ /// even if they have been sent in other packets.
+ ///
+ /// We don't send ranges that have been acknowledged, but they still need
+ /// to be tracked so that duplicates can be detected.
+ fn write_frame(
+ &mut self,
+ now: Instant,
+ builder: &mut PacketBuilder,
+ stats: &mut FrameStats,
+ ) -> Option<RecoveryToken> {
+ // The worst possible ACK frame, assuming only one range.
+ // Note that this assumes one byte for the type and count of extra ranges.
+ const LONGEST_ACK_HEADER: usize = 1 + 8 + 8 + 1 + 8;
+
+ // Check that we aren't delaying ACKs.
+ if !self.ack_now(now) {
+ return None;
+ }
+
+ // Drop extra ACK ranges to fit the available space. Do this based on
+ // a worst-case estimate of frame size for simplicity.
+ //
+ // When congestion limited, ACK-only packets are 255 bytes at most
+ // (`recovery::ACK_ONLY_SIZE_LIMIT - 1`). This results in limiting the
+ // ranges to 13 here.
+ let max_ranges = if let Some(avail) = builder.remaining().checked_sub(LONGEST_ACK_HEADER) {
+ // Apply a hard maximum to keep plenty of space for other stuff.
+ min(1 + (avail / 16), MAX_ACKS_PER_FRAME)
+ } else {
+ return None;
+ };
+
+ let ranges = self
+ .ranges
+ .iter()
+ .filter(|r| r.ack_needed())
+ .take(max_ranges)
+ .cloned()
+ .collect::<Vec<_>>();
+
+ builder.encode_varint(crate::frame::FRAME_TYPE_ACK);
+ let mut iter = ranges.iter();
+ let first = match iter.next() {
+ Some(v) => v,
+ None => return None, // Nothing to send.
+ };
+ builder.encode_varint(first.largest);
+ stats.largest_acknowledged = first.largest;
+ stats.ack += 1;
+
+ let elapsed = now.duration_since(self.largest_pn_time.unwrap());
+ // We use the default exponent, so delay is in multiples of 8 microseconds.
+ let ack_delay = u64::try_from(elapsed.as_micros() / 8).unwrap_or(u64::MAX);
+ let ack_delay = min((1 << 62) - 1, ack_delay);
+ builder.encode_varint(ack_delay);
+ builder.encode_varint(u64::try_from(ranges.len() - 1).unwrap()); // extra ranges
+ builder.encode_varint(first.len() - 1); // first range
+
+ let mut last = first.smallest;
+ for r in iter {
+ // the difference must be at least 2 because 0-length gaps,
+ // (difference 1) are illegal.
+ builder.encode_varint(last - r.largest - 2); // Gap
+ builder.encode_varint(r.len() - 1); // Range
+ last = r.smallest;
+ }
+
+ // We've sent an ACK, reset the timer.
+ self.ack_time = None;
+ self.pkts_since_last_ack = 0;
+
+ Some(RecoveryToken::Ack(AckToken {
+ space: self.space,
+ ranges,
+ }))
+ }
+}
+
+impl ::std::fmt::Display for RecvdPackets {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "Recvd-{}", self.space)
+ }
+}
+
+#[derive(Debug)]
+pub struct AckTracker {
+ /// This stores information about received packets in *reverse* order
+ /// by spaces. Why reverse? Because we ultimately only want to keep
+ /// `ApplicationData` and this allows us to drop other spaces easily.
+ spaces: SmallVec<[RecvdPackets; 1]>,
+}
+
+impl AckTracker {
+ pub fn drop_space(&mut self, space: PNSpace) {
+ let sp = match space {
+ PNSpace::Initial => self.spaces.pop(),
+ PNSpace::Handshake => {
+ let sp = self.spaces.pop();
+ self.spaces.shrink_to_fit();
+ sp
+ }
+ PNSpace::ApplicationData => panic!("discarding application space"),
+ };
+ assert_eq!(sp.unwrap().space, space, "dropping spaces out of order");
+ }
+
+ pub fn get_mut(&mut self, space: PNSpace) -> Option<&mut RecvdPackets> {
+ self.spaces.get_mut(match space {
+ PNSpace::ApplicationData => 0,
+ PNSpace::Handshake => 1,
+ PNSpace::Initial => 2,
+ })
+ }
+
+ /// Determine the earliest time that an ACK might be needed.
+ pub fn ack_time(&self, now: Instant) -> Option<Instant> {
+ if self.spaces.len() == 1 {
+ self.spaces[0].ack_time()
+ } else {
+ // Ignore any time that is in the past relative to `now`.
+ // That is something of a hack, but there are cases where we can't send ACK
+ // frames for all spaces, which can mean that one space is stuck in the past.
+ // That isn't a problem because we guarantee that earlier spaces will always
+ // be able to send ACK frames.
+ self.spaces
+ .iter()
+ .filter_map(|recvd| recvd.ack_time().filter(|t| *t > now))
+ .min()
+ }
+ }
+
+ pub fn acked(&mut self, token: &AckToken) {
+ if let Some(space) = self.get_mut(token.space) {
+ space.acknowledged(&token.ranges);
+ }
+ }
+
+ pub(crate) fn write_frame(
+ &mut self,
+ pn_space: PNSpace,
+ now: Instant,
+ builder: &mut PacketBuilder,
+ stats: &mut FrameStats,
+ ) -> Option<RecoveryToken> {
+ self.get_mut(pn_space)
+ .and_then(|space| space.write_frame(now, builder, stats))
+ }
+}
+
+impl Default for AckTracker {
+ fn default() -> Self {
+ Self {
+ spaces: smallvec![
+ RecvdPackets::new(PNSpace::ApplicationData),
+ RecvdPackets::new(PNSpace::Handshake),
+ RecvdPackets::new(PNSpace::Initial),
+ ],
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{
+ AckTracker, Duration, Instant, PNSpace, PNSpaceSet, RecoveryToken, RecvdPackets, ACK_DELAY,
+ MAX_TRACKED_RANGES, MAX_UNACKED_PKTS,
+ };
+ use crate::frame::Frame;
+ use crate::packet::PacketBuilder;
+ use crate::stats::FrameStats;
+ use lazy_static::lazy_static;
+ use neqo_common::Encoder;
+ use std::collections::HashSet;
+ use std::convert::TryFrom;
+
+ lazy_static! {
+ static ref NOW: Instant = Instant::now();
+ }
+
+ fn test_ack_range(pns: &[u64], nranges: usize) {
+ let mut rp = RecvdPackets::new(PNSpace::Initial); // Any space will do.
+ let mut packets = HashSet::new();
+
+ for pn in pns {
+ rp.set_received(*NOW, *pn, true);
+ packets.insert(*pn);
+ }
+
+ assert_eq!(rp.ranges.len(), nranges);
+
+ // Check that all these packets will be detected as duplicates.
+ for pn in pns {
+ assert!(rp.is_duplicate(*pn));
+ }
+
+ // Check that the ranges decrease monotonically and don't overlap.
+ let mut iter = rp.ranges.iter();
+ let mut last = iter.next().expect("should have at least one");
+ for n in iter {
+ assert!(n.largest + 1 < last.smallest);
+ last = n;
+ }
+
+ // Check that the ranges include the right values.
+ let mut in_ranges = HashSet::new();
+ for range in &rp.ranges {
+ for included in range.smallest..=range.largest {
+ in_ranges.insert(included);
+ }
+ }
+ assert_eq!(packets, in_ranges);
+ }
+
+ #[test]
+ fn pn0() {
+ test_ack_range(&[0], 1);
+ }
+
+ #[test]
+ fn pn1() {
+ test_ack_range(&[1], 1);
+ }
+
+ #[test]
+ fn two_ranges() {
+ test_ack_range(&[0, 1, 2, 5, 6, 7], 2);
+ }
+
+ #[test]
+ fn fill_in_range() {
+ test_ack_range(&[0, 1, 2, 5, 6, 7, 3, 4], 1);
+ }
+
+ #[test]
+ fn too_many_ranges() {
+ let mut rp = RecvdPackets::new(PNSpace::Initial); // Any space will do.
+
+ // This will add one too many disjoint ranges.
+ for i in 0..=MAX_TRACKED_RANGES {
+ rp.set_received(*NOW, (i * 2) as u64, true);
+ }
+
+ assert_eq!(rp.ranges.len(), MAX_TRACKED_RANGES);
+ assert_eq!(rp.ranges.back().unwrap().largest, 2);
+
+ // Even though the range was dropped, we still consider it a duplicate.
+ assert!(rp.is_duplicate(0));
+ assert!(!rp.is_duplicate(1));
+ assert!(rp.is_duplicate(2));
+ }
+
+ #[test]
+ fn ack_delay() {
+ // Only application data packets are delayed.
+ let mut rp = RecvdPackets::new(PNSpace::ApplicationData);
+ assert!(rp.ack_time().is_none());
+ assert!(!rp.ack_now(*NOW));
+
+ // Some packets won't cause an ACK to be needed.
+ let max_unacked = u64::try_from(MAX_UNACKED_PKTS).unwrap();
+ for num in 0..max_unacked {
+ rp.set_received(*NOW, num, true);
+ assert_eq!(Some(*NOW + ACK_DELAY), rp.ack_time());
+ assert!(!rp.ack_now(*NOW));
+ assert!(rp.ack_now(*NOW + ACK_DELAY));
+ }
+
+ // Exceeding MAX_UNACKED_PKTS will move the ACK time to now.
+ rp.set_received(*NOW, max_unacked, true);
+ assert_eq!(Some(*NOW), rp.ack_time());
+ assert!(rp.ack_now(*NOW));
+ }
+
+ #[test]
+ fn no_ack_delay() {
+ for space in &[PNSpace::Initial, PNSpace::Handshake] {
+ let mut rp = RecvdPackets::new(*space);
+ assert!(rp.ack_time().is_none());
+ assert!(!rp.ack_now(*NOW));
+
+ // Any packet will be acknowledged straight away.
+ rp.set_received(*NOW, 0, true);
+ assert_eq!(Some(*NOW), rp.ack_time());
+ assert!(rp.ack_now(*NOW));
+ }
+ }
+
+ #[test]
+ fn ooo_no_ack_delay() {
+ for space in &[
+ PNSpace::Initial,
+ PNSpace::Handshake,
+ PNSpace::ApplicationData,
+ ] {
+ let mut rp = RecvdPackets::new(*space);
+ assert!(rp.ack_time().is_none());
+ assert!(!rp.ack_now(*NOW));
+
+ // Any OoO packet will be acknowledged straight away.
+ rp.set_received(*NOW, 3, true);
+ assert_eq!(Some(*NOW), rp.ack_time());
+ assert!(rp.ack_now(*NOW));
+ }
+ }
+
+ #[test]
+ fn aggregate_ack_time() {
+ let mut tracker = AckTracker::default();
+ // This packet won't trigger an ACK.
+ tracker
+ .get_mut(PNSpace::Handshake)
+ .unwrap()
+ .set_received(*NOW, 0, false);
+ assert_eq!(None, tracker.ack_time(*NOW));
+
+ // This should be delayed.
+ tracker
+ .get_mut(PNSpace::ApplicationData)
+ .unwrap()
+ .set_received(*NOW, 0, true);
+ assert_eq!(Some(*NOW + ACK_DELAY), tracker.ack_time(*NOW));
+
+ // This should move the time forward.
+ let later = *NOW + ACK_DELAY.checked_div(2).unwrap();
+ tracker
+ .get_mut(PNSpace::Initial)
+ .unwrap()
+ .set_received(later, 0, true);
+ assert_eq!(Some(later), tracker.ack_time(*NOW));
+ }
+
+ #[test]
+ #[should_panic(expected = "discarding application space")]
+ fn drop_app() {
+ let mut tracker = AckTracker::default();
+ tracker.drop_space(PNSpace::ApplicationData);
+ }
+
+ #[test]
+ #[should_panic(expected = "dropping spaces out of order")]
+ fn drop_out_of_order() {
+ let mut tracker = AckTracker::default();
+ tracker.drop_space(PNSpace::Handshake);
+ }
+
+ #[test]
+ fn drop_spaces() {
+ let mut tracker = AckTracker::default();
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+ tracker
+ .get_mut(PNSpace::Initial)
+ .unwrap()
+ .set_received(*NOW, 0, true);
+ // The reference time for `ack_time` has to be in the past or we filter out the timer.
+ assert!(tracker.ack_time(*NOW - Duration::from_millis(1)).is_some());
+ let token = tracker.write_frame(
+ PNSpace::Initial,
+ *NOW,
+ &mut builder,
+ &mut FrameStats::default(),
+ );
+ assert!(token.is_some());
+
+ // Mark another packet as received so we have cause to send another ACK in that space.
+ tracker
+ .get_mut(PNSpace::Initial)
+ .unwrap()
+ .set_received(*NOW, 1, true);
+ assert!(tracker.ack_time(*NOW - Duration::from_millis(1)).is_some());
+
+ // Now drop that space.
+ tracker.drop_space(PNSpace::Initial);
+
+ assert!(tracker.get_mut(PNSpace::Initial).is_none());
+ assert!(tracker.ack_time(*NOW - Duration::from_millis(1)).is_none());
+ assert!(tracker
+ .write_frame(
+ PNSpace::Initial,
+ *NOW,
+ &mut builder,
+ &mut FrameStats::default()
+ )
+ .is_none());
+ if let RecoveryToken::Ack(tok) = token.unwrap() {
+ tracker.acked(&tok); // Should be a noop.
+ } else {
+ panic!("not an ACK token");
+ }
+ }
+
+ #[test]
+ fn no_room_for_ack() {
+ let mut tracker = AckTracker::default();
+ tracker
+ .get_mut(PNSpace::Initial)
+ .unwrap()
+ .set_received(*NOW, 0, true);
+ assert!(tracker.ack_time(*NOW - Duration::from_millis(1)).is_some());
+
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+ builder.set_limit(10);
+
+ let token = tracker.write_frame(
+ PNSpace::Initial,
+ *NOW,
+ &mut builder,
+ &mut FrameStats::default(),
+ );
+ assert!(token.is_none());
+ assert_eq!(builder.len(), 1); // Only the short packet header has been added.
+ }
+
+ #[test]
+ fn no_room_for_extra_range() {
+ let mut tracker = AckTracker::default();
+ tracker
+ .get_mut(PNSpace::Initial)
+ .unwrap()
+ .set_received(*NOW, 0, true);
+ tracker
+ .get_mut(PNSpace::Initial)
+ .unwrap()
+ .set_received(*NOW, 2, true);
+ assert!(tracker.ack_time(*NOW - Duration::from_millis(1)).is_some());
+
+ let mut builder = PacketBuilder::short(Encoder::new(), false, &[]);
+ builder.set_limit(32);
+
+ let token = tracker.write_frame(
+ PNSpace::Initial,
+ *NOW,
+ &mut builder,
+ &mut FrameStats::default(),
+ );
+ assert!(token.is_some());
+
+ let mut dec = builder.as_decoder();
+ let _ = dec.decode_byte().unwrap(); // Skip the short header.
+ let frame = Frame::decode(&mut dec).unwrap();
+ if let Frame::Ack { ack_ranges, .. } = frame {
+ assert_eq!(ack_ranges.len(), 0);
+ } else {
+ panic!("not an ACK!");
+ }
+ }
+
+ #[test]
+ fn ack_time_elapsed() {
+ let mut tracker = AckTracker::default();
+
+ // While we have multiple PN spaces, we ignore ACK timers from the past.
+ // Send out of order to cause the delayed ack timer to be set to `*NOW`.
+ tracker
+ .get_mut(PNSpace::ApplicationData)
+ .unwrap()
+ .set_received(*NOW, 3, true);
+ assert!(tracker.ack_time(*NOW + Duration::from_millis(1)).is_none());
+
+ // When we are reduced to one space, that filter is off.
+ tracker.drop_space(PNSpace::Initial);
+ tracker.drop_space(PNSpace::Handshake);
+ assert_eq!(
+ tracker.ack_time(*NOW + Duration::from_millis(1)),
+ Some(*NOW)
+ );
+ }
+
+ #[test]
+ fn pnspaceset_default() {
+ let set = PNSpaceSet::default();
+ assert!(!set[PNSpace::Initial]);
+ assert!(!set[PNSpace::Handshake]);
+ assert!(!set[PNSpace::ApplicationData]);
+ }
+
+ #[test]
+ fn pnspaceset_from() {
+ let set = PNSpaceSet::from(&[PNSpace::Initial]);
+ assert!(set[PNSpace::Initial]);
+ assert!(!set[PNSpace::Handshake]);
+ assert!(!set[PNSpace::ApplicationData]);
+
+ let set = PNSpaceSet::from(&[PNSpace::Handshake, PNSpace::Initial]);
+ assert!(set[PNSpace::Initial]);
+ assert!(set[PNSpace::Handshake]);
+ assert!(!set[PNSpace::ApplicationData]);
+
+ let set = PNSpaceSet::from(&[PNSpace::ApplicationData, PNSpace::ApplicationData]);
+ assert!(!set[PNSpace::Initial]);
+ assert!(!set[PNSpace::Handshake]);
+ assert!(set[PNSpace::ApplicationData]);
+ }
+
+ #[test]
+ fn pnspaceset_copy() {
+ let set = PNSpaceSet::from(&[PNSpace::Handshake, PNSpace::ApplicationData]);
+ let copy = set;
+ assert!(!copy[PNSpace::Initial]);
+ assert!(copy[PNSpace::Handshake]);
+ assert!(copy[PNSpace::ApplicationData]);
+ }
+}