from functools import partial from dask.callbacks import Callback from .auto import tqdm as tqdm_auto __author__ = {"github.com/": ["casperdcl"]} __all__ = ['TqdmCallback'] class TqdmCallback(Callback): """Dask callback for task progress.""" def __init__(self, start=None, pretask=None, tqdm_class=tqdm_auto, **tqdm_kwargs): """ Parameters ---------- tqdm_class : optional `tqdm` class to use for bars [default: `tqdm.auto.tqdm`]. tqdm_kwargs : optional Any other arguments used for all bars. """ super().__init__(start=start, pretask=pretask) if tqdm_kwargs: tqdm_class = partial(tqdm_class, **tqdm_kwargs) self.tqdm_class = tqdm_class def _start_state(self, _, state): self.pbar = self.tqdm_class(total=sum( len(state[k]) for k in ['ready', 'waiting', 'running', 'finished'])) def _posttask(self, *_, **__): self.pbar.update() def _finish(self, *_, **__): self.pbar.close() def display(self): """Displays in the current cell in Notebooks.""" container = getattr(self.bar, 'container', None) if container is None: return from .notebook import display display(container)