diff options
Diffstat (limited to 'yt_dlp/networking/_websockets.py')
-rw-r--r-- | yt_dlp/networking/_websockets.py | 173 |
1 files changed, 173 insertions, 0 deletions
diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py new file mode 100644 index 0000000..1597932 --- /dev/null +++ b/yt_dlp/networking/_websockets.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import io +import logging +import ssl +import sys + +from ._helper import ( + create_connection, + create_socks_proxy_socket, + make_socks_proxy_opts, + select_proxy, +) +from .common import Features, Response, register_rh +from .exceptions import ( + CertificateVerifyError, + HTTPError, + ProxyError, + RequestError, + SSLError, + TransportError, +) +from .websocket import WebSocketRequestHandler, WebSocketResponse +from ..compat import functools +from ..dependencies import websockets +from ..socks import ProxyError as SocksProxyError +from ..utils import int_or_none + +if not websockets: + raise ImportError('websockets is not installed') + +import websockets.version + +websockets_version = tuple(map(int_or_none, websockets.version.version.split('.'))) +if websockets_version < (12, 0): + raise ImportError('Only websockets>=12.0 is supported') + +import websockets.sync.client +from websockets.uri import parse_uri + + +class WebsocketsResponseAdapter(WebSocketResponse): + + def __init__(self, wsw: websockets.sync.client.ClientConnection, url): + super().__init__( + fp=io.BytesIO(wsw.response.body or b''), + url=url, + headers=wsw.response.headers, + status=wsw.response.status_code, + reason=wsw.response.reason_phrase, + ) + self.wsw = wsw + + def close(self): + self.wsw.close() + super().close() + + def send(self, message): + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send + try: + return self.wsw.send(message) + except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: + raise TransportError(cause=e) from e + except SocksProxyError as e: + raise ProxyError(cause=e) from e + except TypeError as e: + raise RequestError(cause=e) from e + + def recv(self): + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv + try: + return self.wsw.recv() + except SocksProxyError as e: + raise ProxyError(cause=e) from e + except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: + raise TransportError(cause=e) from e + + +@register_rh +class WebsocketsRH(WebSocketRequestHandler): + """ + Websockets request handler + https://websockets.readthedocs.io + https://github.com/python-websockets/websockets + """ + _SUPPORTED_URL_SCHEMES = ('wss', 'ws') + _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h') + _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY) + RH_NAME = 'websockets' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__logging_handlers = {} + for name in ('websockets.client', 'websockets.server'): + logger = logging.getLogger(name) + handler = logging.StreamHandler(stream=sys.stdout) + handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s')) + self.__logging_handlers[name] = handler + logger.addHandler(handler) + if self.verbose: + logger.setLevel(logging.DEBUG) + + def _check_extensions(self, extensions): + super()._check_extensions(extensions) + extensions.pop('timeout', None) + extensions.pop('cookiejar', None) + + def close(self): + # Remove the logging handler that contains a reference to our logger + # See: https://github.com/yt-dlp/yt-dlp/issues/8922 + for name, handler in self.__logging_handlers.items(): + logging.getLogger(name).removeHandler(handler) + + def _send(self, request): + timeout = float(request.extensions.get('timeout') or self.timeout) + headers = self._merge_headers(request.headers) + if 'cookie' not in headers: + cookiejar = request.extensions.get('cookiejar') or self.cookiejar + cookie_header = cookiejar.get_cookie_header(request.url) + if cookie_header: + headers['cookie'] = cookie_header + + wsuri = parse_uri(request.url) + create_conn_kwargs = { + 'source_address': (self.source_address, 0) if self.source_address else None, + 'timeout': timeout + } + proxy = select_proxy(request.url, request.proxies or self.proxies or {}) + try: + if proxy: + socks_proxy_options = make_socks_proxy_opts(proxy) + sock = create_connection( + address=(socks_proxy_options['addr'], socks_proxy_options['port']), + _create_socket_func=functools.partial( + create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options), + **create_conn_kwargs + ) + else: + sock = create_connection( + address=(wsuri.host, wsuri.port), + **create_conn_kwargs + ) + conn = websockets.sync.client.connect( + sock=sock, + uri=request.url, + additional_headers=headers, + open_timeout=timeout, + user_agent_header=None, + ssl_context=self._make_sslcontext() if wsuri.secure else None, + close_timeout=0, # not ideal, but prevents yt-dlp hanging + ) + return WebsocketsResponseAdapter(conn, url=request.url) + + # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html + except SocksProxyError as e: + raise ProxyError(cause=e) from e + except websockets.exceptions.InvalidURI as e: + raise RequestError(cause=e) from e + except ssl.SSLCertVerificationError as e: + raise CertificateVerifyError(cause=e) from e + except ssl.SSLError as e: + raise SSLError(cause=e) from e + except websockets.exceptions.InvalidStatus as e: + raise HTTPError( + Response( + fp=io.BytesIO(e.response.body), + url=request.url, + headers=e.response.headers, + status=e.response.status_code, + reason=e.response.reason_phrase), + ) from e + except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e: + raise TransportError(cause=e) from e |