summaryrefslogtreecommitdiffstats
path: root/test/test_websockets.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_websockets.py')
-rw-r--r--test/test_websockets.py383
1 files changed, 383 insertions, 0 deletions
diff --git a/test/test_websockets.py b/test/test_websockets.py
new file mode 100644
index 0000000..13b3a1e
--- /dev/null
+++ b/test/test_websockets.py
@@ -0,0 +1,383 @@
+#!/usr/bin/env python3
+
+# Allow direct execution
+import os
+import sys
+
+import pytest
+
+from test.helper import verify_address_availability
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import http.client
+import http.cookiejar
+import http.server
+import json
+import random
+import ssl
+import threading
+
+from yt_dlp import socks
+from yt_dlp.cookies import YoutubeDLCookieJar
+from yt_dlp.dependencies import websockets
+from yt_dlp.networking import Request
+from yt_dlp.networking.exceptions import (
+ CertificateVerifyError,
+ HTTPError,
+ ProxyError,
+ RequestError,
+ SSLError,
+ TransportError,
+)
+from yt_dlp.utils.networking import HTTPHeaderDict
+
+from test.conftest import validate_and_send
+
+TEST_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+def websocket_handler(websocket):
+ for message in websocket:
+ if isinstance(message, bytes):
+ if message == b'bytes':
+ return websocket.send('2')
+ elif isinstance(message, str):
+ if message == 'headers':
+ return websocket.send(json.dumps(dict(websocket.request.headers)))
+ elif message == 'path':
+ return websocket.send(websocket.request.path)
+ elif message == 'source_address':
+ return websocket.send(websocket.remote_address[0])
+ elif message == 'str':
+ return websocket.send('1')
+ return websocket.send(message)
+
+
+def process_request(self, request):
+ if request.path.startswith('/gen_'):
+ status = http.HTTPStatus(int(request.path[5:]))
+ if 300 <= status.value <= 300:
+ return websockets.http11.Response(
+ status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
+ return self.protocol.reject(status.value, status.phrase)
+ return self.protocol.accept(request)
+
+
+def create_websocket_server(**ws_kwargs):
+ import websockets.sync.server
+ wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
+ ws_port = wsd.socket.getsockname()[1]
+ ws_server_thread = threading.Thread(target=wsd.serve_forever)
+ ws_server_thread.daemon = True
+ ws_server_thread.start()
+ return ws_server_thread, ws_port
+
+
+def create_ws_websocket_server():
+ return create_websocket_server()
+
+
+def create_wss_websocket_server():
+ certfn = os.path.join(TEST_DIR, 'testcert.pem')
+ sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ sslctx.load_cert_chain(certfn, None)
+ return create_websocket_server(ssl_context=sslctx)
+
+
+MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
+
+
+def create_mtls_wss_websocket_server():
+ certfn = os.path.join(TEST_DIR, 'testcert.pem')
+ cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
+
+ sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ sslctx.verify_mode = ssl.CERT_REQUIRED
+ sslctx.load_verify_locations(cafile=cacertfn)
+ sslctx.load_cert_chain(certfn, None)
+
+ return create_websocket_server(ssl_context=sslctx)
+
+
+@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
+class TestWebsSocketRequestHandlerConformance:
+ @classmethod
+ def setup_class(cls):
+ cls.ws_thread, cls.ws_port = create_ws_websocket_server()
+ cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
+
+ cls.wss_thread, cls.wss_port = create_wss_websocket_server()
+ cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
+
+ cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
+ cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
+
+ cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
+ cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_basic_websockets(self, handler):
+ with handler() as rh:
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ assert 'upgrade' in ws.headers
+ assert ws.status == 101
+ ws.send('foo')
+ assert ws.recv() == 'foo'
+ ws.close()
+
+ # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
+ @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_send_types(self, handler, msg, opcode):
+ with handler() as rh:
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ ws.send(msg)
+ assert int(ws.recv()) == opcode
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_verify_cert(self, handler):
+ with handler() as rh:
+ with pytest.raises(CertificateVerifyError):
+ validate_and_send(rh, Request(self.wss_base_url))
+
+ with handler(verify=False) as rh:
+ ws = validate_and_send(rh, Request(self.wss_base_url))
+ assert ws.status == 101
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_ssl_error(self, handler):
+ with handler(verify=False) as rh:
+ with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
+ validate_and_send(rh, Request(self.bad_wss_host))
+ assert not issubclass(exc_info.type, CertificateVerifyError)
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ @pytest.mark.parametrize('path,expected', [
+ # Unicode characters should be encoded with uppercase percent-encoding
+ ('/中文', '/%E4%B8%AD%E6%96%87'),
+ # don't normalize existing percent encodings
+ ('/%c7%9f', '/%c7%9f'),
+ ])
+ def test_percent_encode(self, handler, path, expected):
+ with handler() as rh:
+ ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
+ ws.send('path')
+ assert ws.recv() == expected
+ assert ws.status == 101
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_remove_dot_segments(self, handler):
+ with handler() as rh:
+ # This isn't a comprehensive test,
+ # but it should be enough to check whether the handler is removing dot segments
+ ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
+ assert ws.status == 101
+ ws.send('path')
+ assert ws.recv() == '/test'
+ ws.close()
+
+ # We are restricted to known HTTP status codes in http.HTTPStatus
+ # Redirects are not supported for websockets
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
+ def test_raise_http_error(self, handler, status):
+ with handler() as rh:
+ with pytest.raises(HTTPError) as exc_info:
+ validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
+ assert exc_info.value.status == status
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ @pytest.mark.parametrize('params,extensions', [
+ ({'timeout': sys.float_info.min}, {}),
+ ({}, {'timeout': sys.float_info.min}),
+ ])
+ def test_timeout(self, handler, params, extensions):
+ with handler(**params) as rh:
+ with pytest.raises(TransportError):
+ validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_cookies(self, handler):
+ cookiejar = YoutubeDLCookieJar()
+ cookiejar.set_cookie(http.cookiejar.Cookie(
+ version=0, name='test', value='ytdlp', port=None, port_specified=False,
+ domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
+ path_specified=True, secure=False, expires=None, discard=False, comment=None,
+ comment_url=None, rest={}))
+
+ with handler(cookiejar=cookiejar) as rh:
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ ws.send('headers')
+ assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
+ ws.close()
+
+ with handler() as rh:
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ ws.send('headers')
+ assert 'cookie' not in json.loads(ws.recv())
+ ws.close()
+
+ ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
+ ws.send('headers')
+ assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_source_address(self, handler):
+ source_address = f'127.0.0.{random.randint(5, 255)}'
+ verify_address_availability(source_address)
+ with handler(source_address=source_address) as rh:
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ ws.send('source_address')
+ assert source_address == ws.recv()
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_response_url(self, handler):
+ with handler() as rh:
+ url = f'{self.ws_base_url}/something'
+ ws = validate_and_send(rh, Request(url))
+ assert ws.url == url
+ ws.close()
+
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_request_headers(self, handler):
+ with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
+ # Global Headers
+ ws = validate_and_send(rh, Request(self.ws_base_url))
+ ws.send('headers')
+ headers = HTTPHeaderDict(json.loads(ws.recv()))
+ assert headers['test1'] == 'test'
+ ws.close()
+
+ # Per request headers, merged with global
+ ws = validate_and_send(rh, Request(
+ self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
+ ws.send('headers')
+ headers = HTTPHeaderDict(json.loads(ws.recv()))
+ assert headers['test1'] == 'test'
+ assert headers['test2'] == 'changed'
+ assert headers['test3'] == 'test3'
+ ws.close()
+
+ @pytest.mark.parametrize('client_cert', (
+ {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
+ {
+ 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
+ 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
+ },
+ {
+ 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
+ 'client_certificate_password': 'foobar',
+ },
+ {
+ 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
+ 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
+ 'client_certificate_password': 'foobar',
+ }
+ ))
+ @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+ def test_mtls(self, handler, client_cert):
+ with handler(
+ # Disable client-side validation of unacceptable self-signed testcert.pem
+ # The test is of a check on the server side, so unaffected
+ verify=False,
+ client_cert=client_cert
+ ) as rh:
+ validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
+
+
+def create_fake_ws_connection(raised):
+ import websockets.sync.client
+
+ class FakeWsConnection(websockets.sync.client.ClientConnection):
+ def __init__(self, *args, **kwargs):
+ class FakeResponse:
+ body = b''
+ headers = {}
+ status_code = 101
+ reason_phrase = 'test'
+
+ self.response = FakeResponse()
+
+ def send(self, *args, **kwargs):
+ raise raised()
+
+ def recv(self, *args, **kwargs):
+ raise raised()
+
+ def close(self, *args, **kwargs):
+ return
+
+ return FakeWsConnection()
+
+
+@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
+class TestWebsocketsRequestHandler:
+ @pytest.mark.parametrize('raised,expected', [
+ # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
+ (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
+ # Requires a response object. Should be covered by HTTP error tests.
+ # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
+ (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
+ # These are subclasses of InvalidHandshake
+ (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
+ (lambda: websockets.exceptions.NegotiationError(), TransportError),
+ # Catch-all
+ (lambda: websockets.exceptions.WebSocketException(), TransportError),
+ (lambda: TimeoutError(), TransportError),
+ # These may be raised by our create_connection implementation, which should also be caught
+ (lambda: OSError(), TransportError),
+ (lambda: ssl.SSLError(), SSLError),
+ (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
+ (lambda: socks.ProxyError(), ProxyError),
+ ])
+ def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
+ import websockets.sync.client
+
+ import yt_dlp.networking._websockets
+ with handler() as rh:
+ def fake_connect(*args, **kwargs):
+ raise raised()
+ monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
+ monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
+ with pytest.raises(expected) as exc_info:
+ rh.send(Request('ws://fake-url'))
+ assert exc_info.type is expected
+
+ @pytest.mark.parametrize('raised,expected,match', [
+ # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
+ (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
+ (lambda: RuntimeError(), TransportError, None),
+ (lambda: TimeoutError(), TransportError, None),
+ (lambda: TypeError(), RequestError, None),
+ (lambda: socks.ProxyError(), ProxyError, None),
+ # Catch-all
+ (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
+ ])
+ def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
+ from yt_dlp.networking._websockets import WebsocketsResponseAdapter
+ ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
+ with pytest.raises(expected, match=match) as exc_info:
+ ws.send('test')
+ assert exc_info.type is expected
+
+ @pytest.mark.parametrize('raised,expected,match', [
+ # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
+ (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
+ (lambda: RuntimeError(), TransportError, None),
+ (lambda: TimeoutError(), TransportError, None),
+ (lambda: socks.ProxyError(), ProxyError, None),
+ # Catch-all
+ (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
+ ])
+ def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
+ from yt_dlp.networking._websockets import WebsocketsResponseAdapter
+ ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
+ with pytest.raises(expected, match=match) as exc_info:
+ ws.recv()
+ assert exc_info.type is expected