summaryrefslogtreecommitdiffstats
path: root/src/prompt_toolkit/eventloop/async_context_manager.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/prompt_toolkit/eventloop/async_context_manager.py')
-rw-r--r--src/prompt_toolkit/eventloop/async_context_manager.py132
1 files changed, 132 insertions, 0 deletions
diff --git a/src/prompt_toolkit/eventloop/async_context_manager.py b/src/prompt_toolkit/eventloop/async_context_manager.py
new file mode 100644
index 0000000..3914616
--- /dev/null
+++ b/src/prompt_toolkit/eventloop/async_context_manager.py
@@ -0,0 +1,132 @@
+"""
+@asynccontextmanager code, copied from Python 3.7's contextlib.
+For usage in Python 3.6.
+Types have been added to this file, just enough to make Mypy happy.
+"""
+# mypy: allow-untyped-defs
+import abc
+from functools import wraps
+from typing import AsyncContextManager, AsyncIterator, Callable, TypeVar
+
+import _collections_abc
+
+__all__ = ["asynccontextmanager"]
+
+
+class AbstractAsyncContextManager(abc.ABC):
+
+ """An abstract base class for asynchronous context managers."""
+
+ async def __aenter__(self):
+ """Return `self` upon entering the runtime context."""
+ return self
+
+ @abc.abstractmethod
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ """Raise any exception triggered within the runtime context."""
+ return None
+
+ @classmethod
+ def __subclasshook__(cls, C):
+ if cls is AbstractAsyncContextManager:
+ return _collections_abc._check_methods(C, "__aenter__", "__aexit__") # type: ignore
+ return NotImplemented
+
+
+class _GeneratorContextManagerBase:
+ """Shared functionality for @contextmanager and @asynccontextmanager."""
+
+ def __init__(self, func, args, kwds):
+ self.gen = func(*args, **kwds)
+ self.func, self.args, self.kwds = func, args, kwds
+ # Issue 19330: ensure context manager instances have good docstrings
+ doc = getattr(func, "__doc__", None)
+ if doc is None:
+ doc = type(self).__doc__
+ self.__doc__ = doc
+ # Unfortunately, this still doesn't provide good help output when
+ # inspecting the created context manager instances, since pydoc
+ # currently bypasses the instance docstring and shows the docstring
+ # for the class instead.
+ # See http://bugs.python.org/issue19404 for more details.
+
+
+class _AsyncGeneratorContextManager(
+ _GeneratorContextManagerBase, AbstractAsyncContextManager
+):
+ """Helper for @asynccontextmanager."""
+
+ async def __aenter__(self):
+ try:
+ return await self.gen.__anext__()
+ except StopAsyncIteration:
+ raise RuntimeError("generator didn't yield") from None
+
+ async def __aexit__(self, typ, value, traceback):
+ if typ is None:
+ try:
+ await self.gen.__anext__()
+ except StopAsyncIteration:
+ return
+ else:
+ raise RuntimeError("generator didn't stop")
+ else:
+ if value is None:
+ value = typ()
+ # See _GeneratorContextManager.__exit__ for comments on subtleties
+ # in this implementation
+ try:
+ await self.gen.athrow(typ, value, traceback)
+ raise RuntimeError("generator didn't stop after athrow()")
+ except StopAsyncIteration as exc:
+ return exc is not value
+ except RuntimeError as exc:
+ if exc is value:
+ return False
+ # Avoid suppressing if a StopIteration exception
+ # was passed to throw() and later wrapped into a RuntimeError
+ # (see PEP 479 for sync generators; async generators also
+ # have this behavior). But do this only if the exception wrapped
+ # by the RuntimeError is actully Stop(Async)Iteration (see
+ # issue29692).
+ if isinstance(value, (StopIteration, StopAsyncIteration)):
+ if exc.__cause__ is value:
+ return False
+ raise
+ except BaseException as exc:
+ if exc is not value:
+ raise
+
+
+_T = TypeVar("_T")
+
+
+def asynccontextmanager(
+ func: Callable[..., AsyncIterator[_T]]
+) -> Callable[..., AsyncContextManager[_T]]:
+ """@asynccontextmanager decorator.
+ Typical usage:
+ @asynccontextmanager
+ async def some_async_generator(<arguments>):
+ <setup>
+ try:
+ yield <value>
+ finally:
+ <cleanup>
+ This makes this:
+ async with some_async_generator(<arguments>) as <variable>:
+ <body>
+ equivalent to this:
+ <setup>
+ try:
+ <variable> = <value>
+ <body>
+ finally:
+ <cleanup>
+ """
+
+ @wraps(func)
+ def helper(*args, **kwds):
+ return _AsyncGeneratorContextManager(func, args, kwds) # type: ignore
+
+ return helper