diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-07-29 09:40:12 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-07-29 09:40:12 +0000 |
commit | 14b40ec77a4bf8605789cc3aff0eb87625510a41 (patch) | |
tree | 4064d27144d6deaabfcd96df01bd996baa8b51a0 /src/aristaproto/grpc/grpclib_client.py | |
parent | Initial commit. (diff) | |
download | python-aristaproto-14b40ec77a4bf8605789cc3aff0eb87625510a41.tar.xz python-aristaproto-14b40ec77a4bf8605789cc3aff0eb87625510a41.zip |
Adding upstream version 1.2+20240521.upstream/1.2+20240521upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/aristaproto/grpc/grpclib_client.py')
-rw-r--r-- | src/aristaproto/grpc/grpclib_client.py | 177 |
1 files changed, 177 insertions, 0 deletions
diff --git a/src/aristaproto/grpc/grpclib_client.py b/src/aristaproto/grpc/grpclib_client.py new file mode 100644 index 0000000..b19e806 --- /dev/null +++ b/src/aristaproto/grpc/grpclib_client.py @@ -0,0 +1,177 @@ +import asyncio +from abc import ABC +from typing import ( + TYPE_CHECKING, + AsyncIterable, + AsyncIterator, + Collection, + Iterable, + Mapping, + Optional, + Tuple, + Type, + Union, +) + +import grpclib.const + + +if TYPE_CHECKING: + from grpclib.client import Channel + from grpclib.metadata import Deadline + + from .._types import ( + ST, + IProtoMessage, + Message, + T, + ) + + +Value = Union[str, bytes] +MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]] +MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]] + + +class ServiceStub(ABC): + """ + Base class for async gRPC clients. + """ + + def __init__( + self, + channel: "Channel", + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[MetadataLike] = None, + ) -> None: + self.channel = channel + self.timeout = timeout + self.deadline = deadline + self.metadata = metadata + + def __resolve_request_kwargs( + self, + timeout: Optional[float], + deadline: Optional["Deadline"], + metadata: Optional[MetadataLike], + ): + return { + "timeout": self.timeout if timeout is None else timeout, + "deadline": self.deadline if deadline is None else deadline, + "metadata": self.metadata if metadata is None else metadata, + } + + async def _unary_unary( + self, + route: str, + request: "IProtoMessage", + response_type: Type["T"], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[MetadataLike] = None, + ) -> "T": + """Make a unary request and return the response.""" + async with self.channel.request( + route, + grpclib.const.Cardinality.UNARY_UNARY, + type(request), + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await stream.send_message(request, end=True) + response = await stream.recv_message() + assert response is not None + return response + + async def _unary_stream( + self, + route: str, + request: "IProtoMessage", + response_type: Type["T"], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[MetadataLike] = None, + ) -> AsyncIterator["T"]: + """Make a unary request and return the stream response iterator.""" + async with self.channel.request( + route, + grpclib.const.Cardinality.UNARY_STREAM, + type(request), + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await stream.send_message(request, end=True) + async for message in stream: + yield message + + async def _stream_unary( + self, + route: str, + request_iterator: MessageSource, + request_type: Type["IProtoMessage"], + response_type: Type["T"], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[MetadataLike] = None, + ) -> "T": + """Make a stream request and return the response.""" + async with self.channel.request( + route, + grpclib.const.Cardinality.STREAM_UNARY, + request_type, + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await stream.send_request() + await self._send_messages(stream, request_iterator) + response = await stream.recv_message() + assert response is not None + return response + + async def _stream_stream( + self, + route: str, + request_iterator: MessageSource, + request_type: Type["IProtoMessage"], + response_type: Type["T"], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[MetadataLike] = None, + ) -> AsyncIterator["T"]: + """ + Make a stream request and return an AsyncIterator to iterate over response + messages. + """ + async with self.channel.request( + route, + grpclib.const.Cardinality.STREAM_STREAM, + request_type, + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await stream.send_request() + sending_task = asyncio.ensure_future( + self._send_messages(stream, request_iterator) + ) + try: + async for response in stream: + yield response + except: + sending_task.cancel() + raise + + @staticmethod + async def _send_messages(stream, messages: MessageSource): + if isinstance(messages, AsyncIterable): + async for message in messages: + await stream.send_message(message) + else: + for message in messages: + await stream.send_message(message) + await stream.end() |