summaryrefslogtreecommitdiffstats
path: root/yt_dlp/utils/_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'yt_dlp/utils/_utils.py')
-rw-r--r--yt_dlp/utils/_utils.py311
1 files changed, 202 insertions, 109 deletions
diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py
index 42803bb..b5e1e29 100644
--- a/yt_dlp/utils/_utils.py
+++ b/yt_dlp/utils/_utils.py
@@ -53,7 +53,7 @@ from ..compat import (
)
from ..dependencies import xattr
-__name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module
+__name__ = __name__.rsplit('.', 1)[0] # noqa: A001: Pretend to be the parent module
# This is not clearly defined otherwise
compiled_regex_type = type(re.compile(''))
@@ -90,7 +90,7 @@ TIMEZONE_NAMES = {
'EST': -5, 'EDT': -4, # Eastern
'CST': -6, 'CDT': -5, # Central
'MST': -7, 'MDT': -6, # Mountain
- 'PST': -8, 'PDT': -7 # Pacific
+ 'PST': -8, 'PDT': -7, # Pacific
}
# needed for sanitizing filenames in restricted mode
@@ -215,7 +215,7 @@ def write_json_file(obj, fn):
def find_xpath_attr(node, xpath, key, val=None):
""" Find the xpath xpath[@key=val] """
assert re.match(r'^[a-zA-Z_-]+$', key)
- expr = xpath + ('[@%s]' % key if val is None else f"[@{key}='{val}']")
+ expr = xpath + (f'[@{key}]' if val is None else f"[@{key}='{val}']")
return node.find(expr)
# On python2.6 the xml.etree.ElementTree.Element methods don't support
@@ -230,7 +230,7 @@ def xpath_with_ns(path, ns_map):
replaced.append(c[0])
else:
ns, tag = c
- replaced.append('{%s}%s' % (ns_map[ns], tag))
+ replaced.append(f'{{{ns_map[ns]}}}{tag}')
return '/'.join(replaced)
@@ -251,7 +251,7 @@ def xpath_element(node, xpath, name=None, fatal=False, default=NO_DEFAULT):
return default
elif fatal:
name = xpath if name is None else name
- raise ExtractorError('Could not find XML element %s' % name)
+ raise ExtractorError(f'Could not find XML element {name}')
else:
return None
return n
@@ -266,7 +266,7 @@ def xpath_text(node, xpath, name=None, fatal=False, default=NO_DEFAULT):
return default
elif fatal:
name = xpath if name is None else name
- raise ExtractorError('Could not find XML element\'s text %s' % name)
+ raise ExtractorError(f'Could not find XML element\'s text {name}')
else:
return None
return n.text
@@ -279,7 +279,7 @@ def xpath_attr(node, xpath, key, name=None, fatal=False, default=NO_DEFAULT):
return default
elif fatal:
name = f'{xpath}[@{key}]' if name is None else name
- raise ExtractorError('Could not find XML attribute %s' % name)
+ raise ExtractorError(f'Could not find XML attribute {name}')
else:
return None
return n.attrib[key]
@@ -320,14 +320,14 @@ def get_element_html_by_attribute(attribute, value, html, **kargs):
def get_elements_by_class(class_name, html, **kargs):
"""Return the content of all tags with the specified class in the passed HTML document as a list"""
return get_elements_by_attribute(
- 'class', r'[^\'"]*(?<=[\'"\s])%s(?=[\'"\s])[^\'"]*' % re.escape(class_name),
+ 'class', rf'[^\'"]*(?<=[\'"\s]){re.escape(class_name)}(?=[\'"\s])[^\'"]*',
html, escape_value=False)
def get_elements_html_by_class(class_name, html):
"""Return the html of all tags with the specified class in the passed HTML document as a list"""
return get_elements_html_by_attribute(
- 'class', r'[^\'"]*(?<=[\'"\s])%s(?=[\'"\s])[^\'"]*' % re.escape(class_name),
+ 'class', rf'[^\'"]*(?<=[\'"\s]){re.escape(class_name)}(?=[\'"\s])[^\'"]*',
html, escape_value=False)
@@ -364,7 +364,7 @@ def get_elements_text_and_html_by_attribute(attribute, value, html, *, tag=r'[\w
yield (
unescapeHTML(re.sub(r'^(?P<q>["\'])(?P<content>.*)(?P=q)$', r'\g<content>', content, flags=re.DOTALL)),
- whole
+ whole,
)
@@ -407,7 +407,7 @@ class HTMLBreakOnClosingTagParser(html.parser.HTMLParser):
else:
raise compat_HTMLParseError(f'matching opening tag for closing {tag} tag not found')
if not self.tagstack:
- raise self.HTMLBreakOnClosingTagException()
+ raise self.HTMLBreakOnClosingTagException
# XXX: This should be far less strict
@@ -587,7 +587,7 @@ def sanitize_open(filename, open_mode):
# FIXME: An exclusive lock also locks the file from being read.
# Since windows locks are mandatory, don't lock the file on windows (for now).
# Ref: https://github.com/yt-dlp/yt-dlp/issues/3124
- raise LockingUnsupportedError()
+ raise LockingUnsupportedError
stream = locked_file(filename, open_mode, block=False).__enter__()
except OSError:
stream = open(filename, open_mode)
@@ -717,9 +717,9 @@ def extract_basic_auth(url):
return url, None
url = urllib.parse.urlunsplit(parts._replace(netloc=(
parts.hostname if parts.port is None
- else '%s:%d' % (parts.hostname, parts.port))))
+ else f'{parts.hostname}:{parts.port}')))
auth_payload = base64.b64encode(
- ('%s:%s' % (parts.username, parts.password or '')).encode())
+ ('{}:{}'.format(parts.username, parts.password or '')).encode())
return url, f'Basic {auth_payload.decode()}'
@@ -758,7 +758,7 @@ def _htmlentity_transform(entity_with_semicolon):
numstr = mobj.group(1)
if numstr.startswith('x'):
base = 16
- numstr = '0%s' % numstr
+ numstr = f'0{numstr}'
else:
base = 10
# See https://github.com/ytdl-org/youtube-dl/issues/7518
@@ -766,7 +766,7 @@ def _htmlentity_transform(entity_with_semicolon):
return chr(int(numstr, base))
# Unknown entity in name, return its literal representation
- return '&%s;' % entity
+ return f'&{entity};'
def unescapeHTML(s):
@@ -970,7 +970,7 @@ class ExtractorError(YoutubeDLError):
class UnsupportedError(ExtractorError):
def __init__(self, url):
super().__init__(
- 'Unsupported URL: %s' % url, expected=True)
+ f'Unsupported URL: {url}', expected=True)
self.url = url
@@ -1367,7 +1367,7 @@ class DateRange:
else:
self.end = dt.datetime.max.date()
if self.start > self.end:
- raise ValueError('Date range: "%s" , the start date must be before the end date' % self)
+ raise ValueError(f'Date range: "{self}" , the start date must be before the end date')
@classmethod
def day(cls, day):
@@ -1400,7 +1400,7 @@ def system_identifier():
with contextlib.suppress(OSError): # We may not have access to the executable
libc_ver = platform.libc_ver()
- return 'Python %s (%s %s %s) - %s (%s%s)' % (
+ return 'Python {} ({} {} {}) - {} ({}{})'.format(
platform.python_version(),
python_implementation,
platform.machine(),
@@ -1413,7 +1413,7 @@ def system_identifier():
@functools.cache
def get_windows_version():
- ''' Get Windows version. returns () if it's not running on Windows '''
+ """ Get Windows version. returns () if it's not running on Windows """
if compat_os_name == 'nt':
return version_tuple(platform.win32_ver()[1])
else:
@@ -1505,7 +1505,7 @@ if sys.platform == 'win32':
ctypes.wintypes.DWORD, # dwReserved
ctypes.wintypes.DWORD, # nNumberOfBytesToLockLow
ctypes.wintypes.DWORD, # nNumberOfBytesToLockHigh
- ctypes.POINTER(OVERLAPPED) # Overlapped
+ ctypes.POINTER(OVERLAPPED), # Overlapped
]
LockFileEx.restype = ctypes.wintypes.BOOL
UnlockFileEx = kernel32.UnlockFileEx
@@ -1514,7 +1514,7 @@ if sys.platform == 'win32':
ctypes.wintypes.DWORD, # dwReserved
ctypes.wintypes.DWORD, # nNumberOfBytesToLockLow
ctypes.wintypes.DWORD, # nNumberOfBytesToLockHigh
- ctypes.POINTER(OVERLAPPED) # Overlapped
+ ctypes.POINTER(OVERLAPPED), # Overlapped
]
UnlockFileEx.restype = ctypes.wintypes.BOOL
whole_low = 0xffffffff
@@ -1537,7 +1537,7 @@ if sys.platform == 'win32':
assert f._lock_file_overlapped_p
handle = msvcrt.get_osfhandle(f.fileno())
if not UnlockFileEx(handle, 0, whole_low, whole_high, f._lock_file_overlapped_p):
- raise OSError('Unlocking file failed: %r' % ctypes.FormatError())
+ raise OSError(f'Unlocking file failed: {ctypes.FormatError()!r}')
else:
try:
@@ -1564,10 +1564,10 @@ else:
except ImportError:
def _lock_file(f, exclusive, block):
- raise LockingUnsupportedError()
+ raise LockingUnsupportedError
def _unlock_file(f):
- raise LockingUnsupportedError()
+ raise LockingUnsupportedError
class locked_file:
@@ -1926,7 +1926,7 @@ def remove_end(s, end):
def remove_quotes(s):
if s is None or len(s) < 2:
return s
- for quote in ('"', "'", ):
+ for quote in ('"', "'"):
if s[0] == quote and s[-1] == quote:
return s[1:-1]
return s
@@ -2085,26 +2085,27 @@ def parse_duration(s):
(days, 86400), (hours, 3600), (mins, 60), (secs, 1), (ms, 1)))
-def prepend_extension(filename, ext, expected_real_ext=None):
+def _change_extension(prepend, filename, ext, expected_real_ext=None):
name, real_ext = os.path.splitext(filename)
- return (
- f'{name}.{ext}{real_ext}'
- if not expected_real_ext or real_ext[1:] == expected_real_ext
- else f'{filename}.{ext}')
+ if not expected_real_ext or real_ext[1:] == expected_real_ext:
+ filename = name
+ if prepend and real_ext:
+ _UnsafeExtensionError.sanitize_extension(ext, prepend=True)
+ return f'{filename}.{ext}{real_ext}'
-def replace_extension(filename, ext, expected_real_ext=None):
- name, real_ext = os.path.splitext(filename)
- return '{}.{}'.format(
- name if not expected_real_ext or real_ext[1:] == expected_real_ext else filename,
- ext)
+ return f'{filename}.{_UnsafeExtensionError.sanitize_extension(ext)}'
+
+
+prepend_extension = functools.partial(_change_extension, True)
+replace_extension = functools.partial(_change_extension, False)
def check_executable(exe, args=[]):
""" Checks if the given binary is installed somewhere in PATH, and returns its name.
args can be a list of arguments for a short output (like -version) """
try:
- Popen.run([exe] + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ Popen.run([exe, *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except OSError:
return False
return exe
@@ -2115,7 +2116,7 @@ def _get_exe_version_output(exe, args):
# STDIN should be redirected too. On UNIX-like systems, ffmpeg triggers
# SIGTTOU if yt-dlp is run in the background.
# See https://github.com/ytdl-org/youtube-dl/issues/955#issuecomment-209789656
- stdout, _, ret = Popen.run([encodeArgument(exe)] + args, text=True,
+ stdout, _, ret = Popen.run([encodeArgument(exe), *args], text=True,
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
if ret:
return None
@@ -2161,7 +2162,7 @@ class LazyList(collections.abc.Sequence):
"""Lazy immutable list from an iterable
Note that slices of a LazyList are lists and not LazyList"""
- class IndexError(IndexError):
+ class IndexError(IndexError): # noqa: A001
pass
def __init__(self, iterable, *, reverse=False, _cache=None):
@@ -2248,7 +2249,7 @@ class LazyList(collections.abc.Sequence):
class PagedList:
- class IndexError(IndexError):
+ class IndexError(IndexError): # noqa: A001
pass
def __len__(self):
@@ -2282,7 +2283,7 @@ class PagedList:
raise TypeError('indices must be non-negative integers')
entries = self.getslice(idx, idx + 1)
if not entries:
- raise self.IndexError()
+ raise self.IndexError
return entries[0]
def __bool__(self):
@@ -2443,7 +2444,7 @@ class PlaylistEntries:
except IndexError:
entry = self.MissingEntry
if not self.is_incomplete:
- raise self.IndexError()
+ raise self.IndexError
if entry is self.MissingEntry:
raise EntryNotInPlaylist(f'Entry {i + 1} cannot be found')
return entry
@@ -2452,7 +2453,7 @@ class PlaylistEntries:
try:
return type(self.ydl)._handle_extraction_exceptions(lambda _, i: self._entries[i])(self.ydl, i)
except (LazyList.IndexError, PagedList.IndexError):
- raise self.IndexError()
+ raise self.IndexError
return get_entry
def __getitem__(self, idx):
@@ -2488,7 +2489,7 @@ class PlaylistEntries:
def __len__(self):
return len(tuple(self[:]))
- class IndexError(IndexError):
+ class IndexError(IndexError): # noqa: A001
pass
@@ -2550,7 +2551,7 @@ def update_url(url, *, query_update=None, **kwargs):
assert 'query' not in kwargs, 'query_update and query cannot be specified at the same time'
kwargs['query'] = urllib.parse.urlencode({
**urllib.parse.parse_qs(url.query),
- **query_update
+ **query_update,
}, True)
return urllib.parse.urlunparse(url._replace(**kwargs))
@@ -2560,7 +2561,7 @@ def update_url_query(url, query):
def _multipart_encode_impl(data, boundary):
- content_type = 'multipart/form-data; boundary=%s' % boundary
+ content_type = f'multipart/form-data; boundary={boundary}'
out = b''
for k, v in data.items():
@@ -2582,7 +2583,7 @@ def _multipart_encode_impl(data, boundary):
def multipart_encode(data, boundary=None):
- '''
+ """
Encode a dict to RFC 7578-compliant form-data
data:
@@ -2593,7 +2594,7 @@ def multipart_encode(data, boundary=None):
a random boundary is generated.
Reference: https://tools.ietf.org/html/rfc7578
- '''
+ """
has_specified_boundary = boundary is not None
while True:
@@ -2688,7 +2689,7 @@ def parse_age_limit(s):
s = s.upper()
if s in US_RATINGS:
return US_RATINGS[s]
- m = re.match(r'^TV[_-]?(%s)$' % '|'.join(k[3:] for k in TV_PARENTAL_GUIDELINES), s)
+ m = re.match(r'^TV[_-]?({})$'.format('|'.join(k[3:] for k in TV_PARENTAL_GUIDELINES)), s)
if m:
return TV_PARENTAL_GUIDELINES['TV-' + m.group(1)]
return None
@@ -2736,7 +2737,7 @@ def js_to_json(code, vars={}, *, strict=False):
return v
elif v in ('undefined', 'void 0'):
return 'null'
- elif v.startswith('/*') or v.startswith('//') or v.startswith('!') or v == ',':
+ elif v.startswith(('/*', '//', '!')) or v == ',':
return ''
if v[0] in STRING_QUOTES:
@@ -3079,7 +3080,7 @@ def urlhandle_detect_ext(url_handle, default=NO_DEFAULT):
def encode_data_uri(data, mime_type):
- return 'data:%s;base64,%s' % (mime_type, base64.b64encode(data).decode('ascii'))
+ return 'data:{};base64,{}'.format(mime_type, base64.b64encode(data).decode('ascii'))
def age_restricted(content_limit, age_limit):
@@ -3144,18 +3145,18 @@ def render_table(header_row, data, delim=False, extra_gap=0, hide_empty=False):
def get_max_lens(table):
return [max(width(str(v)) for v in col) for col in zip(*table)]
- def filter_using_list(row, filterArray):
- return [col for take, col in itertools.zip_longest(filterArray, row, fillvalue=True) if take]
+ def filter_using_list(row, filter_array):
+ return [col for take, col in itertools.zip_longest(filter_array, row, fillvalue=True) if take]
max_lens = get_max_lens(data) if hide_empty else []
header_row = filter_using_list(header_row, max_lens)
data = [filter_using_list(row, max_lens) for row in data]
- table = [header_row] + data
+ table = [header_row, *data]
max_lens = get_max_lens(table)
extra_gap += 1
if delim:
- table = [header_row, [delim * (ml + extra_gap) for ml in max_lens]] + data
+ table = [header_row, [delim * (ml + extra_gap) for ml in max_lens], *data]
table[1][-1] = table[1][-1][:-extra_gap * len(delim)] # Remove extra_gap from end of delimiter
for row in table:
for pos, text in enumerate(map(str, row)):
@@ -3163,8 +3164,7 @@ def render_table(header_row, data, delim=False, extra_gap=0, hide_empty=False):
row[pos] = text.replace('\t', ' ' * (max_lens[pos] - width(text))) + ' ' * extra_gap
else:
row[pos] = text + ' ' * (max_lens[pos] - width(text) + extra_gap)
- ret = '\n'.join(''.join(row).rstrip() for row in table)
- return ret
+ return '\n'.join(''.join(row).rstrip() for row in table)
def _match_one(filter_part, dct, incomplete):
@@ -3191,12 +3191,12 @@ def _match_one(filter_part, dct, incomplete):
operator_rex = re.compile(r'''(?x)
(?P<key>[a-z_]+)
- \s*(?P<negation>!\s*)?(?P<op>%s)(?P<none_inclusive>\s*\?)?\s*
+ \s*(?P<negation>!\s*)?(?P<op>{})(?P<none_inclusive>\s*\?)?\s*
(?:
(?P<quote>["\'])(?P<quotedstrval>.+?)(?P=quote)|
(?P<strval>.+?)
)
- ''' % '|'.join(map(re.escape, COMPARISON_OPERATORS.keys())))
+ '''.format('|'.join(map(re.escape, COMPARISON_OPERATORS.keys()))))
m = operator_rex.fullmatch(filter_part.strip())
if m:
m = m.groupdict()
@@ -3207,7 +3207,7 @@ def _match_one(filter_part, dct, incomplete):
op = unnegated_op
comparison_value = m['quotedstrval'] or m['strval'] or m['intval']
if m['quote']:
- comparison_value = comparison_value.replace(r'\%s' % m['quote'], m['quote'])
+ comparison_value = comparison_value.replace(r'\{}'.format(m['quote']), m['quote'])
actual_value = dct.get(m['key'])
numeric_comparison = None
if isinstance(actual_value, (int, float)):
@@ -3224,7 +3224,7 @@ def _match_one(filter_part, dct, incomplete):
if numeric_comparison is None:
numeric_comparison = parse_duration(comparison_value)
if numeric_comparison is not None and m['op'] in STRING_OPERATORS:
- raise ValueError('Operator %s only supports string values!' % m['op'])
+ raise ValueError('Operator {} only supports string values!'.format(m['op']))
if actual_value is None:
return is_incomplete(m['key']) or m['none_inclusive']
return op(actual_value, comparison_value if numeric_comparison is None else numeric_comparison)
@@ -3234,8 +3234,8 @@ def _match_one(filter_part, dct, incomplete):
'!': lambda v: (v is False) if isinstance(v, bool) else (v is None),
}
operator_rex = re.compile(r'''(?x)
- (?P<op>%s)\s*(?P<key>[a-z_]+)
- ''' % '|'.join(map(re.escape, UNARY_OPERATORS.keys())))
+ (?P<op>{})\s*(?P<key>[a-z_]+)
+ '''.format('|'.join(map(re.escape, UNARY_OPERATORS.keys()))))
m = operator_rex.fullmatch(filter_part.strip())
if m:
op = UNARY_OPERATORS[m.group('op')]
@@ -3244,7 +3244,7 @@ def _match_one(filter_part, dct, incomplete):
return True
return op(actual_value)
- raise ValueError('Invalid filter part %r' % filter_part)
+ raise ValueError(f'Invalid filter part {filter_part!r}')
def match_str(filter_str, dct, incomplete=False):
@@ -3351,10 +3351,10 @@ def ass_subtitles_timecode(seconds):
def dfxp2srt(dfxp_data):
- '''
+ """
@param dfxp_data A bytes-like object containing DFXP data
@returns A unicode object containing converted SRT data
- '''
+ """
LEGACY_NAMESPACES = (
(b'http://www.w3.org/ns/ttml', [
b'http://www.w3.org/2004/11/ttaf1',
@@ -3372,7 +3372,7 @@ def dfxp2srt(dfxp_data):
'fontSize',
'fontStyle',
'fontWeight',
- 'textDecoration'
+ 'textDecoration',
]
_x = functools.partial(xpath_with_ns, ns_map={
@@ -3410,11 +3410,11 @@ def dfxp2srt(dfxp_data):
if self._applied_styles and self._applied_styles[-1].get(k) == v:
continue
if k == 'color':
- font += ' color="%s"' % v
+ font += f' color="{v}"'
elif k == 'fontSize':
- font += ' size="%s"' % v
+ font += f' size="{v}"'
elif k == 'fontFamily':
- font += ' face="%s"' % v
+ font += f' face="{v}"'
elif k == 'fontWeight' and v == 'bold':
self._out += '<b>'
unclosed_elements.append('b')
@@ -3438,7 +3438,7 @@ def dfxp2srt(dfxp_data):
if tag not in (_x('ttml:br'), 'br'):
unclosed_elements = self._unclosed_elements.pop()
for element in reversed(unclosed_elements):
- self._out += '</%s>' % element
+ self._out += f'</{element}>'
if unclosed_elements and self._applied_styles:
self._applied_styles.pop()
@@ -4349,7 +4349,7 @@ def bytes_to_long(s):
def ohdave_rsa_encrypt(data, exponent, modulus):
- '''
+ """
Implement OHDave's RSA algorithm. See http://www.ohdave.com/rsa/
Input:
@@ -4358,11 +4358,11 @@ def ohdave_rsa_encrypt(data, exponent, modulus):
Output: hex string of encrypted data
Limitation: supports one block encryption only
- '''
+ """
payload = int(binascii.hexlify(data[::-1]), 16)
encrypted = pow(payload, exponent, modulus)
- return '%x' % encrypted
+ return f'{encrypted:x}'
def pkcs1pad(data, length):
@@ -4377,7 +4377,7 @@ def pkcs1pad(data, length):
raise ValueError('Input data too long for PKCS#1 padding')
pseudo_random = [random.randint(0, 254) for _ in range(length - len(data) - 3)]
- return [0, 2] + pseudo_random + [0] + data
+ return [0, 2, *pseudo_random, 0, *data]
def _base_n_table(n, table):
@@ -4710,16 +4710,14 @@ def jwt_encode_hs256(payload_data, key, headers={}):
payload_b64 = base64.b64encode(json.dumps(payload_data).encode())
h = hmac.new(key.encode(), header_b64 + b'.' + payload_b64, hashlib.sha256)
signature_b64 = base64.b64encode(h.digest())
- token = header_b64 + b'.' + payload_b64 + b'.' + signature_b64
- return token
+ return header_b64 + b'.' + payload_b64 + b'.' + signature_b64
# can be extended in future to verify the signature and parse header and return the algorithm used if it's not HS256
def jwt_decode_hs256(jwt):
header_b64, payload_b64, signature_b64 = jwt.split('.')
# add trailing ='s that may have been stripped, superfluous ='s are ignored
- payload_data = json.loads(base64.urlsafe_b64decode(f'{payload_b64}==='))
- return payload_data
+ return json.loads(base64.urlsafe_b64decode(f'{payload_b64}==='))
WINDOWS_VT_MODE = False if compat_os_name == 'nt' else None
@@ -4797,7 +4795,7 @@ def scale_thumbnails_to_max_format_width(formats, thumbnails, url_width_re):
"""
_keys = ('width', 'height')
max_dimensions = max(
- (tuple(format.get(k) or 0 for k in _keys) for format in formats),
+ (tuple(fmt.get(k) or 0 for k in _keys) for fmt in formats),
default=(0, 0))
if not max_dimensions[0]:
return thumbnails
@@ -5040,6 +5038,101 @@ MEDIA_EXTENSIONS.audio += MEDIA_EXTENSIONS.common_audio
KNOWN_EXTENSIONS = (*MEDIA_EXTENSIONS.video, *MEDIA_EXTENSIONS.audio, *MEDIA_EXTENSIONS.manifests)
+class _UnsafeExtensionError(Exception):
+ """
+ Mitigation exception for uncommon/malicious file extensions
+ This should be caught in YoutubeDL.py alongside a warning
+
+ Ref: https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-79w7-vh3h-8g4j
+ """
+ ALLOWED_EXTENSIONS = frozenset([
+ # internal
+ 'description',
+ 'json',
+ 'meta',
+ 'orig',
+ 'part',
+ 'temp',
+ 'uncut',
+ 'unknown_video',
+ 'ytdl',
+
+ # video
+ *MEDIA_EXTENSIONS.video,
+ 'avif',
+ 'ismv',
+ 'm2ts',
+ 'm4s',
+ 'mng',
+ 'mpeg',
+ 'qt',
+ 'swf',
+ 'ts',
+ 'vp9',
+ 'wvm',
+
+ # audio
+ *MEDIA_EXTENSIONS.audio,
+ 'isma',
+ 'mid',
+ 'mpga',
+ 'ra',
+
+ # image
+ *MEDIA_EXTENSIONS.thumbnails,
+ 'bmp',
+ 'gif',
+ 'heic',
+ 'ico',
+ 'jng',
+ 'jpeg',
+ 'jxl',
+ 'svg',
+ 'tif',
+ 'wbmp',
+
+ # subtitle
+ *MEDIA_EXTENSIONS.subtitles,
+ 'dfxp',
+ 'fs',
+ 'ismt',
+ 'sami',
+ 'scc',
+ 'ssa',
+ 'tt',
+ 'ttml',
+
+ # others
+ *MEDIA_EXTENSIONS.manifests,
+ *MEDIA_EXTENSIONS.storyboards,
+ 'desktop',
+ 'ism',
+ 'm3u',
+ 'sbv',
+ 'url',
+ 'webloc',
+ 'xml',
+ ])
+
+ def __init__(self, extension, /):
+ super().__init__(f'unsafe file extension: {extension!r}')
+ self.extension = extension
+
+ @classmethod
+ def sanitize_extension(cls, extension, /, *, prepend=False):
+ if '/' in extension or '\\' in extension:
+ raise cls(extension)
+
+ if not prepend:
+ _, _, last = extension.rpartition('.')
+ if last == 'bin':
+ extension = last = 'unknown_video'
+ if last.lower() not in cls.ALLOWED_EXTENSIONS:
+ raise cls(extension)
+
+ return extension
+
+
class RetryManager:
"""Usage:
for retry in RetryManager(...):
@@ -5193,7 +5286,7 @@ class FormatSorter:
'function': lambda it: next(filter(None, it), None)},
'ext': {'type': 'combined', 'field': ('vext', 'aext')},
'res': {'type': 'multiple', 'field': ('height', 'width'),
- 'function': lambda it: (lambda l: min(l) if l else 0)(tuple(filter(None, it)))},
+ 'function': lambda it: min(filter(None, it), default=0)},
# Actual field names
'format_id': {'type': 'alias', 'field': 'id'},
@@ -5241,21 +5334,21 @@ class FormatSorter:
self.ydl.deprecated_feature(f'Using arbitrary fields ({field}) for format sorting is '
'deprecated and may be removed in a future version')
self.settings[field] = {}
- propObj = self.settings[field]
- if key not in propObj:
- type = propObj.get('type')
+ prop_obj = self.settings[field]
+ if key not in prop_obj:
+ type_ = prop_obj.get('type')
if key == 'field':
- default = 'preference' if type == 'extractor' else (field,) if type in ('combined', 'multiple') else field
+ default = 'preference' if type_ == 'extractor' else (field,) if type_ in ('combined', 'multiple') else field
elif key == 'convert':
- default = 'order' if type == 'ordered' else 'float_string' if field else 'ignore'
+ default = 'order' if type_ == 'ordered' else 'float_string' if field else 'ignore'
else:
- default = {'type': 'field', 'visible': True, 'order': [], 'not_in_list': (None,)}.get(key, None)
- propObj[key] = default
- return propObj[key]
+ default = {'type': 'field', 'visible': True, 'order': [], 'not_in_list': (None,)}.get(key)
+ prop_obj[key] = default
+ return prop_obj[key]
- def _resolve_field_value(self, field, value, convertNone=False):
+ def _resolve_field_value(self, field, value, convert_none=False):
if value is None:
- if not convertNone:
+ if not convert_none:
return None
else:
value = value.lower()
@@ -5317,7 +5410,7 @@ class FormatSorter:
for item in sort_list:
match = re.match(self.regex, item)
if match is None:
- raise ExtractorError('Invalid format sort string "%s" given by extractor' % item)
+ raise ExtractorError(f'Invalid format sort string "{item}" given by extractor')
field = match.group('field')
if field is None:
continue
@@ -5345,31 +5438,31 @@ class FormatSorter:
def print_verbose_info(self, write_debug):
if self._sort_user:
- write_debug('Sort order given by user: %s' % ', '.join(self._sort_user))
+ write_debug('Sort order given by user: {}'.format(', '.join(self._sort_user)))
if self._sort_extractor:
- write_debug('Sort order given by extractor: %s' % ', '.join(self._sort_extractor))
- write_debug('Formats sorted by: %s' % ', '.join(['%s%s%s' % (
+ write_debug('Sort order given by extractor: {}'.format(', '.join(self._sort_extractor)))
+ write_debug('Formats sorted by: {}'.format(', '.join(['{}{}{}'.format(
'+' if self._get_field_setting(field, 'reverse') else '', field,
- '%s%s(%s)' % ('~' if self._get_field_setting(field, 'closest') else ':',
- self._get_field_setting(field, 'limit_text'),
- self._get_field_setting(field, 'limit'))
+ '{}{}({})'.format('~' if self._get_field_setting(field, 'closest') else ':',
+ self._get_field_setting(field, 'limit_text'),
+ self._get_field_setting(field, 'limit'))
if self._get_field_setting(field, 'limit_text') is not None else '')
- for field in self._order if self._get_field_setting(field, 'visible')]))
+ for field in self._order if self._get_field_setting(field, 'visible')])))
- def _calculate_field_preference_from_value(self, format, field, type, value):
+ def _calculate_field_preference_from_value(self, format_, field, type_, value):
reverse = self._get_field_setting(field, 'reverse')
closest = self._get_field_setting(field, 'closest')
limit = self._get_field_setting(field, 'limit')
- if type == 'extractor':
+ if type_ == 'extractor':
maximum = self._get_field_setting(field, 'max')
if value is None or (maximum is not None and value >= maximum):
value = -1
- elif type == 'boolean':
+ elif type_ == 'boolean':
in_list = self._get_field_setting(field, 'in_list')
not_in_list = self._get_field_setting(field, 'not_in_list')
value = 0 if ((in_list is None or value in in_list) and (not_in_list is None or value not in not_in_list)) else -1
- elif type == 'ordered':
+ elif type_ == 'ordered':
value = self._resolve_field_value(field, value, True)
# try to convert to number
@@ -5385,17 +5478,17 @@ class FormatSorter:
else (0, -value, 0) if limit is None or (reverse and value == limit) or value > limit
else (-1, value, 0))
- def _calculate_field_preference(self, format, field):
- type = self._get_field_setting(field, 'type') # extractor, boolean, ordered, field, multiple
- get_value = lambda f: format.get(self._get_field_setting(f, 'field'))
- if type == 'multiple':
- type = 'field' # Only 'field' is allowed in multiple for now
+ def _calculate_field_preference(self, format_, field):
+ type_ = self._get_field_setting(field, 'type') # extractor, boolean, ordered, field, multiple
+ get_value = lambda f: format_.get(self._get_field_setting(f, 'field'))
+ if type_ == 'multiple':
+ type_ = 'field' # Only 'field' is allowed in multiple for now
actual_fields = self._get_field_setting(field, 'field')
value = self._get_field_setting(field, 'function')(get_value(f) for f in actual_fields)
else:
value = get_value(field)
- return self._calculate_field_preference_from_value(format, field, type, value)
+ return self._calculate_field_preference_from_value(format_, field, type_, value)
def calculate_preference(self, format):
# Determine missing protocol