from __future__ import absolute_import, print_function, division, unicode_literals import _io import inspect import json as json_module import logging import re import six from collections import namedtuple from functools import update_wrapper from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError from requests.sessions import REDIRECT_STATI from requests.utils import cookiejar_from_dict try: from collections.abc import Sequence, Sized except ImportError: from collections import Sequence, Sized try: from requests.packages.urllib3.response import HTTPResponse except ImportError: from urllib3.response import HTTPResponse if six.PY2: from urlparse import urlparse, parse_qsl, urlsplit, urlunsplit from urllib import quote else: from urllib.parse import urlparse, parse_qsl, urlsplit, urlunsplit, quote if six.PY2: try: from six import cStringIO as BufferIO except ImportError: from six import StringIO as BufferIO else: from io import BytesIO as BufferIO try: from unittest import mock as std_mock except ImportError: import mock as std_mock try: Pattern = re._pattern_type except AttributeError: # Python 3.7 Pattern = re.Pattern UNSET = object() Call = namedtuple("Call", ["request", "response"]) _real_send = HTTPAdapter.send logger = logging.getLogger("responses") def _is_string(s): return isinstance(s, six.string_types) def _has_unicode(s): return any(ord(char) > 128 for char in s) def _clean_unicode(url): # Clean up domain names, which use punycode to handle unicode chars urllist = list(urlsplit(url)) netloc = urllist[1] if _has_unicode(netloc): domains = netloc.split(".") for i, d in enumerate(domains): if _has_unicode(d): d = "xn--" + d.encode("punycode").decode("ascii") domains[i] = d urllist[1] = ".".join(domains) url = urlunsplit(urllist) # Clean up path/query/params, which use url-encoding to handle unicode chars if isinstance(url.encode("utf8"), six.string_types): url = url.encode("utf8") chars = list(url) for i, x in enumerate(chars): if ord(x) > 128: chars[i] = quote(x) return "".join(chars) def _is_redirect(response): try: # 2.0.0 <= requests <= 2.2 return response.is_redirect except AttributeError: # requests > 2.2 return ( # use request.sessions conditional response.status_code in REDIRECT_STATI and "location" in response.headers ) def _cookies_from_headers(headers): try: import http.cookies as cookies resp_cookie = cookies.SimpleCookie() resp_cookie.load(headers["set-cookie"]) cookies_dict = {name: v.value for name, v in resp_cookie.items()} except ImportError: from cookies import Cookies resp_cookies = Cookies.from_request(headers["set-cookie"]) cookies_dict = {v.name: v.value for _, v in resp_cookies.items()} return cookiejar_from_dict(cookies_dict) _wrapper_template = """\ def wrapper%(wrapper_args)s: with responses: return func%(func_args)s """ def get_wrapped(func, responses): if six.PY2: args, a, kw, defaults = inspect.getargspec(func) wrapper_args = inspect.formatargspec(args, a, kw, defaults) # Preserve the argspec for the wrapped function so that testing # tools such as pytest can continue to use their fixture injection. if hasattr(func, "__self__"): args = args[1:] # Omit 'self' func_args = inspect.formatargspec(args, a, kw, None) else: signature = inspect.signature(func) signature = signature.replace(return_annotation=inspect.Signature.empty) # If the function is wrapped, switch to *args, **kwargs for the parameters # as we can't rely on the signature to give us the arguments the function will # be called with. For example unittest.mock.patch uses required args that are # not actually passed to the function when invoked. if hasattr(func, "__wrapped__"): wrapper_params = [ inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL), inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD), ] else: wrapper_params = [ param.replace(annotation=inspect.Parameter.empty) for param in signature.parameters.values() ] signature = signature.replace(parameters=wrapper_params) wrapper_args = str(signature) params_without_defaults = [ param.replace( annotation=inspect.Parameter.empty, default=inspect.Parameter.empty ) for param in signature.parameters.values() ] signature = signature.replace(parameters=params_without_defaults) func_args = str(signature) evaldict = {"func": func, "responses": responses} six.exec_( _wrapper_template % {"wrapper_args": wrapper_args, "func_args": func_args}, evaldict, ) wrapper = evaldict["wrapper"] update_wrapper(wrapper, func) return wrapper class CallList(Sequence, Sized): def __init__(self): self._calls = [] def __iter__(self): return iter(self._calls) def __len__(self): return len(self._calls) def __getitem__(self, idx): return self._calls[idx] def add(self, request, response): self._calls.append(Call(request, response)) def reset(self): self._calls = [] def _ensure_url_default_path(url): if _is_string(url): url_parts = list(urlsplit(url)) if url_parts[2] == "": url_parts[2] = "/" url = urlunsplit(url_parts) return url def _handle_body(body): if isinstance(body, six.text_type): body = body.encode("utf-8") if isinstance(body, _io.BufferedReader): return body return BufferIO(body) _unspecified = object() class BaseResponse(object): content_type = None headers = None stream = False def __init__(self, method, url, match_querystring=_unspecified): self.method = method # ensure the url has a default path set if the url is a string self.url = _ensure_url_default_path(url) self.match_querystring = self._should_match_querystring(match_querystring) self.call_count = 0 def __eq__(self, other): if not isinstance(other, BaseResponse): return False if self.method != other.method: return False # Can't simply do a equality check on the objects directly here since __eq__ isn't # implemented for regex. It might seem to work as regex is using a cache to return # the same regex instances, but it doesn't in all cases. self_url = self.url.pattern if isinstance(self.url, Pattern) else self.url other_url = other.url.pattern if isinstance(other.url, Pattern) else other.url return self_url == other_url def __ne__(self, other): return not self.__eq__(other) def _url_matches_strict(self, url, other): url_parsed = urlparse(url) other_parsed = urlparse(other) if url_parsed[:3] != other_parsed[:3]: return False url_qsl = sorted(parse_qsl(url_parsed.query)) other_qsl = sorted(parse_qsl(other_parsed.query)) if len(url_qsl) != len(other_qsl): return False for (a_k, a_v), (b_k, b_v) in zip(url_qsl, other_qsl): if a_k != b_k: return False if a_v != b_v: return False return True def _should_match_querystring(self, match_querystring_argument): if match_querystring_argument is not _unspecified: return match_querystring_argument if isinstance(self.url, Pattern): # the old default from <= 0.9.0 return False return bool(urlparse(self.url).query) def _url_matches(self, url, other, match_querystring=False): if _is_string(url): if _has_unicode(url): url = _clean_unicode(url) if not isinstance(other, six.text_type): other = other.encode("ascii").decode("utf8") if match_querystring: return self._url_matches_strict(url, other) else: url_without_qs = url.split("?", 1)[0] other_without_qs = other.split("?", 1)[0] return url_without_qs == other_without_qs elif isinstance(url, Pattern) and url.match(other): return True else: return False def get_headers(self): headers = {} if self.content_type is not None: headers["Content-Type"] = self.content_type if self.headers: headers.update(self.headers) return headers def get_response(self, request): raise NotImplementedError def matches(self, request): if request.method != self.method: return False if not self._url_matches(self.url, request.url, self.match_querystring): return False return True class Response(BaseResponse): def __init__( self, method, url, body="", json=None, status=200, headers=None, stream=False, content_type=UNSET, **kwargs ): # if we were passed a `json` argument, # override the body and content_type if json is not None: assert not body body = json_module.dumps(json) if content_type is UNSET: content_type = "application/json" if content_type is UNSET: content_type = "text/plain" # body must be bytes if isinstance(body, six.text_type): body = body.encode("utf-8") self.body = body self.status = status self.headers = headers self.stream = stream self.content_type = content_type super(Response, self).__init__(method, url, **kwargs) def get_response(self, request): if self.body and isinstance(self.body, Exception): raise self.body headers = self.get_headers() status = self.status body = _handle_body(self.body) return HTTPResponse( status=status, reason=six.moves.http_client.responses.get(status), body=body, headers=headers, preload_content=False, ) class CallbackResponse(BaseResponse): def __init__( self, method, url, callback, stream=False, content_type="text/plain", **kwargs ): self.callback = callback self.stream = stream self.content_type = content_type super(CallbackResponse, self).__init__(method, url, **kwargs) def get_response(self, request): headers = self.get_headers() result = self.callback(request) if isinstance(result, Exception): raise result status, r_headers, body = result if isinstance(body, Exception): raise body body = _handle_body(body) headers.update(r_headers) return HTTPResponse( status=status, reason=six.moves.http_client.responses.get(status), body=body, headers=headers, preload_content=False, ) class RequestsMock(object): DELETE = "DELETE" GET = "GET" HEAD = "HEAD" OPTIONS = "OPTIONS" PATCH = "PATCH" POST = "POST" PUT = "PUT" response_callback = None def __init__( self, assert_all_requests_are_fired=True, response_callback=None, passthru_prefixes=(), target="requests.adapters.HTTPAdapter.send", ): self._calls = CallList() self.reset() self.assert_all_requests_are_fired = assert_all_requests_are_fired self.response_callback = response_callback self.passthru_prefixes = tuple(passthru_prefixes) self.target = target def reset(self): self._matches = [] self._calls.reset() def add( self, method=None, # method or ``Response`` url=None, body="", adding_headers=None, *args, **kwargs ): """ A basic request: >>> responses.add(responses.GET, 'http://example.com') You can also directly pass an object which implements the ``BaseResponse`` interface: >>> responses.add(Response(...)) A JSON payload: >>> responses.add( >>> method='GET', >>> url='http://example.com', >>> json={'foo': 'bar'}, >>> ) Custom headers: >>> responses.add( >>> method='GET', >>> url='http://example.com', >>> headers={'X-Header': 'foo'}, >>> ) Strict query string matching: >>> responses.add( >>> method='GET', >>> url='http://example.com?foo=bar', >>> match_querystring=True >>> ) """ if isinstance(method, BaseResponse): self._matches.append(method) return if adding_headers is not None: kwargs.setdefault("headers", adding_headers) self._matches.append(Response(method=method, url=url, body=body, **kwargs)) def add_passthru(self, prefix): """ Register a URL prefix to passthru any non-matching mock requests to. For example, to allow any request to 'https://example.com', but require mocks for the remainder, you would add the prefix as so: >>> responses.add_passthru('https://example.com') """ if _has_unicode(prefix): prefix = _clean_unicode(prefix) self.passthru_prefixes += (prefix,) def remove(self, method_or_response=None, url=None): """ Removes a response previously added using ``add()``, identified either by a response object inheriting ``BaseResponse`` or ``method`` and ``url``. Removes all matching responses. >>> response.add(responses.GET, 'http://example.org') >>> response.remove(responses.GET, 'http://example.org') """ if isinstance(method_or_response, BaseResponse): response = method_or_response else: response = BaseResponse(method=method_or_response, url=url) while response in self._matches: self._matches.remove(response) def replace(self, method_or_response=None, url=None, body="", *args, **kwargs): """ Replaces a response previously added using ``add()``. The signature is identical to ``add()``. The response is identified using ``method`` and ``url``, and the first matching response is replaced. >>> responses.add(responses.GET, 'http://example.org', json={'data': 1}) >>> responses.replace(responses.GET, 'http://example.org', json={'data': 2}) """ if isinstance(method_or_response, BaseResponse): response = method_or_response else: response = Response(method=method_or_response, url=url, body=body, **kwargs) index = self._matches.index(response) self._matches[index] = response def add_callback( self, method, url, callback, match_querystring=False, content_type="text/plain" ): # ensure the url has a default path set if the url is a string # url = _ensure_url_default_path(url, match_querystring) self._matches.append( CallbackResponse( url=url, method=method, callback=callback, content_type=content_type, match_querystring=match_querystring, ) ) @property def calls(self): return self._calls def __enter__(self): self.start() return self def __exit__(self, type, value, traceback): success = type is None self.stop(allow_assert=success) self.reset() return success def activate(self, func): return get_wrapped(func, self) def _find_match(self, request): found = None found_match = None for i, match in enumerate(self._matches): if match.matches(request): if found is None: found = i found_match = match else: # Multiple matches found. Remove & return the first match. return self._matches.pop(found) return found_match def _on_request(self, adapter, request, **kwargs): match = self._find_match(request) resp_callback = self.response_callback if match is None: if request.url.startswith(self.passthru_prefixes): logger.info("request.allowed-passthru", extra={"url": request.url}) return _real_send(adapter, request, **kwargs) error_msg = ( "Connection refused by Responses: {0} {1} doesn't " "match Responses Mock".format(request.method, request.url) ) response = ConnectionError(error_msg) response.request = request self._calls.add(request, response) response = resp_callback(response) if resp_callback else response raise response try: response = adapter.build_response(request, match.get_response(request)) except Exception as response: match.call_count += 1 self._calls.add(request, response) response = resp_callback(response) if resp_callback else response raise if not match.stream: response.content # NOQA try: response.cookies = _cookies_from_headers(response.headers) except (KeyError, TypeError): pass response = resp_callback(response) if resp_callback else response match.call_count += 1 self._calls.add(request, response) return response def start(self): def unbound_on_send(adapter, request, *a, **kwargs): return self._on_request(adapter, request, *a, **kwargs) self._patcher = std_mock.patch(target=self.target, new=unbound_on_send) self._patcher.start() def stop(self, allow_assert=True): self._patcher.stop() if not self.assert_all_requests_are_fired: return if not allow_assert: return not_called = [m for m in self._matches if m.call_count == 0] if not_called: raise AssertionError( "Not all requests have been executed {0!r}".format( [(match.method, match.url) for match in not_called] ) ) # expose default mock namespace mock = _default_mock = RequestsMock(assert_all_requests_are_fired=False) __all__ = ["CallbackResponse", "Response", "RequestsMock"] for __attr in (a for a in dir(_default_mock) if not a.startswith("_")): __all__.append(__attr) globals()[__attr] = getattr(_default_mock, __attr)