diff options
Diffstat (limited to 'third_party/rust/neqo-transport/tests')
-rw-r--r-- | third_party/rust/neqo-transport/tests/common/mod.rs | 108 | ||||
-rw-r--r-- | third_party/rust/neqo-transport/tests/connection.rs | 41 | ||||
-rw-r--r-- | third_party/rust/neqo-transport/tests/network.rs | 10 | ||||
-rw-r--r-- | third_party/rust/neqo-transport/tests/retry.rs | 20 | ||||
-rw-r--r-- | third_party/rust/neqo-transport/tests/server.rs | 93 |
5 files changed, 122 insertions, 150 deletions
diff --git a/third_party/rust/neqo-transport/tests/common/mod.rs b/third_party/rust/neqo-transport/tests/common/mod.rs index e36e66f753..ecbbe1c3ce 100644 --- a/third_party/rust/neqo-transport/tests/common/mod.rs +++ b/third_party/rust/neqo-transport/tests/common/mod.rs @@ -84,114 +84,6 @@ pub fn connect(client: &mut Connection, server: &mut Server) -> ActiveConnection connected_server(server) } -// Decode the header of a client Initial packet, returning three values: -// * the entire header short of the packet number, -// * just the DCID, -// * just the SCID, and -// * the protected payload including the packet number. -// Any token is thrown away. -#[must_use] -pub fn decode_initial_header(dgram: &Datagram, role: Role) -> (&[u8], &[u8], &[u8], &[u8]) { - let mut dec = Decoder::new(&dgram[..]); - let type_and_ver = dec.decode(5).unwrap().to_vec(); - // The client sets the QUIC bit, the server might not. - match role { - Role::Client => assert_eq!(type_and_ver[0] & 0xf0, 0xc0), - Role::Server => assert_eq!(type_and_ver[0] & 0xb0, 0x80), - } - let dest_cid = dec.decode_vec(1).unwrap(); - let src_cid = dec.decode_vec(1).unwrap(); - dec.skip_vvec(); // Ignore any the token. - - // Need to read of the length separately so that we can find the packet number. - let payload_len = usize::try_from(dec.decode_varint().unwrap()).unwrap(); - let pn_offset = dgram.len() - dec.remaining(); - ( - &dgram[..pn_offset], - dest_cid, - src_cid, - dec.decode(payload_len).unwrap(), - ) -} - -/// Generate an AEAD and header protection object for a client Initial. -/// Note that this works for QUIC version 1 only. -#[must_use] -pub fn initial_aead_and_hp(dcid: &[u8], role: Role) -> (Aead, HpKey) { - const INITIAL_SALT: &[u8] = &[ - 0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, - 0xad, 0xcc, 0xbb, 0x7f, 0x0a, - ]; - let initial_secret = hkdf::extract( - TLS_VERSION_1_3, - TLS_AES_128_GCM_SHA256, - Some( - hkdf::import_key(TLS_VERSION_1_3, INITIAL_SALT) - .as_ref() - .unwrap(), - ), - hkdf::import_key(TLS_VERSION_1_3, dcid).as_ref().unwrap(), - ) - .unwrap(); - - let secret = hkdf::expand_label( - TLS_VERSION_1_3, - TLS_AES_128_GCM_SHA256, - &initial_secret, - &[], - match role { - Role::Client => "client in", - Role::Server => "server in", - }, - ) - .unwrap(); - ( - Aead::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, &secret, "quic ").unwrap(), - HpKey::extract(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, &secret, "quic hp").unwrap(), - ) -} - -// Remove header protection, returning the unmasked header and the packet number. -#[must_use] -pub fn remove_header_protection(hp: &HpKey, header: &[u8], payload: &[u8]) -> (Vec<u8>, u64) { - // Make a copy of the header that can be modified. - let mut fixed_header = header.to_vec(); - let pn_offset = header.len(); - // Save 4 extra in case the packet number is that long. - fixed_header.extend_from_slice(&payload[..4]); - - // Sample for masking and apply the mask. - let mask = hp.mask(&payload[4..20]).unwrap(); - fixed_header[0] ^= mask[0] & 0xf; - let pn_len = 1 + usize::from(fixed_header[0] & 0x3); - for i in 0..pn_len { - fixed_header[pn_offset + i] ^= mask[1 + i]; - } - // Trim down to size. - fixed_header.truncate(pn_offset + pn_len); - // The packet number should be 1. - let pn = Decoder::new(&fixed_header[pn_offset..]) - .decode_uint(pn_len) - .unwrap(); - - (fixed_header, pn) -} - -pub fn apply_header_protection(hp: &HpKey, packet: &mut [u8], pn_bytes: Range<usize>) { - let sample_start = pn_bytes.start + 4; - let sample_end = sample_start + 16; - let mask = hp.mask(&packet[sample_start..sample_end]).unwrap(); - qtrace!( - "sample={} mask={}", - hex_with_len(&packet[sample_start..sample_end]), - hex_with_len(&mask) - ); - packet[0] ^= mask[0] & 0xf; - for i in 0..(pn_bytes.end - pn_bytes.start) { - packet[pn_bytes.start + i] ^= mask[1 + i]; - } -} - /// Scrub through client events to find a resumption token. pub fn find_ticket(client: &mut Connection) -> ResumptionToken { client diff --git a/third_party/rust/neqo-transport/tests/connection.rs b/third_party/rust/neqo-transport/tests/connection.rs index b8877b946d..3cc711f80b 100644 --- a/third_party/rust/neqo-transport/tests/connection.rs +++ b/third_party/rust/neqo-transport/tests/connection.rs @@ -6,12 +6,16 @@ mod common; -use common::{ - apply_header_protection, decode_initial_header, initial_aead_and_hp, remove_header_protection, -}; use neqo_common::{Datagram, Decoder, Encoder, Role}; -use neqo_transport::{ConnectionError, ConnectionParameters, Error, State, Version}; -use test_fixture::{default_client, default_server, new_client, now, split_datagram}; +use neqo_transport::{CloseReason, ConnectionParameters, Error, State, Version}; +use test_fixture::{ + default_client, default_server, + header_protection::{ + apply_header_protection, decode_initial_header, initial_aead_and_hp, + remove_header_protection, + }, + new_client, now, split_datagram, +}; #[test] fn connect() { @@ -58,8 +62,8 @@ fn truncate_long_packet() { /// Test that reordering parts of the server Initial doesn't change things. #[test] fn reorder_server_initial() { - // A simple ACK frame for a single packet with packet number 0. - const ACK_FRAME: &[u8] = &[0x02, 0x00, 0x00, 0x00, 0x00]; + // A simple ACK_ECN frame for a single packet with packet number 0 with a single ECT(0) mark. + const ACK_FRAME: &[u8] = &[0x03, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00]; let mut client = new_client( ConnectionParameters::default().versions(Version::Version1, vec![Version::Version1]), @@ -68,12 +72,13 @@ fn reorder_server_initial() { let client_initial = client.process_output(now()); let (_, client_dcid, _, _) = - decode_initial_header(client_initial.as_dgram_ref().unwrap(), Role::Client); + decode_initial_header(client_initial.as_dgram_ref().unwrap(), Role::Client).unwrap(); let client_dcid = client_dcid.to_owned(); let server_packet = server.process(client_initial.as_dgram_ref(), now()).dgram(); let (server_initial, server_hs) = split_datagram(server_packet.as_ref().unwrap()); - let (protected_header, _, _, payload) = decode_initial_header(&server_initial, Role::Server); + let (protected_header, _, _, payload) = + decode_initial_header(&server_initial, Role::Server).unwrap(); // Now decrypt the packet. let (aead, hp) = initial_aead_and_hp(&client_dcid, Role::Server); @@ -130,7 +135,7 @@ fn reorder_server_initial() { fn set_payload(server_packet: &Option<Datagram>, client_dcid: &[u8], payload: &[u8]) -> Datagram { let (server_initial, _server_hs) = split_datagram(server_packet.as_ref().unwrap()); let (protected_header, _, _, orig_payload) = - decode_initial_header(&server_initial, Role::Server); + decode_initial_header(&server_initial, Role::Server).unwrap(); // Now decrypt the packet. let (aead, hp) = initial_aead_and_hp(client_dcid, Role::Server); @@ -168,14 +173,14 @@ fn packet_without_frames() { let client_initial = client.process_output(now()); let (_, client_dcid, _, _) = - decode_initial_header(client_initial.as_dgram_ref().unwrap(), Role::Client); + decode_initial_header(client_initial.as_dgram_ref().unwrap(), Role::Client).unwrap(); let server_packet = server.process(client_initial.as_dgram_ref(), now()).dgram(); let modified = set_payload(&server_packet, client_dcid, &[]); client.process_input(&modified, now()); assert_eq!( client.state(), - &State::Closed(ConnectionError::Transport(Error::ProtocolViolation)) + &State::Closed(CloseReason::Transport(Error::ProtocolViolation)) ); } @@ -189,7 +194,7 @@ fn packet_with_only_padding() { let client_initial = client.process_output(now()); let (_, client_dcid, _, _) = - decode_initial_header(client_initial.as_dgram_ref().unwrap(), Role::Client); + decode_initial_header(client_initial.as_dgram_ref().unwrap(), Role::Client).unwrap(); let server_packet = server.process(client_initial.as_dgram_ref(), now()).dgram(); let modified = set_payload(&server_packet, client_dcid, &[0]); @@ -208,7 +213,7 @@ fn overflow_crypto() { let client_initial = client.process_output(now()).dgram(); let (_, client_dcid, _, _) = - decode_initial_header(client_initial.as_ref().unwrap(), Role::Client); + decode_initial_header(client_initial.as_ref().unwrap(), Role::Client).unwrap(); let client_dcid = client_dcid.to_owned(); let server_packet = server.process(client_initial.as_ref(), now()).dgram(); @@ -217,7 +222,8 @@ fn overflow_crypto() { // Now decrypt the server packet to get AEAD and HP instances. // We won't be using the packet, but making new ones. let (aead, hp) = initial_aead_and_hp(&client_dcid, Role::Server); - let (_, server_dcid, server_scid, _) = decode_initial_header(&server_initial, Role::Server); + let (_, server_dcid, server_scid, _) = + decode_initial_header(&server_initial, Role::Server).unwrap(); // Send in 100 packets, each with 1000 bytes of crypto frame data each, // eventually this will overrun the buffer we keep for crypto data. @@ -260,10 +266,7 @@ fn overflow_crypto() { client.process_input(&dgram, now()); if let State::Closing { error, .. } = client.state() { assert!( - matches!( - error, - ConnectionError::Transport(Error::CryptoBufferExceeded), - ), + matches!(error, CloseReason::Transport(Error::CryptoBufferExceeded),), "the connection need to abort on crypto buffer" ); assert!(pn > 64, "at least 64000 bytes of data is buffered"); diff --git a/third_party/rust/neqo-transport/tests/network.rs b/third_party/rust/neqo-transport/tests/network.rs index 27e5a83cd6..68a835a436 100644 --- a/third_party/rust/neqo-transport/tests/network.rs +++ b/third_party/rust/neqo-transport/tests/network.rs @@ -6,7 +6,7 @@ use std::{ops::Range, time::Duration}; -use neqo_transport::{ConnectionError, ConnectionParameters, Error, State}; +use neqo_transport::{CloseReason, ConnectionParameters, Error, State}; use test_fixture::{ boxed, sim::{ @@ -48,10 +48,10 @@ simulate!( idle_timeout, [ ConnectionNode::default_client(boxed![ReachState::new(State::Closed( - ConnectionError::Transport(Error::IdleTimeout) + CloseReason::Transport(Error::IdleTimeout) ))]), ConnectionNode::default_server(boxed![ReachState::new(State::Closed( - ConnectionError::Transport(Error::IdleTimeout) + CloseReason::Transport(Error::IdleTimeout) ))]), ] ); @@ -62,7 +62,7 @@ simulate!( ConnectionNode::new_client( ConnectionParameters::default().idle_timeout(weeks(1000)), boxed![ReachState::new(State::Confirmed),], - boxed![ReachState::new(State::Closed(ConnectionError::Transport( + boxed![ReachState::new(State::Closed(CloseReason::Transport( Error::IdleTimeout )))] ), @@ -71,7 +71,7 @@ simulate!( ConnectionNode::new_server( ConnectionParameters::default().idle_timeout(weeks(1000)), boxed![ReachState::new(State::Confirmed),], - boxed![ReachState::new(State::Closed(ConnectionError::Transport( + boxed![ReachState::new(State::Closed(CloseReason::Transport( Error::IdleTimeout )))] ), diff --git a/third_party/rust/neqo-transport/tests/retry.rs b/third_party/rust/neqo-transport/tests/retry.rs index 36eff71e7b..3f95511c3e 100644 --- a/third_party/rust/neqo-transport/tests/retry.rs +++ b/third_party/rust/neqo-transport/tests/retry.rs @@ -14,14 +14,18 @@ use std::{ time::Duration, }; -use common::{ - apply_header_protection, connected_server, decode_initial_header, default_server, - generate_ticket, initial_aead_and_hp, remove_header_protection, -}; +use common::{connected_server, default_server, generate_ticket}; use neqo_common::{hex_with_len, qdebug, qtrace, Datagram, Encoder, Role}; use neqo_crypto::AuthenticationStatus; -use neqo_transport::{server::ValidateAddress, ConnectionError, Error, State, StreamType}; -use test_fixture::{assertions, datagram, default_client, now, split_datagram}; +use neqo_transport::{server::ValidateAddress, CloseReason, Error, State, StreamType}; +use test_fixture::{ + assertions, datagram, default_client, + header_protection::{ + apply_header_protection, decode_initial_header, initial_aead_and_hp, + remove_header_protection, + }, + now, split_datagram, +}; #[test] fn retry_basic() { @@ -400,7 +404,7 @@ fn mitm_retry() { // rewriting the header to remove the token, and then re-encrypting. let client_initial2 = client_initial2.unwrap(); let (protected_header, d_cid, s_cid, payload) = - decode_initial_header(&client_initial2, Role::Client); + decode_initial_header(&client_initial2, Role::Client).unwrap(); // Now we have enough information to make keys. let (aead, hp) = initial_aead_and_hp(d_cid, Role::Client); @@ -465,7 +469,7 @@ fn mitm_retry() { assert!(matches!( *client.state(), State::Closing { - error: ConnectionError::Transport(Error::ProtocolViolation), + error: CloseReason::Transport(Error::ProtocolViolation), .. } )); diff --git a/third_party/rust/neqo-transport/tests/server.rs b/third_party/rust/neqo-transport/tests/server.rs index 7388e0fee7..4740d26ded 100644 --- a/third_party/rust/neqo-transport/tests/server.rs +++ b/third_party/rust/neqo-transport/tests/server.rs @@ -8,21 +8,22 @@ mod common; use std::{cell::RefCell, mem, net::SocketAddr, rc::Rc, time::Duration}; -use common::{ - apply_header_protection, connect, connected_server, decode_initial_header, default_server, - find_ticket, generate_ticket, initial_aead_and_hp, new_server, remove_header_protection, -}; +use common::{connect, connected_server, default_server, find_ticket, generate_ticket, new_server}; use neqo_common::{qtrace, Datagram, Decoder, Encoder, Role}; use neqo_crypto::{ generate_ech_keys, AllowZeroRtt, AuthenticationStatus, ZeroRttCheckResult, ZeroRttChecker, }; use neqo_transport::{ server::{ActiveConnectionRef, Server, ValidateAddress}, - Connection, ConnectionError, ConnectionParameters, Error, Output, State, StreamType, Version, + CloseReason, Connection, ConnectionParameters, Error, Output, State, StreamType, Version, }; use test_fixture::{ - assertions, datagram, default_client, new_client, now, split_datagram, - CountingConnectionIdGenerator, + assertions, datagram, default_client, + header_protection::{ + apply_header_protection, decode_initial_header, initial_aead_and_hp, + remove_header_protection, + }, + new_client, now, split_datagram, CountingConnectionIdGenerator, }; /// Take a pair of connections in any state and complete the handshake. @@ -389,7 +390,7 @@ fn bad_client_initial() { let mut server = default_server(); let dgram = client.process(None, now()).dgram().expect("a datagram"); - let (header, d_cid, s_cid, payload) = decode_initial_header(&dgram, Role::Client); + let (header, d_cid, s_cid, payload) = decode_initial_header(&dgram, Role::Client).unwrap(); let (aead, hp) = initial_aead_and_hp(d_cid, Role::Client); let (fixed_header, pn) = remove_header_protection(&hp, header, payload); let payload = &payload[(fixed_header.len() - header.len())..]; @@ -462,13 +463,13 @@ fn bad_client_initial() { assert_ne!(delay, Duration::from_secs(0)); assert!(matches!( *client.state(), - State::Draining { error: ConnectionError::Transport(Error::PeerError(code)), .. } if code == Error::ProtocolViolation.code() + State::Draining { error: CloseReason::Transport(Error::PeerError(code)), .. } if code == Error::ProtocolViolation.code() )); for server in server.active_connections() { assert_eq!( *server.borrow().state(), - State::Closed(ConnectionError::Transport(Error::ProtocolViolation)) + State::Closed(CloseReason::Transport(Error::ProtocolViolation)) ); } @@ -478,6 +479,65 @@ fn bad_client_initial() { } #[test] +fn bad_client_initial_connection_close() { + let mut client = default_client(); + let mut server = default_server(); + + let dgram = client.process(None, now()).dgram().expect("a datagram"); + let (header, d_cid, s_cid, payload) = decode_initial_header(&dgram, Role::Client).unwrap(); + let (aead, hp) = initial_aead_and_hp(d_cid, Role::Client); + let (_, pn) = remove_header_protection(&hp, header, payload); + + let mut payload_enc = Encoder::with_capacity(1200); + payload_enc.encode(&[0x1c, 0x01, 0x00, 0x00]); // Add a CONNECTION_CLOSE frame. + + // Make a new header with a 1 byte packet number length. + let mut header_enc = Encoder::new(); + header_enc + .encode_byte(0xc0) // Initial with 1 byte packet number. + .encode_uint(4, Version::default().wire_version()) + .encode_vec(1, d_cid) + .encode_vec(1, s_cid) + .encode_vvec(&[]) + .encode_varint(u64::try_from(payload_enc.len() + aead.expansion() + 1).unwrap()) + .encode_byte(u8::try_from(pn).unwrap()); + + let mut ciphertext = header_enc.as_ref().to_vec(); + ciphertext.resize(header_enc.len() + payload_enc.len() + aead.expansion(), 0); + let v = aead + .encrypt( + pn, + header_enc.as_ref(), + payload_enc.as_ref(), + &mut ciphertext[header_enc.len()..], + ) + .unwrap(); + assert_eq!(header_enc.len() + v.len(), ciphertext.len()); + // Pad with zero to get up to 1200. + ciphertext.resize(1200, 0); + + apply_header_protection( + &hp, + &mut ciphertext, + (header_enc.len() - 1)..header_enc.len(), + ); + let bad_dgram = Datagram::new( + dgram.source(), + dgram.destination(), + dgram.tos(), + dgram.ttl(), + ciphertext, + ); + + // The server should ignore this and go to Draining. + let mut now = now(); + let response = server.process(Some(&bad_dgram), now); + now += response.callback(); + let response = server.process(None, now); + assert_eq!(response, Output::None); +} + +#[test] fn version_negotiation_ignored() { let mut server = default_server(); let mut client = default_client(); @@ -774,3 +834,16 @@ fn ech() { .ech_accepted() .unwrap()); } + +#[test] +fn has_active_connections() { + let mut server = default_server(); + let mut client = default_client(); + + assert!(!server.has_active_connections()); + + let initial = client.process(None, now()); + let _ = server.process(initial.as_dgram_ref(), now()).dgram(); + + assert!(server.has_active_connections()); +} |