diff options
Diffstat (limited to 'third_party/rust/neqo-transport/src/packet/mod.rs')
-rw-r--r-- | third_party/rust/neqo-transport/src/packet/mod.rs | 1452 |
1 files changed, 1452 insertions, 0 deletions
diff --git a/third_party/rust/neqo-transport/src/packet/mod.rs b/third_party/rust/neqo-transport/src/packet/mod.rs new file mode 100644 index 0000000000..631bf84795 --- /dev/null +++ b/third_party/rust/neqo-transport/src/packet/mod.rs @@ -0,0 +1,1452 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Encoding and decoding packets off the wire. +use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdRef, MAX_CONNECTION_ID_LEN}; +use crate::crypto::{CryptoDxState, CryptoSpace, CryptoStates}; +use crate::version::{Version, WireVersion}; +use crate::{Error, Res}; + +use neqo_common::{hex, hex_with_len, qtrace, qwarn, Decoder, Encoder}; +use neqo_crypto::random; + +use std::cmp::min; +use std::convert::TryFrom; +use std::fmt; +use std::iter::ExactSizeIterator; +use std::ops::{Deref, DerefMut, Range}; +use std::time::Instant; + +pub const PACKET_BIT_LONG: u8 = 0x80; +const PACKET_BIT_SHORT: u8 = 0x00; +const PACKET_BIT_FIXED_QUIC: u8 = 0x40; +const PACKET_BIT_SPIN: u8 = 0x20; +const PACKET_BIT_KEY_PHASE: u8 = 0x04; + +const PACKET_HP_MASK_LONG: u8 = 0x0f; +const PACKET_HP_MASK_SHORT: u8 = 0x1f; + +const SAMPLE_SIZE: usize = 16; +const SAMPLE_OFFSET: usize = 4; +const MAX_PACKET_NUMBER_LEN: usize = 4; + +mod retry; + +pub type PacketNumber = u64; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PacketType { + VersionNegotiation, + Initial, + Handshake, + ZeroRtt, + Retry, + Short, + OtherVersion, +} + +impl PacketType { + #[must_use] + fn from_byte(t: u8, v: Version) -> Self { + // Version2 adds one to the type, modulo 4 + match t.wrapping_sub(u8::from(v == Version::Version2)) & 3 { + 0 => Self::Initial, + 1 => Self::ZeroRtt, + 2 => Self::Handshake, + 3 => Self::Retry, + _ => panic!("packet type out of range"), + } + } + + #[must_use] + fn to_byte(self, v: Version) -> u8 { + let t = match self { + Self::Initial => 0, + Self::ZeroRtt => 1, + Self::Handshake => 2, + Self::Retry => 3, + _ => panic!("not a long header packet type"), + }; + // Version2 adds one to the type, modulo 4 + (t + u8::from(v == Version::Version2)) & 3 + } +} + +impl From<PacketType> for CryptoSpace { + fn from(v: PacketType) -> Self { + match v { + PacketType::Initial => Self::Initial, + PacketType::ZeroRtt => Self::ZeroRtt, + PacketType::Handshake => Self::Handshake, + PacketType::Short => Self::ApplicationData, + _ => panic!("shouldn't be here"), + } + } +} + +impl From<CryptoSpace> for PacketType { + fn from(cs: CryptoSpace) -> Self { + match cs { + CryptoSpace::Initial => Self::Initial, + CryptoSpace::ZeroRtt => Self::ZeroRtt, + CryptoSpace::Handshake => Self::Handshake, + CryptoSpace::ApplicationData => Self::Short, + } + } +} + +struct PacketBuilderOffsets { + /// The bits of the first octet that need masking. + first_byte_mask: u8, + /// The offset of the length field. + len: usize, + /// The location of the packet number field. + pn: Range<usize>, +} + +/// A packet builder that can be used to produce short packets and long packets. +/// This does not produce Retry or Version Negotiation. +pub struct PacketBuilder { + encoder: Encoder, + pn: PacketNumber, + header: Range<usize>, + offsets: PacketBuilderOffsets, + limit: usize, + /// Whether to pad the packet before construction. + padding: bool, +} + +impl PacketBuilder { + /// The minimum useful frame size. If space is less than this, we will claim to be full. + pub const MINIMUM_FRAME_SIZE: usize = 2; + + fn infer_limit(encoder: &Encoder) -> usize { + if encoder.capacity() > 64 { + encoder.capacity() + } else { + 2048 + } + } + + /// Start building a short header packet. + /// + /// This doesn't fail if there isn't enough space; instead it returns a builder that + /// has no available space left. This allows the caller to extract the encoder + /// and any packets that might have been added before as adding a packet header is + /// only likely to fail if there are other packets already written. + /// + /// If, after calling this method, `remaining()` returns 0, then call `abort()` to get + /// the encoder back. + #[allow(clippy::reversed_empty_ranges)] + pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self { + let mut limit = Self::infer_limit(&encoder); + let header_start = encoder.len(); + // Check that there is enough space for the header. + // 5 = 1 (first byte) + 4 (packet number) + if limit > encoder.len() && 5 + dcid.as_ref().len() < limit - encoder.len() { + encoder + .encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2)); + encoder.encode(dcid.as_ref()); + } else { + limit = 0; + } + Self { + encoder, + pn: u64::max_value(), + header: header_start..header_start, + offsets: PacketBuilderOffsets { + first_byte_mask: PACKET_HP_MASK_SHORT, + pn: 0..0, + len: 0, + }, + limit, + padding: false, + } + } + + /// Start building a long header packet. + /// For an Initial packet you will need to call initial_token(), + /// even if the token is empty. + /// + /// See `short()` for more on how to handle this in cases where there is no space. + #[allow(clippy::reversed_empty_ranges)] // For initializing an empty range. + pub fn long( + mut encoder: Encoder, + pt: PacketType, + version: Version, + dcid: impl AsRef<[u8]>, + scid: impl AsRef<[u8]>, + ) -> Self { + let mut limit = Self::infer_limit(&encoder); + let header_start = encoder.len(); + // Check that there is enough space for the header. + // 11 = 1 (first byte) + 4 (version) + 2 (dcid+scid length) + 4 (packet number) + if limit > encoder.len() + && 11 + dcid.as_ref().len() + scid.as_ref().len() < limit - encoder.len() + { + encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.to_byte(version) << 4); + encoder.encode_uint(4, version.wire_version()); + encoder.encode_vec(1, dcid.as_ref()); + encoder.encode_vec(1, scid.as_ref()); + } else { + limit = 0; + } + + Self { + encoder, + pn: u64::max_value(), + header: header_start..header_start, + offsets: PacketBuilderOffsets { + first_byte_mask: PACKET_HP_MASK_LONG, + pn: 0..0, + len: 0, + }, + limit, + padding: false, + } + } + + fn is_long(&self) -> bool { + self.as_ref()[self.header.start] & 0x80 == PACKET_BIT_LONG + } + + /// This stores a value that can be used as a limit. This does not cause + /// this limit to be enforced until encryption occurs. Prior to that, it + /// is only used voluntarily by users of the builder, through `remaining()`. + pub fn set_limit(&mut self, limit: usize) { + self.limit = limit; + } + + /// Get the current limit. + #[must_use] + pub fn limit(&mut self) -> usize { + self.limit + } + + /// How many bytes remain against the size limit for the builder. + #[must_use] + pub fn remaining(&self) -> usize { + self.limit.saturating_sub(self.encoder.len()) + } + + /// Returns true if the packet has no more space for frames. + #[must_use] + pub fn is_full(&self) -> bool { + // No useful frame is smaller than 2 bytes long. + self.limit < self.encoder.len() + Self::MINIMUM_FRAME_SIZE + } + + /// Adjust the limit to ensure that no more data is added. + pub fn mark_full(&mut self) { + self.limit = self.encoder.len() + } + + /// Mark the packet as needing padding (or not). + pub fn enable_padding(&mut self, needs_padding: bool) { + self.padding = needs_padding; + } + + /// Maybe pad with "PADDING" frames. + /// Only does so if padding was needed and this is a short packet. + /// Returns true if padding was added. + pub fn pad(&mut self) -> bool { + if self.padding && !self.is_long() { + self.encoder.pad_to(self.limit, 0); + true + } else { + false + } + } + + /// Add unpredictable values for unprotected parts of the packet. + pub fn scramble(&mut self, quic_bit: bool) { + debug_assert!(self.len() > self.header.start); + let mask = if quic_bit { PACKET_BIT_FIXED_QUIC } else { 0 } + | if self.is_long() { 0 } else { PACKET_BIT_SPIN }; + let first = self.header.start; + self.encoder.as_mut()[first] ^= random(1)[0] & mask; + } + + /// For an Initial packet, encode the token. + /// If you fail to do this, then you will not get a valid packet. + pub fn initial_token(&mut self, token: &[u8]) { + if Encoder::vvec_len(token.len()) < self.remaining() { + self.encoder.encode_vvec(token); + } else { + self.limit = 0; + } + } + + /// Add a packet number of the given size. + /// For a long header packet, this also inserts a dummy length. + /// The length is filled in after calling `build`. + /// Does nothing if there isn't 4 bytes available other than render this builder + /// unusable; if `remaining()` returns 0 at any point, call `abort()`. + pub fn pn(&mut self, pn: PacketNumber, pn_len: usize) { + if self.remaining() < 4 { + self.limit = 0; + return; + } + + // Reserve space for a length in long headers. + if self.is_long() { + self.offsets.len = self.encoder.len(); + self.encoder.encode(&[0; 2]); + } + + // This allows the input to be >4, which is absurd, but we can eat that. + let pn_len = min(MAX_PACKET_NUMBER_LEN, pn_len); + debug_assert_ne!(pn_len, 0); + // Encode the packet number and save its offset. + let pn_offset = self.encoder.len(); + self.encoder.encode_uint(pn_len, pn); + self.offsets.pn = pn_offset..self.encoder.len(); + + // Now encode the packet number length and save the header length. + self.encoder.as_mut()[self.header.start] |= u8::try_from(pn_len - 1).unwrap(); + self.header.end = self.encoder.len(); + self.pn = pn; + } + + fn write_len(&mut self, expansion: usize) { + let len = self.encoder.len() - (self.offsets.len + 2) + expansion; + self.encoder.as_mut()[self.offsets.len] = 0x40 | ((len >> 8) & 0x3f) as u8; + self.encoder.as_mut()[self.offsets.len + 1] = (len & 0xff) as u8; + } + + fn pad_for_crypto(&mut self, crypto: &mut CryptoDxState) { + // Make sure that there is enough data in the packet. + // The length of the packet number plus the payload length needs to + // be at least 4 (MAX_PACKET_NUMBER_LEN) plus any amount by which + // the header protection sample exceeds the AEAD expansion. + let crypto_pad = crypto.extra_padding(); + self.encoder.pad_to( + self.offsets.pn.start + MAX_PACKET_NUMBER_LEN + crypto_pad, + 0, + ); + } + + /// A lot of frames here are just a collection of varints. + /// This helper functions writes a frame like that safely, returning `true` if + /// a frame was written. + pub fn write_varint_frame(&mut self, values: &[u64]) -> bool { + let write = self.remaining() + >= values + .iter() + .map(|&v| Encoder::varint_len(v)) + .sum::<usize>(); + if write { + for v in values { + self.encode_varint(*v); + } + debug_assert!(self.len() <= self.limit()); + }; + write + } + + /// Build the packet and return the encoder. + pub fn build(mut self, crypto: &mut CryptoDxState) -> Res<Encoder> { + if self.len() > self.limit { + qwarn!("Packet contents are more than the limit"); + debug_assert!(false); + return Err(Error::InternalError(5)); + } + + self.pad_for_crypto(crypto); + if self.offsets.len > 0 { + self.write_len(crypto.expansion()); + } + + let hdr = &self.encoder.as_ref()[self.header.clone()]; + let body = &self.encoder.as_ref()[self.header.end..]; + qtrace!( + "Packet build pn={} hdr={} body={}", + self.pn, + hex(hdr), + hex(body) + ); + let ciphertext = crypto.encrypt(self.pn, hdr, body)?; + + // Calculate the mask. + let offset = SAMPLE_OFFSET - self.offsets.pn.len(); + assert!(offset + SAMPLE_SIZE <= ciphertext.len()); + let sample = &ciphertext[offset..offset + SAMPLE_SIZE]; + let mask = crypto.compute_mask(sample)?; + + // Apply the mask. + self.encoder.as_mut()[self.header.start] ^= mask[0] & self.offsets.first_byte_mask; + for (i, j) in (1..=self.offsets.pn.len()).zip(self.offsets.pn) { + self.encoder.as_mut()[j] ^= mask[i]; + } + + // Finally, cut off the plaintext and add back the ciphertext. + self.encoder.truncate(self.header.end); + self.encoder.encode(&ciphertext); + qtrace!("Packet built {}", hex(&self.encoder)); + Ok(self.encoder) + } + + /// Abort writing of this packet and return the encoder. + #[must_use] + pub fn abort(mut self) -> Encoder { + self.encoder.truncate(self.header.start); + self.encoder + } + + /// Work out if nothing was added after the header. + #[must_use] + pub fn packet_empty(&self) -> bool { + self.encoder.len() == self.header.end + } + + /// Make a retry packet. + /// As this is a simple packet, this is just an associated function. + /// As Retry is odd (it has to be constructed with leading bytes), + /// this returns a Vec<u8> rather than building on an encoder. + pub fn retry( + version: Version, + dcid: &[u8], + scid: &[u8], + token: &[u8], + odcid: &[u8], + ) -> Res<Vec<u8>> { + let mut encoder = Encoder::default(); + encoder.encode_vec(1, odcid); + let start = encoder.len(); + encoder.encode_byte( + PACKET_BIT_LONG + | PACKET_BIT_FIXED_QUIC + | (PacketType::Retry.to_byte(version) << 4) + | (random(1)[0] & 0xf), + ); + encoder.encode_uint(4, version.wire_version()); + encoder.encode_vec(1, dcid); + encoder.encode_vec(1, scid); + debug_assert_ne!(token.len(), 0); + encoder.encode(token); + let tag = retry::use_aead(version, |aead| { + let mut buf = vec![0; aead.expansion()]; + Ok(aead.encrypt(0, encoder.as_ref(), &[], &mut buf)?.to_vec()) + })?; + encoder.encode(&tag); + let mut complete: Vec<u8> = encoder.into(); + Ok(complete.split_off(start)) + } + + /// Make a Version Negotiation packet. + pub fn version_negotiation( + dcid: &[u8], + scid: &[u8], + client_version: u32, + versions: &[Version], + ) -> Vec<u8> { + let mut encoder = Encoder::default(); + let mut grease = random(4); + // This will not include the "QUIC bit" sometimes. Intentionally. + encoder.encode_byte(PACKET_BIT_LONG | (grease[3] & 0x7f)); + encoder.encode(&[0; 4]); // Zero version == VN. + encoder.encode_vec(1, dcid); + encoder.encode_vec(1, scid); + + for v in versions { + encoder.encode_uint(4, v.wire_version()); + } + // Add a greased version, using the randomness already generated. + for g in &mut grease[..3] { + *g = *g & 0xf0 | 0x0a; + } + + // Ensure our greased version does not collide with the client version + // by making the last byte differ from the client initial. + grease[3] = (client_version.wrapping_add(0x10) & 0xf0) as u8 | 0x0a; + encoder.encode(&grease[..4]); + + Vec::from(encoder) + } +} + +impl Deref for PacketBuilder { + type Target = Encoder; + + fn deref(&self) -> &Self::Target { + &self.encoder + } +} + +impl DerefMut for PacketBuilder { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.encoder + } +} + +impl From<PacketBuilder> for Encoder { + fn from(v: PacketBuilder) -> Self { + v.encoder + } +} + +/// PublicPacket holds information from packets that is public only. This allows for +/// processing of packets prior to decryption. +pub struct PublicPacket<'a> { + /// The packet type. + packet_type: PacketType, + /// The recovered destination connection ID. + dcid: ConnectionIdRef<'a>, + /// The source connection ID, if this is a long header packet. + scid: Option<ConnectionIdRef<'a>>, + /// Any token that is included in the packet (Retry always has a token; Initial sometimes does). + /// This is empty when there is no token. + token: &'a [u8], + /// The size of the header, not including the packet number. + header_len: usize, + /// Protocol version, if present in header. + version: Option<WireVersion>, + /// A reference to the entire packet, including the header. + data: &'a [u8], +} + +impl<'a> PublicPacket<'a> { + fn opt<T>(v: Option<T>) -> Res<T> { + if let Some(v) = v { + Ok(v) + } else { + Err(Error::NoMoreData) + } + } + + /// Decode the type-specific portions of a long header. + /// This includes reading the length and the remainder of the packet. + /// Returns a tuple of any token and the length of the header. + fn decode_long( + decoder: &mut Decoder<'a>, + packet_type: PacketType, + version: Version, + ) -> Res<(&'a [u8], usize)> { + if packet_type == PacketType::Retry { + let header_len = decoder.offset(); + let expansion = retry::expansion(version); + let token = Self::opt(decoder.decode(decoder.remaining() - expansion))?; + if token.is_empty() { + return Err(Error::InvalidPacket); + } + Self::opt(decoder.decode(expansion))?; + return Ok((token, header_len)); + } + let token = if packet_type == PacketType::Initial { + Self::opt(decoder.decode_vvec())? + } else { + &[] + }; + let len = Self::opt(decoder.decode_varint())?; + let header_len = decoder.offset(); + let _body = Self::opt(decoder.decode(usize::try_from(len)?))?; + Ok((token, header_len)) + } + + /// Decode the common parts of a packet. This provides minimal parsing and validation. + /// Returns a tuple of a `PublicPacket` and a slice with any remainder from the datagram. + pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> { + let mut decoder = Decoder::new(data); + let first = Self::opt(decoder.decode_byte())?; + + if first & 0x80 == PACKET_BIT_SHORT { + // Conveniently, this also guarantees that there is enough space + // for a connection ID of any size. + if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { + return Err(Error::InvalidPacket); + } + let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?; + if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { + return Err(Error::InvalidPacket); + } + let header_len = decoder.offset(); + return Ok(( + Self { + packet_type: PacketType::Short, + dcid, + scid: None, + token: &[], + header_len, + version: None, + data, + }, + &[], + )); + } + + // Generic long header. + let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?).unwrap(); + let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); + let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); + + // Version negotiation. + if version == 0 { + return Ok(( + Self { + packet_type: PacketType::VersionNegotiation, + dcid, + scid: Some(scid), + token: &[], + header_len: decoder.offset(), + version: None, + data, + }, + &[], + )); + } + + // Check that this is a long header from a supported version. + let version = if let Ok(v) = Version::try_from(version) { + v + } else { + return Ok(( + Self { + packet_type: PacketType::OtherVersion, + dcid, + scid: Some(scid), + token: &[], + header_len: decoder.offset(), + version: Some(version), + data, + }, + &[], + )); + }; + + if dcid.len() > MAX_CONNECTION_ID_LEN || scid.len() > MAX_CONNECTION_ID_LEN { + return Err(Error::InvalidPacket); + } + let packet_type = PacketType::from_byte((first >> 4) & 3, version); + + // The type-specific code includes a token. This consumes the remainder of the packet. + let (token, header_len) = Self::decode_long(&mut decoder, packet_type, version)?; + let end = data.len() - decoder.remaining(); + let (data, remainder) = data.split_at(end); + Ok(( + Self { + packet_type, + dcid, + scid: Some(scid), + token, + header_len, + version: Some(version.wire_version()), + data, + }, + remainder, + )) + } + + /// Validate the given packet as though it were a retry. + pub fn is_valid_retry(&self, odcid: &ConnectionId) -> bool { + if self.packet_type != PacketType::Retry { + return false; + } + let version = self.version().unwrap(); + let expansion = retry::expansion(version); + if self.data.len() <= expansion { + return false; + } + let (header, tag) = self.data.split_at(self.data.len() - expansion); + let mut encoder = Encoder::with_capacity(self.data.len()); + encoder.encode_vec(1, odcid); + encoder.encode(header); + retry::use_aead(version, |aead| { + let mut buf = vec![0; expansion]; + Ok(aead.decrypt(0, encoder.as_ref(), tag, &mut buf)?.is_empty()) + }) + .unwrap_or(false) + } + + pub fn is_valid_initial(&self) -> bool { + // Packet has to be an initial, with a DCID of 8 bytes, or a token. + // Note: the Server class validates the token and checks the length. + self.packet_type == PacketType::Initial + && (self.dcid().len() >= 8 || !self.token.is_empty()) + } + + pub fn packet_type(&self) -> PacketType { + self.packet_type + } + + pub fn dcid(&self) -> &ConnectionIdRef<'a> { + &self.dcid + } + + pub fn scid(&self) -> &ConnectionIdRef<'a> { + self.scid + .as_ref() + .expect("should only be called for long header packets") + } + + pub fn token(&self) -> &'a [u8] { + self.token + } + + pub fn version(&self) -> Option<Version> { + self.version.and_then(|v| Version::try_from(v).ok()) + } + + pub fn wire_version(&self) -> WireVersion { + debug_assert!(self.version.is_some()); + self.version.unwrap_or(0) + } + + pub fn len(&self) -> usize { + self.data.len() + } + + fn decode_pn(expected: PacketNumber, pn: u64, w: usize) -> PacketNumber { + let window = 1_u64 << (w * 8); + let candidate = (expected & !(window - 1)) | pn; + if candidate + (window / 2) <= expected { + candidate + window + } else if candidate > expected + (window / 2) { + match candidate.checked_sub(window) { + Some(pn_sub) => pn_sub, + None => candidate, + } + } else { + candidate + } + } + + /// Decrypt the header of the packet. + fn decrypt_header( + &self, + crypto: &mut CryptoDxState, + ) -> Res<(bool, PacketNumber, Vec<u8>, &'a [u8])> { + assert_ne!(self.packet_type, PacketType::Retry); + assert_ne!(self.packet_type, PacketType::VersionNegotiation); + + qtrace!( + "unmask hdr={}", + hex(&self.data[..self.header_len + SAMPLE_OFFSET]) + ); + + let sample_offset = self.header_len + SAMPLE_OFFSET; + let mask = if let Some(sample) = self.data.get(sample_offset..(sample_offset + SAMPLE_SIZE)) + { + crypto.compute_mask(sample) + } else { + Err(Error::NoMoreData) + }?; + + // Un-mask the leading byte. + let bits = if self.packet_type == PacketType::Short { + PACKET_HP_MASK_SHORT + } else { + PACKET_HP_MASK_LONG + }; + let first_byte = self.data[0] ^ (mask[0] & bits); + + // Make a copy of the header to work on. + let mut hdrbytes = self.data[..self.header_len + 4].to_vec(); + hdrbytes[0] = first_byte; + + // Unmask the PN. + let mut pn_encoded: u64 = 0; + for i in 0..MAX_PACKET_NUMBER_LEN { + hdrbytes[self.header_len + i] ^= mask[1 + i]; + pn_encoded <<= 8; + pn_encoded += u64::from(hdrbytes[self.header_len + i]); + } + + // Now decode the packet number length and apply it, hopefully in constant time. + let pn_len = usize::from((first_byte & 0x3) + 1); + hdrbytes.truncate(self.header_len + pn_len); + pn_encoded >>= 8 * (MAX_PACKET_NUMBER_LEN - pn_len); + + qtrace!("unmasked hdr={}", hex(&hdrbytes)); + + let key_phase = self.packet_type == PacketType::Short + && (first_byte & PACKET_BIT_KEY_PHASE) == PACKET_BIT_KEY_PHASE; + let pn = Self::decode_pn(crypto.next_pn(), pn_encoded, pn_len); + Ok(( + key_phase, + pn, + hdrbytes, + &self.data[self.header_len + pn_len..], + )) + } + + pub fn decrypt(&self, crypto: &mut CryptoStates, release_at: Instant) -> Res<DecryptedPacket> { + let cspace: CryptoSpace = self.packet_type.into(); + // When we don't have a version, the crypto code doesn't need a version + // for lookup, so use the default, but fix it up if decryption succeeds. + let version = self.version().unwrap_or_default(); + // This has to work in two stages because we need to remove header protection + // before picking the keys to use. + if let Some(rx) = crypto.rx_hp(version, cspace) { + // Note that this will dump early, which creates a side-channel. + // This is OK in this case because we the only reason this can + // fail is if the cryptographic module is bad or the packet is + // too small (which is public information). + let (key_phase, pn, header, body) = self.decrypt_header(rx)?; + qtrace!([rx], "decoded header: {:?}", header); + let rx = crypto.rx(version, cspace, key_phase).unwrap(); + let version = rx.version(); // Version fixup; see above. + let d = rx.decrypt(pn, &header, body)?; + // If this is the first packet ever successfully decrypted + // using `rx`, make sure to initiate a key update. + if rx.needs_update() { + crypto.key_update_received(release_at)?; + } + crypto.check_pn_overlap()?; + Ok(DecryptedPacket { + version, + pt: self.packet_type, + pn, + data: d, + }) + } else if crypto.rx_pending(cspace) { + Err(Error::KeysPending(cspace)) + } else { + qtrace!("keys for {:?} already discarded", cspace); + Err(Error::KeysDiscarded(cspace)) + } + } + + pub fn supported_versions(&self) -> Res<Vec<WireVersion>> { + assert_eq!(self.packet_type, PacketType::VersionNegotiation); + let mut decoder = Decoder::new(&self.data[self.header_len..]); + let mut res = Vec::new(); + while decoder.remaining() > 0 { + let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?)?; + res.push(version); + } + Ok(res) + } +} + +impl fmt::Debug for PublicPacket<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{:?}: {} {}", + self.packet_type(), + hex_with_len(&self.data[..self.header_len]), + hex_with_len(&self.data[self.header_len..]) + ) + } +} + +pub struct DecryptedPacket { + version: Version, + pt: PacketType, + pn: PacketNumber, + data: Vec<u8>, +} + +impl DecryptedPacket { + pub fn version(&self) -> Version { + self.version + } + + pub fn packet_type(&self) -> PacketType { + self.pt + } + + pub fn pn(&self) -> PacketNumber { + self.pn + } +} + +impl Deref for DecryptedPacket { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.data[..] + } +} + +#[cfg(all(test, not(feature = "fuzzing")))] +mod tests { + use super::*; + use crate::crypto::{CryptoDxState, CryptoStates}; + use crate::{EmptyConnectionIdGenerator, RandomConnectionIdGenerator, Version}; + use neqo_common::Encoder; + use test_fixture::{fixture_init, now}; + + const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; + const SERVER_CID: &[u8] = &[0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5]; + + /// This is a connection ID manager, which is only used for decoding short header packets. + fn cid_mgr() -> RandomConnectionIdGenerator { + RandomConnectionIdGenerator::new(SERVER_CID.len()) + } + + const SAMPLE_INITIAL_PAYLOAD: &[u8] = &[ + 0x02, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x40, 0x5a, 0x02, 0x00, 0x00, 0x56, 0x03, 0x03, + 0xee, 0xfc, 0xe7, 0xf7, 0xb3, 0x7b, 0xa1, 0xd1, 0x63, 0x2e, 0x96, 0x67, 0x78, 0x25, 0xdd, + 0xf7, 0x39, 0x88, 0xcf, 0xc7, 0x98, 0x25, 0xdf, 0x56, 0x6d, 0xc5, 0x43, 0x0b, 0x9a, 0x04, + 0x5a, 0x12, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00, + 0x20, 0x9d, 0x3c, 0x94, 0x0d, 0x89, 0x69, 0x0b, 0x84, 0xd0, 0x8a, 0x60, 0x99, 0x3c, 0x14, + 0x4e, 0xca, 0x68, 0x4d, 0x10, 0x81, 0x28, 0x7c, 0x83, 0x4d, 0x53, 0x11, 0xbc, 0xf3, 0x2b, + 0xb9, 0xda, 0x1a, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04, + ]; + const SAMPLE_INITIAL: &[u8] = &[ + 0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x00, 0x40, 0x75, 0xc0, 0xd9, 0x5a, 0x48, 0x2c, 0xd0, 0x99, 0x1c, 0xd2, 0x5b, 0x0a, 0xac, + 0x40, 0x6a, 0x58, 0x16, 0xb6, 0x39, 0x41, 0x00, 0xf3, 0x7a, 0x1c, 0x69, 0x79, 0x75, 0x54, + 0x78, 0x0b, 0xb3, 0x8c, 0xc5, 0xa9, 0x9f, 0x5e, 0xde, 0x4c, 0xf7, 0x3c, 0x3e, 0xc2, 0x49, + 0x3a, 0x18, 0x39, 0xb3, 0xdb, 0xcb, 0xa3, 0xf6, 0xea, 0x46, 0xc5, 0xb7, 0x68, 0x4d, 0xf3, + 0x54, 0x8e, 0x7d, 0xde, 0xb9, 0xc3, 0xbf, 0x9c, 0x73, 0xcc, 0x3f, 0x3b, 0xde, 0xd7, 0x4b, + 0x56, 0x2b, 0xfb, 0x19, 0xfb, 0x84, 0x02, 0x2f, 0x8e, 0xf4, 0xcd, 0xd9, 0x37, 0x95, 0xd7, + 0x7d, 0x06, 0xed, 0xbb, 0x7a, 0xaf, 0x2f, 0x58, 0x89, 0x18, 0x50, 0xab, 0xbd, 0xca, 0x3d, + 0x20, 0x39, 0x8c, 0x27, 0x64, 0x56, 0xcb, 0xc4, 0x21, 0x58, 0x40, 0x7d, 0xd0, 0x74, 0xee, + ]; + + #[test] + fn sample_server_initial() { + fixture_init(); + let mut prot = CryptoDxState::test_default(); + + // The spec uses PN=1, but our crypto refuses to skip packet numbers. + // So burn an encryption: + let burn = prot.encrypt(0, &[], &[]).expect("burn OK"); + assert_eq!(burn.len(), prot.expansion()); + + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Initial, + Version::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(SERVER_CID), + ); + builder.initial_token(&[]); + builder.pn(1, 2); + builder.encode(SAMPLE_INITIAL_PAYLOAD); + let packet = builder.build(&mut prot).expect("build"); + assert_eq!(packet.as_ref(), SAMPLE_INITIAL); + } + + #[test] + fn decrypt_initial() { + const EXTRA: &[u8] = &[0xce; 33]; + + fixture_init(); + let mut padded = SAMPLE_INITIAL.to_vec(); + padded.extend_from_slice(EXTRA); + let (packet, remainder) = PublicPacket::decode(&padded, &cid_mgr()).unwrap(); + assert_eq!(packet.packet_type(), PacketType::Initial); + assert_eq!(&packet.dcid()[..], &[] as &[u8]); + assert_eq!(&packet.scid()[..], SERVER_CID); + assert!(packet.token().is_empty()); + assert_eq!(remainder, EXTRA); + + let decrypted = packet + .decrypt(&mut CryptoStates::test_default(), now()) + .unwrap(); + assert_eq!(decrypted.pn(), 1); + } + + #[test] + fn disallow_long_dcid() { + let mut enc = Encoder::new(); + enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC); + enc.encode_uint(4, Version::default().wire_version()); + enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 1]); + enc.encode_vec(1, &[]); + enc.encode(&[0xff; 40]); // junk + + assert!(PublicPacket::decode(enc.as_ref(), &cid_mgr()).is_err()); + } + + #[test] + fn disallow_long_scid() { + let mut enc = Encoder::new(); + enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC); + enc.encode_uint(4, Version::default().wire_version()); + enc.encode_vec(1, &[]); + enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 2]); + enc.encode(&[0xff; 40]); // junk + + assert!(PublicPacket::decode(enc.as_ref(), &cid_mgr()).is_err()); + } + + const SAMPLE_SHORT: &[u8] = &[ + 0x40, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0xf4, 0xa8, 0x30, 0x39, 0xc4, 0x7d, + 0x99, 0xe3, 0x94, 0x1c, 0x9b, 0xb9, 0x7a, 0x30, 0x1d, 0xd5, 0x8f, 0xf3, 0xdd, 0xa9, + ]; + const SAMPLE_SHORT_PAYLOAD: &[u8] = &[0; 3]; + + #[test] + fn build_short() { + fixture_init(); + let mut builder = + PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID)); + builder.pn(0, 1); + builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling. + let packet = builder + .build(&mut CryptoDxState::test_default()) + .expect("build"); + assert_eq!(packet.as_ref(), SAMPLE_SHORT); + } + + #[test] + fn scramble_short() { + fixture_init(); + let mut firsts = Vec::new(); + for _ in 0..64 { + let mut builder = + PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID)); + builder.scramble(true); + builder.pn(0, 1); + firsts.push(builder.as_ref()[0]); + } + let is_set = |bit| move |v| v & bit == bit; + // There should be at least one value with the QUIC bit set: + assert!(firsts.iter().any(is_set(PACKET_BIT_FIXED_QUIC))); + // ... but not all: + assert!(!firsts.iter().all(is_set(PACKET_BIT_FIXED_QUIC))); + // There should be at least one value with the spin bit set: + assert!(firsts.iter().any(is_set(PACKET_BIT_SPIN))); + // ... but not all: + assert!(!firsts.iter().all(is_set(PACKET_BIT_SPIN))); + } + + #[test] + fn decode_short() { + fixture_init(); + let (packet, remainder) = PublicPacket::decode(SAMPLE_SHORT, &cid_mgr()).unwrap(); + assert_eq!(packet.packet_type(), PacketType::Short); + assert!(remainder.is_empty()); + let decrypted = packet + .decrypt(&mut CryptoStates::test_default(), now()) + .unwrap(); + assert_eq!(&decrypted[..], SAMPLE_SHORT_PAYLOAD); + } + + /// By telling the decoder that the connection ID is shorter than it really is, we get a decryption error. + #[test] + fn decode_short_bad_cid() { + fixture_init(); + let (packet, remainder) = PublicPacket::decode( + SAMPLE_SHORT, + &RandomConnectionIdGenerator::new(SERVER_CID.len() - 1), + ) + .unwrap(); + assert_eq!(packet.packet_type(), PacketType::Short); + assert!(remainder.is_empty()); + assert!(packet + .decrypt(&mut CryptoStates::test_default(), now()) + .is_err()); + } + + /// Saying that the connection ID is longer causes the initial decode to fail. + #[test] + fn decode_short_long_cid() { + assert!(PublicPacket::decode( + SAMPLE_SHORT, + &RandomConnectionIdGenerator::new(SERVER_CID.len() + 1) + ) + .is_err()); + } + + #[test] + fn build_two() { + fixture_init(); + let mut prot = CryptoDxState::test_default(); + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Handshake, + Version::default(), + &ConnectionId::from(SERVER_CID), + &ConnectionId::from(CLIENT_CID), + ); + builder.pn(0, 1); + builder.encode(&[0; 3]); + let encoder = builder.build(&mut prot).expect("build"); + assert_eq!(encoder.len(), 45); + let first = encoder.clone(); + + let mut builder = PacketBuilder::short(encoder, false, &ConnectionId::from(SERVER_CID)); + builder.pn(1, 3); + builder.encode(&[0]); // Minimal size (packet number is big enough). + let encoder = builder.build(&mut prot).expect("build"); + assert_eq!( + first.as_ref(), + &encoder.as_ref()[..first.len()], + "the first packet should be a prefix" + ); + assert_eq!(encoder.len(), 45 + 29); + } + + #[test] + fn build_long() { + const EXPECTED: &[u8] = &[ + 0xe4, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x40, 0x14, 0xfb, 0xa9, 0x32, 0x3a, 0xf8, + 0xbb, 0x18, 0x63, 0xc6, 0xbd, 0x78, 0x0e, 0xba, 0x0c, 0x98, 0x65, 0x58, 0xc9, 0x62, + 0x31, + ]; + + fixture_init(); + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Handshake, + Version::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(&[][..]), + ); + builder.pn(0, 1); + builder.encode(&[1, 2, 3]); + let packet = builder.build(&mut CryptoDxState::test_default()).unwrap(); + assert_eq!(packet.as_ref(), EXPECTED); + } + + #[test] + fn scramble_long() { + fixture_init(); + let mut found_unset = false; + let mut found_set = false; + for _ in 1..64 { + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Handshake, + Version::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(&[][..]), + ); + builder.pn(0, 1); + builder.scramble(true); + if (builder.as_ref()[0] & PACKET_BIT_FIXED_QUIC) == 0 { + found_unset = true; + } else { + found_set = true; + } + } + assert!(found_unset); + assert!(found_set); + } + + #[test] + fn build_abort() { + let mut builder = PacketBuilder::long( + Encoder::new(), + PacketType::Initial, + Version::default(), + &ConnectionId::from(&[][..]), + &ConnectionId::from(SERVER_CID), + ); + assert_ne!(builder.remaining(), 0); + builder.initial_token(&[]); + assert_ne!(builder.remaining(), 0); + builder.pn(1, 2); + assert_ne!(builder.remaining(), 0); + let encoder = builder.abort(); + assert!(encoder.is_empty()); + } + + #[test] + fn build_insufficient_space() { + fixture_init(); + + let mut builder = PacketBuilder::short( + Encoder::with_capacity(100), + true, + &ConnectionId::from(SERVER_CID), + ); + builder.pn(0, 1); + // Pad, but not up to the full capacity. Leave enough space for the + // AEAD expansion and some extra, but not for an entire long header. + builder.set_limit(75); + builder.enable_padding(true); + assert!(builder.pad()); + let encoder = builder.build(&mut CryptoDxState::test_default()).unwrap(); + let encoder_copy = encoder.clone(); + + let builder = PacketBuilder::long( + encoder, + PacketType::Initial, + Version::default(), + &ConnectionId::from(SERVER_CID), + &ConnectionId::from(SERVER_CID), + ); + assert_eq!(builder.remaining(), 0); + assert_eq!(builder.abort(), encoder_copy); + } + + const SAMPLE_RETRY_V2: &[u8] = &[ + 0xcf, 0x70, 0x9a, 0x50, 0xc4, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x1d, 0xc7, 0x11, 0x30, 0xcd, 0x1e, 0xd3, 0x9d, 0x6e, 0xfc, + 0xee, 0x5c, 0x85, 0x80, 0x65, 0x01, + ]; + + const SAMPLE_RETRY_V1: &[u8] = &[ + 0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x04, 0xa2, 0x65, 0xba, 0x2e, 0xff, 0x4d, 0x82, 0x90, 0x58, + 0xfb, 0x3f, 0x0f, 0x24, 0x96, 0xba, + ]; + + const SAMPLE_RETRY_29: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xd1, 0x69, 0x26, 0xd8, 0x1f, 0x6f, 0x9c, 0xa2, 0x95, 0x3a, + 0x8a, 0xa4, 0x57, 0x5e, 0x1e, 0x49, + ]; + + const SAMPLE_RETRY_30: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x1e, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2d, 0x3e, 0x04, 0x5d, 0x6d, 0x39, 0x20, 0x67, 0x89, 0x94, + 0x37, 0x10, 0x8c, 0xe0, 0x0a, 0x61, + ]; + + const SAMPLE_RETRY_31: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x1f, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xc7, 0x0c, 0xe5, 0xde, 0x43, 0x0b, 0x4b, 0xdb, 0x7d, 0xf1, + 0xa3, 0x83, 0x3a, 0x75, 0xf9, 0x86, + ]; + + const SAMPLE_RETRY_32: &[u8] = &[ + 0xff, 0xff, 0x00, 0x00, 0x20, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x59, 0x75, 0x65, 0x19, 0xdd, 0x6c, 0xc8, 0x5b, 0xd9, 0x0e, + 0x33, 0xa9, 0x34, 0xd2, 0xff, 0x85, + ]; + + const RETRY_TOKEN: &[u8] = b"token"; + + fn build_retry_single(version: Version, sample_retry: &[u8]) { + fixture_init(); + let retry = + PacketBuilder::retry(version, &[], SERVER_CID, RETRY_TOKEN, CLIENT_CID).unwrap(); + + let (packet, remainder) = PublicPacket::decode(&retry, &cid_mgr()).unwrap(); + assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); + assert!(remainder.is_empty()); + + // The builder adds randomness, which makes expectations hard. + // So only do a full check when that randomness matches up. + if retry[0] == sample_retry[0] { + assert_eq!(&retry, &sample_retry); + } else { + // Otherwise, just check that the header is OK. + assert_eq!( + retry[0] & 0xf0, + 0xc0 | (PacketType::Retry.to_byte(version) << 4) + ); + let header_range = 1..retry.len() - 16; + assert_eq!(&retry[header_range.clone()], &sample_retry[header_range]); + } + } + + #[test] + fn build_retry_v2() { + build_retry_single(Version::Version2, SAMPLE_RETRY_V2); + } + + #[test] + fn build_retry_v1() { + build_retry_single(Version::Version1, SAMPLE_RETRY_V1); + } + + #[test] + fn build_retry_29() { + build_retry_single(Version::Draft29, SAMPLE_RETRY_29); + } + + #[test] + fn build_retry_30() { + build_retry_single(Version::Draft30, SAMPLE_RETRY_30); + } + + #[test] + fn build_retry_31() { + build_retry_single(Version::Draft31, SAMPLE_RETRY_31); + } + + #[test] + fn build_retry_32() { + build_retry_single(Version::Draft32, SAMPLE_RETRY_32); + } + + #[test] + fn build_retry_multiple() { + // Run the build_retry test a few times. + // Odds are approximately 1 in 8 that the full comparison doesn't happen + // for a given version. + for _ in 0..32 { + build_retry_v2(); + build_retry_v1(); + build_retry_29(); + build_retry_30(); + build_retry_31(); + build_retry_32(); + } + } + + fn decode_retry(version: Version, sample_retry: &[u8]) { + fixture_init(); + let (packet, remainder) = + PublicPacket::decode(sample_retry, &RandomConnectionIdGenerator::new(5)).unwrap(); + assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); + assert_eq!(Some(version), packet.version()); + assert!(packet.dcid().is_empty()); + assert_eq!(&packet.scid()[..], SERVER_CID); + assert_eq!(packet.token(), RETRY_TOKEN); + assert!(remainder.is_empty()); + } + + #[test] + fn decode_retry_v2() { + decode_retry(Version::Version2, SAMPLE_RETRY_V2); + } + + #[test] + fn decode_retry_v1() { + decode_retry(Version::Version1, SAMPLE_RETRY_V1); + } + + #[test] + fn decode_retry_29() { + decode_retry(Version::Draft29, SAMPLE_RETRY_29); + } + + #[test] + fn decode_retry_30() { + decode_retry(Version::Draft30, SAMPLE_RETRY_30); + } + + #[test] + fn decode_retry_31() { + decode_retry(Version::Draft31, SAMPLE_RETRY_31); + } + + #[test] + fn decode_retry_32() { + decode_retry(Version::Draft32, SAMPLE_RETRY_32); + } + + /// Check some packets that are clearly not valid Retry packets. + #[test] + fn invalid_retry() { + fixture_init(); + let cid_mgr = RandomConnectionIdGenerator::new(5); + let odcid = ConnectionId::from(CLIENT_CID); + + assert!(PublicPacket::decode(&[], &cid_mgr).is_err()); + + let (packet, remainder) = PublicPacket::decode(SAMPLE_RETRY_V1, &cid_mgr).unwrap(); + assert!(remainder.is_empty()); + assert!(packet.is_valid_retry(&odcid)); + + let mut damaged_retry = SAMPLE_RETRY_V1.to_vec(); + let last = damaged_retry.len() - 1; + damaged_retry[last] ^= 66; + let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap(); + assert!(remainder.is_empty()); + assert!(!packet.is_valid_retry(&odcid)); + + damaged_retry.truncate(last); + let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap(); + assert!(remainder.is_empty()); + assert!(!packet.is_valid_retry(&odcid)); + + // An invalid token should be rejected sooner. + damaged_retry.truncate(last - 4); + assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err()); + + damaged_retry.truncate(last - 1); + assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err()); + } + + const SAMPLE_VN: &[u8] = &[ + 0x80, 0x00, 0x00, 0x00, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x08, + 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, 0x70, 0x9a, 0x50, 0xc4, 0x00, 0x00, 0x00, + 0x01, 0xff, 0x00, 0x00, 0x20, 0xff, 0x00, 0x00, 0x1f, 0xff, 0x00, 0x00, 0x1e, 0xff, 0x00, + 0x00, 0x1d, 0x0a, 0x0a, 0x0a, 0x0a, + ]; + + #[test] + fn build_vn() { + fixture_init(); + let mut vn = + PacketBuilder::version_negotiation(SERVER_CID, CLIENT_CID, 0x0a0a0a0a, &Version::all()); + // Erase randomness from greasing... + assert_eq!(vn.len(), SAMPLE_VN.len()); + vn[0] &= 0x80; + for v in vn.iter_mut().skip(SAMPLE_VN.len() - 4) { + *v &= 0x0f; + } + assert_eq!(&vn, &SAMPLE_VN); + } + + #[test] + fn vn_do_not_repeat_client_grease() { + fixture_init(); + let vn = + PacketBuilder::version_negotiation(SERVER_CID, CLIENT_CID, 0x0a0a0a0a, &Version::all()); + assert_ne!(&vn[SAMPLE_VN.len() - 4..], &[0x0a, 0x0a, 0x0a, 0x0a]); + } + + #[test] + fn parse_vn() { + let (packet, remainder) = + PublicPacket::decode(SAMPLE_VN, &EmptyConnectionIdGenerator::default()).unwrap(); + assert!(remainder.is_empty()); + assert_eq!(&packet.dcid[..], SERVER_CID); + assert!(packet.scid.is_some()); + assert_eq!(&packet.scid.unwrap()[..], CLIENT_CID); + } + + /// A Version Negotiation packet can have a long connection ID. + #[test] + fn parse_vn_big_cid() { + const BIG_DCID: &[u8] = &[0x44; MAX_CONNECTION_ID_LEN + 1]; + const BIG_SCID: &[u8] = &[0xee; 255]; + + let mut enc = Encoder::from(&[0xff, 0x00, 0x00, 0x00, 0x00][..]); + enc.encode_vec(1, BIG_DCID); + enc.encode_vec(1, BIG_SCID); + enc.encode_uint(4, 0x1a2a_3a4a_u64); + enc.encode_uint(4, Version::default().wire_version()); + enc.encode_uint(4, 0x5a6a_7a8a_u64); + + let (packet, remainder) = + PublicPacket::decode(enc.as_ref(), &EmptyConnectionIdGenerator::default()).unwrap(); + assert!(remainder.is_empty()); + assert_eq!(&packet.dcid[..], BIG_DCID); + assert!(packet.scid.is_some()); + assert_eq!(&packet.scid.unwrap()[..], BIG_SCID); + } + + #[test] + fn decode_pn() { + // When the expected value is low, the value doesn't go negative. + assert_eq!(PublicPacket::decode_pn(0, 0, 1), 0); + assert_eq!(PublicPacket::decode_pn(0, 0xff, 1), 0xff); + assert_eq!(PublicPacket::decode_pn(10, 0, 1), 0); + assert_eq!(PublicPacket::decode_pn(0x7f, 0, 1), 0); + assert_eq!(PublicPacket::decode_pn(0x80, 0, 1), 0x100); + assert_eq!(PublicPacket::decode_pn(0x80, 2, 1), 2); + assert_eq!(PublicPacket::decode_pn(0x80, 0xff, 1), 0xff); + assert_eq!(PublicPacket::decode_pn(0x7ff, 0xfe, 1), 0x7fe); + + // This is invalid by spec, as we are expected to check for overflow around 2^62-1, + // but we don't need to worry about overflow + // and hitting this is basically impossible in practice. + assert_eq!( + PublicPacket::decode_pn(0x3fff_ffff_ffff_ffff, 2, 4), + 0x4000_0000_0000_0002 + ); + } + + #[test] + fn chacha20_sample() { + const PACKET: &[u8] = &[ + 0x4c, 0xfe, 0x41, 0x89, 0x65, 0x5e, 0x5c, 0xd5, 0x5c, 0x41, 0xf6, 0x90, 0x80, 0x57, + 0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb, + ]; + fixture_init(); + let (packet, slice) = + PublicPacket::decode(PACKET, &EmptyConnectionIdGenerator::default()).unwrap(); + assert!(slice.is_empty()); + let decrypted = packet + .decrypt(&mut CryptoStates::test_chacha(), now()) + .unwrap(); + assert_eq!(decrypted.packet_type(), PacketType::Short); + assert_eq!(decrypted.pn(), 654_360_564); + assert_eq!(&decrypted[..], &[0x01]); + } +} |