summaryrefslogtreecommitdiffstats
path: root/tests/grpc
diff options
context:
space:
mode:
Diffstat (limited to 'tests/grpc')
-rw-r--r--tests/grpc/__init__.py0
-rw-r--r--tests/grpc/test_grpclib_client.py298
-rw-r--r--tests/grpc/test_stream_stream.py99
-rw-r--r--tests/grpc/thing_service.py85
4 files changed, 482 insertions, 0 deletions
diff --git a/tests/grpc/__init__.py b/tests/grpc/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/grpc/__init__.py
diff --git a/tests/grpc/test_grpclib_client.py b/tests/grpc/test_grpclib_client.py
new file mode 100644
index 0000000..d36e4a5
--- /dev/null
+++ b/tests/grpc/test_grpclib_client.py
@@ -0,0 +1,298 @@
+import asyncio
+import sys
+import uuid
+
+import grpclib
+import grpclib.client
+import grpclib.metadata
+import grpclib.server
+import pytest
+from grpclib.testing import ChannelFor
+
+from aristaproto.grpc.util.async_channel import AsyncChannel
+from tests.output_aristaproto.service import (
+ DoThingRequest,
+ DoThingResponse,
+ GetThingRequest,
+ TestStub as ThingServiceClient,
+)
+
+from .thing_service import ThingService
+
+
+async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
+ response = await client.do_thing(DoThingRequest(name=name), **kwargs)
+ assert response.names == [name]
+
+
+def _assert_request_meta_received(deadline, metadata):
+ def server_side_test(stream):
+ assert stream.deadline._timestamp == pytest.approx(
+ deadline._timestamp, 1
+ ), "The provided deadline should be received serverside"
+ assert (
+ stream.metadata["authorization"] == metadata["authorization"]
+ ), "The provided authorization metadata should be received serverside"
+
+ return server_side_test
+
+
+@pytest.fixture
+def handler_trailer_only_unauthenticated():
+ async def handler(stream: grpclib.server.Stream):
+ await stream.recv_message()
+ await stream.send_initial_metadata()
+ await stream.send_trailing_metadata(status=grpclib.Status.UNAUTHENTICATED)
+
+ return handler
+
+
+@pytest.mark.asyncio
+async def test_simple_service_call():
+ async with ChannelFor([ThingService()]) as channel:
+ await _test_client(ThingServiceClient(channel))
+
+
+@pytest.mark.asyncio
+async def test_trailer_only_error_unary_unary(
+ mocker, handler_trailer_only_unauthenticated
+):
+ service = ThingService()
+ mocker.patch.object(
+ service,
+ "do_thing",
+ side_effect=handler_trailer_only_unauthenticated,
+ autospec=True,
+ )
+ async with ChannelFor([service]) as channel:
+ with pytest.raises(grpclib.exceptions.GRPCError) as e:
+ await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
+ assert e.value.status == grpclib.Status.UNAUTHENTICATED
+
+
+@pytest.mark.asyncio
+async def test_trailer_only_error_stream_unary(
+ mocker, handler_trailer_only_unauthenticated
+):
+ service = ThingService()
+ mocker.patch.object(
+ service,
+ "do_many_things",
+ side_effect=handler_trailer_only_unauthenticated,
+ autospec=True,
+ )
+ async with ChannelFor([service]) as channel:
+ with pytest.raises(grpclib.exceptions.GRPCError) as e:
+ await ThingServiceClient(channel).do_many_things(
+ do_thing_request_iterator=[DoThingRequest(name="something")]
+ )
+ await _test_client(ThingServiceClient(channel))
+ assert e.value.status == grpclib.Status.UNAUTHENTICATED
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
+)
+async def test_service_call_mutable_defaults(mocker):
+ async with ChannelFor([ThingService()]) as channel:
+ client = ThingServiceClient(channel)
+ spy = mocker.spy(client, "_unary_unary")
+ await _test_client(client)
+ comments = spy.call_args_list[-1].args[1].comments
+ await _test_client(client)
+ assert spy.call_args_list[-1].args[1].comments is not comments
+
+
+@pytest.mark.asyncio
+async def test_service_call_with_upfront_request_params():
+ # Setting deadline
+ deadline = grpclib.metadata.Deadline.from_timeout(22)
+ metadata = {"authorization": "12345"}
+ async with ChannelFor(
+ [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
+ ) as channel:
+ await _test_client(
+ ThingServiceClient(channel, deadline=deadline, metadata=metadata)
+ )
+
+ # Setting timeout
+ timeout = 99
+ deadline = grpclib.metadata.Deadline.from_timeout(timeout)
+ metadata = {"authorization": "12345"}
+ async with ChannelFor(
+ [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
+ ) as channel:
+ await _test_client(
+ ThingServiceClient(channel, timeout=timeout, metadata=metadata)
+ )
+
+
+@pytest.mark.asyncio
+async def test_service_call_lower_level_with_overrides():
+ THING_TO_DO = "get milk"
+
+ # Setting deadline
+ deadline = grpclib.metadata.Deadline.from_timeout(22)
+ metadata = {"authorization": "12345"}
+ kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
+ kwarg_metadata = {"authorization": "12345"}
+ async with ChannelFor(
+ [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
+ ) as channel:
+ client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
+ response = await client._unary_unary(
+ "/service.Test/DoThing",
+ DoThingRequest(THING_TO_DO),
+ DoThingResponse,
+ deadline=kwarg_deadline,
+ metadata=kwarg_metadata,
+ )
+ assert response.names == [THING_TO_DO]
+
+ # Setting timeout
+ timeout = 99
+ deadline = grpclib.metadata.Deadline.from_timeout(timeout)
+ metadata = {"authorization": "12345"}
+ kwarg_timeout = 9000
+ kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
+ kwarg_metadata = {"authorization": "09876"}
+ async with ChannelFor(
+ [
+ ThingService(
+ test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata),
+ )
+ ]
+ ) as channel:
+ client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
+ response = await client._unary_unary(
+ "/service.Test/DoThing",
+ DoThingRequest(THING_TO_DO),
+ DoThingResponse,
+ timeout=kwarg_timeout,
+ metadata=kwarg_metadata,
+ )
+ assert response.names == [THING_TO_DO]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("overrides_gen",),
+ [
+ (lambda: dict(timeout=10),),
+ (lambda: dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),),
+ (lambda: dict(metadata={"authorization": str(uuid.uuid4())}),),
+ (lambda: dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),),
+ ],
+)
+async def test_service_call_high_level_with_overrides(mocker, overrides_gen):
+ overrides = overrides_gen()
+ request_spy = mocker.spy(grpclib.client.Channel, "request")
+ name = str(uuid.uuid4())
+ defaults = dict(
+ timeout=99,
+ deadline=grpclib.metadata.Deadline.from_timeout(99),
+ metadata={"authorization": name},
+ )
+
+ async with ChannelFor(
+ [
+ ThingService(
+ test_hook=_assert_request_meta_received(
+ deadline=grpclib.metadata.Deadline.from_timeout(
+ overrides.get("timeout", 99)
+ ),
+ metadata=overrides.get("metadata", defaults.get("metadata")),
+ )
+ )
+ ]
+ ) as channel:
+ client = ThingServiceClient(channel, **defaults)
+ await _test_client(client, name=name, **overrides)
+ assert request_spy.call_count == 1
+
+ # for python <3.8 request_spy.call_args.kwargs do not work
+ _, request_spy_call_kwargs = request_spy.call_args_list[0]
+
+ # ensure all overrides were successful
+ for key, value in overrides.items():
+ assert key in request_spy_call_kwargs
+ assert request_spy_call_kwargs[key] == value
+
+ # ensure default values were retained
+ for key in set(defaults.keys()) - set(overrides.keys()):
+ assert key in request_spy_call_kwargs
+ assert request_spy_call_kwargs[key] == defaults[key]
+
+
+@pytest.mark.asyncio
+async def test_async_gen_for_unary_stream_request():
+ thing_name = "my milkshakes"
+
+ async with ChannelFor([ThingService()]) as channel:
+ client = ThingServiceClient(channel)
+ expected_versions = [5, 4, 3, 2, 1]
+ async for response in client.get_thing_versions(
+ GetThingRequest(name=thing_name)
+ ):
+ assert response.name == thing_name
+ assert response.version == expected_versions.pop()
+
+
+@pytest.mark.asyncio
+async def test_async_gen_for_stream_stream_request():
+ some_things = ["cake", "cricket", "coral reef"]
+ more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"]
+ expected_things = (*some_things, *more_things)
+
+ async with ChannelFor([ThingService()]) as channel:
+ client = ThingServiceClient(channel)
+ # Use an AsyncChannel to decouple sending and recieving, it'll send some_things
+ # immediately and we'll use it to send more_things later, after recieving some
+ # results
+ request_chan = AsyncChannel()
+ send_initial_requests = asyncio.ensure_future(
+ request_chan.send_from(GetThingRequest(name) for name in some_things)
+ )
+ response_index = 0
+ async for response in client.get_different_things(request_chan):
+ assert response.name == expected_things[response_index]
+ assert response.version == response_index + 1
+ response_index += 1
+ if more_things:
+ # Send some more requests as we receive responses to be sure coordination of
+ # send/receive events doesn't matter
+ await request_chan.send(GetThingRequest(more_things.pop(0)))
+ elif not send_initial_requests.done():
+ # Make sure the sending task it completed
+ await send_initial_requests
+ else:
+ # No more things to send make sure channel is closed
+ request_chan.close()
+ assert response_index == len(
+ expected_things
+ ), "Didn't receive all expected responses"
+
+
+@pytest.mark.asyncio
+async def test_stream_unary_with_empty_iterable():
+ things = [] # empty
+
+ async with ChannelFor([ThingService()]) as channel:
+ client = ThingServiceClient(channel)
+ requests = [DoThingRequest(name) for name in things]
+ response = await client.do_many_things(requests)
+ assert len(response.names) == 0
+
+
+@pytest.mark.asyncio
+async def test_stream_stream_with_empty_iterable():
+ things = [] # empty
+
+ async with ChannelFor([ThingService()]) as channel:
+ client = ThingServiceClient(channel)
+ requests = [GetThingRequest(name) for name in things]
+ responses = [
+ response async for response in client.get_different_things(requests)
+ ]
+ assert len(responses) == 0
diff --git a/tests/grpc/test_stream_stream.py b/tests/grpc/test_stream_stream.py
new file mode 100644
index 0000000..d4b27e5
--- /dev/null
+++ b/tests/grpc/test_stream_stream.py
@@ -0,0 +1,99 @@
+import asyncio
+from dataclasses import dataclass
+from typing import AsyncIterator
+
+import pytest
+
+import aristaproto
+from aristaproto.grpc.util.async_channel import AsyncChannel
+
+
+@dataclass
+class Message(aristaproto.Message):
+ body: str = aristaproto.string_field(1)
+
+
+@pytest.fixture
+def expected_responses():
+ return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")]
+
+
+class ClientStub:
+ async def connect(self, requests: AsyncIterator):
+ await asyncio.sleep(0.1)
+ async for request in requests:
+ await asyncio.sleep(0.1)
+ yield request
+ await asyncio.sleep(0.1)
+ yield Message("Done")
+
+
+async def to_list(generator: AsyncIterator):
+ return [value async for value in generator]
+
+
+@pytest.fixture
+def client():
+ # channel = Channel(host='127.0.0.1', port=50051)
+ # return ClientStub(channel)
+ return ClientStub()
+
+
+@pytest.mark.asyncio
+async def test_send_from_before_connect_and_close_automatically(
+ client, expected_responses
+):
+ requests = AsyncChannel()
+ await requests.send_from(
+ [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
+ )
+ responses = client.connect(requests)
+
+ assert await to_list(responses) == expected_responses
+
+
+@pytest.mark.asyncio
+async def test_send_from_after_connect_and_close_automatically(
+ client, expected_responses
+):
+ requests = AsyncChannel()
+ responses = client.connect(requests)
+ await requests.send_from(
+ [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
+ )
+
+ assert await to_list(responses) == expected_responses
+
+
+@pytest.mark.asyncio
+async def test_send_from_close_manually_immediately(client, expected_responses):
+ requests = AsyncChannel()
+ responses = client.connect(requests)
+ await requests.send_from(
+ [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
+ )
+ requests.close()
+
+ assert await to_list(responses) == expected_responses
+
+
+@pytest.mark.asyncio
+async def test_send_individually_and_close_before_connect(client, expected_responses):
+ requests = AsyncChannel()
+ await requests.send(Message(body="Hello world 1"))
+ await requests.send(Message(body="Hello world 2"))
+ requests.close()
+ responses = client.connect(requests)
+
+ assert await to_list(responses) == expected_responses
+
+
+@pytest.mark.asyncio
+async def test_send_individually_and_close_after_connect(client, expected_responses):
+ requests = AsyncChannel()
+ await requests.send(Message(body="Hello world 1"))
+ await requests.send(Message(body="Hello world 2"))
+ responses = client.connect(requests)
+ requests.close()
+
+ assert await to_list(responses) == expected_responses
diff --git a/tests/grpc/thing_service.py b/tests/grpc/thing_service.py
new file mode 100644
index 0000000..5b00cbe
--- /dev/null
+++ b/tests/grpc/thing_service.py
@@ -0,0 +1,85 @@
+from typing import Dict
+
+import grpclib
+import grpclib.server
+
+from tests.output_aristaproto.service import (
+ DoThingRequest,
+ DoThingResponse,
+ GetThingRequest,
+ GetThingResponse,
+)
+
+
+class ThingService:
+ def __init__(self, test_hook=None):
+ # This lets us pass assertions to the servicer ;)
+ self.test_hook = test_hook
+
+ async def do_thing(
+ self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
+ ):
+ request = await stream.recv_message()
+ if self.test_hook is not None:
+ self.test_hook(stream)
+ await stream.send_message(DoThingResponse([request.name]))
+
+ async def do_many_things(
+ self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
+ ):
+ thing_names = [request.name async for request in stream]
+ if self.test_hook is not None:
+ self.test_hook(stream)
+ await stream.send_message(DoThingResponse(thing_names))
+
+ async def get_thing_versions(
+ self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
+ ):
+ request = await stream.recv_message()
+ if self.test_hook is not None:
+ self.test_hook(stream)
+ for version_num in range(1, 6):
+ await stream.send_message(
+ GetThingResponse(name=request.name, version=version_num)
+ )
+
+ async def get_different_things(
+ self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
+ ):
+ if self.test_hook is not None:
+ self.test_hook(stream)
+ # Respond to each input item immediately
+ response_num = 0
+ async for request in stream:
+ response_num += 1
+ await stream.send_message(
+ GetThingResponse(name=request.name, version=response_num)
+ )
+
+ def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]:
+ return {
+ "/service.Test/DoThing": grpclib.const.Handler(
+ self.do_thing,
+ grpclib.const.Cardinality.UNARY_UNARY,
+ DoThingRequest,
+ DoThingResponse,
+ ),
+ "/service.Test/DoManyThings": grpclib.const.Handler(
+ self.do_many_things,
+ grpclib.const.Cardinality.STREAM_UNARY,
+ DoThingRequest,
+ DoThingResponse,
+ ),
+ "/service.Test/GetThingVersions": grpclib.const.Handler(
+ self.get_thing_versions,
+ grpclib.const.Cardinality.UNARY_STREAM,
+ GetThingRequest,
+ GetThingResponse,
+ ),
+ "/service.Test/GetDifferentThings": grpclib.const.Handler(
+ self.get_different_things,
+ grpclib.const.Cardinality.STREAM_STREAM,
+ GetThingRequest,
+ GetThingResponse,
+ ),
+ }