summaryrefslogtreecommitdiffstats
path: root/third_party/rust/neqo-transport/src/cid.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/neqo-transport/src/cid.rs')
-rw-r--r--third_party/rust/neqo-transport/src/cid.rs98
1 files changed, 49 insertions, 49 deletions
diff --git a/third_party/rust/neqo-transport/src/cid.rs b/third_party/rust/neqo-transport/src/cid.rs
index be202daf25..6b3a95eaf0 100644
--- a/third_party/rust/neqo-transport/src/cid.rs
+++ b/third_party/rust/neqo-transport/src/cid.rs
@@ -10,14 +10,13 @@ use std::{
borrow::Borrow,
cell::{Ref, RefCell},
cmp::{max, min},
- convert::{AsRef, TryFrom},
ops::Deref,
rc::Rc,
};
use neqo_common::{hex, hex_with_len, qinfo, Decoder, Encoder};
-use neqo_crypto::random;
-use smallvec::SmallVec;
+use neqo_crypto::{random, randomize};
+use smallvec::{smallvec, SmallVec};
use crate::{
frame::FRAME_TYPE_NEW_CONNECTION_ID, packet::PacketBuilder, recovery::RecoveryToken,
@@ -39,19 +38,26 @@ pub struct ConnectionId {
}
impl ConnectionId {
+ /// # Panics
+ /// When `len` is larger than `MAX_CONNECTION_ID_LEN`.
+ #[must_use]
pub fn generate(len: usize) -> Self {
assert!(matches!(len, 0..=MAX_CONNECTION_ID_LEN));
- Self::from(random(len))
+ let mut cid = smallvec![0; len];
+ randomize(&mut cid);
+ Self { cid }
}
// Apply a wee bit of greasing here in picking a length between 8 and 20 bytes long.
+ #[must_use]
pub fn generate_initial() -> Self {
- let v = random(1);
+ let v = random::<1>()[0];
// Bias selection toward picking 8 (>50% of the time).
- let len: usize = max(8, 5 + (v[0] & (v[0] >> 4))).into();
+ let len: usize = max(8, 5 + (v & (v >> 4))).into();
Self::generate(len)
}
+ #[must_use]
pub fn as_cid_ref(&self) -> ConnectionIdRef {
ConnectionIdRef::from(&self.cid[..])
}
@@ -75,12 +81,6 @@ impl From<SmallVec<[u8; MAX_CONNECTION_ID_LEN]>> for ConnectionId {
}
}
-impl From<Vec<u8>> for ConnectionId {
- fn from(cid: Vec<u8>) -> Self {
- Self::from(SmallVec::from(cid))
- }
-}
-
impl<T: AsRef<[u8]> + ?Sized> From<&T> for ConnectionId {
fn from(buf: &T) -> Self {
Self::from(SmallVec::from(buf.as_ref()))
@@ -201,7 +201,7 @@ impl ConnectionIdGenerator for EmptyConnectionIdGenerator {
}
}
-/// An RandomConnectionIdGenerator produces connection IDs of
+/// An `RandomConnectionIdGenerator` produces connection IDs of
/// a fixed length and random content. No effort is made to
/// prevent collisions.
pub struct RandomConnectionIdGenerator {
@@ -209,6 +209,7 @@ pub struct RandomConnectionIdGenerator {
}
impl RandomConnectionIdGenerator {
+ #[must_use]
pub fn new(len: usize) -> Self {
Self { len }
}
@@ -222,7 +223,9 @@ impl ConnectionIdDecoder for RandomConnectionIdGenerator {
impl ConnectionIdGenerator for RandomConnectionIdGenerator {
fn generate_cid(&mut self) -> Option<ConnectionId> {
- Some(ConnectionId::from(&random(self.len)))
+ let mut buf = smallvec![0; self.len];
+ randomize(&mut buf);
+ Some(ConnectionId::from(buf))
}
fn as_decoder(&self) -> &dyn ConnectionIdDecoder {
@@ -234,7 +237,7 @@ impl ConnectionIdGenerator for RandomConnectionIdGenerator {
}
}
-/// A single connection ID, as saved from NEW_CONNECTION_ID.
+/// A single connection ID, as saved from `NEW_CONNECTION_ID`.
/// This is templated so that the connection ID entries from a peer can be
/// saved with a stateless reset token. Local entries don't need that.
#[derive(Debug, PartialEq, Eq, Clone)]
@@ -250,8 +253,8 @@ pub struct ConnectionIdEntry<SRT: Clone + PartialEq> {
impl ConnectionIdEntry<[u8; 16]> {
/// Create a random stateless reset token so that it is hard to guess the correct
/// value and reset the connection.
- fn random_srt() -> [u8; 16] {
- <[u8; 16]>::try_from(&random(16)[..]).unwrap()
+ pub fn random_srt() -> [u8; 16] {
+ random::<16>()
}
/// Create the first entry, which won't have a stateless reset token.
@@ -294,6 +297,23 @@ impl ConnectionIdEntry<[u8; 16]> {
pub fn sequence_number(&self) -> u64 {
self.seqno
}
+
+ /// Write the entry out in a `NEW_CONNECTION_ID` frame.
+ /// Returns `true` if the frame was written, `false` if there is insufficient space.
+ pub fn write(&self, builder: &mut PacketBuilder, stats: &mut FrameStats) -> bool {
+ let len = 1 + Encoder::varint_len(self.seqno) + 1 + 1 + self.cid.len() + 16;
+ if builder.remaining() < len {
+ return false;
+ }
+
+ builder.encode_varint(FRAME_TYPE_NEW_CONNECTION_ID);
+ builder.encode_varint(self.seqno);
+ builder.encode_varint(0u64);
+ builder.encode_vec(1, &self.cid);
+ builder.encode(&self.srt);
+ stats.new_connection_id += 1;
+ true
+ }
}
impl ConnectionIdEntry<()> {
@@ -430,7 +450,7 @@ pub struct ConnectionIdManager {
limit: usize,
/// The next sequence number that will be used for sending `NEW_CONNECTION_ID` frames.
next_seqno: u64,
- /// Outstanding, but lost NEW_CONNECTION_ID frames will be stored here.
+ /// Outstanding, but lost `NEW_CONNECTION_ID` frames will be stored here.
lost_new_connection_id: Vec<ConnectionIdEntry<[u8; 16]>>,
}
@@ -476,7 +496,7 @@ impl ConnectionIdManager {
.add_local(ConnectionIdEntry::new(self.next_seqno, cid.clone(), ()));
self.next_seqno += 1;
- let srt = <[u8; 16]>::try_from(&random(16)[..]).unwrap();
+ let srt = ConnectionIdEntry::random_srt();
Ok((cid, srt))
} else {
Err(Error::ConnectionIdsExhausted)
@@ -516,39 +536,19 @@ impl ConnectionIdManager {
);
}
- fn write_entry(
- &mut self,
- entry: &ConnectionIdEntry<[u8; 16]>,
- builder: &mut PacketBuilder,
- stats: &mut FrameStats,
- ) -> Res<bool> {
- let len = 1 + Encoder::varint_len(entry.seqno) + 1 + 1 + entry.cid.len() + 16;
- if builder.remaining() < len {
- return Ok(false);
- }
-
- builder.encode_varint(FRAME_TYPE_NEW_CONNECTION_ID);
- builder.encode_varint(entry.seqno);
- builder.encode_varint(0u64);
- builder.encode_vec(1, &entry.cid);
- builder.encode(&entry.srt);
- stats.new_connection_id += 1;
- Ok(true)
- }
-
pub fn write_frames(
&mut self,
builder: &mut PacketBuilder,
tokens: &mut Vec<RecoveryToken>,
stats: &mut FrameStats,
- ) -> Res<()> {
+ ) {
if self.generator.deref().borrow().generates_empty_cids() {
debug_assert_eq!(self.generator.borrow_mut().generate_cid().unwrap().len(), 0);
- return Ok(());
+ return;
}
while let Some(entry) = self.lost_new_connection_id.pop() {
- if self.write_entry(&entry, builder, stats)? {
+ if entry.write(builder, stats) {
tokens.push(RecoveryToken::NewConnectionId(entry));
} else {
// This shouldn't happen often.
@@ -565,7 +565,7 @@ impl ConnectionIdManager {
if let Some(cid) = maybe_cid {
assert_ne!(cid.len(), 0);
// TODO: generate the stateless reset tokens from the connection ID and a key.
- let srt = <[u8; 16]>::try_from(&random(16)[..]).unwrap();
+ let srt = ConnectionIdEntry::random_srt();
let seqno = self.next_seqno;
self.next_seqno += 1;
@@ -573,11 +573,10 @@ impl ConnectionIdManager {
.add_local(ConnectionIdEntry::new(seqno, cid.clone(), ()));
let entry = ConnectionIdEntry::new(seqno, cid, srt);
- debug_assert!(self.write_entry(&entry, builder, stats)?);
+ entry.write(builder, stats);
tokens.push(RecoveryToken::NewConnectionId(entry));
}
}
- Ok(())
}
pub fn lost(&mut self, entry: &ConnectionIdEntry<[u8; 16]>) {
@@ -594,16 +593,17 @@ impl ConnectionIdManager {
mod tests {
use test_fixture::fixture_init;
- use super::*;
+ use crate::{cid::MAX_CONNECTION_ID_LEN, ConnectionId};
#[test]
fn generate_initial_cid() {
fixture_init();
for _ in 0..100 {
let cid = ConnectionId::generate_initial();
- if !matches!(cid.len(), 8..=MAX_CONNECTION_ID_LEN) {
- panic!("connection ID {:?}", cid);
- }
+ assert!(
+ matches!(cid.len(), 8..=MAX_CONNECTION_ID_LEN),
+ "connection ID length {cid:?}",
+ );
}
}
}