diff options
Diffstat (limited to 'yt_dlp/networking/_helper.py')
-rw-r--r-- | yt_dlp/networking/_helper.py | 283 |
1 files changed, 283 insertions, 0 deletions
diff --git a/yt_dlp/networking/_helper.py b/yt_dlp/networking/_helper.py new file mode 100644 index 0000000..d79dd79 --- /dev/null +++ b/yt_dlp/networking/_helper.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import contextlib +import functools +import socket +import ssl +import sys +import typing +import urllib.parse +import urllib.request + +from .exceptions import RequestError, UnsupportedRequest +from ..dependencies import certifi +from ..socks import ProxyType, sockssocket +from ..utils import format_field, traverse_obj + +if typing.TYPE_CHECKING: + from collections.abc import Iterable + + from ..utils.networking import HTTPHeaderDict + + +def ssl_load_certs(context: ssl.SSLContext, use_certifi=True): + if certifi and use_certifi: + context.load_verify_locations(cafile=certifi.where()) + else: + try: + context.load_default_certs() + # Work around the issue in load_default_certs when there are bad certificates. See: + # https://github.com/yt-dlp/yt-dlp/issues/1060, + # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312 + except ssl.SSLError: + # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151 + if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'): + for storename in ('CA', 'ROOT'): + ssl_load_windows_store_certs(context, storename) + context.set_default_verify_paths() + + +def ssl_load_windows_store_certs(ssl_context, storename): + # Code adapted from _load_windows_store_certs in https://github.com/python/cpython/blob/main/Lib/ssl.py + try: + certs = [cert for cert, encoding, trust in ssl.enum_certificates(storename) + if encoding == 'x509_asn' and ( + trust is True or ssl.Purpose.SERVER_AUTH.oid in trust)] + except PermissionError: + return + for cert in certs: + with contextlib.suppress(ssl.SSLError): + ssl_context.load_verify_locations(cadata=cert) + + +def make_socks_proxy_opts(socks_proxy): + url_components = urllib.parse.urlparse(socks_proxy) + if url_components.scheme.lower() == 'socks5': + socks_type = ProxyType.SOCKS5 + rdns = False + elif url_components.scheme.lower() == 'socks5h': + socks_type = ProxyType.SOCKS5 + rdns = True + elif url_components.scheme.lower() == 'socks4': + socks_type = ProxyType.SOCKS4 + rdns = False + elif url_components.scheme.lower() == 'socks4a': + socks_type = ProxyType.SOCKS4A + rdns = True + else: + raise ValueError(f'Unknown SOCKS proxy version: {url_components.scheme.lower()}') + + def unquote_if_non_empty(s): + if not s: + return s + return urllib.parse.unquote_plus(s) + return { + 'proxytype': socks_type, + 'addr': url_components.hostname, + 'port': url_components.port or 1080, + 'rdns': rdns, + 'username': unquote_if_non_empty(url_components.username), + 'password': unquote_if_non_empty(url_components.password), + } + + +def select_proxy(url, proxies): + """Unified proxy selector for all backends""" + url_components = urllib.parse.urlparse(url) + if 'no' in proxies: + hostport = url_components.hostname + format_field(url_components.port, None, ':%s') + if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}): + return + elif urllib.request.proxy_bypass(hostport): # check system settings + return + + return traverse_obj(proxies, url_components.scheme or 'http', 'all') + + +def get_redirect_method(method, status): + """Unified redirect method handling""" + + # A 303 must either use GET or HEAD for subsequent request + # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.4 + if status == 303 and method != 'HEAD': + method = 'GET' + # 301 and 302 redirects are commonly turned into a GET from a POST + # for subsequent requests by browsers, so we'll do the same. + # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.2 + # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.3 + if status in (301, 302) and method == 'POST': + method = 'GET' + return method + + +def make_ssl_context( + verify=True, + client_certificate=None, + client_certificate_key=None, + client_certificate_password=None, + legacy_support=False, + use_certifi=True, +): + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = verify + context.verify_mode = ssl.CERT_REQUIRED if verify else ssl.CERT_NONE + + # Some servers may reject requests if ALPN extension is not sent. See: + # https://github.com/python/cpython/issues/85140 + # https://github.com/yt-dlp/yt-dlp/issues/3878 + with contextlib.suppress(NotImplementedError): + context.set_alpn_protocols(['http/1.1']) + if verify: + ssl_load_certs(context, use_certifi) + + if legacy_support: + context.options |= 4 # SSL_OP_LEGACY_SERVER_CONNECT + context.set_ciphers('DEFAULT') # compat + + elif ssl.OPENSSL_VERSION_INFO >= (1, 1, 1) and not ssl.OPENSSL_VERSION.startswith('LibreSSL'): + # Use the default SSL ciphers and minimum TLS version settings from Python 3.10 [1]. + # This is to ensure consistent behavior across Python versions and libraries, and help avoid fingerprinting + # in some situations [2][3]. + # Python 3.10 only supports OpenSSL 1.1.1+ [4]. Because this change is likely + # untested on older versions, we only apply this to OpenSSL 1.1.1+ to be safe. + # LibreSSL is excluded until further investigation due to cipher support issues [5][6]. + # 1. https://github.com/python/cpython/commit/e983252b516edb15d4338b0a47631b59ef1e2536 + # 2. https://github.com/yt-dlp/yt-dlp/issues/4627 + # 3. https://github.com/yt-dlp/yt-dlp/pull/5294 + # 4. https://peps.python.org/pep-0644/ + # 5. https://peps.python.org/pep-0644/#libressl-support + # 6. https://github.com/yt-dlp/yt-dlp/commit/5b9f253fa0aee996cf1ed30185d4b502e00609c4#commitcomment-89054368 + context.set_ciphers( + '@SECLEVEL=2:ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES:DHE+AES:!aNULL:!eNULL:!aDSS:!SHA1:!AESCCM') + context.minimum_version = ssl.TLSVersion.TLSv1_2 + + if client_certificate: + try: + context.load_cert_chain( + client_certificate, keyfile=client_certificate_key, + password=client_certificate_password) + except ssl.SSLError: + raise RequestError('Unable to load client certificate') + + if getattr(context, 'post_handshake_auth', None) is not None: + context.post_handshake_auth = True + return context + + +class InstanceStoreMixin: + def __init__(self, **kwargs): + self.__instances = [] + super().__init__(**kwargs) # So that both MRO works + + @staticmethod + def _create_instance(**kwargs): + raise NotImplementedError + + def _get_instance(self, **kwargs): + for key, instance in self.__instances: + if key == kwargs: + return instance + + instance = self._create_instance(**kwargs) + self.__instances.append((kwargs, instance)) + return instance + + def _close_instance(self, instance): + if callable(getattr(instance, 'close', None)): + instance.close() + + def _clear_instances(self): + for _, instance in self.__instances: + self._close_instance(instance) + self.__instances.clear() + + +def add_accept_encoding_header(headers: HTTPHeaderDict, supported_encodings: Iterable[str]): + if 'Accept-Encoding' not in headers: + headers['Accept-Encoding'] = ', '.join(supported_encodings) or 'identity' + + +def wrap_request_errors(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except UnsupportedRequest as e: + if e.handler is None: + e.handler = self + raise + return wrapper + + +def _socket_connect(ip_addr, timeout, source_address): + af, socktype, proto, canonname, sa = ip_addr + sock = socket.socket(af, socktype, proto) + try: + if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect(sa) + return sock + except OSError: + sock.close() + raise + + +def create_socks_proxy_socket(dest_addr, proxy_args, proxy_ip_addr, timeout, source_address): + af, socktype, proto, canonname, sa = proxy_ip_addr + sock = sockssocket(af, socktype, proto) + try: + connect_proxy_args = proxy_args.copy() + connect_proxy_args.update({'addr': sa[0], 'port': sa[1]}) + sock.setproxy(**connect_proxy_args) + if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: # noqa: E721 + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect(dest_addr) + return sock + except OSError: + sock.close() + raise + + +def create_connection( + address, + timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + source_address=None, + *, + _create_socket_func=_socket_connect +): + # Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6. + # This filters the addresses based on the given source_address. + # Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810 + host, port = address + ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM) + if not ip_addrs: + raise OSError('getaddrinfo returns an empty list') + if source_address is not None: + af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6 + ip_addrs = [addr for addr in ip_addrs if addr[0] == af] + if not ip_addrs: + raise OSError( + f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. ' + f'Can\'t use "{source_address[0]}" as source address') + + err = None + for ip_addr in ip_addrs: + try: + sock = _create_socket_func(ip_addr, timeout, source_address) + # Explicitly break __traceback__ reference cycle + # https://bugs.python.org/issue36820 + err = None + return sock + except OSError as e: + err = e + + try: + raise err + finally: + # Explicitly break __traceback__ reference cycle + # https://bugs.python.org/issue36820 + err = None |