diff options
Diffstat (limited to 'third_party/python/aiohttp/aiohttp/cookiejar.py')
-rw-r--r-- | third_party/python/aiohttp/aiohttp/cookiejar.py | 382 |
1 files changed, 382 insertions, 0 deletions
diff --git a/third_party/python/aiohttp/aiohttp/cookiejar.py b/third_party/python/aiohttp/aiohttp/cookiejar.py new file mode 100644 index 0000000000..b6b59d6289 --- /dev/null +++ b/third_party/python/aiohttp/aiohttp/cookiejar.py @@ -0,0 +1,382 @@ +import asyncio +import datetime +import os # noqa +import pathlib +import pickle +import re +from collections import defaultdict +from http.cookies import BaseCookie, Morsel, SimpleCookie +from typing import ( # noqa + DefaultDict, + Dict, + Iterable, + Iterator, + Mapping, + Optional, + Set, + Tuple, + Union, + cast, +) + +from yarl import URL + +from .abc import AbstractCookieJar +from .helpers import is_ip_address, next_whole_second +from .typedefs import LooseCookies, PathLike + +__all__ = ("CookieJar", "DummyCookieJar") + + +CookieItem = Union[str, "Morsel[str]"] + + +class CookieJar(AbstractCookieJar): + """Implements cookie storage adhering to RFC 6265.""" + + DATE_TOKENS_RE = re.compile( + r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*" + r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)" + ) + + DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})") + + DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})") + + DATE_MONTH_RE = re.compile( + "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|" "(aug)|(sep)|(oct)|(nov)|(dec)", + re.I, + ) + + DATE_YEAR_RE = re.compile(r"(\d{2,4})") + + MAX_TIME = datetime.datetime.max.replace(tzinfo=datetime.timezone.utc) + + MAX_32BIT_TIME = datetime.datetime.utcfromtimestamp(2 ** 31 - 1) + + def __init__( + self, + *, + unsafe: bool = False, + quote_cookie: bool = True, + loop: Optional[asyncio.AbstractEventLoop] = None + ) -> None: + super().__init__(loop=loop) + self._cookies = defaultdict( + SimpleCookie + ) # type: DefaultDict[str, SimpleCookie[str]] + self._host_only_cookies = set() # type: Set[Tuple[str, str]] + self._unsafe = unsafe + self._quote_cookie = quote_cookie + self._next_expiration = next_whole_second() + self._expirations = {} # type: Dict[Tuple[str, str], datetime.datetime] + # #4515: datetime.max may not be representable on 32-bit platforms + self._max_time = self.MAX_TIME + try: + self._max_time.timestamp() + except OverflowError: + self._max_time = self.MAX_32BIT_TIME + + def save(self, file_path: PathLike) -> None: + file_path = pathlib.Path(file_path) + with file_path.open(mode="wb") as f: + pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL) + + def load(self, file_path: PathLike) -> None: + file_path = pathlib.Path(file_path) + with file_path.open(mode="rb") as f: + self._cookies = pickle.load(f) + + def clear(self) -> None: + self._cookies.clear() + self._host_only_cookies.clear() + self._next_expiration = next_whole_second() + self._expirations.clear() + + def __iter__(self) -> "Iterator[Morsel[str]]": + self._do_expiration() + for val in self._cookies.values(): + yield from val.values() + + def __len__(self) -> int: + return sum(1 for i in self) + + def _do_expiration(self) -> None: + now = datetime.datetime.now(datetime.timezone.utc) + if self._next_expiration > now: + return + if not self._expirations: + return + next_expiration = self._max_time + to_del = [] + cookies = self._cookies + expirations = self._expirations + for (domain, name), when in expirations.items(): + if when <= now: + cookies[domain].pop(name, None) + to_del.append((domain, name)) + self._host_only_cookies.discard((domain, name)) + else: + next_expiration = min(next_expiration, when) + for key in to_del: + del expirations[key] + + try: + self._next_expiration = next_expiration.replace( + microsecond=0 + ) + datetime.timedelta(seconds=1) + except OverflowError: + self._next_expiration = self._max_time + + def _expire_cookie(self, when: datetime.datetime, domain: str, name: str) -> None: + self._next_expiration = min(self._next_expiration, when) + self._expirations[(domain, name)] = when + + def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: + """Update cookies.""" + hostname = response_url.raw_host + + if not self._unsafe and is_ip_address(hostname): + # Don't accept cookies from IPs + return + + if isinstance(cookies, Mapping): + cookies = cookies.items() + + for name, cookie in cookies: + if not isinstance(cookie, Morsel): + tmp = SimpleCookie() # type: SimpleCookie[str] + tmp[name] = cookie # type: ignore + cookie = tmp[name] + + domain = cookie["domain"] + + # ignore domains with trailing dots + if domain.endswith("."): + domain = "" + del cookie["domain"] + + if not domain and hostname is not None: + # Set the cookie's domain to the response hostname + # and set its host-only-flag + self._host_only_cookies.add((hostname, name)) + domain = cookie["domain"] = hostname + + if domain.startswith("."): + # Remove leading dot + domain = domain[1:] + cookie["domain"] = domain + + if hostname and not self._is_domain_match(domain, hostname): + # Setting cookies for different domains is not allowed + continue + + path = cookie["path"] + if not path or not path.startswith("/"): + # Set the cookie's path to the response path + path = response_url.path + if not path.startswith("/"): + path = "/" + else: + # Cut everything from the last slash to the end + path = "/" + path[1 : path.rfind("/")] + cookie["path"] = path + + max_age = cookie["max-age"] + if max_age: + try: + delta_seconds = int(max_age) + try: + max_age_expiration = datetime.datetime.now( + datetime.timezone.utc + ) + datetime.timedelta(seconds=delta_seconds) + except OverflowError: + max_age_expiration = self._max_time + self._expire_cookie(max_age_expiration, domain, name) + except ValueError: + cookie["max-age"] = "" + + else: + expires = cookie["expires"] + if expires: + expire_time = self._parse_date(expires) + if expire_time: + self._expire_cookie(expire_time, domain, name) + else: + cookie["expires"] = "" + + self._cookies[domain][name] = cookie + + self._do_expiration() + + def filter_cookies( + self, request_url: URL = URL() + ) -> Union["BaseCookie[str]", "SimpleCookie[str]"]: + """Returns this jar's cookies filtered by their attributes.""" + self._do_expiration() + request_url = URL(request_url) + filtered: Union["SimpleCookie[str]", "BaseCookie[str]"] = ( + SimpleCookie() if self._quote_cookie else BaseCookie() + ) + hostname = request_url.raw_host or "" + is_not_secure = request_url.scheme not in ("https", "wss") + + for cookie in self: + name = cookie.key + domain = cookie["domain"] + + # Send shared cookies + if not domain: + filtered[name] = cookie.value + continue + + if not self._unsafe and is_ip_address(hostname): + continue + + if (domain, name) in self._host_only_cookies: + if domain != hostname: + continue + elif not self._is_domain_match(domain, hostname): + continue + + if not self._is_path_match(request_url.path, cookie["path"]): + continue + + if is_not_secure and cookie["secure"]: + continue + + # It's critical we use the Morsel so the coded_value + # (based on cookie version) is preserved + mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) + mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) + filtered[name] = mrsl_val + + return filtered + + @staticmethod + def _is_domain_match(domain: str, hostname: str) -> bool: + """Implements domain matching adhering to RFC 6265.""" + if hostname == domain: + return True + + if not hostname.endswith(domain): + return False + + non_matching = hostname[: -len(domain)] + + if not non_matching.endswith("."): + return False + + return not is_ip_address(hostname) + + @staticmethod + def _is_path_match(req_path: str, cookie_path: str) -> bool: + """Implements path matching adhering to RFC 6265.""" + if not req_path.startswith("/"): + req_path = "/" + + if req_path == cookie_path: + return True + + if not req_path.startswith(cookie_path): + return False + + if cookie_path.endswith("/"): + return True + + non_matching = req_path[len(cookie_path) :] + + return non_matching.startswith("/") + + @classmethod + def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]: + """Implements date string parsing adhering to RFC 6265.""" + if not date_str: + return None + + found_time = False + found_day = False + found_month = False + found_year = False + + hour = minute = second = 0 + day = 0 + month = 0 + year = 0 + + for token_match in cls.DATE_TOKENS_RE.finditer(date_str): + + token = token_match.group("token") + + if not found_time: + time_match = cls.DATE_HMS_TIME_RE.match(token) + if time_match: + found_time = True + hour, minute, second = [int(s) for s in time_match.groups()] + continue + + if not found_day: + day_match = cls.DATE_DAY_OF_MONTH_RE.match(token) + if day_match: + found_day = True + day = int(day_match.group()) + continue + + if not found_month: + month_match = cls.DATE_MONTH_RE.match(token) + if month_match: + found_month = True + assert month_match.lastindex is not None + month = month_match.lastindex + continue + + if not found_year: + year_match = cls.DATE_YEAR_RE.match(token) + if year_match: + found_year = True + year = int(year_match.group()) + + if 70 <= year <= 99: + year += 1900 + elif 0 <= year <= 69: + year += 2000 + + if False in (found_day, found_month, found_year, found_time): + return None + + if not 1 <= day <= 31: + return None + + if year < 1601 or hour > 23 or minute > 59 or second > 59: + return None + + return datetime.datetime( + year, month, day, hour, minute, second, tzinfo=datetime.timezone.utc + ) + + +class DummyCookieJar(AbstractCookieJar): + """Implements a dummy cookie storage. + + It can be used with the ClientSession when no cookie processing is needed. + + """ + + def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + super().__init__(loop=loop) + + def __iter__(self) -> "Iterator[Morsel[str]]": + while False: + yield None + + def __len__(self) -> int: + return 0 + + def clear(self) -> None: + pass + + def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: + pass + + def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": + return SimpleCookie() |