summaryrefslogtreecommitdiffstats
path: root/tqdm/keras.py
blob: cce9467c51a95388aaa502d1da9a42f3ebf0af24 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from copy import copy
from functools import partial

from .auto import tqdm as tqdm_auto

try:
    import keras
except (ImportError, AttributeError) as e:
    try:
        from tensorflow import keras
    except ImportError:
        raise e
__author__ = {"github.com/": ["casperdcl"]}
__all__ = ['TqdmCallback']


class TqdmCallback(keras.callbacks.Callback):
    """Keras callback for epoch and batch progress."""
    @staticmethod
    def bar2callback(bar, pop=None, delta=(lambda logs: 1)):
        def callback(_, logs=None):
            n = delta(logs)
            if logs:
                if pop:
                    logs = copy(logs)
                    [logs.pop(i, 0) for i in pop]
                bar.set_postfix(logs, refresh=False)
            bar.update(n)

        return callback

    def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1,
                 tqdm_class=tqdm_auto, **tqdm_kwargs):
        """
        Parameters
        ----------
        epochs  : int, optional
        data_size  : int, optional
            Number of training pairs.
        batch_size  : int, optional
            Number of training pairs per batch.
        verbose  : int
            0: epoch, 1: batch (transient), 2: batch. [default: 1].
            Will be set to `0` unless both `data_size` and `batch_size`
            are given.
        tqdm_class  : optional
            `tqdm` class to use for bars [default: `tqdm.auto.tqdm`].
        tqdm_kwargs  : optional
            Any other arguments used for all bars.
        """
        if tqdm_kwargs:
            tqdm_class = partial(tqdm_class, **tqdm_kwargs)
        self.tqdm_class = tqdm_class
        self.epoch_bar = tqdm_class(total=epochs, unit='epoch')
        self.on_epoch_end = self.bar2callback(self.epoch_bar)
        if data_size and batch_size:
            self.batches = batches = (data_size + batch_size - 1) // batch_size
        else:
            self.batches = batches = None
        self.verbose = verbose
        if verbose == 1:
            self.batch_bar = tqdm_class(total=batches, unit='batch', leave=False)
            self.on_batch_end = self.bar2callback(
                self.batch_bar, pop=['batch', 'size'],
                delta=lambda logs: logs.get('size', 1))

    def on_train_begin(self, *_, **__):
        params = self.params.get
        auto_total = params('epochs', params('nb_epoch', None))
        if auto_total is not None and auto_total != self.epoch_bar.total:
            self.epoch_bar.reset(total=auto_total)

    def on_epoch_begin(self, epoch, *_, **__):
        if self.epoch_bar.n < epoch:
            ebar = self.epoch_bar
            ebar.n = ebar.last_print_n = ebar.initial = epoch
        if self.verbose:
            params = self.params.get
            total = params('samples', params(
                'nb_sample', params('steps', None))) or self.batches
            if self.verbose == 2:
                if hasattr(self, 'batch_bar'):
                    self.batch_bar.close()
                self.batch_bar = self.tqdm_class(
                    total=total, unit='batch', leave=True,
                    unit_scale=1 / (params('batch_size', 1) or 1))
                self.on_batch_end = self.bar2callback(
                    self.batch_bar, pop=['batch', 'size'],
                    delta=lambda logs: logs.get('size', 1))
            elif self.verbose == 1:
                self.batch_bar.unit_scale = 1 / (params('batch_size', 1) or 1)
                self.batch_bar.reset(total=total)
            else:
                raise KeyError('Unknown verbosity')

    def on_train_end(self, *_, **__):
        if hasattr(self, 'batch_bar'):
            self.batch_bar.close()
        self.epoch_bar.close()

    def display(self):
        """Displays in the current cell in Notebooks."""
        container = getattr(self.epoch_bar, 'container', None)
        if container is None:
            return
        from .notebook import display
        display(container)
        batch_bar = getattr(self, 'batch_bar', None)
        if batch_bar is not None:
            display(batch_bar.container)

    @staticmethod
    def _implements_train_batch_hooks():
        return True

    @staticmethod
    def _implements_test_batch_hooks():
        return True

    @staticmethod
    def _implements_predict_batch_hooks():
        return True