diff options
Diffstat (limited to 'tests/grpc')
-rw-r--r-- | tests/grpc/__init__.py | 0 | ||||
-rw-r--r-- | tests/grpc/test_grpclib_client.py | 298 | ||||
-rw-r--r-- | tests/grpc/test_stream_stream.py | 99 | ||||
-rw-r--r-- | tests/grpc/thing_service.py | 85 |
4 files changed, 482 insertions, 0 deletions
diff --git a/tests/grpc/__init__.py b/tests/grpc/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/grpc/__init__.py diff --git a/tests/grpc/test_grpclib_client.py b/tests/grpc/test_grpclib_client.py new file mode 100644 index 0000000..d36e4a5 --- /dev/null +++ b/tests/grpc/test_grpclib_client.py @@ -0,0 +1,298 @@ +import asyncio +import sys +import uuid + +import grpclib +import grpclib.client +import grpclib.metadata +import grpclib.server +import pytest +from grpclib.testing import ChannelFor + +from aristaproto.grpc.util.async_channel import AsyncChannel +from tests.output_aristaproto.service import ( + DoThingRequest, + DoThingResponse, + GetThingRequest, + TestStub as ThingServiceClient, +) + +from .thing_service import ThingService + + +async def _test_client(client: ThingServiceClient, name="clean room", **kwargs): + response = await client.do_thing(DoThingRequest(name=name), **kwargs) + assert response.names == [name] + + +def _assert_request_meta_received(deadline, metadata): + def server_side_test(stream): + assert stream.deadline._timestamp == pytest.approx( + deadline._timestamp, 1 + ), "The provided deadline should be received serverside" + assert ( + stream.metadata["authorization"] == metadata["authorization"] + ), "The provided authorization metadata should be received serverside" + + return server_side_test + + +@pytest.fixture +def handler_trailer_only_unauthenticated(): + async def handler(stream: grpclib.server.Stream): + await stream.recv_message() + await stream.send_initial_metadata() + await stream.send_trailing_metadata(status=grpclib.Status.UNAUTHENTICATED) + + return handler + + +@pytest.mark.asyncio +async def test_simple_service_call(): + async with ChannelFor([ThingService()]) as channel: + await _test_client(ThingServiceClient(channel)) + + +@pytest.mark.asyncio +async def test_trailer_only_error_unary_unary( + mocker, handler_trailer_only_unauthenticated +): + service = ThingService() + mocker.patch.object( + service, + "do_thing", + side_effect=handler_trailer_only_unauthenticated, + autospec=True, + ) + async with ChannelFor([service]) as channel: + with pytest.raises(grpclib.exceptions.GRPCError) as e: + await ThingServiceClient(channel).do_thing(DoThingRequest(name="something")) + assert e.value.status == grpclib.Status.UNAUTHENTICATED + + +@pytest.mark.asyncio +async def test_trailer_only_error_stream_unary( + mocker, handler_trailer_only_unauthenticated +): + service = ThingService() + mocker.patch.object( + service, + "do_many_things", + side_effect=handler_trailer_only_unauthenticated, + autospec=True, + ) + async with ChannelFor([service]) as channel: + with pytest.raises(grpclib.exceptions.GRPCError) as e: + await ThingServiceClient(channel).do_many_things( + do_thing_request_iterator=[DoThingRequest(name="something")] + ) + await _test_client(ThingServiceClient(channel)) + assert e.value.status == grpclib.Status.UNAUTHENTICATED + + +@pytest.mark.asyncio +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="async mock spy does works for python3.8+" +) +async def test_service_call_mutable_defaults(mocker): + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + spy = mocker.spy(client, "_unary_unary") + await _test_client(client) + comments = spy.call_args_list[-1].args[1].comments + await _test_client(client) + assert spy.call_args_list[-1].args[1].comments is not comments + + +@pytest.mark.asyncio +async def test_service_call_with_upfront_request_params(): + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] + ) as channel: + await _test_client( + ThingServiceClient(channel, deadline=deadline, metadata=metadata) + ) + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] + ) as channel: + await _test_client( + ThingServiceClient(channel, timeout=timeout, metadata=metadata) + ) + + +@pytest.mark.asyncio +async def test_service_call_lower_level_with_overrides(): + THING_TO_DO = "get milk" + + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) + kwarg_metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] + ) as channel: + client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await client._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + deadline=kwarg_deadline, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + kwarg_timeout = 9000 + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) + kwarg_metadata = {"authorization": "09876"} + async with ChannelFor( + [ + ThingService( + test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata), + ) + ] + ) as channel: + client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await client._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + timeout=kwarg_timeout, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("overrides_gen",), + [ + (lambda: dict(timeout=10),), + (lambda: dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),), + (lambda: dict(metadata={"authorization": str(uuid.uuid4())}),), + (lambda: dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),), + ], +) +async def test_service_call_high_level_with_overrides(mocker, overrides_gen): + overrides = overrides_gen() + request_spy = mocker.spy(grpclib.client.Channel, "request") + name = str(uuid.uuid4()) + defaults = dict( + timeout=99, + deadline=grpclib.metadata.Deadline.from_timeout(99), + metadata={"authorization": name}, + ) + + async with ChannelFor( + [ + ThingService( + test_hook=_assert_request_meta_received( + deadline=grpclib.metadata.Deadline.from_timeout( + overrides.get("timeout", 99) + ), + metadata=overrides.get("metadata", defaults.get("metadata")), + ) + ) + ] + ) as channel: + client = ThingServiceClient(channel, **defaults) + await _test_client(client, name=name, **overrides) + assert request_spy.call_count == 1 + + # for python <3.8 request_spy.call_args.kwargs do not work + _, request_spy_call_kwargs = request_spy.call_args_list[0] + + # ensure all overrides were successful + for key, value in overrides.items(): + assert key in request_spy_call_kwargs + assert request_spy_call_kwargs[key] == value + + # ensure default values were retained + for key in set(defaults.keys()) - set(overrides.keys()): + assert key in request_spy_call_kwargs + assert request_spy_call_kwargs[key] == defaults[key] + + +@pytest.mark.asyncio +async def test_async_gen_for_unary_stream_request(): + thing_name = "my milkshakes" + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + expected_versions = [5, 4, 3, 2, 1] + async for response in client.get_thing_versions( + GetThingRequest(name=thing_name) + ): + assert response.name == thing_name + assert response.version == expected_versions.pop() + + +@pytest.mark.asyncio +async def test_async_gen_for_stream_stream_request(): + some_things = ["cake", "cricket", "coral reef"] + more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"] + expected_things = (*some_things, *more_things) + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + # Use an AsyncChannel to decouple sending and recieving, it'll send some_things + # immediately and we'll use it to send more_things later, after recieving some + # results + request_chan = AsyncChannel() + send_initial_requests = asyncio.ensure_future( + request_chan.send_from(GetThingRequest(name) for name in some_things) + ) + response_index = 0 + async for response in client.get_different_things(request_chan): + assert response.name == expected_things[response_index] + assert response.version == response_index + 1 + response_index += 1 + if more_things: + # Send some more requests as we receive responses to be sure coordination of + # send/receive events doesn't matter + await request_chan.send(GetThingRequest(more_things.pop(0))) + elif not send_initial_requests.done(): + # Make sure the sending task it completed + await send_initial_requests + else: + # No more things to send make sure channel is closed + request_chan.close() + assert response_index == len( + expected_things + ), "Didn't receive all expected responses" + + +@pytest.mark.asyncio +async def test_stream_unary_with_empty_iterable(): + things = [] # empty + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + requests = [DoThingRequest(name) for name in things] + response = await client.do_many_things(requests) + assert len(response.names) == 0 + + +@pytest.mark.asyncio +async def test_stream_stream_with_empty_iterable(): + things = [] # empty + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + requests = [GetThingRequest(name) for name in things] + responses = [ + response async for response in client.get_different_things(requests) + ] + assert len(responses) == 0 diff --git a/tests/grpc/test_stream_stream.py b/tests/grpc/test_stream_stream.py new file mode 100644 index 0000000..d4b27e5 --- /dev/null +++ b/tests/grpc/test_stream_stream.py @@ -0,0 +1,99 @@ +import asyncio +from dataclasses import dataclass +from typing import AsyncIterator + +import pytest + +import aristaproto +from aristaproto.grpc.util.async_channel import AsyncChannel + + +@dataclass +class Message(aristaproto.Message): + body: str = aristaproto.string_field(1) + + +@pytest.fixture +def expected_responses(): + return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")] + + +class ClientStub: + async def connect(self, requests: AsyncIterator): + await asyncio.sleep(0.1) + async for request in requests: + await asyncio.sleep(0.1) + yield request + await asyncio.sleep(0.1) + yield Message("Done") + + +async def to_list(generator: AsyncIterator): + return [value async for value in generator] + + +@pytest.fixture +def client(): + # channel = Channel(host='127.0.0.1', port=50051) + # return ClientStub(channel) + return ClientStub() + + +@pytest.mark.asyncio +async def test_send_from_before_connect_and_close_automatically( + client, expected_responses +): + requests = AsyncChannel() + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_after_connect_and_close_automatically( + client, expected_responses +): + requests = AsyncChannel() + responses = client.connect(requests) + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_close_manually_immediately(client, expected_responses): + requests = AsyncChannel() + responses = client.connect(requests) + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False + ) + requests.close() + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_before_connect(client, expected_responses): + requests = AsyncChannel() + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + requests.close() + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_after_connect(client, expected_responses): + requests = AsyncChannel() + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + responses = client.connect(requests) + requests.close() + + assert await to_list(responses) == expected_responses diff --git a/tests/grpc/thing_service.py b/tests/grpc/thing_service.py new file mode 100644 index 0000000..5b00cbe --- /dev/null +++ b/tests/grpc/thing_service.py @@ -0,0 +1,85 @@ +from typing import Dict + +import grpclib +import grpclib.server + +from tests.output_aristaproto.service import ( + DoThingRequest, + DoThingResponse, + GetThingRequest, + GetThingResponse, +) + + +class ThingService: + def __init__(self, test_hook=None): + # This lets us pass assertions to the servicer ;) + self.test_hook = test_hook + + async def do_thing( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse([request.name])) + + async def do_many_things( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + thing_names = [request.name async for request in stream] + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse(thing_names)) + + async def get_thing_versions( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + for version_num in range(1, 6): + await stream.send_message( + GetThingResponse(name=request.name, version=version_num) + ) + + async def get_different_things( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + if self.test_hook is not None: + self.test_hook(stream) + # Respond to each input item immediately + response_num = 0 + async for request in stream: + response_num += 1 + await stream.send_message( + GetThingResponse(name=request.name, version=response_num) + ) + + def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]: + return { + "/service.Test/DoThing": grpclib.const.Handler( + self.do_thing, + grpclib.const.Cardinality.UNARY_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/DoManyThings": grpclib.const.Handler( + self.do_many_things, + grpclib.const.Cardinality.STREAM_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/GetThingVersions": grpclib.const.Handler( + self.get_thing_versions, + grpclib.const.Cardinality.UNARY_STREAM, + GetThingRequest, + GetThingResponse, + ), + "/service.Test/GetDifferentThings": grpclib.const.Handler( + self.get_different_things, + grpclib.const.Cardinality.STREAM_STREAM, + GetThingRequest, + GetThingResponse, + ), + } |