diff options
Diffstat (limited to 'third_party/rust/neqo-common/src')
-rw-r--r-- | third_party/rust/neqo-common/src/codec.rs | 847 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/datagram.rs | 91 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/event.rs | 52 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/header.rs | 48 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/hrtime.rs | 485 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/incrdecoder.rs | 275 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/lib.rs | 109 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/log.rs | 104 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/qlog.rs | 188 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/timer.rs | 396 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/tos.rs | 290 |
11 files changed, 2885 insertions, 0 deletions
diff --git a/third_party/rust/neqo-common/src/codec.rs b/third_party/rust/neqo-common/src/codec.rs new file mode 100644 index 0000000000..57ff13f39f --- /dev/null +++ b/third_party/rust/neqo-common/src/codec.rs @@ -0,0 +1,847 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{convert::TryFrom, fmt::Debug}; + +use crate::hex_with_len; + +/// Decoder is a view into a byte array that has a read offset. Use it for parsing. +pub struct Decoder<'a> { + buf: &'a [u8], + offset: usize, +} + +impl<'a> Decoder<'a> { + /// Make a new view of the provided slice. + #[must_use] + pub fn new(buf: &[u8]) -> Decoder { + Decoder { buf, offset: 0 } + } + + /// Get the number of bytes remaining until the end. + #[must_use] + pub fn remaining(&self) -> usize { + self.buf.len() - self.offset + } + + /// The number of bytes from the underlying slice that have been decoded. + #[must_use] + pub fn offset(&self) -> usize { + self.offset + } + + /// Skip n bytes. + /// + /// # Panics + /// + /// If the remaining quantity is less than `n`. + pub fn skip(&mut self, n: usize) { + assert!(self.remaining() >= n, "insufficient data"); + self.offset += n; + } + + /// Skip helper that panics if `n` is `None` or not able to fit in `usize`. + fn skip_inner(&mut self, n: Option<u64>) { + self.skip(usize::try_from(n.expect("invalid length")).unwrap()); + } + + /// Skip a vector. Panics if there isn't enough space. + /// Only use this for tests because we panic rather than reporting a result. + pub fn skip_vec(&mut self, n: usize) { + let len = self.decode_uint(n); + self.skip_inner(len); + } + + /// Skip a variable length vector. Panics if there isn't enough space. + /// Only use this for tests because we panic rather than reporting a result. + pub fn skip_vvec(&mut self) { + let len = self.decode_varint(); + self.skip_inner(len); + } + + /// Decodes (reads) a single byte. + pub fn decode_byte(&mut self) -> Option<u8> { + if self.remaining() < 1 { + return None; + } + let b = self.buf[self.offset]; + self.offset += 1; + Some(b) + } + + /// Provides the next byte without moving the read position. + pub fn peek_byte(&mut self) -> Option<u8> { + if self.remaining() < 1 { + None + } else { + Some(self.buf[self.offset]) + } + } + + /// Decodes arbitrary data. + pub fn decode(&mut self, n: usize) -> Option<&'a [u8]> { + if self.remaining() < n { + return None; + } + let res = &self.buf[self.offset..self.offset + n]; + self.offset += n; + Some(res) + } + + /// Decodes an unsigned integer of length 1..=8. + /// + /// # Panics + /// + /// This panics if `n` is not in the range `1..=8`. + pub fn decode_uint(&mut self, n: usize) -> Option<u64> { + assert!(n > 0 && n <= 8); + if self.remaining() < n { + return None; + } + let mut v = 0_u64; + for i in 0..n { + let b = self.buf[self.offset + i]; + v = v << 8 | u64::from(b); + } + self.offset += n; + Some(v) + } + + /// Decodes a QUIC varint. + pub fn decode_varint(&mut self) -> Option<u64> { + let Some(b1) = self.decode_byte() else { + return None; + }; + match b1 >> 6 { + 0 => Some(u64::from(b1 & 0x3f)), + 1 => Some((u64::from(b1 & 0x3f) << 8) | self.decode_uint(1)?), + 2 => Some((u64::from(b1 & 0x3f) << 24) | self.decode_uint(3)?), + 3 => Some((u64::from(b1 & 0x3f) << 56) | self.decode_uint(7)?), + _ => unreachable!(), + } + } + + /// Decodes the rest of the buffer. Infallible. + pub fn decode_remainder(&mut self) -> &'a [u8] { + let res = &self.buf[self.offset..]; + self.offset = self.buf.len(); + res + } + + fn decode_checked(&mut self, n: Option<u64>) -> Option<&'a [u8]> { + if let Ok(l) = usize::try_from(n?) { + self.decode(l) + } else { + // sizeof(usize) < sizeof(u64) and the value is greater than + // usize can hold. Throw away the rest of the input. + self.offset = self.buf.len(); + None + } + } + + /// Decodes a TLS-style length-prefixed buffer. + pub fn decode_vec(&mut self, n: usize) -> Option<&'a [u8]> { + let len = self.decode_uint(n); + self.decode_checked(len) + } + + /// Decodes a QUIC varint-length-prefixed buffer. + pub fn decode_vvec(&mut self) -> Option<&'a [u8]> { + let len = self.decode_varint(); + self.decode_checked(len) + } +} + +// Implement `AsRef` for `Decoder` so that values can be examined without +// moving the cursor. +impl<'a> AsRef<[u8]> for Decoder<'a> { + #[must_use] + fn as_ref(&self) -> &'a [u8] { + &self.buf[self.offset..] + } +} + +impl<'a> Debug for Decoder<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str(&hex_with_len(self.as_ref())) + } +} + +impl<'a> From<&'a [u8]> for Decoder<'a> { + #[must_use] + fn from(buf: &'a [u8]) -> Decoder<'a> { + Decoder::new(buf) + } +} + +impl<'a, T> From<&'a T> for Decoder<'a> +where + T: AsRef<[u8]>, +{ + #[must_use] + fn from(buf: &'a T) -> Decoder<'a> { + Decoder::new(buf.as_ref()) + } +} + +impl<'a, 'b> PartialEq<Decoder<'b>> for Decoder<'a> { + #[must_use] + fn eq(&self, other: &Decoder<'b>) -> bool { + self.buf == other.buf + } +} + +/// Encoder is good for building data structures. +#[derive(Clone, Default, PartialEq, Eq)] +pub struct Encoder { + buf: Vec<u8>, +} + +impl Encoder { + /// Static helper function for previewing the results of encoding without doing it. + /// + /// # Panics + /// + /// When `v` is too large. + #[must_use] + pub const fn varint_len(v: u64) -> usize { + match () { + () if v < (1 << 6) => 1, + () if v < (1 << 14) => 2, + () if v < (1 << 30) => 4, + () if v < (1 << 62) => 8, + () => panic!("Varint value too large"), + } + } + + /// Static helper to determine how long a varint-prefixed array encodes to. + /// + /// # Panics + /// + /// When `len` doesn't fit in a `u64`. + #[must_use] + pub fn vvec_len(len: usize) -> usize { + Self::varint_len(u64::try_from(len).unwrap()) + len + } + + /// Default construction of an empty buffer. + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Construction of a buffer with a predetermined capacity. + #[must_use] + pub fn with_capacity(capacity: usize) -> Self { + Self { + buf: Vec::with_capacity(capacity), + } + } + + /// Get the capacity of the underlying buffer: the number of bytes that can be + /// written without causing an allocation to occur. + #[must_use] + pub fn capacity(&self) -> usize { + self.buf.capacity() + } + + /// Get the length of the underlying buffer: the number of bytes that have + /// been written to the buffer. + #[must_use] + pub fn len(&self) -> usize { + self.buf.len() + } + + /// Returns true if the encoder buffer contains no elements. + #[must_use] + pub fn is_empty(&self) -> bool { + self.buf.is_empty() + } + + /// Create a view of the current contents of the buffer. + /// Note: for a view of a slice, use `Decoder::new(&enc[s..e])` + #[must_use] + pub fn as_decoder(&self) -> Decoder { + Decoder::new(self.as_ref()) + } + + /// Don't use this except in testing. + /// + /// # Panics + /// + /// When `s` contains non-hex values or an odd number of values. + #[must_use] + pub fn from_hex(s: impl AsRef<str>) -> Self { + let s = s.as_ref(); + assert_eq!(s.len() % 2, 0, "Needs to be even length"); + + let cap = s.len() / 2; + let mut enc = Self::with_capacity(cap); + + for i in 0..cap { + let v = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16).unwrap(); + enc.encode_byte(v); + } + enc + } + + /// Generic encode routine for arbitrary data. + pub fn encode(&mut self, data: &[u8]) -> &mut Self { + self.buf.extend_from_slice(data.as_ref()); + self + } + + /// Encode a single byte. + pub fn encode_byte(&mut self, data: u8) -> &mut Self { + self.buf.push(data); + self + } + + /// Encode an integer of any size up to u64. + /// + /// # Panics + /// + /// When `n` is outside the range `1..=8`. + #[allow(clippy::cast_possible_truncation)] + pub fn encode_uint<T: Into<u64>>(&mut self, n: usize, v: T) -> &mut Self { + let v = v.into(); + assert!(n > 0 && n <= 8); + for i in 0..n { + self.encode_byte(((v >> (8 * (n - i - 1))) & 0xff) as u8); + } + self + } + + /// Encode a QUIC varint. + /// + /// # Panics + /// + /// When `v >= 1<<62`. + pub fn encode_varint<T: Into<u64>>(&mut self, v: T) -> &mut Self { + let v = v.into(); + match () { + () if v < (1 << 6) => self.encode_uint(1, v), + () if v < (1 << 14) => self.encode_uint(2, v | (1 << 14)), + () if v < (1 << 30) => self.encode_uint(4, v | (2 << 30)), + () if v < (1 << 62) => self.encode_uint(8, v | (3 << 62)), + () => panic!("Varint value too large"), + }; + self + } + + /// Encode a vector in TLS style. + /// + /// # Panics + /// + /// When `v` is longer than 2^64. + pub fn encode_vec(&mut self, n: usize, v: &[u8]) -> &mut Self { + self.encode_uint(n, u64::try_from(v.as_ref().len()).unwrap()) + .encode(v) + } + + /// Encode a vector in TLS style using a closure for the contents. + /// + /// # Panics + /// + /// When `f()` returns a length larger than `2^8n`. + #[allow(clippy::cast_possible_truncation)] + pub fn encode_vec_with<F: FnOnce(&mut Self)>(&mut self, n: usize, f: F) -> &mut Self { + let start = self.buf.len(); + self.buf.resize(self.buf.len() + n, 0); + f(self); + let len = self.buf.len() - start - n; + assert!(len < (1 << (n * 8))); + for i in 0..n { + self.buf[start + i] = ((len >> (8 * (n - i - 1))) & 0xff) as u8; + } + self + } + + /// Encode a vector with a varint length. + /// + /// # Panics + /// + /// When `v` is longer than 2^64. + pub fn encode_vvec(&mut self, v: &[u8]) -> &mut Self { + self.encode_varint(u64::try_from(v.as_ref().len()).unwrap()) + .encode(v) + } + + /// Encode a vector with a varint length using a closure. + /// + /// # Panics + /// + /// When `f()` writes more than 2^62 bytes. + #[allow(clippy::cast_possible_truncation)] + pub fn encode_vvec_with<F: FnOnce(&mut Self)>(&mut self, f: F) -> &mut Self { + let start = self.buf.len(); + // Optimize for short buffers, reserve a single byte for the length. + self.buf.resize(self.buf.len() + 1, 0); + f(self); + let len = self.buf.len() - start - 1; + + // Now to insert a varint for `len` before the encoded block. + // + // We now have one zero byte at `start`, followed by `len` encoded bytes: + // | 0 | ... encoded ... | + // We are going to encode a varint by putting the low bytes in that spare byte. + // Any additional bytes for the varint are put after the encoded blob: + // | low | ... encoded ... | varint high | + // Then we will rotate that entire piece right, by however many bytes we add: + // | varint high | low | ... encoded ... | + // As long as encoding more than 63 bytes is rare, this won't cost much relative + // to the convenience of being able to use this function. + + let v = u64::try_from(len).expect("encoded value fits in a u64"); + // The lower order byte fits before the inserted block of bytes. + self.buf[start] = (v & 0xff) as u8; + let (count, bits) = match () { + // Great. The byte we have is enough. + () if v < (1 << 6) => return self, + () if v < (1 << 14) => (1, 1 << 6), + () if v < (1 << 30) => (3, 2 << 22), + () if v < (1 << 62) => (7, 3 << 54), + () => panic!("Varint value too large"), + }; + // Now, we need to encode the high bits after the main block, ... + self.encode_uint(count, (v >> 8) | bits); + // ..., then rotate the entire thing right by the same amount. + self.buf[start..].rotate_right(count); + self + } + + /// Truncate the encoder to the given size. + pub fn truncate(&mut self, len: usize) { + self.buf.truncate(len); + } + + /// Pad the buffer to `len` with bytes set to `v`. + pub fn pad_to(&mut self, len: usize, v: u8) { + if len > self.buf.len() { + self.buf.resize(len, v); + } + } +} + +impl Debug for Encoder { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str(&hex_with_len(self)) + } +} + +impl AsRef<[u8]> for Encoder { + fn as_ref(&self) -> &[u8] { + self.buf.as_ref() + } +} + +impl AsMut<[u8]> for Encoder { + fn as_mut(&mut self) -> &mut [u8] { + self.buf.as_mut() + } +} + +impl<'a> From<Decoder<'a>> for Encoder { + #[must_use] + fn from(dec: Decoder<'a>) -> Self { + Self::from(&dec.buf[dec.offset..]) + } +} + +impl From<&[u8]> for Encoder { + #[must_use] + fn from(buf: &[u8]) -> Self { + Self { + buf: Vec::from(buf), + } + } +} + +impl From<Encoder> for Vec<u8> { + #[must_use] + fn from(buf: Encoder) -> Self { + buf.buf + } +} + +#[cfg(test)] +mod tests { + use super::{Decoder, Encoder}; + + #[test] + fn decode() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode(2).unwrap(), &[0x01, 0x23]); + assert!(dec.decode(2).is_none()); + } + + #[test] + fn decode_byte() { + let enc = Encoder::from_hex("0123"); + let mut dec = enc.as_decoder(); + + assert_eq!(dec.decode_byte().unwrap(), 0x01); + assert_eq!(dec.decode_byte().unwrap(), 0x23); + assert!(dec.decode_byte().is_none()); + } + + #[test] + fn decode_byte_short() { + let enc = Encoder::from_hex(""); + let mut dec = enc.as_decoder(); + assert!(dec.decode_byte().is_none()); + } + + #[test] + fn decode_remainder() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode_remainder(), &[0x01, 0x23, 0x45]); + assert!(dec.decode(2).is_none()); + + let mut dec = Decoder::from(&[]); + assert_eq!(dec.decode_remainder().len(), 0); + } + + #[test] + fn decode_vec() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode_vec(1).expect("read one octet length"), &[0x23]); + assert_eq!(dec.remaining(), 1); + + let enc = Encoder::from_hex("00012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode_vec(2).expect("read two octet length"), &[0x23]); + assert_eq!(dec.remaining(), 1); + } + + #[test] + fn decode_vec_short() { + // The length is too short. + let enc = Encoder::from_hex("02"); + let mut dec = enc.as_decoder(); + assert!(dec.decode_vec(2).is_none()); + + // The body is too short. + let enc = Encoder::from_hex("0200"); + let mut dec = enc.as_decoder(); + assert!(dec.decode_vec(1).is_none()); + } + + #[test] + fn decode_vvec() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode_vvec().expect("read one octet length"), &[0x23]); + assert_eq!(dec.remaining(), 1); + + let enc = Encoder::from_hex("40012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode_vvec().expect("read two octet length"), &[0x23]); + assert_eq!(dec.remaining(), 1); + } + + #[test] + fn decode_vvec_short() { + // The length field is too short. + let enc = Encoder::from_hex("ff"); + let mut dec = enc.as_decoder(); + assert!(dec.decode_vvec().is_none()); + + let enc = Encoder::from_hex("405500"); + let mut dec = enc.as_decoder(); + assert!(dec.decode_vvec().is_none()); + } + + #[test] + fn skip() { + let enc = Encoder::from_hex("ffff"); + let mut dec = enc.as_decoder(); + dec.skip(1); + assert_eq!(dec.remaining(), 1); + } + + #[test] + #[should_panic(expected = "insufficient data")] + fn skip_too_much() { + let enc = Encoder::from_hex("ff"); + let mut dec = enc.as_decoder(); + dec.skip(2); + } + + #[test] + fn skip_vec() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + dec.skip_vec(1); + assert_eq!(dec.remaining(), 1); + } + + #[test] + #[should_panic(expected = "insufficient data")] + fn skip_vec_too_much() { + let enc = Encoder::from_hex("ff1234"); + let mut dec = enc.as_decoder(); + dec.skip_vec(1); + } + + #[test] + #[should_panic(expected = "invalid length")] + fn skip_vec_short_length() { + let enc = Encoder::from_hex("ff"); + let mut dec = enc.as_decoder(); + dec.skip_vec(4); + } + #[test] + fn skip_vvec() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + dec.skip_vvec(); + assert_eq!(dec.remaining(), 1); + } + + #[test] + #[should_panic(expected = "insufficient data")] + fn skip_vvec_too_much() { + let enc = Encoder::from_hex("0f1234"); + let mut dec = enc.as_decoder(); + dec.skip_vvec(); + } + + #[test] + #[should_panic(expected = "invalid length")] + fn skip_vvec_short_length() { + let enc = Encoder::from_hex("ff"); + let mut dec = enc.as_decoder(); + dec.skip_vvec(); + } + + #[test] + fn encoded_lengths() { + assert_eq!(Encoder::varint_len(0), 1); + assert_eq!(Encoder::varint_len(0x3f), 1); + assert_eq!(Encoder::varint_len(0x40), 2); + assert_eq!(Encoder::varint_len(0x3fff), 2); + assert_eq!(Encoder::varint_len(0x4000), 4); + assert_eq!(Encoder::varint_len(0x3fff_ffff), 4); + assert_eq!(Encoder::varint_len(0x4000_0000), 8); + } + + #[test] + #[should_panic(expected = "Varint value too large")] + fn encoded_length_oob() { + _ = Encoder::varint_len(1 << 62); + } + + #[test] + fn encoded_vvec_lengths() { + assert_eq!(Encoder::vvec_len(0), 1); + assert_eq!(Encoder::vvec_len(0x3f), 0x40); + assert_eq!(Encoder::vvec_len(0x40), 0x42); + assert_eq!(Encoder::vvec_len(0x3fff), 0x4001); + assert_eq!(Encoder::vvec_len(0x4000), 0x4004); + assert_eq!(Encoder::vvec_len(0x3fff_ffff), 0x4000_0003); + assert_eq!(Encoder::vvec_len(0x4000_0000), 0x4000_0008); + } + + #[test] + #[should_panic(expected = "Varint value too large")] + fn encoded_vvec_length_oob() { + _ = Encoder::vvec_len(1 << 62); + } + + #[test] + fn encode_byte() { + let mut enc = Encoder::default(); + + enc.encode_byte(1); + assert_eq!(enc, Encoder::from_hex("01")); + + enc.encode_byte(0xfe); + assert_eq!(enc, Encoder::from_hex("01fe")); + } + + #[test] + fn encode() { + let mut enc = Encoder::default(); + enc.encode(&[1, 2, 3]); + assert_eq!(enc, Encoder::from_hex("010203")); + } + + #[test] + fn encode_uint() { + let mut enc = Encoder::default(); + enc.encode_uint(2, 10_u8); // 000a + enc.encode_uint(1, 257_u16); // 01 + enc.encode_uint(3, 0xff_ffff_u32); // ffffff + enc.encode_uint(8, 0xfedc_ba98_7654_3210_u64); + assert_eq!(enc, Encoder::from_hex("000a01fffffffedcba9876543210")); + } + + #[test] + fn builder_from_slice() { + let slice = &[1, 2, 3]; + let enc = Encoder::from(&slice[..]); + assert_eq!(enc, Encoder::from_hex("010203")); + } + + #[test] + fn builder_inas_decoder() { + let enc = Encoder::from_hex("010203"); + let buf = &[1, 2, 3]; + assert_eq!(enc.as_decoder(), Decoder::new(buf)); + } + + struct UintTestCase { + v: u64, + b: String, + } + + macro_rules! uint_tc { + [$( $v:expr => $b:expr ),+ $(,)?] => { + vec![ $( UintTestCase { v: $v, b: String::from($b) } ),+] + }; + } + + #[test] + fn varint_encode_decode() { + let cases = uint_tc![ + 0 => "00", + 1 => "01", + 63 => "3f", + 64 => "4040", + 16383 => "7fff", + 16384 => "80004000", + (1 << 30) - 1 => "bfffffff", + 1 << 30 => "c000000040000000", + (1 << 62) - 1 => "ffffffffffffffff", + ]; + + for c in cases { + assert_eq!(Encoder::varint_len(c.v), c.b.len() / 2); + + let mut enc = Encoder::default(); + enc.encode_varint(c.v); + let encoded = Encoder::from_hex(&c.b); + assert_eq!(enc, encoded); + + let mut dec = encoded.as_decoder(); + let v = dec.decode_varint().expect("should decode"); + assert_eq!(dec.remaining(), 0); + assert_eq!(v, c.v); + } + } + + #[test] + fn varint_decode_long_zero() { + for c in &["4000", "80000000", "c000000000000000"] { + let encoded = Encoder::from_hex(c); + let mut dec = encoded.as_decoder(); + let v = dec.decode_varint().expect("should decode"); + assert_eq!(dec.remaining(), 0); + assert_eq!(v, 0); + } + } + + #[test] + fn varint_decode_short() { + for c in &["40", "800000", "c0000000000000"] { + let encoded = Encoder::from_hex(c); + let mut dec = encoded.as_decoder(); + assert!(dec.decode_varint().is_none()); + } + } + + #[test] + fn encode_vec() { + let mut enc = Encoder::default(); + enc.encode_vec(2, &[1, 2, 0x34]); + assert_eq!(enc, Encoder::from_hex("0003010234")); + } + + #[test] + fn encode_vec_with() { + let mut enc = Encoder::default(); + enc.encode_vec_with(2, |enc_inner| { + enc_inner.encode(Encoder::from_hex("02").as_ref()); + }); + assert_eq!(enc, Encoder::from_hex("000102")); + } + + #[test] + #[should_panic(expected = "assertion failed")] + fn encode_vec_with_overflow() { + let mut enc = Encoder::default(); + enc.encode_vec_with(1, |enc_inner| { + enc_inner.encode(&[0xb0; 256]); + }); + } + + #[test] + fn encode_vvec() { + let mut enc = Encoder::default(); + enc.encode_vvec(&[1, 2, 0x34]); + assert_eq!(enc, Encoder::from_hex("03010234")); + } + + #[test] + fn encode_vvec_with() { + let mut enc = Encoder::default(); + enc.encode_vvec_with(|enc_inner| { + enc_inner.encode(Encoder::from_hex("02").as_ref()); + }); + assert_eq!(enc, Encoder::from_hex("0102")); + } + + #[test] + fn encode_vvec_with_longer() { + let mut enc = Encoder::default(); + enc.encode_vvec_with(|enc_inner| { + enc_inner.encode(&[0xa5; 65]); + }); + let v: Vec<u8> = enc.into(); + assert_eq!(&v[..3], &[0x40, 0x41, 0xa5]); + } + + // Test that Deref to &[u8] works for Encoder. + #[test] + fn encode_builder() { + let mut enc = Encoder::from_hex("ff"); + let enc2 = Encoder::from_hex("010234"); + enc.encode(enc2.as_ref()); + assert_eq!(enc, Encoder::from_hex("ff010234")); + } + + // Test that Deref to &[u8] works for Decoder. + #[test] + fn encode_view() { + let mut enc = Encoder::from_hex("ff"); + let enc2 = Encoder::from_hex("010234"); + let v = enc2.as_decoder(); + enc.encode(v.as_ref()); + assert_eq!(enc, Encoder::from_hex("ff010234")); + } + + #[test] + fn encode_mutate() { + let mut enc = Encoder::from_hex("010234"); + enc.as_mut()[0] = 0xff; + assert_eq!(enc, Encoder::from_hex("ff0234")); + } + + #[test] + fn pad() { + let mut enc = Encoder::from_hex("010234"); + enc.pad_to(5, 0); + assert_eq!(enc, Encoder::from_hex("0102340000")); + enc.pad_to(4, 0); + assert_eq!(enc, Encoder::from_hex("0102340000")); + enc.pad_to(7, 0xc2); + assert_eq!(enc, Encoder::from_hex("0102340000c2c2")); + } +} diff --git a/third_party/rust/neqo-common/src/datagram.rs b/third_party/rust/neqo-common/src/datagram.rs new file mode 100644 index 0000000000..1729c8ed8d --- /dev/null +++ b/third_party/rust/neqo-common/src/datagram.rs @@ -0,0 +1,91 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{net::SocketAddr, ops::Deref}; + +use crate::{hex_with_len, IpTos}; + +#[derive(Clone, PartialEq, Eq)] +pub struct Datagram { + src: SocketAddr, + dst: SocketAddr, + tos: IpTos, + ttl: Option<u8>, + d: Vec<u8>, +} + +impl Datagram { + pub fn new<V: Into<Vec<u8>>>( + src: SocketAddr, + dst: SocketAddr, + tos: IpTos, + ttl: Option<u8>, + d: V, + ) -> Self { + Self { + src, + dst, + tos, + ttl, + d: d.into(), + } + } + + #[must_use] + pub fn source(&self) -> SocketAddr { + self.src + } + + #[must_use] + pub fn destination(&self) -> SocketAddr { + self.dst + } + + #[must_use] + pub fn tos(&self) -> IpTos { + self.tos + } + + #[must_use] + pub fn ttl(&self) -> Option<u8> { + self.ttl + } +} + +impl Deref for Datagram { + type Target = Vec<u8>; + #[must_use] + fn deref(&self) -> &Self::Target { + &self.d + } +} + +impl std::fmt::Debug for Datagram { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "Datagram {:?} TTL {:?} {:?}->{:?}: {}", + self.tos, + self.ttl, + self.src, + self.dst, + hex_with_len(&self.d) + ) + } +} + +#[cfg(test)] +use test_fixture::datagram; + +#[test] +fn fmt_datagram() { + let d = datagram([0; 1].to_vec()); + assert_eq!( + format!("{d:?}"), + "Datagram IpTos(Cs0, NotEct) TTL Some(128) [fe80::1]:443->[fe80::1]:443: [1]: 00" + .to_string() + ); +} diff --git a/third_party/rust/neqo-common/src/event.rs b/third_party/rust/neqo-common/src/event.rs new file mode 100644 index 0000000000..26052b7571 --- /dev/null +++ b/third_party/rust/neqo-common/src/event.rs @@ -0,0 +1,52 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{iter::Iterator, marker::PhantomData}; + +/// An event provider is able to generate a stream of events. +pub trait Provider { + type Event; + + /// Get the next event. + #[must_use] + fn next_event(&mut self) -> Option<Self::Event>; + + /// Determine whether there are pending events. + #[must_use] + fn has_events(&self) -> bool; + + /// Construct an iterator that produces all events. + fn events(&'_ mut self) -> Iter<'_, Self, Self::Event> { + Iter::new(self) + } +} + +pub struct Iter<'a, P, E> +where + P: ?Sized, +{ + p: &'a mut P, + _e: PhantomData<E>, +} + +impl<'a, P, E> Iter<'a, P, E> +where + P: Provider<Event = E> + ?Sized, +{ + fn new(p: &'a mut P) -> Self { + Self { p, _e: PhantomData } + } +} + +impl<'a, P, E> Iterator for Iter<'a, P, E> +where + P: Provider<Event = E>, +{ + type Item = E; + fn next(&mut self) -> Option<Self::Item> { + self.p.next_event() + } +} diff --git a/third_party/rust/neqo-common/src/header.rs b/third_party/rust/neqo-common/src/header.rs new file mode 100644 index 0000000000..112fcf0057 --- /dev/null +++ b/third_party/rust/neqo-common/src/header.rs @@ -0,0 +1,48 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#[derive(Debug, PartialEq, PartialOrd, Eq, Ord, Clone)] +pub struct Header { + name: String, + value: String, +} + +impl Header { + pub fn new<N, V>(name: N, value: V) -> Self + where + N: Into<String> + ?Sized, + V: Into<String> + ?Sized, + { + Self { + name: name.into(), + value: value.into(), + } + } + + #[must_use] + pub fn is_allowed_for_response(&self) -> bool { + !matches!( + self.name.as_str(), + "connection" + | "host" + | "keep-alive" + | "proxy-connection" + | "te" + | "transfer-encoding" + | "upgrade" + ) + } + + #[must_use] + pub fn name(&self) -> &str { + &self.name + } + + #[must_use] + pub fn value(&self) -> &str { + &self.value + } +} diff --git a/third_party/rust/neqo-common/src/hrtime.rs b/third_party/rust/neqo-common/src/hrtime.rs new file mode 100644 index 0000000000..62d2567d42 --- /dev/null +++ b/third_party/rust/neqo-common/src/hrtime.rs @@ -0,0 +1,485 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cell::RefCell, + convert::TryFrom, + rc::{Rc, Weak}, + time::Duration, +}; + +#[cfg(windows)] +use winapi::shared::minwindef::UINT; +#[cfg(windows)] +use winapi::um::timeapi::{timeBeginPeriod, timeEndPeriod}; + +/// A quantized `Duration`. This currently just produces 16 discrete values +/// corresponding to whole milliseconds. Future implementations might choose +/// a different allocation, such as a logarithmic scale. +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +struct Period(u8); + +impl Period { + const MAX: Period = Period(16); + const MIN: Period = Period(1); + + #[cfg(windows)] + fn as_uint(self) -> UINT { + UINT::from(self.0) + } + + #[cfg(target_os = "macos")] + fn scaled(self, scale: f64) -> f64 { + scale * f64::from(self.0) + } +} + +impl From<Duration> for Period { + fn from(p: Duration) -> Self { + let rounded = u8::try_from(p.as_millis()).unwrap_or(Self::MAX.0); + Self(rounded.clamp(Self::MIN.0, Self::MAX.0)) + } +} + +/// This counts instances of `Period`, except those of `Period::MAX`. +#[derive(Default)] +struct PeriodSet { + counts: [usize; (Period::MAX.0 - Period::MIN.0) as usize], +} + +impl PeriodSet { + fn idx(&mut self, p: Period) -> &mut usize { + debug_assert!(p >= Period::MIN); + &mut self.counts[usize::from(p.0 - Period::MIN.0)] + } + + fn add(&mut self, p: Period) { + if p != Period::MAX { + *self.idx(p) += 1; + } + } + + fn remove(&mut self, p: Period) { + if p != Period::MAX { + debug_assert_ne!(*self.idx(p), 0); + *self.idx(p) -= 1; + } + } + + fn min(&self) -> Option<Period> { + for (i, v) in self.counts.iter().enumerate() { + if *v > 0 { + return Some(Period(u8::try_from(i).unwrap() + Period::MIN.0)); + } + } + None + } +} + +#[cfg(target_os = "macos")] +#[allow(non_camel_case_types)] +mod mac { + use std::{mem::size_of, ptr::addr_of_mut}; + + // These are manually extracted from the many bindings generated + // by bindgen when provided with the simple header: + // #include <mach/mach_init.h> + // #include <mach/mach_time.h> + // #include <mach/thread_policy.h> + // #include <pthread.h> + + type __darwin_natural_t = ::std::os::raw::c_uint; + type __darwin_mach_port_name_t = __darwin_natural_t; + type __darwin_mach_port_t = __darwin_mach_port_name_t; + type mach_port_t = __darwin_mach_port_t; + type thread_t = mach_port_t; + type natural_t = __darwin_natural_t; + type thread_policy_flavor_t = natural_t; + type integer_t = ::std::os::raw::c_int; + type thread_policy_t = *mut integer_t; + type mach_msg_type_number_t = natural_t; + type boolean_t = ::std::os::raw::c_uint; + type kern_return_t = ::std::os::raw::c_int; + + #[repr(C)] + #[derive(Debug, Copy, Clone, Default)] + struct mach_timebase_info { + numer: u32, + denom: u32, + } + type mach_timebase_info_t = *mut mach_timebase_info; + type mach_timebase_info_data_t = mach_timebase_info; + extern "C" { + fn mach_timebase_info(info: mach_timebase_info_t) -> kern_return_t; + } + + #[repr(C)] + #[derive(Debug, Copy, Clone, Default)] + pub struct thread_time_constraint_policy { + period: u32, + computation: u32, + constraint: u32, + preemptible: boolean_t, + } + + const THREAD_TIME_CONSTRAINT_POLICY: thread_policy_flavor_t = 2; + #[allow(clippy::cast_possible_truncation)] + const THREAD_TIME_CONSTRAINT_POLICY_COUNT: mach_msg_type_number_t = + (size_of::<thread_time_constraint_policy>() / size_of::<integer_t>()) + as mach_msg_type_number_t; + + // These function definitions are taken from a comment in <thread_policy.h>. + // Why they are inaccessible is unknown, but they work as declared. + extern "C" { + fn thread_policy_set( + thread: thread_t, + flavor: thread_policy_flavor_t, + policy_info: thread_policy_t, + count: mach_msg_type_number_t, + ) -> kern_return_t; + fn thread_policy_get( + thread: thread_t, + flavor: thread_policy_flavor_t, + policy_info: thread_policy_t, + count: *mut mach_msg_type_number_t, + get_default: *mut boolean_t, + ) -> kern_return_t; + } + + enum _opaque_pthread_t {} // An opaque type is fine here. + type __darwin_pthread_t = *mut _opaque_pthread_t; + type pthread_t = __darwin_pthread_t; + + extern "C" { + fn pthread_self() -> pthread_t; + fn pthread_mach_thread_np(thread: pthread_t) -> mach_port_t; + } + + /// Set a thread time policy. + pub fn set_thread_policy(mut policy: thread_time_constraint_policy) { + _ = unsafe { + thread_policy_set( + pthread_mach_thread_np(pthread_self()), + THREAD_TIME_CONSTRAINT_POLICY, + addr_of_mut!(policy).cast(), // horror! + THREAD_TIME_CONSTRAINT_POLICY_COUNT, + ) + }; + } + + pub fn get_scale() -> f64 { + const NANOS_PER_MSEC: f64 = 1_000_000.0; + let mut timebase_info = mach_timebase_info_data_t::default(); + unsafe { + mach_timebase_info(&mut timebase_info); + } + f64::from(timebase_info.denom) * NANOS_PER_MSEC / f64::from(timebase_info.numer) + } + + /// Create a realtime policy and set it. + pub fn set_realtime(base: f64) { + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let policy = thread_time_constraint_policy { + period: base as u32, // Base interval + computation: (base * 0.5) as u32, + constraint: (base * 1.0) as u32, + preemptible: 1, + }; + set_thread_policy(policy); + } + + /// Get the default policy. + pub fn get_default_policy() -> thread_time_constraint_policy { + let mut policy = thread_time_constraint_policy::default(); + let mut count = THREAD_TIME_CONSTRAINT_POLICY_COUNT; + let mut get_default = 0; + _ = unsafe { + thread_policy_get( + pthread_mach_thread_np(pthread_self()), + THREAD_TIME_CONSTRAINT_POLICY, + addr_of_mut!(policy).cast(), // horror! + &mut count, + &mut get_default, + ) + }; + policy + } +} + +/// A handle for a high-resolution timer of a specific period. +pub struct Handle { + hrt: Rc<RefCell<Time>>, + active: Period, + hysteresis: [Period; Self::HISTORY], + hysteresis_index: usize, +} + +impl Handle { + const HISTORY: usize = 8; + + fn new(hrt: Rc<RefCell<Time>>, active: Period) -> Self { + Self { + hrt, + active, + hysteresis: [Period::MAX; Self::HISTORY], + hysteresis_index: 0, + } + } + + /// Update shortcut. Equivalent to dropping the current reference and + /// calling `HrTime::get` again with the new period, except that this applies + /// a little hysteresis that smoothes out fluctuations. + pub fn update(&mut self, period: Duration) { + self.hysteresis[self.hysteresis_index] = Period::from(period); + self.hysteresis_index += 1; + self.hysteresis_index %= self.hysteresis.len(); + + let mut first = Period::MAX; + let mut second = Period::MAX; + for i in &self.hysteresis { + if *i < first { + second = first; + first = *i; + } else if *i < second { + second = *i; + } + } + + if second != self.active { + let mut b = self.hrt.borrow_mut(); + b.periods.remove(self.active); + self.active = second; + b.periods.add(self.active); + b.update(); + } + } +} + +impl Drop for Handle { + fn drop(&mut self) { + self.hrt.borrow_mut().remove(self.active); + } +} + +/// Holding an instance of this indicates that high resolution timers are enabled. +pub struct Time { + periods: PeriodSet, + active: Option<Period>, + + #[cfg(target_os = "macos")] + scale: f64, + #[cfg(target_os = "macos")] + deflt: mac::thread_time_constraint_policy, +} +impl Time { + fn new() -> Self { + Self { + periods: PeriodSet::default(), + active: None, + + #[cfg(target_os = "macos")] + scale: mac::get_scale(), + #[cfg(target_os = "macos")] + deflt: mac::get_default_policy(), + } + } + + #[allow(clippy::unused_self)] // Only on some platforms is it unused. + fn start(&self) { + #[cfg(target_os = "macos")] + { + if let Some(p) = self.active { + mac::set_realtime(p.scaled(self.scale)); + } else { + mac::set_thread_policy(self.deflt); + } + } + + #[cfg(windows)] + { + if let Some(p) = self.active { + _ = unsafe { timeBeginPeriod(p.as_uint()) }; + } + } + } + + #[allow(clippy::unused_self)] // Only on some platforms is it unused. + fn stop(&self) { + #[cfg(windows)] + { + if let Some(p) = self.active { + _ = unsafe { timeEndPeriod(p.as_uint()) }; + } + } + } + + fn update(&mut self) { + let next = self.periods.min(); + if next != self.active { + self.stop(); + self.active = next; + self.start(); + } + } + + fn add(&mut self, p: Period) { + self.periods.add(p); + self.update(); + } + + fn remove(&mut self, p: Period) { + self.periods.remove(p); + self.update(); + } + + /// Enable high resolution time. Returns a thread-bound handle that + /// needs to be held until the high resolution time is no longer needed. + /// The handle can also be used to update the resolution. + #[must_use] + pub fn get(period: Duration) -> Handle { + thread_local! { + static HR_TIME: RefCell<Weak<RefCell<Time>>> = RefCell::default(); + } + + HR_TIME.with(|r| { + let mut b = r.borrow_mut(); + let hrt = b.upgrade().unwrap_or_else(|| { + let hrt = Rc::new(RefCell::new(Time::new())); + *b = Rc::downgrade(&hrt); + hrt + }); + + let p = Period::from(period); + hrt.borrow_mut().add(p); + Handle::new(hrt, p) + }) + } +} + +impl Drop for Time { + fn drop(&mut self) { + self.stop(); + + #[cfg(target_os = "macos")] + { + if self.active.is_some() { + mac::set_thread_policy(self.deflt); + } + } + } +} + +// Only run these tests in CI on platforms other than MacOS and Windows, where the timer +// inaccuracies are too high to pass the tests. +#[cfg(all( + test, + not(all(any(target_os = "macos", target_os = "windows"), feature = "ci")) +))] +mod test { + use std::{ + thread::{sleep, spawn}, + time::{Duration, Instant}, + }; + + use super::Time; + + const ONE: Duration = Duration::from_millis(1); + const ONE_AND_A_BIT: Duration = Duration::from_micros(1500); + /// A limit for when high resolution timers are disabled. + const GENEROUS: Duration = Duration::from_millis(30); + + fn validate_delays(max_lag: Duration) -> Result<(), ()> { + const DELAYS: &[u64] = &[1, 2, 3, 5, 8, 10, 12, 15, 20, 25, 30]; + let durations = DELAYS.iter().map(|&d| Duration::from_millis(d)); + + let mut s = Instant::now(); + for d in durations { + sleep(d); + let e = Instant::now(); + let actual = e - s; + let lag = actual - d; + println!("sleep({d:?}) \u{2192} {actual:?} \u{394}{lag:?}"); + if lag > max_lag { + return Err(()); + } + s = Instant::now(); + } + Ok(()) + } + + /// Validate the delays twice. Sometimes the first run can stall. + /// Reliability in CI is more important than reliable timers. + fn check_delays(max_lag: Duration) { + if validate_delays(max_lag).is_err() { + sleep(Duration::from_millis(50)); + validate_delays(max_lag).unwrap(); + } + } + + /// Note that you have to run this test alone or other tests will + /// grab the high resolution timer and this will run faster. + #[test] + fn baseline() { + check_delays(GENEROUS); + } + + #[test] + fn one_ms() { + let _hrt = Time::get(ONE); + check_delays(ONE_AND_A_BIT); + } + + #[test] + fn multithread_baseline() { + let thr = spawn(move || { + baseline(); + }); + baseline(); + thr.join().unwrap(); + } + + #[test] + fn one_ms_multi() { + let thr = spawn(move || { + one_ms(); + }); + one_ms(); + thr.join().unwrap(); + } + + #[test] + fn mixed_multi() { + let thr = spawn(move || { + one_ms(); + }); + let _hrt = Time::get(Duration::from_millis(4)); + check_delays(Duration::from_millis(5)); + thr.join().unwrap(); + } + + #[test] + fn update() { + let mut hrt = Time::get(Duration::from_millis(4)); + check_delays(Duration::from_millis(5)); + hrt.update(ONE); + check_delays(ONE_AND_A_BIT); + } + + #[test] + fn update_multi() { + let thr = spawn(move || { + update(); + }); + update(); + thr.join().unwrap(); + } + + #[test] + fn max() { + let _hrt = Time::get(Duration::from_secs(1)); + check_delays(GENEROUS); + } +} diff --git a/third_party/rust/neqo-common/src/incrdecoder.rs b/third_party/rust/neqo-common/src/incrdecoder.rs new file mode 100644 index 0000000000..8468102cb6 --- /dev/null +++ b/third_party/rust/neqo-common/src/incrdecoder.rs @@ -0,0 +1,275 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{cmp::min, mem}; + +use crate::codec::Decoder; + +#[derive(Clone, Debug, Default)] +pub struct IncrementalDecoderUint { + v: u64, + remaining: Option<usize>, +} + +impl IncrementalDecoderUint { + #[must_use] + pub fn min_remaining(&self) -> usize { + self.remaining.unwrap_or(1) + } + + /// Consume some data. + /// + /// # Panics + /// + /// Never, but this is not something the compiler can tell. + pub fn consume(&mut self, dv: &mut Decoder) -> Option<u64> { + if let Some(r) = &mut self.remaining { + let amount = min(*r, dv.remaining()); + if amount < 8 { + self.v <<= amount * 8; + } + self.v |= dv.decode_uint(amount).unwrap(); + *r -= amount; + if *r == 0 { + Some(self.v) + } else { + None + } + } else { + let (v, remaining) = match dv.decode_byte() { + Some(b) => ( + u64::from(b & 0x3f), + match b >> 6 { + 0 => 0, + 1 => 1, + 2 => 3, + 3 => 7, + _ => unreachable!(), + }, + ), + None => unreachable!(), + }; + self.remaining = Some(remaining); + self.v = v; + if remaining == 0 { + Some(v) + } else { + None + } + } + } + + #[must_use] + pub fn decoding_in_progress(&self) -> bool { + self.remaining.is_some() + } +} + +#[derive(Clone, Debug)] +pub struct IncrementalDecoderBuffer { + v: Vec<u8>, + remaining: usize, +} + +impl IncrementalDecoderBuffer { + #[must_use] + pub fn new(n: usize) -> Self { + Self { + v: Vec::new(), + remaining: n, + } + } + + #[must_use] + pub fn min_remaining(&self) -> usize { + self.remaining + } + + /// Consume some bytes from the decoder. + /// + /// # Panics + /// + /// Never; but rust doesn't know that. + pub fn consume(&mut self, dv: &mut Decoder) -> Option<Vec<u8>> { + let amount = min(self.remaining, dv.remaining()); + let b = dv.decode(amount).unwrap(); + self.v.extend_from_slice(b); + self.remaining -= amount; + if self.remaining == 0 { + Some(mem::take(&mut self.v)) + } else { + None + } + } +} + +#[derive(Clone, Debug)] +pub struct IncrementalDecoderIgnore { + remaining: usize, +} + +impl IncrementalDecoderIgnore { + /// Make a new ignoring decoder. + /// + /// # Panics + /// + /// If the amount to ignore is zero. + #[must_use] + pub fn new(n: usize) -> Self { + assert_ne!(n, 0); + Self { remaining: n } + } + + #[must_use] + pub fn min_remaining(&self) -> usize { + self.remaining + } + + pub fn consume(&mut self, dv: &mut Decoder) -> bool { + let amount = min(self.remaining, dv.remaining()); + _ = dv.decode(amount); + self.remaining -= amount; + self.remaining == 0 + } +} + +#[cfg(test)] +mod tests { + use super::{ + Decoder, IncrementalDecoderBuffer, IncrementalDecoderIgnore, IncrementalDecoderUint, + }; + use crate::codec::Encoder; + + #[test] + fn buffer_incremental() { + let b = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + let mut dec = IncrementalDecoderBuffer::new(b.len()); + let mut i = 0; + while i < b.len() { + // Feed in b in increasing-sized chunks. + let incr = if i < b.len() / 2 { i + 1 } else { b.len() - i }; + let mut dv = Decoder::from(&b[i..i + incr]); + i += incr; + match dec.consume(&mut dv) { + None => { + assert!(i < b.len()); + } + Some(res) => { + assert_eq!(i, b.len()); + assert_eq!(res, b); + } + } + } + } + + struct UintTestCase { + b: String, + v: u64, + } + + impl UintTestCase { + pub fn run(&self) { + eprintln!( + "IncrementalDecoderUint decoder with {:?} ; expect {:?}", + self.b, self.v + ); + + let decoder = IncrementalDecoderUint::default(); + let mut db = Encoder::from_hex(&self.b); + // Add padding so that we can verify that the reader doesn't over-consume. + db.encode_byte(0xff); + + for tail in 1..db.len() { + let split = db.len() - tail; + let mut dv = Decoder::from(&db.as_ref()[0..split]); + eprintln!(" split at {split}: {dv:?}"); + + // Clone the basic decoder for each iteration of the loop. + let mut dec = decoder.clone(); + let mut res = None; + while dv.remaining() > 0 { + res = dec.consume(&mut dv); + } + assert!(dec.min_remaining() < tail); + + if tail > 1 { + assert_eq!(res, None); + assert!(dec.min_remaining() > 0); + let mut dv = Decoder::from(&db.as_ref()[split..]); + eprintln!(" split remainder {split}: {dv:?}"); + res = dec.consume(&mut dv); + assert_eq!(dv.remaining(), 1); + } + + assert_eq!(dec.min_remaining(), 0); + assert_eq!(res.unwrap(), self.v); + } + } + } + + macro_rules! uint_tc { + [$( $b:expr => $v:expr ),+ $(,)?] => { + vec![ $( UintTestCase { b: String::from($b), v: $v, } ),+] + }; + } + + #[test] + fn varint() { + for c in uint_tc![ + "00" => 0, + "01" => 1, + "3f" => 63, + "4040" => 64, + "7fff" => 16383, + "80004000" => 16384, + "bfffffff" => (1 << 30) - 1, + "c000000040000000" => 1 << 30, + "ffffffffffffffff" => (1 << 62) - 1, + ] { + c.run(); + } + } + + #[test] + fn zero_len() { + let enc = Encoder::from_hex("ff"); + let mut dec = Decoder::new(enc.as_ref()); + let mut incr = IncrementalDecoderBuffer::new(0); + assert_eq!(incr.consume(&mut dec), Some(Vec::new())); + assert_eq!(dec.remaining(), enc.len()); + } + + #[test] + fn ignore() { + let db = Encoder::from_hex("12345678ff"); + + let decoder = IncrementalDecoderIgnore::new(4); + + for tail in 1..db.len() { + let split = db.len() - tail; + let mut dv = Decoder::from(&db.as_ref()[0..split]); + eprintln!(" split at {split}: {dv:?}"); + + // Clone the basic decoder for each iteration of the loop. + let mut dec = decoder.clone(); + let mut res = dec.consume(&mut dv); + assert_eq!(dv.remaining(), 0); + assert!(dec.min_remaining() < tail); + + if tail > 1 { + assert!(!res); + assert!(dec.min_remaining() > 0); + let mut dv = Decoder::from(&db.as_ref()[split..]); + eprintln!(" split remainder {split}: {dv:?}"); + res = dec.consume(&mut dv); + assert_eq!(dv.remaining(), 1); + } + + assert_eq!(dec.min_remaining(), 0); + assert!(res); + } + } +} diff --git a/third_party/rust/neqo-common/src/lib.rs b/third_party/rust/neqo-common/src/lib.rs new file mode 100644 index 0000000000..853b05705b --- /dev/null +++ b/third_party/rust/neqo-common/src/lib.rs @@ -0,0 +1,109 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![cfg_attr(feature = "deny-warnings", deny(warnings))] +#![warn(clippy::pedantic)] + +mod codec; +mod datagram; +pub mod event; +pub mod header; +pub mod hrtime; +mod incrdecoder; +pub mod log; +pub mod qlog; +pub mod timer; +pub mod tos; + +use std::fmt::Write; + +use enum_map::Enum; + +pub use self::{ + codec::{Decoder, Encoder}, + datagram::Datagram, + header::Header, + incrdecoder::{IncrementalDecoderBuffer, IncrementalDecoderIgnore, IncrementalDecoderUint}, + tos::{IpTos, IpTosDscp, IpTosEcn}, +}; + +#[must_use] +pub fn hex(buf: impl AsRef<[u8]>) -> String { + let mut ret = String::with_capacity(buf.as_ref().len() * 2); + for b in buf.as_ref() { + write!(&mut ret, "{b:02x}").unwrap(); + } + ret +} + +#[must_use] +pub fn hex_snip_middle(buf: impl AsRef<[u8]>) -> String { + const SHOW_LEN: usize = 8; + let buf = buf.as_ref(); + if buf.len() <= SHOW_LEN * 2 { + hex_with_len(buf) + } else { + let mut ret = String::with_capacity(SHOW_LEN * 2 + 16); + write!(&mut ret, "[{}]: ", buf.len()).unwrap(); + for b in &buf[..SHOW_LEN] { + write!(&mut ret, "{b:02x}").unwrap(); + } + ret.push_str(".."); + for b in &buf[buf.len() - SHOW_LEN..] { + write!(&mut ret, "{b:02x}").unwrap(); + } + ret + } +} + +#[must_use] +pub fn hex_with_len(buf: impl AsRef<[u8]>) -> String { + let buf = buf.as_ref(); + let mut ret = String::with_capacity(10 + buf.len() * 2); + write!(&mut ret, "[{}]: ", buf.len()).unwrap(); + for b in buf { + write!(&mut ret, "{b:02x}").unwrap(); + } + ret +} + +#[must_use] +pub const fn const_max(a: usize, b: usize) -> usize { + [a, b][(a < b) as usize] +} +#[must_use] +pub const fn const_min(a: usize, b: usize) -> usize { + [a, b][(a >= b) as usize] +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone, Enum)] +/// Client or Server. +pub enum Role { + Client, + Server, +} + +impl Role { + #[must_use] + pub fn remote(self) -> Self { + match self { + Self::Client => Self::Server, + Self::Server => Self::Client, + } + } +} + +impl ::std::fmt::Display for Role { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{self:?}") + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MessageType { + Request, + Response, +} diff --git a/third_party/rust/neqo-common/src/log.rs b/third_party/rust/neqo-common/src/log.rs new file mode 100644 index 0000000000..d9c30b98b1 --- /dev/null +++ b/third_party/rust/neqo-common/src/log.rs @@ -0,0 +1,104 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(clippy::module_name_repetitions)] + +use std::{io::Write, sync::Once, time::Instant}; + +use env_logger::Builder; +use lazy_static::lazy_static; + +#[macro_export] +macro_rules! do_log { + (target: $target:expr, $lvl:expr, $($arg:tt)+) => ({ + let lvl = $lvl; + if lvl <= ::log::max_level() { + ::log::logger().log( + &::log::Record::builder() + .args(format_args!($($arg)+)) + .level(lvl) + .target($target) + .module_path_static(Some(module_path!())) + .file_static(Some(file!())) + .line(Some(line!())) + .build() + ); + } + }); + ($lvl:expr, $($arg:tt)+) => ($crate::do_log!(target: module_path!(), $lvl, $($arg)+)) +} + +#[macro_export] +macro_rules! log_subject { + ($lvl:expr, $subject:expr) => {{ + if $lvl <= ::log::max_level() { + format!("{}", $subject) + } else { + String::new() + } + }}; +} + +static INIT_ONCE: Once = Once::new(); + +lazy_static! { + static ref START_TIME: Instant = Instant::now(); +} + +pub fn init() { + INIT_ONCE.call_once(|| { + let mut builder = Builder::from_env("RUST_LOG"); + builder.format(|buf, record| { + let elapsed = START_TIME.elapsed(); + writeln!( + buf, + "{}s{:3}ms {} {}", + elapsed.as_secs(), + elapsed.as_millis() % 1000, + record.level(), + record.args() + ) + }); + if let Err(e) = builder.try_init() { + do_log!(::log::Level::Info, "Logging initialization error {:?}", e); + } else { + do_log!(::log::Level::Info, "Logging initialized"); + } + }); +} + +#[macro_export] +macro_rules! log_invoke { + ($lvl:expr, $ctx:expr, $($arg:tt)*) => ( { + ::neqo_common::log::init(); + ::neqo_common::do_log!($lvl, "[{}] {}", $ctx, format!($($arg)*)); + } ) +} +#[macro_export] +macro_rules! qerror { + ([$ctx:expr], $($arg:tt)*) => (::neqo_common::log_invoke!(::log::Level::Error, $ctx, $($arg)*);); + ($($arg:tt)*) => ( { ::neqo_common::log::init(); ::neqo_common::do_log!(::log::Level::Error, $($arg)*); } ); +} +#[macro_export] +macro_rules! qwarn { + ([$ctx:expr], $($arg:tt)*) => (::neqo_common::log_invoke!(::log::Level::Warn, $ctx, $($arg)*);); + ($($arg:tt)*) => ( { ::neqo_common::log::init(); ::neqo_common::do_log!(::log::Level::Warn, $($arg)*); } ); +} +#[macro_export] +macro_rules! qinfo { + ([$ctx:expr], $($arg:tt)*) => (::neqo_common::log_invoke!(::log::Level::Info, $ctx, $($arg)*);); + ($($arg:tt)*) => ( { ::neqo_common::log::init(); ::neqo_common::do_log!(::log::Level::Info, $($arg)*); } ); +} +#[macro_export] +macro_rules! qdebug { + ([$ctx:expr], $($arg:tt)*) => (::neqo_common::log_invoke!(::log::Level::Debug, $ctx, $($arg)*);); + ($($arg:tt)*) => ( { ::neqo_common::log::init(); ::neqo_common::do_log!(::log::Level::Debug, $($arg)*); } ); +} +#[macro_export] +macro_rules! qtrace { + ([$ctx:expr], $($arg:tt)*) => (::neqo_common::log_invoke!(::log::Level::Trace, $ctx, $($arg)*);); + ($($arg:tt)*) => ( { ::neqo_common::log::init(); ::neqo_common::do_log!(::log::Level::Trace, $($arg)*); } ); +} diff --git a/third_party/rust/neqo-common/src/qlog.rs b/third_party/rust/neqo-common/src/qlog.rs new file mode 100644 index 0000000000..3da8350990 --- /dev/null +++ b/third_party/rust/neqo-common/src/qlog.rs @@ -0,0 +1,188 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cell::RefCell, + fmt, + path::{Path, PathBuf}, + rc::Rc, +}; + +use qlog::{ + self, streamer::QlogStreamer, CommonFields, Configuration, TraceSeq, VantagePoint, + VantagePointType, +}; + +use crate::Role; + +#[allow(clippy::module_name_repetitions)] +#[derive(Debug, Clone, Default)] +pub struct NeqoQlog { + inner: Rc<RefCell<Option<NeqoQlogShared>>>, +} + +pub struct NeqoQlogShared { + qlog_path: PathBuf, + streamer: QlogStreamer, +} + +impl NeqoQlog { + /// Create an enabled `NeqoQlog` configuration. + /// + /// # Errors + /// + /// Will return `qlog::Error` if cannot write to the new log. + pub fn enabled( + mut streamer: QlogStreamer, + qlog_path: impl AsRef<Path>, + ) -> Result<Self, qlog::Error> { + streamer.start_log()?; + + Ok(Self { + inner: Rc::new(RefCell::new(Some(NeqoQlogShared { + streamer, + qlog_path: qlog_path.as_ref().to_owned(), + }))), + }) + } + + #[must_use] + pub fn inner(&self) -> Rc<RefCell<Option<NeqoQlogShared>>> { + Rc::clone(&self.inner) + } + + /// Create a disabled `NeqoQlog` configuration. + #[must_use] + pub fn disabled() -> Self { + Self::default() + } + + /// If logging enabled, closure may generate an event to be logged. + pub fn add_event<F>(&mut self, f: F) + where + F: FnOnce() -> Option<qlog::events::Event>, + { + self.add_event_with_stream(|s| { + if let Some(evt) = f() { + s.add_event(evt)?; + } + Ok(()) + }); + } + + /// If logging enabled, closure may generate an event to be logged. + pub fn add_event_data<F>(&mut self, f: F) + where + F: FnOnce() -> Option<qlog::events::EventData>, + { + self.add_event_with_stream(|s| { + if let Some(ev_data) = f() { + s.add_event_data_now(ev_data)?; + } + Ok(()) + }); + } + + /// If logging enabled, closure is given the Qlog stream to write events and + /// frames to. + pub fn add_event_with_stream<F>(&mut self, f: F) + where + F: FnOnce(&mut QlogStreamer) -> Result<(), qlog::Error>, + { + if let Some(inner) = self.inner.borrow_mut().as_mut() { + if let Err(e) = f(&mut inner.streamer) { + crate::do_log!( + ::log::Level::Error, + "Qlog event generation failed with error {}; closing qlog.", + e + ); + *self.inner.borrow_mut() = None; + } + } + } +} + +impl fmt::Debug for NeqoQlogShared { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "NeqoQlog writing to {}", self.qlog_path.display()) + } +} + +impl Drop for NeqoQlogShared { + fn drop(&mut self) { + if let Err(e) = self.streamer.finish_log() { + crate::do_log!(::log::Level::Error, "Error dropping NeqoQlog: {}", e); + } + } +} + +#[must_use] +pub fn new_trace(role: Role) -> qlog::TraceSeq { + TraceSeq { + vantage_point: VantagePoint { + name: Some(format!("neqo-{role}")), + ty: match role { + Role::Client => VantagePointType::Client, + Role::Server => VantagePointType::Server, + }, + flow: None, + }, + title: Some(format!("neqo-{role} trace")), + description: Some("Example qlog trace description".to_string()), + configuration: Some(Configuration { + time_offset: Some(0.0), + original_uris: None, + }), + common_fields: Some(CommonFields { + group_id: None, + protocol_type: None, + reference_time: { + // It is better to allow this than deal with a conversion from i64 to f64. + // We can't do the obvious two-step conversion with f64::from(i32::try_from(...)), + // because that overflows earlier than is ideal. This should be fine for a while. + #[allow(clippy::cast_precision_loss)] + Some(time::OffsetDateTime::now_utc().unix_timestamp() as f64) + }, + time_format: Some("relative".to_string()), + }), + } +} + +#[cfg(test)] +mod test { + use qlog::events::Event; + use test_fixture::EXPECTED_LOG_HEADER; + + const EV_DATA: qlog::events::EventData = + qlog::events::EventData::SpinBitUpdated(qlog::events::connectivity::SpinBitUpdated { + state: true, + }); + + const EXPECTED_LOG_EVENT: &str = concat!( + "\u{1e}", + r#"{"time":0.0,"name":"connectivity:spin_bit_updated","data":{"state":true}}"#, + "\n" + ); + + #[test] + fn new_neqo_qlog() { + let (_log, contents) = test_fixture::new_neqo_qlog(); + assert_eq!(contents.to_string(), EXPECTED_LOG_HEADER); + } + + #[test] + fn add_event() { + let (mut log, contents) = test_fixture::new_neqo_qlog(); + log.add_event(|| Some(Event::with_time(1.1, EV_DATA))); + assert_eq!( + contents.to_string(), + format!( + "{EXPECTED_LOG_HEADER}{e}", + e = EXPECTED_LOG_EVENT.replace("\"time\":0.0,", "\"time\":1.1,") + ) + ); + } +} diff --git a/third_party/rust/neqo-common/src/timer.rs b/third_party/rust/neqo-common/src/timer.rs new file mode 100644 index 0000000000..e8532af442 --- /dev/null +++ b/third_party/rust/neqo-common/src/timer.rs @@ -0,0 +1,396 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + convert::TryFrom, + mem, + time::{Duration, Instant}, +}; + +/// Internal structure for a timer item. +struct TimerItem<T> { + time: Instant, + item: T, +} + +impl<T> TimerItem<T> { + fn time(ti: &Self) -> Instant { + ti.time + } +} + +/// A timer queue. +/// This uses a classic timer wheel arrangement, with some characteristics that might be considered +/// peculiar. Each slot in the wheel is sorted (complexity O(N) insertions, but O(logN) to find cut +/// points). Time is relative, the wheel has an origin time and it is unable to represent times that +/// are more than `granularity * capacity` past that time. +pub struct Timer<T> { + items: Vec<Vec<TimerItem<T>>>, + now: Instant, + granularity: Duration, + cursor: usize, +} + +impl<T> Timer<T> { + /// Construct a new wheel at the given granularity, starting at the given time. + /// + /// # Panics + /// + /// When `capacity` is too large to fit in `u32` or `granularity` is zero. + pub fn new(now: Instant, granularity: Duration, capacity: usize) -> Self { + assert!(u32::try_from(capacity).is_ok()); + assert!(granularity.as_nanos() > 0); + let mut items = Vec::with_capacity(capacity); + items.resize_with(capacity, Default::default); + Self { + items, + now, + granularity, + cursor: 0, + } + } + + /// Return a reference to the time of the next entry. + #[must_use] + pub fn next_time(&self) -> Option<Instant> { + for i in 0..self.items.len() { + let idx = self.bucket(i); + if let Some(t) = self.items[idx].first() { + return Some(t.time); + } + } + None + } + + /// Get the full span of time that this can cover. + /// Two timers cannot be more than this far apart. + /// In practice, this value is less by one amount of the timer granularity. + #[inline] + #[allow(clippy::cast_possible_truncation)] // guarded by assertion + #[must_use] + pub fn span(&self) -> Duration { + self.granularity * (self.items.len() as u32) + } + + /// For the given `time`, get the number of whole buckets in the future that is. + #[inline] + #[allow(clippy::cast_possible_truncation)] // guarded by assertion + fn delta(&self, time: Instant) -> usize { + // This really should use Duration::div_duration_f??(), but it can't yet. + ((time - self.now).as_nanos() / self.granularity.as_nanos()) as usize + } + + #[inline] + fn time_bucket(&self, time: Instant) -> usize { + self.bucket(self.delta(time)) + } + + #[inline] + fn bucket(&self, delta: usize) -> usize { + debug_assert!(delta < self.items.len()); + (self.cursor + delta) % self.items.len() + } + + /// Slide forward in time by `n * self.granularity`. + #[allow(clippy::cast_possible_truncation, clippy::reversed_empty_ranges)] + // cast_possible_truncation is ok because we have an assertion guard. + // reversed_empty_ranges is to avoid different types on the if/else. + fn tick(&mut self, n: usize) { + let new = self.bucket(n); + let iter = if new < self.cursor { + (self.cursor..self.items.len()).chain(0..new) + } else { + (self.cursor..new).chain(0..0) + }; + for i in iter { + assert!(self.items[i].is_empty()); + } + self.now += self.granularity * (n as u32); + self.cursor = new; + } + + /// Asserts if the time given is in the past or too far in the future. + /// + /// # Panics + /// + /// When `time` is in the past relative to previous calls. + pub fn add(&mut self, time: Instant, item: T) { + assert!(time >= self.now); + // Skip forward quickly if there is too large a gap. + let short_span = self.span() - self.granularity; + if time >= (self.now + self.span() + short_span) { + // Assert that there aren't any items. + for i in &self.items { + debug_assert!(i.is_empty()); + } + self.now = time.checked_sub(short_span).unwrap(); + self.cursor = 0; + } + + // Adjust time forward the minimum amount necessary. + let mut d = self.delta(time); + if d >= self.items.len() { + self.tick(1 + d - self.items.len()); + d = self.items.len() - 1; + } + + let bucket = self.bucket(d); + let ins = match self.items[bucket].binary_search_by_key(&time, TimerItem::time) { + Ok(j) | Err(j) => j, + }; + self.items[bucket].insert(ins, TimerItem { time, item }); + } + + /// Given knowledge of the time an item was added, remove it. + /// This requires use of a predicate that identifies matching items. + pub fn remove<F>(&mut self, time: Instant, mut selector: F) -> Option<T> + where + F: FnMut(&T) -> bool, + { + if time < self.now { + return None; + } + if time > self.now + self.span() { + return None; + } + let bucket = self.time_bucket(time); + let Ok(start_index) = self.items[bucket].binary_search_by_key(&time, TimerItem::time) + else { + return None; + }; + // start_index is just one of potentially many items with the same time. + // Search backwards for a match, ... + for i in (0..=start_index).rev() { + if self.items[bucket][i].time != time { + break; + } + if selector(&self.items[bucket][i].item) { + return Some(self.items[bucket].remove(i).item); + } + } + // ... then forwards. + for i in (start_index + 1)..self.items[bucket].len() { + if self.items[bucket][i].time != time { + break; + } + if selector(&self.items[bucket][i].item) { + return Some(self.items[bucket].remove(i).item); + } + } + None + } + + /// Take the next item, unless there are no items with + /// a timeout in the past relative to `until`. + pub fn take_next(&mut self, until: Instant) -> Option<T> { + for i in 0..self.items.len() { + let idx = self.bucket(i); + if !self.items[idx].is_empty() && self.items[idx][0].time <= until { + return Some(self.items[idx].remove(0).item); + } + } + None + } + + /// Create an iterator that takes all items until the given time. + /// Note: Items might be removed even if the iterator is not fully exhausted. + pub fn take_until(&mut self, until: Instant) -> impl Iterator<Item = T> { + let get_item = move |x: TimerItem<T>| x.item; + if until >= self.now + self.span() { + // Drain everything, so a clean sweep. + let mut empty_items = Vec::with_capacity(self.items.len()); + empty_items.resize_with(self.items.len(), Vec::default); + let mut items = mem::replace(&mut self.items, empty_items); + self.now = until; + self.cursor = 0; + + let tail = items.split_off(self.cursor); + return tail.into_iter().chain(items).flatten().map(get_item); + } + + // Only returning a partial span, so do it bucket at a time. + let delta = self.delta(until); + let mut buckets = Vec::with_capacity(delta + 1); + + // First, the whole buckets. + for i in 0..delta { + let idx = self.bucket(i); + buckets.push(mem::take(&mut self.items[idx])); + } + self.tick(delta); + + // Now we need to split the last bucket, because there might be + // some items with `item.time > until`. + let bucket = &mut self.items[self.cursor]; + let last_idx = match bucket.binary_search_by_key(&until, TimerItem::time) { + Ok(mut m) => { + // If there are multiple values, the search will hit any of them. + // Make sure to get them all. + while m < bucket.len() && bucket[m].time == until { + m += 1; + } + m + } + Err(ins) => ins, + }; + let tail = bucket.split_off(last_idx); + buckets.push(mem::replace(bucket, tail)); + // This tomfoolery with the empty vector ensures that + // the returned type here matches the one above precisely + // without having to invoke the `either` crate. + buckets.into_iter().chain(vec![]).flatten().map(get_item) + } +} + +#[cfg(test)] +mod test { + use lazy_static::lazy_static; + + use super::{Duration, Instant, Timer}; + + lazy_static! { + static ref NOW: Instant = Instant::now(); + } + + const GRANULARITY: Duration = Duration::from_millis(10); + const CAPACITY: usize = 10; + #[test] + fn create() { + let t: Timer<()> = Timer::new(*NOW, GRANULARITY, CAPACITY); + assert_eq!(t.span(), Duration::from_millis(100)); + assert_eq!(None, t.next_time()); + } + + #[test] + fn immediate_entry() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + t.add(*NOW, 12); + assert_eq!(*NOW, t.next_time().expect("should have an entry")); + let values: Vec<_> = t.take_until(*NOW).collect(); + assert_eq!(vec![12], values); + } + + #[test] + fn same_time() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let v1 = 12; + let v2 = 13; + t.add(*NOW, v1); + t.add(*NOW, v2); + assert_eq!(*NOW, t.next_time().expect("should have an entry")); + let values: Vec<_> = t.take_until(*NOW).collect(); + assert!(values.contains(&v1)); + assert!(values.contains(&v2)); + } + + #[test] + fn add() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let near_future = *NOW + Duration::from_millis(17); + let v = 9; + t.add(near_future, v); + assert_eq!(near_future, t.next_time().expect("should return a value")); + assert_eq!( + t.take_until(near_future.checked_sub(Duration::from_millis(1)).unwrap()) + .count(), + 0 + ); + assert!(t + .take_until(near_future + Duration::from_millis(1)) + .any(|x| x == v)); + } + + #[test] + fn add_future() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let future = *NOW + Duration::from_millis(117); + let v = 9; + t.add(future, v); + assert_eq!(future, t.next_time().expect("should return a value")); + assert!(t.take_until(future).any(|x| x == v)); + } + + #[test] + fn add_far_future() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let far_future = *NOW + Duration::from_millis(892); + let v = 9; + t.add(far_future, v); + assert_eq!(far_future, t.next_time().expect("should return a value")); + assert!(t.take_until(far_future).any(|x| x == v)); + } + + const TIMES: &[Duration] = &[ + Duration::from_millis(40), + Duration::from_millis(91), + Duration::from_millis(6), + Duration::from_millis(3), + Duration::from_millis(22), + Duration::from_millis(40), + ]; + + fn with_times() -> Timer<usize> { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + for (i, time) in TIMES.iter().enumerate() { + t.add(*NOW + *time, i); + } + assert_eq!( + *NOW + *TIMES.iter().min().unwrap(), + t.next_time().expect("should have a time") + ); + t + } + + #[test] + #[allow(clippy::needless_collect)] // false positive + fn multiple_values() { + let mut t = with_times(); + let values: Vec<_> = t.take_until(*NOW + *TIMES.iter().max().unwrap()).collect(); + for i in 0..TIMES.len() { + assert!(values.contains(&i)); + } + } + + #[test] + #[allow(clippy::needless_collect)] // false positive + fn take_far_future() { + let mut t = with_times(); + let values: Vec<_> = t.take_until(*NOW + Duration::from_secs(100)).collect(); + for i in 0..TIMES.len() { + assert!(values.contains(&i)); + } + } + + #[test] + fn remove_each() { + let mut t = with_times(); + for (i, time) in TIMES.iter().enumerate() { + assert_eq!(Some(i), t.remove(*NOW + *time, |&x| x == i)); + } + assert_eq!(None, t.next_time()); + } + + #[test] + fn remove_future() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let future = *NOW + Duration::from_millis(117); + let v = 9; + t.add(future, v); + + assert_eq!(Some(v), t.remove(future, |candidate| *candidate == v)); + } + + #[test] + fn remove_too_far_future() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let future = *NOW + Duration::from_millis(117); + let too_far_future = *NOW + t.span() + Duration::from_millis(117); + let v = 9; + t.add(future, v); + + assert_eq!(None, t.remove(too_far_future, |candidate| *candidate == v)); + } +} diff --git a/third_party/rust/neqo-common/src/tos.rs b/third_party/rust/neqo-common/src/tos.rs new file mode 100644 index 0000000000..aa360d1d53 --- /dev/null +++ b/third_party/rust/neqo-common/src/tos.rs @@ -0,0 +1,290 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::fmt::Debug; + +use enum_map::Enum; + +/// ECN (Explicit Congestion Notification) codepoints mapped to the +/// lower 2 bits of the TOS field. +/// <https://www.iana.org/assignments/dscp-registry/dscp-registry.xhtml> +#[derive(Copy, Clone, PartialEq, Eq, Enum, Default, Debug)] +#[repr(u8)] +pub enum IpTosEcn { + #[default] + /// Not-ECT, Not ECN-Capable Transport, RFC3168 + NotEct = 0b00, + + /// ECT(1), ECN-Capable Transport(1), RFC8311 and RFC9331 + Ect1 = 0b01, + + /// ECT(0), ECN-Capable Transport(0), RFC3168 + Ect0 = 0b10, + + /// CE, Congestion Experienced, RFC3168 + Ce = 0b11, +} + +impl From<IpTosEcn> for u8 { + fn from(v: IpTosEcn) -> Self { + v as u8 + } +} + +impl From<u8> for IpTosEcn { + fn from(v: u8) -> Self { + match v & 0b11 { + 0b00 => IpTosEcn::NotEct, + 0b01 => IpTosEcn::Ect1, + 0b10 => IpTosEcn::Ect0, + 0b11 => IpTosEcn::Ce, + _ => unreachable!(), + } + } +} + +/// Diffserv Codepoints, mapped to the upper six bits of the TOS field. +/// <https://www.iana.org/assignments/dscp-registry/dscp-registry.xhtml> +#[derive(Copy, Clone, PartialEq, Eq, Enum, Default, Debug)] +#[repr(u8)] +pub enum IpTosDscp { + #[default] + /// Class Selector 0, RFC2474 + Cs0 = 0b0000_0000, + + /// Class Selector 1, RFC2474 + Cs1 = 0b0010_0000, + + /// Class Selector 2, RFC2474 + Cs2 = 0b0100_0000, + + /// Class Selector 3, RFC2474 + Cs3 = 0b0110_0000, + + /// Class Selector 4, RFC2474 + Cs4 = 0b1000_0000, + + /// Class Selector 5, RFC2474 + Cs5 = 0b1010_0000, + + /// Class Selector 6, RFC2474 + Cs6 = 0b1100_0000, + + /// Class Selector 7, RFC2474 + Cs7 = 0b1110_0000, + + /// Assured Forwarding 11, RFC2597 + Af11 = 0b0010_1000, + + /// Assured Forwarding 12, RFC2597 + Af12 = 0b0011_0000, + + /// Assured Forwarding 13, RFC2597 + Af13 = 0b0011_1000, + + /// Assured Forwarding 21, RFC2597 + Af21 = 0b0100_1000, + + /// Assured Forwarding 22, RFC2597 + Af22 = 0b0101_0000, + + /// Assured Forwarding 23, RFC2597 + Af23 = 0b0101_1000, + + /// Assured Forwarding 31, RFC2597 + Af31 = 0b0110_1000, + + /// Assured Forwarding 32, RFC2597 + Af32 = 0b0111_0000, + + /// Assured Forwarding 33, RFC2597 + Af33 = 0b0111_1000, + + /// Assured Forwarding 41, RFC2597 + Af41 = 0b1000_1000, + + /// Assured Forwarding 42, RFC2597 + Af42 = 0b1001_0000, + + /// Assured Forwarding 43, RFC2597 + Af43 = 0b1001_1000, + + /// Expedited Forwarding, RFC3246 + Ef = 0b1011_1000, + + /// Capacity-Admitted Traffic, RFC5865 + VoiceAdmit = 0b1011_0000, + + /// Lower-Effort, RFC8622 + Le = 0b0000_0100, +} + +impl From<IpTosDscp> for u8 { + fn from(v: IpTosDscp) -> Self { + v as u8 + } +} + +impl From<u8> for IpTosDscp { + fn from(v: u8) -> Self { + match v & 0b1111_1100 { + 0b0000_0000 => IpTosDscp::Cs0, + 0b0010_0000 => IpTosDscp::Cs1, + 0b0100_0000 => IpTosDscp::Cs2, + 0b0110_0000 => IpTosDscp::Cs3, + 0b1000_0000 => IpTosDscp::Cs4, + 0b1010_0000 => IpTosDscp::Cs5, + 0b1100_0000 => IpTosDscp::Cs6, + 0b1110_0000 => IpTosDscp::Cs7, + 0b0010_1000 => IpTosDscp::Af11, + 0b0011_0000 => IpTosDscp::Af12, + 0b0011_1000 => IpTosDscp::Af13, + 0b0100_1000 => IpTosDscp::Af21, + 0b0101_0000 => IpTosDscp::Af22, + 0b0101_1000 => IpTosDscp::Af23, + 0b0110_1000 => IpTosDscp::Af31, + 0b0111_0000 => IpTosDscp::Af32, + 0b0111_1000 => IpTosDscp::Af33, + 0b1000_1000 => IpTosDscp::Af41, + 0b1001_0000 => IpTosDscp::Af42, + 0b1001_1000 => IpTosDscp::Af43, + 0b1011_1000 => IpTosDscp::Ef, + 0b1011_0000 => IpTosDscp::VoiceAdmit, + 0b0000_0100 => IpTosDscp::Le, + _ => unreachable!(), + } + } +} + +/// The type-of-service field in an IP packet. +#[allow(clippy::module_name_repetitions)] +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct IpTos(u8); + +impl From<IpTosEcn> for IpTos { + fn from(v: IpTosEcn) -> Self { + Self(u8::from(v)) + } +} +impl From<IpTosDscp> for IpTos { + fn from(v: IpTosDscp) -> Self { + Self(u8::from(v)) + } +} +impl From<(IpTosDscp, IpTosEcn)> for IpTos { + fn from(v: (IpTosDscp, IpTosEcn)) -> Self { + Self(u8::from(v.0) | u8::from(v.1)) + } +} +impl From<IpTos> for u8 { + fn from(v: IpTos) -> Self { + v.0 + } +} + +impl Debug for IpTos { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("IpTos") + .field(&IpTosDscp::from(self.0 & 0xfc)) + .field(&IpTosEcn::from(self.0 & 0x3)) + .finish() + } +} + +impl Default for IpTos { + fn default() -> Self { + (IpTosDscp::default(), IpTosEcn::default()).into() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn iptosecn_into_u8() { + assert_eq!(u8::from(IpTosEcn::NotEct), 0b00); + assert_eq!(u8::from(IpTosEcn::Ect1), 0b01); + assert_eq!(u8::from(IpTosEcn::Ect0), 0b10); + assert_eq!(u8::from(IpTosEcn::Ce), 0b11); + } + + #[test] + fn u8_into_iptosecn() { + assert_eq!(IpTosEcn::from(0b00), IpTosEcn::NotEct); + assert_eq!(IpTosEcn::from(0b01), IpTosEcn::Ect1); + assert_eq!(IpTosEcn::from(0b10), IpTosEcn::Ect0); + assert_eq!(IpTosEcn::from(0b11), IpTosEcn::Ce); + } + + #[test] + fn iptosdscp_into_u8() { + assert_eq!(u8::from(IpTosDscp::Cs0), 0b0000_0000); + assert_eq!(u8::from(IpTosDscp::Cs1), 0b0010_0000); + assert_eq!(u8::from(IpTosDscp::Cs2), 0b0100_0000); + assert_eq!(u8::from(IpTosDscp::Cs3), 0b0110_0000); + assert_eq!(u8::from(IpTosDscp::Cs4), 0b1000_0000); + assert_eq!(u8::from(IpTosDscp::Cs5), 0b1010_0000); + assert_eq!(u8::from(IpTosDscp::Cs6), 0b1100_0000); + assert_eq!(u8::from(IpTosDscp::Cs7), 0b1110_0000); + assert_eq!(u8::from(IpTosDscp::Af11), 0b0010_1000); + assert_eq!(u8::from(IpTosDscp::Af12), 0b0011_0000); + assert_eq!(u8::from(IpTosDscp::Af13), 0b0011_1000); + assert_eq!(u8::from(IpTosDscp::Af21), 0b0100_1000); + assert_eq!(u8::from(IpTosDscp::Af22), 0b0101_0000); + assert_eq!(u8::from(IpTosDscp::Af23), 0b0101_1000); + assert_eq!(u8::from(IpTosDscp::Af31), 0b0110_1000); + assert_eq!(u8::from(IpTosDscp::Af32), 0b0111_0000); + assert_eq!(u8::from(IpTosDscp::Af33), 0b0111_1000); + assert_eq!(u8::from(IpTosDscp::Af41), 0b1000_1000); + assert_eq!(u8::from(IpTosDscp::Af42), 0b1001_0000); + assert_eq!(u8::from(IpTosDscp::Af43), 0b1001_1000); + assert_eq!(u8::from(IpTosDscp::Ef), 0b1011_1000); + assert_eq!(u8::from(IpTosDscp::VoiceAdmit), 0b1011_0000); + assert_eq!(u8::from(IpTosDscp::Le), 0b0000_0100); + } + + #[test] + fn u8_into_iptosdscp() { + assert_eq!(IpTosDscp::from(0b0000_0000), IpTosDscp::Cs0); + assert_eq!(IpTosDscp::from(0b0010_0000), IpTosDscp::Cs1); + assert_eq!(IpTosDscp::from(0b0100_0000), IpTosDscp::Cs2); + assert_eq!(IpTosDscp::from(0b0110_0000), IpTosDscp::Cs3); + assert_eq!(IpTosDscp::from(0b1000_0000), IpTosDscp::Cs4); + assert_eq!(IpTosDscp::from(0b1010_0000), IpTosDscp::Cs5); + assert_eq!(IpTosDscp::from(0b1100_0000), IpTosDscp::Cs6); + assert_eq!(IpTosDscp::from(0b1110_0000), IpTosDscp::Cs7); + assert_eq!(IpTosDscp::from(0b0010_1000), IpTosDscp::Af11); + assert_eq!(IpTosDscp::from(0b0011_0000), IpTosDscp::Af12); + assert_eq!(IpTosDscp::from(0b0011_1000), IpTosDscp::Af13); + assert_eq!(IpTosDscp::from(0b0100_1000), IpTosDscp::Af21); + assert_eq!(IpTosDscp::from(0b0101_0000), IpTosDscp::Af22); + assert_eq!(IpTosDscp::from(0b0101_1000), IpTosDscp::Af23); + assert_eq!(IpTosDscp::from(0b0110_1000), IpTosDscp::Af31); + assert_eq!(IpTosDscp::from(0b0111_0000), IpTosDscp::Af32); + assert_eq!(IpTosDscp::from(0b0111_1000), IpTosDscp::Af33); + assert_eq!(IpTosDscp::from(0b1000_1000), IpTosDscp::Af41); + assert_eq!(IpTosDscp::from(0b1001_0000), IpTosDscp::Af42); + assert_eq!(IpTosDscp::from(0b1001_1000), IpTosDscp::Af43); + assert_eq!(IpTosDscp::from(0b1011_1000), IpTosDscp::Ef); + assert_eq!(IpTosDscp::from(0b1011_0000), IpTosDscp::VoiceAdmit); + assert_eq!(IpTosDscp::from(0b0000_0100), IpTosDscp::Le); + } + + #[test] + fn iptosecn_into_iptos() { + let ecn = IpTosEcn::default(); + let iptos_ecn: IpTos = ecn.into(); + assert_eq!(u8::from(iptos_ecn), ecn as u8); + } + + #[test] + fn iptosdscp_into_iptos() { + let dscp = IpTosDscp::default(); + let iptos_dscp: IpTos = dscp.into(); + assert_eq!(u8::from(iptos_dscp), dscp as u8); + } +} |