// Licensed under the Apache License, Version 2.0 or the MIT license // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. use std::{cell::RefCell, collections::BTreeSet, mem, rc::Rc}; use neqo_common::{qtrace, Encoder, Header, MessageType, Role}; use neqo_qpack::{QPackDecoder, QPackEncoder}; use neqo_transport::{Connection, DatagramTracking, StreamId}; use super::{ExtendedConnectEvents, ExtendedConnectType, SessionCloseReason}; use crate::{ frames::{FrameReader, StreamReaderRecvStreamWrapper, WebTransportFrame}, recv_message::{RecvMessage, RecvMessageInfo}, send_message::SendMessage, CloseType, Error, HFrame, Http3StreamInfo, Http3StreamType, HttpRecvStream, HttpRecvStreamEvents, Priority, PriorityHandler, ReceiveOutput, RecvStream, RecvStreamEvents, Res, SendStream, SendStreamEvents, Stream, }; #[derive(Debug, PartialEq)] enum SessionState { Negotiating, Active, FinPending, Done, } impl SessionState { pub fn closing_state(&self) -> bool { matches!(self, Self::FinPending | Self::Done) } } #[derive(Debug)] pub(crate) struct WebTransportSession { control_stream_recv: Box, control_stream_send: Box, stream_event_listener: Rc>, session_id: StreamId, state: SessionState, frame_reader: FrameReader, events: Box, send_streams: BTreeSet, recv_streams: BTreeSet, role: Role, } impl ::std::fmt::Display for WebTransportSession { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { write!(f, "WebTransportSession session={}", self.session_id,) } } impl WebTransportSession { #[must_use] pub fn new( session_id: StreamId, events: Box, role: Role, qpack_encoder: Rc>, qpack_decoder: Rc>, ) -> Self { let stream_event_listener = Rc::new(RefCell::new(WebTransportSessionListener::default())); Self { control_stream_recv: Box::new(RecvMessage::new( &RecvMessageInfo { message_type: MessageType::Response, stream_type: Http3StreamType::ExtendedConnect, stream_id: session_id, header_frame_type_read: false, }, qpack_decoder, Box::new(stream_event_listener.clone()), None, PriorityHandler::new(false, Priority::default()), )), control_stream_send: Box::new(SendMessage::new( MessageType::Request, Http3StreamType::ExtendedConnect, session_id, qpack_encoder, Box::new(stream_event_listener.clone()), )), stream_event_listener, session_id, state: SessionState::Negotiating, frame_reader: FrameReader::new(), events, send_streams: BTreeSet::new(), recv_streams: BTreeSet::new(), role, } } /// # Panics /// /// This function is only called with `RecvStream` and `SendStream` that also implement /// the http specific functions and `http_stream()` will never return `None`. #[must_use] pub fn new_with_http_streams( session_id: StreamId, events: Box, role: Role, mut control_stream_recv: Box, mut control_stream_send: Box, ) -> Self { let stream_event_listener = Rc::new(RefCell::new(WebTransportSessionListener::default())); control_stream_recv .http_stream() .unwrap() .set_new_listener(Box::new(stream_event_listener.clone())); control_stream_send .http_stream() .unwrap() .set_new_listener(Box::new(stream_event_listener.clone())); Self { control_stream_recv, control_stream_send, stream_event_listener, session_id, state: SessionState::Active, frame_reader: FrameReader::new(), events, send_streams: BTreeSet::new(), recv_streams: BTreeSet::new(), role, } } /// # Errors /// /// The function can only fail if supplied headers are not valid http headers. /// /// # Panics /// /// `control_stream_send` implements the http specific functions and `http_stream()` /// will never return `None`. pub fn send_request(&mut self, headers: &[Header], conn: &mut Connection) -> Res<()> { self.control_stream_send .http_stream() .unwrap() .send_headers(headers, conn) } fn receive(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)> { qtrace!([self], "receive control data"); let (out, _) = self.control_stream_recv.receive(conn)?; debug_assert!(out == ReceiveOutput::NoOutput); self.maybe_check_headers(); self.read_control_stream(conn)?; Ok((ReceiveOutput::NoOutput, self.state == SessionState::Done)) } fn header_unblocked(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)> { let (out, _) = self .control_stream_recv .http_stream() .unwrap() .header_unblocked(conn)?; debug_assert!(out == ReceiveOutput::NoOutput); self.maybe_check_headers(); self.read_control_stream(conn)?; Ok((ReceiveOutput::NoOutput, self.state == SessionState::Done)) } fn maybe_update_priority(&mut self, priority: Priority) -> bool { self.control_stream_recv .http_stream() .unwrap() .maybe_update_priority(priority) } fn priority_update_frame(&mut self) -> Option { self.control_stream_recv .http_stream() .unwrap() .priority_update_frame() } fn priority_update_sent(&mut self) { self.control_stream_recv .http_stream() .unwrap() .priority_update_sent(); } fn send(&mut self, conn: &mut Connection) -> Res<()> { self.control_stream_send.send(conn)?; if self.control_stream_send.done() { self.state = SessionState::Done; } Ok(()) } fn has_data_to_send(&self) -> bool { self.control_stream_send.has_data_to_send() } fn done(&self) -> bool { self.state == SessionState::Done } fn close(&mut self, close_type: CloseType) { if self.state.closing_state() { return; } qtrace!("ExtendedConnect close the session"); self.state = SessionState::Done; if !close_type.locally_initiated() { self.events.session_end( ExtendedConnectType::WebTransport, self.session_id, SessionCloseReason::from(close_type), None, ); } } /// # Panics /// /// This cannot panic because headers are checked before this function called. pub fn maybe_check_headers(&mut self) { if SessionState::Negotiating != self.state { return; } if let Some((headers, interim, fin)) = self.stream_event_listener.borrow_mut().get_headers() { qtrace!( "ExtendedConnect response headers {:?}, fin={}", headers, fin ); if interim { if fin { self.events.session_end( ExtendedConnectType::WebTransport, self.session_id, SessionCloseReason::Clean { error: 0, message: String::new(), }, Some(headers), ); self.state = SessionState::Done; } } else { let status = headers .iter() .find_map(|h| { if h.name() == ":status" { h.value().parse::().ok() } else { None } }) .unwrap(); self.state = if (200..300).contains(&status) { if fin { self.events.session_end( ExtendedConnectType::WebTransport, self.session_id, SessionCloseReason::Clean { error: 0, message: String::new(), }, Some(headers), ); SessionState::Done } else { self.events.session_start( ExtendedConnectType::WebTransport, self.session_id, status, headers, ); SessionState::Active } } else { self.events.session_end( ExtendedConnectType::WebTransport, self.session_id, SessionCloseReason::Status(status), Some(headers), ); SessionState::Done }; } } } pub fn add_stream(&mut self, stream_id: StreamId) { if let SessionState::Active = self.state { if stream_id.is_bidi() { self.send_streams.insert(stream_id); self.recv_streams.insert(stream_id); } else if stream_id.is_self_initiated(self.role) { self.send_streams.insert(stream_id); } else { self.recv_streams.insert(stream_id); } if !stream_id.is_self_initiated(self.role) { self.events .extended_connect_new_stream(Http3StreamInfo::new( stream_id, ExtendedConnectType::WebTransport.get_stream_type(self.session_id), )); } } } pub fn remove_recv_stream(&mut self, stream_id: StreamId) { self.recv_streams.remove(&stream_id); } pub fn remove_send_stream(&mut self, stream_id: StreamId) { self.send_streams.remove(&stream_id); } #[must_use] pub fn is_active(&self) -> bool { matches!(self.state, SessionState::Active) } pub fn take_sub_streams(&mut self) -> (BTreeSet, BTreeSet) { ( mem::take(&mut self.recv_streams), mem::take(&mut self.send_streams), ) } /// # Errors /// /// It may return an error if the frame is not correctly decoded. pub fn read_control_stream(&mut self, conn: &mut Connection) -> Res<()> { let (f, fin) = self .frame_reader .receive::(&mut StreamReaderRecvStreamWrapper::new( conn, &mut self.control_stream_recv, )) .map_err(|_| Error::HttpGeneralProtocolStream)?; qtrace!([self], "Received frame: {:?} fin={}", f, fin); if let Some(WebTransportFrame::CloseSession { error, message }) = f { self.events.session_end( ExtendedConnectType::WebTransport, self.session_id, SessionCloseReason::Clean { error, message }, None, ); self.state = if fin { SessionState::Done } else { SessionState::FinPending }; } else if fin { self.events.session_end( ExtendedConnectType::WebTransport, self.session_id, SessionCloseReason::Clean { error: 0, message: String::new(), }, None, ); self.state = SessionState::Done; } Ok(()) } /// # Errors /// /// Return an error if the stream was closed on the transport layer, but that information is not /// yet consumed on the http/3 layer. pub fn close_session(&mut self, conn: &mut Connection, error: u32, message: &str) -> Res<()> { self.state = SessionState::Done; let close_frame = WebTransportFrame::CloseSession { error, message: message.to_string(), }; let mut encoder = Encoder::default(); close_frame.encode(&mut encoder); self.control_stream_send .send_data_atomic(conn, encoder.as_ref())?; self.control_stream_send.close(conn)?; self.state = if self.control_stream_send.done() { SessionState::Done } else { SessionState::FinPending }; Ok(()) } fn send_data(&mut self, conn: &mut Connection, buf: &[u8]) -> Res { self.control_stream_send.send_data(conn, buf) } /// # Errors /// /// Returns an error if the datagram exceeds the remote datagram size limit. pub fn send_datagram( &self, conn: &mut Connection, buf: &[u8], id: impl Into, ) -> Res<()> { qtrace!([self], "send_datagram state={:?}", self.state); if let SessionState::Active = self.state { let mut dgram_data = Encoder::default(); dgram_data.encode_varint(self.session_id.as_u64() / 4); dgram_data.encode(buf); conn.send_datagram(dgram_data.as_ref(), id)?; } else { debug_assert!(false); return Err(Error::Unavailable); } Ok(()) } pub fn datagram(&mut self, datagram: Vec) { if let SessionState::Active = self.state { self.events.new_datagram(self.session_id, datagram); } } } impl Stream for Rc> { fn stream_type(&self) -> Http3StreamType { Http3StreamType::ExtendedConnect } } impl RecvStream for Rc> { fn receive(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)> { self.borrow_mut().receive(conn) } fn reset(&mut self, close_type: CloseType) -> Res<()> { self.borrow_mut().close(close_type); Ok(()) } fn http_stream(&mut self) -> Option<&mut dyn HttpRecvStream> { Some(self) } fn webtransport(&self) -> Option>> { Some(self.clone()) } } impl HttpRecvStream for Rc> { fn header_unblocked(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)> { self.borrow_mut().header_unblocked(conn) } fn maybe_update_priority(&mut self, priority: Priority) -> bool { self.borrow_mut().maybe_update_priority(priority) } fn priority_update_frame(&mut self) -> Option { self.borrow_mut().priority_update_frame() } fn priority_update_sent(&mut self) { self.borrow_mut().priority_update_sent(); } } impl SendStream for Rc> { fn send(&mut self, conn: &mut Connection) -> Res<()> { self.borrow_mut().send(conn) } fn send_data(&mut self, conn: &mut Connection, buf: &[u8]) -> Res { self.borrow_mut().send_data(conn, buf) } fn has_data_to_send(&self) -> bool { self.borrow_mut().has_data_to_send() } fn stream_writable(&self) {} fn done(&self) -> bool { self.borrow_mut().done() } fn close(&mut self, conn: &mut Connection) -> Res<()> { self.borrow_mut().close_session(conn, 0, "") } fn close_with_message(&mut self, conn: &mut Connection, error: u32, message: &str) -> Res<()> { self.borrow_mut().close_session(conn, error, message) } fn handle_stop_sending(&mut self, close_type: CloseType) { self.borrow_mut().close(close_type); } } #[derive(Debug, Default)] struct WebTransportSessionListener { headers: Option<(Vec
, bool, bool)>, } impl WebTransportSessionListener { fn set_headers(&mut self, headers: Vec
, interim: bool, fin: bool) { self.headers = Some((headers, interim, fin)); } pub fn get_headers(&mut self) -> Option<(Vec
, bool, bool)> { mem::take(&mut self.headers) } } impl RecvStreamEvents for Rc> {} impl HttpRecvStreamEvents for Rc> { fn header_ready( &self, _stream_info: Http3StreamInfo, headers: Vec
, interim: bool, fin: bool, ) { if !interim || fin { self.borrow_mut().set_headers(headers, interim, fin); } } } impl SendStreamEvents for Rc> {}