summaryrefslogtreecommitdiffstats
path: root/third_party/python/responses/responses.py
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/python/responses/responses.py')
-rw-r--r--third_party/python/responses/responses.py653
1 files changed, 653 insertions, 0 deletions
diff --git a/third_party/python/responses/responses.py b/third_party/python/responses/responses.py
new file mode 100644
index 0000000000..9de936805c
--- /dev/null
+++ b/third_party/python/responses/responses.py
@@ -0,0 +1,653 @@
+from __future__ import absolute_import, print_function, division, unicode_literals
+
+import _io
+import inspect
+import json as json_module
+import logging
+import re
+import six
+
+from collections import namedtuple
+from functools import update_wrapper
+from requests.adapters import HTTPAdapter
+from requests.exceptions import ConnectionError
+from requests.sessions import REDIRECT_STATI
+from requests.utils import cookiejar_from_dict
+
+try:
+ from collections.abc import Sequence, Sized
+except ImportError:
+ from collections import Sequence, Sized
+
+try:
+ from requests.packages.urllib3.response import HTTPResponse
+except ImportError:
+ from urllib3.response import HTTPResponse
+
+if six.PY2:
+ from urlparse import urlparse, parse_qsl, urlsplit, urlunsplit
+ from urllib import quote
+else:
+ from urllib.parse import urlparse, parse_qsl, urlsplit, urlunsplit, quote
+
+if six.PY2:
+ try:
+ from six import cStringIO as BufferIO
+ except ImportError:
+ from six import StringIO as BufferIO
+else:
+ from io import BytesIO as BufferIO
+
+try:
+ from unittest import mock as std_mock
+except ImportError:
+ import mock as std_mock
+
+try:
+ Pattern = re._pattern_type
+except AttributeError:
+ # Python 3.7
+ Pattern = re.Pattern
+
+UNSET = object()
+
+Call = namedtuple("Call", ["request", "response"])
+
+_real_send = HTTPAdapter.send
+
+logger = logging.getLogger("responses")
+
+
+def _is_string(s):
+ return isinstance(s, six.string_types)
+
+
+def _has_unicode(s):
+ return any(ord(char) > 128 for char in s)
+
+
+def _clean_unicode(url):
+ # Clean up domain names, which use punycode to handle unicode chars
+ urllist = list(urlsplit(url))
+ netloc = urllist[1]
+ if _has_unicode(netloc):
+ domains = netloc.split(".")
+ for i, d in enumerate(domains):
+ if _has_unicode(d):
+ d = "xn--" + d.encode("punycode").decode("ascii")
+ domains[i] = d
+ urllist[1] = ".".join(domains)
+ url = urlunsplit(urllist)
+
+ # Clean up path/query/params, which use url-encoding to handle unicode chars
+ if isinstance(url.encode("utf8"), six.string_types):
+ url = url.encode("utf8")
+ chars = list(url)
+ for i, x in enumerate(chars):
+ if ord(x) > 128:
+ chars[i] = quote(x)
+
+ return "".join(chars)
+
+
+def _is_redirect(response):
+ try:
+ # 2.0.0 <= requests <= 2.2
+ return response.is_redirect
+
+ except AttributeError:
+ # requests > 2.2
+ return (
+ # use request.sessions conditional
+ response.status_code in REDIRECT_STATI
+ and "location" in response.headers
+ )
+
+
+def _cookies_from_headers(headers):
+ try:
+ import http.cookies as cookies
+
+ resp_cookie = cookies.SimpleCookie()
+ resp_cookie.load(headers["set-cookie"])
+
+ cookies_dict = {name: v.value for name, v in resp_cookie.items()}
+ except ImportError:
+ from cookies import Cookies
+
+ resp_cookies = Cookies.from_request(headers["set-cookie"])
+ cookies_dict = {v.name: v.value for _, v in resp_cookies.items()}
+ return cookiejar_from_dict(cookies_dict)
+
+
+_wrapper_template = """\
+def wrapper%(wrapper_args)s:
+ with responses:
+ return func%(func_args)s
+"""
+
+
+def get_wrapped(func, responses):
+ if six.PY2:
+ args, a, kw, defaults = inspect.getargspec(func)
+ wrapper_args = inspect.formatargspec(args, a, kw, defaults)
+
+ # Preserve the argspec for the wrapped function so that testing
+ # tools such as pytest can continue to use their fixture injection.
+ if hasattr(func, "__self__"):
+ args = args[1:] # Omit 'self'
+ func_args = inspect.formatargspec(args, a, kw, None)
+ else:
+ signature = inspect.signature(func)
+ signature = signature.replace(return_annotation=inspect.Signature.empty)
+ # If the function is wrapped, switch to *args, **kwargs for the parameters
+ # as we can't rely on the signature to give us the arguments the function will
+ # be called with. For example unittest.mock.patch uses required args that are
+ # not actually passed to the function when invoked.
+ if hasattr(func, "__wrapped__"):
+ wrapper_params = [
+ inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL),
+ inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD),
+ ]
+ else:
+ wrapper_params = [
+ param.replace(annotation=inspect.Parameter.empty)
+ for param in signature.parameters.values()
+ ]
+ signature = signature.replace(parameters=wrapper_params)
+
+ wrapper_args = str(signature)
+ params_without_defaults = [
+ param.replace(
+ annotation=inspect.Parameter.empty, default=inspect.Parameter.empty
+ )
+ for param in signature.parameters.values()
+ ]
+ signature = signature.replace(parameters=params_without_defaults)
+ func_args = str(signature)
+
+ evaldict = {"func": func, "responses": responses}
+ six.exec_(
+ _wrapper_template % {"wrapper_args": wrapper_args, "func_args": func_args},
+ evaldict,
+ )
+ wrapper = evaldict["wrapper"]
+ update_wrapper(wrapper, func)
+ return wrapper
+
+
+class CallList(Sequence, Sized):
+ def __init__(self):
+ self._calls = []
+
+ def __iter__(self):
+ return iter(self._calls)
+
+ def __len__(self):
+ return len(self._calls)
+
+ def __getitem__(self, idx):
+ return self._calls[idx]
+
+ def add(self, request, response):
+ self._calls.append(Call(request, response))
+
+ def reset(self):
+ self._calls = []
+
+
+def _ensure_url_default_path(url):
+ if _is_string(url):
+ url_parts = list(urlsplit(url))
+ if url_parts[2] == "":
+ url_parts[2] = "/"
+ url = urlunsplit(url_parts)
+ return url
+
+
+def _handle_body(body):
+ if isinstance(body, six.text_type):
+ body = body.encode("utf-8")
+ if isinstance(body, _io.BufferedReader):
+ return body
+
+ return BufferIO(body)
+
+
+_unspecified = object()
+
+
+class BaseResponse(object):
+ content_type = None
+ headers = None
+
+ stream = False
+
+ def __init__(self, method, url, match_querystring=_unspecified):
+ self.method = method
+ # ensure the url has a default path set if the url is a string
+ self.url = _ensure_url_default_path(url)
+ self.match_querystring = self._should_match_querystring(match_querystring)
+ self.call_count = 0
+
+ def __eq__(self, other):
+ if not isinstance(other, BaseResponse):
+ return False
+
+ if self.method != other.method:
+ return False
+
+ # Can't simply do a equality check on the objects directly here since __eq__ isn't
+ # implemented for regex. It might seem to work as regex is using a cache to return
+ # the same regex instances, but it doesn't in all cases.
+ self_url = self.url.pattern if isinstance(self.url, Pattern) else self.url
+ other_url = other.url.pattern if isinstance(other.url, Pattern) else other.url
+
+ return self_url == other_url
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def _url_matches_strict(self, url, other):
+ url_parsed = urlparse(url)
+ other_parsed = urlparse(other)
+
+ if url_parsed[:3] != other_parsed[:3]:
+ return False
+
+ url_qsl = sorted(parse_qsl(url_parsed.query))
+ other_qsl = sorted(parse_qsl(other_parsed.query))
+
+ if len(url_qsl) != len(other_qsl):
+ return False
+
+ for (a_k, a_v), (b_k, b_v) in zip(url_qsl, other_qsl):
+ if a_k != b_k:
+ return False
+
+ if a_v != b_v:
+ return False
+
+ return True
+
+ def _should_match_querystring(self, match_querystring_argument):
+ if match_querystring_argument is not _unspecified:
+ return match_querystring_argument
+
+ if isinstance(self.url, Pattern):
+ # the old default from <= 0.9.0
+ return False
+
+ return bool(urlparse(self.url).query)
+
+ def _url_matches(self, url, other, match_querystring=False):
+ if _is_string(url):
+ if _has_unicode(url):
+ url = _clean_unicode(url)
+ if not isinstance(other, six.text_type):
+ other = other.encode("ascii").decode("utf8")
+ if match_querystring:
+ return self._url_matches_strict(url, other)
+
+ else:
+ url_without_qs = url.split("?", 1)[0]
+ other_without_qs = other.split("?", 1)[0]
+ return url_without_qs == other_without_qs
+
+ elif isinstance(url, Pattern) and url.match(other):
+ return True
+
+ else:
+ return False
+
+ def get_headers(self):
+ headers = {}
+ if self.content_type is not None:
+ headers["Content-Type"] = self.content_type
+ if self.headers:
+ headers.update(self.headers)
+ return headers
+
+ def get_response(self, request):
+ raise NotImplementedError
+
+ def matches(self, request):
+ if request.method != self.method:
+ return False
+
+ if not self._url_matches(self.url, request.url, self.match_querystring):
+ return False
+
+ return True
+
+
+class Response(BaseResponse):
+ def __init__(
+ self,
+ method,
+ url,
+ body="",
+ json=None,
+ status=200,
+ headers=None,
+ stream=False,
+ content_type=UNSET,
+ **kwargs
+ ):
+ # if we were passed a `json` argument,
+ # override the body and content_type
+ if json is not None:
+ assert not body
+ body = json_module.dumps(json)
+ if content_type is UNSET:
+ content_type = "application/json"
+
+ if content_type is UNSET:
+ content_type = "text/plain"
+
+ # body must be bytes
+ if isinstance(body, six.text_type):
+ body = body.encode("utf-8")
+
+ self.body = body
+ self.status = status
+ self.headers = headers
+ self.stream = stream
+ self.content_type = content_type
+ super(Response, self).__init__(method, url, **kwargs)
+
+ def get_response(self, request):
+ if self.body and isinstance(self.body, Exception):
+ raise self.body
+
+ headers = self.get_headers()
+ status = self.status
+ body = _handle_body(self.body)
+
+ return HTTPResponse(
+ status=status,
+ reason=six.moves.http_client.responses.get(status),
+ body=body,
+ headers=headers,
+ preload_content=False,
+ )
+
+
+class CallbackResponse(BaseResponse):
+ def __init__(
+ self, method, url, callback, stream=False, content_type="text/plain", **kwargs
+ ):
+ self.callback = callback
+ self.stream = stream
+ self.content_type = content_type
+ super(CallbackResponse, self).__init__(method, url, **kwargs)
+
+ def get_response(self, request):
+ headers = self.get_headers()
+
+ result = self.callback(request)
+ if isinstance(result, Exception):
+ raise result
+
+ status, r_headers, body = result
+ if isinstance(body, Exception):
+ raise body
+
+ body = _handle_body(body)
+ headers.update(r_headers)
+
+ return HTTPResponse(
+ status=status,
+ reason=six.moves.http_client.responses.get(status),
+ body=body,
+ headers=headers,
+ preload_content=False,
+ )
+
+
+class RequestsMock(object):
+ DELETE = "DELETE"
+ GET = "GET"
+ HEAD = "HEAD"
+ OPTIONS = "OPTIONS"
+ PATCH = "PATCH"
+ POST = "POST"
+ PUT = "PUT"
+ response_callback = None
+
+ def __init__(
+ self,
+ assert_all_requests_are_fired=True,
+ response_callback=None,
+ passthru_prefixes=(),
+ target="requests.adapters.HTTPAdapter.send",
+ ):
+ self._calls = CallList()
+ self.reset()
+ self.assert_all_requests_are_fired = assert_all_requests_are_fired
+ self.response_callback = response_callback
+ self.passthru_prefixes = tuple(passthru_prefixes)
+ self.target = target
+
+ def reset(self):
+ self._matches = []
+ self._calls.reset()
+
+ def add(
+ self,
+ method=None, # method or ``Response``
+ url=None,
+ body="",
+ adding_headers=None,
+ *args,
+ **kwargs
+ ):
+ """
+ A basic request:
+
+ >>> responses.add(responses.GET, 'http://example.com')
+
+ You can also directly pass an object which implements the
+ ``BaseResponse`` interface:
+
+ >>> responses.add(Response(...))
+
+ A JSON payload:
+
+ >>> responses.add(
+ >>> method='GET',
+ >>> url='http://example.com',
+ >>> json={'foo': 'bar'},
+ >>> )
+
+ Custom headers:
+
+ >>> responses.add(
+ >>> method='GET',
+ >>> url='http://example.com',
+ >>> headers={'X-Header': 'foo'},
+ >>> )
+
+
+ Strict query string matching:
+
+ >>> responses.add(
+ >>> method='GET',
+ >>> url='http://example.com?foo=bar',
+ >>> match_querystring=True
+ >>> )
+ """
+ if isinstance(method, BaseResponse):
+ self._matches.append(method)
+ return
+
+ if adding_headers is not None:
+ kwargs.setdefault("headers", adding_headers)
+
+ self._matches.append(Response(method=method, url=url, body=body, **kwargs))
+
+ def add_passthru(self, prefix):
+ """
+ Register a URL prefix to passthru any non-matching mock requests to.
+
+ For example, to allow any request to 'https://example.com', but require
+ mocks for the remainder, you would add the prefix as so:
+
+ >>> responses.add_passthru('https://example.com')
+ """
+ if _has_unicode(prefix):
+ prefix = _clean_unicode(prefix)
+ self.passthru_prefixes += (prefix,)
+
+ def remove(self, method_or_response=None, url=None):
+ """
+ Removes a response previously added using ``add()``, identified
+ either by a response object inheriting ``BaseResponse`` or
+ ``method`` and ``url``. Removes all matching responses.
+
+ >>> response.add(responses.GET, 'http://example.org')
+ >>> response.remove(responses.GET, 'http://example.org')
+ """
+ if isinstance(method_or_response, BaseResponse):
+ response = method_or_response
+ else:
+ response = BaseResponse(method=method_or_response, url=url)
+
+ while response in self._matches:
+ self._matches.remove(response)
+
+ def replace(self, method_or_response=None, url=None, body="", *args, **kwargs):
+ """
+ Replaces a response previously added using ``add()``. The signature
+ is identical to ``add()``. The response is identified using ``method``
+ and ``url``, and the first matching response is replaced.
+
+ >>> responses.add(responses.GET, 'http://example.org', json={'data': 1})
+ >>> responses.replace(responses.GET, 'http://example.org', json={'data': 2})
+ """
+ if isinstance(method_or_response, BaseResponse):
+ response = method_or_response
+ else:
+ response = Response(method=method_or_response, url=url, body=body, **kwargs)
+
+ index = self._matches.index(response)
+ self._matches[index] = response
+
+ def add_callback(
+ self, method, url, callback, match_querystring=False, content_type="text/plain"
+ ):
+ # ensure the url has a default path set if the url is a string
+ # url = _ensure_url_default_path(url, match_querystring)
+
+ self._matches.append(
+ CallbackResponse(
+ url=url,
+ method=method,
+ callback=callback,
+ content_type=content_type,
+ match_querystring=match_querystring,
+ )
+ )
+
+ @property
+ def calls(self):
+ return self._calls
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ success = type is None
+ self.stop(allow_assert=success)
+ self.reset()
+ return success
+
+ def activate(self, func):
+ return get_wrapped(func, self)
+
+ def _find_match(self, request):
+ found = None
+ found_match = None
+ for i, match in enumerate(self._matches):
+ if match.matches(request):
+ if found is None:
+ found = i
+ found_match = match
+ else:
+ # Multiple matches found. Remove & return the first match.
+ return self._matches.pop(found)
+
+ return found_match
+
+ def _on_request(self, adapter, request, **kwargs):
+ match = self._find_match(request)
+ resp_callback = self.response_callback
+
+ if match is None:
+ if request.url.startswith(self.passthru_prefixes):
+ logger.info("request.allowed-passthru", extra={"url": request.url})
+ return _real_send(adapter, request, **kwargs)
+
+ error_msg = (
+ "Connection refused by Responses: {0} {1} doesn't "
+ "match Responses Mock".format(request.method, request.url)
+ )
+ response = ConnectionError(error_msg)
+ response.request = request
+
+ self._calls.add(request, response)
+ response = resp_callback(response) if resp_callback else response
+ raise response
+
+ try:
+ response = adapter.build_response(request, match.get_response(request))
+ except Exception as response:
+ match.call_count += 1
+ self._calls.add(request, response)
+ response = resp_callback(response) if resp_callback else response
+ raise
+
+ if not match.stream:
+ response.content # NOQA
+
+ try:
+ response.cookies = _cookies_from_headers(response.headers)
+ except (KeyError, TypeError):
+ pass
+
+ response = resp_callback(response) if resp_callback else response
+ match.call_count += 1
+ self._calls.add(request, response)
+ return response
+
+ def start(self):
+ def unbound_on_send(adapter, request, *a, **kwargs):
+ return self._on_request(adapter, request, *a, **kwargs)
+
+ self._patcher = std_mock.patch(target=self.target, new=unbound_on_send)
+ self._patcher.start()
+
+ def stop(self, allow_assert=True):
+ self._patcher.stop()
+ if not self.assert_all_requests_are_fired:
+ return
+
+ if not allow_assert:
+ return
+
+ not_called = [m for m in self._matches if m.call_count == 0]
+ if not_called:
+ raise AssertionError(
+ "Not all requests have been executed {0!r}".format(
+ [(match.method, match.url) for match in not_called]
+ )
+ )
+
+
+# expose default mock namespace
+mock = _default_mock = RequestsMock(assert_all_requests_are_fired=False)
+__all__ = ["CallbackResponse", "Response", "RequestsMock"]
+for __attr in (a for a in dir(_default_mock) if not a.startswith("_")):
+ __all__.append(__attr)
+ globals()[__attr] = getattr(_default_mock, __attr)