From 36d22d82aa202bb199967e9512281e9a53db42c9 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 21:33:14 +0200 Subject: Adding upstream version 115.7.0esr. Signed-off-by: Daniel Baumann --- .../tests/tools/webtransport/h3/__init__.py | 0 .../tests/tools/webtransport/h3/capsule.py | 116 ++++ .../tests/tools/webtransport/h3/handler.py | 76 +++ .../tests/tools/webtransport/h3/test_capsule.py | 130 +++++ .../webtransport/h3/webtransport_h3_server.py | 607 +++++++++++++++++++++ 5 files changed, 929 insertions(+) create mode 100644 testing/web-platform/tests/tools/webtransport/h3/__init__.py create mode 100644 testing/web-platform/tests/tools/webtransport/h3/capsule.py create mode 100644 testing/web-platform/tests/tools/webtransport/h3/handler.py create mode 100644 testing/web-platform/tests/tools/webtransport/h3/test_capsule.py create mode 100644 testing/web-platform/tests/tools/webtransport/h3/webtransport_h3_server.py (limited to 'testing/web-platform/tests/tools/webtransport/h3') diff --git a/testing/web-platform/tests/tools/webtransport/h3/__init__.py b/testing/web-platform/tests/tools/webtransport/h3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/testing/web-platform/tests/tools/webtransport/h3/capsule.py b/testing/web-platform/tests/tools/webtransport/h3/capsule.py new file mode 100644 index 0000000000..74ca71ade9 --- /dev/null +++ b/testing/web-platform/tests/tools/webtransport/h3/capsule.py @@ -0,0 +1,116 @@ +# mypy: no-warn-return-any + +from enum import IntEnum +from typing import Iterator, Optional + +# TODO(bashi): Remove import check suppressions once aioquic dependency is +# resolved. +from aioquic.buffer import UINT_VAR_MAX_SIZE, Buffer, BufferReadError # type: ignore + + +class CapsuleType(IntEnum): + # Defined in + # https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-04#section-8.2 + DATAGRAM_DRAFT04 = 0xff37a0 + REGISTER_DATAGRAM_CONTEXT_DRAFT04 = 0xff37a1 + REGISTER_DATAGRAM_NO_CONTEXT_DRAFT04 = 0xff37a2 + CLOSE_DATAGRAM_CONTEXT_DRAFT04 = 0xff37a3 + # Defined in + # https://datatracker.ietf.org/doc/html/rfc9297#section-5.4 + DATAGRAM_RFC = 0x00 + # Defined in + # https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-01.html. + CLOSE_WEBTRANSPORT_SESSION = 0x2843 + + +class H3Capsule: + """ + Represents the Capsule concept defined in + https://ietf-wg-masque.github.io/draft-ietf-masque-h3-datagram/draft-ietf-masque-h3-datagram.html#name-capsules. + """ + + def __init__(self, type: int, data: bytes) -> None: + """ + :param type the type of this Capsule. We don't use CapsuleType here + because this may be a capsule of an unknown type. + :param data the payload + """ + self.type = type + self.data = data + + def encode(self) -> bytes: + """ + Encodes this H3Capsule and return the bytes. + """ + buffer = Buffer(capacity=len(self.data) + 2 * UINT_VAR_MAX_SIZE) + buffer.push_uint_var(self.type) + buffer.push_uint_var(len(self.data)) + buffer.push_bytes(self.data) + return buffer.data + + +class H3CapsuleDecoder: + """ + A decoder of H3Capsule. This is a streaming decoder and can handle multiple + decoders. + """ + + def __init__(self) -> None: + self._buffer: Optional[Buffer] = None + self._type: Optional[int] = None + self._length: Optional[int] = None + self._final: bool = False + + def append(self, data: bytes) -> None: + """ + Appends the given bytes to this decoder. + """ + assert not self._final + + if len(data) == 0: + return + if self._buffer: + remaining = self._buffer.pull_bytes( + self._buffer.capacity - self._buffer.tell()) + self._buffer = Buffer(data=(remaining + data)) + else: + self._buffer = Buffer(data=data) + + def final(self) -> None: + """ + Pushes the end-of-stream mark to this decoder. After calling this, + calling append() will be invalid. + """ + self._final = True + + def __iter__(self) -> Iterator[H3Capsule]: + """ + Yields decoded capsules. + """ + try: + while self._buffer is not None: + if self._type is None: + self._type = self._buffer.pull_uint_var() + if self._length is None: + self._length = self._buffer.pull_uint_var() + if self._buffer.capacity - self._buffer.tell() < self._length: + if self._final: + raise ValueError('insufficient buffer') + return + capsule = H3Capsule( + self._type, self._buffer.pull_bytes(self._length)) + self._type = None + self._length = None + if self._buffer.tell() == self._buffer.capacity: + self._buffer = None + yield capsule + except BufferReadError as e: + if self._final: + raise e + if not self._buffer: + return 0 + size = self._buffer.capacity - self._buffer.tell() + if size >= UINT_VAR_MAX_SIZE: + raise e + # Ignore the error because there may not be sufficient input. + return diff --git a/testing/web-platform/tests/tools/webtransport/h3/handler.py b/testing/web-platform/tests/tools/webtransport/h3/handler.py new file mode 100644 index 0000000000..9b6cb1ab20 --- /dev/null +++ b/testing/web-platform/tests/tools/webtransport/h3/handler.py @@ -0,0 +1,76 @@ +from typing import List, Optional, Tuple + +from .webtransport_h3_server import WebTransportSession + +# This file exists for documentation purpose. + + +def connect_received(request_headers: List[Tuple[bytes, bytes]], + response_headers: List[Tuple[bytes, bytes]]) -> None: + """ + Called whenever an extended CONNECT method is received. + + :param request_headers: The request headers received from the peer. + :param response_headers: The response headers which will be sent to the peer. ``:status`` is set + to 200 when it isn't specified. + """ + pass + + +def session_established(session: WebTransportSession) -> None: + """ + Called whenever an WebTransport session is established. + + :param session: A WebTransport session object. + """ + + +def stream_data_received(session: WebTransportSession, stream_id: int, + data: bytes, stream_ended: bool) -> None: + """ + Called whenever data is received on a WebTransport stream. + + :param session: A WebTransport session object. + :param stream_id: The ID of the stream. + :param data: The received data. + :param stream_ended: Whether the stream is ended. + """ + pass + + +def datagram_received(session: WebTransportSession, data: bytes) -> None: + """ + Called whenever a datagram is received on a WebTransport session. + + :param session: A WebTransport session object. + :param data: The received data. + """ + pass + + +def session_closed(session: WebTransportSession, + close_info: Optional[Tuple[int, bytes]], + abruptly: bool) -> None: + """ + Called when a WebTransport session is closed. + + :param session: A WebTransport session. + :param close_info: The code and reason attached to the + CLOSE_WEBTRANSPORT_SESSION capsule. + :param abruptly: True when the session is closed forcibly + (by a CLOSE_CONNECTION QUIC frame for example). + """ + pass + + +def stream_reset(session: WebTransportSession, + stream_id: int, + error_code: int) -> None: + """ + Called when a stream is reset with RESET_STREAM. + + :param session: A WebTransport session. + :param stream_id: The ID of the stream. + :param error_code: The reason of the reset. + """ + pass diff --git a/testing/web-platform/tests/tools/webtransport/h3/test_capsule.py b/testing/web-platform/tests/tools/webtransport/h3/test_capsule.py new file mode 100644 index 0000000000..4321775e93 --- /dev/null +++ b/testing/web-platform/tests/tools/webtransport/h3/test_capsule.py @@ -0,0 +1,130 @@ +# type: ignore + +import unittest +import importlib.util +import pytest + +if importlib.util.find_spec('aioquic'): + has_aioquic = True + from .capsule import H3Capsule, H3CapsuleDecoder + from aioquic.buffer import BufferReadError +else: + has_aioquic = False + + +class H3CapsuleTest(unittest.TestCase): + @pytest.mark.skipif(not has_aioquic, reason='not having aioquic') + def test_capsule(self) -> None: + capsule1 = H3Capsule(0x12345, b'abcde') + bs = capsule1.encode() + decoder = H3CapsuleDecoder() + decoder.append(bs) + capsule2 = next(iter(decoder)) + + self.assertEqual(bs, b'\x80\x01\x23\x45\x05abcde', 'bytes') + self.assertEqual(capsule1.type, capsule2.type, 'type') + self.assertEqual(capsule1.data, capsule2.data, 'data') + + @pytest.mark.skipif( + not has_aioquic, reason='not having aioquic') + def test_small_capsule(self) -> None: + capsule1 = H3Capsule(0, b'') + bs = capsule1.encode() + decoder = H3CapsuleDecoder() + decoder.append(bs) + capsule2 = next(iter(decoder)) + + self.assertEqual(bs, b'\x00\x00', 'bytes') + self.assertEqual(capsule1.type, capsule2.type, 'type') + self.assertEqual(capsule1.data, capsule2.data, 'data') + + @pytest.mark.skipif(not has_aioquic, reason='not having aioquic') + def test_capsule_append(self) -> None: + decoder = H3CapsuleDecoder() + decoder.append(b'\x80') + + with self.assertRaises(StopIteration): + next(iter(decoder)) + + decoder.append(b'\x01\x23') + with self.assertRaises(StopIteration): + next(iter(decoder)) + + decoder.append(b'\x45\x05abcd') + with self.assertRaises(StopIteration): + next(iter(decoder)) + + decoder.append(b'e\x00') + capsule1 = next(iter(decoder)) + + self.assertEqual(capsule1.type, 0x12345, 'type') + self.assertEqual(capsule1.data, b'abcde', 'data') + + with self.assertRaises(StopIteration): + next(iter(decoder)) + + decoder.append(b'\x00') + capsule2 = next(iter(decoder)) + self.assertEqual(capsule2.type, 0, 'type') + self.assertEqual(capsule2.data, b'', 'data') + + @pytest.mark.skipif(not has_aioquic, reason='not having aioquic') + def test_multiple_values(self) -> None: + decoder = H3CapsuleDecoder() + decoder.append(b'\x01\x02ab\x03\x04cdef') + + it = iter(decoder) + capsule1 = next(it) + capsule2 = next(it) + with self.assertRaises(StopIteration): + next(it) + with self.assertRaises(StopIteration): + next(iter(decoder)) + + self.assertEqual(capsule1.type, 1, 'type') + self.assertEqual(capsule1.data, b'ab', 'data') + self.assertEqual(capsule2.type, 3, 'type') + self.assertEqual(capsule2.data, b'cdef', 'data') + + @pytest.mark.skipif(not has_aioquic, reason='not having aioquic') + def test_final(self) -> None: + decoder = H3CapsuleDecoder() + decoder.append(b'\x01') + + with self.assertRaises(StopIteration): + next(iter(decoder)) + + decoder.append(b'\x01a') + decoder.final() + capsule1 = next(iter(decoder)) + with self.assertRaises(StopIteration): + next(iter(decoder)) + + self.assertEqual(capsule1.type, 1, 'type') + self.assertEqual(capsule1.data, b'a', 'data') + + @pytest.mark.skipif(not has_aioquic, reason='not having aioquic') + def test_empty_bytes_before_fin(self) -> None: + decoder = H3CapsuleDecoder() + decoder.append(b'') + decoder.final() + + it = iter(decoder) + with self.assertRaises(StopIteration): + next(it) + + @pytest.mark.skipif(not has_aioquic, reason='not having aioquic') + def test_final_invalid(self) -> None: + decoder = H3CapsuleDecoder() + decoder.append(b'\x01') + + with self.assertRaises(StopIteration): + next(iter(decoder)) + + decoder.final() + with self.assertRaises(BufferReadError): + next(iter(decoder)) + + +if __name__ == '__main__': + unittest.main() diff --git a/testing/web-platform/tests/tools/webtransport/h3/webtransport_h3_server.py b/testing/web-platform/tests/tools/webtransport/h3/webtransport_h3_server.py new file mode 100644 index 0000000000..0fc367e1ff --- /dev/null +++ b/testing/web-platform/tests/tools/webtransport/h3/webtransport_h3_server.py @@ -0,0 +1,607 @@ +# mypy: allow-subclassing-any, no-warn-return-any + +import asyncio +import logging +import os +import ssl +import sys +import threading +import traceback +from enum import IntEnum +from urllib.parse import urlparse +from typing import Any, Dict, List, Optional, Tuple + +# TODO(bashi): Remove import check suppressions once aioquic dependency is resolved. +from aioquic.buffer import Buffer # type: ignore +from aioquic.asyncio import QuicConnectionProtocol, serve # type: ignore +from aioquic.asyncio.client import connect # type: ignore +from aioquic.h3.connection import H3_ALPN, FrameType, H3Connection, ProtocolError, SettingsError # type: ignore +from aioquic.h3.events import H3Event, HeadersReceived, WebTransportStreamDataReceived, DatagramReceived, DataReceived # type: ignore +from aioquic.quic.configuration import QuicConfiguration # type: ignore +from aioquic.quic.connection import logger as quic_connection_logger # type: ignore +from aioquic.quic.connection import stream_is_unidirectional +from aioquic.quic.events import QuicEvent, ProtocolNegotiated, ConnectionTerminated, StreamReset # type: ignore +from aioquic.tls import SessionTicket # type: ignore + +from tools import localpaths # noqa: F401 +from wptserve import stash +from .capsule import H3Capsule, H3CapsuleDecoder, CapsuleType + +""" +A WebTransport over HTTP/3 server for testing. + +The server interprets the underlying protocols (WebTransport, HTTP/3 and QUIC) +and passes events to a particular webtransport handler. From the standpoint of +test authors, a webtransport handler is a Python script which contains some +callback functions. See handler.py for available callbacks. +""" + +SERVER_NAME = 'webtransport-h3-server' + +_logger: logging.Logger = logging.getLogger(__name__) +_doc_root: str = "" + +# Set aioquic's log level to WARNING to suppress some INFO logs which are +# recorded every connection close. +quic_connection_logger.setLevel(logging.WARNING) + + +class H3DatagramSetting(IntEnum): + # https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-04#section-8.1 + DRAFT04 = 0xffd277 + # https://datatracker.ietf.org/doc/html/rfc9220#section-5-2.2.1 + RFC = 0x33 + + +class H3ConnectionWithDatagram(H3Connection): + """ + A H3Connection subclass, to make it work with the latest + HTTP Datagram protocol. + """ + # https://datatracker.ietf.org/doc/html/rfc9220#name-iana-considerations + ENABLE_CONNECT_PROTOCOL = 0x08 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._datagram_setting: Optional[H3DatagramSetting] = None + + def _validate_settings(self, settings: Dict[int, int]) -> None: + # aioquic doesn't recognize the RFC version of HTTP Datagrams yet. + # Intentionally don't call `super()._validate_settings(settings)` since + # it raises a SettingsError when only the RFC version is negotiated. + if settings.get(H3DatagramSetting.RFC) == 1: + self._datagram_setting = H3DatagramSetting.RFC + elif settings.get(H3DatagramSetting.DRAFT04) == 1: + self._datagram_setting = H3DatagramSetting.DRAFT04 + + if self._datagram_setting is None: + raise SettingsError("HTTP Datagrams support required") + + + def _get_local_settings(self) -> Dict[int, int]: + settings = super()._get_local_settings() + settings[H3DatagramSetting.RFC] = 1 + settings[H3DatagramSetting.DRAFT04] = 1 + settings[H3ConnectionWithDatagram.ENABLE_CONNECT_PROTOCOL] = 1 + return settings + + @property + def datagram_setting(self) -> Optional[H3DatagramSetting]: + return self._datagram_setting + + +class WebTransportH3Protocol(QuicConnectionProtocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._handler: Optional[Any] = None + self._http: Optional[H3ConnectionWithDatagram] = None + self._session_stream_id: Optional[int] = None + self._close_info: Optional[Tuple[int, bytes]] = None + self._capsule_decoder_for_session_stream: H3CapsuleDecoder =\ + H3CapsuleDecoder() + self._allow_calling_session_closed = True + self._allow_datagrams = False + + def quic_event_received(self, event: QuicEvent) -> None: + if isinstance(event, ProtocolNegotiated): + self._http = H3ConnectionWithDatagram( + self._quic, enable_webtransport=True) + if self._http.datagram_setting != H3DatagramSetting.DRAFT04: + self._allow_datagrams = True + + if self._http is not None: + for http_event in self._http.handle_event(event): + self._h3_event_received(http_event) + + if isinstance(event, ConnectionTerminated): + self._call_session_closed(close_info=None, abruptly=True) + if isinstance(event, StreamReset): + if self._handler: + self._handler.stream_reset(event.stream_id, event.error_code) + + def _h3_event_received(self, event: H3Event) -> None: + if isinstance(event, HeadersReceived): + # Convert from List[Tuple[bytes, bytes]] to Dict[bytes, bytes]. + # Only the last header will be kept when there are duplicate + # headers. + headers = {} + for header, value in event.headers: + headers[header] = value + + method = headers.get(b":method") + protocol = headers.get(b":protocol") + origin = headers.get(b"origin") + # Accept any Origin but the client must send it. + if method == b"CONNECT" and protocol == b"webtransport" and origin: + self._session_stream_id = event.stream_id + self._handshake_webtransport(event, headers) + else: + status_code = 404 if origin else 403 + self._send_error_response(event.stream_id, status_code) + + if isinstance(event, DataReceived) and\ + self._session_stream_id == event.stream_id: + if self._http and not self._http.datagram_setting and\ + len(event.data) > 0: + raise ProtocolError('Unexpected data on the session stream') + self._receive_data_on_session_stream( + event.data, event.stream_ended) + elif self._handler is not None: + if isinstance(event, WebTransportStreamDataReceived): + self._handler.stream_data_received( + stream_id=event.stream_id, + data=event.data, + stream_ended=event.stream_ended) + elif isinstance(event, DatagramReceived): + if self._allow_datagrams: + self._handler.datagram_received(data=event.data) + + def _receive_data_on_session_stream(self, data: bytes, fin: bool) -> None: + assert self._http is not None + if len(data) > 0: + self._capsule_decoder_for_session_stream.append(data) + if fin: + self._capsule_decoder_for_session_stream.final() + for capsule in self._capsule_decoder_for_session_stream: + if self._close_info is not None: + raise ProtocolError(( + "Receiving a capsule with type = {} after receiving " + + "CLOSE_WEBTRANSPORT_SESSION").format(capsule.type)) + assert self._http.datagram_setting is not None + if self._http.datagram_setting == H3DatagramSetting.RFC: + self._receive_h3_datagram_rfc_capsule_data( + capsule=capsule, fin=fin) + elif self._http.datagram_setting == H3DatagramSetting.DRAFT04: + self._receive_h3_datagram_draft04_capsule_data( + capsule=capsule, fin=fin) + + def _receive_h3_datagram_rfc_capsule_data(self, capsule: H3Capsule, fin: bool) -> None: + if capsule.type == CapsuleType.DATAGRAM_RFC: + raise ProtocolError( + f"Unimplemented capsule type: {capsule.type}") + elif capsule.type == CapsuleType.CLOSE_WEBTRANSPORT_SESSION: + self._set_close_info_and_may_close_session( + data=capsule.data, fin=fin) + else: + # Ignore unknown capsules. + return + + def _receive_h3_datagram_draft04_capsule_data( + self, capsule: H3Capsule, fin: bool) -> None: + if capsule.type in {CapsuleType.DATAGRAM_DRAFT04, + CapsuleType.REGISTER_DATAGRAM_CONTEXT_DRAFT04, + CapsuleType.CLOSE_DATAGRAM_CONTEXT_DRAFT04}: + raise ProtocolError( + f"Unimplemented capsule type: {capsule.type}") + if capsule.type in {CapsuleType.REGISTER_DATAGRAM_NO_CONTEXT_DRAFT04, + CapsuleType.CLOSE_WEBTRANSPORT_SESSION}: + # We'll handle this case below. + pass + else: + # We should ignore unknown capsules. + return + + if capsule.type == CapsuleType.REGISTER_DATAGRAM_NO_CONTEXT_DRAFT04: + buffer = Buffer(data=capsule.data) + format_type = buffer.pull_uint_var() + # https://ietf-wg-webtrans.github.io/draft-ietf-webtrans-http3/draft-ietf-webtrans-http3.html#name-datagram-format-type + WEBTRANPORT_FORMAT_TYPE = 0xff7c00 + if format_type != WEBTRANPORT_FORMAT_TYPE: + raise ProtocolError( + "Unexpected datagram format type: {}".format( + format_type)) + self._allow_datagrams = True + elif capsule.type == CapsuleType.CLOSE_WEBTRANSPORT_SESSION: + self._set_close_info_and_may_close_session( + data=capsule.data, fin=fin) + + def _set_close_info_and_may_close_session( + self, data: bytes, fin: bool) -> None: + buffer = Buffer(data=data) + code = buffer.pull_uint32() + # 4 bytes for the uint32. + reason = buffer.pull_bytes(len(data) - 4) + # TODO(bashi): Make sure `reason` is a UTF-8 text. + self._close_info = (code, reason) + if fin: + self._call_session_closed(self._close_info, abruptly=False) + + def _send_error_response(self, stream_id: int, status_code: int) -> None: + assert self._http is not None + headers = [(b":status", str(status_code).encode()), + (b"server", SERVER_NAME.encode())] + self._http.send_headers(stream_id=stream_id, + headers=headers, + end_stream=True) + + def _handshake_webtransport(self, event: HeadersReceived, + request_headers: Dict[bytes, bytes]) -> None: + assert self._http is not None + path = request_headers.get(b":path") + if path is None: + # `:path` must be provided. + self._send_error_response(event.stream_id, 400) + return + + # Create a handler using `:path`. + try: + self._handler = self._create_event_handler( + session_id=event.stream_id, + path=path, + request_headers=event.headers) + except OSError: + self._send_error_response(event.stream_id, 404) + return + + response_headers = [ + (b"server", SERVER_NAME.encode()), + (b"sec-webtransport-http3-draft", b"draft02"), + ] + self._handler.connect_received(response_headers=response_headers) + + status_code = None + for name, value in response_headers: + if name == b":status": + status_code = value + response_headers.remove((b":status", status_code)) + response_headers.insert(0, (b":status", status_code)) + break + if not status_code: + response_headers.insert(0, (b":status", b"200")) + self._http.send_headers(stream_id=event.stream_id, + headers=response_headers) + + if status_code is None or status_code == b"200": + self._handler.session_established() + + def _create_event_handler(self, session_id: int, path: bytes, + request_headers: List[Tuple[bytes, bytes]]) -> Any: + parsed = urlparse(path.decode()) + file_path = os.path.join(_doc_root, parsed.path.lstrip("/")) + callbacks = {"__file__": file_path} + with open(file_path) as f: + exec(compile(f.read(), path, "exec"), callbacks) + session = WebTransportSession(self, session_id, request_headers) + return WebTransportEventHandler(session, callbacks) + + def _call_session_closed( + self, close_info: Optional[Tuple[int, bytes]], + abruptly: bool) -> None: + allow_calling_session_closed = self._allow_calling_session_closed + self._allow_calling_session_closed = False + if self._handler and allow_calling_session_closed: + self._handler.session_closed(close_info, abruptly) + + +class WebTransportSession: + """ + A WebTransport session. + """ + + def __init__(self, protocol: WebTransportH3Protocol, session_id: int, + request_headers: List[Tuple[bytes, bytes]]) -> None: + self.session_id = session_id + self.request_headers = request_headers + + self._protocol: WebTransportH3Protocol = protocol + self._http: H3Connection = protocol._http + + # Use the a shared default path for all handlers so that different + # WebTransport sessions can access the same store easily. + self._stash_path = '/webtransport/handlers' + self._stash: Optional[stash.Stash] = None + self._dict_for_handlers: Dict[str, Any] = {} + + @property + def stash(self) -> stash.Stash: + """A Stash object for storing cross-session state.""" + if self._stash is None: + address, authkey = stash.load_env_config() # type: ignore + self._stash = stash.Stash(self._stash_path, address, authkey) # type: ignore + return self._stash + + @property + def dict_for_handlers(self) -> Dict[str, Any]: + """A dictionary that handlers can attach arbitrary data.""" + return self._dict_for_handlers + + def stream_is_unidirectional(self, stream_id: int) -> bool: + """Return True if the stream is unidirectional.""" + return stream_is_unidirectional(stream_id) + + def close(self, close_info: Optional[Tuple[int, bytes]]) -> None: + """ + Close the session. + + :param close_info The close information to send. + """ + self._protocol._allow_calling_session_closed = False + assert self._protocol._session_stream_id is not None + session_stream_id = self._protocol._session_stream_id + if close_info is not None: + code = close_info[0] + reason = close_info[1] + buffer = Buffer(capacity=len(reason) + 4) + buffer.push_uint32(code) + buffer.push_bytes(reason) + capsule =\ + H3Capsule(CapsuleType.CLOSE_WEBTRANSPORT_SESSION, buffer.data) + self._http.send_data( + session_stream_id, capsule.encode(), end_stream=False) + + self._http.send_data(session_stream_id, b'', end_stream=True) + # TODO(yutakahirano): Reset all other streams. + # TODO(yutakahirano): Reject future stream open requests + # We need to wait for the stream data to arrive at the client, and then + # we need to close the connection. At this moment we're relying on the + # client's behavior. + # TODO(yutakahirano): Implement the above. + + def create_unidirectional_stream(self) -> int: + """ + Create a unidirectional WebTransport stream and return the stream ID. + """ + return self._http.create_webtransport_stream( + session_id=self.session_id, is_unidirectional=True) + + def create_bidirectional_stream(self) -> int: + """ + Create a bidirectional WebTransport stream and return the stream ID. + """ + stream_id = self._http.create_webtransport_stream( + session_id=self.session_id, is_unidirectional=False) + # TODO(bashi): Remove this workaround when aioquic supports receiving + # data on server-initiated bidirectional streams. + stream = self._http._get_or_create_stream(stream_id) + assert stream.frame_type is None + assert stream.session_id is None + stream.frame_type = FrameType.WEBTRANSPORT_STREAM + stream.session_id = self.session_id + return stream_id + + def send_stream_data(self, + stream_id: int, + data: bytes, + end_stream: bool = False) -> None: + """ + Send data on the specific stream. + + :param stream_id: The stream ID on which to send the data. + :param data: The data to send. + :param end_stream: If set to True, the stream will be closed. + """ + self._http._quic.send_stream_data(stream_id=stream_id, + data=data, + end_stream=end_stream) + + def send_datagram(self, data: bytes) -> None: + """ + Send data using a datagram frame. + + :param data: The data to send. + """ + if not self._protocol._allow_datagrams: + _logger.warn( + "Sending a datagram while that's now allowed - discarding it") + return + flow_id = self.session_id + if self._http.datagram_setting is not None: + # We must have a WebTransport Session ID at this point because + # an extended CONNECT request is already received. + assert self._protocol._session_stream_id is not None + # TODO(yutakahirano): Make sure if this is the correct logic. + # Chrome always use 0 for the initial stream and the initial flow + # ID, we cannot check the correctness with it. + flow_id = self._protocol._session_stream_id // 4 + self._http.send_datagram(flow_id=flow_id, data=data) + + def stop_stream(self, stream_id: int, code: int) -> None: + """ + Send a STOP_SENDING frame to the given stream. + :param code: the reason of the error. + """ + self._http._quic.stop_stream(stream_id, code) + + def reset_stream(self, stream_id: int, code: int) -> None: + """ + Send a RESET_STREAM frame to the given stream. + :param code: the reason of the error. + """ + self._http._quic.reset_stream(stream_id, code) + + +class WebTransportEventHandler: + def __init__(self, session: WebTransportSession, + callbacks: Dict[str, Any]) -> None: + self._session = session + self._callbacks = callbacks + + def _run_callback(self, callback_name: str, + *args: Any, **kwargs: Any) -> None: + if callback_name not in self._callbacks: + return + try: + self._callbacks[callback_name](*args, **kwargs) + except Exception as e: + _logger.warn(str(e)) + traceback.print_exc() + + def connect_received(self, response_headers: List[Tuple[bytes, + bytes]]) -> None: + self._run_callback("connect_received", self._session.request_headers, + response_headers) + + def session_established(self) -> None: + self._run_callback("session_established", self._session) + + def stream_data_received(self, stream_id: int, data: bytes, + stream_ended: bool) -> None: + self._run_callback("stream_data_received", self._session, stream_id, + data, stream_ended) + + def datagram_received(self, data: bytes) -> None: + self._run_callback("datagram_received", self._session, data) + + def session_closed( + self, + close_info: Optional[Tuple[int, bytes]], + abruptly: bool) -> None: + self._run_callback( + "session_closed", self._session, close_info, abruptly=abruptly) + + def stream_reset(self, stream_id: int, error_code: int) -> None: + self._run_callback( + "stream_reset", self._session, stream_id, error_code) + + +class SessionTicketStore: + """ + Simple in-memory store for session tickets. + """ + + def __init__(self) -> None: + self.tickets: Dict[bytes, SessionTicket] = {} + + def add(self, ticket: SessionTicket) -> None: + self.tickets[ticket.ticket] = ticket + + def pop(self, label: bytes) -> Optional[SessionTicket]: + return self.tickets.pop(label, None) + + +class WebTransportH3Server: + """ + A WebTransport over HTTP/3 for testing. + + :param host: Host from which to serve. + :param port: Port from which to serve. + :param doc_root: Document root for serving handlers. + :param cert_path: Path to certificate file to use. + :param key_path: Path to key file to use. + :param logger: a Logger object for this server. + """ + + def __init__(self, host: str, port: int, doc_root: str, cert_path: str, + key_path: str, logger: Optional[logging.Logger]) -> None: + self.host = host + self.port = port + self.doc_root = doc_root + self.cert_path = cert_path + self.key_path = key_path + self.started = False + global _doc_root + _doc_root = self.doc_root + global _logger + if logger is not None: + _logger = logger + + def start(self) -> None: + """Start the server.""" + self.server_thread = threading.Thread( + target=self._start_on_server_thread, daemon=True) + self.server_thread.start() + self.started = True + + def _start_on_server_thread(self) -> None: + secrets_log_file = None + if "SSLKEYLOGFILE" in os.environ: + try: + secrets_log_file = open(os.environ["SSLKEYLOGFILE"], "a") + except Exception as e: + _logger.warn(str(e)) + + configuration = QuicConfiguration( + alpn_protocols=H3_ALPN, + is_client=False, + max_datagram_frame_size=65536, + secrets_log_file=secrets_log_file, + ) + + _logger.info("Starting WebTransport over HTTP/3 server on %s:%s", + self.host, self.port) + + configuration.load_cert_chain(self.cert_path, self.key_path) + + ticket_store = SessionTicketStore() + + # On Windows, the default event loop is ProactorEventLoop but it + # doesn't seem to work when aioquic detects a connection loss. + # Use SelectorEventLoop to work around the problem. + if sys.platform == "win32": + asyncio.set_event_loop_policy( + asyncio.WindowsSelectorEventLoopPolicy()) + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.loop.run_until_complete( + serve( + self.host, + self.port, + configuration=configuration, + create_protocol=WebTransportH3Protocol, + session_ticket_fetcher=ticket_store.pop, + session_ticket_handler=ticket_store.add, + )) + self.loop.run_forever() + + def stop(self) -> None: + """Stop the server.""" + if self.started: + asyncio.run_coroutine_threadsafe(self._stop_on_server_thread(), + self.loop) + self.server_thread.join() + _logger.info("Stopped WebTransport over HTTP/3 server on %s:%s", + self.host, self.port) + self.started = False + + async def _stop_on_server_thread(self) -> None: + self.loop.stop() + + +def server_is_running(host: str, port: int, timeout: float) -> bool: + """ + Check the WebTransport over HTTP/3 server is running at the given `host` and + `port`. + """ + loop = asyncio.get_event_loop() + return loop.run_until_complete(_connect_server_with_timeout(host, port, timeout)) + + +async def _connect_server_with_timeout(host: str, port: int, timeout: float) -> bool: + try: + await asyncio.wait_for(_connect_to_server(host, port), timeout=timeout) + except asyncio.TimeoutError: + _logger.warning("Failed to connect WebTransport over HTTP/3 server") + return False + return True + + +async def _connect_to_server(host: str, port: int) -> None: + configuration = QuicConfiguration( + alpn_protocols=H3_ALPN, + is_client=True, + verify_mode=ssl.CERT_NONE, + ) + + async with connect(host, port, configuration=configuration) as protocol: + await protocol.ping() -- cgit v1.2.3