summaryrefslogtreecommitdiffstats
path: root/third_party/rust/neqo-transport/src/packet/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--third_party/rust/neqo-transport/src/packet/mod.rs99
1 files changed, 84 insertions, 15 deletions
diff --git a/third_party/rust/neqo-transport/src/packet/mod.rs b/third_party/rust/neqo-transport/src/packet/mod.rs
index 8458f69779..ce611a9664 100644
--- a/third_party/rust/neqo-transport/src/packet/mod.rs
+++ b/third_party/rust/neqo-transport/src/packet/mod.rs
@@ -18,6 +18,7 @@ use neqo_crypto::random;
use crate::{
cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdRef, MAX_CONNECTION_ID_LEN},
crypto::{CryptoDxState, CryptoSpace, CryptoStates},
+ frame::FRAME_TYPE_PADDING,
version::{Version, WireVersion},
Error, Res,
};
@@ -157,7 +158,7 @@ impl PacketBuilder {
}
Self {
encoder,
- pn: u64::max_value(),
+ pn: u64::MAX,
header: header_start..header_start,
offsets: PacketBuilderOffsets {
first_byte_mask: PACKET_HP_MASK_SHORT,
@@ -200,7 +201,7 @@ impl PacketBuilder {
Self {
encoder,
- pn: u64::max_value(),
+ pn: u64::MAX,
header: header_start..header_start,
offsets: PacketBuilderOffsets {
first_byte_mask: PACKET_HP_MASK_LONG,
@@ -255,9 +256,14 @@ impl PacketBuilder {
/// Maybe pad with "PADDING" frames.
/// Only does so if padding was needed and this is a short packet.
/// Returns true if padding was added.
+ ///
+ /// # Panics
+ ///
+ /// Cannot happen.
pub fn pad(&mut self) -> bool {
if self.padding && !self.is_long() {
- self.encoder.pad_to(self.limit, 0);
+ self.encoder
+ .pad_to(self.limit, FRAME_TYPE_PADDING.try_into().unwrap());
true
} else {
false
@@ -288,6 +294,10 @@ impl PacketBuilder {
/// The length is filled in after calling `build`.
/// Does nothing if there isn't 4 bytes available other than render this builder
/// unusable; if `remaining()` returns 0 at any point, call `abort()`.
+ ///
+ /// # Panics
+ ///
+ /// This will panic if the packet number length is too large.
pub fn pn(&mut self, pn: PacketNumber, pn_len: usize) {
if self.remaining() < 4 {
self.limit = 0;
@@ -352,6 +362,10 @@ impl PacketBuilder {
}
/// Build the packet and return the encoder.
+ ///
+ /// # Errors
+ ///
+ /// This will return an error if the packet is too large.
pub fn build(mut self, crypto: &mut CryptoDxState) -> Res<Encoder> {
if self.len() > self.limit {
qwarn!("Packet contents are more than the limit");
@@ -376,7 +390,9 @@ impl PacketBuilder {
// Calculate the mask.
let offset = SAMPLE_OFFSET - self.offsets.pn.len();
- assert!(offset + SAMPLE_SIZE <= ciphertext.len());
+ if offset + SAMPLE_SIZE > ciphertext.len() {
+ return Err(Error::InternalError);
+ }
let sample = &ciphertext[offset..offset + SAMPLE_SIZE];
let mask = crypto.compute_mask(sample)?;
@@ -410,6 +426,10 @@ impl PacketBuilder {
/// As this is a simple packet, this is just an associated function.
/// As Retry is odd (it has to be constructed with leading bytes),
/// this returns a [`Vec<u8>`] rather than building on an encoder.
+ ///
+ /// # Errors
+ ///
+ /// This will return an error if AEAD encrypt fails.
#[allow(clippy::similar_names)] // scid and dcid are fine here.
pub fn retry(
version: Version,
@@ -443,6 +463,7 @@ impl PacketBuilder {
/// Make a Version Negotiation packet.
#[allow(clippy::similar_names)] // scid and dcid are fine here.
+ #[must_use]
pub fn version_negotiation(
dcid: &[u8],
scid: &[u8],
@@ -534,7 +555,10 @@ impl<'a> PublicPacket<'a> {
if packet_type == PacketType::Retry {
let header_len = decoder.offset();
let expansion = retry::expansion(version);
- let token = Self::opt(decoder.decode(decoder.remaining() - expansion))?;
+ let token = decoder
+ .remaining()
+ .checked_sub(expansion)
+ .map_or(Err(Error::InvalidPacket), |v| Self::opt(decoder.decode(v)))?;
if token.is_empty() {
return Err(Error::InvalidPacket);
}
@@ -554,6 +578,10 @@ impl<'a> PublicPacket<'a> {
/// Decode the common parts of a packet. This provides minimal parsing and validation.
/// Returns a tuple of a `PublicPacket` and a slice with any remainder from the datagram.
+ ///
+ /// # Errors
+ ///
+ /// This will return an error if the packet could not be decoded.
#[allow(clippy::similar_names)] // For dcid and scid, which are fine.
pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> {
let mut decoder = Decoder::new(data);
@@ -585,7 +613,7 @@ impl<'a> PublicPacket<'a> {
}
// Generic long header.
- let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?).unwrap();
+ let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?)?;
let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
@@ -645,11 +673,14 @@ impl<'a> PublicPacket<'a> {
}
/// Validate the given packet as though it were a retry.
+ #[must_use]
pub fn is_valid_retry(&self, odcid: &ConnectionId) -> bool {
if self.packet_type != PacketType::Retry {
return false;
}
- let version = self.version().unwrap();
+ let Some(version) = self.version() else {
+ return false;
+ };
let expansion = retry::expansion(version);
if self.data.len() <= expansion {
return false;
@@ -665,6 +696,7 @@ impl<'a> PublicPacket<'a> {
.unwrap_or(false)
}
+ #[must_use]
pub fn is_valid_initial(&self) -> bool {
// Packet has to be an initial, with a DCID of 8 bytes, or a token.
// Note: the Server class validates the token and checks the length.
@@ -672,32 +704,42 @@ impl<'a> PublicPacket<'a> {
&& (self.dcid().len() >= 8 || !self.token.is_empty())
}
+ #[must_use]
pub fn packet_type(&self) -> PacketType {
self.packet_type
}
+ #[must_use]
pub fn dcid(&self) -> ConnectionIdRef<'a> {
self.dcid
}
+ /// # Panics
+ ///
+ /// This will panic if called for a short header packet.
+ #[must_use]
pub fn scid(&self) -> ConnectionIdRef<'a> {
self.scid
.expect("should only be called for long header packets")
}
+ #[must_use]
pub fn token(&self) -> &'a [u8] {
self.token
}
+ #[must_use]
pub fn version(&self) -> Option<Version> {
self.version.and_then(|v| Version::try_from(v).ok())
}
+ #[must_use]
pub fn wire_version(&self) -> WireVersion {
debug_assert!(self.version.is_some());
self.version.unwrap_or(0)
}
+ #[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
@@ -725,14 +767,10 @@ impl<'a> PublicPacket<'a> {
assert_ne!(self.packet_type, PacketType::Retry);
assert_ne!(self.packet_type, PacketType::VersionNegotiation);
- qtrace!(
- "unmask hdr={}",
- hex(&self.data[..self.header_len + SAMPLE_OFFSET])
- );
-
let sample_offset = self.header_len + SAMPLE_OFFSET;
let mask = if let Some(sample) = self.data.get(sample_offset..(sample_offset + SAMPLE_SIZE))
{
+ qtrace!("unmask hdr={}", hex(&self.data[..sample_offset]));
crypto.compute_mask(sample)
} else {
Err(Error::NoMoreData)
@@ -776,6 +814,9 @@ impl<'a> PublicPacket<'a> {
))
}
+ /// # Errors
+ ///
+ /// This will return an error if the packet cannot be decrypted.
pub fn decrypt(&self, crypto: &mut CryptoStates, release_at: Instant) -> Res<DecryptedPacket> {
let cspace: CryptoSpace = self.packet_type.into();
// When we don't have a version, the crypto code doesn't need a version
@@ -790,7 +831,9 @@ impl<'a> PublicPacket<'a> {
// too small (which is public information).
let (key_phase, pn, header, body) = self.decrypt_header(rx)?;
qtrace!([rx], "decoded header: {:?}", header);
- let rx = crypto.rx(version, cspace, key_phase).unwrap();
+ let Some(rx) = crypto.rx(version, cspace, key_phase) else {
+ return Err(Error::DecryptError);
+ };
let version = rx.version(); // Version fixup; see above.
let d = rx.decrypt(pn, &header, body)?;
// If this is the first packet ever successfully decrypted
@@ -813,8 +856,14 @@ impl<'a> PublicPacket<'a> {
}
}
+ /// # Errors
+ ///
+ /// This will return an error if the packet is not a version negotiation packet
+ /// or if the versions cannot be decoded.
pub fn supported_versions(&self) -> Res<Vec<WireVersion>> {
- assert_eq!(self.packet_type, PacketType::VersionNegotiation);
+ if self.packet_type != PacketType::VersionNegotiation {
+ return Err(Error::InvalidPacket);
+ }
let mut decoder = Decoder::new(&self.data[self.header_len..]);
let mut res = Vec::new();
while decoder.remaining() > 0 {
@@ -845,14 +894,17 @@ pub struct DecryptedPacket {
}
impl DecryptedPacket {
+ #[must_use]
pub fn version(&self) -> Version {
self.version
}
+ #[must_use]
pub fn packet_type(&self) -> PacketType {
self.pt
}
+ #[must_use]
pub fn pn(&self) -> PacketNumber {
self.pn
}
@@ -866,7 +918,7 @@ impl Deref for DecryptedPacket {
}
}
-#[cfg(all(test, not(feature = "fuzzing")))]
+#[cfg(all(test, not(feature = "disable-encryption")))]
mod tests {
use neqo_common::Encoder;
use test_fixture::{fixture_init, now};
@@ -1469,4 +1521,21 @@ mod tests {
assert_eq!(decrypted.pn(), 654_360_564);
assert_eq!(&decrypted[..], &[0x01]);
}
+
+ #[test]
+ fn decode_empty() {
+ neqo_crypto::init().unwrap();
+ let res = PublicPacket::decode(&[], &EmptyConnectionIdGenerator::default());
+ assert!(res.is_err());
+ }
+
+ #[test]
+ fn decode_too_short() {
+ neqo_crypto::init().unwrap();
+ let res = PublicPacket::decode(
+ &[179, 255, 0, 0, 32, 0, 0],
+ &EmptyConnectionIdGenerator::default(),
+ );
+ assert!(res.is_err());
+ }
}