diff options
Diffstat (limited to 'test/test_websockets.py')
-rw-r--r-- | test/test_websockets.py | 383 |
1 files changed, 383 insertions, 0 deletions
diff --git a/test/test_websockets.py b/test/test_websockets.py new file mode 100644 index 0000000..13b3a1e --- /dev/null +++ b/test/test_websockets.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 + +# Allow direct execution +import os +import sys + +import pytest + +from test.helper import verify_address_availability + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import http.client +import http.cookiejar +import http.server +import json +import random +import ssl +import threading + +from yt_dlp import socks +from yt_dlp.cookies import YoutubeDLCookieJar +from yt_dlp.dependencies import websockets +from yt_dlp.networking import Request +from yt_dlp.networking.exceptions import ( + CertificateVerifyError, + HTTPError, + ProxyError, + RequestError, + SSLError, + TransportError, +) +from yt_dlp.utils.networking import HTTPHeaderDict + +from test.conftest import validate_and_send + +TEST_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def websocket_handler(websocket): + for message in websocket: + if isinstance(message, bytes): + if message == b'bytes': + return websocket.send('2') + elif isinstance(message, str): + if message == 'headers': + return websocket.send(json.dumps(dict(websocket.request.headers))) + elif message == 'path': + return websocket.send(websocket.request.path) + elif message == 'source_address': + return websocket.send(websocket.remote_address[0]) + elif message == 'str': + return websocket.send('1') + return websocket.send(message) + + +def process_request(self, request): + if request.path.startswith('/gen_'): + status = http.HTTPStatus(int(request.path[5:])) + if 300 <= status.value <= 300: + return websockets.http11.Response( + status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'') + return self.protocol.reject(status.value, status.phrase) + return self.protocol.accept(request) + + +def create_websocket_server(**ws_kwargs): + import websockets.sync.server + wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs) + ws_port = wsd.socket.getsockname()[1] + ws_server_thread = threading.Thread(target=wsd.serve_forever) + ws_server_thread.daemon = True + ws_server_thread.start() + return ws_server_thread, ws_port + + +def create_ws_websocket_server(): + return create_websocket_server() + + +def create_wss_websocket_server(): + certfn = os.path.join(TEST_DIR, 'testcert.pem') + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslctx.load_cert_chain(certfn, None) + return create_websocket_server(ssl_context=sslctx) + + +MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate') + + +def create_mtls_wss_websocket_server(): + certfn = os.path.join(TEST_DIR, 'testcert.pem') + cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt') + + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslctx.verify_mode = ssl.CERT_REQUIRED + sslctx.load_verify_locations(cafile=cacertfn) + sslctx.load_cert_chain(certfn, None) + + return create_websocket_server(ssl_context=sslctx) + + +@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') +class TestWebsSocketRequestHandlerConformance: + @classmethod + def setup_class(cls): + cls.ws_thread, cls.ws_port = create_ws_websocket_server() + cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}' + + cls.wss_thread, cls.wss_port = create_wss_websocket_server() + cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}' + + cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)) + cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}' + + cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server() + cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}' + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_basic_websockets(self, handler): + with handler() as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + assert 'upgrade' in ws.headers + assert ws.status == 101 + ws.send('foo') + assert ws.recv() == 'foo' + ws.close() + + # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)]) + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_send_types(self, handler, msg, opcode): + with handler() as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send(msg) + assert int(ws.recv()) == opcode + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_verify_cert(self, handler): + with handler() as rh: + with pytest.raises(CertificateVerifyError): + validate_and_send(rh, Request(self.wss_base_url)) + + with handler(verify=False) as rh: + ws = validate_and_send(rh, Request(self.wss_base_url)) + assert ws.status == 101 + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_ssl_error(self, handler): + with handler(verify=False) as rh: + with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info: + validate_and_send(rh, Request(self.bad_wss_host)) + assert not issubclass(exc_info.type, CertificateVerifyError) + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + @pytest.mark.parametrize('path,expected', [ + # Unicode characters should be encoded with uppercase percent-encoding + ('/中文', '/%E4%B8%AD%E6%96%87'), + # don't normalize existing percent encodings + ('/%c7%9f', '/%c7%9f'), + ]) + def test_percent_encode(self, handler, path, expected): + with handler() as rh: + ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) + ws.send('path') + assert ws.recv() == expected + assert ws.status == 101 + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_remove_dot_segments(self, handler): + with handler() as rh: + # This isn't a comprehensive test, + # but it should be enough to check whether the handler is removing dot segments + ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) + assert ws.status == 101 + ws.send('path') + assert ws.recv() == '/test' + ws.close() + + # We are restricted to known HTTP status codes in http.HTTPStatus + # Redirects are not supported for websockets + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511)) + def test_raise_http_error(self, handler, status): + with handler() as rh: + with pytest.raises(HTTPError) as exc_info: + validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) + assert exc_info.value.status == status + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + @pytest.mark.parametrize('params,extensions', [ + ({'timeout': sys.float_info.min}, {}), + ({}, {'timeout': sys.float_info.min}), + ]) + def test_timeout(self, handler, params, extensions): + with handler(**params) as rh: + with pytest.raises(TransportError): + validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_cookies(self, handler): + cookiejar = YoutubeDLCookieJar() + cookiejar.set_cookie(http.cookiejar.Cookie( + version=0, name='test', value='ytdlp', port=None, port_specified=False, + domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/', + path_specified=True, secure=False, expires=None, discard=False, comment=None, + comment_url=None, rest={})) + + with handler(cookiejar=cookiejar) as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('headers') + assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + ws.close() + + with handler() as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('headers') + assert 'cookie' not in json.loads(ws.recv()) + ws.close() + + ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) + ws.send('headers') + assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_source_address(self, handler): + source_address = f'127.0.0.{random.randint(5, 255)}' + verify_address_availability(source_address) + with handler(source_address=source_address) as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('source_address') + assert source_address == ws.recv() + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_response_url(self, handler): + with handler() as rh: + url = f'{self.ws_base_url}/something' + ws = validate_and_send(rh, Request(url)) + assert ws.url == url + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_request_headers(self, handler): + with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: + # Global Headers + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('headers') + headers = HTTPHeaderDict(json.loads(ws.recv())) + assert headers['test1'] == 'test' + ws.close() + + # Per request headers, merged with global + ws = validate_and_send(rh, Request( + self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'})) + ws.send('headers') + headers = HTTPHeaderDict(json.loads(ws.recv())) + assert headers['test1'] == 'test' + assert headers['test2'] == 'changed' + assert headers['test3'] == 'test3' + ws.close() + + @pytest.mark.parametrize('client_cert', ( + {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}, + { + 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), + 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'), + }, + { + 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'), + 'client_certificate_password': 'foobar', + }, + { + 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), + 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'), + 'client_certificate_password': 'foobar', + } + )) + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_mtls(self, handler, client_cert): + with handler( + # Disable client-side validation of unacceptable self-signed testcert.pem + # The test is of a check on the server side, so unaffected + verify=False, + client_cert=client_cert + ) as rh: + validate_and_send(rh, Request(self.mtls_wss_base_url)).close() + + +def create_fake_ws_connection(raised): + import websockets.sync.client + + class FakeWsConnection(websockets.sync.client.ClientConnection): + def __init__(self, *args, **kwargs): + class FakeResponse: + body = b'' + headers = {} + status_code = 101 + reason_phrase = 'test' + + self.response = FakeResponse() + + def send(self, *args, **kwargs): + raise raised() + + def recv(self, *args, **kwargs): + raise raised() + + def close(self, *args, **kwargs): + return + + return FakeWsConnection() + + +@pytest.mark.parametrize('handler', ['Websockets'], indirect=True) +class TestWebsocketsRequestHandler: + @pytest.mark.parametrize('raised,expected', [ + # https://websockets.readthedocs.io/en/stable/reference/exceptions.html + (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError), + # Requires a response object. Should be covered by HTTP error tests. + # (lambda: websockets.exceptions.InvalidStatus(), TransportError), + (lambda: websockets.exceptions.InvalidHandshake(), TransportError), + # These are subclasses of InvalidHandshake + (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError), + (lambda: websockets.exceptions.NegotiationError(), TransportError), + # Catch-all + (lambda: websockets.exceptions.WebSocketException(), TransportError), + (lambda: TimeoutError(), TransportError), + # These may be raised by our create_connection implementation, which should also be caught + (lambda: OSError(), TransportError), + (lambda: ssl.SSLError(), SSLError), + (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError), + (lambda: socks.ProxyError(), ProxyError), + ]) + def test_request_error_mapping(self, handler, monkeypatch, raised, expected): + import websockets.sync.client + + import yt_dlp.networking._websockets + with handler() as rh: + def fake_connect(*args, **kwargs): + raise raised() + monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None) + monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect) + with pytest.raises(expected) as exc_info: + rh.send(Request('ws://fake-url')) + assert exc_info.type is expected + + @pytest.mark.parametrize('raised,expected,match', [ + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send + (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), + (lambda: RuntimeError(), TransportError, None), + (lambda: TimeoutError(), TransportError, None), + (lambda: TypeError(), RequestError, None), + (lambda: socks.ProxyError(), ProxyError, None), + # Catch-all + (lambda: websockets.exceptions.WebSocketException(), TransportError, None), + ]) + def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match): + from yt_dlp.networking._websockets import WebsocketsResponseAdapter + ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') + with pytest.raises(expected, match=match) as exc_info: + ws.send('test') + assert exc_info.type is expected + + @pytest.mark.parametrize('raised,expected,match', [ + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv + (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), + (lambda: RuntimeError(), TransportError, None), + (lambda: TimeoutError(), TransportError, None), + (lambda: socks.ProxyError(), ProxyError, None), + # Catch-all + (lambda: websockets.exceptions.WebSocketException(), TransportError, None), + ]) + def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match): + from yt_dlp.networking._websockets import WebsocketsResponseAdapter + ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') + with pytest.raises(expected, match=match) as exc_info: + ws.recv() + assert exc_info.type is expected |