diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
commit | 2aa4a82499d4becd2284cdb482213d541b8804dd (patch) | |
tree | b80bf8bf13c3766139fbacc530efd0dd9d54394c /third_party/rust/neqo-common/src | |
parent | Initial commit. (diff) | |
download | firefox-2aa4a82499d4becd2284cdb482213d541b8804dd.tar.xz firefox-2aa4a82499d4becd2284cdb482213d541b8804dd.zip |
Adding upstream version 86.0.1.upstream/86.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/neqo-common/src')
-rw-r--r-- | third_party/rust/neqo-common/src/codec.rs | 808 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/datagram.rs | 57 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/event.rs | 53 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/incrdecoder.rs | 260 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/lib.rs | 97 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/log.rs | 98 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/qlog.rs | 146 | ||||
-rw-r--r-- | third_party/rust/neqo-common/src/timer.rs | 386 |
8 files changed, 1905 insertions, 0 deletions
diff --git a/third_party/rust/neqo-common/src/codec.rs b/third_party/rust/neqo-common/src/codec.rs new file mode 100644 index 0000000000..b253a4acdb --- /dev/null +++ b/third_party/rust/neqo-common/src/codec.rs @@ -0,0 +1,808 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::convert::TryFrom; +use std::fmt::Debug; +use std::ops::{Deref, DerefMut}; + +use crate::hex_with_len; + +/// Decoder is a view into a byte array that has a read offset. Use it for parsing. +pub struct Decoder<'a> { + buf: &'a [u8], + offset: usize, +} + +impl<'a> Decoder<'a> { + /// Make a new view of the provided slice. + #[must_use] + pub fn new(buf: &[u8]) -> Decoder { + Decoder { buf, offset: 0 } + } + + /// Get the number of bytes remaining until the end. + #[must_use] + pub fn remaining(&self) -> usize { + self.buf.len() - self.offset + } + + /// The number of bytes from the underlying slice that have been decoded. + #[must_use] + pub fn offset(&self) -> usize { + self.offset + } + + /// Skip n bytes. Panics if `n` is too large. + pub fn skip(&mut self, n: usize) { + assert!(self.remaining() >= n); + self.offset += n; + } + + /// Skip helper that panics if `n` is `None` or not able to fit in `usize`. + fn skip_inner(&mut self, n: Option<u64>) { + self.skip(usize::try_from(n.unwrap()).unwrap()); + } + + /// Skip a vector. Panics if there isn't enough space. + /// Only use this for tests because we panic rather than reporting a result. + pub fn skip_vec(&mut self, n: usize) { + let len = self.decode_uint(n); + self.skip_inner(len); + } + + /// Skip a variable length vector. Panics if there isn't enough space. + /// Only use this for tests because we panic rather than reporting a result. + pub fn skip_vvec(&mut self) { + let len = self.decode_varint(); + self.skip_inner(len); + } + + /// Decodes (reads) a single byte. + pub fn decode_byte(&mut self) -> Option<u8> { + if self.remaining() < 1 { + return None; + } + let b = self.buf[self.offset]; + self.offset += 1; + Some(b) + } + + /// Provides the next byte without moving the read position. + pub fn peek_byte(&mut self) -> Option<u8> { + if self.remaining() < 1 { + None + } else { + Some(self.buf[self.offset]) + } + } + + /// Decodes arbitrary data. + pub fn decode(&mut self, n: usize) -> Option<&'a [u8]> { + if self.remaining() < n { + return None; + } + let res = &self.buf[self.offset..self.offset + n]; + self.offset += n; + Some(res) + } + + /// Decodes an unsigned integer of length 1..8. + pub fn decode_uint(&mut self, n: usize) -> Option<u64> { + assert!(n > 0 && n <= 8); + if self.remaining() < n { + return None; + } + let mut v = 0_u64; + for i in 0..n { + let b = self.buf[self.offset + i]; + v = v << 8 | u64::from(b); + } + self.offset += n; + Some(v) + } + + /// Decodes a QUIC varint. + pub fn decode_varint(&mut self) -> Option<u64> { + let b1 = match self.decode_byte() { + Some(b) => b, + None => return None, + }; + match b1 >> 6 { + 0 => Some(u64::from(b1 & 0x3f)), + 1 => Some((u64::from(b1 & 0x3f) << 8) | self.decode_uint(1)?), + 2 => Some((u64::from(b1 & 0x3f) << 24) | self.decode_uint(3)?), + 3 => Some((u64::from(b1 & 0x3f) << 56) | self.decode_uint(7)?), + _ => unreachable!(), + } + } + + /// Decodes the rest of the buffer. Infallible. + pub fn decode_remainder(&mut self) -> &'a [u8] { + let res = &self.buf[self.offset..]; + self.offset = self.buf.len(); + res + } + + fn decode_checked(&mut self, n: Option<u64>) -> Option<&'a [u8]> { + let len = match n { + Some(l) => l, + None => return None, + }; + if let Ok(l) = usize::try_from(len) { + self.decode(l) + } else { + // sizeof(usize) < sizeof(u64) and the value is greater than + // usize can hold. Throw away the rest of the input. + self.offset = self.buf.len(); + None + } + } + + /// Decodes a TLS-style length-prefixed buffer. + pub fn decode_vec(&mut self, n: usize) -> Option<&'a [u8]> { + let len = self.decode_uint(n); + self.decode_checked(len) + } + + /// Decodes a QUIC varint-length-prefixed buffer. + pub fn decode_vvec(&mut self) -> Option<&'a [u8]> { + let len = self.decode_varint(); + self.decode_checked(len) + } +} + +// Implement `Deref` for `Decoder` so that values can be examined without moving the cursor. +impl<'a> Deref for Decoder<'a> { + type Target = [u8]; + #[must_use] + fn deref(&self) -> &[u8] { + &self.buf[self.offset..] + } +} + +impl<'a> Debug for Decoder<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str(&hex_with_len(&self[..])) + } +} + +impl<'a> From<&'a [u8]> for Decoder<'a> { + #[must_use] + fn from(buf: &'a [u8]) -> Decoder<'a> { + Decoder::new(buf) + } +} + +impl<'a, T> From<&'a T> for Decoder<'a> +where + T: AsRef<[u8]>, +{ + #[must_use] + fn from(buf: &'a T) -> Decoder<'a> { + Decoder::new(buf.as_ref()) + } +} + +impl<'a, 'b> PartialEq<Decoder<'b>> for Decoder<'a> { + #[must_use] + fn eq(&self, other: &Decoder<'b>) -> bool { + self.buf == other.buf + } +} + +/// Encoder is good for building data structures. +#[derive(Clone, Default, PartialEq)] +pub struct Encoder { + buf: Vec<u8>, +} + +impl Encoder { + /// Static helper function for previewing the results of encoding without doing it. + #[must_use] + pub fn varint_len(v: u64) -> usize { + match () { + _ if v < (1 << 6) => 1, + _ if v < (1 << 14) => 2, + _ if v < (1 << 30) => 4, + _ if v < (1 << 62) => 8, + _ => panic!("Varint value too large"), + } + } + + /// Static helper to determine how long a varint-prefixed array encodes to. + #[must_use] + pub fn vvec_len(len: usize) -> usize { + Self::varint_len(u64::try_from(len).unwrap()) + len + } + + /// Default construction of an empty buffer. + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Construction of a buffer with a predetermined capacity. + #[must_use] + pub fn with_capacity(capacity: usize) -> Self { + Self { + buf: Vec::with_capacity(capacity), + } + } + + /// Get the capacity of the underlying buffer: the number of bytes that can be + /// written without causing an allocation to occur. + #[must_use] + pub fn capacity(&self) -> usize { + self.buf.capacity() + } + + /// Create a view of the current contents of the buffer. + /// Note: for a view of a slice, use `Decoder::new(&enc[s..e])` + #[must_use] + pub fn as_decoder(&self) -> Decoder { + Decoder::new(&self) + } + + /// Don't use this except in testing. + #[must_use] + pub fn from_hex(s: impl AsRef<str>) -> Self { + let s = s.as_ref(); + if s.len() % 2 != 0 { + panic!("Needs to be even length"); + } + + let cap = s.len() / 2; + let mut enc = Self::with_capacity(cap); + + for i in 0..cap { + let v = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16).unwrap(); + enc.encode_byte(v); + } + enc + } + + /// Generic encode routine for arbitrary data. + pub fn encode(&mut self, data: &[u8]) -> &mut Self { + self.buf.extend_from_slice(data); + self + } + + /// Encode a single byte. + pub fn encode_byte(&mut self, data: u8) -> &mut Self { + self.buf.push(data); + self + } + + /// Encode an integer of any size up to u64. + #[allow(clippy::cast_possible_truncation)] + pub fn encode_uint<T: Into<u64>>(&mut self, n: usize, v: T) -> &mut Self { + let v = v.into(); + assert!(n > 0 && n <= 8); + for i in 0..n { + self.encode_byte(((v >> (8 * (n - i - 1))) & 0xff) as u8); + } + self + } + + /// Encode a QUIC varint. + pub fn encode_varint<T: Into<u64>>(&mut self, v: T) -> &mut Self { + let v = v.into(); + match () { + _ if v < (1 << 6) => self.encode_uint(1, v), + _ if v < (1 << 14) => self.encode_uint(2, v | (1 << 14)), + _ if v < (1 << 30) => self.encode_uint(4, v | (2 << 30)), + _ if v < (1 << 62) => self.encode_uint(8, v | (3 << 62)), + _ => panic!("Varint value too large"), + }; + self + } + + /// Encode a vector in TLS style. + pub fn encode_vec(&mut self, n: usize, v: &[u8]) -> &mut Self { + self.encode_uint(n, u64::try_from(v.len()).unwrap()) + .encode(v) + } + + /// Encode a vector in TLS style using a closure for the contents. + #[allow(clippy::cast_possible_truncation)] + pub fn encode_vec_with<F: FnOnce(&mut Self)>(&mut self, n: usize, f: F) -> &mut Self { + let start = self.buf.len(); + self.buf.resize(self.buf.len() + n, 0); + f(self); + let len = self.buf.len() - start - n; + assert!(len < (1 << (n * 8))); + for i in 0..n { + self.buf[start + i] = ((len >> (8 * (n - i - 1))) & 0xff) as u8 + } + self + } + + /// Encode a vector with a varint length. + pub fn encode_vvec(&mut self, v: &[u8]) -> &mut Self { + self.encode_varint(u64::try_from(v.len()).unwrap()) + .encode(v) + } + + /// Encode a vector with a varint length using a closure. + #[allow(clippy::cast_possible_truncation)] + pub fn encode_vvec_with<F: FnOnce(&mut Self)>(&mut self, f: F) -> &mut Self { + let start = self.buf.len(); + // Optimize for short buffers, reserve a single byte for the length. + self.buf.resize(self.buf.len() + 1, 0); + f(self); + let len = self.buf.len() - start - 1; + + // Now to insert a varint for `len` before the encoded block. + // + // We now have one zero byte at `start`, followed by `len` encoded bytes: + // | 0 | ... encoded ... | + // We are going to encode a varint by putting the low bytes in that spare byte. + // Any additional bytes for the varint are put after the encoded blob: + // | low | ... encoded ... | varint high | + // Then we will rotate that entire piece right, by however many bytes we add: + // | varint high | low | ... encoded ... | + // As long as encoding more than 63 bytes is rare, this won't cost much relative + // to the convenience of being able to use this function. + + let v = u64::try_from(len).expect("encoded value fits in a u64"); + // The lower order byte fits before the inserted block of bytes. + self.buf[start] = (v & 0xff) as u8; + let (count, bits) = match () { + // Great. The byte we have is enough. + _ if v < (1 << 6) => return self, + _ if v < (1 << 14) => (1, 1 << 6), + _ if v < (1 << 30) => (3, 2 << 22), + _ if v < (1 << 62) => (7, 3 << 54), + _ => panic!("Varint value too large"), + }; + // Now, we need to encode the high bits after the main block, ... + self.encode_uint(count, (v >> 8) | bits); + // ..., then rotate the entire thing right by the same amount. + self.buf[start..].rotate_right(count); + self + } + + /// Truncate the encoder to the given size. + pub fn truncate(&mut self, len: usize) { + self.buf.truncate(len); + } + + /// Pad the buffer to `len` with bytes set to `v`. + pub fn pad_to(&mut self, len: usize, v: u8) { + if len > self.buf.len() { + self.buf.resize(len, v); + } + } +} + +impl Debug for Encoder { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str(&hex_with_len(self)) + } +} + +impl AsRef<[u8]> for Encoder { + fn as_ref(&self) -> &[u8] { + self.buf.as_ref() + } +} + +impl<'a> From<Decoder<'a>> for Encoder { + #[must_use] + fn from(dec: Decoder<'a>) -> Self { + Self::from(&dec.buf[dec.offset..]) + } +} + +impl From<&[u8]> for Encoder { + #[must_use] + fn from(buf: &[u8]) -> Self { + Self { + buf: Vec::from(buf), + } + } +} + +impl Into<Vec<u8>> for Encoder { + #[must_use] + fn into(self) -> Vec<u8> { + self.buf + } +} + +impl Deref for Encoder { + type Target = [u8]; + #[must_use] + fn deref(&self) -> &[u8] { + &self.buf[..] + } +} + +impl DerefMut for Encoder { + fn deref_mut(&mut self) -> &mut [u8] { + &mut self.buf[..] + } +} + +#[cfg(test)] +mod tests { + use super::{Decoder, Encoder}; + + #[test] + fn decode() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode(2).unwrap(), &[0x01, 0x23]); + assert!(dec.decode(2).is_none()); + } + + #[test] + fn decode_byte() { + let enc = Encoder::from_hex("0123"); + let mut dec = enc.as_decoder(); + + assert_eq!(dec.decode_byte().unwrap(), 0x01); + assert_eq!(dec.decode_byte().unwrap(), 0x23); + assert!(dec.decode_byte().is_none()); + } + + #[test] + fn decode_byte_short() { + let enc = Encoder::from_hex(""); + let mut dec = enc.as_decoder(); + assert!(dec.decode_byte().is_none()); + } + + #[test] + fn decode_remainder() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode_remainder(), &[0x01, 0x23, 0x45]); + assert!(dec.decode(2).is_none()); + + let mut dec = Decoder::from(&[]); + assert_eq!(dec.decode_remainder().len(), 0); + } + + #[test] + fn decode_vec() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode_vec(1).expect("read one octet length"), &[0x23]); + assert_eq!(dec.remaining(), 1); + + let enc = Encoder::from_hex("00012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode_vec(2).expect("read two octet length"), &[0x23]); + assert_eq!(dec.remaining(), 1); + } + + #[test] + fn decode_vec_short() { + // The length is too short. + let enc = Encoder::from_hex("02"); + let mut dec = enc.as_decoder(); + assert!(dec.decode_vec(2).is_none()); + + // The body is too short. + let enc = Encoder::from_hex("0200"); + let mut dec = enc.as_decoder(); + assert!(dec.decode_vec(1).is_none()); + } + + #[test] + fn decode_vvec() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode_vvec().expect("read one octet length"), &[0x23]); + assert_eq!(dec.remaining(), 1); + + let enc = Encoder::from_hex("40012345"); + let mut dec = enc.as_decoder(); + assert_eq!(dec.decode_vvec().expect("read two octet length"), &[0x23]); + assert_eq!(dec.remaining(), 1); + } + + #[test] + fn decode_vvec_short() { + // The length field is too short. + let enc = Encoder::from_hex("ff"); + let mut dec = enc.as_decoder(); + assert!(dec.decode_vvec().is_none()); + + let enc = Encoder::from_hex("405500"); + let mut dec = enc.as_decoder(); + assert!(dec.decode_vvec().is_none()); + } + + #[test] + fn skip() { + let enc = Encoder::from_hex("ffff"); + let mut dec = enc.as_decoder(); + dec.skip(1); + assert_eq!(dec.remaining(), 1); + } + + #[test] + #[should_panic] + fn skip_too_much() { + let enc = Encoder::from_hex("ff"); + let mut dec = enc.as_decoder(); + dec.skip(2); + } + + #[test] + fn skip_vec() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + dec.skip_vec(1); + assert_eq!(dec.remaining(), 1); + } + + #[test] + #[should_panic] + fn skip_vec_too_much() { + let enc = Encoder::from_hex("ff1234"); + let mut dec = enc.as_decoder(); + dec.skip_vec(1); + } + + #[test] + #[should_panic] + fn skip_vec_short_length() { + let enc = Encoder::from_hex("ff"); + let mut dec = enc.as_decoder(); + dec.skip_vec(4); + } + #[test] + fn skip_vvec() { + let enc = Encoder::from_hex("012345"); + let mut dec = enc.as_decoder(); + dec.skip_vvec(); + assert_eq!(dec.remaining(), 1); + } + + #[test] + #[should_panic] + fn skip_vvec_too_much() { + let enc = Encoder::from_hex("0f1234"); + let mut dec = enc.as_decoder(); + dec.skip_vvec(); + } + + #[test] + #[should_panic] + fn skip_vvec_short_length() { + let enc = Encoder::from_hex("ff"); + let mut dec = enc.as_decoder(); + dec.skip_vvec(); + } + + #[test] + fn encoded_lengths() { + assert_eq!(Encoder::varint_len(0), 1); + assert_eq!(Encoder::varint_len(0x3f), 1); + assert_eq!(Encoder::varint_len(0x40), 2); + assert_eq!(Encoder::varint_len(0x3fff), 2); + assert_eq!(Encoder::varint_len(0x4000), 4); + assert_eq!(Encoder::varint_len(0x3fff_ffff), 4); + assert_eq!(Encoder::varint_len(0x4000_0000), 8); + } + + #[test] + #[should_panic] + fn encoded_length_oob() { + let _ = Encoder::varint_len(1 << 62); + } + + #[test] + fn encoded_vvec_lengths() { + assert_eq!(Encoder::vvec_len(0), 1); + assert_eq!(Encoder::vvec_len(0x3f), 0x40); + assert_eq!(Encoder::vvec_len(0x40), 0x42); + assert_eq!(Encoder::vvec_len(0x3fff), 0x4001); + assert_eq!(Encoder::vvec_len(0x4000), 0x4004); + assert_eq!(Encoder::vvec_len(0x3fff_ffff), 0x4000_0003); + assert_eq!(Encoder::vvec_len(0x4000_0000), 0x4000_0008); + } + + #[test] + #[should_panic] + fn encoded_vvec_length_oob() { + let _ = Encoder::vvec_len(1 << 62); + } + + #[test] + fn encode_byte() { + let mut enc = Encoder::default(); + + enc.encode_byte(1); + assert_eq!(enc, Encoder::from_hex("01")); + + enc.encode_byte(0xfe); + assert_eq!(enc, Encoder::from_hex("01fe")); + } + + #[test] + fn encode() { + let mut enc = Encoder::default(); + enc.encode(&[1, 2, 3]); + assert_eq!(enc, Encoder::from_hex("010203")); + } + + #[test] + fn encode_uint() { + let mut enc = Encoder::default(); + enc.encode_uint(2, 10_u8); // 000a + enc.encode_uint(1, 257_u16); // 01 + enc.encode_uint(3, 0xff_ffff_u32); // ffffff + enc.encode_uint(8, 0xfedc_ba98_7654_3210_u64); + assert_eq!(enc, Encoder::from_hex("000a01fffffffedcba9876543210")); + } + + #[test] + fn builder_from_slice() { + let slice = &[1, 2, 3]; + let enc = Encoder::from(&slice[..]); + assert_eq!(enc, Encoder::from_hex("010203")) + } + + #[test] + fn builder_inas_decoder() { + let enc = Encoder::from_hex("010203"); + let buf = &[1, 2, 3]; + assert_eq!(enc.as_decoder(), Decoder::new(buf)); + } + + struct UintTestCase { + v: u64, + b: String, + } + + macro_rules! uint_tc { + [$( $v:expr => $b:expr ),+ $(,)?] => { + vec![ $( UintTestCase { v: $v, b: String::from($b) } ),+] + }; + } + + #[test] + fn varint_encode_decode() { + let cases = uint_tc![ + 0 => "00", + 1 => "01", + 63 => "3f", + 64 => "4040", + 16383 => "7fff", + 16384 => "80004000", + (1 << 30) - 1 => "bfffffff", + 1 << 30 => "c000000040000000", + (1 << 62) - 1 => "ffffffffffffffff", + ]; + + for c in cases { + assert_eq!(Encoder::varint_len(c.v), c.b.len() / 2); + + let mut enc = Encoder::default(); + enc.encode_varint(c.v); + let encoded = Encoder::from_hex(&c.b); + assert_eq!(enc, encoded); + + let mut dec = encoded.as_decoder(); + let v = dec.decode_varint().expect("should decode"); + assert_eq!(dec.remaining(), 0); + assert_eq!(v, c.v); + } + } + + #[test] + fn varint_decode_long_zero() { + for c in &["4000", "80000000", "c000000000000000"] { + let encoded = Encoder::from_hex(c); + let mut dec = encoded.as_decoder(); + let v = dec.decode_varint().expect("should decode"); + assert_eq!(dec.remaining(), 0); + assert_eq!(v, 0); + } + } + + #[test] + fn varint_decode_short() { + for c in &["40", "800000", "c0000000000000"] { + let encoded = Encoder::from_hex(c); + let mut dec = encoded.as_decoder(); + assert!(dec.decode_varint().is_none()); + } + } + + #[test] + fn encode_vec() { + let mut enc = Encoder::default(); + enc.encode_vec(2, &[1, 2, 0x34]); + assert_eq!(enc, Encoder::from_hex("0003010234")); + } + + #[test] + fn encode_vec_with() { + let mut enc = Encoder::default(); + enc.encode_vec_with(2, |enc_inner| { + enc_inner.encode(&Encoder::from_hex("02")); + }); + assert_eq!(enc, Encoder::from_hex("000102")); + } + + #[test] + #[should_panic] + fn encode_vec_with_overflow() { + let mut enc = Encoder::default(); + enc.encode_vec_with(1, |enc_inner| { + enc_inner.encode(&[0xb0; 256]); + }); + } + + #[test] + fn encode_vvec() { + let mut enc = Encoder::default(); + enc.encode_vvec(&[1, 2, 0x34]); + assert_eq!(enc, Encoder::from_hex("03010234")); + } + + #[test] + fn encode_vvec_with() { + let mut enc = Encoder::default(); + enc.encode_vvec_with(|enc_inner| { + enc_inner.encode(&Encoder::from_hex("02")); + }); + assert_eq!(enc, Encoder::from_hex("0102")); + } + + #[test] + fn encode_vvec_with_longer() { + let mut enc = Encoder::default(); + enc.encode_vvec_with(|enc_inner| { + enc_inner.encode(&[0xa5; 65]); + }); + let mut v: Vec<u8> = enc.into(); + let _ = v.split_off(3); + assert_eq!(v, vec![0x40, 0x41, 0xa5]); + } + + // Test that Deref to &[u8] works for Encoder. + #[test] + fn encode_builder() { + let mut enc = Encoder::from_hex("ff"); + let enc2 = Encoder::from_hex("010234"); + enc.encode(&enc2); + assert_eq!(enc, Encoder::from_hex("ff010234")); + } + + // Test that Deref to &[u8] works for Decoder. + #[test] + fn encode_view() { + let mut enc = Encoder::from_hex("ff"); + let enc2 = Encoder::from_hex("010234"); + let v = enc2.as_decoder(); + enc.encode(&v); + assert_eq!(enc, Encoder::from_hex("ff010234")); + } + + #[test] + fn encode_mutate() { + let mut enc = Encoder::from_hex("010234"); + enc[0] = 0xff; + assert_eq!(enc, Encoder::from_hex("ff0234")); + } + + #[test] + fn pad() { + let mut enc = Encoder::from_hex("010234"); + enc.pad_to(5, 0); + assert_eq!(enc, Encoder::from_hex("0102340000")); + enc.pad_to(4, 0); + assert_eq!(enc, Encoder::from_hex("0102340000")); + enc.pad_to(7, 0xc2); + assert_eq!(enc, Encoder::from_hex("0102340000c2c2")); + } +} diff --git a/third_party/rust/neqo-common/src/datagram.rs b/third_party/rust/neqo-common/src/datagram.rs new file mode 100644 index 0000000000..34b4d921e9 --- /dev/null +++ b/third_party/rust/neqo-common/src/datagram.rs @@ -0,0 +1,57 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::net::SocketAddr; +use std::ops::Deref; + +use crate::hex_with_len; + +#[derive(PartialEq, Clone)] +pub struct Datagram { + src: SocketAddr, + dst: SocketAddr, + d: Vec<u8>, +} + +impl Datagram { + pub fn new<V: Into<Vec<u8>>>(src: SocketAddr, dst: SocketAddr, d: V) -> Self { + Self { + src, + dst, + d: d.into(), + } + } + + #[must_use] + pub fn source(&self) -> SocketAddr { + self.src + } + + #[must_use] + pub fn destination(&self) -> SocketAddr { + self.dst + } +} + +impl Deref for Datagram { + type Target = Vec<u8>; + #[must_use] + fn deref(&self) -> &Self::Target { + &self.d + } +} + +impl std::fmt::Debug for Datagram { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "Datagram {:?}->{:?}: {}", + self.src, + self.dst, + hex_with_len(&self.d) + ) + } +} diff --git a/third_party/rust/neqo-common/src/event.rs b/third_party/rust/neqo-common/src/event.rs new file mode 100644 index 0000000000..8598383e76 --- /dev/null +++ b/third_party/rust/neqo-common/src/event.rs @@ -0,0 +1,53 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::iter::Iterator; +use std::marker::PhantomData; + +/// An event provider is able to generate a stream of events. +pub trait Provider { + type Event; + + /// Get the next event. + #[must_use] + fn next_event(&mut self) -> Option<Self::Event>; + + /// Determine whether there are pending events. + #[must_use] + fn has_events(&self) -> bool; + + /// Construct an iterator that produces all events. + fn events(&'_ mut self) -> Iter<'_, Self, Self::Event> { + Iter::new(self) + } +} + +pub struct Iter<'a, P, E> +where + P: ?Sized, +{ + p: &'a mut P, + _e: PhantomData<E>, +} + +impl<'a, P, E> Iter<'a, P, E> +where + P: Provider<Event = E> + ?Sized, +{ + fn new(p: &'a mut P) -> Self { + Self { p, _e: PhantomData } + } +} + +impl<'a, P, E> Iterator for Iter<'a, P, E> +where + P: Provider<Event = E>, +{ + type Item = E; + fn next(&mut self) -> Option<Self::Item> { + self.p.next_event() + } +} diff --git a/third_party/rust/neqo-common/src/incrdecoder.rs b/third_party/rust/neqo-common/src/incrdecoder.rs new file mode 100644 index 0000000000..1ece719e2c --- /dev/null +++ b/third_party/rust/neqo-common/src/incrdecoder.rs @@ -0,0 +1,260 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::cmp::min; +use std::mem; + +use crate::codec::Decoder; + +#[derive(Clone, Debug, Default)] +pub struct IncrementalDecoderUint { + v: u64, + remaining: Option<usize>, +} + +impl IncrementalDecoderUint { + #[must_use] + pub fn min_remaining(&self) -> usize { + self.remaining.unwrap_or(1) + } + + pub fn consume(&mut self, dv: &mut Decoder) -> Option<u64> { + if let Some(r) = &mut self.remaining { + let amount = min(*r, dv.remaining()); + if amount < 8 { + self.v <<= amount * 8; + } + self.v |= dv.decode_uint(amount).unwrap(); + *r -= amount; + if *r == 0 { + Some(self.v) + } else { + None + } + } else { + let (v, remaining) = match dv.decode_byte() { + Some(b) => ( + u64::from(b & 0x3f), + match b >> 6 { + 0 => 0, + 1 => 1, + 2 => 3, + 3 => 7, + _ => unreachable!(), + }, + ), + None => unreachable!(), + }; + self.remaining = Some(remaining); + self.v = v; + if remaining == 0 { + Some(v) + } else { + None + } + } + } + + #[must_use] + pub fn decoding_in_progress(&self) -> bool { + self.remaining.is_some() + } +} + +#[derive(Clone, Debug)] +pub struct IncrementalDecoderBuffer { + v: Vec<u8>, + remaining: usize, +} + +impl IncrementalDecoderBuffer { + #[must_use] + pub fn new(n: usize) -> Self { + Self { + v: Vec::new(), + remaining: n, + } + } + + #[must_use] + pub fn min_remaining(&self) -> usize { + self.remaining + } + + pub fn consume(&mut self, dv: &mut Decoder) -> Option<Vec<u8>> { + let amount = min(self.remaining, dv.remaining()); + let b = dv.decode(amount).unwrap(); + self.v.extend_from_slice(b); + self.remaining -= amount; + if self.remaining == 0 { + Some(mem::replace(&mut self.v, Vec::new())) + } else { + None + } + } +} + +#[derive(Clone, Debug)] +pub struct IncrementalDecoderIgnore { + remaining: usize, +} + +impl IncrementalDecoderIgnore { + #[must_use] + pub fn new(n: usize) -> Self { + Self { remaining: n } + } + + #[must_use] + pub fn min_remaining(&self) -> usize { + self.remaining + } + + pub fn consume(&mut self, dv: &mut Decoder) -> bool { + let amount = min(self.remaining, dv.remaining()); + let _ = dv.decode(amount); + self.remaining -= amount; + self.remaining == 0 + } +} + +#[cfg(test)] +mod tests { + use super::{ + Decoder, IncrementalDecoderBuffer, IncrementalDecoderIgnore, IncrementalDecoderUint, + }; + use crate::codec::Encoder; + + #[test] + fn buffer_incremental() { + let b = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + let mut dec = IncrementalDecoderBuffer::new(b.len()); + let mut i = 0; + while i < b.len() { + // Feed in b in increasing-sized chunks. + let incr = if i < b.len() / 2 { i + 1 } else { b.len() - i }; + let mut dv = Decoder::from(&b[i..i + incr]); + i += incr; + match dec.consume(&mut dv) { + None => { + assert!(i < b.len()); + } + Some(res) => { + assert_eq!(i, b.len()); + assert_eq!(res, b); + } + } + } + } + + struct UintTestCase { + b: String, + v: u64, + } + + impl UintTestCase { + pub fn run(&self) { + eprintln!( + "IncrementalDecoderUint decoder with {:?} ; expect {:?}", + self.b, self.v + ); + + let decoder = IncrementalDecoderUint::default(); + let mut db = Encoder::from_hex(&self.b); + // Add padding so that we can verify that the reader doesn't over-consume. + db.encode_byte(0xff); + + for tail in 1..db.len() { + let split = db.len() - tail; + let mut dv = Decoder::from(&db[0..split]); + eprintln!(" split at {}: {:?}", split, dv); + + // Clone the basic decoder for each iteration of the loop. + let mut dec = decoder.clone(); + let mut res = None; + while dv.remaining() > 0 { + res = dec.consume(&mut dv); + } + assert!(dec.min_remaining() < tail); + + if tail > 1 { + assert_eq!(res, None); + assert!(dec.min_remaining() > 0); + let mut dv = Decoder::from(&db[split..]); + eprintln!(" split remainder {}: {:?}", split, dv); + res = dec.consume(&mut dv); + assert_eq!(dv.remaining(), 1); + } + + assert_eq!(dec.min_remaining(), 0); + assert_eq!(res.unwrap(), self.v); + } + } + } + + macro_rules! uint_tc { + [$( $b:expr => $v:expr ),+ $(,)?] => { + vec![ $( UintTestCase { b: String::from($b), v: $v, } ),+] + }; + } + + #[test] + fn varint() { + for c in uint_tc![ + "00" => 0, + "01" => 1, + "3f" => 63, + "4040" => 64, + "7fff" => 16383, + "80004000" => 16384, + "bfffffff" => (1 << 30) - 1, + "c000000040000000" => 1 << 30, + "ffffffffffffffff" => (1 << 62) - 1, + ] { + c.run(); + } + } + + #[test] + fn zero_len() { + let enc = Encoder::from_hex("ff"); + let mut dec = Decoder::new(&enc); + let mut incr = IncrementalDecoderBuffer::new(0); + assert_eq!(incr.consume(&mut dec), Some(Vec::new())); + assert_eq!(dec.remaining(), enc.len()); + } + + #[test] + fn ignore() { + let db = Encoder::from_hex("12345678ff"); + + let decoder = IncrementalDecoderIgnore::new(4); + + for tail in 1..db.len() { + let split = db.len() - tail; + let mut dv = Decoder::from(&db[0..split]); + eprintln!(" split at {}: {:?}", split, dv); + + // Clone the basic decoder for each iteration of the loop. + let mut dec = decoder.clone(); + let mut res = dec.consume(&mut dv); + assert_eq!(dv.remaining(), 0); + assert!(dec.min_remaining() < tail); + + if tail > 1 { + assert_eq!(res, false); + assert!(dec.min_remaining() > 0); + let mut dv = Decoder::from(&db[split..]); + eprintln!(" split remainder {}: {:?}", split, dv); + res = dec.consume(&mut dv); + assert_eq!(dv.remaining(), 1); + } + + assert_eq!(dec.min_remaining(), 0); + assert_eq!(res, true); + } + } +} diff --git a/third_party/rust/neqo-common/src/lib.rs b/third_party/rust/neqo-common/src/lib.rs new file mode 100644 index 0000000000..ba049e9532 --- /dev/null +++ b/third_party/rust/neqo-common/src/lib.rs @@ -0,0 +1,97 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![cfg_attr(feature = "deny-warnings", deny(warnings))] +#![warn(clippy::pedantic)] + +mod codec; +mod datagram; +pub mod event; +mod incrdecoder; +pub mod log; +pub mod qlog; +pub mod timer; + +pub use self::codec::{Decoder, Encoder}; +pub use self::datagram::Datagram; +pub use self::incrdecoder::{ + IncrementalDecoderBuffer, IncrementalDecoderIgnore, IncrementalDecoderUint, +}; + +#[macro_use] +extern crate lazy_static; + +#[must_use] +pub fn hex(buf: impl AsRef<[u8]>) -> String { + let mut ret = String::with_capacity(buf.as_ref().len() * 2); + for b in buf.as_ref() { + ret.push_str(&format!("{:02x}", b)); + } + ret +} + +#[must_use] +pub fn hex_snip_middle(buf: impl AsRef<[u8]>) -> String { + const SHOW_LEN: usize = 8; + let buf = buf.as_ref(); + if buf.len() <= SHOW_LEN * 2 { + hex_with_len(buf) + } else { + let mut ret = String::with_capacity(SHOW_LEN * 2 + 16); + ret.push_str(&format!("[{}]: ", buf.len())); + for b in &buf[..SHOW_LEN] { + ret.push_str(&format!("{:02x}", b)); + } + ret.push_str(".."); + for b in &buf[buf.len() - SHOW_LEN..] { + ret.push_str(&format!("{:02x}", b)); + } + ret + } +} + +#[must_use] +pub fn hex_with_len(buf: impl AsRef<[u8]>) -> String { + let buf = buf.as_ref(); + let mut ret = String::with_capacity(10 + buf.len() * 2); + ret.push_str(&format!("[{}]: ", buf.len())); + for b in buf { + ret.push_str(&format!("{:02x}", b)); + } + ret +} + +#[must_use] +pub const fn const_max(a: usize, b: usize) -> usize { + [a, b][(a < b) as usize] +} +#[must_use] +pub const fn const_min(a: usize, b: usize) -> usize { + [a, b][(a >= b) as usize] +} + +#[derive(Debug, PartialEq, Copy, Clone)] +/// Client or Server. +pub enum Role { + Client, + Server, +} + +impl Role { + #[must_use] + pub fn remote(self) -> Self { + match self { + Self::Client => Self::Server, + Self::Server => Self::Client, + } + } +} + +impl ::std::fmt::Display for Role { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{:?}", self) + } +} diff --git a/third_party/rust/neqo-common/src/log.rs b/third_party/rust/neqo-common/src/log.rs new file mode 100644 index 0000000000..0573cb45e9 --- /dev/null +++ b/third_party/rust/neqo-common/src/log.rs @@ -0,0 +1,98 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::io::Write; +use std::sync::Once; +use std::time::Instant; + +#[macro_export] +macro_rules! do_log { + (target: $target:expr, $lvl:expr, $($arg:tt)+) => ({ + let lvl = $lvl; + if lvl <= ::log::max_level() { + ::log::__private_api_log( + ::log::__log_format_args!($($arg)+), + lvl, + &($target, ::log::__log_module_path!(), ::log::__log_file!(), ::log::__log_line!()), + ); + } + }); + ($lvl:expr, $($arg:tt)+) => ($crate::do_log!(target: ::log::__log_module_path!(), $lvl, $($arg)+)) +} + +#[macro_export] +macro_rules! log_subject { + ($lvl:expr, $subject:expr) => {{ + if $lvl <= ::log::max_level() { + format!("{}", $subject) + } else { + String::new() + } + }}; +} + +use env_logger::Builder; + +static INIT_ONCE: Once = Once::new(); + +lazy_static! { + static ref START_TIME: Instant = Instant::now(); +} + +pub fn init() { + INIT_ONCE.call_once(|| { + let mut builder = Builder::from_env("RUST_LOG"); + builder.format(|buf, record| { + let elapsed = START_TIME.elapsed(); + writeln!( + buf, + "{}s{:3}ms {} {}", + elapsed.as_secs(), + elapsed.as_millis() % 1000, + record.level(), + record.args() + ) + }); + if let Err(e) = builder.try_init() { + do_log!(::log::Level::Info, "Logging initialization error {:?}", e); + } else { + do_log!(::log::Level::Info, "Logging initialized"); + } + }); +} + +#[macro_export] +macro_rules! log_invoke { + ($lvl:expr, $ctx:expr, $($arg:tt)*) => ( { + ::neqo_common::log::init(); + ::neqo_common::do_log!($lvl, "[{}] {}", $ctx, format!($($arg)*)); + } ) +} +#[macro_export] +macro_rules! qerror { + ([$ctx:expr], $($arg:tt)*) => (::neqo_common::log_invoke!(::log::Level::Error, $ctx, $($arg)*);); + ($($arg:tt)*) => ( { ::neqo_common::log::init(); ::neqo_common::do_log!(::log::Level::Error, $($arg)*); } ); +} +#[macro_export] +macro_rules! qwarn { + ([$ctx:expr], $($arg:tt)*) => (::neqo_common::log_invoke!(::log::Level::Warn, $ctx, $($arg)*);); + ($($arg:tt)*) => ( { ::neqo_common::log::init(); ::neqo_common::do_log!(::log::Level::Warn, $($arg)*); } ); +} +#[macro_export] +macro_rules! qinfo { + ([$ctx:expr], $($arg:tt)*) => (::neqo_common::log_invoke!(::log::Level::Info, $ctx, $($arg)*);); + ($($arg:tt)*) => ( { ::neqo_common::log::init(); ::neqo_common::do_log!(::log::Level::Info, $($arg)*); } ); +} +#[macro_export] +macro_rules! qdebug { + ([$ctx:expr], $($arg:tt)*) => (::neqo_common::log_invoke!(::log::Level::Debug, $ctx, $($arg)*);); + ($($arg:tt)*) => ( { ::neqo_common::log::init(); ::neqo_common::do_log!(::log::Level::Debug, $($arg)*); } ); +} +#[macro_export] +macro_rules! qtrace { + ([$ctx:expr], $($arg:tt)*) => (::neqo_common::log_invoke!(::log::Level::Trace, $ctx, $($arg)*);); + ($($arg:tt)*) => ( { ::neqo_common::log::init(); ::neqo_common::do_log!(::log::Level::Trace, $($arg)*); } ); +} diff --git a/third_party/rust/neqo-common/src/qlog.rs b/third_party/rust/neqo-common/src/qlog.rs new file mode 100644 index 0000000000..8405d091d6 --- /dev/null +++ b/third_party/rust/neqo-common/src/qlog.rs @@ -0,0 +1,146 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::cell::RefCell; +use std::fmt; +use std::path::{Path, PathBuf}; +use std::rc::Rc; +use std::time::SystemTime; + +use chrono::{DateTime, Utc}; +use qlog::{ + self, CommonFields, Configuration, QlogStreamer, TimeUnits, Trace, VantagePoint, + VantagePointType, +}; + +use crate::Role; + +#[allow(clippy::module_name_repetitions)] +#[derive(Debug, Clone)] +pub struct NeqoQlog { + inner: Rc<RefCell<Option<NeqoQlogShared>>>, +} + +pub struct NeqoQlogShared { + qlog_path: PathBuf, + streamer: QlogStreamer, +} + +impl NeqoQlog { + /// Create an enabled `NeqoQlog` configuration. + /// # Errors + /// + /// Will return `qlog::Error` if cannot write to the new log. + pub fn enabled( + mut streamer: QlogStreamer, + qlog_path: impl AsRef<Path>, + ) -> Result<Self, qlog::Error> { + streamer.start_log()?; + + Ok(Self { + inner: Rc::new(RefCell::new(Some(NeqoQlogShared { + streamer, + qlog_path: qlog_path.as_ref().to_owned(), + }))), + }) + } + + /// Create a disabled `NeqoQlog` configuration. + #[must_use] + pub fn disabled() -> Self { + Self { + inner: Rc::new(RefCell::new(None)), + } + } + + /// If logging enabled, closure may generate an event to be logged. + pub fn add_event<F>(&mut self, f: F) + where + F: FnOnce() -> Option<qlog::event::Event>, + { + self.add_event_with_stream(|s| { + if let Some(evt) = f() { + s.add_event(evt)?; + } + Ok(()) + }); + } + + /// If logging enabled, closure is given the Qlog stream to write events and + /// frames to. + pub fn add_event_with_stream<F>(&mut self, f: F) + where + F: FnOnce(&mut QlogStreamer) -> Result<(), qlog::Error>, + { + if let Some(inner) = self.inner.borrow_mut().as_mut() { + if let Err(e) = f(&mut inner.streamer) { + crate::do_log!( + ::log::Level::Error, + "Qlog event generation failed with error {}; closing qlog.", + e + ); + *self.inner.borrow_mut() = None; + } + } + } +} + +impl Default for NeqoQlog { + fn default() -> Self { + Self::disabled() + } +} + +impl fmt::Debug for NeqoQlogShared { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "NeqoQlog writing to {}", self.qlog_path.display()) + } +} + +impl Drop for NeqoQlogShared { + fn drop(&mut self) { + if let Err(e) = self.streamer.finish_log() { + crate::do_log!(::log::Level::Error, "Error dropping NeqoQlog: {}", e) + } + } +} + +#[must_use] +pub fn new_trace(role: Role) -> qlog::Trace { + Trace { + vantage_point: VantagePoint { + name: Some(format!("neqo-{}", role)), + ty: match role { + Role::Client => VantagePointType::Client, + Role::Server => VantagePointType::Server, + }, + flow: None, + }, + title: Some(format!("neqo-{} trace", role)), + description: Some("Example qlog trace description".to_string()), + configuration: Some(Configuration { + time_offset: Some("0".into()), + time_units: Some(TimeUnits::Us), + original_uris: None, + }), + common_fields: Some(CommonFields { + group_id: None, + protocol_type: None, + reference_time: Some({ + let system_time = SystemTime::now(); + let datetime: DateTime<Utc> = system_time.into(); + datetime.to_rfc3339() + }), + }), + event_fields: vec![ + "relative_time".to_string(), + "category".to_string(), + "event".to_string(), + "data".to_string(), + ], + events: Vec::new(), + } +} diff --git a/third_party/rust/neqo-common/src/timer.rs b/third_party/rust/neqo-common/src/timer.rs new file mode 100644 index 0000000000..4ddaf94fa6 --- /dev/null +++ b/third_party/rust/neqo-common/src/timer.rs @@ -0,0 +1,386 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::convert::TryFrom; +use std::mem; +use std::time::{Duration, Instant}; + +/// Internal structure for a timer item. +struct TimerItem<T> { + time: Instant, + item: T, +} + +impl<T> TimerItem<T> { + fn time(ti: &Self) -> Instant { + ti.time + } +} + +/// A timer queue. +/// This uses a classic timer wheel arrangement, with some characteristics that might be considered peculiar. +/// Each slot in the wheel is sorted (complexity O(N) insertions, but O(logN) to find cut points). +/// Time is relative, the wheel has an origin time and it is unable to represent times that are more than +/// `granularity * capacity` past that time. +pub struct Timer<T> { + items: Vec<Vec<TimerItem<T>>>, + now: Instant, + granularity: Duration, + cursor: usize, +} + +impl<T> Timer<T> { + /// Construct a new wheel at the given granularity, starting at the given time. + pub fn new(now: Instant, granularity: Duration, capacity: usize) -> Self { + assert!(u32::try_from(capacity).is_ok()); + assert!(granularity.as_nanos() > 0); + let mut items = Vec::with_capacity(capacity); + items.resize_with(capacity, Default::default); + Self { + items, + now, + granularity, + cursor: 0, + } + } + + /// Return a reference to the time of the next entry. + #[must_use] + pub fn next_time(&self) -> Option<Instant> { + for i in 0..self.items.len() { + let idx = self.bucket(i); + if let Some(t) = self.items[idx].first() { + return Some(t.time); + } + } + None + } + + /// Get the full span of time that this can cover. + /// Two timers cannot be more than this far apart. + /// In practice, this value is less by one amount of the timer granularity. + #[inline] + #[allow(clippy::cast_possible_truncation)] // guarded by assertion + #[must_use] + pub fn span(&self) -> Duration { + self.granularity * (self.items.len() as u32) + } + + /// For the given `time`, get the number of whole buckets in the future that is. + #[inline] + #[allow(clippy::cast_possible_truncation)] // guarded by assertion + fn delta(&self, time: Instant) -> usize { + // This really should use Instant::div_duration(), but it can't yet. + ((time - self.now).as_nanos() / self.granularity.as_nanos()) as usize + } + + #[inline] + fn time_bucket(&self, time: Instant) -> usize { + self.bucket(self.delta(time)) + } + + #[inline] + fn bucket(&self, delta: usize) -> usize { + debug_assert!(delta < self.items.len()); + (self.cursor + delta) % self.items.len() + } + + /// Slide forward in time by `n * self.granularity`. + #[allow(clippy::unknown_clippy_lints)] // Until we require rust 1.45. + #[allow(clippy::cast_possible_truncation, clippy::reversed_empty_ranges)] + // cast_possible_truncation is ok because we have an assertion guard. + // reversed_empty_ranges is to avoid different types on the if/else. + fn tick(&mut self, n: usize) { + let new = self.bucket(n); + let iter = if new < self.cursor { + (self.cursor..self.items.len()).chain(0..new) + } else { + (self.cursor..new).chain(0..0) + }; + for i in iter { + assert!(self.items[i].is_empty()); + } + self.now += self.granularity * (n as u32); + self.cursor = new; + } + + /// Asserts if the time given is in the past or too far in the future. + pub fn add(&mut self, time: Instant, item: T) { + assert!(time >= self.now); + // Skip forward quickly if there is too large a gap. + let short_span = self.span() - self.granularity; + if time >= (self.now + self.span() + short_span) { + // Assert that there aren't any items. + for i in &self.items { + assert!(i.is_empty()); + } + self.now = time - short_span; + self.cursor = 0; + } + + // Adjust time forward the minimum amount necessary. + let mut d = self.delta(time); + if d >= self.items.len() { + self.tick(1 + d - self.items.len()); + d = self.items.len() - 1; + } + + let bucket = self.bucket(d); + let ins = match self.items[bucket].binary_search_by_key(&time, TimerItem::time) { + Ok(j) | Err(j) => j, + }; + self.items[bucket].insert(ins, TimerItem { time, item }); + } + + /// Given knowledge of the time an item was added, remove it. + /// This requires use of a predicate that identifies matching items. + pub fn remove<F>(&mut self, time: Instant, mut selector: F) -> Option<T> + where + F: FnMut(&T) -> bool, + { + if time < self.now { + return None; + } + if time > self.now + self.span() { + return None; + } + let bucket = self.time_bucket(time); + let start_index = match self.items[bucket].binary_search_by_key(&time, TimerItem::time) { + Ok(idx) => idx, + Err(_) => return None, + }; + // start_index is just one of potentially many items with the same time. + // Search backwards for a match, ... + for i in (0..=start_index).rev() { + if self.items[bucket][i].time != time { + break; + } + if selector(&self.items[bucket][i].item) { + return Some(self.items[bucket].remove(i).item); + } + } + // ... then forwards. + for i in (start_index + 1)..self.items[bucket].len() { + if self.items[bucket][i].time != time { + break; + } + if selector(&self.items[bucket][i].item) { + return Some(self.items[bucket].remove(i).item); + } + } + None + } + + /// Take the next item, unless there are no items with + /// a timeout in the past relative to `until`. + pub fn take_next(&mut self, until: Instant) -> Option<T> { + for i in 0..self.items.len() { + let idx = self.bucket(i); + if !self.items[idx].is_empty() && self.items[idx][0].time <= until { + return Some(self.items[idx].remove(0).item); + } + } + None + } + + /// Create an iterator that takes all items until the given time. + /// Note: Items might be removed even if the iterator is not fully exhausted. + pub fn take_until(&mut self, until: Instant) -> impl Iterator<Item = T> { + let get_item = move |x: TimerItem<T>| x.item; + if until >= self.now + self.span() { + // Drain everything, so a clean sweep. + let mut empty_items = Vec::with_capacity(self.items.len()); + empty_items.resize_with(self.items.len(), Vec::default); + let mut items = mem::replace(&mut self.items, empty_items); + self.now = until; + self.cursor = 0; + + let tail = items.split_off(self.cursor); + return tail.into_iter().chain(items).flatten().map(get_item); + } + + // Only returning a partial span, so do it bucket at a time. + let delta = self.delta(until); + let mut buckets = Vec::with_capacity(delta + 1); + + // First, the whole buckets. + for i in 0..delta { + let idx = self.bucket(i); + buckets.push(mem::take(&mut self.items[idx])); + } + self.tick(delta); + + // Now we need to split the last bucket, because there might be + // some items with `item.time > until`. + let bucket = &mut self.items[self.cursor]; + let last_idx = match bucket.binary_search_by_key(&until, TimerItem::time) { + Ok(mut m) => { + // If there are multiple values, the search will hit any of them. + // Make sure to get them all. + while m < bucket.len() && bucket[m].time == until { + m += 1; + } + m + } + Err(ins) => ins, + }; + let tail = bucket.split_off(last_idx); + buckets.push(mem::replace(bucket, tail)); + // This tomfoolery with the empty vector ensures that + // the returned type here matches the one above precisely + // without having to invoke the `either` crate. + buckets.into_iter().chain(vec![]).flatten().map(get_item) + } +} + +#[cfg(test)] +mod test { + use super::{Duration, Instant, Timer}; + use lazy_static::lazy_static; + + lazy_static! { + static ref NOW: Instant = Instant::now(); + } + + const GRANULARITY: Duration = Duration::from_millis(10); + const CAPACITY: usize = 10; + #[test] + fn create() { + let t: Timer<()> = Timer::new(*NOW, GRANULARITY, CAPACITY); + assert_eq!(t.span(), Duration::from_millis(100)); + assert_eq!(None, t.next_time()); + } + + #[test] + fn immediate_entry() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + t.add(*NOW, 12); + assert_eq!(*NOW, t.next_time().expect("should have an entry")); + let values: Vec<_> = t.take_until(*NOW).collect(); + assert_eq!(vec![12], values); + } + + #[test] + fn same_time() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let v1 = 12; + let v2 = 13; + t.add(*NOW, v1); + t.add(*NOW, v2); + assert_eq!(*NOW, t.next_time().expect("should have an entry")); + let values: Vec<_> = t.take_until(*NOW).collect(); + assert!(values.contains(&v1)); + assert!(values.contains(&v2)); + } + + #[test] + fn add() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let near_future = *NOW + Duration::from_millis(17); + let v = 9; + t.add(near_future, v); + assert_eq!(near_future, t.next_time().expect("should return a value")); + let values: Vec<_> = t + .take_until(near_future - Duration::from_millis(1)) + .collect(); + assert!(values.is_empty()); + let values: Vec<_> = t + .take_until(near_future + Duration::from_millis(1)) + .collect(); + assert!(values.contains(&v)); + } + + #[test] + fn add_future() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let future = *NOW + Duration::from_millis(117); + let v = 9; + t.add(future, v); + assert_eq!(future, t.next_time().expect("should return a value")); + let values: Vec<_> = t.take_until(future).collect(); + assert!(values.contains(&v)); + } + + #[test] + fn add_far_future() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let far_future = *NOW + Duration::from_millis(892); + let v = 9; + t.add(far_future, v); + assert_eq!(far_future, t.next_time().expect("should return a value")); + let values: Vec<_> = t.take_until(far_future).collect(); + assert!(values.contains(&v)); + } + + const TIMES: &[Duration] = &[ + Duration::from_millis(40), + Duration::from_millis(91), + Duration::from_millis(6), + Duration::from_millis(3), + Duration::from_millis(22), + Duration::from_millis(40), + ]; + + fn with_times() -> Timer<usize> { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + for (i, time) in TIMES.iter().enumerate() { + t.add(*NOW + *time, i); + } + assert_eq!( + *NOW + *TIMES.iter().min().unwrap(), + t.next_time().expect("should have a time") + ); + t + } + + #[test] + fn multiple_values() { + let mut t = with_times(); + let values: Vec<_> = t.take_until(*NOW + *TIMES.iter().max().unwrap()).collect(); + for i in 0..TIMES.len() { + assert!(values.contains(&i)); + } + } + + #[test] + fn take_far_future() { + let mut t = with_times(); + let values: Vec<_> = t.take_until(*NOW + Duration::from_secs(100)).collect(); + for i in 0..TIMES.len() { + assert!(values.contains(&i)); + } + } + + #[test] + fn remove_each() { + let mut t = with_times(); + for (i, time) in TIMES.iter().enumerate() { + assert_eq!(Some(i), t.remove(*NOW + *time, |&x| x == i)); + } + assert_eq!(None, t.next_time()); + } + + #[test] + fn remove_future() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let future = *NOW + Duration::from_millis(117); + let v = 9; + t.add(future, v); + + assert_eq!(Some(v), t.remove(future, |candidate| *candidate == v)); + } + + #[test] + fn remove_too_far_future() { + let mut t = Timer::new(*NOW, GRANULARITY, CAPACITY); + let future = *NOW + Duration::from_millis(117); + let too_far_future = *NOW + t.span() + Duration::from_millis(117); + let v = 9; + t.add(future, v); + + assert_eq!(None, t.remove(too_far_future, |candidate| *candidate == v)); + } +} |