summaryrefslogtreecommitdiffstats
path: root/third_party/rust/neqo-transport/tests/common/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/neqo-transport/tests/common/mod.rs')
-rw-r--r--third_party/rust/neqo-transport/tests/common/mod.rs232
1 files changed, 232 insertions, 0 deletions
diff --git a/third_party/rust/neqo-transport/tests/common/mod.rs b/third_party/rust/neqo-transport/tests/common/mod.rs
new file mode 100644
index 0000000000..ea36f59203
--- /dev/null
+++ b/third_party/rust/neqo-transport/tests/common/mod.rs
@@ -0,0 +1,232 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![cfg_attr(feature = "deny-warnings", deny(warnings))]
+#![warn(clippy::pedantic)]
+#![allow(unused)]
+
+use neqo_common::{event::Provider, hex_with_len, qtrace, Datagram, Decoder, Role};
+use neqo_crypto::{
+ constants::{TLS_AES_128_GCM_SHA256, TLS_VERSION_1_3},
+ hkdf,
+ hp::HpKey,
+ Aead, AllowZeroRtt, AuthenticationStatus, ResumptionToken,
+};
+use neqo_transport::{
+ server::{ActiveConnectionRef, Server, ValidateAddress},
+ Connection, ConnectionEvent, ConnectionParameters, State,
+};
+use test_fixture::{self, default_client, now, CountingConnectionIdGenerator};
+
+use std::cell::RefCell;
+use std::convert::TryFrom;
+use std::mem;
+use std::ops::Range;
+use std::rc::Rc;
+
+/// Create a server. This is different than the one in the fixture, which is a single connection.
+pub fn new_server(params: ConnectionParameters) -> Server {
+ Server::new(
+ now(),
+ test_fixture::DEFAULT_KEYS,
+ test_fixture::DEFAULT_ALPN,
+ test_fixture::anti_replay(),
+ Box::new(AllowZeroRtt {}),
+ Rc::new(RefCell::new(CountingConnectionIdGenerator::default())),
+ params,
+ )
+ .expect("should create a server")
+}
+
+/// Create a server. This is different than the one in the fixture, which is a single connection.
+pub fn default_server() -> Server {
+ new_server(ConnectionParameters::default())
+}
+
+// Check that there is at least one connection. Returns a ref to the first confirmed connection.
+pub fn connected_server(server: &mut Server) -> ActiveConnectionRef {
+ let server_connections = server.active_connections();
+ // Find confirmed connections. There should only be one.
+ let mut confirmed = server_connections
+ .iter()
+ .filter(|c: &&ActiveConnectionRef| *c.borrow().state() == State::Confirmed);
+ let c = confirmed.next().expect("one confirmed");
+ assert!(confirmed.next().is_none(), "only one confirmed");
+ c.clone()
+}
+
+/// Connect. This returns a reference to the server connection.
+pub fn connect(client: &mut Connection, server: &mut Server) -> ActiveConnectionRef {
+ server.set_validation(ValidateAddress::Never);
+
+ assert_eq!(*client.state(), State::Init);
+ let dgram = client.process(None, now()).dgram(); // ClientHello
+ assert!(dgram.is_some());
+ let dgram = server.process(dgram, now()).dgram(); // ServerHello...
+ assert!(dgram.is_some());
+
+ // Ingest the server Certificate.
+ let dgram = client.process(dgram, now()).dgram();
+ assert!(dgram.is_some()); // This should just be an ACK.
+ let dgram = server.process(dgram, now()).dgram();
+ assert!(dgram.is_none()); // So the server should have nothing to say.
+
+ // Now mark the server as authenticated.
+ client.authenticated(AuthenticationStatus::Ok, now());
+ let dgram = client.process(None, now()).dgram();
+ assert!(dgram.is_some());
+ assert_eq!(*client.state(), State::Connected);
+ let dgram = server.process(dgram, now()).dgram();
+ assert!(dgram.is_some()); // ACK + HANDSHAKE_DONE + NST
+
+ // Have the client process the HANDSHAKE_DONE.
+ let dgram = client.process(dgram, now()).dgram();
+ assert!(dgram.is_none());
+ assert_eq!(*client.state(), State::Confirmed);
+
+ connected_server(server)
+}
+
+// Decode the header of a client Initial packet, returning three values:
+// * the entire header short of the packet number,
+// * just the DCID,
+// * just the SCID, and
+// * the protected payload including the packet number.
+// Any token is thrown away.
+#[must_use]
+pub fn decode_initial_header(dgram: &Datagram, role: Role) -> (&[u8], &[u8], &[u8], &[u8]) {
+ let mut dec = Decoder::new(&dgram[..]);
+ let type_and_ver = dec.decode(5).unwrap().to_vec();
+ // The client sets the QUIC bit, the server might not.
+ match role {
+ Role::Client => assert_eq!(type_and_ver[0] & 0xf0, 0xc0),
+ Role::Server => assert_eq!(type_and_ver[0] & 0xb0, 0x80),
+ }
+ let dest_cid = dec.decode_vec(1).unwrap();
+ let src_cid = dec.decode_vec(1).unwrap();
+ dec.skip_vvec(); // Ignore any the token.
+
+ // Need to read of the length separately so that we can find the packet number.
+ let payload_len = usize::try_from(dec.decode_varint().unwrap()).unwrap();
+ let pn_offset = dgram.len() - dec.remaining();
+ (
+ &dgram[..pn_offset],
+ dest_cid,
+ src_cid,
+ dec.decode(payload_len).unwrap(),
+ )
+}
+
+/// Generate an AEAD and header protection object for a client Initial.
+/// Note that this works for QUIC version 1 only.
+#[must_use]
+pub fn initial_aead_and_hp(dcid: &[u8], role: Role) -> (Aead, HpKey) {
+ const INITIAL_SALT: &[u8] = &[
+ 0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c,
+ 0xad, 0xcc, 0xbb, 0x7f, 0x0a,
+ ];
+ let initial_secret = hkdf::extract(
+ TLS_VERSION_1_3,
+ TLS_AES_128_GCM_SHA256,
+ Some(
+ hkdf::import_key(TLS_VERSION_1_3, INITIAL_SALT)
+ .as_ref()
+ .unwrap(),
+ ),
+ hkdf::import_key(TLS_VERSION_1_3, dcid).as_ref().unwrap(),
+ )
+ .unwrap();
+
+ let secret = hkdf::expand_label(
+ TLS_VERSION_1_3,
+ TLS_AES_128_GCM_SHA256,
+ &initial_secret,
+ &[],
+ match role {
+ Role::Client => "client in",
+ Role::Server => "server in",
+ },
+ )
+ .unwrap();
+ (
+ Aead::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, &secret, "quic ").unwrap(),
+ HpKey::extract(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, &secret, "quic hp").unwrap(),
+ )
+}
+
+// Remove header protection, returning the unmasked header and the packet number.
+#[must_use]
+pub fn remove_header_protection(hp: &HpKey, header: &[u8], payload: &[u8]) -> (Vec<u8>, u64) {
+ // Make a copy of the header that can be modified.
+ let mut fixed_header = header.to_vec();
+ let pn_offset = header.len();
+ // Save 4 extra in case the packet number is that long.
+ fixed_header.extend_from_slice(&payload[..4]);
+
+ // Sample for masking and apply the mask.
+ let mask = hp.mask(&payload[4..20]).unwrap();
+ fixed_header[0] ^= mask[0] & 0xf;
+ let pn_len = 1 + usize::from(fixed_header[0] & 0x3);
+ for i in 0..pn_len {
+ fixed_header[pn_offset + i] ^= mask[1 + i];
+ }
+ // Trim down to size.
+ fixed_header.truncate(pn_offset + pn_len);
+ // The packet number should be 1.
+ let pn = Decoder::new(&fixed_header[pn_offset..])
+ .decode_uint(pn_len)
+ .unwrap();
+
+ (fixed_header, pn)
+}
+
+pub fn apply_header_protection(hp: &HpKey, packet: &mut [u8], pn_bytes: Range<usize>) {
+ let sample_start = pn_bytes.start + 4;
+ let sample_end = sample_start + 16;
+ let mask = hp.mask(&packet[sample_start..sample_end]).unwrap();
+ qtrace!(
+ "sample={} mask={}",
+ hex_with_len(&packet[sample_start..sample_end]),
+ hex_with_len(&mask)
+ );
+ packet[0] ^= mask[0] & 0xf;
+ for i in 0..(pn_bytes.end - pn_bytes.start) {
+ packet[pn_bytes.start + i] ^= mask[1 + i];
+ }
+}
+
+/// Scrub through client events to find a resumption token.
+pub fn find_ticket(client: &mut Connection) -> ResumptionToken {
+ client
+ .events()
+ .find_map(|e| {
+ if let ConnectionEvent::ResumptionToken(token) = e {
+ Some(token)
+ } else {
+ None
+ }
+ })
+ .unwrap()
+}
+
+/// Connect to the server and have it generate a ticket.
+pub fn generate_ticket(server: &mut Server) -> ResumptionToken {
+ let mut client = default_client();
+ let mut server_conn = connect(&mut client, server);
+
+ server_conn.borrow_mut().send_ticket(now(), &[]).unwrap();
+ let dgram = server.process(None, now()).dgram();
+ client.process_input(dgram.unwrap(), now()); // Consume ticket, ignore output.
+ let ticket = find_ticket(&mut client);
+
+ // Have the client close the connection and then let the server clean up.
+ client.close(now(), 0, "got a ticket");
+ let dgram = client.process_output(now()).dgram();
+ mem::drop(server.process(dgram, now()));
+ // Calling active_connections clears the set of active connections.
+ assert_eq!(server.active_connections().len(), 1);
+ ticket
+}