diff options
Diffstat (limited to 'src/prompt_toolkit/shortcuts/progress_bar')
-rw-r--r-- | src/prompt_toolkit/shortcuts/progress_bar/__init__.py | 31 | ||||
-rw-r--r-- | src/prompt_toolkit/shortcuts/progress_bar/base.py | 464 | ||||
-rw-r--r-- | src/prompt_toolkit/shortcuts/progress_bar/formatters.py | 436 |
3 files changed, 931 insertions, 0 deletions
diff --git a/src/prompt_toolkit/shortcuts/progress_bar/__init__.py b/src/prompt_toolkit/shortcuts/progress_bar/__init__.py new file mode 100644 index 0000000..7d0fbb5 --- /dev/null +++ b/src/prompt_toolkit/shortcuts/progress_bar/__init__.py @@ -0,0 +1,31 @@ +from .base import ProgressBar, ProgressBarCounter +from .formatters import ( + Bar, + Formatter, + IterationsPerSecond, + Label, + Percentage, + Progress, + Rainbow, + SpinningWheel, + Text, + TimeElapsed, + TimeLeft, +) + +__all__ = [ + "ProgressBar", + "ProgressBarCounter", + # Formatters. + "Formatter", + "Text", + "Label", + "Percentage", + "Bar", + "Progress", + "TimeElapsed", + "TimeLeft", + "IterationsPerSecond", + "SpinningWheel", + "Rainbow", +] diff --git a/src/prompt_toolkit/shortcuts/progress_bar/base.py b/src/prompt_toolkit/shortcuts/progress_bar/base.py new file mode 100644 index 0000000..d790466 --- /dev/null +++ b/src/prompt_toolkit/shortcuts/progress_bar/base.py @@ -0,0 +1,464 @@ +""" +Progress bar implementation on top of prompt_toolkit. + +:: + + with ProgressBar(...) as pb: + for item in pb(data): + ... +""" +import datetime +import functools +import os +import signal +import threading +import traceback +from asyncio import new_event_loop, set_event_loop +from typing import ( + Callable, + Generic, + Iterable, + Iterator, + List, + Optional, + Sequence, + Sized, + TextIO, + TypeVar, + cast, +) + +from prompt_toolkit.application import Application +from prompt_toolkit.application.current import get_app_session +from prompt_toolkit.eventloop import get_event_loop +from prompt_toolkit.filters import Condition, is_done, renderer_height_is_known +from prompt_toolkit.formatted_text import ( + AnyFormattedText, + StyleAndTextTuples, + to_formatted_text, +) +from prompt_toolkit.input import Input +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.key_binding.key_processor import KeyPressEvent +from prompt_toolkit.layout import ( + ConditionalContainer, + FormattedTextControl, + HSplit, + Layout, + VSplit, + Window, +) +from prompt_toolkit.layout.controls import UIContent, UIControl +from prompt_toolkit.layout.dimension import AnyDimension, D +from prompt_toolkit.output import ColorDepth, Output +from prompt_toolkit.styles import BaseStyle + +from .formatters import Formatter, create_default_formatters + +try: + import contextvars +except ImportError: + from prompt_toolkit.eventloop import dummy_contextvars + + contextvars = dummy_contextvars # type: ignore + + +__all__ = ["ProgressBar"] + +E = KeyPressEvent + +_SIGWINCH = getattr(signal, "SIGWINCH", None) + + +def create_key_bindings(cancel_callback: Optional[Callable[[], None]]) -> KeyBindings: + """ + Key bindings handled by the progress bar. + (The main thread is not supposed to handle any key bindings.) + """ + kb = KeyBindings() + + @kb.add("c-l") + def _clear(event: E) -> None: + event.app.renderer.clear() + + if cancel_callback is not None: + + @kb.add("c-c") + def _interrupt(event: E) -> None: + "Kill the 'body' of the progress bar, but only if we run from the main thread." + assert cancel_callback is not None + cancel_callback() + + return kb + + +_T = TypeVar("_T") + + +class ProgressBar: + """ + Progress bar context manager. + + Usage :: + + with ProgressBar(...) as pb: + for item in pb(data): + ... + + :param title: Text to be displayed above the progress bars. This can be a + callable or formatted text as well. + :param formatters: List of :class:`.Formatter` instances. + :param bottom_toolbar: Text to be displayed in the bottom toolbar. This + can be a callable or formatted text. + :param style: :class:`prompt_toolkit.styles.BaseStyle` instance. + :param key_bindings: :class:`.KeyBindings` instance. + :param cancel_callback: Callback function that's called when control-c is + pressed by the user. This can be used for instance to start "proper" + cancellation if the wrapped code supports it. + :param file: The file object used for rendering, by default `sys.stderr` is used. + + :param color_depth: `prompt_toolkit` `ColorDepth` instance. + :param output: :class:`~prompt_toolkit.output.Output` instance. + :param input: :class:`~prompt_toolkit.input.Input` instance. + """ + + def __init__( + self, + title: AnyFormattedText = None, + formatters: Optional[Sequence[Formatter]] = None, + bottom_toolbar: AnyFormattedText = None, + style: Optional[BaseStyle] = None, + key_bindings: Optional[KeyBindings] = None, + cancel_callback: Optional[Callable[[], None]] = None, + file: Optional[TextIO] = None, + color_depth: Optional[ColorDepth] = None, + output: Optional[Output] = None, + input: Optional[Input] = None, + ) -> None: + + self.title = title + self.formatters = formatters or create_default_formatters() + self.bottom_toolbar = bottom_toolbar + self.counters: List[ProgressBarCounter[object]] = [] + self.style = style + self.key_bindings = key_bindings + self.cancel_callback = cancel_callback + + # If no `cancel_callback` was given, and we're creating the progress + # bar from the main thread. Cancel by sending a `KeyboardInterrupt` to + # the main thread. + if ( + self.cancel_callback is None + and threading.currentThread() == threading.main_thread() + ): + + def keyboard_interrupt_to_main_thread() -> None: + os.kill(os.getpid(), signal.SIGINT) + + self.cancel_callback = keyboard_interrupt_to_main_thread + + # Note that we use __stderr__ as default error output, because that + # works best with `patch_stdout`. + self.color_depth = color_depth + self.output = output or get_app_session().output + self.input = input or get_app_session().input + + self._thread: Optional[threading.Thread] = None + + self._app_loop = new_event_loop() + self._has_sigwinch = False + self._app_started = threading.Event() + + def __enter__(self) -> "ProgressBar": + # Create UI Application. + title_toolbar = ConditionalContainer( + Window( + FormattedTextControl(lambda: self.title), + height=1, + style="class:progressbar,title", + ), + filter=Condition(lambda: self.title is not None), + ) + + bottom_toolbar = ConditionalContainer( + Window( + FormattedTextControl( + lambda: self.bottom_toolbar, style="class:bottom-toolbar.text" + ), + style="class:bottom-toolbar", + height=1, + ), + filter=~is_done + & renderer_height_is_known + & Condition(lambda: self.bottom_toolbar is not None), + ) + + def width_for_formatter(formatter: Formatter) -> AnyDimension: + # Needs to be passed as callable (partial) to the 'width' + # parameter, because we want to call it on every resize. + return formatter.get_width(progress_bar=self) + + progress_controls = [ + Window( + content=_ProgressControl(self, f, self.cancel_callback), + width=functools.partial(width_for_formatter, f), + ) + for f in self.formatters + ] + + self.app: Application[None] = Application( + min_redraw_interval=0.05, + layout=Layout( + HSplit( + [ + title_toolbar, + VSplit( + progress_controls, + height=lambda: D( + preferred=len(self.counters), max=len(self.counters) + ), + ), + Window(), + bottom_toolbar, + ] + ) + ), + style=self.style, + key_bindings=self.key_bindings, + refresh_interval=0.3, + color_depth=self.color_depth, + output=self.output, + input=self.input, + ) + + # Run application in different thread. + def run() -> None: + set_event_loop(self._app_loop) + try: + self.app.run(pre_run=self._app_started.set) + except BaseException as e: + traceback.print_exc() + print(e) + + ctx: contextvars.Context = contextvars.copy_context() + + self._thread = threading.Thread(target=ctx.run, args=(run,)) + self._thread.start() + + return self + + def __exit__(self, *a: object) -> None: + # Wait for the app to be started. Make sure we don't quit earlier, + # otherwise `self.app.exit` won't terminate the app because + # `self.app.future` has not yet been set. + self._app_started.wait() + + # Quit UI application. + if self.app.is_running: + self._app_loop.call_soon_threadsafe(self.app.exit) + + if self._thread is not None: + self._thread.join() + self._app_loop.close() + + def __call__( + self, + data: Optional[Iterable[_T]] = None, + label: AnyFormattedText = "", + remove_when_done: bool = False, + total: Optional[int] = None, + ) -> "ProgressBarCounter[_T]": + """ + Start a new counter. + + :param label: Title text or description for this progress. (This can be + formatted text as well). + :param remove_when_done: When `True`, hide this progress bar. + :param total: Specify the maximum value if it can't be calculated by + calling ``len``. + """ + counter = ProgressBarCounter( + self, data, label=label, remove_when_done=remove_when_done, total=total + ) + self.counters.append(counter) + return counter + + def invalidate(self) -> None: + self.app.invalidate() + + +class _ProgressControl(UIControl): + """ + User control for the progress bar. + """ + + def __init__( + self, + progress_bar: ProgressBar, + formatter: Formatter, + cancel_callback: Optional[Callable[[], None]], + ) -> None: + self.progress_bar = progress_bar + self.formatter = formatter + self._key_bindings = create_key_bindings(cancel_callback) + + def create_content(self, width: int, height: int) -> UIContent: + items: List[StyleAndTextTuples] = [] + + for pr in self.progress_bar.counters: + try: + text = self.formatter.format(self.progress_bar, pr, width) + except BaseException: + traceback.print_exc() + text = "ERROR" + + items.append(to_formatted_text(text)) + + def get_line(i: int) -> StyleAndTextTuples: + return items[i] + + return UIContent(get_line=get_line, line_count=len(items), show_cursor=False) + + def is_focusable(self) -> bool: + return True # Make sure that the key bindings work. + + def get_key_bindings(self) -> KeyBindings: + return self._key_bindings + + +_CounterItem = TypeVar("_CounterItem", covariant=True) + + +class ProgressBarCounter(Generic[_CounterItem]): + """ + An individual counter (A progress bar can have multiple counters). + """ + + def __init__( + self, + progress_bar: ProgressBar, + data: Optional[Iterable[_CounterItem]] = None, + label: AnyFormattedText = "", + remove_when_done: bool = False, + total: Optional[int] = None, + ) -> None: + + self.start_time = datetime.datetime.now() + self.stop_time: Optional[datetime.datetime] = None + self.progress_bar = progress_bar + self.data = data + self.items_completed = 0 + self.label = label + self.remove_when_done = remove_when_done + self._done = False + self.total: Optional[int] + + if total is None: + try: + self.total = len(cast(Sized, data)) + except TypeError: + self.total = None # We don't know the total length. + else: + self.total = total + + def __iter__(self) -> Iterator[_CounterItem]: + if self.data is not None: + try: + for item in self.data: + yield item + self.item_completed() + + # Only done if we iterate to the very end. + self.done = True + finally: + # Ensure counter has stopped even if we did not iterate to the + # end (e.g. break or exceptions). + self.stopped = True + else: + raise NotImplementedError("No data defined to iterate over.") + + def item_completed(self) -> None: + """ + Start handling the next item. + + (Can be called manually in case we don't have a collection to loop through.) + """ + self.items_completed += 1 + self.progress_bar.invalidate() + + @property + def done(self) -> bool: + """Whether a counter has been completed. + + Done counter have been stopped (see stopped) and removed depending on + remove_when_done value. + + Contrast this with stopped. A stopped counter may be terminated before + 100% completion. A done counter has reached its 100% completion. + """ + return self._done + + @done.setter + def done(self, value: bool) -> None: + self._done = value + self.stopped = value + + if value and self.remove_when_done: + self.progress_bar.counters.remove(self) + + @property + def stopped(self) -> bool: + """Whether a counter has been stopped. + + Stopped counters no longer have increasing time_elapsed. This distinction is + also used to prevent the Bar formatter with unknown totals from continuing to run. + + A stopped counter (but not done) can be used to signal that a given counter has + encountered an error but allows other counters to continue + (e.g. download X of Y failed). Given how only done counters are removed + (see remove_when_done) this can help aggregate failures from a large number of + successes. + + Contrast this with done. A done counter has reached its 100% completion. + A stopped counter may be terminated before 100% completion. + """ + return self.stop_time is not None + + @stopped.setter + def stopped(self, value: bool) -> None: + if value: + # This counter has not already been stopped. + if not self.stop_time: + self.stop_time = datetime.datetime.now() + else: + # Clearing any previously set stop_time. + self.stop_time = None + + @property + def percentage(self) -> float: + if self.total is None: + return 0 + else: + return self.items_completed * 100 / max(self.total, 1) + + @property + def time_elapsed(self) -> datetime.timedelta: + """ + Return how much time has been elapsed since the start. + """ + if self.stop_time is None: + return datetime.datetime.now() - self.start_time + else: + return self.stop_time - self.start_time + + @property + def time_left(self) -> Optional[datetime.timedelta]: + """ + Timedelta representing the time left. + """ + if self.total is None or not self.percentage: + return None + elif self.done or self.stopped: + return datetime.timedelta(0) + else: + return self.time_elapsed * (100 - self.percentage) / self.percentage diff --git a/src/prompt_toolkit/shortcuts/progress_bar/formatters.py b/src/prompt_toolkit/shortcuts/progress_bar/formatters.py new file mode 100644 index 0000000..1383d7a --- /dev/null +++ b/src/prompt_toolkit/shortcuts/progress_bar/formatters.py @@ -0,0 +1,436 @@ +""" +Formatter classes for the progress bar. +Each progress bar consists of a list of these formatters. +""" +import datetime +import time +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, List, Tuple + +from prompt_toolkit.formatted_text import ( + HTML, + AnyFormattedText, + StyleAndTextTuples, + to_formatted_text, +) +from prompt_toolkit.formatted_text.utils import fragment_list_width +from prompt_toolkit.layout.dimension import AnyDimension, D +from prompt_toolkit.layout.utils import explode_text_fragments +from prompt_toolkit.utils import get_cwidth + +if TYPE_CHECKING: + from .base import ProgressBar, ProgressBarCounter + +__all__ = [ + "Formatter", + "Text", + "Label", + "Percentage", + "Bar", + "Progress", + "TimeElapsed", + "TimeLeft", + "IterationsPerSecond", + "SpinningWheel", + "Rainbow", + "create_default_formatters", +] + + +class Formatter(metaclass=ABCMeta): + """ + Base class for any formatter. + """ + + @abstractmethod + def format( + self, + progress_bar: "ProgressBar", + progress: "ProgressBarCounter[object]", + width: int, + ) -> AnyFormattedText: + pass + + def get_width(self, progress_bar: "ProgressBar") -> AnyDimension: + return D() + + +class Text(Formatter): + """ + Display plain text. + """ + + def __init__(self, text: AnyFormattedText, style: str = "") -> None: + self.text = to_formatted_text(text, style=style) + + def format( + self, + progress_bar: "ProgressBar", + progress: "ProgressBarCounter[object]", + width: int, + ) -> AnyFormattedText: + return self.text + + def get_width(self, progress_bar: "ProgressBar") -> AnyDimension: + return fragment_list_width(self.text) + + +class Label(Formatter): + """ + Display the name of the current task. + + :param width: If a `width` is given, use this width. Scroll the text if it + doesn't fit in this width. + :param suffix: String suffix to be added after the task name, e.g. ': '. + If no task name was given, no suffix will be added. + """ + + def __init__(self, width: AnyDimension = None, suffix: str = "") -> None: + self.width = width + self.suffix = suffix + + def _add_suffix(self, label: AnyFormattedText) -> StyleAndTextTuples: + label = to_formatted_text(label, style="class:label") + return label + [("", self.suffix)] + + def format( + self, + progress_bar: "ProgressBar", + progress: "ProgressBarCounter[object]", + width: int, + ) -> AnyFormattedText: + + label = self._add_suffix(progress.label) + cwidth = fragment_list_width(label) + + if cwidth > width: + # It doesn't fit -> scroll task name. + label = explode_text_fragments(label) + max_scroll = cwidth - width + current_scroll = int(time.time() * 3 % max_scroll) + label = label[current_scroll:] + + return label + + def get_width(self, progress_bar: "ProgressBar") -> AnyDimension: + if self.width: + return self.width + + all_labels = [self._add_suffix(c.label) for c in progress_bar.counters] + if all_labels: + max_widths = max(fragment_list_width(l) for l in all_labels) + return D(preferred=max_widths, max=max_widths) + else: + return D() + + +class Percentage(Formatter): + """ + Display the progress as a percentage. + """ + + template = "<percentage>{percentage:>5}%</percentage>" + + def format( + self, + progress_bar: "ProgressBar", + progress: "ProgressBarCounter[object]", + width: int, + ) -> AnyFormattedText: + + return HTML(self.template).format(percentage=round(progress.percentage, 1)) + + def get_width(self, progress_bar: "ProgressBar") -> AnyDimension: + return D.exact(6) + + +class Bar(Formatter): + """ + Display the progress bar itself. + """ + + template = "<bar>{start}<bar-a>{bar_a}</bar-a><bar-b>{bar_b}</bar-b><bar-c>{bar_c}</bar-c>{end}</bar>" + + def __init__( + self, + start: str = "[", + end: str = "]", + sym_a: str = "=", + sym_b: str = ">", + sym_c: str = " ", + unknown: str = "#", + ) -> None: + + assert len(sym_a) == 1 and get_cwidth(sym_a) == 1 + assert len(sym_c) == 1 and get_cwidth(sym_c) == 1 + + self.start = start + self.end = end + self.sym_a = sym_a + self.sym_b = sym_b + self.sym_c = sym_c + self.unknown = unknown + + def format( + self, + progress_bar: "ProgressBar", + progress: "ProgressBarCounter[object]", + width: int, + ) -> AnyFormattedText: + if progress.done or progress.total or progress.stopped: + sym_a, sym_b, sym_c = self.sym_a, self.sym_b, self.sym_c + + # Compute pb_a based on done, total, or stopped states. + if progress.done: + # 100% completed irrelevant of how much was actually marked as completed. + percent = 1.0 + else: + # Show percentage completed. + percent = progress.percentage / 100 + else: + # Total is unknown and bar is still running. + sym_a, sym_b, sym_c = self.sym_c, self.unknown, self.sym_c + + # Compute percent based on the time. + percent = time.time() * 20 % 100 / 100 + + # Subtract left, sym_b, and right. + width -= get_cwidth(self.start + sym_b + self.end) + + # Scale percent by width + pb_a = int(percent * width) + bar_a = sym_a * pb_a + bar_b = sym_b + bar_c = sym_c * (width - pb_a) + + return HTML(self.template).format( + start=self.start, end=self.end, bar_a=bar_a, bar_b=bar_b, bar_c=bar_c + ) + + def get_width(self, progress_bar: "ProgressBar") -> AnyDimension: + return D(min=9) + + +class Progress(Formatter): + """ + Display the progress as text. E.g. "8/20" + """ + + template = "<current>{current:>3}</current>/<total>{total:>3}</total>" + + def format( + self, + progress_bar: "ProgressBar", + progress: "ProgressBarCounter[object]", + width: int, + ) -> AnyFormattedText: + + return HTML(self.template).format( + current=progress.items_completed, total=progress.total or "?" + ) + + def get_width(self, progress_bar: "ProgressBar") -> AnyDimension: + all_lengths = [ + len("{:>3}".format(c.total or "?")) for c in progress_bar.counters + ] + all_lengths.append(1) + return D.exact(max(all_lengths) * 2 + 1) + + +def _format_timedelta(timedelta: datetime.timedelta) -> str: + """ + Return hh:mm:ss, or mm:ss if the amount of hours is zero. + """ + result = f"{timedelta}".split(".")[0] + if result.startswith("0:"): + result = result[2:] + return result + + +class TimeElapsed(Formatter): + """ + Display the elapsed time. + """ + + def format( + self, + progress_bar: "ProgressBar", + progress: "ProgressBarCounter[object]", + width: int, + ) -> AnyFormattedText: + + text = _format_timedelta(progress.time_elapsed).rjust(width) + return HTML("<time-elapsed>{time_elapsed}</time-elapsed>").format( + time_elapsed=text + ) + + def get_width(self, progress_bar: "ProgressBar") -> AnyDimension: + all_values = [ + len(_format_timedelta(c.time_elapsed)) for c in progress_bar.counters + ] + if all_values: + return max(all_values) + return 0 + + +class TimeLeft(Formatter): + """ + Display the time left. + """ + + template = "<time-left>{time_left}</time-left>" + unknown = "?:??:??" + + def format( + self, + progress_bar: "ProgressBar", + progress: "ProgressBarCounter[object]", + width: int, + ) -> AnyFormattedText: + + time_left = progress.time_left + if time_left is not None: + formatted_time_left = _format_timedelta(time_left) + else: + formatted_time_left = self.unknown + + return HTML(self.template).format(time_left=formatted_time_left.rjust(width)) + + def get_width(self, progress_bar: "ProgressBar") -> AnyDimension: + all_values = [ + len(_format_timedelta(c.time_left)) if c.time_left is not None else 7 + for c in progress_bar.counters + ] + if all_values: + return max(all_values) + return 0 + + +class IterationsPerSecond(Formatter): + """ + Display the iterations per second. + """ + + template = ( + "<iterations-per-second>{iterations_per_second:.2f}</iterations-per-second>" + ) + + def format( + self, + progress_bar: "ProgressBar", + progress: "ProgressBarCounter[object]", + width: int, + ) -> AnyFormattedText: + + value = progress.items_completed / progress.time_elapsed.total_seconds() + return HTML(self.template.format(iterations_per_second=value)) + + def get_width(self, progress_bar: "ProgressBar") -> AnyDimension: + all_values = [ + len(f"{c.items_completed / c.time_elapsed.total_seconds():.2f}") + for c in progress_bar.counters + ] + if all_values: + return max(all_values) + return 0 + + +class SpinningWheel(Formatter): + """ + Display a spinning wheel. + """ + + characters = r"/-\|" + + def format( + self, + progress_bar: "ProgressBar", + progress: "ProgressBarCounter[object]", + width: int, + ) -> AnyFormattedText: + + index = int(time.time() * 3) % len(self.characters) + return HTML("<spinning-wheel>{0}</spinning-wheel>").format( + self.characters[index] + ) + + def get_width(self, progress_bar: "ProgressBar") -> AnyDimension: + return D.exact(1) + + +def _hue_to_rgb(hue: float) -> Tuple[int, int, int]: + """ + Take hue between 0 and 1, return (r, g, b). + """ + i = int(hue * 6.0) + f = (hue * 6.0) - i + + q = int(255 * (1.0 - f)) + t = int(255 * (1.0 - (1.0 - f))) + + i %= 6 + + return [ + (255, t, 0), + (q, 255, 0), + (0, 255, t), + (0, q, 255), + (t, 0, 255), + (255, 0, q), + ][i] + + +class Rainbow(Formatter): + """ + For the fun. Add rainbow colors to any of the other formatters. + """ + + colors = ["#%.2x%.2x%.2x" % _hue_to_rgb(h / 100.0) for h in range(0, 100)] + + def __init__(self, formatter: Formatter) -> None: + self.formatter = formatter + + def format( + self, + progress_bar: "ProgressBar", + progress: "ProgressBarCounter[object]", + width: int, + ) -> AnyFormattedText: + + # Get formatted text from nested formatter, and explode it in + # text/style tuples. + result = self.formatter.format(progress_bar, progress, width) + result = explode_text_fragments(to_formatted_text(result)) + + # Insert colors. + result2: StyleAndTextTuples = [] + shift = int(time.time() * 3) % len(self.colors) + + for i, (style, text, *_) in enumerate(result): + result2.append( + (style + " " + self.colors[(i + shift) % len(self.colors)], text) + ) + return result2 + + def get_width(self, progress_bar: "ProgressBar") -> AnyDimension: + return self.formatter.get_width(progress_bar) + + +def create_default_formatters() -> List[Formatter]: + """ + Return the list of default formatters. + """ + return [ + Label(), + Text(" "), + Percentage(), + Text(" "), + Bar(), + Text(" "), + Progress(), + Text(" "), + Text("eta [", style="class:time-left"), + TimeLeft(), + Text("]", style="class:time-left"), + Text(" "), + ] |