diff options
Diffstat (limited to 'yt_dlp/utils')
-rw-r--r-- | yt_dlp/utils/_utils.py | 129 | ||||
-rw-r--r-- | yt_dlp/utils/traversal.py | 37 |
2 files changed, 110 insertions, 56 deletions
diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py index 9efeb6a..e3e80f3 100644 --- a/yt_dlp/utils/_utils.py +++ b/yt_dlp/utils/_utils.py @@ -5,7 +5,7 @@ import codecs import collections import collections.abc import contextlib -import datetime +import datetime as dt import email.header import email.utils import errno @@ -50,7 +50,6 @@ from ..compat import ( compat_expanduser, compat_HTMLParseError, compat_os_name, - compat_shlex_quote, ) from ..dependencies import xattr @@ -836,9 +835,11 @@ class Popen(subprocess.Popen): if shell and compat_os_name == 'nt' and kwargs.get('executable') is None: if not isinstance(args, str): - args = ' '.join(compat_shlex_quote(a) for a in args) + args = shell_quote(args, shell=True) shell = False - args = f'{self.__comspec()} /Q /S /D /V:OFF /C "{args}"' + # Set variable for `cmd.exe` newline escaping (see `utils.shell_quote`) + env['='] = '"^\n\n"' + args = f'{self.__comspec()} /Q /S /D /V:OFF /E:ON /C "{args}"' super().__init__(args, *remaining, env=env, shell=shell, **kwargs, startupinfo=self._startupinfo) @@ -1150,14 +1151,14 @@ def extract_timezone(date_str): timezone = TIMEZONE_NAMES.get(m and m.group('tz').strip()) if timezone is not None: date_str = date_str[:-len(m.group('tz'))] - timezone = datetime.timedelta(hours=timezone or 0) + timezone = dt.timedelta(hours=timezone or 0) else: date_str = date_str[:-len(m.group('tz'))] if not m.group('sign'): - timezone = datetime.timedelta() + timezone = dt.timedelta() else: sign = 1 if m.group('sign') == '+' else -1 - timezone = datetime.timedelta( + timezone = dt.timedelta( hours=sign * int(m.group('hours')), minutes=sign * int(m.group('minutes'))) return timezone, date_str @@ -1176,8 +1177,8 @@ def parse_iso8601(date_str, delimiter='T', timezone=None): with contextlib.suppress(ValueError): date_format = f'%Y-%m-%d{delimiter}%H:%M:%S' - dt = datetime.datetime.strptime(date_str, date_format) - timezone - return calendar.timegm(dt.timetuple()) + dt_ = dt.datetime.strptime(date_str, date_format) - timezone + return calendar.timegm(dt_.timetuple()) def date_formats(day_first=True): @@ -1198,12 +1199,12 @@ def unified_strdate(date_str, day_first=True): for expression in date_formats(day_first): with contextlib.suppress(ValueError): - upload_date = datetime.datetime.strptime(date_str, expression).strftime('%Y%m%d') + upload_date = dt.datetime.strptime(date_str, expression).strftime('%Y%m%d') if upload_date is None: timetuple = email.utils.parsedate_tz(date_str) if timetuple: with contextlib.suppress(ValueError): - upload_date = datetime.datetime(*timetuple[:6]).strftime('%Y%m%d') + upload_date = dt.datetime(*timetuple[:6]).strftime('%Y%m%d') if upload_date is not None: return str(upload_date) @@ -1233,8 +1234,8 @@ def unified_timestamp(date_str, day_first=True): for expression in date_formats(day_first): with contextlib.suppress(ValueError): - dt = datetime.datetime.strptime(date_str, expression) - timezone + datetime.timedelta(hours=pm_delta) - return calendar.timegm(dt.timetuple()) + dt_ = dt.datetime.strptime(date_str, expression) - timezone + dt.timedelta(hours=pm_delta) + return calendar.timegm(dt_.timetuple()) timetuple = email.utils.parsedate_tz(date_str) if timetuple: @@ -1272,11 +1273,11 @@ def datetime_from_str(date_str, precision='auto', format='%Y%m%d'): if precision == 'auto': auto_precision = True precision = 'microsecond' - today = datetime_round(datetime.datetime.now(datetime.timezone.utc), precision) + today = datetime_round(dt.datetime.now(dt.timezone.utc), precision) if date_str in ('now', 'today'): return today if date_str == 'yesterday': - return today - datetime.timedelta(days=1) + return today - dt.timedelta(days=1) match = re.match( r'(?P<start>.+)(?P<sign>[+-])(?P<time>\d+)(?P<unit>microsecond|second|minute|hour|day|week|month|year)s?', date_str) @@ -1291,13 +1292,13 @@ def datetime_from_str(date_str, precision='auto', format='%Y%m%d'): if unit == 'week': unit = 'day' time *= 7 - delta = datetime.timedelta(**{unit + 's': time}) + delta = dt.timedelta(**{unit + 's': time}) new_date = start_time + delta if auto_precision: return datetime_round(new_date, unit) return new_date - return datetime_round(datetime.datetime.strptime(date_str, format), precision) + return datetime_round(dt.datetime.strptime(date_str, format), precision) def date_from_str(date_str, format='%Y%m%d', strict=False): @@ -1312,21 +1313,21 @@ def date_from_str(date_str, format='%Y%m%d', strict=False): return datetime_from_str(date_str, precision='microsecond', format=format).date() -def datetime_add_months(dt, months): +def datetime_add_months(dt_, months): """Increment/Decrement a datetime object by months.""" - month = dt.month + months - 1 - year = dt.year + month // 12 + month = dt_.month + months - 1 + year = dt_.year + month // 12 month = month % 12 + 1 - day = min(dt.day, calendar.monthrange(year, month)[1]) - return dt.replace(year, month, day) + day = min(dt_.day, calendar.monthrange(year, month)[1]) + return dt_.replace(year, month, day) -def datetime_round(dt, precision='day'): +def datetime_round(dt_, precision='day'): """ Round a datetime object's time to a specific precision """ if precision == 'microsecond': - return dt + return dt_ unit_seconds = { 'day': 86400, @@ -1335,8 +1336,8 @@ def datetime_round(dt, precision='day'): 'second': 1, } roundto = lambda x, n: ((x + n / 2) // n) * n - timestamp = roundto(calendar.timegm(dt.timetuple()), unit_seconds[precision]) - return datetime.datetime.fromtimestamp(timestamp, datetime.timezone.utc) + timestamp = roundto(calendar.timegm(dt_.timetuple()), unit_seconds[precision]) + return dt.datetime.fromtimestamp(timestamp, dt.timezone.utc) def hyphenate_date(date_str): @@ -1357,11 +1358,11 @@ class DateRange: if start is not None: self.start = date_from_str(start, strict=True) else: - self.start = datetime.datetime.min.date() + self.start = dt.datetime.min.date() if end is not None: self.end = date_from_str(end, strict=True) else: - self.end = datetime.datetime.max.date() + self.end = dt.datetime.max.date() if self.start > self.end: raise ValueError('Date range: "%s" , the start date must be before the end date' % self) @@ -1372,7 +1373,7 @@ class DateRange: def __contains__(self, date): """Check if the date is in the range""" - if not isinstance(date, datetime.date): + if not isinstance(date, dt.date): date = date_from_str(date) return self.start <= date <= self.end @@ -1637,15 +1638,38 @@ def get_filesystem_encoding(): return encoding if encoding is not None else 'utf-8' -def shell_quote(args): - quoted_args = [] - encoding = get_filesystem_encoding() - for a in args: - if isinstance(a, bytes): - # We may get a filename encoded with 'encodeFilename' - a = a.decode(encoding) - quoted_args.append(compat_shlex_quote(a)) - return ' '.join(quoted_args) +_WINDOWS_QUOTE_TRANS = str.maketrans({'"': '\\"', '\\': '\\\\'}) +_CMD_QUOTE_TRANS = str.maketrans({ + # Keep quotes balanced by replacing them with `""` instead of `\\"` + '"': '""', + # Requires a variable `=` containing `"^\n\n"` (set in `utils.Popen`) + # `=` should be unique since variables containing `=` cannot be set using cmd + '\n': '%=%', + # While we are only required to escape backslashes immediately before quotes, + # we instead escape all of 'em anyways to be consistent + '\\': '\\\\', + # Use zero length variable replacement so `%` doesn't get expanded + # `cd` is always set as long as extensions are enabled (`/E:ON` in `utils.Popen`) + '%': '%%cd:~,%', +}) + + +def shell_quote(args, *, shell=False): + args = list(variadic(args)) + if any(isinstance(item, bytes) for item in args): + deprecation_warning('Passing bytes to utils.shell_quote is deprecated') + encoding = get_filesystem_encoding() + for index, item in enumerate(args): + if isinstance(item, bytes): + args[index] = item.decode(encoding) + + if compat_os_name != 'nt': + return shlex.join(args) + + trans = _CMD_QUOTE_TRANS if shell else _WINDOWS_QUOTE_TRANS + return ' '.join( + s if re.fullmatch(r'[\w#$*\-+./:?@\\]+', s, re.ASCII) else s.translate(trans).join('""') + for s in args) def smuggle_url(url, data): @@ -1996,12 +2020,12 @@ def strftime_or_none(timestamp, date_format='%Y%m%d', default=None): if isinstance(timestamp, (int, float)): # unix timestamp # Using naive datetime here can break timestamp() in Windows # Ref: https://github.com/yt-dlp/yt-dlp/issues/5185, https://github.com/python/cpython/issues/94414 - # Also, datetime.datetime.fromtimestamp breaks for negative timestamps + # Also, dt.datetime.fromtimestamp breaks for negative timestamps # Ref: https://github.com/yt-dlp/yt-dlp/issues/6706#issuecomment-1496842642 - datetime_object = (datetime.datetime.fromtimestamp(0, datetime.timezone.utc) - + datetime.timedelta(seconds=timestamp)) + datetime_object = (dt.datetime.fromtimestamp(0, dt.timezone.utc) + + dt.timedelta(seconds=timestamp)) elif isinstance(timestamp, str): # assume YYYYMMDD - datetime_object = datetime.datetime.strptime(timestamp, '%Y%m%d') + datetime_object = dt.datetime.strptime(timestamp, '%Y%m%d') date_format = re.sub( # Support %s on windows r'(?<!%)(%%)*%s', rf'\g<1>{int(datetime_object.timestamp())}', date_format) return datetime_object.strftime(date_format) @@ -2849,7 +2873,7 @@ def ytdl_is_updateable(): def args_to_str(args): # Get a short string representation for a subprocess command - return ' '.join(compat_shlex_quote(a) for a in args) + return shell_quote(args) def error_to_str(err): @@ -4490,10 +4514,10 @@ def write_xattr(path, key, value): def random_birthday(year_field, month_field, day_field): - start_date = datetime.date(1950, 1, 1) - end_date = datetime.date(1995, 12, 31) + start_date = dt.date(1950, 1, 1) + end_date = dt.date(1995, 12, 31) offset = random.randint(0, (end_date - start_date).days) - random_date = start_date + datetime.timedelta(offset) + random_date = start_date + dt.timedelta(offset) return { year_field: str(random_date.year), month_field: str(random_date.month), @@ -4672,7 +4696,7 @@ def time_seconds(**kwargs): """ Returns TZ-aware time in seconds since the epoch (1970-01-01T00:00:00Z) """ - return time.time() + datetime.timedelta(**kwargs).total_seconds() + return time.time() + dt.timedelta(**kwargs).total_seconds() # create a JSON Web Signature (jws) with HS256 algorithm @@ -5415,6 +5439,17 @@ class FormatSorter: return tuple(self._calculate_field_preference(format, field) for field in self._order) +def filesize_from_tbr(tbr, duration): + """ + @param tbr: Total bitrate in kbps (1000 bits/sec) + @param duration: Duration in seconds + @returns Filesize in bytes + """ + if tbr is None or duration is None: + return None + return int(duration * tbr * (1000 / 8)) + + # XXX: Temporary class _YDLLogger: def __init__(self, ydl=None): diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index 8938f4c..96eb2ed 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -1,5 +1,6 @@ import collections.abc import contextlib +import http.cookies import inspect import itertools import re @@ -28,7 +29,8 @@ def traverse_obj( Each of the provided `paths` is tested and the first producing a valid result will be returned. The next path will also be tested if the path branched but no results could be found. - Supported values for traversal are `Mapping`, `Iterable` and `re.Match`. + Supported values for traversal are `Mapping`, `Iterable`, `re.Match`, + `xml.etree.ElementTree` (xpath) and `http.cookies.Morsel`. Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. @@ -36,8 +38,8 @@ def traverse_obj( The keys in the path can be one of: - `None`: Return the current object. - `set`: Requires the only item in the set to be a type or function, - like `{type}`/`{func}`. If a `type`, returns only values - of this type. If a function, returns `func(obj)`. + like `{type}`/`{type, type, ...}/`{func}`. If a `type`, return only + values of this type. If a function, returns `func(obj)`. - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`. - `slice`: Branch out and return all values in `obj[key]`. - `Ellipsis`: Branch out and return a list of all values. @@ -48,8 +50,10 @@ def traverse_obj( For `Iterable`s, `key` is the index of the value. For `re.Match`es, `key` is the group number (0 = full match) as well as additionally any group names, if given. - - `dict` Transform the current object and return a matching dict. + - `dict`: Transform the current object and return a matching dict. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. + - `any`-builtin: Take the first matching object and return it, resetting branching. + - `all`-builtin: Take all matching objects and return them as a list, resetting branching. `tuple`, `list`, and `dict` all support nested paths and branches. @@ -102,10 +106,10 @@ def traverse_obj( result = obj elif isinstance(key, set): - assert len(key) == 1, 'Set should only be used to wrap a single item' item = next(iter(key)) - if isinstance(item, type): - if isinstance(obj, item): + if len(key) > 1 or isinstance(item, type): + assert all(isinstance(item, type) for item in key) + if isinstance(obj, tuple(key)): result = obj else: result = try_call(item, args=(obj,)) @@ -117,6 +121,8 @@ def traverse_obj( elif key is ...: branching = True + if isinstance(obj, http.cookies.Morsel): + obj = dict(obj, key=obj.key, value=obj.value) if isinstance(obj, collections.abc.Mapping): result = obj.values() elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): @@ -131,6 +137,8 @@ def traverse_obj( elif callable(key): branching = True + if isinstance(obj, http.cookies.Morsel): + obj = dict(obj, key=obj.key, value=obj.value) if isinstance(obj, collections.abc.Mapping): iter_obj = obj.items() elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): @@ -157,6 +165,8 @@ def traverse_obj( } or None elif isinstance(obj, collections.abc.Mapping): + if isinstance(obj, http.cookies.Morsel): + obj = dict(obj, key=obj.key, value=obj.value) result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else next((v for k, v in obj.items() if casefold(k) == key), None)) @@ -179,7 +189,7 @@ def traverse_obj( elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str): xpath, _, special = key.rpartition('/') - if not special.startswith('@') and special != 'text()': + if not special.startswith('@') and not special.endswith('()'): xpath = key special = None @@ -198,7 +208,7 @@ def traverse_obj( return try_call(element.attrib.get, args=(special[1:],)) if special == 'text()': return element.text - assert False, f'apply_specials is missing case for {special!r}' + raise SyntaxError(f'apply_specials is missing case for {special!r}') if xpath: result = list(map(apply_specials, obj.iterfind(xpath))) @@ -228,6 +238,15 @@ def traverse_obj( if not casesense and isinstance(key, str): key = key.casefold() + if key in (any, all): + has_branched = False + filtered_objs = (obj for obj in objs if obj not in (None, {})) + if key is any: + objs = (next(filtered_objs, None),) + else: + objs = (list(filtered_objs),) + continue + if __debug__ and callable(key): # Verify function signature inspect.signature(key).bind(None, None) |