summaryrefslogtreecommitdiffstats
path: root/yt_dlp/networking/impersonate.py
blob: ca66180c707db3badbec9b58bf1bfacaad78fb7d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from __future__ import annotations

import re
from abc import ABC
from dataclasses import dataclass
from typing import Any

from .common import RequestHandler, register_preference
from .exceptions import UnsupportedRequest
from ..compat.types import NoneType
from ..utils import classproperty, join_nonempty
from ..utils.networking import std_headers


@dataclass(order=True, frozen=True)
class ImpersonateTarget:
    """
    A target for browser impersonation.

    Parameters:
    @param client: the client to impersonate
    @param version: the client version to impersonate
    @param os: the client OS to impersonate
    @param os_version: the client OS version to impersonate

    Note: None is used to indicate to match any.

    """
    client: str | None = None
    version: str | None = None
    os: str | None = None
    os_version: str | None = None

    def __post_init__(self):
        if self.version and not self.client:
            raise ValueError('client is required if version is set')
        if self.os_version and not self.os:
            raise ValueError('os is required if os_version is set')

    def __contains__(self, target: ImpersonateTarget):
        if not isinstance(target, ImpersonateTarget):
            return False
        return (
            (self.client is None or target.client is None or self.client == target.client)
            and (self.version is None or target.version is None or self.version == target.version)
            and (self.os is None or target.os is None or self.os == target.os)
            and (self.os_version is None or target.os_version is None or self.os_version == target.os_version)
        )

    def __str__(self):
        return f'{join_nonempty(self.client, self.version)}:{join_nonempty(self.os, self.os_version)}'.rstrip(':')

    @classmethod
    def from_str(cls, target: str):
        mobj = re.fullmatch(r'(?:(?P<client>[^:-]+)(?:-(?P<version>[^:-]+))?)?(?::(?:(?P<os>[^:-]+)(?:-(?P<os_version>[^:-]+))?)?)?', target)
        if not mobj:
            raise ValueError(f'Invalid impersonate target "{target}"')
        return cls(**mobj.groupdict())


class ImpersonateRequestHandler(RequestHandler, ABC):
    """
    Base class for request handlers that support browser impersonation.

    This provides a method for checking the validity of the impersonate extension,
    which can be used in _check_extensions.

    Impersonate targets consist of a client, version, os and os_ver.
    See the ImpersonateTarget class for more details.

    The following may be defined:
     - `_SUPPORTED_IMPERSONATE_TARGET_MAP`: a dict mapping supported targets to custom object.
                Any Request with an impersonate target not in this list will raise an UnsupportedRequest.
                Set to None to disable this check.
                Note: Entries are in order of preference

    Parameters:
    @param impersonate: the default impersonate target to use for requests.
                        Set to None to disable impersonation.
    """
    _SUPPORTED_IMPERSONATE_TARGET_MAP: dict[ImpersonateTarget, Any] = {}

    def __init__(self, *, impersonate: ImpersonateTarget = None, **kwargs):
        super().__init__(**kwargs)
        self.impersonate = impersonate

    def _check_impersonate_target(self, target: ImpersonateTarget):
        assert isinstance(target, (ImpersonateTarget, NoneType))
        if target is None or not self.supported_targets:
            return
        if not self.is_supported_target(target):
            raise UnsupportedRequest(f'Unsupported impersonate target: {target}')

    def _check_extensions(self, extensions):
        super()._check_extensions(extensions)
        if 'impersonate' in extensions:
            self._check_impersonate_target(extensions.get('impersonate'))

    def _validate(self, request):
        super()._validate(request)
        self._check_impersonate_target(self.impersonate)

    def _resolve_target(self, target: ImpersonateTarget | None):
        """Resolve a target to a supported target."""
        if target is None:
            return
        for supported_target in self.supported_targets:
            if target in supported_target:
                if self.verbose:
                    self._logger.stdout(
                        f'{self.RH_NAME}: resolved impersonate target {target} to {supported_target}')
                return supported_target

    @classproperty
    def supported_targets(self) -> tuple[ImpersonateTarget, ...]:
        return tuple(self._SUPPORTED_IMPERSONATE_TARGET_MAP.keys())

    def is_supported_target(self, target: ImpersonateTarget):
        assert isinstance(target, ImpersonateTarget)
        return self._resolve_target(target) is not None

    def _get_request_target(self, request):
        """Get the requested target for the request"""
        return self._resolve_target(request.extensions.get('impersonate') or self.impersonate)

    def _get_impersonate_headers(self, request):
        headers = self._merge_headers(request.headers)
        if self._get_request_target(request) is not None:
            # remove all headers present in std_headers
            # todo: change this to not depend on std_headers
            for k, v in std_headers.items():
                if headers.get(k) == v:
                    headers.pop(k)
        return headers


@register_preference(ImpersonateRequestHandler)
def impersonate_preference(rh, request):
    if request.extensions.get('impersonate') or rh.impersonate:
        return 1000
    return 0