summaryrefslogtreecommitdiffstats
path: root/tests/tests_keras.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/tests_keras.py')
-rw-r--r--tests/tests_keras.py93
1 files changed, 93 insertions, 0 deletions
diff --git a/tests/tests_keras.py b/tests/tests_keras.py
new file mode 100644
index 0000000..220f946
--- /dev/null
+++ b/tests/tests_keras.py
@@ -0,0 +1,93 @@
+from __future__ import division
+
+from .tests_tqdm import importorskip, mark
+
+pytestmark = mark.slow
+
+
+@mark.filterwarnings("ignore:.*:DeprecationWarning")
+def test_keras(capsys):
+ """Test tqdm.keras.TqdmCallback"""
+ TqdmCallback = importorskip('tqdm.keras').TqdmCallback
+ np = importorskip('numpy')
+ try:
+ import keras as K
+ except ImportError:
+ K = importorskip('tensorflow.keras')
+
+ # 1D autoencoder
+ dtype = np.float32
+ model = K.models.Sequential([
+ K.layers.InputLayer((1, 1), dtype=dtype), K.layers.Conv1D(1, 1)])
+ model.compile("adam", "mse")
+ x = np.random.rand(100, 1, 1).astype(dtype)
+ batch_size = 10
+ batches = len(x) / batch_size
+ epochs = 5
+
+ # just epoch (no batch) progress
+ model.fit(
+ x,
+ x,
+ epochs=epochs,
+ batch_size=batch_size,
+ verbose=False,
+ callbacks=[
+ TqdmCallback(
+ epochs,
+ desc="training",
+ data_size=len(x),
+ batch_size=batch_size,
+ verbose=0)])
+ _, res = capsys.readouterr()
+ assert "training: " in res
+ assert "{epochs}/{epochs}".format(epochs=epochs) in res
+ assert "{batches}/{batches}".format(batches=batches) not in res
+
+ # full (epoch and batch) progress
+ model.fit(
+ x,
+ x,
+ epochs=epochs,
+ batch_size=batch_size,
+ verbose=False,
+ callbacks=[
+ TqdmCallback(
+ epochs,
+ desc="training",
+ data_size=len(x),
+ batch_size=batch_size,
+ verbose=2)])
+ _, res = capsys.readouterr()
+ assert "training: " in res
+ assert "{epochs}/{epochs}".format(epochs=epochs) in res
+ assert "{batches}/{batches}".format(batches=batches) in res
+
+ # auto-detect epochs and batches
+ model.fit(
+ x,
+ x,
+ epochs=epochs,
+ batch_size=batch_size,
+ verbose=False,
+ callbacks=[TqdmCallback(desc="training", verbose=2)])
+ _, res = capsys.readouterr()
+ assert "training: " in res
+ assert "{epochs}/{epochs}".format(epochs=epochs) in res
+ assert "{batches}/{batches}".format(batches=batches) in res
+
+ # continue training (start from epoch != 0)
+ initial_epoch = 3
+ model.fit(
+ x,
+ x,
+ initial_epoch=initial_epoch,
+ epochs=epochs,
+ batch_size=batch_size,
+ verbose=False,
+ callbacks=[TqdmCallback(desc="training", verbose=0,
+ miniters=1, mininterval=0, maxinterval=0)])
+ _, res = capsys.readouterr()
+ assert "training: " in res
+ assert "{epochs}/{epochs}".format(epochs=initial_epoch - 1) not in res
+ assert "{epochs}/{epochs}".format(epochs=epochs) in res