summaryrefslogtreecommitdiffstats
path: root/src/aristaproto/grpc/grpclib_client.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-07-29 09:40:12 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-07-29 09:40:12 +0000
commit14b40ec77a4bf8605789cc3aff0eb87625510a41 (patch)
tree4064d27144d6deaabfcd96df01bd996baa8b51a0 /src/aristaproto/grpc/grpclib_client.py
parentInitial commit. (diff)
downloadpython-aristaproto-14b40ec77a4bf8605789cc3aff0eb87625510a41.tar.xz
python-aristaproto-14b40ec77a4bf8605789cc3aff0eb87625510a41.zip
Adding upstream version 1.2+20240521.upstream/1.2+20240521upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/aristaproto/grpc/grpclib_client.py')
-rw-r--r--src/aristaproto/grpc/grpclib_client.py177
1 files changed, 177 insertions, 0 deletions
diff --git a/src/aristaproto/grpc/grpclib_client.py b/src/aristaproto/grpc/grpclib_client.py
new file mode 100644
index 0000000..b19e806
--- /dev/null
+++ b/src/aristaproto/grpc/grpclib_client.py
@@ -0,0 +1,177 @@
+import asyncio
+from abc import ABC
+from typing import (
+ TYPE_CHECKING,
+ AsyncIterable,
+ AsyncIterator,
+ Collection,
+ Iterable,
+ Mapping,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
+
+import grpclib.const
+
+
+if TYPE_CHECKING:
+ from grpclib.client import Channel
+ from grpclib.metadata import Deadline
+
+ from .._types import (
+ ST,
+ IProtoMessage,
+ Message,
+ T,
+ )
+
+
+Value = Union[str, bytes]
+MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
+MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
+
+
+class ServiceStub(ABC):
+ """
+ Base class for async gRPC clients.
+ """
+
+ def __init__(
+ self,
+ channel: "Channel",
+ *,
+ timeout: Optional[float] = None,
+ deadline: Optional["Deadline"] = None,
+ metadata: Optional[MetadataLike] = None,
+ ) -> None:
+ self.channel = channel
+ self.timeout = timeout
+ self.deadline = deadline
+ self.metadata = metadata
+
+ def __resolve_request_kwargs(
+ self,
+ timeout: Optional[float],
+ deadline: Optional["Deadline"],
+ metadata: Optional[MetadataLike],
+ ):
+ return {
+ "timeout": self.timeout if timeout is None else timeout,
+ "deadline": self.deadline if deadline is None else deadline,
+ "metadata": self.metadata if metadata is None else metadata,
+ }
+
+ async def _unary_unary(
+ self,
+ route: str,
+ request: "IProtoMessage",
+ response_type: Type["T"],
+ *,
+ timeout: Optional[float] = None,
+ deadline: Optional["Deadline"] = None,
+ metadata: Optional[MetadataLike] = None,
+ ) -> "T":
+ """Make a unary request and return the response."""
+ async with self.channel.request(
+ route,
+ grpclib.const.Cardinality.UNARY_UNARY,
+ type(request),
+ response_type,
+ **self.__resolve_request_kwargs(timeout, deadline, metadata),
+ ) as stream:
+ await stream.send_message(request, end=True)
+ response = await stream.recv_message()
+ assert response is not None
+ return response
+
+ async def _unary_stream(
+ self,
+ route: str,
+ request: "IProtoMessage",
+ response_type: Type["T"],
+ *,
+ timeout: Optional[float] = None,
+ deadline: Optional["Deadline"] = None,
+ metadata: Optional[MetadataLike] = None,
+ ) -> AsyncIterator["T"]:
+ """Make a unary request and return the stream response iterator."""
+ async with self.channel.request(
+ route,
+ grpclib.const.Cardinality.UNARY_STREAM,
+ type(request),
+ response_type,
+ **self.__resolve_request_kwargs(timeout, deadline, metadata),
+ ) as stream:
+ await stream.send_message(request, end=True)
+ async for message in stream:
+ yield message
+
+ async def _stream_unary(
+ self,
+ route: str,
+ request_iterator: MessageSource,
+ request_type: Type["IProtoMessage"],
+ response_type: Type["T"],
+ *,
+ timeout: Optional[float] = None,
+ deadline: Optional["Deadline"] = None,
+ metadata: Optional[MetadataLike] = None,
+ ) -> "T":
+ """Make a stream request and return the response."""
+ async with self.channel.request(
+ route,
+ grpclib.const.Cardinality.STREAM_UNARY,
+ request_type,
+ response_type,
+ **self.__resolve_request_kwargs(timeout, deadline, metadata),
+ ) as stream:
+ await stream.send_request()
+ await self._send_messages(stream, request_iterator)
+ response = await stream.recv_message()
+ assert response is not None
+ return response
+
+ async def _stream_stream(
+ self,
+ route: str,
+ request_iterator: MessageSource,
+ request_type: Type["IProtoMessage"],
+ response_type: Type["T"],
+ *,
+ timeout: Optional[float] = None,
+ deadline: Optional["Deadline"] = None,
+ metadata: Optional[MetadataLike] = None,
+ ) -> AsyncIterator["T"]:
+ """
+ Make a stream request and return an AsyncIterator to iterate over response
+ messages.
+ """
+ async with self.channel.request(
+ route,
+ grpclib.const.Cardinality.STREAM_STREAM,
+ request_type,
+ response_type,
+ **self.__resolve_request_kwargs(timeout, deadline, metadata),
+ ) as stream:
+ await stream.send_request()
+ sending_task = asyncio.ensure_future(
+ self._send_messages(stream, request_iterator)
+ )
+ try:
+ async for response in stream:
+ yield response
+ except:
+ sending_task.cancel()
+ raise
+
+ @staticmethod
+ async def _send_messages(stream, messages: MessageSource):
+ if isinstance(messages, AsyncIterable):
+ async for message in messages:
+ await stream.send_message(message)
+ else:
+ for message in messages:
+ await stream.send_message(message)
+ await stream.end()