summaryrefslogtreecommitdiffstats
path: root/yt_dlp/networking
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-15 16:49:24 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-15 16:49:24 +0000
commit2415e66f889f38503b73e8ebc5f43ca342390e5c (patch)
treeac48ab69d1d96bae3d83756134921e0d90593aa5 /yt_dlp/networking
parentInitial commit. (diff)
downloadyt-dlp-2415e66f889f38503b73e8ebc5f43ca342390e5c.tar.xz
yt-dlp-2415e66f889f38503b73e8ebc5f43ca342390e5c.zip
Adding upstream version 2024.03.10.upstream/2024.03.10
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'yt_dlp/networking')
-rw-r--r--yt_dlp/networking/__init__.py30
-rw-r--r--yt_dlp/networking/_helper.py283
-rw-r--r--yt_dlp/networking/_requests.py408
-rw-r--r--yt_dlp/networking/_urllib.py422
-rw-r--r--yt_dlp/networking/_websockets.py173
-rw-r--r--yt_dlp/networking/common.py565
-rw-r--r--yt_dlp/networking/exceptions.py103
-rw-r--r--yt_dlp/networking/websocket.py23
8 files changed, 2007 insertions, 0 deletions
diff --git a/yt_dlp/networking/__init__.py b/yt_dlp/networking/__init__.py
new file mode 100644
index 0000000..acadc01
--- /dev/null
+++ b/yt_dlp/networking/__init__.py
@@ -0,0 +1,30 @@
+# flake8: noqa: F401
+import warnings
+
+from .common import (
+ HEADRequest,
+ PUTRequest,
+ Request,
+ RequestDirector,
+ RequestHandler,
+ Response,
+)
+
+# isort: split
+# TODO: all request handlers should be safely imported
+from . import _urllib
+from ..utils import bug_reports_message
+
+try:
+ from . import _requests
+except ImportError:
+ pass
+except Exception as e:
+ warnings.warn(f'Failed to import "requests" request handler: {e}' + bug_reports_message())
+
+try:
+ from . import _websockets
+except ImportError:
+ pass
+except Exception as e:
+ warnings.warn(f'Failed to import "websockets" request handler: {e}' + bug_reports_message())
diff --git a/yt_dlp/networking/_helper.py b/yt_dlp/networking/_helper.py
new file mode 100644
index 0000000..d79dd79
--- /dev/null
+++ b/yt_dlp/networking/_helper.py
@@ -0,0 +1,283 @@
+from __future__ import annotations
+
+import contextlib
+import functools
+import socket
+import ssl
+import sys
+import typing
+import urllib.parse
+import urllib.request
+
+from .exceptions import RequestError, UnsupportedRequest
+from ..dependencies import certifi
+from ..socks import ProxyType, sockssocket
+from ..utils import format_field, traverse_obj
+
+if typing.TYPE_CHECKING:
+ from collections.abc import Iterable
+
+ from ..utils.networking import HTTPHeaderDict
+
+
+def ssl_load_certs(context: ssl.SSLContext, use_certifi=True):
+ if certifi and use_certifi:
+ context.load_verify_locations(cafile=certifi.where())
+ else:
+ try:
+ context.load_default_certs()
+ # Work around the issue in load_default_certs when there are bad certificates. See:
+ # https://github.com/yt-dlp/yt-dlp/issues/1060,
+ # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312
+ except ssl.SSLError:
+ # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151
+ if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'):
+ for storename in ('CA', 'ROOT'):
+ ssl_load_windows_store_certs(context, storename)
+ context.set_default_verify_paths()
+
+
+def ssl_load_windows_store_certs(ssl_context, storename):
+ # Code adapted from _load_windows_store_certs in https://github.com/python/cpython/blob/main/Lib/ssl.py
+ try:
+ certs = [cert for cert, encoding, trust in ssl.enum_certificates(storename)
+ if encoding == 'x509_asn' and (
+ trust is True or ssl.Purpose.SERVER_AUTH.oid in trust)]
+ except PermissionError:
+ return
+ for cert in certs:
+ with contextlib.suppress(ssl.SSLError):
+ ssl_context.load_verify_locations(cadata=cert)
+
+
+def make_socks_proxy_opts(socks_proxy):
+ url_components = urllib.parse.urlparse(socks_proxy)
+ if url_components.scheme.lower() == 'socks5':
+ socks_type = ProxyType.SOCKS5
+ rdns = False
+ elif url_components.scheme.lower() == 'socks5h':
+ socks_type = ProxyType.SOCKS5
+ rdns = True
+ elif url_components.scheme.lower() == 'socks4':
+ socks_type = ProxyType.SOCKS4
+ rdns = False
+ elif url_components.scheme.lower() == 'socks4a':
+ socks_type = ProxyType.SOCKS4A
+ rdns = True
+ else:
+ raise ValueError(f'Unknown SOCKS proxy version: {url_components.scheme.lower()}')
+
+ def unquote_if_non_empty(s):
+ if not s:
+ return s
+ return urllib.parse.unquote_plus(s)
+ return {
+ 'proxytype': socks_type,
+ 'addr': url_components.hostname,
+ 'port': url_components.port or 1080,
+ 'rdns': rdns,
+ 'username': unquote_if_non_empty(url_components.username),
+ 'password': unquote_if_non_empty(url_components.password),
+ }
+
+
+def select_proxy(url, proxies):
+ """Unified proxy selector for all backends"""
+ url_components = urllib.parse.urlparse(url)
+ if 'no' in proxies:
+ hostport = url_components.hostname + format_field(url_components.port, None, ':%s')
+ if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}):
+ return
+ elif urllib.request.proxy_bypass(hostport): # check system settings
+ return
+
+ return traverse_obj(proxies, url_components.scheme or 'http', 'all')
+
+
+def get_redirect_method(method, status):
+ """Unified redirect method handling"""
+
+ # A 303 must either use GET or HEAD for subsequent request
+ # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.4
+ if status == 303 and method != 'HEAD':
+ method = 'GET'
+ # 301 and 302 redirects are commonly turned into a GET from a POST
+ # for subsequent requests by browsers, so we'll do the same.
+ # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.2
+ # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.3
+ if status in (301, 302) and method == 'POST':
+ method = 'GET'
+ return method
+
+
+def make_ssl_context(
+ verify=True,
+ client_certificate=None,
+ client_certificate_key=None,
+ client_certificate_password=None,
+ legacy_support=False,
+ use_certifi=True,
+):
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ context.check_hostname = verify
+ context.verify_mode = ssl.CERT_REQUIRED if verify else ssl.CERT_NONE
+
+ # Some servers may reject requests if ALPN extension is not sent. See:
+ # https://github.com/python/cpython/issues/85140
+ # https://github.com/yt-dlp/yt-dlp/issues/3878
+ with contextlib.suppress(NotImplementedError):
+ context.set_alpn_protocols(['http/1.1'])
+ if verify:
+ ssl_load_certs(context, use_certifi)
+
+ if legacy_support:
+ context.options |= 4 # SSL_OP_LEGACY_SERVER_CONNECT
+ context.set_ciphers('DEFAULT') # compat
+
+ elif ssl.OPENSSL_VERSION_INFO >= (1, 1, 1) and not ssl.OPENSSL_VERSION.startswith('LibreSSL'):
+ # Use the default SSL ciphers and minimum TLS version settings from Python 3.10 [1].
+ # This is to ensure consistent behavior across Python versions and libraries, and help avoid fingerprinting
+ # in some situations [2][3].
+ # Python 3.10 only supports OpenSSL 1.1.1+ [4]. Because this change is likely
+ # untested on older versions, we only apply this to OpenSSL 1.1.1+ to be safe.
+ # LibreSSL is excluded until further investigation due to cipher support issues [5][6].
+ # 1. https://github.com/python/cpython/commit/e983252b516edb15d4338b0a47631b59ef1e2536
+ # 2. https://github.com/yt-dlp/yt-dlp/issues/4627
+ # 3. https://github.com/yt-dlp/yt-dlp/pull/5294
+ # 4. https://peps.python.org/pep-0644/
+ # 5. https://peps.python.org/pep-0644/#libressl-support
+ # 6. https://github.com/yt-dlp/yt-dlp/commit/5b9f253fa0aee996cf1ed30185d4b502e00609c4#commitcomment-89054368
+ context.set_ciphers(
+ '@SECLEVEL=2:ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES:DHE+AES:!aNULL:!eNULL:!aDSS:!SHA1:!AESCCM')
+ context.minimum_version = ssl.TLSVersion.TLSv1_2
+
+ if client_certificate:
+ try:
+ context.load_cert_chain(
+ client_certificate, keyfile=client_certificate_key,
+ password=client_certificate_password)
+ except ssl.SSLError:
+ raise RequestError('Unable to load client certificate')
+
+ if getattr(context, 'post_handshake_auth', None) is not None:
+ context.post_handshake_auth = True
+ return context
+
+
+class InstanceStoreMixin:
+ def __init__(self, **kwargs):
+ self.__instances = []
+ super().__init__(**kwargs) # So that both MRO works
+
+ @staticmethod
+ def _create_instance(**kwargs):
+ raise NotImplementedError
+
+ def _get_instance(self, **kwargs):
+ for key, instance in self.__instances:
+ if key == kwargs:
+ return instance
+
+ instance = self._create_instance(**kwargs)
+ self.__instances.append((kwargs, instance))
+ return instance
+
+ def _close_instance(self, instance):
+ if callable(getattr(instance, 'close', None)):
+ instance.close()
+
+ def _clear_instances(self):
+ for _, instance in self.__instances:
+ self._close_instance(instance)
+ self.__instances.clear()
+
+
+def add_accept_encoding_header(headers: HTTPHeaderDict, supported_encodings: Iterable[str]):
+ if 'Accept-Encoding' not in headers:
+ headers['Accept-Encoding'] = ', '.join(supported_encodings) or 'identity'
+
+
+def wrap_request_errors(func):
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ try:
+ return func(self, *args, **kwargs)
+ except UnsupportedRequest as e:
+ if e.handler is None:
+ e.handler = self
+ raise
+ return wrapper
+
+
+def _socket_connect(ip_addr, timeout, source_address):
+ af, socktype, proto, canonname, sa = ip_addr
+ sock = socket.socket(af, socktype, proto)
+ try:
+ if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
+ sock.settimeout(timeout)
+ if source_address:
+ sock.bind(source_address)
+ sock.connect(sa)
+ return sock
+ except OSError:
+ sock.close()
+ raise
+
+
+def create_socks_proxy_socket(dest_addr, proxy_args, proxy_ip_addr, timeout, source_address):
+ af, socktype, proto, canonname, sa = proxy_ip_addr
+ sock = sockssocket(af, socktype, proto)
+ try:
+ connect_proxy_args = proxy_args.copy()
+ connect_proxy_args.update({'addr': sa[0], 'port': sa[1]})
+ sock.setproxy(**connect_proxy_args)
+ if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: # noqa: E721
+ sock.settimeout(timeout)
+ if source_address:
+ sock.bind(source_address)
+ sock.connect(dest_addr)
+ return sock
+ except OSError:
+ sock.close()
+ raise
+
+
+def create_connection(
+ address,
+ timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
+ source_address=None,
+ *,
+ _create_socket_func=_socket_connect
+):
+ # Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6.
+ # This filters the addresses based on the given source_address.
+ # Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810
+ host, port = address
+ ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
+ if not ip_addrs:
+ raise OSError('getaddrinfo returns an empty list')
+ if source_address is not None:
+ af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6
+ ip_addrs = [addr for addr in ip_addrs if addr[0] == af]
+ if not ip_addrs:
+ raise OSError(
+ f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. '
+ f'Can\'t use "{source_address[0]}" as source address')
+
+ err = None
+ for ip_addr in ip_addrs:
+ try:
+ sock = _create_socket_func(ip_addr, timeout, source_address)
+ # Explicitly break __traceback__ reference cycle
+ # https://bugs.python.org/issue36820
+ err = None
+ return sock
+ except OSError as e:
+ err = e
+
+ try:
+ raise err
+ finally:
+ # Explicitly break __traceback__ reference cycle
+ # https://bugs.python.org/issue36820
+ err = None
diff --git a/yt_dlp/networking/_requests.py b/yt_dlp/networking/_requests.py
new file mode 100644
index 0000000..6545028
--- /dev/null
+++ b/yt_dlp/networking/_requests.py
@@ -0,0 +1,408 @@
+import contextlib
+import functools
+import http.client
+import logging
+import re
+import socket
+import warnings
+
+from ..dependencies import brotli, requests, urllib3
+from ..utils import bug_reports_message, int_or_none, variadic
+from ..utils.networking import normalize_url
+
+if requests is None:
+ raise ImportError('requests module is not installed')
+
+if urllib3 is None:
+ raise ImportError('urllib3 module is not installed')
+
+urllib3_version = tuple(int_or_none(x, default=0) for x in urllib3.__version__.split('.'))
+
+if urllib3_version < (1, 26, 17):
+ raise ImportError('Only urllib3 >= 1.26.17 is supported')
+
+if requests.__build__ < 0x023100:
+ raise ImportError('Only requests >= 2.31.0 is supported')
+
+import requests.adapters
+import requests.utils
+import urllib3.connection
+import urllib3.exceptions
+
+from ._helper import (
+ InstanceStoreMixin,
+ add_accept_encoding_header,
+ create_connection,
+ create_socks_proxy_socket,
+ get_redirect_method,
+ make_socks_proxy_opts,
+ select_proxy,
+)
+from .common import (
+ Features,
+ RequestHandler,
+ Response,
+ register_preference,
+ register_rh,
+)
+from .exceptions import (
+ CertificateVerifyError,
+ HTTPError,
+ IncompleteRead,
+ ProxyError,
+ RequestError,
+ SSLError,
+ TransportError,
+)
+from ..socks import ProxyError as SocksProxyError
+
+SUPPORTED_ENCODINGS = [
+ 'gzip', 'deflate'
+]
+
+if brotli is not None:
+ SUPPORTED_ENCODINGS.append('br')
+
+"""
+Override urllib3's behavior to not convert lower-case percent-encoded characters
+to upper-case during url normalization process.
+
+RFC3986 defines that the lower or upper case percent-encoded hexidecimal characters are equivalent
+and normalizers should convert them to uppercase for consistency [1].
+
+However, some sites may have an incorrect implementation where they provide
+a percent-encoded url that is then compared case-sensitively.[2]
+
+While this is a very rare case, since urllib does not do this normalization step, it
+is best to avoid it in requests too for compatability reasons.
+
+1: https://tools.ietf.org/html/rfc3986#section-2.1
+2: https://github.com/streamlink/streamlink/pull/4003
+"""
+
+
+class Urllib3PercentREOverride:
+ def __init__(self, r: re.Pattern):
+ self.re = r
+
+ # pass through all other attribute calls to the original re
+ def __getattr__(self, item):
+ return self.re.__getattribute__(item)
+
+ def subn(self, repl, string, *args, **kwargs):
+ return string, self.re.subn(repl, string, *args, **kwargs)[1]
+
+
+# urllib3 >= 1.25.8 uses subn:
+# https://github.com/urllib3/urllib3/commit/a2697e7c6b275f05879b60f593c5854a816489f0
+import urllib3.util.url # noqa: E305
+
+if hasattr(urllib3.util.url, 'PERCENT_RE'):
+ urllib3.util.url.PERCENT_RE = Urllib3PercentREOverride(urllib3.util.url.PERCENT_RE)
+elif hasattr(urllib3.util.url, '_PERCENT_RE'): # urllib3 >= 2.0.0
+ urllib3.util.url._PERCENT_RE = Urllib3PercentREOverride(urllib3.util.url._PERCENT_RE)
+else:
+ warnings.warn('Failed to patch PERCENT_RE in urllib3 (does the attribute exist?)' + bug_reports_message())
+
+"""
+Workaround for issue in urllib.util.ssl_.py: ssl_wrap_context does not pass
+server_hostname to SSLContext.wrap_socket if server_hostname is an IP,
+however this is an issue because we set check_hostname to True in our SSLContext.
+
+Monkey-patching IS_SECURETRANSPORT forces ssl_wrap_context to pass server_hostname regardless.
+
+This has been fixed in urllib3 2.0+.
+See: https://github.com/urllib3/urllib3/issues/517
+"""
+
+if urllib3_version < (2, 0, 0):
+ with contextlib.suppress(Exception):
+ urllib3.util.IS_SECURETRANSPORT = urllib3.util.ssl_.IS_SECURETRANSPORT = True
+
+
+# Requests will not automatically handle no_proxy by default
+# due to buggy no_proxy handling with proxy dict [1].
+# 1. https://github.com/psf/requests/issues/5000
+requests.adapters.select_proxy = select_proxy
+
+
+class RequestsResponseAdapter(Response):
+ def __init__(self, res: requests.models.Response):
+ super().__init__(
+ fp=res.raw, headers=res.headers, url=res.url,
+ status=res.status_code, reason=res.reason)
+
+ self._requests_response = res
+
+ def read(self, amt: int = None):
+ try:
+ # Interact with urllib3 response directly.
+ return self.fp.read(amt, decode_content=True)
+
+ # See urllib3.response.HTTPResponse.read() for exceptions raised on read
+ except urllib3.exceptions.SSLError as e:
+ raise SSLError(cause=e) from e
+
+ except urllib3.exceptions.ProtocolError as e:
+ # IncompleteRead is always contained within ProtocolError
+ # See urllib3.response.HTTPResponse._error_catcher()
+ ir_err = next(
+ (err for err in (e.__context__, e.__cause__, *variadic(e.args))
+ if isinstance(err, http.client.IncompleteRead)), None)
+ if ir_err is not None:
+ # `urllib3.exceptions.IncompleteRead` is subclass of `http.client.IncompleteRead`
+ # but uses an `int` for its `partial` property.
+ partial = ir_err.partial if isinstance(ir_err.partial, int) else len(ir_err.partial)
+ raise IncompleteRead(partial=partial, expected=ir_err.expected) from e
+ raise TransportError(cause=e) from e
+
+ except urllib3.exceptions.HTTPError as e:
+ # catch-all for any other urllib3 response exceptions
+ raise TransportError(cause=e) from e
+
+
+class RequestsHTTPAdapter(requests.adapters.HTTPAdapter):
+ def __init__(self, ssl_context=None, proxy_ssl_context=None, source_address=None, **kwargs):
+ self._pm_args = {}
+ if ssl_context:
+ self._pm_args['ssl_context'] = ssl_context
+ if source_address:
+ self._pm_args['source_address'] = (source_address, 0)
+ self._proxy_ssl_context = proxy_ssl_context or ssl_context
+ super().__init__(**kwargs)
+
+ def init_poolmanager(self, *args, **kwargs):
+ return super().init_poolmanager(*args, **kwargs, **self._pm_args)
+
+ def proxy_manager_for(self, proxy, **proxy_kwargs):
+ extra_kwargs = {}
+ if not proxy.lower().startswith('socks') and self._proxy_ssl_context:
+ extra_kwargs['proxy_ssl_context'] = self._proxy_ssl_context
+ return super().proxy_manager_for(proxy, **proxy_kwargs, **self._pm_args, **extra_kwargs)
+
+ def cert_verify(*args, **kwargs):
+ # lean on SSLContext for cert verification
+ pass
+
+
+class RequestsSession(requests.sessions.Session):
+ """
+ Ensure unified redirect method handling with our urllib redirect handler.
+ """
+
+ def rebuild_method(self, prepared_request, response):
+ new_method = get_redirect_method(prepared_request.method, response.status_code)
+
+ # HACK: requests removes headers/body on redirect unless code was a 307/308.
+ if new_method == prepared_request.method:
+ response._real_status_code = response.status_code
+ response.status_code = 308
+
+ prepared_request.method = new_method
+
+ # Requests fails to resolve dot segments on absolute redirect locations
+ # See: https://github.com/yt-dlp/yt-dlp/issues/9020
+ prepared_request.url = normalize_url(prepared_request.url)
+
+ def rebuild_auth(self, prepared_request, response):
+ # HACK: undo status code change from rebuild_method, if applicable.
+ # rebuild_auth runs after requests would remove headers/body based on status code
+ if hasattr(response, '_real_status_code'):
+ response.status_code = response._real_status_code
+ del response._real_status_code
+ return super().rebuild_auth(prepared_request, response)
+
+
+class Urllib3LoggingFilter(logging.Filter):
+
+ def filter(self, record):
+ # Ignore HTTP request messages since HTTPConnection prints those
+ if record.msg == '%s://%s:%s "%s %s %s" %s %s':
+ return False
+ return True
+
+
+class Urllib3LoggingHandler(logging.Handler):
+ """Redirect urllib3 logs to our logger"""
+
+ def __init__(self, logger, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._logger = logger
+
+ def emit(self, record):
+ try:
+ msg = self.format(record)
+ if record.levelno >= logging.ERROR:
+ self._logger.error(msg)
+ else:
+ self._logger.stdout(msg)
+
+ except Exception:
+ self.handleError(record)
+
+
+@register_rh
+class RequestsRH(RequestHandler, InstanceStoreMixin):
+
+ """Requests RequestHandler
+ https://github.com/psf/requests
+ """
+ _SUPPORTED_URL_SCHEMES = ('http', 'https')
+ _SUPPORTED_ENCODINGS = tuple(SUPPORTED_ENCODINGS)
+ _SUPPORTED_PROXY_SCHEMES = ('http', 'https', 'socks4', 'socks4a', 'socks5', 'socks5h')
+ _SUPPORTED_FEATURES = (Features.NO_PROXY, Features.ALL_PROXY)
+ RH_NAME = 'requests'
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Forward urllib3 debug messages to our logger
+ logger = logging.getLogger('urllib3')
+ self.__logging_handler = Urllib3LoggingHandler(logger=self._logger)
+ self.__logging_handler.setFormatter(logging.Formatter('requests: %(message)s'))
+ self.__logging_handler.addFilter(Urllib3LoggingFilter())
+ logger.addHandler(self.__logging_handler)
+ # TODO: Use a logger filter to suppress pool reuse warning instead
+ logger.setLevel(logging.ERROR)
+
+ if self.verbose:
+ # Setting this globally is not ideal, but is easier than hacking with urllib3.
+ # It could technically be problematic for scripts embedding yt-dlp.
+ # However, it is unlikely debug traffic is used in that context in a way this will cause problems.
+ urllib3.connection.HTTPConnection.debuglevel = 1
+ logger.setLevel(logging.DEBUG)
+ # this is expected if we are using --no-check-certificate
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
+
+ def close(self):
+ self._clear_instances()
+ # Remove the logging handler that contains a reference to our logger
+ # See: https://github.com/yt-dlp/yt-dlp/issues/8922
+ logging.getLogger('urllib3').removeHandler(self.__logging_handler)
+
+ def _check_extensions(self, extensions):
+ super()._check_extensions(extensions)
+ extensions.pop('cookiejar', None)
+ extensions.pop('timeout', None)
+
+ def _create_instance(self, cookiejar):
+ session = RequestsSession()
+ http_adapter = RequestsHTTPAdapter(
+ ssl_context=self._make_sslcontext(),
+ source_address=self.source_address,
+ max_retries=urllib3.util.retry.Retry(False),
+ )
+ session.adapters.clear()
+ session.headers = requests.models.CaseInsensitiveDict({'Connection': 'keep-alive'})
+ session.mount('https://', http_adapter)
+ session.mount('http://', http_adapter)
+ session.cookies = cookiejar
+ session.trust_env = False # no need, we already load proxies from env
+ return session
+
+ def _send(self, request):
+
+ headers = self._merge_headers(request.headers)
+ add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
+
+ max_redirects_exceeded = False
+
+ session = self._get_instance(
+ cookiejar=request.extensions.get('cookiejar') or self.cookiejar)
+
+ try:
+ requests_res = session.request(
+ method=request.method,
+ url=request.url,
+ data=request.data,
+ headers=headers,
+ timeout=float(request.extensions.get('timeout') or self.timeout),
+ proxies=request.proxies or self.proxies,
+ allow_redirects=True,
+ stream=True
+ )
+
+ except requests.exceptions.TooManyRedirects as e:
+ max_redirects_exceeded = True
+ requests_res = e.response
+
+ except requests.exceptions.SSLError as e:
+ if 'CERTIFICATE_VERIFY_FAILED' in str(e):
+ raise CertificateVerifyError(cause=e) from e
+ raise SSLError(cause=e) from e
+
+ except requests.exceptions.ProxyError as e:
+ raise ProxyError(cause=e) from e
+
+ except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
+ raise TransportError(cause=e) from e
+
+ except urllib3.exceptions.HTTPError as e:
+ # Catch any urllib3 exceptions that may leak through
+ raise TransportError(cause=e) from e
+
+ except requests.exceptions.RequestException as e:
+ # Miscellaneous Requests exceptions. May not necessary be network related e.g. InvalidURL
+ raise RequestError(cause=e) from e
+
+ res = RequestsResponseAdapter(requests_res)
+
+ if not 200 <= res.status < 300:
+ raise HTTPError(res, redirect_loop=max_redirects_exceeded)
+
+ return res
+
+
+@register_preference(RequestsRH)
+def requests_preference(rh, request):
+ return 100
+
+
+# Use our socks proxy implementation with requests to avoid an extra dependency.
+class SocksHTTPConnection(urllib3.connection.HTTPConnection):
+ def __init__(self, _socks_options, *args, **kwargs): # must use _socks_options to pass PoolKey checks
+ self._proxy_args = _socks_options
+ super().__init__(*args, **kwargs)
+
+ def _new_conn(self):
+ try:
+ return create_connection(
+ address=(self._proxy_args['addr'], self._proxy_args['port']),
+ timeout=self.timeout,
+ source_address=self.source_address,
+ _create_socket_func=functools.partial(
+ create_socks_proxy_socket, (self.host, self.port), self._proxy_args))
+ except (socket.timeout, TimeoutError) as e:
+ raise urllib3.exceptions.ConnectTimeoutError(
+ self, f'Connection to {self.host} timed out. (connect timeout={self.timeout})') from e
+ except SocksProxyError as e:
+ raise urllib3.exceptions.ProxyError(str(e), e) from e
+ except OSError as e:
+ raise urllib3.exceptions.NewConnectionError(
+ self, f'Failed to establish a new connection: {e}') from e
+
+
+class SocksHTTPSConnection(SocksHTTPConnection, urllib3.connection.HTTPSConnection):
+ pass
+
+
+class SocksHTTPConnectionPool(urllib3.HTTPConnectionPool):
+ ConnectionCls = SocksHTTPConnection
+
+
+class SocksHTTPSConnectionPool(urllib3.HTTPSConnectionPool):
+ ConnectionCls = SocksHTTPSConnection
+
+
+class SocksProxyManager(urllib3.PoolManager):
+
+ def __init__(self, socks_proxy, username=None, password=None, num_pools=10, headers=None, **connection_pool_kw):
+ connection_pool_kw['_socks_options'] = make_socks_proxy_opts(socks_proxy)
+ super().__init__(num_pools, headers, **connection_pool_kw)
+ self.pool_classes_by_scheme = {
+ 'http': SocksHTTPConnectionPool,
+ 'https': SocksHTTPSConnectionPool
+ }
+
+
+requests.adapters.SOCKSProxyManager = SocksProxyManager
diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py
new file mode 100644
index 0000000..cb4dae3
--- /dev/null
+++ b/yt_dlp/networking/_urllib.py
@@ -0,0 +1,422 @@
+from __future__ import annotations
+
+import functools
+import http.client
+import io
+import ssl
+import urllib.error
+import urllib.parse
+import urllib.request
+import urllib.response
+import zlib
+from urllib.request import (
+ DataHandler,
+ FileHandler,
+ FTPHandler,
+ HTTPCookieProcessor,
+ HTTPDefaultErrorHandler,
+ HTTPErrorProcessor,
+ UnknownHandler,
+)
+
+from ._helper import (
+ InstanceStoreMixin,
+ add_accept_encoding_header,
+ create_connection,
+ create_socks_proxy_socket,
+ get_redirect_method,
+ make_socks_proxy_opts,
+ select_proxy,
+)
+from .common import Features, RequestHandler, Response, register_rh
+from .exceptions import (
+ CertificateVerifyError,
+ HTTPError,
+ IncompleteRead,
+ ProxyError,
+ RequestError,
+ SSLError,
+ TransportError,
+)
+from ..dependencies import brotli
+from ..socks import ProxyError as SocksProxyError
+from ..utils import update_url_query
+from ..utils.networking import normalize_url
+
+SUPPORTED_ENCODINGS = ['gzip', 'deflate']
+CONTENT_DECODE_ERRORS = [zlib.error, OSError]
+
+if brotli:
+ SUPPORTED_ENCODINGS.append('br')
+ CONTENT_DECODE_ERRORS.append(brotli.error)
+
+
+def _create_http_connection(http_class, source_address, *args, **kwargs):
+ hc = http_class(*args, **kwargs)
+
+ if hasattr(hc, '_create_connection'):
+ hc._create_connection = create_connection
+
+ if source_address is not None:
+ hc.source_address = (source_address, 0)
+
+ return hc
+
+
+class HTTPHandler(urllib.request.AbstractHTTPHandler):
+ """Handler for HTTP requests and responses.
+
+ This class, when installed with an OpenerDirector, automatically adds
+ the standard headers to every HTTP request and handles gzipped, deflated and
+ brotli responses from web servers.
+
+ Part of this code was copied from:
+
+ http://techknack.net/python-urllib2-handlers/
+
+ Andrew Rowls, the author of that code, agreed to release it to the
+ public domain.
+ """
+
+ def __init__(self, context=None, source_address=None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._source_address = source_address
+ self._context = context
+
+ @staticmethod
+ def _make_conn_class(base, req):
+ conn_class = base
+ socks_proxy = req.headers.pop('Ytdl-socks-proxy', None)
+ if socks_proxy:
+ conn_class = make_socks_conn_class(conn_class, socks_proxy)
+ return conn_class
+
+ def http_open(self, req):
+ conn_class = self._make_conn_class(http.client.HTTPConnection, req)
+ return self.do_open(functools.partial(
+ _create_http_connection, conn_class, self._source_address), req)
+
+ def https_open(self, req):
+ conn_class = self._make_conn_class(http.client.HTTPSConnection, req)
+ return self.do_open(
+ functools.partial(
+ _create_http_connection, conn_class, self._source_address),
+ req, context=self._context)
+
+ @staticmethod
+ def deflate(data):
+ if not data:
+ return data
+ try:
+ return zlib.decompress(data, -zlib.MAX_WBITS)
+ except zlib.error:
+ return zlib.decompress(data)
+
+ @staticmethod
+ def brotli(data):
+ if not data:
+ return data
+ return brotli.decompress(data)
+
+ @staticmethod
+ def gz(data):
+ # There may be junk added the end of the file
+ # We ignore it by only ever decoding a single gzip payload
+ if not data:
+ return data
+ return zlib.decompress(data, wbits=zlib.MAX_WBITS | 16)
+
+ def http_request(self, req):
+ # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not
+ # always respected by websites, some tend to give out URLs with non percent-encoded
+ # non-ASCII characters (see telemb.py, ard.py [#3412])
+ # urllib chokes on URLs with non-ASCII characters (see http://bugs.python.org/issue3991)
+ # To work around aforementioned issue we will replace request's original URL with
+ # percent-encoded one
+ # Since redirects are also affected (e.g. http://www.southpark.de/alle-episoden/s18e09)
+ # the code of this workaround has been moved here from YoutubeDL.urlopen()
+ url = req.get_full_url()
+ url_escaped = normalize_url(url)
+
+ # Substitute URL if any change after escaping
+ if url != url_escaped:
+ req = update_Request(req, url=url_escaped)
+
+ return super().do_request_(req)
+
+ def http_response(self, req, resp):
+ old_resp = resp
+
+ # Content-Encoding header lists the encodings in order that they were applied [1].
+ # To decompress, we simply do the reverse.
+ # [1]: https://datatracker.ietf.org/doc/html/rfc9110#name-content-encoding
+ decoded_response = None
+ for encoding in (e.strip() for e in reversed(resp.headers.get('Content-encoding', '').split(','))):
+ if encoding == 'gzip':
+ decoded_response = self.gz(decoded_response or resp.read())
+ elif encoding == 'deflate':
+ decoded_response = self.deflate(decoded_response or resp.read())
+ elif encoding == 'br' and brotli:
+ decoded_response = self.brotli(decoded_response or resp.read())
+
+ if decoded_response is not None:
+ resp = urllib.request.addinfourl(io.BytesIO(decoded_response), old_resp.headers, old_resp.url, old_resp.code)
+ resp.msg = old_resp.msg
+ # Percent-encode redirect URL of Location HTTP header to satisfy RFC 3986 (see
+ # https://github.com/ytdl-org/youtube-dl/issues/6457).
+ if 300 <= resp.code < 400:
+ location = resp.headers.get('Location')
+ if location:
+ # As of RFC 2616 default charset is iso-8859-1 that is respected by Python 3
+ location = location.encode('iso-8859-1').decode()
+ location_escaped = normalize_url(location)
+ if location != location_escaped:
+ del resp.headers['Location']
+ resp.headers['Location'] = location_escaped
+ return resp
+
+ https_request = http_request
+ https_response = http_response
+
+
+def make_socks_conn_class(base_class, socks_proxy):
+ assert issubclass(base_class, (
+ http.client.HTTPConnection, http.client.HTTPSConnection))
+
+ proxy_args = make_socks_proxy_opts(socks_proxy)
+
+ class SocksConnection(base_class):
+ _create_connection = create_connection
+
+ def connect(self):
+ self.sock = create_connection(
+ (proxy_args['addr'], proxy_args['port']),
+ timeout=self.timeout,
+ source_address=self.source_address,
+ _create_socket_func=functools.partial(
+ create_socks_proxy_socket, (self.host, self.port), proxy_args))
+ if isinstance(self, http.client.HTTPSConnection):
+ self.sock = self._context.wrap_socket(self.sock, server_hostname=self.host)
+
+ return SocksConnection
+
+
+class RedirectHandler(urllib.request.HTTPRedirectHandler):
+ """YoutubeDL redirect handler
+
+ The code is based on HTTPRedirectHandler implementation from CPython [1].
+
+ This redirect handler fixes and improves the logic to better align with RFC7261
+ and what browsers tend to do [2][3]
+
+ 1. https://github.com/python/cpython/blob/master/Lib/urllib/request.py
+ 2. https://datatracker.ietf.org/doc/html/rfc7231
+ 3. https://github.com/python/cpython/issues/91306
+ """
+
+ http_error_301 = http_error_303 = http_error_307 = http_error_308 = urllib.request.HTTPRedirectHandler.http_error_302
+
+ def redirect_request(self, req, fp, code, msg, headers, newurl):
+ if code not in (301, 302, 303, 307, 308):
+ raise urllib.error.HTTPError(req.full_url, code, msg, headers, fp)
+
+ new_data = req.data
+
+ # Technically the Cookie header should be in unredirected_hdrs,
+ # however in practice some may set it in normal headers anyway.
+ # We will remove it here to prevent any leaks.
+ remove_headers = ['Cookie']
+
+ new_method = get_redirect_method(req.get_method(), code)
+ # only remove payload if method changed (e.g. POST to GET)
+ if new_method != req.get_method():
+ new_data = None
+ remove_headers.extend(['Content-Length', 'Content-Type'])
+
+ new_headers = {k: v for k, v in req.headers.items() if k.title() not in remove_headers}
+
+ return urllib.request.Request(
+ newurl, headers=new_headers, origin_req_host=req.origin_req_host,
+ unverifiable=True, method=new_method, data=new_data)
+
+
+class ProxyHandler(urllib.request.BaseHandler):
+ handler_order = 100
+
+ def __init__(self, proxies=None):
+ self.proxies = proxies
+ # Set default handlers
+ for type in ('http', 'https', 'ftp'):
+ setattr(self, '%s_open' % type, lambda r, meth=self.proxy_open: meth(r))
+
+ def proxy_open(self, req):
+ proxy = select_proxy(req.get_full_url(), self.proxies)
+ if proxy is None:
+ return
+ if urllib.parse.urlparse(proxy).scheme.lower() in ('socks4', 'socks4a', 'socks5', 'socks5h'):
+ req.add_header('Ytdl-socks-proxy', proxy)
+ # yt-dlp's http/https handlers do wrapping the socket with socks
+ return None
+ return urllib.request.ProxyHandler.proxy_open(
+ self, req, proxy, None)
+
+
+class PUTRequest(urllib.request.Request):
+ def get_method(self):
+ return 'PUT'
+
+
+class HEADRequest(urllib.request.Request):
+ def get_method(self):
+ return 'HEAD'
+
+
+def update_Request(req, url=None, data=None, headers=None, query=None):
+ req_headers = req.headers.copy()
+ req_headers.update(headers or {})
+ req_data = data if data is not None else req.data
+ req_url = update_url_query(url or req.get_full_url(), query)
+ req_get_method = req.get_method()
+ if req_get_method == 'HEAD':
+ req_type = HEADRequest
+ elif req_get_method == 'PUT':
+ req_type = PUTRequest
+ else:
+ req_type = urllib.request.Request
+ new_req = req_type(
+ req_url, data=req_data, headers=req_headers,
+ origin_req_host=req.origin_req_host, unverifiable=req.unverifiable)
+ if hasattr(req, 'timeout'):
+ new_req.timeout = req.timeout
+ return new_req
+
+
+class UrllibResponseAdapter(Response):
+ """
+ HTTP Response adapter class for urllib addinfourl and http.client.HTTPResponse
+ """
+
+ def __init__(self, res: http.client.HTTPResponse | urllib.response.addinfourl):
+ # addinfourl: In Python 3.9+, .status was introduced and .getcode() was deprecated [1]
+ # HTTPResponse: .getcode() was deprecated, .status always existed [2]
+ # 1. https://docs.python.org/3/library/urllib.request.html#urllib.response.addinfourl.getcode
+ # 2. https://docs.python.org/3.10/library/http.client.html#http.client.HTTPResponse.status
+ super().__init__(
+ fp=res, headers=res.headers, url=res.url,
+ status=getattr(res, 'status', None) or res.getcode(), reason=getattr(res, 'reason', None))
+
+ def read(self, amt=None):
+ try:
+ return self.fp.read(amt)
+ except Exception as e:
+ handle_response_read_exceptions(e)
+ raise e
+
+
+def handle_sslerror(e: ssl.SSLError):
+ if not isinstance(e, ssl.SSLError):
+ return
+ if isinstance(e, ssl.SSLCertVerificationError):
+ raise CertificateVerifyError(cause=e) from e
+ raise SSLError(cause=e) from e
+
+
+def handle_response_read_exceptions(e):
+ if isinstance(e, http.client.IncompleteRead):
+ raise IncompleteRead(partial=len(e.partial), cause=e, expected=e.expected) from e
+ elif isinstance(e, ssl.SSLError):
+ handle_sslerror(e)
+ elif isinstance(e, (OSError, EOFError, http.client.HTTPException, *CONTENT_DECODE_ERRORS)):
+ # OSErrors raised here should mostly be network related
+ raise TransportError(cause=e) from e
+
+
+@register_rh
+class UrllibRH(RequestHandler, InstanceStoreMixin):
+ _SUPPORTED_URL_SCHEMES = ('http', 'https', 'data', 'ftp')
+ _SUPPORTED_PROXY_SCHEMES = ('http', 'socks4', 'socks4a', 'socks5', 'socks5h')
+ _SUPPORTED_FEATURES = (Features.NO_PROXY, Features.ALL_PROXY)
+ RH_NAME = 'urllib'
+
+ def __init__(self, *, enable_file_urls: bool = False, **kwargs):
+ super().__init__(**kwargs)
+ self.enable_file_urls = enable_file_urls
+ if self.enable_file_urls:
+ self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file')
+
+ def _check_extensions(self, extensions):
+ super()._check_extensions(extensions)
+ extensions.pop('cookiejar', None)
+ extensions.pop('timeout', None)
+
+ def _create_instance(self, proxies, cookiejar):
+ opener = urllib.request.OpenerDirector()
+ handlers = [
+ ProxyHandler(proxies),
+ HTTPHandler(
+ debuglevel=int(bool(self.verbose)),
+ context=self._make_sslcontext(),
+ source_address=self.source_address),
+ HTTPCookieProcessor(cookiejar),
+ DataHandler(),
+ UnknownHandler(),
+ HTTPDefaultErrorHandler(),
+ FTPHandler(),
+ HTTPErrorProcessor(),
+ RedirectHandler(),
+ ]
+
+ if self.enable_file_urls:
+ handlers.append(FileHandler())
+
+ for handler in handlers:
+ opener.add_handler(handler)
+
+ # Delete the default user-agent header, which would otherwise apply in
+ # cases where our custom HTTP handler doesn't come into play
+ # (See https://github.com/ytdl-org/youtube-dl/issues/1309 for details)
+ opener.addheaders = []
+ return opener
+
+ def _send(self, request):
+ headers = self._merge_headers(request.headers)
+ add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
+ urllib_req = urllib.request.Request(
+ url=request.url,
+ data=request.data,
+ headers=dict(headers),
+ method=request.method
+ )
+
+ opener = self._get_instance(
+ proxies=request.proxies or self.proxies,
+ cookiejar=request.extensions.get('cookiejar') or self.cookiejar
+ )
+ try:
+ res = opener.open(urllib_req, timeout=float(request.extensions.get('timeout') or self.timeout))
+ except urllib.error.HTTPError as e:
+ if isinstance(e.fp, (http.client.HTTPResponse, urllib.response.addinfourl)):
+ # Prevent file object from being closed when urllib.error.HTTPError is destroyed.
+ e._closer.close_called = True
+ raise HTTPError(UrllibResponseAdapter(e.fp), redirect_loop='redirect error' in str(e)) from e
+ raise # unexpected
+ except urllib.error.URLError as e:
+ cause = e.reason # NOTE: cause may be a string
+
+ # proxy errors
+ if 'tunnel connection failed' in str(cause).lower() or isinstance(cause, SocksProxyError):
+ raise ProxyError(cause=e) from e
+
+ handle_response_read_exceptions(cause)
+ raise TransportError(cause=e) from e
+ except (http.client.InvalidURL, ValueError) as e:
+ # Validation errors
+ # http.client.HTTPConnection raises ValueError in some validation cases
+ # such as if request method contains illegal control characters [1]
+ # 1. https://github.com/python/cpython/blob/987b712b4aeeece336eed24fcc87a950a756c3e2/Lib/http/client.py#L1256
+ raise RequestError(cause=e) from e
+ except Exception as e:
+ handle_response_read_exceptions(e)
+ raise # unexpected
+
+ return UrllibResponseAdapter(res)
diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py
new file mode 100644
index 0000000..1597932
--- /dev/null
+++ b/yt_dlp/networking/_websockets.py
@@ -0,0 +1,173 @@
+from __future__ import annotations
+
+import io
+import logging
+import ssl
+import sys
+
+from ._helper import (
+ create_connection,
+ create_socks_proxy_socket,
+ make_socks_proxy_opts,
+ select_proxy,
+)
+from .common import Features, Response, register_rh
+from .exceptions import (
+ CertificateVerifyError,
+ HTTPError,
+ ProxyError,
+ RequestError,
+ SSLError,
+ TransportError,
+)
+from .websocket import WebSocketRequestHandler, WebSocketResponse
+from ..compat import functools
+from ..dependencies import websockets
+from ..socks import ProxyError as SocksProxyError
+from ..utils import int_or_none
+
+if not websockets:
+ raise ImportError('websockets is not installed')
+
+import websockets.version
+
+websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
+if websockets_version < (12, 0):
+ raise ImportError('Only websockets>=12.0 is supported')
+
+import websockets.sync.client
+from websockets.uri import parse_uri
+
+
+class WebsocketsResponseAdapter(WebSocketResponse):
+
+ def __init__(self, wsw: websockets.sync.client.ClientConnection, url):
+ super().__init__(
+ fp=io.BytesIO(wsw.response.body or b''),
+ url=url,
+ headers=wsw.response.headers,
+ status=wsw.response.status_code,
+ reason=wsw.response.reason_phrase,
+ )
+ self.wsw = wsw
+
+ def close(self):
+ self.wsw.close()
+ super().close()
+
+ def send(self, message):
+ # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
+ try:
+ return self.wsw.send(message)
+ except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
+ raise TransportError(cause=e) from e
+ except SocksProxyError as e:
+ raise ProxyError(cause=e) from e
+ except TypeError as e:
+ raise RequestError(cause=e) from e
+
+ def recv(self):
+ # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
+ try:
+ return self.wsw.recv()
+ except SocksProxyError as e:
+ raise ProxyError(cause=e) from e
+ except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
+ raise TransportError(cause=e) from e
+
+
+@register_rh
+class WebsocketsRH(WebSocketRequestHandler):
+ """
+ Websockets request handler
+ https://websockets.readthedocs.io
+ https://github.com/python-websockets/websockets
+ """
+ _SUPPORTED_URL_SCHEMES = ('wss', 'ws')
+ _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
+ _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
+ RH_NAME = 'websockets'
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.__logging_handlers = {}
+ for name in ('websockets.client', 'websockets.server'):
+ logger = logging.getLogger(name)
+ handler = logging.StreamHandler(stream=sys.stdout)
+ handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
+ self.__logging_handlers[name] = handler
+ logger.addHandler(handler)
+ if self.verbose:
+ logger.setLevel(logging.DEBUG)
+
+ def _check_extensions(self, extensions):
+ super()._check_extensions(extensions)
+ extensions.pop('timeout', None)
+ extensions.pop('cookiejar', None)
+
+ def close(self):
+ # Remove the logging handler that contains a reference to our logger
+ # See: https://github.com/yt-dlp/yt-dlp/issues/8922
+ for name, handler in self.__logging_handlers.items():
+ logging.getLogger(name).removeHandler(handler)
+
+ def _send(self, request):
+ timeout = float(request.extensions.get('timeout') or self.timeout)
+ headers = self._merge_headers(request.headers)
+ if 'cookie' not in headers:
+ cookiejar = request.extensions.get('cookiejar') or self.cookiejar
+ cookie_header = cookiejar.get_cookie_header(request.url)
+ if cookie_header:
+ headers['cookie'] = cookie_header
+
+ wsuri = parse_uri(request.url)
+ create_conn_kwargs = {
+ 'source_address': (self.source_address, 0) if self.source_address else None,
+ 'timeout': timeout
+ }
+ proxy = select_proxy(request.url, request.proxies or self.proxies or {})
+ try:
+ if proxy:
+ socks_proxy_options = make_socks_proxy_opts(proxy)
+ sock = create_connection(
+ address=(socks_proxy_options['addr'], socks_proxy_options['port']),
+ _create_socket_func=functools.partial(
+ create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
+ **create_conn_kwargs
+ )
+ else:
+ sock = create_connection(
+ address=(wsuri.host, wsuri.port),
+ **create_conn_kwargs
+ )
+ conn = websockets.sync.client.connect(
+ sock=sock,
+ uri=request.url,
+ additional_headers=headers,
+ open_timeout=timeout,
+ user_agent_header=None,
+ ssl_context=self._make_sslcontext() if wsuri.secure else None,
+ close_timeout=0, # not ideal, but prevents yt-dlp hanging
+ )
+ return WebsocketsResponseAdapter(conn, url=request.url)
+
+ # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
+ except SocksProxyError as e:
+ raise ProxyError(cause=e) from e
+ except websockets.exceptions.InvalidURI as e:
+ raise RequestError(cause=e) from e
+ except ssl.SSLCertVerificationError as e:
+ raise CertificateVerifyError(cause=e) from e
+ except ssl.SSLError as e:
+ raise SSLError(cause=e) from e
+ except websockets.exceptions.InvalidStatus as e:
+ raise HTTPError(
+ Response(
+ fp=io.BytesIO(e.response.body),
+ url=request.url,
+ headers=e.response.headers,
+ status=e.response.status_code,
+ reason=e.response.reason_phrase),
+ ) from e
+ except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
+ raise TransportError(cause=e) from e
diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py
new file mode 100644
index 0000000..39442ba
--- /dev/null
+++ b/yt_dlp/networking/common.py
@@ -0,0 +1,565 @@
+from __future__ import annotations
+
+import abc
+import copy
+import enum
+import functools
+import io
+import typing
+import urllib.parse
+import urllib.request
+import urllib.response
+from collections.abc import Iterable, Mapping
+from email.message import Message
+from http import HTTPStatus
+
+from ._helper import make_ssl_context, wrap_request_errors
+from .exceptions import (
+ NoSupportingHandlers,
+ RequestError,
+ TransportError,
+ UnsupportedRequest,
+)
+from ..compat.types import NoneType
+from ..cookies import YoutubeDLCookieJar
+from ..utils import (
+ bug_reports_message,
+ classproperty,
+ deprecation_warning,
+ error_to_str,
+ update_url_query,
+)
+from ..utils.networking import HTTPHeaderDict, normalize_url
+
+
+def register_preference(*handlers: type[RequestHandler]):
+ assert all(issubclass(handler, RequestHandler) for handler in handlers)
+
+ def outer(preference: Preference):
+ @functools.wraps(preference)
+ def inner(handler, *args, **kwargs):
+ if not handlers or isinstance(handler, handlers):
+ return preference(handler, *args, **kwargs)
+ return 0
+ _RH_PREFERENCES.add(inner)
+ return inner
+ return outer
+
+
+class RequestDirector:
+ """RequestDirector class
+
+ Helper class that, when given a request, forward it to a RequestHandler that supports it.
+
+ Preference functions in the form of func(handler, request) -> int
+ can be registered into the `preferences` set. These are used to sort handlers
+ in order of preference.
+
+ @param logger: Logger instance.
+ @param verbose: Print debug request information to stdout.
+ """
+
+ def __init__(self, logger, verbose=False):
+ self.handlers: dict[str, RequestHandler] = {}
+ self.preferences: set[Preference] = set()
+ self.logger = logger # TODO(Grub4k): default logger
+ self.verbose = verbose
+
+ def close(self):
+ for handler in self.handlers.values():
+ handler.close()
+ self.handlers.clear()
+
+ def add_handler(self, handler: RequestHandler):
+ """Add a handler. If a handler of the same RH_KEY exists, it will overwrite it"""
+ assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
+ self.handlers[handler.RH_KEY] = handler
+
+ def _get_handlers(self, request: Request) -> list[RequestHandler]:
+ """Sorts handlers by preference, given a request"""
+ preferences = {
+ rh: sum(pref(rh, request) for pref in self.preferences)
+ for rh in self.handlers.values()
+ }
+ self._print_verbose('Handler preferences for this request: %s' % ', '.join(
+ f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items()))
+ return sorted(self.handlers.values(), key=preferences.get, reverse=True)
+
+ def _print_verbose(self, msg):
+ if self.verbose:
+ self.logger.stdout(f'director: {msg}')
+
+ def send(self, request: Request) -> Response:
+ """
+ Passes a request onto a suitable RequestHandler
+ """
+ if not self.handlers:
+ raise RequestError('No request handlers configured')
+
+ assert isinstance(request, Request)
+
+ unexpected_errors = []
+ unsupported_errors = []
+ for handler in self._get_handlers(request):
+ self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
+ try:
+ handler.validate(request)
+ except UnsupportedRequest as e:
+ self._print_verbose(
+ f'"{handler.RH_NAME}" cannot handle this request (reason: {error_to_str(e)})')
+ unsupported_errors.append(e)
+ continue
+
+ self._print_verbose(f'Sending request via "{handler.RH_NAME}"')
+ try:
+ response = handler.send(request)
+ except RequestError:
+ raise
+ except Exception as e:
+ self.logger.error(
+ f'[{handler.RH_NAME}] Unexpected error: {error_to_str(e)}{bug_reports_message()}',
+ is_error=False)
+ unexpected_errors.append(e)
+ continue
+
+ assert isinstance(response, Response)
+ return response
+
+ raise NoSupportingHandlers(unsupported_errors, unexpected_errors)
+
+
+_REQUEST_HANDLERS = {}
+
+
+def register_rh(handler):
+ """Register a RequestHandler class"""
+ assert issubclass(handler, RequestHandler), f'{handler} must be a subclass of RequestHandler'
+ assert handler.RH_KEY not in _REQUEST_HANDLERS, f'RequestHandler {handler.RH_KEY} already registered'
+ _REQUEST_HANDLERS[handler.RH_KEY] = handler
+ return handler
+
+
+class Features(enum.Enum):
+ ALL_PROXY = enum.auto()
+ NO_PROXY = enum.auto()
+
+
+class RequestHandler(abc.ABC):
+
+ """Request Handler class
+
+ Request handlers are class that, given a Request,
+ process the request from start to finish and return a Response.
+
+ Concrete subclasses need to redefine the _send(request) method,
+ which handles the underlying request logic and returns a Response.
+
+ RH_NAME class variable may contain a display name for the RequestHandler.
+ By default, this is generated from the class name.
+
+ The concrete request handler MUST have "RH" as the suffix in the class name.
+
+ All exceptions raised by a RequestHandler should be an instance of RequestError.
+ Any other exception raised will be treated as a handler issue.
+
+ If a Request is not supported by the handler, an UnsupportedRequest
+ should be raised with a reason.
+
+ By default, some checks are done on the request in _validate() based on the following class variables:
+ - `_SUPPORTED_URL_SCHEMES`: a tuple of supported url schemes.
+ Any Request with an url scheme not in this list will raise an UnsupportedRequest.
+
+ - `_SUPPORTED_PROXY_SCHEMES`: a tuple of support proxy url schemes. Any Request that contains
+ a proxy url with an url scheme not in this list will raise an UnsupportedRequest.
+
+ - `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum.
+
+ The above may be set to None to disable the checks.
+
+ Parameters:
+ @param logger: logger instance
+ @param headers: HTTP Headers to include when sending requests.
+ @param cookiejar: Cookiejar to use for requests.
+ @param timeout: Socket timeout to use when sending requests.
+ @param proxies: Proxies to use for sending requests.
+ @param source_address: Client-side IP address to bind to for requests.
+ @param verbose: Print debug request and traffic information to stdout.
+ @param prefer_system_certs: Whether to prefer system certificates over other means (e.g. certifi).
+ @param client_cert: SSL client certificate configuration.
+ dict with {client_certificate, client_certificate_key, client_certificate_password}
+ @param verify: Verify SSL certificates
+ @param legacy_ssl_support: Enable legacy SSL options such as legacy server connect and older cipher support.
+
+ Some configuration options may be available for individual Requests too. In this case,
+ either the Request configuration option takes precedence or they are merged.
+
+ Requests may have additional optional parameters defined as extensions.
+ RequestHandler subclasses may choose to support custom extensions.
+
+ If an extension is supported, subclasses should extend _check_extensions(extensions)
+ to pop and validate the extension.
+ - Extensions left in `extensions` are treated as unsupported and UnsupportedRequest will be raised.
+
+ The following extensions are defined for RequestHandler:
+ - `cookiejar`: Cookiejar to use for this request.
+ - `timeout`: socket timeout to use for this request.
+ To enable these, add extensions.pop('<extension>', None) to _check_extensions
+
+ Apart from the url protocol, proxies dict may contain the following keys:
+ - `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol.
+ - `no`: comma seperated list of hostnames (optionally with port) to not use a proxy for.
+ Note: a RequestHandler may not support these, as defined in `_SUPPORTED_FEATURES`.
+
+ """
+
+ _SUPPORTED_URL_SCHEMES = ()
+ _SUPPORTED_PROXY_SCHEMES = ()
+ _SUPPORTED_FEATURES = ()
+
+ def __init__(
+ self, *,
+ logger, # TODO(Grub4k): default logger
+ headers: HTTPHeaderDict = None,
+ cookiejar: YoutubeDLCookieJar = None,
+ timeout: float | int | None = None,
+ proxies: dict = None,
+ source_address: str = None,
+ verbose: bool = False,
+ prefer_system_certs: bool = False,
+ client_cert: dict[str, str | None] = None,
+ verify: bool = True,
+ legacy_ssl_support: bool = False,
+ **_,
+ ):
+
+ self._logger = logger
+ self.headers = headers or {}
+ self.cookiejar = cookiejar if cookiejar is not None else YoutubeDLCookieJar()
+ self.timeout = float(timeout or 20)
+ self.proxies = proxies or {}
+ self.source_address = source_address
+ self.verbose = verbose
+ self.prefer_system_certs = prefer_system_certs
+ self._client_cert = client_cert or {}
+ self.verify = verify
+ self.legacy_ssl_support = legacy_ssl_support
+ super().__init__()
+
+ def _make_sslcontext(self):
+ return make_ssl_context(
+ verify=self.verify,
+ legacy_support=self.legacy_ssl_support,
+ use_certifi=not self.prefer_system_certs,
+ **self._client_cert,
+ )
+
+ def _merge_headers(self, request_headers):
+ return HTTPHeaderDict(self.headers, request_headers)
+
+ def _check_url_scheme(self, request: Request):
+ scheme = urllib.parse.urlparse(request.url).scheme.lower()
+ if self._SUPPORTED_URL_SCHEMES is not None and scheme not in self._SUPPORTED_URL_SCHEMES:
+ raise UnsupportedRequest(f'Unsupported url scheme: "{scheme}"')
+ return scheme # for further processing
+
+ def _check_proxies(self, proxies):
+ for proxy_key, proxy_url in proxies.items():
+ if proxy_url is None:
+ continue
+ if proxy_key == 'no':
+ if self._SUPPORTED_FEATURES is not None and Features.NO_PROXY not in self._SUPPORTED_FEATURES:
+ raise UnsupportedRequest('"no" proxy is not supported')
+ continue
+ if (
+ proxy_key == 'all'
+ and self._SUPPORTED_FEATURES is not None
+ and Features.ALL_PROXY not in self._SUPPORTED_FEATURES
+ ):
+ raise UnsupportedRequest('"all" proxy is not supported')
+
+ # Unlikely this handler will use this proxy, so ignore.
+ # This is to allow a case where a proxy may be set for a protocol
+ # for one handler in which such protocol (and proxy) is not supported by another handler.
+ if self._SUPPORTED_URL_SCHEMES is not None and proxy_key not in (*self._SUPPORTED_URL_SCHEMES, 'all'):
+ continue
+
+ if self._SUPPORTED_PROXY_SCHEMES is None:
+ # Skip proxy scheme checks
+ continue
+
+ try:
+ if urllib.request._parse_proxy(proxy_url)[0] is None:
+ # Scheme-less proxies are not supported
+ raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme')
+ except ValueError as e:
+ # parse_proxy may raise on some invalid proxy urls such as "/a/b/c"
+ raise UnsupportedRequest(f'Invalid proxy url "{proxy_url}": {e}')
+
+ scheme = urllib.parse.urlparse(proxy_url).scheme.lower()
+ if scheme not in self._SUPPORTED_PROXY_SCHEMES:
+ raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"')
+
+ def _check_extensions(self, extensions):
+ """Check extensions for unsupported extensions. Subclasses should extend this."""
+ assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
+ assert isinstance(extensions.get('timeout'), (float, int, NoneType))
+
+ def _validate(self, request):
+ self._check_url_scheme(request)
+ self._check_proxies(request.proxies or self.proxies)
+ extensions = request.extensions.copy()
+ self._check_extensions(extensions)
+ if extensions:
+ # TODO: add support for optional extensions
+ raise UnsupportedRequest(f'Unsupported extensions: {", ".join(extensions.keys())}')
+
+ @wrap_request_errors
+ def validate(self, request: Request):
+ if not isinstance(request, Request):
+ raise TypeError('Expected an instance of Request')
+ self._validate(request)
+
+ @wrap_request_errors
+ def send(self, request: Request) -> Response:
+ if not isinstance(request, Request):
+ raise TypeError('Expected an instance of Request')
+ return self._send(request)
+
+ @abc.abstractmethod
+ def _send(self, request: Request):
+ """Handle a request from start to finish. Redefine in subclasses."""
+ pass
+
+ def close(self):
+ pass
+
+ @classproperty
+ def RH_NAME(cls):
+ return cls.__name__[:-2]
+
+ @classproperty
+ def RH_KEY(cls):
+ assert cls.__name__.endswith('RH'), 'RequestHandler class names must end with "RH"'
+ return cls.__name__[:-2]
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ self.close()
+
+
+class Request:
+ """
+ Represents a request to be made.
+ Partially backwards-compatible with urllib.request.Request.
+
+ @param url: url to send. Will be sanitized.
+ @param data: payload data to send. Must be bytes, iterable of bytes, a file-like object or None
+ @param headers: headers to send.
+ @param proxies: proxy dict mapping of proto:proxy to use for the request and any redirects.
+ @param query: URL query parameters to update the url with.
+ @param method: HTTP method to use. If no method specified, will use POST if payload data is present else GET
+ @param extensions: Dictionary of Request extensions to add, as supported by handlers.
+ """
+
+ def __init__(
+ self,
+ url: str,
+ data: RequestData = None,
+ headers: typing.Mapping = None,
+ proxies: dict = None,
+ query: dict = None,
+ method: str = None,
+ extensions: dict = None
+ ):
+
+ self._headers = HTTPHeaderDict()
+ self._data = None
+
+ if query:
+ url = update_url_query(url, query)
+
+ self.url = url
+ self.method = method
+ if headers:
+ self.headers = headers
+ self.data = data # note: must be done after setting headers
+ self.proxies = proxies or {}
+ self.extensions = extensions or {}
+
+ @property
+ def url(self):
+ return self._url
+
+ @url.setter
+ def url(self, url):
+ if not isinstance(url, str):
+ raise TypeError('url must be a string')
+ elif url.startswith('//'):
+ url = 'http:' + url
+ self._url = normalize_url(url)
+
+ @property
+ def method(self):
+ return self._method or ('POST' if self.data is not None else 'GET')
+
+ @method.setter
+ def method(self, method):
+ if method is None:
+ self._method = None
+ elif isinstance(method, str):
+ self._method = method.upper()
+ else:
+ raise TypeError('method must be a string')
+
+ @property
+ def data(self):
+ return self._data
+
+ @data.setter
+ def data(self, data: RequestData):
+ # Try catch some common mistakes
+ if data is not None and (
+ not isinstance(data, (bytes, io.IOBase, Iterable)) or isinstance(data, (str, Mapping))
+ ):
+ raise TypeError('data must be bytes, iterable of bytes, or a file-like object')
+
+ if data == self._data and self._data is None:
+ self.headers.pop('Content-Length', None)
+
+ # https://docs.python.org/3/library/urllib.request.html#urllib.request.Request.data
+ if data != self._data:
+ if self._data is not None:
+ self.headers.pop('Content-Length', None)
+ self._data = data
+
+ if self._data is None:
+ self.headers.pop('Content-Type', None)
+
+ if 'Content-Type' not in self.headers and self._data is not None:
+ self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
+
+ @property
+ def headers(self) -> HTTPHeaderDict:
+ return self._headers
+
+ @headers.setter
+ def headers(self, new_headers: Mapping):
+ """Replaces headers of the request. If not a HTTPHeaderDict, it will be converted to one."""
+ if isinstance(new_headers, HTTPHeaderDict):
+ self._headers = new_headers
+ elif isinstance(new_headers, Mapping):
+ self._headers = HTTPHeaderDict(new_headers)
+ else:
+ raise TypeError('headers must be a mapping')
+
+ def update(self, url=None, data=None, headers=None, query=None):
+ self.data = data if data is not None else self.data
+ self.headers.update(headers or {})
+ self.url = update_url_query(url or self.url, query or {})
+
+ def copy(self):
+ return self.__class__(
+ url=self.url,
+ headers=copy.deepcopy(self.headers),
+ proxies=copy.deepcopy(self.proxies),
+ data=self._data,
+ extensions=copy.copy(self.extensions),
+ method=self._method,
+ )
+
+
+HEADRequest = functools.partial(Request, method='HEAD')
+PUTRequest = functools.partial(Request, method='PUT')
+
+
+class Response(io.IOBase):
+ """
+ Base class for HTTP response adapters.
+
+ By default, it provides a basic wrapper for a file-like response object.
+
+ Interface partially backwards-compatible with addinfourl and http.client.HTTPResponse.
+
+ @param fp: Original, file-like, response.
+ @param url: URL that this is a response of.
+ @param headers: response headers.
+ @param status: Response HTTP status code. Default is 200 OK.
+ @param reason: HTTP status reason. Will use built-in reasons based on status code if not provided.
+ """
+
+ def __init__(
+ self,
+ fp: typing.IO,
+ url: str,
+ headers: Mapping[str, str],
+ status: int = 200,
+ reason: str = None):
+
+ self.fp = fp
+ self.headers = Message()
+ for name, value in headers.items():
+ self.headers.add_header(name, value)
+ self.status = status
+ self.url = url
+ try:
+ self.reason = reason or HTTPStatus(status).phrase
+ except ValueError:
+ self.reason = None
+
+ def readable(self):
+ return self.fp.readable()
+
+ def read(self, amt: int = None) -> bytes:
+ # Expected errors raised here should be of type RequestError or subclasses.
+ # Subclasses should redefine this method with more precise error handling.
+ try:
+ return self.fp.read(amt)
+ except Exception as e:
+ raise TransportError(cause=e) from e
+
+ def close(self):
+ self.fp.close()
+ return super().close()
+
+ def get_header(self, name, default=None):
+ """Get header for name.
+ If there are multiple matching headers, return all seperated by comma."""
+ headers = self.headers.get_all(name)
+ if not headers:
+ return default
+ if name.title() == 'Set-Cookie':
+ # Special case, only get the first one
+ # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.3-4.1
+ return headers[0]
+ return ', '.join(headers)
+
+ # The following methods are for compatability reasons and are deprecated
+ @property
+ def code(self):
+ deprecation_warning('Response.code is deprecated, use Response.status', stacklevel=2)
+ return self.status
+
+ def getcode(self):
+ deprecation_warning('Response.getcode() is deprecated, use Response.status', stacklevel=2)
+ return self.status
+
+ def geturl(self):
+ deprecation_warning('Response.geturl() is deprecated, use Response.url', stacklevel=2)
+ return self.url
+
+ def info(self):
+ deprecation_warning('Response.info() is deprecated, use Response.headers', stacklevel=2)
+ return self.headers
+
+ def getheader(self, name, default=None):
+ deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
+ return self.get_header(name, default)
+
+
+if typing.TYPE_CHECKING:
+ RequestData = bytes | Iterable[bytes] | typing.IO | None
+ Preference = typing.Callable[[RequestHandler, Request], int]
+
+_RH_PREFERENCES: set[Preference] = set()
diff --git a/yt_dlp/networking/exceptions.py b/yt_dlp/networking/exceptions.py
new file mode 100644
index 0000000..9037f18
--- /dev/null
+++ b/yt_dlp/networking/exceptions.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+import typing
+
+from ..utils import YoutubeDLError
+
+if typing.TYPE_CHECKING:
+ from .common import RequestHandler, Response
+
+
+class RequestError(YoutubeDLError):
+ def __init__(
+ self,
+ msg: str | None = None,
+ cause: Exception | str | None = None,
+ handler: RequestHandler = None
+ ):
+ self.handler = handler
+ self.cause = cause
+ if not msg and cause:
+ msg = str(cause)
+ super().__init__(msg)
+
+
+class UnsupportedRequest(RequestError):
+ """raised when a handler cannot handle a request"""
+ pass
+
+
+class NoSupportingHandlers(RequestError):
+ """raised when no handlers can support a request for various reasons"""
+
+ def __init__(self, unsupported_errors: list[UnsupportedRequest], unexpected_errors: list[Exception]):
+ self.unsupported_errors = unsupported_errors or []
+ self.unexpected_errors = unexpected_errors or []
+
+ # Print a quick summary of the errors
+ err_handler_map = {}
+ for err in unsupported_errors:
+ err_handler_map.setdefault(err.msg, []).append(err.handler.RH_NAME)
+
+ reason_str = ', '.join([f'{msg} ({", ".join(handlers)})' for msg, handlers in err_handler_map.items()])
+ if unexpected_errors:
+ reason_str = ' + '.join(filter(None, [reason_str, f'{len(unexpected_errors)} unexpected error(s)']))
+
+ err_str = 'Unable to handle request'
+ if reason_str:
+ err_str += f': {reason_str}'
+
+ super().__init__(msg=err_str)
+
+
+class TransportError(RequestError):
+ """Network related errors"""
+
+
+class HTTPError(RequestError):
+ def __init__(self, response: Response, redirect_loop=False):
+ self.response = response
+ self.status = response.status
+ self.reason = response.reason
+ self.redirect_loop = redirect_loop
+ msg = f'HTTP Error {response.status}: {response.reason}'
+ if redirect_loop:
+ msg += ' (redirect loop detected)'
+
+ super().__init__(msg=msg)
+
+ def close(self):
+ self.response.close()
+
+ def __repr__(self):
+ return f'<HTTPError {self.status}: {self.reason}>'
+
+
+class IncompleteRead(TransportError):
+ def __init__(self, partial: int, expected: int | None = None, **kwargs):
+ self.partial = partial
+ self.expected = expected
+ msg = f'{partial} bytes read'
+ if expected is not None:
+ msg += f', {expected} more expected'
+
+ super().__init__(msg=msg, **kwargs)
+
+ def __repr__(self):
+ return f'<IncompleteRead: {self.msg}>'
+
+
+class SSLError(TransportError):
+ pass
+
+
+class CertificateVerifyError(SSLError):
+ """Raised when certificate validated has failed"""
+ pass
+
+
+class ProxyError(TransportError):
+ pass
+
+
+network_exceptions = (HTTPError, TransportError)
diff --git a/yt_dlp/networking/websocket.py b/yt_dlp/networking/websocket.py
new file mode 100644
index 0000000..0e7e73c
--- /dev/null
+++ b/yt_dlp/networking/websocket.py
@@ -0,0 +1,23 @@
+from __future__ import annotations
+
+import abc
+
+from .common import RequestHandler, Response
+
+
+class WebSocketResponse(Response):
+
+ def send(self, message: bytes | str):
+ """
+ Send a message to the server.
+
+ @param message: The message to send. A string (str) is sent as a text frame, bytes is sent as a binary frame.
+ """
+ raise NotImplementedError
+
+ def recv(self):
+ raise NotImplementedError
+
+
+class WebSocketRequestHandler(RequestHandler, abc.ABC):
+ pass