diff options
Diffstat (limited to 'third_party/python/aiohttp/aiohttp/pytest_plugin.py')
-rw-r--r-- | third_party/python/aiohttp/aiohttp/pytest_plugin.py | 380 |
1 files changed, 380 insertions, 0 deletions
diff --git a/third_party/python/aiohttp/aiohttp/pytest_plugin.py b/third_party/python/aiohttp/aiohttp/pytest_plugin.py new file mode 100644 index 0000000000..5204293410 --- /dev/null +++ b/third_party/python/aiohttp/aiohttp/pytest_plugin.py @@ -0,0 +1,380 @@ +import asyncio +import contextlib +import warnings +from collections.abc import Callable + +import pytest + +from aiohttp.helpers import PY_37, isasyncgenfunction +from aiohttp.web import Application + +from .test_utils import ( + BaseTestServer, + RawTestServer, + TestClient, + TestServer, + loop_context, + setup_test_loop, + teardown_test_loop, + unused_port as _unused_port, +) + +try: + import uvloop +except ImportError: # pragma: no cover + uvloop = None + +try: + import tokio +except ImportError: # pragma: no cover + tokio = None + + +def pytest_addoption(parser): # type: ignore + parser.addoption( + "--aiohttp-fast", + action="store_true", + default=False, + help="run tests faster by disabling extra checks", + ) + parser.addoption( + "--aiohttp-loop", + action="store", + default="pyloop", + help="run tests with specific loop: pyloop, uvloop, tokio or all", + ) + parser.addoption( + "--aiohttp-enable-loop-debug", + action="store_true", + default=False, + help="enable event loop debug mode", + ) + + +def pytest_fixture_setup(fixturedef): # type: ignore + """ + Allow fixtures to be coroutines. Run coroutine fixtures in an event loop. + """ + func = fixturedef.func + + if isasyncgenfunction(func): + # async generator fixture + is_async_gen = True + elif asyncio.iscoroutinefunction(func): + # regular async fixture + is_async_gen = False + else: + # not an async fixture, nothing to do + return + + strip_request = False + if "request" not in fixturedef.argnames: + fixturedef.argnames += ("request",) + strip_request = True + + def wrapper(*args, **kwargs): # type: ignore + request = kwargs["request"] + if strip_request: + del kwargs["request"] + + # if neither the fixture nor the test use the 'loop' fixture, + # 'getfixturevalue' will fail because the test is not parameterized + # (this can be removed someday if 'loop' is no longer parameterized) + if "loop" not in request.fixturenames: + raise Exception( + "Asynchronous fixtures must depend on the 'loop' fixture or " + "be used in tests depending from it." + ) + + _loop = request.getfixturevalue("loop") + + if is_async_gen: + # for async generators, we need to advance the generator once, + # then advance it again in a finalizer + gen = func(*args, **kwargs) + + def finalizer(): # type: ignore + try: + return _loop.run_until_complete(gen.__anext__()) + except StopAsyncIteration: + pass + + request.addfinalizer(finalizer) + return _loop.run_until_complete(gen.__anext__()) + else: + return _loop.run_until_complete(func(*args, **kwargs)) + + fixturedef.func = wrapper + + +@pytest.fixture +def fast(request): # type: ignore + """--fast config option""" + return request.config.getoption("--aiohttp-fast") + + +@pytest.fixture +def loop_debug(request): # type: ignore + """--enable-loop-debug config option""" + return request.config.getoption("--aiohttp-enable-loop-debug") + + +@contextlib.contextmanager +def _runtime_warning_context(): # type: ignore + """ + Context manager which checks for RuntimeWarnings, specifically to + avoid "coroutine 'X' was never awaited" warnings being missed. + + If RuntimeWarnings occur in the context a RuntimeError is raised. + """ + with warnings.catch_warnings(record=True) as _warnings: + yield + rw = [ + "{w.filename}:{w.lineno}:{w.message}".format(w=w) + for w in _warnings + if w.category == RuntimeWarning + ] + if rw: + raise RuntimeError( + "{} Runtime Warning{},\n{}".format( + len(rw), "" if len(rw) == 1 else "s", "\n".join(rw) + ) + ) + + +@contextlib.contextmanager +def _passthrough_loop_context(loop, fast=False): # type: ignore + """ + setups and tears down a loop unless one is passed in via the loop + argument when it's passed straight through. + """ + if loop: + # loop already exists, pass it straight through + yield loop + else: + # this shadows loop_context's standard behavior + loop = setup_test_loop() + yield loop + teardown_test_loop(loop, fast=fast) + + +def pytest_pycollect_makeitem(collector, name, obj): # type: ignore + """ + Fix pytest collecting for coroutines. + """ + if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj): + return list(collector._genfunctions(name, obj)) + + +def pytest_pyfunc_call(pyfuncitem): # type: ignore + """ + Run coroutines in an event loop instead of a normal function call. + """ + fast = pyfuncitem.config.getoption("--aiohttp-fast") + if asyncio.iscoroutinefunction(pyfuncitem.function): + existing_loop = pyfuncitem.funcargs.get( + "proactor_loop" + ) or pyfuncitem.funcargs.get("loop", None) + with _runtime_warning_context(): + with _passthrough_loop_context(existing_loop, fast=fast) as _loop: + testargs = { + arg: pyfuncitem.funcargs[arg] + for arg in pyfuncitem._fixtureinfo.argnames + } + _loop.run_until_complete(pyfuncitem.obj(**testargs)) + + return True + + +def pytest_generate_tests(metafunc): # type: ignore + if "loop_factory" not in metafunc.fixturenames: + return + + loops = metafunc.config.option.aiohttp_loop + avail_factories = {"pyloop": asyncio.DefaultEventLoopPolicy} + + if uvloop is not None: # pragma: no cover + avail_factories["uvloop"] = uvloop.EventLoopPolicy + + if tokio is not None: # pragma: no cover + avail_factories["tokio"] = tokio.EventLoopPolicy + + if loops == "all": + loops = "pyloop,uvloop?,tokio?" + + factories = {} # type: ignore + for name in loops.split(","): + required = not name.endswith("?") + name = name.strip(" ?") + if name not in avail_factories: # pragma: no cover + if required: + raise ValueError( + "Unknown loop '%s', available loops: %s" + % (name, list(factories.keys())) + ) + else: + continue + factories[name] = avail_factories[name] + metafunc.parametrize( + "loop_factory", list(factories.values()), ids=list(factories.keys()) + ) + + +@pytest.fixture +def loop(loop_factory, fast, loop_debug): # type: ignore + """Return an instance of the event loop.""" + policy = loop_factory() + asyncio.set_event_loop_policy(policy) + with loop_context(fast=fast) as _loop: + if loop_debug: + _loop.set_debug(True) # pragma: no cover + asyncio.set_event_loop(_loop) + yield _loop + + +@pytest.fixture +def proactor_loop(): # type: ignore + if not PY_37: + policy = asyncio.get_event_loop_policy() + policy._loop_factory = asyncio.ProactorEventLoop # type: ignore + else: + policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore + asyncio.set_event_loop_policy(policy) + + with loop_context(policy.new_event_loop) as _loop: + asyncio.set_event_loop(_loop) + yield _loop + + +@pytest.fixture +def unused_port(aiohttp_unused_port): # type: ignore # pragma: no cover + warnings.warn( + "Deprecated, use aiohttp_unused_port fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_unused_port + + +@pytest.fixture +def aiohttp_unused_port(): # type: ignore + """Return a port that is unused on the current host.""" + return _unused_port + + +@pytest.fixture +def aiohttp_server(loop): # type: ignore + """Factory to create a TestServer instance, given an app. + + aiohttp_server(app, **kwargs) + """ + servers = [] + + async def go(app, *, port=None, **kwargs): # type: ignore + server = TestServer(app, port=port) + await server.start_server(loop=loop, **kwargs) + servers.append(server) + return server + + yield go + + async def finalize(): # type: ignore + while servers: + await servers.pop().close() + + loop.run_until_complete(finalize()) + + +@pytest.fixture +def test_server(aiohttp_server): # type: ignore # pragma: no cover + warnings.warn( + "Deprecated, use aiohttp_server fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_server + + +@pytest.fixture +def aiohttp_raw_server(loop): # type: ignore + """Factory to create a RawTestServer instance, given a web handler. + + aiohttp_raw_server(handler, **kwargs) + """ + servers = [] + + async def go(handler, *, port=None, **kwargs): # type: ignore + server = RawTestServer(handler, port=port) + await server.start_server(loop=loop, **kwargs) + servers.append(server) + return server + + yield go + + async def finalize(): # type: ignore + while servers: + await servers.pop().close() + + loop.run_until_complete(finalize()) + + +@pytest.fixture +def raw_test_server(aiohttp_raw_server): # type: ignore # pragma: no cover + warnings.warn( + "Deprecated, use aiohttp_raw_server fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_raw_server + + +@pytest.fixture +def aiohttp_client(loop): # type: ignore + """Factory to create a TestClient instance. + + aiohttp_client(app, **kwargs) + aiohttp_client(server, **kwargs) + aiohttp_client(raw_server, **kwargs) + """ + clients = [] + + async def go(__param, *args, server_kwargs=None, **kwargs): # type: ignore + + if isinstance(__param, Callable) and not isinstance( # type: ignore + __param, (Application, BaseTestServer) + ): + __param = __param(loop, *args, **kwargs) + kwargs = {} + else: + assert not args, "args should be empty" + + if isinstance(__param, Application): + server_kwargs = server_kwargs or {} + server = TestServer(__param, loop=loop, **server_kwargs) + client = TestClient(server, loop=loop, **kwargs) + elif isinstance(__param, BaseTestServer): + client = TestClient(__param, loop=loop, **kwargs) + else: + raise ValueError("Unknown argument type: %r" % type(__param)) + + await client.start_server() + clients.append(client) + return client + + yield go + + async def finalize(): # type: ignore + while clients: + await clients.pop().close() + + loop.run_until_complete(finalize()) + + +@pytest.fixture +def test_client(aiohttp_client): # type: ignore # pragma: no cover + warnings.warn( + "Deprecated, use aiohttp_client fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_client |