From 2415e66f889f38503b73e8ebc5f43ca342390e5c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 15 Apr 2024 18:49:24 +0200 Subject: Adding upstream version 2024.03.10. Signed-off-by: Daniel Baumann --- yt_dlp/networking/__init__.py | 30 +++ yt_dlp/networking/_helper.py | 283 ++++++++++++++++++++ yt_dlp/networking/_requests.py | 408 ++++++++++++++++++++++++++++ yt_dlp/networking/_urllib.py | 422 +++++++++++++++++++++++++++++ yt_dlp/networking/_websockets.py | 173 ++++++++++++ yt_dlp/networking/common.py | 565 +++++++++++++++++++++++++++++++++++++++ yt_dlp/networking/exceptions.py | 103 +++++++ yt_dlp/networking/websocket.py | 23 ++ 8 files changed, 2007 insertions(+) create mode 100644 yt_dlp/networking/__init__.py create mode 100644 yt_dlp/networking/_helper.py create mode 100644 yt_dlp/networking/_requests.py create mode 100644 yt_dlp/networking/_urllib.py create mode 100644 yt_dlp/networking/_websockets.py create mode 100644 yt_dlp/networking/common.py create mode 100644 yt_dlp/networking/exceptions.py create mode 100644 yt_dlp/networking/websocket.py (limited to 'yt_dlp/networking') diff --git a/yt_dlp/networking/__init__.py b/yt_dlp/networking/__init__.py new file mode 100644 index 0000000..acadc01 --- /dev/null +++ b/yt_dlp/networking/__init__.py @@ -0,0 +1,30 @@ +# flake8: noqa: F401 +import warnings + +from .common import ( + HEADRequest, + PUTRequest, + Request, + RequestDirector, + RequestHandler, + Response, +) + +# isort: split +# TODO: all request handlers should be safely imported +from . import _urllib +from ..utils import bug_reports_message + +try: + from . import _requests +except ImportError: + pass +except Exception as e: + warnings.warn(f'Failed to import "requests" request handler: {e}' + bug_reports_message()) + +try: + from . import _websockets +except ImportError: + pass +except Exception as e: + warnings.warn(f'Failed to import "websockets" request handler: {e}' + bug_reports_message()) diff --git a/yt_dlp/networking/_helper.py b/yt_dlp/networking/_helper.py new file mode 100644 index 0000000..d79dd79 --- /dev/null +++ b/yt_dlp/networking/_helper.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import contextlib +import functools +import socket +import ssl +import sys +import typing +import urllib.parse +import urllib.request + +from .exceptions import RequestError, UnsupportedRequest +from ..dependencies import certifi +from ..socks import ProxyType, sockssocket +from ..utils import format_field, traverse_obj + +if typing.TYPE_CHECKING: + from collections.abc import Iterable + + from ..utils.networking import HTTPHeaderDict + + +def ssl_load_certs(context: ssl.SSLContext, use_certifi=True): + if certifi and use_certifi: + context.load_verify_locations(cafile=certifi.where()) + else: + try: + context.load_default_certs() + # Work around the issue in load_default_certs when there are bad certificates. See: + # https://github.com/yt-dlp/yt-dlp/issues/1060, + # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312 + except ssl.SSLError: + # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151 + if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'): + for storename in ('CA', 'ROOT'): + ssl_load_windows_store_certs(context, storename) + context.set_default_verify_paths() + + +def ssl_load_windows_store_certs(ssl_context, storename): + # Code adapted from _load_windows_store_certs in https://github.com/python/cpython/blob/main/Lib/ssl.py + try: + certs = [cert for cert, encoding, trust in ssl.enum_certificates(storename) + if encoding == 'x509_asn' and ( + trust is True or ssl.Purpose.SERVER_AUTH.oid in trust)] + except PermissionError: + return + for cert in certs: + with contextlib.suppress(ssl.SSLError): + ssl_context.load_verify_locations(cadata=cert) + + +def make_socks_proxy_opts(socks_proxy): + url_components = urllib.parse.urlparse(socks_proxy) + if url_components.scheme.lower() == 'socks5': + socks_type = ProxyType.SOCKS5 + rdns = False + elif url_components.scheme.lower() == 'socks5h': + socks_type = ProxyType.SOCKS5 + rdns = True + elif url_components.scheme.lower() == 'socks4': + socks_type = ProxyType.SOCKS4 + rdns = False + elif url_components.scheme.lower() == 'socks4a': + socks_type = ProxyType.SOCKS4A + rdns = True + else: + raise ValueError(f'Unknown SOCKS proxy version: {url_components.scheme.lower()}') + + def unquote_if_non_empty(s): + if not s: + return s + return urllib.parse.unquote_plus(s) + return { + 'proxytype': socks_type, + 'addr': url_components.hostname, + 'port': url_components.port or 1080, + 'rdns': rdns, + 'username': unquote_if_non_empty(url_components.username), + 'password': unquote_if_non_empty(url_components.password), + } + + +def select_proxy(url, proxies): + """Unified proxy selector for all backends""" + url_components = urllib.parse.urlparse(url) + if 'no' in proxies: + hostport = url_components.hostname + format_field(url_components.port, None, ':%s') + if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}): + return + elif urllib.request.proxy_bypass(hostport): # check system settings + return + + return traverse_obj(proxies, url_components.scheme or 'http', 'all') + + +def get_redirect_method(method, status): + """Unified redirect method handling""" + + # A 303 must either use GET or HEAD for subsequent request + # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.4 + if status == 303 and method != 'HEAD': + method = 'GET' + # 301 and 302 redirects are commonly turned into a GET from a POST + # for subsequent requests by browsers, so we'll do the same. + # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.2 + # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.3 + if status in (301, 302) and method == 'POST': + method = 'GET' + return method + + +def make_ssl_context( + verify=True, + client_certificate=None, + client_certificate_key=None, + client_certificate_password=None, + legacy_support=False, + use_certifi=True, +): + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = verify + context.verify_mode = ssl.CERT_REQUIRED if verify else ssl.CERT_NONE + + # Some servers may reject requests if ALPN extension is not sent. See: + # https://github.com/python/cpython/issues/85140 + # https://github.com/yt-dlp/yt-dlp/issues/3878 + with contextlib.suppress(NotImplementedError): + context.set_alpn_protocols(['http/1.1']) + if verify: + ssl_load_certs(context, use_certifi) + + if legacy_support: + context.options |= 4 # SSL_OP_LEGACY_SERVER_CONNECT + context.set_ciphers('DEFAULT') # compat + + elif ssl.OPENSSL_VERSION_INFO >= (1, 1, 1) and not ssl.OPENSSL_VERSION.startswith('LibreSSL'): + # Use the default SSL ciphers and minimum TLS version settings from Python 3.10 [1]. + # This is to ensure consistent behavior across Python versions and libraries, and help avoid fingerprinting + # in some situations [2][3]. + # Python 3.10 only supports OpenSSL 1.1.1+ [4]. Because this change is likely + # untested on older versions, we only apply this to OpenSSL 1.1.1+ to be safe. + # LibreSSL is excluded until further investigation due to cipher support issues [5][6]. + # 1. https://github.com/python/cpython/commit/e983252b516edb15d4338b0a47631b59ef1e2536 + # 2. https://github.com/yt-dlp/yt-dlp/issues/4627 + # 3. https://github.com/yt-dlp/yt-dlp/pull/5294 + # 4. https://peps.python.org/pep-0644/ + # 5. https://peps.python.org/pep-0644/#libressl-support + # 6. https://github.com/yt-dlp/yt-dlp/commit/5b9f253fa0aee996cf1ed30185d4b502e00609c4#commitcomment-89054368 + context.set_ciphers( + '@SECLEVEL=2:ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES:DHE+AES:!aNULL:!eNULL:!aDSS:!SHA1:!AESCCM') + context.minimum_version = ssl.TLSVersion.TLSv1_2 + + if client_certificate: + try: + context.load_cert_chain( + client_certificate, keyfile=client_certificate_key, + password=client_certificate_password) + except ssl.SSLError: + raise RequestError('Unable to load client certificate') + + if getattr(context, 'post_handshake_auth', None) is not None: + context.post_handshake_auth = True + return context + + +class InstanceStoreMixin: + def __init__(self, **kwargs): + self.__instances = [] + super().__init__(**kwargs) # So that both MRO works + + @staticmethod + def _create_instance(**kwargs): + raise NotImplementedError + + def _get_instance(self, **kwargs): + for key, instance in self.__instances: + if key == kwargs: + return instance + + instance = self._create_instance(**kwargs) + self.__instances.append((kwargs, instance)) + return instance + + def _close_instance(self, instance): + if callable(getattr(instance, 'close', None)): + instance.close() + + def _clear_instances(self): + for _, instance in self.__instances: + self._close_instance(instance) + self.__instances.clear() + + +def add_accept_encoding_header(headers: HTTPHeaderDict, supported_encodings: Iterable[str]): + if 'Accept-Encoding' not in headers: + headers['Accept-Encoding'] = ', '.join(supported_encodings) or 'identity' + + +def wrap_request_errors(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except UnsupportedRequest as e: + if e.handler is None: + e.handler = self + raise + return wrapper + + +def _socket_connect(ip_addr, timeout, source_address): + af, socktype, proto, canonname, sa = ip_addr + sock = socket.socket(af, socktype, proto) + try: + if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect(sa) + return sock + except OSError: + sock.close() + raise + + +def create_socks_proxy_socket(dest_addr, proxy_args, proxy_ip_addr, timeout, source_address): + af, socktype, proto, canonname, sa = proxy_ip_addr + sock = sockssocket(af, socktype, proto) + try: + connect_proxy_args = proxy_args.copy() + connect_proxy_args.update({'addr': sa[0], 'port': sa[1]}) + sock.setproxy(**connect_proxy_args) + if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: # noqa: E721 + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect(dest_addr) + return sock + except OSError: + sock.close() + raise + + +def create_connection( + address, + timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + source_address=None, + *, + _create_socket_func=_socket_connect +): + # Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6. + # This filters the addresses based on the given source_address. + # Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810 + host, port = address + ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM) + if not ip_addrs: + raise OSError('getaddrinfo returns an empty list') + if source_address is not None: + af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6 + ip_addrs = [addr for addr in ip_addrs if addr[0] == af] + if not ip_addrs: + raise OSError( + f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. ' + f'Can\'t use "{source_address[0]}" as source address') + + err = None + for ip_addr in ip_addrs: + try: + sock = _create_socket_func(ip_addr, timeout, source_address) + # Explicitly break __traceback__ reference cycle + # https://bugs.python.org/issue36820 + err = None + return sock + except OSError as e: + err = e + + try: + raise err + finally: + # Explicitly break __traceback__ reference cycle + # https://bugs.python.org/issue36820 + err = None diff --git a/yt_dlp/networking/_requests.py b/yt_dlp/networking/_requests.py new file mode 100644 index 0000000..6545028 --- /dev/null +++ b/yt_dlp/networking/_requests.py @@ -0,0 +1,408 @@ +import contextlib +import functools +import http.client +import logging +import re +import socket +import warnings + +from ..dependencies import brotli, requests, urllib3 +from ..utils import bug_reports_message, int_or_none, variadic +from ..utils.networking import normalize_url + +if requests is None: + raise ImportError('requests module is not installed') + +if urllib3 is None: + raise ImportError('urllib3 module is not installed') + +urllib3_version = tuple(int_or_none(x, default=0) for x in urllib3.__version__.split('.')) + +if urllib3_version < (1, 26, 17): + raise ImportError('Only urllib3 >= 1.26.17 is supported') + +if requests.__build__ < 0x023100: + raise ImportError('Only requests >= 2.31.0 is supported') + +import requests.adapters +import requests.utils +import urllib3.connection +import urllib3.exceptions + +from ._helper import ( + InstanceStoreMixin, + add_accept_encoding_header, + create_connection, + create_socks_proxy_socket, + get_redirect_method, + make_socks_proxy_opts, + select_proxy, +) +from .common import ( + Features, + RequestHandler, + Response, + register_preference, + register_rh, +) +from .exceptions import ( + CertificateVerifyError, + HTTPError, + IncompleteRead, + ProxyError, + RequestError, + SSLError, + TransportError, +) +from ..socks import ProxyError as SocksProxyError + +SUPPORTED_ENCODINGS = [ + 'gzip', 'deflate' +] + +if brotli is not None: + SUPPORTED_ENCODINGS.append('br') + +""" +Override urllib3's behavior to not convert lower-case percent-encoded characters +to upper-case during url normalization process. + +RFC3986 defines that the lower or upper case percent-encoded hexidecimal characters are equivalent +and normalizers should convert them to uppercase for consistency [1]. + +However, some sites may have an incorrect implementation where they provide +a percent-encoded url that is then compared case-sensitively.[2] + +While this is a very rare case, since urllib does not do this normalization step, it +is best to avoid it in requests too for compatability reasons. + +1: https://tools.ietf.org/html/rfc3986#section-2.1 +2: https://github.com/streamlink/streamlink/pull/4003 +""" + + +class Urllib3PercentREOverride: + def __init__(self, r: re.Pattern): + self.re = r + + # pass through all other attribute calls to the original re + def __getattr__(self, item): + return self.re.__getattribute__(item) + + def subn(self, repl, string, *args, **kwargs): + return string, self.re.subn(repl, string, *args, **kwargs)[1] + + +# urllib3 >= 1.25.8 uses subn: +# https://github.com/urllib3/urllib3/commit/a2697e7c6b275f05879b60f593c5854a816489f0 +import urllib3.util.url # noqa: E305 + +if hasattr(urllib3.util.url, 'PERCENT_RE'): + urllib3.util.url.PERCENT_RE = Urllib3PercentREOverride(urllib3.util.url.PERCENT_RE) +elif hasattr(urllib3.util.url, '_PERCENT_RE'): # urllib3 >= 2.0.0 + urllib3.util.url._PERCENT_RE = Urllib3PercentREOverride(urllib3.util.url._PERCENT_RE) +else: + warnings.warn('Failed to patch PERCENT_RE in urllib3 (does the attribute exist?)' + bug_reports_message()) + +""" +Workaround for issue in urllib.util.ssl_.py: ssl_wrap_context does not pass +server_hostname to SSLContext.wrap_socket if server_hostname is an IP, +however this is an issue because we set check_hostname to True in our SSLContext. + +Monkey-patching IS_SECURETRANSPORT forces ssl_wrap_context to pass server_hostname regardless. + +This has been fixed in urllib3 2.0+. +See: https://github.com/urllib3/urllib3/issues/517 +""" + +if urllib3_version < (2, 0, 0): + with contextlib.suppress(Exception): + urllib3.util.IS_SECURETRANSPORT = urllib3.util.ssl_.IS_SECURETRANSPORT = True + + +# Requests will not automatically handle no_proxy by default +# due to buggy no_proxy handling with proxy dict [1]. +# 1. https://github.com/psf/requests/issues/5000 +requests.adapters.select_proxy = select_proxy + + +class RequestsResponseAdapter(Response): + def __init__(self, res: requests.models.Response): + super().__init__( + fp=res.raw, headers=res.headers, url=res.url, + status=res.status_code, reason=res.reason) + + self._requests_response = res + + def read(self, amt: int = None): + try: + # Interact with urllib3 response directly. + return self.fp.read(amt, decode_content=True) + + # See urllib3.response.HTTPResponse.read() for exceptions raised on read + except urllib3.exceptions.SSLError as e: + raise SSLError(cause=e) from e + + except urllib3.exceptions.ProtocolError as e: + # IncompleteRead is always contained within ProtocolError + # See urllib3.response.HTTPResponse._error_catcher() + ir_err = next( + (err for err in (e.__context__, e.__cause__, *variadic(e.args)) + if isinstance(err, http.client.IncompleteRead)), None) + if ir_err is not None: + # `urllib3.exceptions.IncompleteRead` is subclass of `http.client.IncompleteRead` + # but uses an `int` for its `partial` property. + partial = ir_err.partial if isinstance(ir_err.partial, int) else len(ir_err.partial) + raise IncompleteRead(partial=partial, expected=ir_err.expected) from e + raise TransportError(cause=e) from e + + except urllib3.exceptions.HTTPError as e: + # catch-all for any other urllib3 response exceptions + raise TransportError(cause=e) from e + + +class RequestsHTTPAdapter(requests.adapters.HTTPAdapter): + def __init__(self, ssl_context=None, proxy_ssl_context=None, source_address=None, **kwargs): + self._pm_args = {} + if ssl_context: + self._pm_args['ssl_context'] = ssl_context + if source_address: + self._pm_args['source_address'] = (source_address, 0) + self._proxy_ssl_context = proxy_ssl_context or ssl_context + super().__init__(**kwargs) + + def init_poolmanager(self, *args, **kwargs): + return super().init_poolmanager(*args, **kwargs, **self._pm_args) + + def proxy_manager_for(self, proxy, **proxy_kwargs): + extra_kwargs = {} + if not proxy.lower().startswith('socks') and self._proxy_ssl_context: + extra_kwargs['proxy_ssl_context'] = self._proxy_ssl_context + return super().proxy_manager_for(proxy, **proxy_kwargs, **self._pm_args, **extra_kwargs) + + def cert_verify(*args, **kwargs): + # lean on SSLContext for cert verification + pass + + +class RequestsSession(requests.sessions.Session): + """ + Ensure unified redirect method handling with our urllib redirect handler. + """ + + def rebuild_method(self, prepared_request, response): + new_method = get_redirect_method(prepared_request.method, response.status_code) + + # HACK: requests removes headers/body on redirect unless code was a 307/308. + if new_method == prepared_request.method: + response._real_status_code = response.status_code + response.status_code = 308 + + prepared_request.method = new_method + + # Requests fails to resolve dot segments on absolute redirect locations + # See: https://github.com/yt-dlp/yt-dlp/issues/9020 + prepared_request.url = normalize_url(prepared_request.url) + + def rebuild_auth(self, prepared_request, response): + # HACK: undo status code change from rebuild_method, if applicable. + # rebuild_auth runs after requests would remove headers/body based on status code + if hasattr(response, '_real_status_code'): + response.status_code = response._real_status_code + del response._real_status_code + return super().rebuild_auth(prepared_request, response) + + +class Urllib3LoggingFilter(logging.Filter): + + def filter(self, record): + # Ignore HTTP request messages since HTTPConnection prints those + if record.msg == '%s://%s:%s "%s %s %s" %s %s': + return False + return True + + +class Urllib3LoggingHandler(logging.Handler): + """Redirect urllib3 logs to our logger""" + + def __init__(self, logger, *args, **kwargs): + super().__init__(*args, **kwargs) + self._logger = logger + + def emit(self, record): + try: + msg = self.format(record) + if record.levelno >= logging.ERROR: + self._logger.error(msg) + else: + self._logger.stdout(msg) + + except Exception: + self.handleError(record) + + +@register_rh +class RequestsRH(RequestHandler, InstanceStoreMixin): + + """Requests RequestHandler + https://github.com/psf/requests + """ + _SUPPORTED_URL_SCHEMES = ('http', 'https') + _SUPPORTED_ENCODINGS = tuple(SUPPORTED_ENCODINGS) + _SUPPORTED_PROXY_SCHEMES = ('http', 'https', 'socks4', 'socks4a', 'socks5', 'socks5h') + _SUPPORTED_FEATURES = (Features.NO_PROXY, Features.ALL_PROXY) + RH_NAME = 'requests' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Forward urllib3 debug messages to our logger + logger = logging.getLogger('urllib3') + self.__logging_handler = Urllib3LoggingHandler(logger=self._logger) + self.__logging_handler.setFormatter(logging.Formatter('requests: %(message)s')) + self.__logging_handler.addFilter(Urllib3LoggingFilter()) + logger.addHandler(self.__logging_handler) + # TODO: Use a logger filter to suppress pool reuse warning instead + logger.setLevel(logging.ERROR) + + if self.verbose: + # Setting this globally is not ideal, but is easier than hacking with urllib3. + # It could technically be problematic for scripts embedding yt-dlp. + # However, it is unlikely debug traffic is used in that context in a way this will cause problems. + urllib3.connection.HTTPConnection.debuglevel = 1 + logger.setLevel(logging.DEBUG) + # this is expected if we are using --no-check-certificate + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + def close(self): + self._clear_instances() + # Remove the logging handler that contains a reference to our logger + # See: https://github.com/yt-dlp/yt-dlp/issues/8922 + logging.getLogger('urllib3').removeHandler(self.__logging_handler) + + def _check_extensions(self, extensions): + super()._check_extensions(extensions) + extensions.pop('cookiejar', None) + extensions.pop('timeout', None) + + def _create_instance(self, cookiejar): + session = RequestsSession() + http_adapter = RequestsHTTPAdapter( + ssl_context=self._make_sslcontext(), + source_address=self.source_address, + max_retries=urllib3.util.retry.Retry(False), + ) + session.adapters.clear() + session.headers = requests.models.CaseInsensitiveDict({'Connection': 'keep-alive'}) + session.mount('https://', http_adapter) + session.mount('http://', http_adapter) + session.cookies = cookiejar + session.trust_env = False # no need, we already load proxies from env + return session + + def _send(self, request): + + headers = self._merge_headers(request.headers) + add_accept_encoding_header(headers, SUPPORTED_ENCODINGS) + + max_redirects_exceeded = False + + session = self._get_instance( + cookiejar=request.extensions.get('cookiejar') or self.cookiejar) + + try: + requests_res = session.request( + method=request.method, + url=request.url, + data=request.data, + headers=headers, + timeout=float(request.extensions.get('timeout') or self.timeout), + proxies=request.proxies or self.proxies, + allow_redirects=True, + stream=True + ) + + except requests.exceptions.TooManyRedirects as e: + max_redirects_exceeded = True + requests_res = e.response + + except requests.exceptions.SSLError as e: + if 'CERTIFICATE_VERIFY_FAILED' in str(e): + raise CertificateVerifyError(cause=e) from e + raise SSLError(cause=e) from e + + except requests.exceptions.ProxyError as e: + raise ProxyError(cause=e) from e + + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: + raise TransportError(cause=e) from e + + except urllib3.exceptions.HTTPError as e: + # Catch any urllib3 exceptions that may leak through + raise TransportError(cause=e) from e + + except requests.exceptions.RequestException as e: + # Miscellaneous Requests exceptions. May not necessary be network related e.g. InvalidURL + raise RequestError(cause=e) from e + + res = RequestsResponseAdapter(requests_res) + + if not 200 <= res.status < 300: + raise HTTPError(res, redirect_loop=max_redirects_exceeded) + + return res + + +@register_preference(RequestsRH) +def requests_preference(rh, request): + return 100 + + +# Use our socks proxy implementation with requests to avoid an extra dependency. +class SocksHTTPConnection(urllib3.connection.HTTPConnection): + def __init__(self, _socks_options, *args, **kwargs): # must use _socks_options to pass PoolKey checks + self._proxy_args = _socks_options + super().__init__(*args, **kwargs) + + def _new_conn(self): + try: + return create_connection( + address=(self._proxy_args['addr'], self._proxy_args['port']), + timeout=self.timeout, + source_address=self.source_address, + _create_socket_func=functools.partial( + create_socks_proxy_socket, (self.host, self.port), self._proxy_args)) + except (socket.timeout, TimeoutError) as e: + raise urllib3.exceptions.ConnectTimeoutError( + self, f'Connection to {self.host} timed out. (connect timeout={self.timeout})') from e + except SocksProxyError as e: + raise urllib3.exceptions.ProxyError(str(e), e) from e + except OSError as e: + raise urllib3.exceptions.NewConnectionError( + self, f'Failed to establish a new connection: {e}') from e + + +class SocksHTTPSConnection(SocksHTTPConnection, urllib3.connection.HTTPSConnection): + pass + + +class SocksHTTPConnectionPool(urllib3.HTTPConnectionPool): + ConnectionCls = SocksHTTPConnection + + +class SocksHTTPSConnectionPool(urllib3.HTTPSConnectionPool): + ConnectionCls = SocksHTTPSConnection + + +class SocksProxyManager(urllib3.PoolManager): + + def __init__(self, socks_proxy, username=None, password=None, num_pools=10, headers=None, **connection_pool_kw): + connection_pool_kw['_socks_options'] = make_socks_proxy_opts(socks_proxy) + super().__init__(num_pools, headers, **connection_pool_kw) + self.pool_classes_by_scheme = { + 'http': SocksHTTPConnectionPool, + 'https': SocksHTTPSConnectionPool + } + + +requests.adapters.SOCKSProxyManager = SocksProxyManager diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py new file mode 100644 index 0000000..cb4dae3 --- /dev/null +++ b/yt_dlp/networking/_urllib.py @@ -0,0 +1,422 @@ +from __future__ import annotations + +import functools +import http.client +import io +import ssl +import urllib.error +import urllib.parse +import urllib.request +import urllib.response +import zlib +from urllib.request import ( + DataHandler, + FileHandler, + FTPHandler, + HTTPCookieProcessor, + HTTPDefaultErrorHandler, + HTTPErrorProcessor, + UnknownHandler, +) + +from ._helper import ( + InstanceStoreMixin, + add_accept_encoding_header, + create_connection, + create_socks_proxy_socket, + get_redirect_method, + make_socks_proxy_opts, + select_proxy, +) +from .common import Features, RequestHandler, Response, register_rh +from .exceptions import ( + CertificateVerifyError, + HTTPError, + IncompleteRead, + ProxyError, + RequestError, + SSLError, + TransportError, +) +from ..dependencies import brotli +from ..socks import ProxyError as SocksProxyError +from ..utils import update_url_query +from ..utils.networking import normalize_url + +SUPPORTED_ENCODINGS = ['gzip', 'deflate'] +CONTENT_DECODE_ERRORS = [zlib.error, OSError] + +if brotli: + SUPPORTED_ENCODINGS.append('br') + CONTENT_DECODE_ERRORS.append(brotli.error) + + +def _create_http_connection(http_class, source_address, *args, **kwargs): + hc = http_class(*args, **kwargs) + + if hasattr(hc, '_create_connection'): + hc._create_connection = create_connection + + if source_address is not None: + hc.source_address = (source_address, 0) + + return hc + + +class HTTPHandler(urllib.request.AbstractHTTPHandler): + """Handler for HTTP requests and responses. + + This class, when installed with an OpenerDirector, automatically adds + the standard headers to every HTTP request and handles gzipped, deflated and + brotli responses from web servers. + + Part of this code was copied from: + + http://techknack.net/python-urllib2-handlers/ + + Andrew Rowls, the author of that code, agreed to release it to the + public domain. + """ + + def __init__(self, context=None, source_address=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self._source_address = source_address + self._context = context + + @staticmethod + def _make_conn_class(base, req): + conn_class = base + socks_proxy = req.headers.pop('Ytdl-socks-proxy', None) + if socks_proxy: + conn_class = make_socks_conn_class(conn_class, socks_proxy) + return conn_class + + def http_open(self, req): + conn_class = self._make_conn_class(http.client.HTTPConnection, req) + return self.do_open(functools.partial( + _create_http_connection, conn_class, self._source_address), req) + + def https_open(self, req): + conn_class = self._make_conn_class(http.client.HTTPSConnection, req) + return self.do_open( + functools.partial( + _create_http_connection, conn_class, self._source_address), + req, context=self._context) + + @staticmethod + def deflate(data): + if not data: + return data + try: + return zlib.decompress(data, -zlib.MAX_WBITS) + except zlib.error: + return zlib.decompress(data) + + @staticmethod + def brotli(data): + if not data: + return data + return brotli.decompress(data) + + @staticmethod + def gz(data): + # There may be junk added the end of the file + # We ignore it by only ever decoding a single gzip payload + if not data: + return data + return zlib.decompress(data, wbits=zlib.MAX_WBITS | 16) + + def http_request(self, req): + # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not + # always respected by websites, some tend to give out URLs with non percent-encoded + # non-ASCII characters (see telemb.py, ard.py [#3412]) + # urllib chokes on URLs with non-ASCII characters (see http://bugs.python.org/issue3991) + # To work around aforementioned issue we will replace request's original URL with + # percent-encoded one + # Since redirects are also affected (e.g. http://www.southpark.de/alle-episoden/s18e09) + # the code of this workaround has been moved here from YoutubeDL.urlopen() + url = req.get_full_url() + url_escaped = normalize_url(url) + + # Substitute URL if any change after escaping + if url != url_escaped: + req = update_Request(req, url=url_escaped) + + return super().do_request_(req) + + def http_response(self, req, resp): + old_resp = resp + + # Content-Encoding header lists the encodings in order that they were applied [1]. + # To decompress, we simply do the reverse. + # [1]: https://datatracker.ietf.org/doc/html/rfc9110#name-content-encoding + decoded_response = None + for encoding in (e.strip() for e in reversed(resp.headers.get('Content-encoding', '').split(','))): + if encoding == 'gzip': + decoded_response = self.gz(decoded_response or resp.read()) + elif encoding == 'deflate': + decoded_response = self.deflate(decoded_response or resp.read()) + elif encoding == 'br' and brotli: + decoded_response = self.brotli(decoded_response or resp.read()) + + if decoded_response is not None: + resp = urllib.request.addinfourl(io.BytesIO(decoded_response), old_resp.headers, old_resp.url, old_resp.code) + resp.msg = old_resp.msg + # Percent-encode redirect URL of Location HTTP header to satisfy RFC 3986 (see + # https://github.com/ytdl-org/youtube-dl/issues/6457). + if 300 <= resp.code < 400: + location = resp.headers.get('Location') + if location: + # As of RFC 2616 default charset is iso-8859-1 that is respected by Python 3 + location = location.encode('iso-8859-1').decode() + location_escaped = normalize_url(location) + if location != location_escaped: + del resp.headers['Location'] + resp.headers['Location'] = location_escaped + return resp + + https_request = http_request + https_response = http_response + + +def make_socks_conn_class(base_class, socks_proxy): + assert issubclass(base_class, ( + http.client.HTTPConnection, http.client.HTTPSConnection)) + + proxy_args = make_socks_proxy_opts(socks_proxy) + + class SocksConnection(base_class): + _create_connection = create_connection + + def connect(self): + self.sock = create_connection( + (proxy_args['addr'], proxy_args['port']), + timeout=self.timeout, + source_address=self.source_address, + _create_socket_func=functools.partial( + create_socks_proxy_socket, (self.host, self.port), proxy_args)) + if isinstance(self, http.client.HTTPSConnection): + self.sock = self._context.wrap_socket(self.sock, server_hostname=self.host) + + return SocksConnection + + +class RedirectHandler(urllib.request.HTTPRedirectHandler): + """YoutubeDL redirect handler + + The code is based on HTTPRedirectHandler implementation from CPython [1]. + + This redirect handler fixes and improves the logic to better align with RFC7261 + and what browsers tend to do [2][3] + + 1. https://github.com/python/cpython/blob/master/Lib/urllib/request.py + 2. https://datatracker.ietf.org/doc/html/rfc7231 + 3. https://github.com/python/cpython/issues/91306 + """ + + http_error_301 = http_error_303 = http_error_307 = http_error_308 = urllib.request.HTTPRedirectHandler.http_error_302 + + def redirect_request(self, req, fp, code, msg, headers, newurl): + if code not in (301, 302, 303, 307, 308): + raise urllib.error.HTTPError(req.full_url, code, msg, headers, fp) + + new_data = req.data + + # Technically the Cookie header should be in unredirected_hdrs, + # however in practice some may set it in normal headers anyway. + # We will remove it here to prevent any leaks. + remove_headers = ['Cookie'] + + new_method = get_redirect_method(req.get_method(), code) + # only remove payload if method changed (e.g. POST to GET) + if new_method != req.get_method(): + new_data = None + remove_headers.extend(['Content-Length', 'Content-Type']) + + new_headers = {k: v for k, v in req.headers.items() if k.title() not in remove_headers} + + return urllib.request.Request( + newurl, headers=new_headers, origin_req_host=req.origin_req_host, + unverifiable=True, method=new_method, data=new_data) + + +class ProxyHandler(urllib.request.BaseHandler): + handler_order = 100 + + def __init__(self, proxies=None): + self.proxies = proxies + # Set default handlers + for type in ('http', 'https', 'ftp'): + setattr(self, '%s_open' % type, lambda r, meth=self.proxy_open: meth(r)) + + def proxy_open(self, req): + proxy = select_proxy(req.get_full_url(), self.proxies) + if proxy is None: + return + if urllib.parse.urlparse(proxy).scheme.lower() in ('socks4', 'socks4a', 'socks5', 'socks5h'): + req.add_header('Ytdl-socks-proxy', proxy) + # yt-dlp's http/https handlers do wrapping the socket with socks + return None + return urllib.request.ProxyHandler.proxy_open( + self, req, proxy, None) + + +class PUTRequest(urllib.request.Request): + def get_method(self): + return 'PUT' + + +class HEADRequest(urllib.request.Request): + def get_method(self): + return 'HEAD' + + +def update_Request(req, url=None, data=None, headers=None, query=None): + req_headers = req.headers.copy() + req_headers.update(headers or {}) + req_data = data if data is not None else req.data + req_url = update_url_query(url or req.get_full_url(), query) + req_get_method = req.get_method() + if req_get_method == 'HEAD': + req_type = HEADRequest + elif req_get_method == 'PUT': + req_type = PUTRequest + else: + req_type = urllib.request.Request + new_req = req_type( + req_url, data=req_data, headers=req_headers, + origin_req_host=req.origin_req_host, unverifiable=req.unverifiable) + if hasattr(req, 'timeout'): + new_req.timeout = req.timeout + return new_req + + +class UrllibResponseAdapter(Response): + """ + HTTP Response adapter class for urllib addinfourl and http.client.HTTPResponse + """ + + def __init__(self, res: http.client.HTTPResponse | urllib.response.addinfourl): + # addinfourl: In Python 3.9+, .status was introduced and .getcode() was deprecated [1] + # HTTPResponse: .getcode() was deprecated, .status always existed [2] + # 1. https://docs.python.org/3/library/urllib.request.html#urllib.response.addinfourl.getcode + # 2. https://docs.python.org/3.10/library/http.client.html#http.client.HTTPResponse.status + super().__init__( + fp=res, headers=res.headers, url=res.url, + status=getattr(res, 'status', None) or res.getcode(), reason=getattr(res, 'reason', None)) + + def read(self, amt=None): + try: + return self.fp.read(amt) + except Exception as e: + handle_response_read_exceptions(e) + raise e + + +def handle_sslerror(e: ssl.SSLError): + if not isinstance(e, ssl.SSLError): + return + if isinstance(e, ssl.SSLCertVerificationError): + raise CertificateVerifyError(cause=e) from e + raise SSLError(cause=e) from e + + +def handle_response_read_exceptions(e): + if isinstance(e, http.client.IncompleteRead): + raise IncompleteRead(partial=len(e.partial), cause=e, expected=e.expected) from e + elif isinstance(e, ssl.SSLError): + handle_sslerror(e) + elif isinstance(e, (OSError, EOFError, http.client.HTTPException, *CONTENT_DECODE_ERRORS)): + # OSErrors raised here should mostly be network related + raise TransportError(cause=e) from e + + +@register_rh +class UrllibRH(RequestHandler, InstanceStoreMixin): + _SUPPORTED_URL_SCHEMES = ('http', 'https', 'data', 'ftp') + _SUPPORTED_PROXY_SCHEMES = ('http', 'socks4', 'socks4a', 'socks5', 'socks5h') + _SUPPORTED_FEATURES = (Features.NO_PROXY, Features.ALL_PROXY) + RH_NAME = 'urllib' + + def __init__(self, *, enable_file_urls: bool = False, **kwargs): + super().__init__(**kwargs) + self.enable_file_urls = enable_file_urls + if self.enable_file_urls: + self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file') + + def _check_extensions(self, extensions): + super()._check_extensions(extensions) + extensions.pop('cookiejar', None) + extensions.pop('timeout', None) + + def _create_instance(self, proxies, cookiejar): + opener = urllib.request.OpenerDirector() + handlers = [ + ProxyHandler(proxies), + HTTPHandler( + debuglevel=int(bool(self.verbose)), + context=self._make_sslcontext(), + source_address=self.source_address), + HTTPCookieProcessor(cookiejar), + DataHandler(), + UnknownHandler(), + HTTPDefaultErrorHandler(), + FTPHandler(), + HTTPErrorProcessor(), + RedirectHandler(), + ] + + if self.enable_file_urls: + handlers.append(FileHandler()) + + for handler in handlers: + opener.add_handler(handler) + + # Delete the default user-agent header, which would otherwise apply in + # cases where our custom HTTP handler doesn't come into play + # (See https://github.com/ytdl-org/youtube-dl/issues/1309 for details) + opener.addheaders = [] + return opener + + def _send(self, request): + headers = self._merge_headers(request.headers) + add_accept_encoding_header(headers, SUPPORTED_ENCODINGS) + urllib_req = urllib.request.Request( + url=request.url, + data=request.data, + headers=dict(headers), + method=request.method + ) + + opener = self._get_instance( + proxies=request.proxies or self.proxies, + cookiejar=request.extensions.get('cookiejar') or self.cookiejar + ) + try: + res = opener.open(urllib_req, timeout=float(request.extensions.get('timeout') or self.timeout)) + except urllib.error.HTTPError as e: + if isinstance(e.fp, (http.client.HTTPResponse, urllib.response.addinfourl)): + # Prevent file object from being closed when urllib.error.HTTPError is destroyed. + e._closer.close_called = True + raise HTTPError(UrllibResponseAdapter(e.fp), redirect_loop='redirect error' in str(e)) from e + raise # unexpected + except urllib.error.URLError as e: + cause = e.reason # NOTE: cause may be a string + + # proxy errors + if 'tunnel connection failed' in str(cause).lower() or isinstance(cause, SocksProxyError): + raise ProxyError(cause=e) from e + + handle_response_read_exceptions(cause) + raise TransportError(cause=e) from e + except (http.client.InvalidURL, ValueError) as e: + # Validation errors + # http.client.HTTPConnection raises ValueError in some validation cases + # such as if request method contains illegal control characters [1] + # 1. https://github.com/python/cpython/blob/987b712b4aeeece336eed24fcc87a950a756c3e2/Lib/http/client.py#L1256 + raise RequestError(cause=e) from e + except Exception as e: + handle_response_read_exceptions(e) + raise # unexpected + + return UrllibResponseAdapter(res) diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py new file mode 100644 index 0000000..1597932 --- /dev/null +++ b/yt_dlp/networking/_websockets.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import io +import logging +import ssl +import sys + +from ._helper import ( + create_connection, + create_socks_proxy_socket, + make_socks_proxy_opts, + select_proxy, +) +from .common import Features, Response, register_rh +from .exceptions import ( + CertificateVerifyError, + HTTPError, + ProxyError, + RequestError, + SSLError, + TransportError, +) +from .websocket import WebSocketRequestHandler, WebSocketResponse +from ..compat import functools +from ..dependencies import websockets +from ..socks import ProxyError as SocksProxyError +from ..utils import int_or_none + +if not websockets: + raise ImportError('websockets is not installed') + +import websockets.version + +websockets_version = tuple(map(int_or_none, websockets.version.version.split('.'))) +if websockets_version < (12, 0): + raise ImportError('Only websockets>=12.0 is supported') + +import websockets.sync.client +from websockets.uri import parse_uri + + +class WebsocketsResponseAdapter(WebSocketResponse): + + def __init__(self, wsw: websockets.sync.client.ClientConnection, url): + super().__init__( + fp=io.BytesIO(wsw.response.body or b''), + url=url, + headers=wsw.response.headers, + status=wsw.response.status_code, + reason=wsw.response.reason_phrase, + ) + self.wsw = wsw + + def close(self): + self.wsw.close() + super().close() + + def send(self, message): + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send + try: + return self.wsw.send(message) + except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: + raise TransportError(cause=e) from e + except SocksProxyError as e: + raise ProxyError(cause=e) from e + except TypeError as e: + raise RequestError(cause=e) from e + + def recv(self): + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv + try: + return self.wsw.recv() + except SocksProxyError as e: + raise ProxyError(cause=e) from e + except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: + raise TransportError(cause=e) from e + + +@register_rh +class WebsocketsRH(WebSocketRequestHandler): + """ + Websockets request handler + https://websockets.readthedocs.io + https://github.com/python-websockets/websockets + """ + _SUPPORTED_URL_SCHEMES = ('wss', 'ws') + _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h') + _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY) + RH_NAME = 'websockets' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__logging_handlers = {} + for name in ('websockets.client', 'websockets.server'): + logger = logging.getLogger(name) + handler = logging.StreamHandler(stream=sys.stdout) + handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s')) + self.__logging_handlers[name] = handler + logger.addHandler(handler) + if self.verbose: + logger.setLevel(logging.DEBUG) + + def _check_extensions(self, extensions): + super()._check_extensions(extensions) + extensions.pop('timeout', None) + extensions.pop('cookiejar', None) + + def close(self): + # Remove the logging handler that contains a reference to our logger + # See: https://github.com/yt-dlp/yt-dlp/issues/8922 + for name, handler in self.__logging_handlers.items(): + logging.getLogger(name).removeHandler(handler) + + def _send(self, request): + timeout = float(request.extensions.get('timeout') or self.timeout) + headers = self._merge_headers(request.headers) + if 'cookie' not in headers: + cookiejar = request.extensions.get('cookiejar') or self.cookiejar + cookie_header = cookiejar.get_cookie_header(request.url) + if cookie_header: + headers['cookie'] = cookie_header + + wsuri = parse_uri(request.url) + create_conn_kwargs = { + 'source_address': (self.source_address, 0) if self.source_address else None, + 'timeout': timeout + } + proxy = select_proxy(request.url, request.proxies or self.proxies or {}) + try: + if proxy: + socks_proxy_options = make_socks_proxy_opts(proxy) + sock = create_connection( + address=(socks_proxy_options['addr'], socks_proxy_options['port']), + _create_socket_func=functools.partial( + create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options), + **create_conn_kwargs + ) + else: + sock = create_connection( + address=(wsuri.host, wsuri.port), + **create_conn_kwargs + ) + conn = websockets.sync.client.connect( + sock=sock, + uri=request.url, + additional_headers=headers, + open_timeout=timeout, + user_agent_header=None, + ssl_context=self._make_sslcontext() if wsuri.secure else None, + close_timeout=0, # not ideal, but prevents yt-dlp hanging + ) + return WebsocketsResponseAdapter(conn, url=request.url) + + # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html + except SocksProxyError as e: + raise ProxyError(cause=e) from e + except websockets.exceptions.InvalidURI as e: + raise RequestError(cause=e) from e + except ssl.SSLCertVerificationError as e: + raise CertificateVerifyError(cause=e) from e + except ssl.SSLError as e: + raise SSLError(cause=e) from e + except websockets.exceptions.InvalidStatus as e: + raise HTTPError( + Response( + fp=io.BytesIO(e.response.body), + url=request.url, + headers=e.response.headers, + status=e.response.status_code, + reason=e.response.reason_phrase), + ) from e + except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e: + raise TransportError(cause=e) from e diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py new file mode 100644 index 0000000..39442ba --- /dev/null +++ b/yt_dlp/networking/common.py @@ -0,0 +1,565 @@ +from __future__ import annotations + +import abc +import copy +import enum +import functools +import io +import typing +import urllib.parse +import urllib.request +import urllib.response +from collections.abc import Iterable, Mapping +from email.message import Message +from http import HTTPStatus + +from ._helper import make_ssl_context, wrap_request_errors +from .exceptions import ( + NoSupportingHandlers, + RequestError, + TransportError, + UnsupportedRequest, +) +from ..compat.types import NoneType +from ..cookies import YoutubeDLCookieJar +from ..utils import ( + bug_reports_message, + classproperty, + deprecation_warning, + error_to_str, + update_url_query, +) +from ..utils.networking import HTTPHeaderDict, normalize_url + + +def register_preference(*handlers: type[RequestHandler]): + assert all(issubclass(handler, RequestHandler) for handler in handlers) + + def outer(preference: Preference): + @functools.wraps(preference) + def inner(handler, *args, **kwargs): + if not handlers or isinstance(handler, handlers): + return preference(handler, *args, **kwargs) + return 0 + _RH_PREFERENCES.add(inner) + return inner + return outer + + +class RequestDirector: + """RequestDirector class + + Helper class that, when given a request, forward it to a RequestHandler that supports it. + + Preference functions in the form of func(handler, request) -> int + can be registered into the `preferences` set. These are used to sort handlers + in order of preference. + + @param logger: Logger instance. + @param verbose: Print debug request information to stdout. + """ + + def __init__(self, logger, verbose=False): + self.handlers: dict[str, RequestHandler] = {} + self.preferences: set[Preference] = set() + self.logger = logger # TODO(Grub4k): default logger + self.verbose = verbose + + def close(self): + for handler in self.handlers.values(): + handler.close() + self.handlers.clear() + + def add_handler(self, handler: RequestHandler): + """Add a handler. If a handler of the same RH_KEY exists, it will overwrite it""" + assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler' + self.handlers[handler.RH_KEY] = handler + + def _get_handlers(self, request: Request) -> list[RequestHandler]: + """Sorts handlers by preference, given a request""" + preferences = { + rh: sum(pref(rh, request) for pref in self.preferences) + for rh in self.handlers.values() + } + self._print_verbose('Handler preferences for this request: %s' % ', '.join( + f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items())) + return sorted(self.handlers.values(), key=preferences.get, reverse=True) + + def _print_verbose(self, msg): + if self.verbose: + self.logger.stdout(f'director: {msg}') + + def send(self, request: Request) -> Response: + """ + Passes a request onto a suitable RequestHandler + """ + if not self.handlers: + raise RequestError('No request handlers configured') + + assert isinstance(request, Request) + + unexpected_errors = [] + unsupported_errors = [] + for handler in self._get_handlers(request): + self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.') + try: + handler.validate(request) + except UnsupportedRequest as e: + self._print_verbose( + f'"{handler.RH_NAME}" cannot handle this request (reason: {error_to_str(e)})') + unsupported_errors.append(e) + continue + + self._print_verbose(f'Sending request via "{handler.RH_NAME}"') + try: + response = handler.send(request) + except RequestError: + raise + except Exception as e: + self.logger.error( + f'[{handler.RH_NAME}] Unexpected error: {error_to_str(e)}{bug_reports_message()}', + is_error=False) + unexpected_errors.append(e) + continue + + assert isinstance(response, Response) + return response + + raise NoSupportingHandlers(unsupported_errors, unexpected_errors) + + +_REQUEST_HANDLERS = {} + + +def register_rh(handler): + """Register a RequestHandler class""" + assert issubclass(handler, RequestHandler), f'{handler} must be a subclass of RequestHandler' + assert handler.RH_KEY not in _REQUEST_HANDLERS, f'RequestHandler {handler.RH_KEY} already registered' + _REQUEST_HANDLERS[handler.RH_KEY] = handler + return handler + + +class Features(enum.Enum): + ALL_PROXY = enum.auto() + NO_PROXY = enum.auto() + + +class RequestHandler(abc.ABC): + + """Request Handler class + + Request handlers are class that, given a Request, + process the request from start to finish and return a Response. + + Concrete subclasses need to redefine the _send(request) method, + which handles the underlying request logic and returns a Response. + + RH_NAME class variable may contain a display name for the RequestHandler. + By default, this is generated from the class name. + + The concrete request handler MUST have "RH" as the suffix in the class name. + + All exceptions raised by a RequestHandler should be an instance of RequestError. + Any other exception raised will be treated as a handler issue. + + If a Request is not supported by the handler, an UnsupportedRequest + should be raised with a reason. + + By default, some checks are done on the request in _validate() based on the following class variables: + - `_SUPPORTED_URL_SCHEMES`: a tuple of supported url schemes. + Any Request with an url scheme not in this list will raise an UnsupportedRequest. + + - `_SUPPORTED_PROXY_SCHEMES`: a tuple of support proxy url schemes. Any Request that contains + a proxy url with an url scheme not in this list will raise an UnsupportedRequest. + + - `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum. + + The above may be set to None to disable the checks. + + Parameters: + @param logger: logger instance + @param headers: HTTP Headers to include when sending requests. + @param cookiejar: Cookiejar to use for requests. + @param timeout: Socket timeout to use when sending requests. + @param proxies: Proxies to use for sending requests. + @param source_address: Client-side IP address to bind to for requests. + @param verbose: Print debug request and traffic information to stdout. + @param prefer_system_certs: Whether to prefer system certificates over other means (e.g. certifi). + @param client_cert: SSL client certificate configuration. + dict with {client_certificate, client_certificate_key, client_certificate_password} + @param verify: Verify SSL certificates + @param legacy_ssl_support: Enable legacy SSL options such as legacy server connect and older cipher support. + + Some configuration options may be available for individual Requests too. In this case, + either the Request configuration option takes precedence or they are merged. + + Requests may have additional optional parameters defined as extensions. + RequestHandler subclasses may choose to support custom extensions. + + If an extension is supported, subclasses should extend _check_extensions(extensions) + to pop and validate the extension. + - Extensions left in `extensions` are treated as unsupported and UnsupportedRequest will be raised. + + The following extensions are defined for RequestHandler: + - `cookiejar`: Cookiejar to use for this request. + - `timeout`: socket timeout to use for this request. + To enable these, add extensions.pop('', None) to _check_extensions + + Apart from the url protocol, proxies dict may contain the following keys: + - `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol. + - `no`: comma seperated list of hostnames (optionally with port) to not use a proxy for. + Note: a RequestHandler may not support these, as defined in `_SUPPORTED_FEATURES`. + + """ + + _SUPPORTED_URL_SCHEMES = () + _SUPPORTED_PROXY_SCHEMES = () + _SUPPORTED_FEATURES = () + + def __init__( + self, *, + logger, # TODO(Grub4k): default logger + headers: HTTPHeaderDict = None, + cookiejar: YoutubeDLCookieJar = None, + timeout: float | int | None = None, + proxies: dict = None, + source_address: str = None, + verbose: bool = False, + prefer_system_certs: bool = False, + client_cert: dict[str, str | None] = None, + verify: bool = True, + legacy_ssl_support: bool = False, + **_, + ): + + self._logger = logger + self.headers = headers or {} + self.cookiejar = cookiejar if cookiejar is not None else YoutubeDLCookieJar() + self.timeout = float(timeout or 20) + self.proxies = proxies or {} + self.source_address = source_address + self.verbose = verbose + self.prefer_system_certs = prefer_system_certs + self._client_cert = client_cert or {} + self.verify = verify + self.legacy_ssl_support = legacy_ssl_support + super().__init__() + + def _make_sslcontext(self): + return make_ssl_context( + verify=self.verify, + legacy_support=self.legacy_ssl_support, + use_certifi=not self.prefer_system_certs, + **self._client_cert, + ) + + def _merge_headers(self, request_headers): + return HTTPHeaderDict(self.headers, request_headers) + + def _check_url_scheme(self, request: Request): + scheme = urllib.parse.urlparse(request.url).scheme.lower() + if self._SUPPORTED_URL_SCHEMES is not None and scheme not in self._SUPPORTED_URL_SCHEMES: + raise UnsupportedRequest(f'Unsupported url scheme: "{scheme}"') + return scheme # for further processing + + def _check_proxies(self, proxies): + for proxy_key, proxy_url in proxies.items(): + if proxy_url is None: + continue + if proxy_key == 'no': + if self._SUPPORTED_FEATURES is not None and Features.NO_PROXY not in self._SUPPORTED_FEATURES: + raise UnsupportedRequest('"no" proxy is not supported') + continue + if ( + proxy_key == 'all' + and self._SUPPORTED_FEATURES is not None + and Features.ALL_PROXY not in self._SUPPORTED_FEATURES + ): + raise UnsupportedRequest('"all" proxy is not supported') + + # Unlikely this handler will use this proxy, so ignore. + # This is to allow a case where a proxy may be set for a protocol + # for one handler in which such protocol (and proxy) is not supported by another handler. + if self._SUPPORTED_URL_SCHEMES is not None and proxy_key not in (*self._SUPPORTED_URL_SCHEMES, 'all'): + continue + + if self._SUPPORTED_PROXY_SCHEMES is None: + # Skip proxy scheme checks + continue + + try: + if urllib.request._parse_proxy(proxy_url)[0] is None: + # Scheme-less proxies are not supported + raise UnsupportedRequest(f'Proxy "{proxy_url}" missing scheme') + except ValueError as e: + # parse_proxy may raise on some invalid proxy urls such as "/a/b/c" + raise UnsupportedRequest(f'Invalid proxy url "{proxy_url}": {e}') + + scheme = urllib.parse.urlparse(proxy_url).scheme.lower() + if scheme not in self._SUPPORTED_PROXY_SCHEMES: + raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"') + + def _check_extensions(self, extensions): + """Check extensions for unsupported extensions. Subclasses should extend this.""" + assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType)) + assert isinstance(extensions.get('timeout'), (float, int, NoneType)) + + def _validate(self, request): + self._check_url_scheme(request) + self._check_proxies(request.proxies or self.proxies) + extensions = request.extensions.copy() + self._check_extensions(extensions) + if extensions: + # TODO: add support for optional extensions + raise UnsupportedRequest(f'Unsupported extensions: {", ".join(extensions.keys())}') + + @wrap_request_errors + def validate(self, request: Request): + if not isinstance(request, Request): + raise TypeError('Expected an instance of Request') + self._validate(request) + + @wrap_request_errors + def send(self, request: Request) -> Response: + if not isinstance(request, Request): + raise TypeError('Expected an instance of Request') + return self._send(request) + + @abc.abstractmethod + def _send(self, request: Request): + """Handle a request from start to finish. Redefine in subclasses.""" + pass + + def close(self): + pass + + @classproperty + def RH_NAME(cls): + return cls.__name__[:-2] + + @classproperty + def RH_KEY(cls): + assert cls.__name__.endswith('RH'), 'RequestHandler class names must end with "RH"' + return cls.__name__[:-2] + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class Request: + """ + Represents a request to be made. + Partially backwards-compatible with urllib.request.Request. + + @param url: url to send. Will be sanitized. + @param data: payload data to send. Must be bytes, iterable of bytes, a file-like object or None + @param headers: headers to send. + @param proxies: proxy dict mapping of proto:proxy to use for the request and any redirects. + @param query: URL query parameters to update the url with. + @param method: HTTP method to use. If no method specified, will use POST if payload data is present else GET + @param extensions: Dictionary of Request extensions to add, as supported by handlers. + """ + + def __init__( + self, + url: str, + data: RequestData = None, + headers: typing.Mapping = None, + proxies: dict = None, + query: dict = None, + method: str = None, + extensions: dict = None + ): + + self._headers = HTTPHeaderDict() + self._data = None + + if query: + url = update_url_query(url, query) + + self.url = url + self.method = method + if headers: + self.headers = headers + self.data = data # note: must be done after setting headers + self.proxies = proxies or {} + self.extensions = extensions or {} + + @property + def url(self): + return self._url + + @url.setter + def url(self, url): + if not isinstance(url, str): + raise TypeError('url must be a string') + elif url.startswith('//'): + url = 'http:' + url + self._url = normalize_url(url) + + @property + def method(self): + return self._method or ('POST' if self.data is not None else 'GET') + + @method.setter + def method(self, method): + if method is None: + self._method = None + elif isinstance(method, str): + self._method = method.upper() + else: + raise TypeError('method must be a string') + + @property + def data(self): + return self._data + + @data.setter + def data(self, data: RequestData): + # Try catch some common mistakes + if data is not None and ( + not isinstance(data, (bytes, io.IOBase, Iterable)) or isinstance(data, (str, Mapping)) + ): + raise TypeError('data must be bytes, iterable of bytes, or a file-like object') + + if data == self._data and self._data is None: + self.headers.pop('Content-Length', None) + + # https://docs.python.org/3/library/urllib.request.html#urllib.request.Request.data + if data != self._data: + if self._data is not None: + self.headers.pop('Content-Length', None) + self._data = data + + if self._data is None: + self.headers.pop('Content-Type', None) + + if 'Content-Type' not in self.headers and self._data is not None: + self.headers['Content-Type'] = 'application/x-www-form-urlencoded' + + @property + def headers(self) -> HTTPHeaderDict: + return self._headers + + @headers.setter + def headers(self, new_headers: Mapping): + """Replaces headers of the request. If not a HTTPHeaderDict, it will be converted to one.""" + if isinstance(new_headers, HTTPHeaderDict): + self._headers = new_headers + elif isinstance(new_headers, Mapping): + self._headers = HTTPHeaderDict(new_headers) + else: + raise TypeError('headers must be a mapping') + + def update(self, url=None, data=None, headers=None, query=None): + self.data = data if data is not None else self.data + self.headers.update(headers or {}) + self.url = update_url_query(url or self.url, query or {}) + + def copy(self): + return self.__class__( + url=self.url, + headers=copy.deepcopy(self.headers), + proxies=copy.deepcopy(self.proxies), + data=self._data, + extensions=copy.copy(self.extensions), + method=self._method, + ) + + +HEADRequest = functools.partial(Request, method='HEAD') +PUTRequest = functools.partial(Request, method='PUT') + + +class Response(io.IOBase): + """ + Base class for HTTP response adapters. + + By default, it provides a basic wrapper for a file-like response object. + + Interface partially backwards-compatible with addinfourl and http.client.HTTPResponse. + + @param fp: Original, file-like, response. + @param url: URL that this is a response of. + @param headers: response headers. + @param status: Response HTTP status code. Default is 200 OK. + @param reason: HTTP status reason. Will use built-in reasons based on status code if not provided. + """ + + def __init__( + self, + fp: typing.IO, + url: str, + headers: Mapping[str, str], + status: int = 200, + reason: str = None): + + self.fp = fp + self.headers = Message() + for name, value in headers.items(): + self.headers.add_header(name, value) + self.status = status + self.url = url + try: + self.reason = reason or HTTPStatus(status).phrase + except ValueError: + self.reason = None + + def readable(self): + return self.fp.readable() + + def read(self, amt: int = None) -> bytes: + # Expected errors raised here should be of type RequestError or subclasses. + # Subclasses should redefine this method with more precise error handling. + try: + return self.fp.read(amt) + except Exception as e: + raise TransportError(cause=e) from e + + def close(self): + self.fp.close() + return super().close() + + def get_header(self, name, default=None): + """Get header for name. + If there are multiple matching headers, return all seperated by comma.""" + headers = self.headers.get_all(name) + if not headers: + return default + if name.title() == 'Set-Cookie': + # Special case, only get the first one + # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.3-4.1 + return headers[0] + return ', '.join(headers) + + # The following methods are for compatability reasons and are deprecated + @property + def code(self): + deprecation_warning('Response.code is deprecated, use Response.status', stacklevel=2) + return self.status + + def getcode(self): + deprecation_warning('Response.getcode() is deprecated, use Response.status', stacklevel=2) + return self.status + + def geturl(self): + deprecation_warning('Response.geturl() is deprecated, use Response.url', stacklevel=2) + return self.url + + def info(self): + deprecation_warning('Response.info() is deprecated, use Response.headers', stacklevel=2) + return self.headers + + def getheader(self, name, default=None): + deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2) + return self.get_header(name, default) + + +if typing.TYPE_CHECKING: + RequestData = bytes | Iterable[bytes] | typing.IO | None + Preference = typing.Callable[[RequestHandler, Request], int] + +_RH_PREFERENCES: set[Preference] = set() diff --git a/yt_dlp/networking/exceptions.py b/yt_dlp/networking/exceptions.py new file mode 100644 index 0000000..9037f18 --- /dev/null +++ b/yt_dlp/networking/exceptions.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import typing + +from ..utils import YoutubeDLError + +if typing.TYPE_CHECKING: + from .common import RequestHandler, Response + + +class RequestError(YoutubeDLError): + def __init__( + self, + msg: str | None = None, + cause: Exception | str | None = None, + handler: RequestHandler = None + ): + self.handler = handler + self.cause = cause + if not msg and cause: + msg = str(cause) + super().__init__(msg) + + +class UnsupportedRequest(RequestError): + """raised when a handler cannot handle a request""" + pass + + +class NoSupportingHandlers(RequestError): + """raised when no handlers can support a request for various reasons""" + + def __init__(self, unsupported_errors: list[UnsupportedRequest], unexpected_errors: list[Exception]): + self.unsupported_errors = unsupported_errors or [] + self.unexpected_errors = unexpected_errors or [] + + # Print a quick summary of the errors + err_handler_map = {} + for err in unsupported_errors: + err_handler_map.setdefault(err.msg, []).append(err.handler.RH_NAME) + + reason_str = ', '.join([f'{msg} ({", ".join(handlers)})' for msg, handlers in err_handler_map.items()]) + if unexpected_errors: + reason_str = ' + '.join(filter(None, [reason_str, f'{len(unexpected_errors)} unexpected error(s)'])) + + err_str = 'Unable to handle request' + if reason_str: + err_str += f': {reason_str}' + + super().__init__(msg=err_str) + + +class TransportError(RequestError): + """Network related errors""" + + +class HTTPError(RequestError): + def __init__(self, response: Response, redirect_loop=False): + self.response = response + self.status = response.status + self.reason = response.reason + self.redirect_loop = redirect_loop + msg = f'HTTP Error {response.status}: {response.reason}' + if redirect_loop: + msg += ' (redirect loop detected)' + + super().__init__(msg=msg) + + def close(self): + self.response.close() + + def __repr__(self): + return f'' + + +class IncompleteRead(TransportError): + def __init__(self, partial: int, expected: int | None = None, **kwargs): + self.partial = partial + self.expected = expected + msg = f'{partial} bytes read' + if expected is not None: + msg += f', {expected} more expected' + + super().__init__(msg=msg, **kwargs) + + def __repr__(self): + return f'' + + +class SSLError(TransportError): + pass + + +class CertificateVerifyError(SSLError): + """Raised when certificate validated has failed""" + pass + + +class ProxyError(TransportError): + pass + + +network_exceptions = (HTTPError, TransportError) diff --git a/yt_dlp/networking/websocket.py b/yt_dlp/networking/websocket.py new file mode 100644 index 0000000..0e7e73c --- /dev/null +++ b/yt_dlp/networking/websocket.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import abc + +from .common import RequestHandler, Response + + +class WebSocketResponse(Response): + + def send(self, message: bytes | str): + """ + Send a message to the server. + + @param message: The message to send. A string (str) is sent as a text frame, bytes is sent as a binary frame. + """ + raise NotImplementedError + + def recv(self): + raise NotImplementedError + + +class WebSocketRequestHandler(RequestHandler, abc.ABC): + pass -- cgit v1.2.3