diff options
Diffstat (limited to '')
-rw-r--r-- | yt_dlp/utils/_utils.py | 311 |
1 files changed, 202 insertions, 109 deletions
diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py index 42803bb..b5e1e29 100644 --- a/yt_dlp/utils/_utils.py +++ b/yt_dlp/utils/_utils.py @@ -53,7 +53,7 @@ from ..compat import ( ) from ..dependencies import xattr -__name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module +__name__ = __name__.rsplit('.', 1)[0] # noqa: A001: Pretend to be the parent module # This is not clearly defined otherwise compiled_regex_type = type(re.compile('')) @@ -90,7 +90,7 @@ TIMEZONE_NAMES = { 'EST': -5, 'EDT': -4, # Eastern 'CST': -6, 'CDT': -5, # Central 'MST': -7, 'MDT': -6, # Mountain - 'PST': -8, 'PDT': -7 # Pacific + 'PST': -8, 'PDT': -7, # Pacific } # needed for sanitizing filenames in restricted mode @@ -215,7 +215,7 @@ def write_json_file(obj, fn): def find_xpath_attr(node, xpath, key, val=None): """ Find the xpath xpath[@key=val] """ assert re.match(r'^[a-zA-Z_-]+$', key) - expr = xpath + ('[@%s]' % key if val is None else f"[@{key}='{val}']") + expr = xpath + (f'[@{key}]' if val is None else f"[@{key}='{val}']") return node.find(expr) # On python2.6 the xml.etree.ElementTree.Element methods don't support @@ -230,7 +230,7 @@ def xpath_with_ns(path, ns_map): replaced.append(c[0]) else: ns, tag = c - replaced.append('{%s}%s' % (ns_map[ns], tag)) + replaced.append(f'{{{ns_map[ns]}}}{tag}') return '/'.join(replaced) @@ -251,7 +251,7 @@ def xpath_element(node, xpath, name=None, fatal=False, default=NO_DEFAULT): return default elif fatal: name = xpath if name is None else name - raise ExtractorError('Could not find XML element %s' % name) + raise ExtractorError(f'Could not find XML element {name}') else: return None return n @@ -266,7 +266,7 @@ def xpath_text(node, xpath, name=None, fatal=False, default=NO_DEFAULT): return default elif fatal: name = xpath if name is None else name - raise ExtractorError('Could not find XML element\'s text %s' % name) + raise ExtractorError(f'Could not find XML element\'s text {name}') else: return None return n.text @@ -279,7 +279,7 @@ def xpath_attr(node, xpath, key, name=None, fatal=False, default=NO_DEFAULT): return default elif fatal: name = f'{xpath}[@{key}]' if name is None else name - raise ExtractorError('Could not find XML attribute %s' % name) + raise ExtractorError(f'Could not find XML attribute {name}') else: return None return n.attrib[key] @@ -320,14 +320,14 @@ def get_element_html_by_attribute(attribute, value, html, **kargs): def get_elements_by_class(class_name, html, **kargs): """Return the content of all tags with the specified class in the passed HTML document as a list""" return get_elements_by_attribute( - 'class', r'[^\'"]*(?<=[\'"\s])%s(?=[\'"\s])[^\'"]*' % re.escape(class_name), + 'class', rf'[^\'"]*(?<=[\'"\s]){re.escape(class_name)}(?=[\'"\s])[^\'"]*', html, escape_value=False) def get_elements_html_by_class(class_name, html): """Return the html of all tags with the specified class in the passed HTML document as a list""" return get_elements_html_by_attribute( - 'class', r'[^\'"]*(?<=[\'"\s])%s(?=[\'"\s])[^\'"]*' % re.escape(class_name), + 'class', rf'[^\'"]*(?<=[\'"\s]){re.escape(class_name)}(?=[\'"\s])[^\'"]*', html, escape_value=False) @@ -364,7 +364,7 @@ def get_elements_text_and_html_by_attribute(attribute, value, html, *, tag=r'[\w yield ( unescapeHTML(re.sub(r'^(?P<q>["\'])(?P<content>.*)(?P=q)$', r'\g<content>', content, flags=re.DOTALL)), - whole + whole, ) @@ -407,7 +407,7 @@ class HTMLBreakOnClosingTagParser(html.parser.HTMLParser): else: raise compat_HTMLParseError(f'matching opening tag for closing {tag} tag not found') if not self.tagstack: - raise self.HTMLBreakOnClosingTagException() + raise self.HTMLBreakOnClosingTagException # XXX: This should be far less strict @@ -587,7 +587,7 @@ def sanitize_open(filename, open_mode): # FIXME: An exclusive lock also locks the file from being read. # Since windows locks are mandatory, don't lock the file on windows (for now). # Ref: https://github.com/yt-dlp/yt-dlp/issues/3124 - raise LockingUnsupportedError() + raise LockingUnsupportedError stream = locked_file(filename, open_mode, block=False).__enter__() except OSError: stream = open(filename, open_mode) @@ -717,9 +717,9 @@ def extract_basic_auth(url): return url, None url = urllib.parse.urlunsplit(parts._replace(netloc=( parts.hostname if parts.port is None - else '%s:%d' % (parts.hostname, parts.port)))) + else f'{parts.hostname}:{parts.port}'))) auth_payload = base64.b64encode( - ('%s:%s' % (parts.username, parts.password or '')).encode()) + ('{}:{}'.format(parts.username, parts.password or '')).encode()) return url, f'Basic {auth_payload.decode()}' @@ -758,7 +758,7 @@ def _htmlentity_transform(entity_with_semicolon): numstr = mobj.group(1) if numstr.startswith('x'): base = 16 - numstr = '0%s' % numstr + numstr = f'0{numstr}' else: base = 10 # See https://github.com/ytdl-org/youtube-dl/issues/7518 @@ -766,7 +766,7 @@ def _htmlentity_transform(entity_with_semicolon): return chr(int(numstr, base)) # Unknown entity in name, return its literal representation - return '&%s;' % entity + return f'&{entity};' def unescapeHTML(s): @@ -970,7 +970,7 @@ class ExtractorError(YoutubeDLError): class UnsupportedError(ExtractorError): def __init__(self, url): super().__init__( - 'Unsupported URL: %s' % url, expected=True) + f'Unsupported URL: {url}', expected=True) self.url = url @@ -1367,7 +1367,7 @@ class DateRange: else: self.end = dt.datetime.max.date() if self.start > self.end: - raise ValueError('Date range: "%s" , the start date must be before the end date' % self) + raise ValueError(f'Date range: "{self}" , the start date must be before the end date') @classmethod def day(cls, day): @@ -1400,7 +1400,7 @@ def system_identifier(): with contextlib.suppress(OSError): # We may not have access to the executable libc_ver = platform.libc_ver() - return 'Python %s (%s %s %s) - %s (%s%s)' % ( + return 'Python {} ({} {} {}) - {} ({}{})'.format( platform.python_version(), python_implementation, platform.machine(), @@ -1413,7 +1413,7 @@ def system_identifier(): @functools.cache def get_windows_version(): - ''' Get Windows version. returns () if it's not running on Windows ''' + """ Get Windows version. returns () if it's not running on Windows """ if compat_os_name == 'nt': return version_tuple(platform.win32_ver()[1]) else: @@ -1505,7 +1505,7 @@ if sys.platform == 'win32': ctypes.wintypes.DWORD, # dwReserved ctypes.wintypes.DWORD, # nNumberOfBytesToLockLow ctypes.wintypes.DWORD, # nNumberOfBytesToLockHigh - ctypes.POINTER(OVERLAPPED) # Overlapped + ctypes.POINTER(OVERLAPPED), # Overlapped ] LockFileEx.restype = ctypes.wintypes.BOOL UnlockFileEx = kernel32.UnlockFileEx @@ -1514,7 +1514,7 @@ if sys.platform == 'win32': ctypes.wintypes.DWORD, # dwReserved ctypes.wintypes.DWORD, # nNumberOfBytesToLockLow ctypes.wintypes.DWORD, # nNumberOfBytesToLockHigh - ctypes.POINTER(OVERLAPPED) # Overlapped + ctypes.POINTER(OVERLAPPED), # Overlapped ] UnlockFileEx.restype = ctypes.wintypes.BOOL whole_low = 0xffffffff @@ -1537,7 +1537,7 @@ if sys.platform == 'win32': assert f._lock_file_overlapped_p handle = msvcrt.get_osfhandle(f.fileno()) if not UnlockFileEx(handle, 0, whole_low, whole_high, f._lock_file_overlapped_p): - raise OSError('Unlocking file failed: %r' % ctypes.FormatError()) + raise OSError(f'Unlocking file failed: {ctypes.FormatError()!r}') else: try: @@ -1564,10 +1564,10 @@ else: except ImportError: def _lock_file(f, exclusive, block): - raise LockingUnsupportedError() + raise LockingUnsupportedError def _unlock_file(f): - raise LockingUnsupportedError() + raise LockingUnsupportedError class locked_file: @@ -1926,7 +1926,7 @@ def remove_end(s, end): def remove_quotes(s): if s is None or len(s) < 2: return s - for quote in ('"', "'", ): + for quote in ('"', "'"): if s[0] == quote and s[-1] == quote: return s[1:-1] return s @@ -2085,26 +2085,27 @@ def parse_duration(s): (days, 86400), (hours, 3600), (mins, 60), (secs, 1), (ms, 1))) -def prepend_extension(filename, ext, expected_real_ext=None): +def _change_extension(prepend, filename, ext, expected_real_ext=None): name, real_ext = os.path.splitext(filename) - return ( - f'{name}.{ext}{real_ext}' - if not expected_real_ext or real_ext[1:] == expected_real_ext - else f'{filename}.{ext}') + if not expected_real_ext or real_ext[1:] == expected_real_ext: + filename = name + if prepend and real_ext: + _UnsafeExtensionError.sanitize_extension(ext, prepend=True) + return f'{filename}.{ext}{real_ext}' -def replace_extension(filename, ext, expected_real_ext=None): - name, real_ext = os.path.splitext(filename) - return '{}.{}'.format( - name if not expected_real_ext or real_ext[1:] == expected_real_ext else filename, - ext) + return f'{filename}.{_UnsafeExtensionError.sanitize_extension(ext)}' + + +prepend_extension = functools.partial(_change_extension, True) +replace_extension = functools.partial(_change_extension, False) def check_executable(exe, args=[]): """ Checks if the given binary is installed somewhere in PATH, and returns its name. args can be a list of arguments for a short output (like -version) """ try: - Popen.run([exe] + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + Popen.run([exe, *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE) except OSError: return False return exe @@ -2115,7 +2116,7 @@ def _get_exe_version_output(exe, args): # STDIN should be redirected too. On UNIX-like systems, ffmpeg triggers # SIGTTOU if yt-dlp is run in the background. # See https://github.com/ytdl-org/youtube-dl/issues/955#issuecomment-209789656 - stdout, _, ret = Popen.run([encodeArgument(exe)] + args, text=True, + stdout, _, ret = Popen.run([encodeArgument(exe), *args], text=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) if ret: return None @@ -2161,7 +2162,7 @@ class LazyList(collections.abc.Sequence): """Lazy immutable list from an iterable Note that slices of a LazyList are lists and not LazyList""" - class IndexError(IndexError): + class IndexError(IndexError): # noqa: A001 pass def __init__(self, iterable, *, reverse=False, _cache=None): @@ -2248,7 +2249,7 @@ class LazyList(collections.abc.Sequence): class PagedList: - class IndexError(IndexError): + class IndexError(IndexError): # noqa: A001 pass def __len__(self): @@ -2282,7 +2283,7 @@ class PagedList: raise TypeError('indices must be non-negative integers') entries = self.getslice(idx, idx + 1) if not entries: - raise self.IndexError() + raise self.IndexError return entries[0] def __bool__(self): @@ -2443,7 +2444,7 @@ class PlaylistEntries: except IndexError: entry = self.MissingEntry if not self.is_incomplete: - raise self.IndexError() + raise self.IndexError if entry is self.MissingEntry: raise EntryNotInPlaylist(f'Entry {i + 1} cannot be found') return entry @@ -2452,7 +2453,7 @@ class PlaylistEntries: try: return type(self.ydl)._handle_extraction_exceptions(lambda _, i: self._entries[i])(self.ydl, i) except (LazyList.IndexError, PagedList.IndexError): - raise self.IndexError() + raise self.IndexError return get_entry def __getitem__(self, idx): @@ -2488,7 +2489,7 @@ class PlaylistEntries: def __len__(self): return len(tuple(self[:])) - class IndexError(IndexError): + class IndexError(IndexError): # noqa: A001 pass @@ -2550,7 +2551,7 @@ def update_url(url, *, query_update=None, **kwargs): assert 'query' not in kwargs, 'query_update and query cannot be specified at the same time' kwargs['query'] = urllib.parse.urlencode({ **urllib.parse.parse_qs(url.query), - **query_update + **query_update, }, True) return urllib.parse.urlunparse(url._replace(**kwargs)) @@ -2560,7 +2561,7 @@ def update_url_query(url, query): def _multipart_encode_impl(data, boundary): - content_type = 'multipart/form-data; boundary=%s' % boundary + content_type = f'multipart/form-data; boundary={boundary}' out = b'' for k, v in data.items(): @@ -2582,7 +2583,7 @@ def _multipart_encode_impl(data, boundary): def multipart_encode(data, boundary=None): - ''' + """ Encode a dict to RFC 7578-compliant form-data data: @@ -2593,7 +2594,7 @@ def multipart_encode(data, boundary=None): a random boundary is generated. Reference: https://tools.ietf.org/html/rfc7578 - ''' + """ has_specified_boundary = boundary is not None while True: @@ -2688,7 +2689,7 @@ def parse_age_limit(s): s = s.upper() if s in US_RATINGS: return US_RATINGS[s] - m = re.match(r'^TV[_-]?(%s)$' % '|'.join(k[3:] for k in TV_PARENTAL_GUIDELINES), s) + m = re.match(r'^TV[_-]?({})$'.format('|'.join(k[3:] for k in TV_PARENTAL_GUIDELINES)), s) if m: return TV_PARENTAL_GUIDELINES['TV-' + m.group(1)] return None @@ -2736,7 +2737,7 @@ def js_to_json(code, vars={}, *, strict=False): return v elif v in ('undefined', 'void 0'): return 'null' - elif v.startswith('/*') or v.startswith('//') or v.startswith('!') or v == ',': + elif v.startswith(('/*', '//', '!')) or v == ',': return '' if v[0] in STRING_QUOTES: @@ -3079,7 +3080,7 @@ def urlhandle_detect_ext(url_handle, default=NO_DEFAULT): def encode_data_uri(data, mime_type): - return 'data:%s;base64,%s' % (mime_type, base64.b64encode(data).decode('ascii')) + return 'data:{};base64,{}'.format(mime_type, base64.b64encode(data).decode('ascii')) def age_restricted(content_limit, age_limit): @@ -3144,18 +3145,18 @@ def render_table(header_row, data, delim=False, extra_gap=0, hide_empty=False): def get_max_lens(table): return [max(width(str(v)) for v in col) for col in zip(*table)] - def filter_using_list(row, filterArray): - return [col for take, col in itertools.zip_longest(filterArray, row, fillvalue=True) if take] + def filter_using_list(row, filter_array): + return [col for take, col in itertools.zip_longest(filter_array, row, fillvalue=True) if take] max_lens = get_max_lens(data) if hide_empty else [] header_row = filter_using_list(header_row, max_lens) data = [filter_using_list(row, max_lens) for row in data] - table = [header_row] + data + table = [header_row, *data] max_lens = get_max_lens(table) extra_gap += 1 if delim: - table = [header_row, [delim * (ml + extra_gap) for ml in max_lens]] + data + table = [header_row, [delim * (ml + extra_gap) for ml in max_lens], *data] table[1][-1] = table[1][-1][:-extra_gap * len(delim)] # Remove extra_gap from end of delimiter for row in table: for pos, text in enumerate(map(str, row)): @@ -3163,8 +3164,7 @@ def render_table(header_row, data, delim=False, extra_gap=0, hide_empty=False): row[pos] = text.replace('\t', ' ' * (max_lens[pos] - width(text))) + ' ' * extra_gap else: row[pos] = text + ' ' * (max_lens[pos] - width(text) + extra_gap) - ret = '\n'.join(''.join(row).rstrip() for row in table) - return ret + return '\n'.join(''.join(row).rstrip() for row in table) def _match_one(filter_part, dct, incomplete): @@ -3191,12 +3191,12 @@ def _match_one(filter_part, dct, incomplete): operator_rex = re.compile(r'''(?x) (?P<key>[a-z_]+) - \s*(?P<negation>!\s*)?(?P<op>%s)(?P<none_inclusive>\s*\?)?\s* + \s*(?P<negation>!\s*)?(?P<op>{})(?P<none_inclusive>\s*\?)?\s* (?: (?P<quote>["\'])(?P<quotedstrval>.+?)(?P=quote)| (?P<strval>.+?) ) - ''' % '|'.join(map(re.escape, COMPARISON_OPERATORS.keys()))) + '''.format('|'.join(map(re.escape, COMPARISON_OPERATORS.keys())))) m = operator_rex.fullmatch(filter_part.strip()) if m: m = m.groupdict() @@ -3207,7 +3207,7 @@ def _match_one(filter_part, dct, incomplete): op = unnegated_op comparison_value = m['quotedstrval'] or m['strval'] or m['intval'] if m['quote']: - comparison_value = comparison_value.replace(r'\%s' % m['quote'], m['quote']) + comparison_value = comparison_value.replace(r'\{}'.format(m['quote']), m['quote']) actual_value = dct.get(m['key']) numeric_comparison = None if isinstance(actual_value, (int, float)): @@ -3224,7 +3224,7 @@ def _match_one(filter_part, dct, incomplete): if numeric_comparison is None: numeric_comparison = parse_duration(comparison_value) if numeric_comparison is not None and m['op'] in STRING_OPERATORS: - raise ValueError('Operator %s only supports string values!' % m['op']) + raise ValueError('Operator {} only supports string values!'.format(m['op'])) if actual_value is None: return is_incomplete(m['key']) or m['none_inclusive'] return op(actual_value, comparison_value if numeric_comparison is None else numeric_comparison) @@ -3234,8 +3234,8 @@ def _match_one(filter_part, dct, incomplete): '!': lambda v: (v is False) if isinstance(v, bool) else (v is None), } operator_rex = re.compile(r'''(?x) - (?P<op>%s)\s*(?P<key>[a-z_]+) - ''' % '|'.join(map(re.escape, UNARY_OPERATORS.keys()))) + (?P<op>{})\s*(?P<key>[a-z_]+) + '''.format('|'.join(map(re.escape, UNARY_OPERATORS.keys())))) m = operator_rex.fullmatch(filter_part.strip()) if m: op = UNARY_OPERATORS[m.group('op')] @@ -3244,7 +3244,7 @@ def _match_one(filter_part, dct, incomplete): return True return op(actual_value) - raise ValueError('Invalid filter part %r' % filter_part) + raise ValueError(f'Invalid filter part {filter_part!r}') def match_str(filter_str, dct, incomplete=False): @@ -3351,10 +3351,10 @@ def ass_subtitles_timecode(seconds): def dfxp2srt(dfxp_data): - ''' + """ @param dfxp_data A bytes-like object containing DFXP data @returns A unicode object containing converted SRT data - ''' + """ LEGACY_NAMESPACES = ( (b'http://www.w3.org/ns/ttml', [ b'http://www.w3.org/2004/11/ttaf1', @@ -3372,7 +3372,7 @@ def dfxp2srt(dfxp_data): 'fontSize', 'fontStyle', 'fontWeight', - 'textDecoration' + 'textDecoration', ] _x = functools.partial(xpath_with_ns, ns_map={ @@ -3410,11 +3410,11 @@ def dfxp2srt(dfxp_data): if self._applied_styles and self._applied_styles[-1].get(k) == v: continue if k == 'color': - font += ' color="%s"' % v + font += f' color="{v}"' elif k == 'fontSize': - font += ' size="%s"' % v + font += f' size="{v}"' elif k == 'fontFamily': - font += ' face="%s"' % v + font += f' face="{v}"' elif k == 'fontWeight' and v == 'bold': self._out += '<b>' unclosed_elements.append('b') @@ -3438,7 +3438,7 @@ def dfxp2srt(dfxp_data): if tag not in (_x('ttml:br'), 'br'): unclosed_elements = self._unclosed_elements.pop() for element in reversed(unclosed_elements): - self._out += '</%s>' % element + self._out += f'</{element}>' if unclosed_elements and self._applied_styles: self._applied_styles.pop() @@ -4349,7 +4349,7 @@ def bytes_to_long(s): def ohdave_rsa_encrypt(data, exponent, modulus): - ''' + """ Implement OHDave's RSA algorithm. See http://www.ohdave.com/rsa/ Input: @@ -4358,11 +4358,11 @@ def ohdave_rsa_encrypt(data, exponent, modulus): Output: hex string of encrypted data Limitation: supports one block encryption only - ''' + """ payload = int(binascii.hexlify(data[::-1]), 16) encrypted = pow(payload, exponent, modulus) - return '%x' % encrypted + return f'{encrypted:x}' def pkcs1pad(data, length): @@ -4377,7 +4377,7 @@ def pkcs1pad(data, length): raise ValueError('Input data too long for PKCS#1 padding') pseudo_random = [random.randint(0, 254) for _ in range(length - len(data) - 3)] - return [0, 2] + pseudo_random + [0] + data + return [0, 2, *pseudo_random, 0, *data] def _base_n_table(n, table): @@ -4710,16 +4710,14 @@ def jwt_encode_hs256(payload_data, key, headers={}): payload_b64 = base64.b64encode(json.dumps(payload_data).encode()) h = hmac.new(key.encode(), header_b64 + b'.' + payload_b64, hashlib.sha256) signature_b64 = base64.b64encode(h.digest()) - token = header_b64 + b'.' + payload_b64 + b'.' + signature_b64 - return token + return header_b64 + b'.' + payload_b64 + b'.' + signature_b64 # can be extended in future to verify the signature and parse header and return the algorithm used if it's not HS256 def jwt_decode_hs256(jwt): header_b64, payload_b64, signature_b64 = jwt.split('.') # add trailing ='s that may have been stripped, superfluous ='s are ignored - payload_data = json.loads(base64.urlsafe_b64decode(f'{payload_b64}===')) - return payload_data + return json.loads(base64.urlsafe_b64decode(f'{payload_b64}===')) WINDOWS_VT_MODE = False if compat_os_name == 'nt' else None @@ -4797,7 +4795,7 @@ def scale_thumbnails_to_max_format_width(formats, thumbnails, url_width_re): """ _keys = ('width', 'height') max_dimensions = max( - (tuple(format.get(k) or 0 for k in _keys) for format in formats), + (tuple(fmt.get(k) or 0 for k in _keys) for fmt in formats), default=(0, 0)) if not max_dimensions[0]: return thumbnails @@ -5040,6 +5038,101 @@ MEDIA_EXTENSIONS.audio += MEDIA_EXTENSIONS.common_audio KNOWN_EXTENSIONS = (*MEDIA_EXTENSIONS.video, *MEDIA_EXTENSIONS.audio, *MEDIA_EXTENSIONS.manifests) +class _UnsafeExtensionError(Exception): + """ + Mitigation exception for uncommon/malicious file extensions + This should be caught in YoutubeDL.py alongside a warning + + Ref: https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-79w7-vh3h-8g4j + """ + ALLOWED_EXTENSIONS = frozenset([ + # internal + 'description', + 'json', + 'meta', + 'orig', + 'part', + 'temp', + 'uncut', + 'unknown_video', + 'ytdl', + + # video + *MEDIA_EXTENSIONS.video, + 'avif', + 'ismv', + 'm2ts', + 'm4s', + 'mng', + 'mpeg', + 'qt', + 'swf', + 'ts', + 'vp9', + 'wvm', + + # audio + *MEDIA_EXTENSIONS.audio, + 'isma', + 'mid', + 'mpga', + 'ra', + + # image + *MEDIA_EXTENSIONS.thumbnails, + 'bmp', + 'gif', + 'heic', + 'ico', + 'jng', + 'jpeg', + 'jxl', + 'svg', + 'tif', + 'wbmp', + + # subtitle + *MEDIA_EXTENSIONS.subtitles, + 'dfxp', + 'fs', + 'ismt', + 'sami', + 'scc', + 'ssa', + 'tt', + 'ttml', + + # others + *MEDIA_EXTENSIONS.manifests, + *MEDIA_EXTENSIONS.storyboards, + 'desktop', + 'ism', + 'm3u', + 'sbv', + 'url', + 'webloc', + 'xml', + ]) + + def __init__(self, extension, /): + super().__init__(f'unsafe file extension: {extension!r}') + self.extension = extension + + @classmethod + def sanitize_extension(cls, extension, /, *, prepend=False): + if '/' in extension or '\\' in extension: + raise cls(extension) + + if not prepend: + _, _, last = extension.rpartition('.') + if last == 'bin': + extension = last = 'unknown_video' + if last.lower() not in cls.ALLOWED_EXTENSIONS: + raise cls(extension) + + return extension + + class RetryManager: """Usage: for retry in RetryManager(...): @@ -5193,7 +5286,7 @@ class FormatSorter: 'function': lambda it: next(filter(None, it), None)}, 'ext': {'type': 'combined', 'field': ('vext', 'aext')}, 'res': {'type': 'multiple', 'field': ('height', 'width'), - 'function': lambda it: (lambda l: min(l) if l else 0)(tuple(filter(None, it)))}, + 'function': lambda it: min(filter(None, it), default=0)}, # Actual field names 'format_id': {'type': 'alias', 'field': 'id'}, @@ -5241,21 +5334,21 @@ class FormatSorter: self.ydl.deprecated_feature(f'Using arbitrary fields ({field}) for format sorting is ' 'deprecated and may be removed in a future version') self.settings[field] = {} - propObj = self.settings[field] - if key not in propObj: - type = propObj.get('type') + prop_obj = self.settings[field] + if key not in prop_obj: + type_ = prop_obj.get('type') if key == 'field': - default = 'preference' if type == 'extractor' else (field,) if type in ('combined', 'multiple') else field + default = 'preference' if type_ == 'extractor' else (field,) if type_ in ('combined', 'multiple') else field elif key == 'convert': - default = 'order' if type == 'ordered' else 'float_string' if field else 'ignore' + default = 'order' if type_ == 'ordered' else 'float_string' if field else 'ignore' else: - default = {'type': 'field', 'visible': True, 'order': [], 'not_in_list': (None,)}.get(key, None) - propObj[key] = default - return propObj[key] + default = {'type': 'field', 'visible': True, 'order': [], 'not_in_list': (None,)}.get(key) + prop_obj[key] = default + return prop_obj[key] - def _resolve_field_value(self, field, value, convertNone=False): + def _resolve_field_value(self, field, value, convert_none=False): if value is None: - if not convertNone: + if not convert_none: return None else: value = value.lower() @@ -5317,7 +5410,7 @@ class FormatSorter: for item in sort_list: match = re.match(self.regex, item) if match is None: - raise ExtractorError('Invalid format sort string "%s" given by extractor' % item) + raise ExtractorError(f'Invalid format sort string "{item}" given by extractor') field = match.group('field') if field is None: continue @@ -5345,31 +5438,31 @@ class FormatSorter: def print_verbose_info(self, write_debug): if self._sort_user: - write_debug('Sort order given by user: %s' % ', '.join(self._sort_user)) + write_debug('Sort order given by user: {}'.format(', '.join(self._sort_user))) if self._sort_extractor: - write_debug('Sort order given by extractor: %s' % ', '.join(self._sort_extractor)) - write_debug('Formats sorted by: %s' % ', '.join(['%s%s%s' % ( + write_debug('Sort order given by extractor: {}'.format(', '.join(self._sort_extractor))) + write_debug('Formats sorted by: {}'.format(', '.join(['{}{}{}'.format( '+' if self._get_field_setting(field, 'reverse') else '', field, - '%s%s(%s)' % ('~' if self._get_field_setting(field, 'closest') else ':', - self._get_field_setting(field, 'limit_text'), - self._get_field_setting(field, 'limit')) + '{}{}({})'.format('~' if self._get_field_setting(field, 'closest') else ':', + self._get_field_setting(field, 'limit_text'), + self._get_field_setting(field, 'limit')) if self._get_field_setting(field, 'limit_text') is not None else '') - for field in self._order if self._get_field_setting(field, 'visible')])) + for field in self._order if self._get_field_setting(field, 'visible')]))) - def _calculate_field_preference_from_value(self, format, field, type, value): + def _calculate_field_preference_from_value(self, format_, field, type_, value): reverse = self._get_field_setting(field, 'reverse') closest = self._get_field_setting(field, 'closest') limit = self._get_field_setting(field, 'limit') - if type == 'extractor': + if type_ == 'extractor': maximum = self._get_field_setting(field, 'max') if value is None or (maximum is not None and value >= maximum): value = -1 - elif type == 'boolean': + elif type_ == 'boolean': in_list = self._get_field_setting(field, 'in_list') not_in_list = self._get_field_setting(field, 'not_in_list') value = 0 if ((in_list is None or value in in_list) and (not_in_list is None or value not in not_in_list)) else -1 - elif type == 'ordered': + elif type_ == 'ordered': value = self._resolve_field_value(field, value, True) # try to convert to number @@ -5385,17 +5478,17 @@ class FormatSorter: else (0, -value, 0) if limit is None or (reverse and value == limit) or value > limit else (-1, value, 0)) - def _calculate_field_preference(self, format, field): - type = self._get_field_setting(field, 'type') # extractor, boolean, ordered, field, multiple - get_value = lambda f: format.get(self._get_field_setting(f, 'field')) - if type == 'multiple': - type = 'field' # Only 'field' is allowed in multiple for now + def _calculate_field_preference(self, format_, field): + type_ = self._get_field_setting(field, 'type') # extractor, boolean, ordered, field, multiple + get_value = lambda f: format_.get(self._get_field_setting(f, 'field')) + if type_ == 'multiple': + type_ = 'field' # Only 'field' is allowed in multiple for now actual_fields = self._get_field_setting(field, 'field') value = self._get_field_setting(field, 'function')(get_value(f) for f in actual_fields) else: value = get_value(field) - return self._calculate_field_preference_from_value(format, field, type, value) + return self._calculate_field_preference_from_value(format_, field, type_, value) def calculate_preference(self, format): # Determine missing protocol |