summaryrefslogtreecommitdiffstats
path: root/tqdm/asyncio.py
diff options
context:
space:
mode:
Diffstat (limited to 'tqdm/asyncio.py')
-rw-r--r--tqdm/asyncio.py93
1 files changed, 93 insertions, 0 deletions
diff --git a/tqdm/asyncio.py b/tqdm/asyncio.py
new file mode 100644
index 0000000..97c5f88
--- /dev/null
+++ b/tqdm/asyncio.py
@@ -0,0 +1,93 @@
+"""
+Asynchronous progressbar decorator for iterators.
+Includes a default `range` iterator printing to `stderr`.
+
+Usage:
+>>> from tqdm.asyncio import trange, tqdm
+>>> async for i in trange(10):
+... ...
+"""
+import asyncio
+from sys import version_info
+
+from .std import tqdm as std_tqdm
+
+__author__ = {"github.com/": ["casperdcl"]}
+__all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange']
+
+
+class tqdm_asyncio(std_tqdm):
+ """
+ Asynchronous-friendly version of tqdm (Python 3.6+).
+ """
+ def __init__(self, iterable=None, *args, **kwargs):
+ super(tqdm_asyncio, self).__init__(iterable, *args, **kwargs)
+ self.iterable_awaitable = False
+ if iterable is not None:
+ if hasattr(iterable, "__anext__"):
+ self.iterable_next = iterable.__anext__
+ self.iterable_awaitable = True
+ elif hasattr(iterable, "__next__"):
+ self.iterable_next = iterable.__next__
+ else:
+ self.iterable_iterator = iter(iterable)
+ self.iterable_next = self.iterable_iterator.__next__
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ try:
+ if self.iterable_awaitable:
+ res = await self.iterable_next()
+ else:
+ res = self.iterable_next()
+ self.update()
+ return res
+ except StopIteration:
+ self.close()
+ raise StopAsyncIteration
+ except BaseException:
+ self.close()
+ raise
+
+ def send(self, *args, **kwargs):
+ return self.iterable.send(*args, **kwargs)
+
+ @classmethod
+ def as_completed(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs):
+ """
+ Wrapper for `asyncio.as_completed`.
+ """
+ if total is None:
+ total = len(fs)
+ kwargs = {}
+ if version_info[:2] < (3, 10):
+ kwargs['loop'] = loop
+ yield from cls(asyncio.as_completed(fs, timeout=timeout, **kwargs),
+ total=total, **tqdm_kwargs)
+
+ @classmethod
+ async def gather(cls, *fs, loop=None, timeout=None, total=None, **tqdm_kwargs):
+ """
+ Wrapper for `asyncio.gather`.
+ """
+ async def wrap_awaitable(i, f):
+ return i, await f
+
+ ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)]
+ res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout,
+ total=total, **tqdm_kwargs)]
+ return [i for _, i in sorted(res)]
+
+
+def tarange(*args, **kwargs):
+ """
+ A shortcut for `tqdm.asyncio.tqdm(range(*args), **kwargs)`.
+ """
+ return tqdm_asyncio(range(*args), **kwargs)
+
+
+# Aliases
+tqdm = tqdm_asyncio
+trange = tarange