summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-02 08:20:07 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-02 08:20:07 +0000
commit3d2c9fd003c14a4969f383cd5eb0966b7b6a3d7b (patch)
tree96212b1fc6b9515e6bb63a5fc7869cb1da01d36d /tests
downloadtqdm-3d2c9fd003c14a4969f383cd5eb0966b7b6a3d7b.tar.xz
tqdm-3d2c9fd003c14a4969f383cd5eb0966b7b6a3d7b.zip
Adding upstream version 4.64.1.upstream/4.64.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/conftest.py41
-rw-r--r--tests/py37_asyncio.py128
-rw-r--r--tests/tests_asyncio.py11
-rw-r--r--tests/tests_concurrent.py49
-rw-r--r--tests/tests_contrib.py71
-rw-r--r--tests/tests_contrib_logging.py173
-rw-r--r--tests/tests_dask.py20
-rw-r--r--tests/tests_gui.py7
-rw-r--r--tests/tests_itertools.py26
-rw-r--r--tests/tests_keras.py93
-rw-r--r--tests/tests_main.py245
-rw-r--r--tests/tests_notebook.py7
-rw-r--r--tests/tests_pandas.py219
-rw-r--r--tests/tests_perf.py325
-rw-r--r--tests/tests_rich.py10
-rw-r--r--tests/tests_synchronisation.py224
-rw-r--r--tests/tests_tk.py7
-rw-r--r--tests/tests_tqdm.py1996
-rw-r--r--tests/tests_version.py14
20 files changed, 3666 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..6717044
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,41 @@
+"""Shared pytest config."""
+import sys
+
+from pytest import fixture
+
+from tqdm import tqdm
+
+
+@fixture(autouse=True)
+def pretest_posttest():
+ """Fixture for all tests ensuring environment cleanup"""
+ try:
+ sys.setswitchinterval(1)
+ except AttributeError:
+ sys.setcheckinterval(100) # deprecated
+
+ if getattr(tqdm, "_instances", False):
+ n = len(tqdm._instances)
+ if n:
+ tqdm._instances.clear()
+ raise EnvironmentError(
+ "{0} `tqdm` instances still in existence PRE-test".format(n))
+ yield
+ if getattr(tqdm, "_instances", False):
+ n = len(tqdm._instances)
+ if n:
+ tqdm._instances.clear()
+ raise EnvironmentError(
+ "{0} `tqdm` instances still in existence POST-test".format(n))
+
+
+if sys.version_info[0] > 2:
+ @fixture
+ def capsysbin(capsysbinary):
+ """alias for capsysbinary (py3)"""
+ return capsysbinary
+else:
+ @fixture
+ def capsysbin(capsys):
+ """alias for capsys (py2)"""
+ return capsys
diff --git a/tests/py37_asyncio.py b/tests/py37_asyncio.py
new file mode 100644
index 0000000..8bf61e7
--- /dev/null
+++ b/tests/py37_asyncio.py
@@ -0,0 +1,128 @@
+import asyncio
+from functools import partial
+from sys import platform
+from time import time
+
+from tqdm.asyncio import tarange, tqdm_asyncio
+
+from .tests_tqdm import StringIO, closing, mark
+
+tqdm = partial(tqdm_asyncio, miniters=0, mininterval=0)
+trange = partial(tarange, miniters=0, mininterval=0)
+as_completed = partial(tqdm_asyncio.as_completed, miniters=0, mininterval=0)
+gather = partial(tqdm_asyncio.gather, miniters=0, mininterval=0)
+
+
+def count(start=0, step=1):
+ i = start
+ while True:
+ new_start = yield i
+ if new_start is None:
+ i += step
+ else:
+ i = new_start
+
+
+async def acount(*args, **kwargs):
+ for i in count(*args, **kwargs):
+ yield i
+
+
+@mark.asyncio
+async def test_break():
+ """Test asyncio break"""
+ pbar = tqdm(count())
+ async for _ in pbar:
+ break
+ pbar.close()
+
+
+@mark.asyncio
+async def test_generators(capsys):
+ """Test asyncio generators"""
+ with tqdm(count(), desc="counter") as pbar:
+ async for i in pbar:
+ if i >= 8:
+ break
+ _, err = capsys.readouterr()
+ assert '9it' in err
+
+ with tqdm(acount(), desc="async_counter") as pbar:
+ async for i in pbar:
+ if i >= 8:
+ break
+ _, err = capsys.readouterr()
+ assert '9it' in err
+
+
+@mark.asyncio
+async def test_range():
+ """Test asyncio range"""
+ with closing(StringIO()) as our_file:
+ async for _ in tqdm(range(9), desc="range", file=our_file):
+ pass
+ assert '9/9' in our_file.getvalue()
+ our_file.seek(0)
+ our_file.truncate()
+
+ async for _ in trange(9, desc="trange", file=our_file):
+ pass
+ assert '9/9' in our_file.getvalue()
+
+
+@mark.asyncio
+async def test_nested():
+ """Test asyncio nested"""
+ with closing(StringIO()) as our_file:
+ async for _ in tqdm(trange(9, desc="inner", file=our_file),
+ desc="outer", file=our_file):
+ pass
+ assert 'inner: 100%' in our_file.getvalue()
+ assert 'outer: 100%' in our_file.getvalue()
+
+
+@mark.asyncio
+async def test_coroutines():
+ """Test asyncio coroutine.send"""
+ with closing(StringIO()) as our_file:
+ with tqdm(count(), file=our_file) as pbar:
+ async for i in pbar:
+ if i == 9:
+ pbar.send(-10)
+ elif i < 0:
+ assert i == -9
+ break
+ assert '10it' in our_file.getvalue()
+
+
+@mark.slow
+@mark.asyncio
+@mark.parametrize("tol", [0.2 if platform.startswith("darwin") else 0.1])
+async def test_as_completed(capsys, tol):
+ """Test asyncio as_completed"""
+ for retry in range(3):
+ t = time()
+ skew = time() - t
+ for i in as_completed([asyncio.sleep(0.01 * i) for i in range(30, 0, -1)]):
+ await i
+ t = time() - t - 2 * skew
+ try:
+ assert 0.3 * (1 - tol) < t < 0.3 * (1 + tol), t
+ _, err = capsys.readouterr()
+ assert '30/30' in err
+ except AssertionError:
+ if retry == 2:
+ raise
+
+
+async def double(i):
+ return i * 2
+
+
+@mark.asyncio
+async def test_gather(capsys):
+ """Test asyncio gather"""
+ res = await gather(*map(double, range(30)))
+ _, err = capsys.readouterr()
+ assert '30/30' in err
+ assert res == list(range(0, 30 * 2, 2))
diff --git a/tests/tests_asyncio.py b/tests/tests_asyncio.py
new file mode 100644
index 0000000..6f08926
--- /dev/null
+++ b/tests/tests_asyncio.py
@@ -0,0 +1,11 @@
+"""Tests `tqdm.asyncio` on `python>=3.7`."""
+import sys
+
+if sys.version_info[:2] > (3, 6):
+ from .py37_asyncio import * # NOQA, pylint: disable=wildcard-import
+else:
+ from .tests_tqdm import skip
+ try:
+ skip("async not supported", allow_module_level=True)
+ except TypeError:
+ pass
diff --git a/tests/tests_concurrent.py b/tests/tests_concurrent.py
new file mode 100644
index 0000000..5cd439c
--- /dev/null
+++ b/tests/tests_concurrent.py
@@ -0,0 +1,49 @@
+"""
+Tests for `tqdm.contrib.concurrent`.
+"""
+from pytest import warns
+
+from tqdm.contrib.concurrent import process_map, thread_map
+
+from .tests_tqdm import StringIO, TqdmWarning, closing, importorskip, mark, skip
+
+
+def incr(x):
+ """Dummy function"""
+ return x + 1
+
+
+def test_thread_map():
+ """Test contrib.concurrent.thread_map"""
+ with closing(StringIO()) as our_file:
+ a = range(9)
+ b = [i + 1 for i in a]
+ try:
+ assert thread_map(lambda x: x + 1, a, file=our_file) == b
+ except ImportError as err:
+ skip(str(err))
+ assert thread_map(incr, a, file=our_file) == b
+
+
+def test_process_map():
+ """Test contrib.concurrent.process_map"""
+ with closing(StringIO()) as our_file:
+ a = range(9)
+ b = [i + 1 for i in a]
+ try:
+ assert process_map(incr, a, file=our_file) == b
+ except ImportError as err:
+ skip(str(err))
+
+
+@mark.parametrize("iterables,should_warn", [([], False), (['x'], False), ([()], False),
+ (['x', ()], False), (['x' * 1001], True),
+ (['x' * 100, ('x',) * 1001], True)])
+def test_chunksize_warning(iterables, should_warn):
+ """Test contrib.concurrent.process_map chunksize warnings"""
+ patch = importorskip('unittest.mock').patch
+ with patch('tqdm.contrib.concurrent._executor_map'):
+ if should_warn:
+ warns(TqdmWarning, process_map, incr, *iterables)
+ else:
+ process_map(incr, *iterables)
diff --git a/tests/tests_contrib.py b/tests/tests_contrib.py
new file mode 100644
index 0000000..69a1cad
--- /dev/null
+++ b/tests/tests_contrib.py
@@ -0,0 +1,71 @@
+"""
+Tests for `tqdm.contrib`.
+"""
+import sys
+
+import pytest
+
+from tqdm import tqdm
+from tqdm.contrib import tenumerate, tmap, tzip
+
+from .tests_tqdm import StringIO, closing, importorskip
+
+
+def incr(x):
+ """Dummy function"""
+ return x + 1
+
+
+@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
+def test_enumerate(tqdm_kwargs):
+ """Test contrib.tenumerate"""
+ with closing(StringIO()) as our_file:
+ a = range(9)
+ assert list(tenumerate(a, file=our_file, **tqdm_kwargs)) == list(enumerate(a))
+ assert list(tenumerate(a, 42, file=our_file, **tqdm_kwargs)) == list(
+ enumerate(a, 42)
+ )
+ with closing(StringIO()) as our_file:
+ _ = list(tenumerate(iter(a), file=our_file, **tqdm_kwargs))
+ assert "100%" not in our_file.getvalue()
+ with closing(StringIO()) as our_file:
+ _ = list(tenumerate(iter(a), file=our_file, total=len(a), **tqdm_kwargs))
+ assert "100%" in our_file.getvalue()
+
+
+def test_enumerate_numpy():
+ """Test contrib.tenumerate(numpy.ndarray)"""
+ np = importorskip("numpy")
+ with closing(StringIO()) as our_file:
+ a = np.random.random((42, 7))
+ assert list(tenumerate(a, file=our_file)) == list(np.ndenumerate(a))
+
+
+@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
+def test_zip(tqdm_kwargs):
+ """Test contrib.tzip"""
+ with closing(StringIO()) as our_file:
+ a = range(9)
+ b = [i + 1 for i in a]
+ if sys.version_info[:1] < (3,):
+ assert tzip(a, b, file=our_file, **tqdm_kwargs) == zip(a, b)
+ else:
+ gen = tzip(a, b, file=our_file, **tqdm_kwargs)
+ assert gen != list(zip(a, b))
+ assert list(gen) == list(zip(a, b))
+
+
+@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
+def test_map(tqdm_kwargs):
+ """Test contrib.tmap"""
+ with closing(StringIO()) as our_file:
+ a = range(9)
+ b = [i + 1 for i in a]
+ if sys.version_info[:1] < (3,):
+ assert tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs) == map(
+ incr, a
+ )
+ else:
+ gen = tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs)
+ assert gen != b
+ assert list(gen) == b
diff --git a/tests/tests_contrib_logging.py b/tests/tests_contrib_logging.py
new file mode 100644
index 0000000..6f675dd
--- /dev/null
+++ b/tests/tests_contrib_logging.py
@@ -0,0 +1,173 @@
+# pylint: disable=missing-module-docstring, missing-class-docstring
+# pylint: disable=missing-function-docstring, no-self-use
+from __future__ import absolute_import
+
+import logging
+import logging.handlers
+import sys
+from io import StringIO
+
+import pytest
+
+from tqdm import tqdm
+from tqdm.contrib.logging import _get_first_found_console_logging_handler
+from tqdm.contrib.logging import _TqdmLoggingHandler as TqdmLoggingHandler
+from tqdm.contrib.logging import logging_redirect_tqdm, tqdm_logging_redirect
+
+from .tests_tqdm import importorskip
+
+LOGGER = logging.getLogger(__name__)
+
+TEST_LOGGING_FORMATTER = logging.Formatter()
+
+
+class CustomTqdm(tqdm):
+ messages = []
+
+ @classmethod
+ def write(cls, s, **__): # pylint: disable=arguments-differ
+ CustomTqdm.messages.append(s)
+
+
+class ErrorRaisingTqdm(tqdm):
+ exception_class = RuntimeError
+
+ @classmethod
+ def write(cls, s, **__): # pylint: disable=arguments-differ
+ raise ErrorRaisingTqdm.exception_class('fail fast')
+
+
+class TestTqdmLoggingHandler:
+ def test_should_call_tqdm_write(self):
+ CustomTqdm.messages = []
+ logger = logging.Logger('test')
+ logger.handlers = [TqdmLoggingHandler(CustomTqdm)]
+ logger.info('test')
+ assert CustomTqdm.messages == ['test']
+
+ def test_should_call_handle_error_if_exception_was_thrown(self):
+ patch = importorskip('unittest.mock').patch
+ logger = logging.Logger('test')
+ ErrorRaisingTqdm.exception_class = RuntimeError
+ handler = TqdmLoggingHandler(ErrorRaisingTqdm)
+ logger.handlers = [handler]
+ with patch.object(handler, 'handleError') as mock:
+ logger.info('test')
+ assert mock.called
+
+ @pytest.mark.parametrize('exception_class', [
+ KeyboardInterrupt,
+ SystemExit
+ ])
+ def test_should_not_swallow_certain_exceptions(self, exception_class):
+ logger = logging.Logger('test')
+ ErrorRaisingTqdm.exception_class = exception_class
+ handler = TqdmLoggingHandler(ErrorRaisingTqdm)
+ logger.handlers = [handler]
+ with pytest.raises(exception_class):
+ logger.info('test')
+
+
+class TestGetFirstFoundConsoleLoggingHandler:
+ def test_should_return_none_for_no_handlers(self):
+ assert _get_first_found_console_logging_handler([]) is None
+
+ def test_should_return_none_without_stream_handler(self):
+ handler = logging.handlers.MemoryHandler(capacity=1)
+ assert _get_first_found_console_logging_handler([handler]) is None
+
+ def test_should_return_none_for_stream_handler_not_stdout_or_stderr(self):
+ handler = logging.StreamHandler(StringIO())
+ assert _get_first_found_console_logging_handler([handler]) is None
+
+ def test_should_return_stream_handler_if_stream_is_stdout(self):
+ handler = logging.StreamHandler(sys.stdout)
+ assert _get_first_found_console_logging_handler([handler]) == handler
+
+ def test_should_return_stream_handler_if_stream_is_stderr(self):
+ handler = logging.StreamHandler(sys.stderr)
+ assert _get_first_found_console_logging_handler([handler]) == handler
+
+
+class TestRedirectLoggingToTqdm:
+ def test_should_add_and_remove_tqdm_handler(self):
+ logger = logging.Logger('test')
+ with logging_redirect_tqdm(loggers=[logger]):
+ assert len(logger.handlers) == 1
+ assert isinstance(logger.handlers[0], TqdmLoggingHandler)
+ assert not logger.handlers
+
+ def test_should_remove_and_restore_console_handlers(self):
+ logger = logging.Logger('test')
+ stderr_console_handler = logging.StreamHandler(sys.stderr)
+ stdout_console_handler = logging.StreamHandler(sys.stderr)
+ logger.handlers = [stderr_console_handler, stdout_console_handler]
+ with logging_redirect_tqdm(loggers=[logger]):
+ assert len(logger.handlers) == 1
+ assert isinstance(logger.handlers[0], TqdmLoggingHandler)
+ assert logger.handlers == [stderr_console_handler, stdout_console_handler]
+
+ def test_should_inherit_console_logger_formatter(self):
+ logger = logging.Logger('test')
+ formatter = logging.Formatter('custom: %(message)s')
+ console_handler = logging.StreamHandler(sys.stderr)
+ console_handler.setFormatter(formatter)
+ logger.handlers = [console_handler]
+ with logging_redirect_tqdm(loggers=[logger]):
+ assert logger.handlers[0].formatter == formatter
+
+ def test_should_not_remove_stream_handlers_not_for_stdout_or_stderr(self):
+ logger = logging.Logger('test')
+ stream_handler = logging.StreamHandler(StringIO())
+ logger.addHandler(stream_handler)
+ with logging_redirect_tqdm(loggers=[logger]):
+ assert len(logger.handlers) == 2
+ assert logger.handlers[0] == stream_handler
+ assert isinstance(logger.handlers[1], TqdmLoggingHandler)
+ assert logger.handlers == [stream_handler]
+
+
+class TestTqdmWithLoggingRedirect:
+ def test_should_add_and_remove_handler_from_root_logger_by_default(self):
+ original_handlers = list(logging.root.handlers)
+ with tqdm_logging_redirect(total=1) as pbar:
+ assert isinstance(logging.root.handlers[-1], TqdmLoggingHandler)
+ LOGGER.info('test')
+ pbar.update(1)
+ assert logging.root.handlers == original_handlers
+
+ def test_should_add_and_remove_handler_from_custom_logger(self):
+ logger = logging.Logger('test')
+ with tqdm_logging_redirect(total=1, loggers=[logger]) as pbar:
+ assert len(logger.handlers) == 1
+ assert isinstance(logger.handlers[0], TqdmLoggingHandler)
+ logger.info('test')
+ pbar.update(1)
+ assert not logger.handlers
+
+ def test_should_not_fail_with_logger_without_console_handler(self):
+ logger = logging.Logger('test')
+ logger.handlers = []
+ with tqdm_logging_redirect(total=1, loggers=[logger]):
+ logger.info('test')
+ assert not logger.handlers
+
+ def test_should_format_message(self):
+ logger = logging.Logger('test')
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_handler.setFormatter(logging.Formatter(
+ r'prefix:%(message)s'
+ ))
+ logger.handlers = [console_handler]
+ CustomTqdm.messages = []
+ with tqdm_logging_redirect(loggers=[logger], tqdm_class=CustomTqdm):
+ logger.info('test')
+ assert CustomTqdm.messages == ['prefix:test']
+
+ def test_use_root_logger_by_default_and_write_to_custom_tqdm(self):
+ logger = logging.root
+ CustomTqdm.messages = []
+ with tqdm_logging_redirect(total=1, tqdm_class=CustomTqdm) as pbar:
+ assert isinstance(pbar, CustomTqdm)
+ logger.info('test')
+ assert CustomTqdm.messages == ['test']
diff --git a/tests/tests_dask.py b/tests/tests_dask.py
new file mode 100644
index 0000000..8bf4b64
--- /dev/null
+++ b/tests/tests_dask.py
@@ -0,0 +1,20 @@
+from __future__ import division
+
+from time import sleep
+
+from .tests_tqdm import importorskip, mark
+
+pytestmark = mark.slow
+
+
+def test_dask(capsys):
+ """Test tqdm.dask.TqdmCallback"""
+ ProgressBar = importorskip('tqdm.dask').TqdmCallback
+ dask = importorskip('dask')
+
+ schedule = [dask.delayed(sleep)(i / 10) for i in range(5)]
+ with ProgressBar(desc="computing"):
+ dask.compute(schedule)
+ _, err = capsys.readouterr()
+ assert "computing: " in err
+ assert '5/5' in err
diff --git a/tests/tests_gui.py b/tests/tests_gui.py
new file mode 100644
index 0000000..dddd918
--- /dev/null
+++ b/tests/tests_gui.py
@@ -0,0 +1,7 @@
+"""Test `tqdm.gui`."""
+from .tests_tqdm import importorskip
+
+
+def test_gui_import():
+ """Test `tqdm.gui` import"""
+ importorskip('tqdm.gui')
diff --git a/tests/tests_itertools.py b/tests/tests_itertools.py
new file mode 100644
index 0000000..bfb6eb2
--- /dev/null
+++ b/tests/tests_itertools.py
@@ -0,0 +1,26 @@
+"""
+Tests for `tqdm.contrib.itertools`.
+"""
+import itertools as it
+
+from tqdm.contrib.itertools import product
+
+from .tests_tqdm import StringIO, closing
+
+
+class NoLenIter(object):
+ def __init__(self, iterable):
+ self._it = iterable
+
+ def __iter__(self):
+ for i in self._it:
+ yield i
+
+
+def test_product():
+ """Test contrib.itertools.product"""
+ with closing(StringIO()) as our_file:
+ a = range(9)
+ assert list(product(a, a[::-1], file=our_file)) == list(it.product(a, a[::-1]))
+
+ assert list(product(a, NoLenIter(a), file=our_file)) == list(it.product(a, NoLenIter(a)))
diff --git a/tests/tests_keras.py b/tests/tests_keras.py
new file mode 100644
index 0000000..220f946
--- /dev/null
+++ b/tests/tests_keras.py
@@ -0,0 +1,93 @@
+from __future__ import division
+
+from .tests_tqdm import importorskip, mark
+
+pytestmark = mark.slow
+
+
+@mark.filterwarnings("ignore:.*:DeprecationWarning")
+def test_keras(capsys):
+ """Test tqdm.keras.TqdmCallback"""
+ TqdmCallback = importorskip('tqdm.keras').TqdmCallback
+ np = importorskip('numpy')
+ try:
+ import keras as K
+ except ImportError:
+ K = importorskip('tensorflow.keras')
+
+ # 1D autoencoder
+ dtype = np.float32
+ model = K.models.Sequential([
+ K.layers.InputLayer((1, 1), dtype=dtype), K.layers.Conv1D(1, 1)])
+ model.compile("adam", "mse")
+ x = np.random.rand(100, 1, 1).astype(dtype)
+ batch_size = 10
+ batches = len(x) / batch_size
+ epochs = 5
+
+ # just epoch (no batch) progress
+ model.fit(
+ x,
+ x,
+ epochs=epochs,
+ batch_size=batch_size,
+ verbose=False,
+ callbacks=[
+ TqdmCallback(
+ epochs,
+ desc="training",
+ data_size=len(x),
+ batch_size=batch_size,
+ verbose=0)])
+ _, res = capsys.readouterr()
+ assert "training: " in res
+ assert "{epochs}/{epochs}".format(epochs=epochs) in res
+ assert "{batches}/{batches}".format(batches=batches) not in res
+
+ # full (epoch and batch) progress
+ model.fit(
+ x,
+ x,
+ epochs=epochs,
+ batch_size=batch_size,
+ verbose=False,
+ callbacks=[
+ TqdmCallback(
+ epochs,
+ desc="training",
+ data_size=len(x),
+ batch_size=batch_size,
+ verbose=2)])
+ _, res = capsys.readouterr()
+ assert "training: " in res
+ assert "{epochs}/{epochs}".format(epochs=epochs) in res
+ assert "{batches}/{batches}".format(batches=batches) in res
+
+ # auto-detect epochs and batches
+ model.fit(
+ x,
+ x,
+ epochs=epochs,
+ batch_size=batch_size,
+ verbose=False,
+ callbacks=[TqdmCallback(desc="training", verbose=2)])
+ _, res = capsys.readouterr()
+ assert "training: " in res
+ assert "{epochs}/{epochs}".format(epochs=epochs) in res
+ assert "{batches}/{batches}".format(batches=batches) in res
+
+ # continue training (start from epoch != 0)
+ initial_epoch = 3
+ model.fit(
+ x,
+ x,
+ initial_epoch=initial_epoch,
+ epochs=epochs,
+ batch_size=batch_size,
+ verbose=False,
+ callbacks=[TqdmCallback(desc="training", verbose=0,
+ miniters=1, mininterval=0, maxinterval=0)])
+ _, res = capsys.readouterr()
+ assert "training: " in res
+ assert "{epochs}/{epochs}".format(epochs=initial_epoch - 1) not in res
+ assert "{epochs}/{epochs}".format(epochs=epochs) in res
diff --git a/tests/tests_main.py b/tests/tests_main.py
new file mode 100644
index 0000000..0523cc7
--- /dev/null
+++ b/tests/tests_main.py
@@ -0,0 +1,245 @@
+"""Test CLI usage."""
+import logging
+import subprocess # nosec
+import sys
+from functools import wraps
+from os import linesep
+
+from tqdm.cli import TqdmKeyError, TqdmTypeError, main
+from tqdm.utils import IS_WIN
+
+from .tests_tqdm import BytesIO, _range, closing, mark, raises
+
+
+def restore_sys(func):
+ """Decorates `func(capsysbin)` to save & restore `sys.(stdin|argv)`."""
+ @wraps(func)
+ def inner(capsysbin):
+ """function requiring capsysbin which may alter `sys.(stdin|argv)`"""
+ _SYS = sys.stdin, sys.argv
+ try:
+ res = func(capsysbin)
+ finally:
+ sys.stdin, sys.argv = _SYS
+ return res
+
+ return inner
+
+
+def norm(bytestr):
+ """Normalise line endings."""
+ return bytestr if linesep == "\n" else bytestr.replace(linesep.encode(), b"\n")
+
+
+@mark.slow
+def test_pipes():
+ """Test command line pipes"""
+ ls_out = subprocess.check_output(['ls']) # nosec
+ ls = subprocess.Popen(['ls'], stdout=subprocess.PIPE) # nosec
+ res = subprocess.Popen( # nosec
+ [sys.executable, '-c', 'from tqdm.cli import main; main()'],
+ stdin=ls.stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ out, err = res.communicate()
+ assert ls.poll() == 0
+
+ # actual test:
+ assert norm(ls_out) == norm(out)
+ assert b"it/s" in err
+ assert b"Error" not in err
+
+
+if sys.version_info[:2] >= (3, 8):
+ test_pipes = mark.filterwarnings("ignore:unclosed file:ResourceWarning")(
+ test_pipes)
+
+
+def test_main_import():
+ """Test main CLI import"""
+ N = 123
+ _SYS = sys.stdin, sys.argv
+ # test direct import
+ sys.stdin = [str(i).encode() for i in _range(N)]
+ sys.argv = ['', '--desc', 'Test CLI import',
+ '--ascii', 'True', '--unit_scale', 'True']
+ try:
+ import tqdm.__main__ # NOQA, pylint: disable=unused-variable
+ finally:
+ sys.stdin, sys.argv = _SYS
+
+
+@restore_sys
+def test_main_bytes(capsysbin):
+ """Test CLI --bytes"""
+ N = 123
+
+ # test --delim
+ IN_DATA = '\0'.join(map(str, _range(N))).encode()
+ with closing(BytesIO()) as sys.stdin:
+ sys.stdin.write(IN_DATA)
+ # sys.stdin.write(b'\xff') # TODO
+ sys.stdin.seek(0)
+ main(sys.stderr, ['--desc', 'Test CLI delim', '--ascii', 'True',
+ '--delim', r'\0', '--buf_size', '64'])
+ out, err = capsysbin.readouterr()
+ assert out == IN_DATA
+ assert str(N) + "it" in err.decode("U8")
+
+ # test --bytes
+ IN_DATA = IN_DATA.replace(b'\0', b'\n')
+ with closing(BytesIO()) as sys.stdin:
+ sys.stdin.write(IN_DATA)
+ sys.stdin.seek(0)
+ main(sys.stderr, ['--ascii', '--bytes=True', '--unit_scale', 'False'])
+ out, err = capsysbin.readouterr()
+ assert out == IN_DATA
+ assert str(len(IN_DATA)) + "B" in err.decode("U8")
+
+
+@mark.skipif(sys.version_info[0] == 2, reason="no caplog on py2")
+def test_main_log(capsysbin, caplog):
+ """Test CLI --log"""
+ _SYS = sys.stdin, sys.argv
+ N = 123
+ sys.stdin = [(str(i) + '\n').encode() for i in _range(N)]
+ IN_DATA = b''.join(sys.stdin)
+ try:
+ with caplog.at_level(logging.INFO):
+ main(sys.stderr, ['--log', 'INFO'])
+ out, err = capsysbin.readouterr()
+ assert norm(out) == IN_DATA and b"123/123" in err
+ assert not caplog.record_tuples
+ with caplog.at_level(logging.DEBUG):
+ main(sys.stderr, ['--log', 'DEBUG'])
+ out, err = capsysbin.readouterr()
+ assert norm(out) == IN_DATA and b"123/123" in err
+ assert caplog.record_tuples
+ finally:
+ sys.stdin, sys.argv = _SYS
+
+
+@restore_sys
+def test_main(capsysbin):
+ """Test misc CLI options"""
+ N = 123
+ sys.stdin = [(str(i) + '\n').encode() for i in _range(N)]
+ IN_DATA = b''.join(sys.stdin)
+
+ # test --tee
+ main(sys.stderr, ['--mininterval', '0', '--miniters', '1'])
+ out, err = capsysbin.readouterr()
+ assert norm(out) == IN_DATA and b"123/123" in err
+ assert N <= len(err.split(b"\r")) < N + 5
+
+ len_err = len(err)
+ main(sys.stderr, ['--tee', '--mininterval', '0', '--miniters', '1'])
+ out, err = capsysbin.readouterr()
+ assert norm(out) == IN_DATA and b"123/123" in err
+ # spaces to clear intermediate lines could increase length
+ assert len_err + len(norm(out)) <= len(err)
+
+ # test --null
+ main(sys.stderr, ['--null'])
+ out, err = capsysbin.readouterr()
+ assert not out and b"123/123" in err
+
+ # test integer --update
+ main(sys.stderr, ['--update'])
+ out, err = capsysbin.readouterr()
+ assert norm(out) == IN_DATA
+ assert (str(N // 2 * N) + "it").encode() in err, "expected arithmetic sum formula"
+
+ # test integer --update_to
+ main(sys.stderr, ['--update-to'])
+ out, err = capsysbin.readouterr()
+ assert norm(out) == IN_DATA
+ assert (str(N - 1) + "it").encode() in err
+ assert (str(N) + "it").encode() not in err
+
+ with closing(BytesIO()) as sys.stdin:
+ sys.stdin.write(IN_DATA.replace(b'\n', b'D'))
+
+ # test integer --update --delim
+ sys.stdin.seek(0)
+ main(sys.stderr, ['--update', '--delim', 'D'])
+ out, err = capsysbin.readouterr()
+ assert out == IN_DATA.replace(b'\n', b'D')
+ assert (str(N // 2 * N) + "it").encode() in err, "expected arithmetic sum"
+
+ # test integer --update_to --delim
+ sys.stdin.seek(0)
+ main(sys.stderr, ['--update-to', '--delim', 'D'])
+ out, err = capsysbin.readouterr()
+ assert out == IN_DATA.replace(b'\n', b'D')
+ assert (str(N - 1) + "it").encode() in err
+ assert (str(N) + "it").encode() not in err
+
+ # test float --update_to
+ sys.stdin = [(str(i / 2.0) + '\n').encode() for i in _range(N)]
+ IN_DATA = b''.join(sys.stdin)
+ main(sys.stderr, ['--update-to'])
+ out, err = capsysbin.readouterr()
+ assert norm(out) == IN_DATA
+ assert (str((N - 1) / 2.0) + "it").encode() in err
+ assert (str(N / 2.0) + "it").encode() not in err
+
+
+@mark.slow
+@mark.skipif(IS_WIN, reason="no manpages on windows")
+def test_manpath(tmp_path):
+ """Test CLI --manpath"""
+ man = tmp_path / "tqdm.1"
+ assert not man.exists()
+ with raises(SystemExit):
+ main(argv=['--manpath', str(tmp_path)])
+ assert man.is_file()
+
+
+@mark.slow
+@mark.skipif(IS_WIN, reason="no completion on windows")
+def test_comppath(tmp_path):
+ """Test CLI --comppath"""
+ man = tmp_path / "tqdm_completion.sh"
+ assert not man.exists()
+ with raises(SystemExit):
+ main(argv=['--comppath', str(tmp_path)])
+ assert man.is_file()
+
+ # check most important options appear
+ script = man.read_text()
+ opts = {'--help', '--desc', '--total', '--leave', '--ncols', '--ascii',
+ '--dynamic_ncols', '--position', '--bytes', '--nrows', '--delim',
+ '--manpath', '--comppath'}
+ assert all(args in script for args in opts)
+
+
+@restore_sys
+def test_exceptions(capsysbin):
+ """Test CLI Exceptions"""
+ N = 123
+ sys.stdin = [str(i) + '\n' for i in _range(N)]
+ IN_DATA = ''.join(sys.stdin).encode()
+
+ with raises(TqdmKeyError, match="bad_arg_u_ment"):
+ main(sys.stderr, argv=['-ascii', '-unit_scale', '--bad_arg_u_ment', 'foo'])
+ out, _ = capsysbin.readouterr()
+ assert norm(out) == IN_DATA
+
+ with raises(TqdmTypeError, match="invalid_bool_value"):
+ main(sys.stderr, argv=['-ascii', '-unit_scale', 'invalid_bool_value'])
+ out, _ = capsysbin.readouterr()
+ assert norm(out) == IN_DATA
+
+ with raises(TqdmTypeError, match="invalid_int_value"):
+ main(sys.stderr, argv=['-ascii', '--total', 'invalid_int_value'])
+ out, _ = capsysbin.readouterr()
+ assert norm(out) == IN_DATA
+
+ with raises(TqdmKeyError, match="Can only have one of --"):
+ main(sys.stderr, argv=['--update', '--update_to'])
+ out, _ = capsysbin.readouterr()
+ assert norm(out) == IN_DATA
+
+ # test SystemExits
+ for i in ('-h', '--help', '-v', '--version'):
+ with raises(SystemExit):
+ main(argv=[i])
diff --git a/tests/tests_notebook.py b/tests/tests_notebook.py
new file mode 100644
index 0000000..004d7e5
--- /dev/null
+++ b/tests/tests_notebook.py
@@ -0,0 +1,7 @@
+from tqdm.notebook import tqdm as tqdm_notebook
+
+
+def test_notebook_disabled_description():
+ """Test that set_description works for disabled tqdm_notebook"""
+ with tqdm_notebook(1, disable=True) as t:
+ t.set_description("description")
diff --git a/tests/tests_pandas.py b/tests/tests_pandas.py
new file mode 100644
index 0000000..334a97c
--- /dev/null
+++ b/tests/tests_pandas.py
@@ -0,0 +1,219 @@
+from tqdm import tqdm
+
+from .tests_tqdm import StringIO, closing, importorskip, mark, skip
+
+pytestmark = mark.slow
+
+random = importorskip('numpy.random')
+rand = random.rand
+randint = random.randint
+pd = importorskip('pandas')
+
+
+def test_pandas_setup():
+ """Test tqdm.pandas()"""
+ with closing(StringIO()) as our_file:
+ tqdm.pandas(file=our_file, leave=True, ascii=True, total=123)
+ series = pd.Series(randint(0, 50, (100,)))
+ series.progress_apply(lambda x: x + 10)
+ res = our_file.getvalue()
+ assert '100/123' in res
+
+
+def test_pandas_rolling_expanding():
+ """Test pandas.(Series|DataFrame).(rolling|expanding)"""
+ with closing(StringIO()) as our_file:
+ tqdm.pandas(file=our_file, leave=True, ascii=True)
+
+ series = pd.Series(randint(0, 50, (123,)))
+ res1 = series.rolling(10).progress_apply(lambda x: 1, raw=True)
+ res2 = series.rolling(10).apply(lambda x: 1, raw=True)
+ assert res1.equals(res2)
+
+ res3 = series.expanding(10).progress_apply(lambda x: 2, raw=True)
+ res4 = series.expanding(10).apply(lambda x: 2, raw=True)
+ assert res3.equals(res4)
+
+ expects = ['114it'] # 123-10+1
+ for exres in expects:
+ our_file.seek(0)
+ if our_file.getvalue().count(exres) < 2:
+ our_file.seek(0)
+ raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format(
+ exres + " at least twice.", our_file.read()))
+
+
+def test_pandas_series():
+ """Test pandas.Series.progress_apply and .progress_map"""
+ with closing(StringIO()) as our_file:
+ tqdm.pandas(file=our_file, leave=True, ascii=True)
+
+ series = pd.Series(randint(0, 50, (123,)))
+ res1 = series.progress_apply(lambda x: x + 10)
+ res2 = series.apply(lambda x: x + 10)
+ assert res1.equals(res2)
+
+ res3 = series.progress_map(lambda x: x + 10)
+ res4 = series.map(lambda x: x + 10)
+ assert res3.equals(res4)
+
+ expects = ['100%', '123/123']
+ for exres in expects:
+ our_file.seek(0)
+ if our_file.getvalue().count(exres) < 2:
+ our_file.seek(0)
+ raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format(
+ exres + " at least twice.", our_file.read()))
+
+
+def test_pandas_data_frame():
+ """Test pandas.DataFrame.progress_apply and .progress_applymap"""
+ with closing(StringIO()) as our_file:
+ tqdm.pandas(file=our_file, leave=True, ascii=True)
+ df = pd.DataFrame(randint(0, 50, (100, 200)))
+
+ def task_func(x):
+ return x + 1
+
+ # applymap
+ res1 = df.progress_applymap(task_func)
+ res2 = df.applymap(task_func)
+ assert res1.equals(res2)
+
+ # apply unhashable
+ res1 = []
+ df.progress_apply(res1.extend)
+ assert len(res1) == df.size
+
+ # apply
+ for axis in [0, 1, 'index', 'columns']:
+ res3 = df.progress_apply(task_func, axis=axis)
+ res4 = df.apply(task_func, axis=axis)
+ assert res3.equals(res4)
+
+ our_file.seek(0)
+ if our_file.read().count('100%') < 3:
+ our_file.seek(0)
+ raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format(
+ '100% at least three times', our_file.read()))
+
+ # apply_map, apply axis=0, apply axis=1
+ expects = ['20000/20000', '200/200', '100/100']
+ for exres in expects:
+ our_file.seek(0)
+ if our_file.getvalue().count(exres) < 1:
+ our_file.seek(0)
+ raise AssertionError("\nExpected:\n{0}\nIn:\n {1}\n".format(
+ exres + " at least once.", our_file.read()))
+
+
+def test_pandas_groupby_apply():
+ """Test pandas.DataFrame.groupby(...).progress_apply"""
+ with closing(StringIO()) as our_file:
+ tqdm.pandas(file=our_file, leave=False, ascii=True)
+
+ df = pd.DataFrame(randint(0, 50, (500, 3)))
+ df.groupby(0).progress_apply(lambda x: None)
+
+ dfs = pd.DataFrame(randint(0, 50, (500, 3)), columns=list('abc'))
+ dfs.groupby(['a']).progress_apply(lambda x: None)
+
+ df2 = df = pd.DataFrame({'a': randint(1, 8, 10000), 'b': rand(10000)})
+ res1 = df2.groupby("a").apply(max)
+ res2 = df2.groupby("a").progress_apply(max)
+ assert res1.equals(res2)
+
+ our_file.seek(0)
+
+ # don't expect final output since no `leave` and
+ # high dynamic `miniters`
+ nexres = '100%|##########|'
+ if nexres in our_file.read():
+ our_file.seek(0)
+ raise AssertionError("\nDid not expect:\n{0}\nIn:{1}\n".format(
+ nexres, our_file.read()))
+
+ with closing(StringIO()) as our_file:
+ tqdm.pandas(file=our_file, leave=True, ascii=True)
+
+ dfs = pd.DataFrame(randint(0, 50, (500, 3)), columns=list('abc'))
+ dfs.loc[0] = [2, 1, 1]
+ dfs['d'] = 100
+
+ expects = ['500/500', '1/1', '4/4', '2/2']
+ dfs.groupby(dfs.index).progress_apply(lambda x: None)
+ dfs.groupby('d').progress_apply(lambda x: None)
+ dfs.groupby(dfs.columns, axis=1).progress_apply(lambda x: None)
+ dfs.groupby([2, 2, 1, 1], axis=1).progress_apply(lambda x: None)
+
+ our_file.seek(0)
+ if our_file.read().count('100%') < 4:
+ our_file.seek(0)
+ raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format(
+ '100% at least four times', our_file.read()))
+
+ for exres in expects:
+ our_file.seek(0)
+ if our_file.getvalue().count(exres) < 1:
+ our_file.seek(0)
+ raise AssertionError("\nExpected:\n{0}\nIn:\n {1}\n".format(
+ exres + " at least once.", our_file.read()))
+
+
+def test_pandas_leave():
+ """Test pandas with `leave=True`"""
+ with closing(StringIO()) as our_file:
+ df = pd.DataFrame(randint(0, 100, (1000, 6)))
+ tqdm.pandas(file=our_file, leave=True, ascii=True)
+ df.groupby(0).progress_apply(lambda x: None)
+
+ our_file.seek(0)
+
+ exres = '100%|##########| 100/100'
+ if exres not in our_file.read():
+ our_file.seek(0)
+ raise AssertionError("\nExpected:\n{0}\nIn:{1}\n".format(
+ exres, our_file.read()))
+
+
+def test_pandas_apply_args_deprecation():
+ """Test warning info in
+ `pandas.Dataframe(Series).progress_apply(func, *args)`"""
+ try:
+ from tqdm import tqdm_pandas
+ except ImportError as err:
+ skip(str(err))
+
+ with closing(StringIO()) as our_file:
+ tqdm_pandas(tqdm(file=our_file, leave=False, ascii=True, ncols=20))
+ df = pd.DataFrame(randint(0, 50, (500, 3)))
+ df.progress_apply(lambda x: None, 1) # 1 shall cause a warning
+ # Check deprecation message
+ res = our_file.getvalue()
+ assert all(i in res for i in (
+ "TqdmDeprecationWarning", "not supported",
+ "keyword arguments instead"))
+
+
+def test_pandas_deprecation():
+ """Test bar object instance as argument deprecation"""
+ try:
+ from tqdm import tqdm_pandas
+ except ImportError as err:
+ skip(str(err))
+
+ with closing(StringIO()) as our_file:
+ tqdm_pandas(tqdm(file=our_file, leave=False, ascii=True, ncols=20))
+ df = pd.DataFrame(randint(0, 50, (500, 3)))
+ df.groupby(0).progress_apply(lambda x: None)
+ # Check deprecation message
+ assert "TqdmDeprecationWarning" in our_file.getvalue()
+ assert "instead of `tqdm_pandas(tqdm(...))`" in our_file.getvalue()
+
+ with closing(StringIO()) as our_file:
+ tqdm_pandas(tqdm, file=our_file, leave=False, ascii=True, ncols=20)
+ df = pd.DataFrame(randint(0, 50, (500, 3)))
+ df.groupby(0).progress_apply(lambda x: None)
+ # Check deprecation message
+ assert "TqdmDeprecationWarning" in our_file.getvalue()
+ assert "instead of `tqdm_pandas(tqdm, ...)`" in our_file.getvalue()
diff --git a/tests/tests_perf.py b/tests/tests_perf.py
new file mode 100644
index 0000000..552a169
--- /dev/null
+++ b/tests/tests_perf.py
@@ -0,0 +1,325 @@
+from __future__ import division, print_function
+
+import sys
+from contextlib import contextmanager
+from functools import wraps
+from time import sleep, time
+
+# Use relative/cpu timer to have reliable timings when there is a sudden load
+try:
+ from time import process_time
+except ImportError:
+ from time import clock
+ process_time = clock
+
+from tqdm import tqdm, trange
+
+from .tests_tqdm import _range, importorskip, mark, patch_lock, skip
+
+pytestmark = mark.slow
+
+
+def cpu_sleep(t):
+ """Sleep the given amount of cpu time"""
+ start = process_time()
+ while (process_time() - start) < t:
+ pass
+
+
+def checkCpuTime(sleeptime=0.2):
+ """Check if cpu time works correctly"""
+ if checkCpuTime.passed:
+ return True
+ # First test that sleeping does not consume cputime
+ start1 = process_time()
+ sleep(sleeptime)
+ t1 = process_time() - start1
+
+ # secondly check by comparing to cpusleep (where we actually do something)
+ start2 = process_time()
+ cpu_sleep(sleeptime)
+ t2 = process_time() - start2
+
+ if abs(t1) < 0.0001 and t1 < t2 / 10:
+ checkCpuTime.passed = True
+ return True
+ skip("cpu time not reliable on this machine")
+
+
+checkCpuTime.passed = False
+
+
+@contextmanager
+def relative_timer():
+ """yields a context timer function which stops ticking on exit"""
+ start = process_time()
+
+ def elapser():
+ return process_time() - start
+
+ yield lambda: elapser()
+ spent = elapser()
+
+ def elapser(): # NOQA
+ return spent
+
+
+def retry_on_except(n=3, check_cpu_time=True):
+ """decroator for retrying `n` times before raising Exceptions"""
+ def wrapper(func):
+ """actual decorator"""
+ @wraps(func)
+ def test_inner(*args, **kwargs):
+ """may skip if `check_cpu_time` fails"""
+ for i in range(1, n + 1):
+ try:
+ if check_cpu_time:
+ checkCpuTime()
+ func(*args, **kwargs)
+ except Exception:
+ if i >= n:
+ raise
+ else:
+ return
+ return test_inner
+ return wrapper
+
+
+def simple_progress(iterable=None, total=None, file=sys.stdout, desc='',
+ leave=False, miniters=1, mininterval=0.1, width=60):
+ """Simple progress bar reproducing tqdm's major features"""
+ n = [0] # use a closure
+ start_t = [time()]
+ last_n = [0]
+ last_t = [0]
+ if iterable is not None:
+ total = len(iterable)
+
+ def format_interval(t):
+ mins, s = divmod(int(t), 60)
+ h, m = divmod(mins, 60)
+ if h:
+ return '{0:d}:{1:02d}:{2:02d}'.format(h, m, s)
+ else:
+ return '{0:02d}:{1:02d}'.format(m, s)
+
+ def update_and_print(i=1):
+ n[0] += i
+ if (n[0] - last_n[0]) >= miniters:
+ last_n[0] = n[0]
+
+ if (time() - last_t[0]) >= mininterval:
+ last_t[0] = time() # last_t[0] == current time
+
+ spent = last_t[0] - start_t[0]
+ spent_fmt = format_interval(spent)
+ rate = n[0] / spent if spent > 0 else 0
+ rate_fmt = "%.2fs/it" % (1.0 / rate) if 0.0 < rate < 1.0 else "%.2fit/s" % rate
+
+ frac = n[0] / total
+ percentage = int(frac * 100)
+ eta = (total - n[0]) / rate if rate > 0 else 0
+ eta_fmt = format_interval(eta)
+
+ # full_bar = "#" * int(frac * width)
+ barfill = " " * int((1.0 - frac) * width)
+ bar_length, frac_bar_length = divmod(int(frac * width * 10), 10)
+ full_bar = '#' * bar_length
+ frac_bar = chr(48 + frac_bar_length) if frac_bar_length else ' '
+
+ file.write("\r%s %i%%|%s%s%s| %i/%i [%s<%s, %s]" %
+ (desc, percentage, full_bar, frac_bar, barfill, n[0],
+ total, spent_fmt, eta_fmt, rate_fmt))
+
+ if n[0] == total and leave:
+ file.write("\n")
+ file.flush()
+
+ def update_and_yield():
+ for elt in iterable:
+ yield elt
+ update_and_print()
+
+ update_and_print(0)
+ if iterable is not None:
+ return update_and_yield()
+ else:
+ return update_and_print
+
+
+def assert_performance(thresh, name_left, time_left, name_right, time_right):
+ """raises if time_left > thresh * time_right"""
+ if time_left > thresh * time_right:
+ raise ValueError(
+ ('{name[0]}: {time[0]:f}, '
+ '{name[1]}: {time[1]:f}, '
+ 'ratio {ratio:f} > {thresh:f}').format(
+ name=(name_left, name_right),
+ time=(time_left, time_right),
+ ratio=time_left / time_right, thresh=thresh))
+
+
+@retry_on_except()
+def test_iter_basic_overhead():
+ """Test overhead of iteration based tqdm"""
+ total = int(1e6)
+
+ a = 0
+ with trange(total) as t:
+ with relative_timer() as time_tqdm:
+ for i in t:
+ a += i
+ assert a == (total ** 2 - total) / 2.0
+
+ a = 0
+ with relative_timer() as time_bench:
+ for i in _range(total):
+ a += i
+ sys.stdout.write(str(a))
+
+ assert_performance(3, 'trange', time_tqdm(), 'range', time_bench())
+
+
+@retry_on_except()
+def test_manual_basic_overhead():
+ """Test overhead of manual tqdm"""
+ total = int(1e6)
+
+ with tqdm(total=total * 10, leave=True) as t:
+ a = 0
+ with relative_timer() as time_tqdm:
+ for i in _range(total):
+ a += i
+ t.update(10)
+
+ a = 0
+ with relative_timer() as time_bench:
+ for i in _range(total):
+ a += i
+ sys.stdout.write(str(a))
+
+ assert_performance(5, 'tqdm', time_tqdm(), 'range', time_bench())
+
+
+def worker(total, blocking=True):
+ def incr_bar(x):
+ for _ in trange(total, lock_args=None if blocking else (False,),
+ miniters=1, mininterval=0, maxinterval=0):
+ pass
+ return x + 1
+ return incr_bar
+
+
+@retry_on_except()
+@patch_lock(thread=True)
+def test_lock_args():
+ """Test overhead of nonblocking threads"""
+ ThreadPoolExecutor = importorskip('concurrent.futures').ThreadPoolExecutor
+
+ total = 16
+ subtotal = 10000
+
+ with ThreadPoolExecutor() as pool:
+ sys.stderr.write('block ... ')
+ sys.stderr.flush()
+ with relative_timer() as time_tqdm:
+ res = list(pool.map(worker(subtotal, True), range(total)))
+ assert sum(res) == sum(range(total)) + total
+ sys.stderr.write('noblock ... ')
+ sys.stderr.flush()
+ with relative_timer() as time_noblock:
+ res = list(pool.map(worker(subtotal, False), range(total)))
+ assert sum(res) == sum(range(total)) + total
+
+ assert_performance(0.5, 'noblock', time_noblock(), 'tqdm', time_tqdm())
+
+
+@retry_on_except(10)
+def test_iter_overhead_hard():
+ """Test overhead of iteration based tqdm (hard)"""
+ total = int(1e5)
+
+ a = 0
+ with trange(total, leave=True, miniters=1,
+ mininterval=0, maxinterval=0) as t:
+ with relative_timer() as time_tqdm:
+ for i in t:
+ a += i
+ assert a == (total ** 2 - total) / 2.0
+
+ a = 0
+ with relative_timer() as time_bench:
+ for i in _range(total):
+ a += i
+ sys.stdout.write(("%i" % a) * 40)
+
+ assert_performance(130, 'trange', time_tqdm(), 'range', time_bench())
+
+
+@retry_on_except(10)
+def test_manual_overhead_hard():
+ """Test overhead of manual tqdm (hard)"""
+ total = int(1e5)
+
+ with tqdm(total=total * 10, leave=True, miniters=1,
+ mininterval=0, maxinterval=0) as t:
+ a = 0
+ with relative_timer() as time_tqdm:
+ for i in _range(total):
+ a += i
+ t.update(10)
+
+ a = 0
+ with relative_timer() as time_bench:
+ for i in _range(total):
+ a += i
+ sys.stdout.write(("%i" % a) * 40)
+
+ assert_performance(130, 'tqdm', time_tqdm(), 'range', time_bench())
+
+
+@retry_on_except(10)
+def test_iter_overhead_simplebar_hard():
+ """Test overhead of iteration based tqdm vs simple progress bar (hard)"""
+ total = int(1e4)
+
+ a = 0
+ with trange(total, leave=True, miniters=1,
+ mininterval=0, maxinterval=0) as t:
+ with relative_timer() as time_tqdm:
+ for i in t:
+ a += i
+ assert a == (total ** 2 - total) / 2.0
+
+ a = 0
+ s = simple_progress(_range(total), leave=True,
+ miniters=1, mininterval=0)
+ with relative_timer() as time_bench:
+ for i in s:
+ a += i
+
+ assert_performance(10, 'trange', time_tqdm(), 'simple_progress', time_bench())
+
+
+@retry_on_except(10)
+def test_manual_overhead_simplebar_hard():
+ """Test overhead of manual tqdm vs simple progress bar (hard)"""
+ total = int(1e4)
+
+ with tqdm(total=total * 10, leave=True, miniters=1,
+ mininterval=0, maxinterval=0) as t:
+ a = 0
+ with relative_timer() as time_tqdm:
+ for i in _range(total):
+ a += i
+ t.update(10)
+
+ simplebar_update = simple_progress(total=total * 10, leave=True,
+ miniters=1, mininterval=0)
+ a = 0
+ with relative_timer() as time_bench:
+ for i in _range(total):
+ a += i
+ simplebar_update(10)
+
+ assert_performance(10, 'tqdm', time_tqdm(), 'simple_progress', time_bench())
diff --git a/tests/tests_rich.py b/tests/tests_rich.py
new file mode 100644
index 0000000..c75e246
--- /dev/null
+++ b/tests/tests_rich.py
@@ -0,0 +1,10 @@
+"""Test `tqdm.rich`."""
+import sys
+
+from .tests_tqdm import importorskip, mark
+
+
+@mark.skipif(sys.version_info[:3] < (3, 6, 1), reason="`rich` needs py>=3.6.1")
+def test_rich_import():
+ """Test `tqdm.rich` import"""
+ importorskip('tqdm.rich')
diff --git a/tests/tests_synchronisation.py b/tests/tests_synchronisation.py
new file mode 100644
index 0000000..7ee55fb
--- /dev/null
+++ b/tests/tests_synchronisation.py
@@ -0,0 +1,224 @@
+from __future__ import division
+
+import sys
+from functools import wraps
+from threading import Event
+from time import sleep, time
+
+from tqdm import TMonitor, tqdm, trange
+
+from .tests_perf import retry_on_except
+from .tests_tqdm import StringIO, closing, importorskip, patch_lock, skip
+
+
+class Time(object):
+ """Fake time class class providing an offset"""
+ offset = 0
+
+ @classmethod
+ def reset(cls):
+ """zeroes internal offset"""
+ cls.offset = 0
+
+ @classmethod
+ def time(cls):
+ """time.time() + offset"""
+ return time() + cls.offset
+
+ @staticmethod
+ def sleep(dur):
+ """identical to time.sleep()"""
+ sleep(dur)
+
+ @classmethod
+ def fake_sleep(cls, dur):
+ """adds `dur` to internal offset"""
+ cls.offset += dur
+ sleep(0.000001) # sleep to allow interrupt (instead of pass)
+
+
+def FakeEvent():
+ """patched `threading.Event` where `wait()` uses `Time.fake_sleep()`"""
+ event = Event() # not a class in py2 so can't inherit
+
+ def wait(timeout=None):
+ """uses Time.fake_sleep"""
+ if timeout is not None:
+ Time.fake_sleep(timeout)
+ return event.is_set()
+
+ event.wait = wait
+ return event
+
+
+def patch_sleep(func):
+ """Temporarily makes TMonitor use Time.fake_sleep"""
+ @wraps(func)
+ def inner(*args, **kwargs):
+ """restores TMonitor on completion regardless of Exceptions"""
+ TMonitor._test["time"] = Time.time
+ TMonitor._test["Event"] = FakeEvent
+ if tqdm.monitor:
+ assert not tqdm.monitor.get_instances()
+ tqdm.monitor.exit()
+ del tqdm.monitor
+ tqdm.monitor = None
+ try:
+ return func(*args, **kwargs)
+ finally:
+ # Check that class var monitor is deleted if no instance left
+ tqdm.monitor_interval = 10
+ if tqdm.monitor:
+ assert not tqdm.monitor.get_instances()
+ tqdm.monitor.exit()
+ del tqdm.monitor
+ tqdm.monitor = None
+ TMonitor._test.pop("Event")
+ TMonitor._test.pop("time")
+
+ return inner
+
+
+def cpu_timify(t, timer=Time):
+ """Force tqdm to use the specified timer instead of system-wide time"""
+ t._time = timer.time
+ t._sleep = timer.fake_sleep
+ t.start_t = t.last_print_t = t._time()
+ return timer
+
+
+class FakeTqdm(object):
+ _instances = set()
+ get_lock = tqdm.get_lock
+
+
+def incr(x):
+ return x + 1
+
+
+def incr_bar(x):
+ with closing(StringIO()) as our_file:
+ for _ in trange(x, lock_args=(False,), file=our_file):
+ pass
+ return incr(x)
+
+
+@patch_sleep
+def test_monitor_thread():
+ """Test dummy monitoring thread"""
+ monitor = TMonitor(FakeTqdm, 10)
+ # Test if alive, then killed
+ assert monitor.report()
+ monitor.exit()
+ assert not monitor.report()
+ assert not monitor.is_alive()
+ del monitor
+
+
+@patch_sleep
+def test_monitoring_and_cleanup():
+ """Test for stalled tqdm instance and monitor deletion"""
+ # Note: should fix miniters for these tests, else with dynamic_miniters
+ # it's too complicated to handle with monitoring update and maxinterval...
+ maxinterval = tqdm.monitor_interval
+ assert maxinterval == 10
+ total = 1000
+
+ with closing(StringIO()) as our_file:
+ with tqdm(total=total, file=our_file, miniters=500, mininterval=0.1,
+ maxinterval=maxinterval) as t:
+ cpu_timify(t, Time)
+ # Do a lot of iterations in a small timeframe
+ # (smaller than monitor interval)
+ Time.fake_sleep(maxinterval / 10) # monitor won't wake up
+ t.update(500)
+ # check that our fixed miniters is still there
+ assert t.miniters <= 500 # TODO: should really be == 500
+ # Then do 1 it after monitor interval, so that monitor kicks in
+ Time.fake_sleep(maxinterval)
+ t.update(1)
+ # Wait for the monitor to get out of sleep's loop and update tqdm.
+ timeend = Time.time()
+ while not (t.monitor.woken >= timeend and t.miniters == 1):
+ Time.fake_sleep(1) # Force awake up if it woken too soon
+ assert t.miniters == 1 # check that monitor corrected miniters
+ # Note: at this point, there may be a race condition: monitor saved
+ # current woken time but Time.sleep() happen just before monitor
+ # sleep. To fix that, either sleep here or increase time in a loop
+ # to ensure that monitor wakes up at some point.
+
+ # Try again but already at miniters = 1 so nothing will be done
+ Time.fake_sleep(maxinterval)
+ t.update(2)
+ timeend = Time.time()
+ while t.monitor.woken < timeend:
+ Time.fake_sleep(1) # Force awake if it woken too soon
+ # Wait for the monitor to get out of sleep's loop and update
+ # tqdm
+ assert t.miniters == 1 # check that monitor corrected miniters
+
+
+@patch_sleep
+def test_monitoring_multi():
+ """Test on multiple bars, one not needing miniters adjustment"""
+ # Note: should fix miniters for these tests, else with dynamic_miniters
+ # it's too complicated to handle with monitoring update and maxinterval...
+ maxinterval = tqdm.monitor_interval
+ assert maxinterval == 10
+ total = 1000
+
+ with closing(StringIO()) as our_file:
+ with tqdm(total=total, file=our_file, miniters=500, mininterval=0.1,
+ maxinterval=maxinterval) as t1:
+ # Set high maxinterval for t2 so monitor does not need to adjust it
+ with tqdm(total=total, file=our_file, miniters=500, mininterval=0.1,
+ maxinterval=1E5) as t2:
+ cpu_timify(t1, Time)
+ cpu_timify(t2, Time)
+ # Do a lot of iterations in a small timeframe
+ Time.fake_sleep(maxinterval / 10)
+ t1.update(500)
+ t2.update(500)
+ assert t1.miniters <= 500 # TODO: should really be == 500
+ assert t2.miniters == 500
+ # Then do 1 it after monitor interval, so that monitor kicks in
+ Time.fake_sleep(maxinterval)
+ t1.update(1)
+ t2.update(1)
+ # Wait for the monitor to get out of sleep and update tqdm
+ timeend = Time.time()
+ while not (t1.monitor.woken >= timeend and t1.miniters == 1):
+ Time.fake_sleep(1)
+ assert t1.miniters == 1 # check that monitor corrected miniters
+ assert t2.miniters == 500 # check that t2 was not adjusted
+
+
+def test_imap():
+ """Test multiprocessing.Pool"""
+ try:
+ from multiprocessing import Pool
+ except ImportError as err:
+ skip(str(err))
+
+ pool = Pool()
+ res = list(tqdm(pool.imap(incr, range(100)), disable=True))
+ pool.close()
+ assert res[-1] == 100
+
+
+# py2: locks won't propagate to incr_bar so may cause `AttributeError`
+@retry_on_except(n=3 if sys.version_info < (3,) else 1, check_cpu_time=False)
+@patch_lock(thread=True)
+def test_threadpool():
+ """Test concurrent.futures.ThreadPoolExecutor"""
+ ThreadPoolExecutor = importorskip('concurrent.futures').ThreadPoolExecutor
+
+ with ThreadPoolExecutor(8) as pool:
+ try:
+ res = list(tqdm(pool.map(incr_bar, range(100)), disable=True))
+ except AttributeError:
+ if sys.version_info < (3,):
+ skip("not supported on py2")
+ else:
+ raise
+ assert sum(res) == sum(range(1, 101))
diff --git a/tests/tests_tk.py b/tests/tests_tk.py
new file mode 100644
index 0000000..9aa645c
--- /dev/null
+++ b/tests/tests_tk.py
@@ -0,0 +1,7 @@
+"""Test `tqdm.tk`."""
+from .tests_tqdm import importorskip
+
+
+def test_tk_import():
+ """Test `tqdm.tk` import"""
+ importorskip('tqdm.tk')
diff --git a/tests/tests_tqdm.py b/tests/tests_tqdm.py
new file mode 100644
index 0000000..bba457a
--- /dev/null
+++ b/tests/tests_tqdm.py
@@ -0,0 +1,1996 @@
+# -*- coding: utf-8 -*-
+# Advice: use repr(our_file.read()) to print the full output of tqdm
+# (else '\r' will replace the previous lines and you'll see only the latest.
+from __future__ import print_function
+
+import csv
+import os
+import re
+import sys
+from contextlib import contextmanager
+from functools import wraps
+from warnings import catch_warnings, simplefilter
+
+from pytest import importorskip, mark, raises, skip
+
+from tqdm import TqdmDeprecationWarning, TqdmWarning, tqdm, trange
+from tqdm.contrib import DummyTqdmFile
+from tqdm.std import EMA, Bar
+
+try:
+ from StringIO import StringIO
+except ImportError:
+ from io import StringIO
+
+from io import IOBase # to support unicode strings
+from io import BytesIO
+
+
+class DeprecationError(Exception):
+ pass
+
+
+# Ensure we can use `with closing(...) as ... :` syntax
+if getattr(StringIO, '__exit__', False) and getattr(StringIO, '__enter__', False):
+ def closing(arg):
+ return arg
+else:
+ from contextlib import closing
+
+try:
+ _range = xrange
+except NameError:
+ _range = range
+
+try:
+ _unicode = unicode
+except NameError:
+ _unicode = str
+
+nt_and_no_colorama = False
+if os.name == 'nt':
+ try:
+ import colorama # NOQA
+ except ImportError:
+ nt_and_no_colorama = True
+
+# Regex definitions
+# List of control characters
+CTRLCHR = [r'\r', r'\n', r'\x1b\[A'] # Need to escape [ for regex
+# Regular expressions compilation
+RE_rate = re.compile(r'[^\d](\d[.\d]+)it/s')
+RE_ctrlchr = re.compile("(%s)" % '|'.join(CTRLCHR)) # Match control chars
+RE_ctrlchr_excl = re.compile('|'.join(CTRLCHR)) # Match and exclude ctrl chars
+RE_pos = re.compile(r'([\r\n]+((pos\d+) bar:\s+\d+%|\s{3,6})?[^\r\n]*)')
+
+
+def pos_line_diff(res_list, expected_list, raise_nonempty=True):
+ """
+ Return differences between two bar output lists.
+ To be used with `RE_pos`
+ """
+ res = [(r, e) for r, e in zip(res_list, expected_list)
+ for pos in [len(e) - len(e.lstrip('\n'))] # bar position
+ if r != e # simple comparison
+ if not r.startswith(e) # start matches
+ or not (
+ # move up at end (maybe less due to closing bars)
+ any(r.endswith(end + i * '\x1b[A') for i in range(pos + 1)
+ for end in [
+ ']', # bar
+ ' ']) # cleared
+ or '100%' in r # completed bar
+ or r == '\n') # final bar
+ or r[(-1 - pos) * len('\x1b[A'):] == '\x1b[A'] # too many moves up
+ if raise_nonempty and (res or len(res_list) != len(expected_list)):
+ if len(res_list) < len(expected_list):
+ res.extend([(None, e) for e in expected_list[len(res_list):]])
+ elif len(res_list) > len(expected_list):
+ res.extend([(r, None) for r in res_list[len(expected_list):]])
+ raise AssertionError(
+ "Got => Expected\n" + '\n'.join('%r => %r' % i for i in res))
+ return res
+
+
+class DiscreteTimer(object):
+ """Virtual discrete time manager, to precisely control time for tests"""
+ def __init__(self):
+ self.t = 0.0
+
+ def sleep(self, t):
+ """Sleep = increment the time counter (almost no CPU used)"""
+ self.t += t
+
+ def time(self):
+ """Get the current time"""
+ return self.t
+
+
+def cpu_timify(t, timer=None):
+ """Force tqdm to use the specified timer instead of system-wide time()"""
+ if timer is None:
+ timer = DiscreteTimer()
+ t._time = timer.time
+ t._sleep = timer.sleep
+ t.start_t = t.last_print_t = t._time()
+ return timer
+
+
+class UnicodeIO(IOBase):
+ """Unicode version of StringIO"""
+ def __init__(self, *args, **kwargs):
+ super(UnicodeIO, self).__init__(*args, **kwargs)
+ self.encoding = 'U8' # io.StringIO supports unicode, but no encoding
+ self.text = ''
+ self.cursor = 0
+
+ def __len__(self):
+ return len(self.text)
+
+ def seek(self, offset):
+ self.cursor = offset
+
+ def tell(self):
+ return self.cursor
+
+ def write(self, s):
+ self.text = self.text[:self.cursor] + s + self.text[self.cursor + len(s):]
+ self.cursor += len(s)
+
+ def read(self, n=-1):
+ _cur = self.cursor
+ self.cursor = len(self) if n < 0 else min(_cur + n, len(self))
+ return self.text[_cur:self.cursor]
+
+ def getvalue(self):
+ return self.text
+
+
+def get_bar(all_bars, i=None):
+ """Get a specific update from a whole bar traceback"""
+ # Split according to any used control characters
+ bars_split = RE_ctrlchr_excl.split(all_bars)
+ bars_split = list(filter(None, bars_split)) # filter out empty splits
+ return bars_split if i is None else bars_split[i]
+
+
+def progressbar_rate(bar_str):
+ return float(RE_rate.search(bar_str).group(1))
+
+
+def squash_ctrlchars(s):
+ """Apply control characters in a string just like a terminal display"""
+ curline = 0
+ lines = [''] # state of fake terminal
+ for nextctrl in filter(None, RE_ctrlchr.split(s)):
+ # apply control chars
+ if nextctrl == '\r':
+ # go to line beginning (simplified here: just empty the string)
+ lines[curline] = ''
+ elif nextctrl == '\n':
+ if curline >= len(lines) - 1:
+ # wrap-around creates newline
+ lines.append('')
+ # move cursor down
+ curline += 1
+ elif nextctrl == '\x1b[A':
+ # move cursor up
+ if curline > 0:
+ curline -= 1
+ else:
+ raise ValueError("Cannot go further up")
+ else:
+ # print message on current line
+ lines[curline] += nextctrl
+ return lines
+
+
+def test_format_interval():
+ """Test time interval format"""
+ format_interval = tqdm.format_interval
+
+ assert format_interval(60) == '01:00'
+ assert format_interval(6160) == '1:42:40'
+ assert format_interval(238113) == '66:08:33'
+
+
+def test_format_num():
+ """Test number format"""
+ format_num = tqdm.format_num
+
+ assert float(format_num(1337)) == 1337
+ assert format_num(int(1e6)) == '1e+6'
+ assert format_num(1239876) == '1' '239' '876'
+
+
+def test_format_meter():
+ """Test statistics and progress bar formatting"""
+ try:
+ unich = unichr
+ except NameError:
+ unich = chr
+
+ format_meter = tqdm.format_meter
+
+ assert format_meter(0, 1000, 13) == " 0%| | 0/1000 [00:13<?, ?it/s]"
+ # If not implementing any changes to _tqdm.py, set prefix='desc'
+ # or else ": : " will be in output, so assertion should change
+ assert format_meter(0, 1000, 13, ncols=68, prefix='desc: ') == (
+ "desc: 0%| | 0/1000 [00:13<?, ?it/s]")
+ assert format_meter(231, 1000, 392) == (" 23%|" + unich(0x2588) * 2 + unich(0x258e) +
+ " | 231/1000 [06:32<21:44, 1.70s/it]")
+ assert format_meter(10000, 1000, 13) == "10000it [00:13, 769.23it/s]"
+ assert format_meter(231, 1000, 392, ncols=56, ascii=True) == " 23%|" + '#' * 3 + '6' + (
+ " | 231/1000 [06:32<21:44, 1.70s/it]")
+ assert format_meter(100000, 1000, 13, unit_scale=True,
+ unit='iB') == "100kiB [00:13, 7.69kiB/s]"
+ assert format_meter(100, 1000, 12, ncols=0,
+ rate=7.33) == " 10% 100/1000 [00:12<02:02, 7.33it/s]"
+ # ncols is small, l_bar is too large
+ # l_bar gets chopped
+ # no bar
+ # no r_bar
+ # 10/12 stars since ncols is 10
+ assert format_meter(
+ 0, 1000, 13, ncols=10,
+ bar_format="************{bar:10}$$$$$$$$$$") == "**********"
+ # n_cols allows for l_bar and some of bar
+ # l_bar displays
+ # bar gets chopped
+ # no r_bar
+ # all 12 stars and 8/10 bar parts
+ assert format_meter(
+ 0, 1000, 13, ncols=20,
+ bar_format="************{bar:10}$$$$$$$$$$") == "************ "
+ # n_cols allows for l_bar, bar, and some of r_bar
+ # l_bar displays
+ # bar displays
+ # r_bar gets chopped
+ # all 12 stars and 10 bar parts, but only 8/10 dollar signs
+ assert format_meter(
+ 0, 1000, 13, ncols=30,
+ bar_format="************{bar:10}$$$$$$$$$$") == "************ $$$$$$$$"
+ # trim left ANSI; escape is before trim zone
+ # we only know it has ANSI codes, so we append an END code anyway
+ assert format_meter(
+ 0, 1000, 13, ncols=10, bar_format="*****\033[22m****\033[0m***{bar:10}$$$$$$$$$$"
+ ) == "*****\033[22m****\033[0m*\033[0m"
+ # trim left ANSI; escape is at trim zone
+ assert format_meter(
+ 0, 1000, 13, ncols=10,
+ bar_format="*****\033[22m*****\033[0m**{bar:10}$$$$$$$$$$") == "*****\033[22m*****\033[0m"
+ # trim left ANSI; escape is after trim zone
+ assert format_meter(
+ 0, 1000, 13, ncols=10,
+ bar_format="*****\033[22m******\033[0m*{bar:10}$$$$$$$$$$") == "*****\033[22m*****\033[0m"
+ # Check that bar_format correctly adapts {bar} size to the rest
+ assert format_meter(
+ 20, 100, 12, ncols=13, rate=8.1,
+ bar_format=r'{l_bar}{bar}|{n_fmt}/{total_fmt}') == " 20%|" + unich(0x258f) + "|20/100"
+ assert format_meter(
+ 20, 100, 12, ncols=14, rate=8.1,
+ bar_format=r'{l_bar}{bar}|{n_fmt}/{total_fmt}') == " 20%|" + unich(0x258d) + " |20/100"
+ # Check wide characters
+ if sys.version_info >= (3,):
+ assert format_meter(0, 1000, 13, ncols=68, prefix='fullwidth: ') == (
+ "fullwidth: 0%| | 0/1000 [00:13<?, ?it/s]")
+ assert format_meter(0, 1000, 13, ncols=68, prefix='ニッポン [ニッポン]: ') == (
+ "ニッポン [ニッポン]: 0%| | 0/1000 [00:13<?, ?it/s]")
+ # Check that bar_format can print only {bar} or just one side
+ assert format_meter(20, 100, 12, ncols=2, rate=8.1,
+ bar_format=r'{bar}') == unich(0x258d) + " "
+ assert format_meter(20, 100, 12, ncols=7, rate=8.1,
+ bar_format=r'{l_bar}{bar}') == " 20%|" + unich(0x258d) + " "
+ assert format_meter(20, 100, 12, ncols=6, rate=8.1,
+ bar_format=r'{bar}|test') == unich(0x258f) + "|test"
+
+
+def test_ansi_escape_codes():
+ """Test stripping of ANSI escape codes"""
+ ansi = {'BOLD': '\033[1m', 'RED': '\033[91m', 'END': '\033[0m'}
+ desc_raw = '{BOLD}{RED}Colored{END} description'
+ ncols = 123
+
+ desc_stripped = desc_raw.format(BOLD='', RED='', END='')
+ meter = tqdm.format_meter(0, 100, 0, ncols=ncols, prefix=desc_stripped)
+ assert len(meter) == ncols
+
+ desc = desc_raw.format(**ansi)
+ meter = tqdm.format_meter(0, 100, 0, ncols=ncols, prefix=desc)
+ # `format_meter` inserts an extra END for safety
+ ansi_len = len(desc) - len(desc_stripped) + len(ansi['END'])
+ assert len(meter) == ncols + ansi_len
+
+
+def test_si_format():
+ """Test SI unit prefixes"""
+ format_meter = tqdm.format_meter
+
+ assert '9.00 ' in format_meter(1, 9, 1, unit_scale=True, unit='B')
+ assert '99.0 ' in format_meter(1, 99, 1, unit_scale=True)
+ assert '999 ' in format_meter(1, 999, 1, unit_scale=True)
+ assert '9.99k ' in format_meter(1, 9994, 1, unit_scale=True)
+ assert '10.0k ' in format_meter(1, 9999, 1, unit_scale=True)
+ assert '99.5k ' in format_meter(1, 99499, 1, unit_scale=True)
+ assert '100k ' in format_meter(1, 99999, 1, unit_scale=True)
+ assert '1.00M ' in format_meter(1, 999999, 1, unit_scale=True)
+ assert '1.00G ' in format_meter(1, 999999999, 1, unit_scale=True)
+ assert '1.00T ' in format_meter(1, 999999999999, 1, unit_scale=True)
+ assert '1.00P ' in format_meter(1, 999999999999999, 1, unit_scale=True)
+ assert '1.00E ' in format_meter(1, 999999999999999999, 1, unit_scale=True)
+ assert '1.00Z ' in format_meter(1, 999999999999999999999, 1, unit_scale=True)
+ assert '1.0Y ' in format_meter(1, 999999999999999999999999, 1, unit_scale=True)
+ assert '10.0Y ' in format_meter(1, 9999999999999999999999999, 1, unit_scale=True)
+ assert '100.0Y ' in format_meter(1, 99999999999999999999999999, 1, unit_scale=True)
+ assert '1000.0Y ' in format_meter(1, 999999999999999999999999999, 1,
+ unit_scale=True)
+
+
+def test_bar_formatspec():
+ """Test Bar.__format__ spec"""
+ assert "{0:5a}".format(Bar(0.3)) == "#5 "
+ assert "{0:2}".format(Bar(0.5, charset=" .oO0")) == "0 "
+ assert "{0:2a}".format(Bar(0.5, charset=" .oO0")) == "# "
+ assert "{0:-6a}".format(Bar(0.5, 10)) == '## '
+ assert "{0:2b}".format(Bar(0.5, 10)) == ' '
+
+
+def test_all_defaults():
+ """Test default kwargs"""
+ with closing(UnicodeIO()) as our_file:
+ with tqdm(range(10), file=our_file) as progressbar:
+ assert len(progressbar) == 10
+ for _ in progressbar:
+ pass
+ # restore stdout/stderr output for `nosetest` interface
+ # try:
+ # sys.stderr.write('\x1b[A')
+ # except:
+ # pass
+ sys.stderr.write('\rTest default kwargs ... ')
+
+
+class WriteTypeChecker(BytesIO):
+ """File-like to assert the expected type is written"""
+ def __init__(self, expected_type):
+ super(WriteTypeChecker, self).__init__()
+ self.expected_type = expected_type
+
+ def write(self, s):
+ assert isinstance(s, self.expected_type)
+
+
+def test_native_string_io_for_default_file():
+ """Native strings written to unspecified files"""
+ stderr = sys.stderr
+ try:
+ sys.stderr = WriteTypeChecker(expected_type=type(''))
+ for _ in tqdm(range(3)):
+ pass
+ sys.stderr.encoding = None # py2 behaviour
+ for _ in tqdm(range(3)):
+ pass
+ finally:
+ sys.stderr = stderr
+
+
+def test_unicode_string_io_for_specified_file():
+ """Unicode strings written to specified files"""
+ for _ in tqdm(range(3), file=WriteTypeChecker(expected_type=type(u''))):
+ pass
+
+
+def test_write_bytes():
+ """Test write_bytes argument with and without `file`"""
+ # specified file (and bytes)
+ for _ in tqdm(range(3), file=WriteTypeChecker(expected_type=type(b'')),
+ write_bytes=True):
+ pass
+ # unspecified file (and unicode)
+ stderr = sys.stderr
+ try:
+ sys.stderr = WriteTypeChecker(expected_type=type(u''))
+ for _ in tqdm(range(3), write_bytes=False):
+ pass
+ finally:
+ sys.stderr = stderr
+
+
+def test_iterate_over_csv_rows():
+ """Test csv iterator"""
+ # Create a test csv pseudo file
+ with closing(StringIO()) as test_csv_file:
+ writer = csv.writer(test_csv_file)
+ for _ in _range(3):
+ writer.writerow(['test'] * 3)
+ test_csv_file.seek(0)
+
+ # Test that nothing fails if we iterate over rows
+ reader = csv.DictReader(test_csv_file, fieldnames=('row1', 'row2', 'row3'))
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(reader, file=our_file):
+ pass
+
+
+def test_file_output():
+ """Test output to arbitrary file-like objects"""
+ with closing(StringIO()) as our_file:
+ for i in tqdm(_range(3), file=our_file):
+ if i == 1:
+ our_file.seek(0)
+ assert '0/3' in our_file.read()
+
+
+def test_leave_option():
+ """Test `leave=True` always prints info about the last iteration"""
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(3), file=our_file, leave=True):
+ pass
+ res = our_file.getvalue()
+ assert '| 3/3 ' in res
+ assert '\n' == res[-1] # not '\r'
+
+ with closing(StringIO()) as our_file2:
+ for _ in tqdm(_range(3), file=our_file2, leave=False):
+ pass
+ assert '| 3/3 ' not in our_file2.getvalue()
+
+
+def test_trange():
+ """Test trange"""
+ with closing(StringIO()) as our_file:
+ for _ in trange(3, file=our_file, leave=True):
+ pass
+ assert '| 3/3 ' in our_file.getvalue()
+
+ with closing(StringIO()) as our_file2:
+ for _ in trange(3, file=our_file2, leave=False):
+ pass
+ assert '| 3/3 ' not in our_file2.getvalue()
+
+
+def test_min_interval():
+ """Test mininterval"""
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(3), file=our_file, mininterval=1e-10):
+ pass
+ assert " 0%| | 0/3 [00:00<" in our_file.getvalue()
+
+
+def test_max_interval():
+ """Test maxinterval"""
+ total = 100
+ bigstep = 10
+ smallstep = 5
+
+ # Test without maxinterval
+ timer = DiscreteTimer()
+ with closing(StringIO()) as our_file:
+ with closing(StringIO()) as our_file2:
+ # with maxinterval but higher than loop sleep time
+ t = tqdm(total=total, file=our_file, miniters=None, mininterval=0,
+ smoothing=1, maxinterval=1e-2)
+ cpu_timify(t, timer)
+
+ # without maxinterval
+ t2 = tqdm(total=total, file=our_file2, miniters=None, mininterval=0,
+ smoothing=1, maxinterval=None)
+ cpu_timify(t2, timer)
+
+ assert t.dynamic_miniters
+ assert t2.dynamic_miniters
+
+ # Increase 10 iterations at once
+ t.update(bigstep)
+ t2.update(bigstep)
+ # The next iterations should not trigger maxinterval (step 10)
+ for _ in _range(4):
+ t.update(smallstep)
+ t2.update(smallstep)
+ timer.sleep(1e-5)
+ t.close() # because PyPy doesn't gc immediately
+ t2.close() # as above
+
+ assert "25%" not in our_file2.getvalue()
+ assert "25%" not in our_file.getvalue()
+
+ # Test with maxinterval effect
+ timer = DiscreteTimer()
+ with closing(StringIO()) as our_file:
+ with tqdm(total=total, file=our_file, miniters=None, mininterval=0,
+ smoothing=1, maxinterval=1e-4) as t:
+ cpu_timify(t, timer)
+
+ # Increase 10 iterations at once
+ t.update(bigstep)
+ # The next iterations should trigger maxinterval (step 5)
+ for _ in _range(4):
+ t.update(smallstep)
+ timer.sleep(1e-2)
+
+ assert "25%" in our_file.getvalue()
+
+ # Test iteration based tqdm with maxinterval effect
+ timer = DiscreteTimer()
+ with closing(StringIO()) as our_file:
+ with tqdm(_range(total), file=our_file, miniters=None,
+ mininterval=1e-5, smoothing=1, maxinterval=1e-4) as t2:
+ cpu_timify(t2, timer)
+
+ for i in t2:
+ if i >= (bigstep - 1) and ((i - (bigstep - 1)) % smallstep) == 0:
+ timer.sleep(1e-2)
+ if i >= 3 * bigstep:
+ break
+
+ assert "15%" in our_file.getvalue()
+
+ # Test different behavior with and without mininterval
+ timer = DiscreteTimer()
+ total = 1000
+ mininterval = 0.1
+ maxinterval = 10
+ with closing(StringIO()) as our_file:
+ with tqdm(total=total, file=our_file, miniters=None, smoothing=1,
+ mininterval=mininterval, maxinterval=maxinterval) as tm1:
+ with tqdm(total=total, file=our_file, miniters=None, smoothing=1,
+ mininterval=0, maxinterval=maxinterval) as tm2:
+
+ cpu_timify(tm1, timer)
+ cpu_timify(tm2, timer)
+
+ # Fast iterations, check if dynamic_miniters triggers
+ timer.sleep(mininterval) # to force update for t1
+ tm1.update(total / 2)
+ tm2.update(total / 2)
+ assert int(tm1.miniters) == tm2.miniters == total / 2
+
+ # Slow iterations, check different miniters if mininterval
+ timer.sleep(maxinterval * 2)
+ tm1.update(total / 2)
+ tm2.update(total / 2)
+ res = [tm1.miniters, tm2.miniters]
+ assert res == [(total / 2) * mininterval / (maxinterval * 2),
+ (total / 2) * maxinterval / (maxinterval * 2)]
+
+ # Same with iterable based tqdm
+ timer1 = DiscreteTimer() # need 2 timers for each bar because zip not work
+ timer2 = DiscreteTimer()
+ total = 100
+ mininterval = 0.1
+ maxinterval = 10
+ with closing(StringIO()) as our_file:
+ t1 = tqdm(_range(total), file=our_file, miniters=None, smoothing=1,
+ mininterval=mininterval, maxinterval=maxinterval)
+ t2 = tqdm(_range(total), file=our_file, miniters=None, smoothing=1,
+ mininterval=0, maxinterval=maxinterval)
+
+ cpu_timify(t1, timer1)
+ cpu_timify(t2, timer2)
+
+ for i in t1:
+ if i == ((total / 2) - 2):
+ timer1.sleep(mininterval)
+ if i == (total - 1):
+ timer1.sleep(maxinterval * 2)
+
+ for i in t2:
+ if i == ((total / 2) - 2):
+ timer2.sleep(mininterval)
+ if i == (total - 1):
+ timer2.sleep(maxinterval * 2)
+
+ assert t1.miniters == 0.255
+ assert t2.miniters == 0.5
+
+ t1.close()
+ t2.close()
+
+
+def test_delay():
+ """Test delay"""
+ timer = DiscreteTimer()
+ with closing(StringIO()) as our_file:
+ t = tqdm(total=2, file=our_file, leave=True, delay=3)
+ cpu_timify(t, timer)
+ timer.sleep(2)
+ t.update(1)
+ assert not our_file.getvalue()
+ timer.sleep(2)
+ t.update(1)
+ assert our_file.getvalue()
+ t.close()
+
+
+def test_min_iters():
+ """Test miniters"""
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(3), file=our_file, leave=True, mininterval=0, miniters=2):
+ pass
+
+ out = our_file.getvalue()
+ assert '| 0/3 ' in out
+ assert '| 1/3 ' not in out
+ assert '| 2/3 ' in out
+ assert '| 3/3 ' in out
+
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(3), file=our_file, leave=True, mininterval=0, miniters=1):
+ pass
+
+ out = our_file.getvalue()
+ assert '| 0/3 ' in out
+ assert '| 1/3 ' in out
+ assert '| 2/3 ' in out
+ assert '| 3/3 ' in out
+
+
+def test_dynamic_min_iters():
+ """Test purely dynamic miniters (and manual updates and __del__)"""
+ with closing(StringIO()) as our_file:
+ total = 10
+ t = tqdm(total=total, file=our_file, miniters=None, mininterval=0, smoothing=1)
+
+ t.update()
+ # Increase 3 iterations
+ t.update(3)
+ # The next two iterations should be skipped because of dynamic_miniters
+ t.update()
+ t.update()
+ # The third iteration should be displayed
+ t.update()
+
+ out = our_file.getvalue()
+ assert t.dynamic_miniters
+ t.__del__() # simulate immediate del gc
+
+ assert ' 0%| | 0/10 [00:00<' in out
+ assert '40%' in out
+ assert '50%' not in out
+ assert '60%' not in out
+ assert '70%' in out
+
+ # Check with smoothing=0, miniters should be set to max update seen so far
+ with closing(StringIO()) as our_file:
+ total = 10
+ t = tqdm(total=total, file=our_file, miniters=None, mininterval=0, smoothing=0)
+
+ t.update()
+ t.update(2)
+ t.update(5) # this should be stored as miniters
+ t.update(1)
+
+ out = our_file.getvalue()
+ assert all(i in out for i in ("0/10", "1/10", "3/10"))
+ assert "2/10" not in out
+ assert t.dynamic_miniters and not t.smoothing
+ assert t.miniters == 5
+ t.close()
+
+ # Check iterable based tqdm
+ with closing(StringIO()) as our_file:
+ t = tqdm(_range(10), file=our_file, miniters=None, mininterval=None,
+ smoothing=0.5)
+ for _ in t:
+ pass
+ assert t.dynamic_miniters
+
+ # No smoothing
+ with closing(StringIO()) as our_file:
+ t = tqdm(_range(10), file=our_file, miniters=None, mininterval=None,
+ smoothing=0)
+ for _ in t:
+ pass
+ assert t.dynamic_miniters
+
+ # No dynamic_miniters (miniters is fixed manually)
+ with closing(StringIO()) as our_file:
+ t = tqdm(_range(10), file=our_file, miniters=1, mininterval=None)
+ for _ in t:
+ pass
+ assert not t.dynamic_miniters
+
+
+def test_big_min_interval():
+ """Test large mininterval"""
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(2), file=our_file, mininterval=1E10):
+ pass
+ assert '50%' not in our_file.getvalue()
+
+ with closing(StringIO()) as our_file:
+ with tqdm(_range(2), file=our_file, mininterval=1E10) as t:
+ t.update()
+ t.update()
+ assert '50%' not in our_file.getvalue()
+
+
+def test_smoothed_dynamic_min_iters():
+ """Test smoothed dynamic miniters"""
+ timer = DiscreteTimer()
+
+ with closing(StringIO()) as our_file:
+ with tqdm(total=100, file=our_file, miniters=None, mininterval=1,
+ smoothing=0.5, maxinterval=0) as t:
+ cpu_timify(t, timer)
+
+ # Increase 10 iterations at once
+ timer.sleep(1)
+ t.update(10)
+ # The next iterations should be partially skipped
+ for _ in _range(2):
+ timer.sleep(1)
+ t.update(4)
+ for _ in _range(20):
+ timer.sleep(1)
+ t.update()
+
+ assert t.dynamic_miniters
+ out = our_file.getvalue()
+ assert ' 0%| | 0/100 [00:00<' in out
+ assert '20%' in out
+ assert '23%' not in out
+ assert '25%' in out
+ assert '26%' not in out
+ assert '28%' in out
+
+
+def test_smoothed_dynamic_min_iters_with_min_interval():
+ """Test smoothed dynamic miniters with mininterval"""
+ timer = DiscreteTimer()
+
+ # In this test, `miniters` should gradually decline
+ total = 100
+
+ with closing(StringIO()) as our_file:
+ # Test manual updating tqdm
+ with tqdm(total=total, file=our_file, miniters=None, mininterval=1e-3,
+ smoothing=1, maxinterval=0) as t:
+ cpu_timify(t, timer)
+
+ t.update(10)
+ timer.sleep(1e-2)
+ for _ in _range(4):
+ t.update()
+ timer.sleep(1e-2)
+ out = our_file.getvalue()
+ assert t.dynamic_miniters
+
+ with closing(StringIO()) as our_file:
+ # Test iteration-based tqdm
+ with tqdm(_range(total), file=our_file, miniters=None,
+ mininterval=0.01, smoothing=1, maxinterval=0) as t2:
+ cpu_timify(t2, timer)
+
+ for i in t2:
+ if i >= 10:
+ timer.sleep(0.1)
+ if i >= 14:
+ break
+ out2 = our_file.getvalue()
+
+ assert t.dynamic_miniters
+ assert ' 0%| | 0/100 [00:00<' in out
+ assert '11%' in out and '11%' in out2
+ # assert '12%' not in out and '12%' in out2
+ assert '13%' in out and '13%' in out2
+ assert '14%' in out and '14%' in out2
+
+
+@mark.slow
+def test_rlock_creation():
+ """Test that importing tqdm does not create multiprocessing objects."""
+ mp = importorskip('multiprocessing')
+ if not hasattr(mp, 'get_context'):
+ skip("missing multiprocessing.get_context")
+
+ # Use 'spawn' instead of 'fork' so that the process does not inherit any
+ # globals that have been constructed by running other tests
+ ctx = mp.get_context('spawn')
+ with ctx.Pool(1) as pool:
+ # The pool will propagate the error if the target method fails
+ pool.apply(_rlock_creation_target)
+
+
+def _rlock_creation_target():
+ """Check that the RLock has not been constructed."""
+ import multiprocessing as mp
+ patch = importorskip('unittest.mock').patch
+
+ # Patch the RLock class/method but use the original implementation
+ with patch('multiprocessing.RLock', wraps=mp.RLock) as rlock_mock:
+ # Importing the module should not create a lock
+ from tqdm import tqdm
+ assert rlock_mock.call_count == 0
+ # Creating a progress bar should initialize the lock
+ with closing(StringIO()) as our_file:
+ with tqdm(file=our_file) as _: # NOQA
+ pass
+ assert rlock_mock.call_count == 1
+ # Creating a progress bar again should reuse the lock
+ with closing(StringIO()) as our_file:
+ with tqdm(file=our_file) as _: # NOQA
+ pass
+ assert rlock_mock.call_count == 1
+
+
+def test_disable():
+ """Test disable"""
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(3), file=our_file, disable=True):
+ pass
+ assert our_file.getvalue() == ''
+
+ with closing(StringIO()) as our_file:
+ progressbar = tqdm(total=3, file=our_file, miniters=1, disable=True)
+ progressbar.update(3)
+ progressbar.close()
+ assert our_file.getvalue() == ''
+
+
+def test_infinite_total():
+ """Test treatment of infinite total"""
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(3), file=our_file, total=float("inf")):
+ pass
+
+
+def test_nototal():
+ """Test unknown total length"""
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(iter(range(10)), file=our_file, unit_scale=10):
+ pass
+ assert "100it" in our_file.getvalue()
+
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(iter(range(10)), file=our_file,
+ bar_format="{l_bar}{bar}{r_bar}"):
+ pass
+ assert "10/?" in our_file.getvalue()
+
+
+def test_unit():
+ """Test SI unit prefix"""
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(3), file=our_file, miniters=1, unit="bytes"):
+ pass
+ assert 'bytes/s' in our_file.getvalue()
+
+
+def test_ascii():
+ """Test ascii/unicode bar"""
+ # Test ascii autodetection
+ with closing(StringIO()) as our_file:
+ with tqdm(total=10, file=our_file, ascii=None) as t:
+ assert t.ascii # TODO: this may fail in the future
+
+ # Test ascii bar
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(3), total=15, file=our_file, miniters=1,
+ mininterval=0, ascii=True):
+ pass
+ res = our_file.getvalue().strip("\r").split("\r")
+ assert '7%|6' in res[1]
+ assert '13%|#3' in res[2]
+ assert '20%|##' in res[3]
+
+ # Test unicode bar
+ with closing(UnicodeIO()) as our_file:
+ with tqdm(total=15, file=our_file, ascii=False, mininterval=0) as t:
+ for _ in _range(3):
+ t.update()
+ res = our_file.getvalue().strip("\r").split("\r")
+ assert u"7%|\u258b" in res[1]
+ assert u"13%|\u2588\u258e" in res[2]
+ assert u"20%|\u2588\u2588" in res[3]
+
+ # Test custom bar
+ for bars in [" .oO0", " #"]:
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(len(bars) - 1), file=our_file, miniters=1,
+ mininterval=0, ascii=bars, ncols=27):
+ pass
+ res = our_file.getvalue().strip("\r").split("\r")
+ for b, line in zip(bars, res):
+ assert '|' + b + '|' in line
+
+
+def test_update():
+ """Test manual creation and updates"""
+ res = None
+ with closing(StringIO()) as our_file:
+ with tqdm(total=2, file=our_file, miniters=1, mininterval=0) as progressbar:
+ assert len(progressbar) == 2
+ progressbar.update(2)
+ assert '| 2/2' in our_file.getvalue()
+ progressbar.desc = 'dynamically notify of 4 increments in total'
+ progressbar.total = 4
+ progressbar.update(-1)
+ progressbar.update(2)
+ res = our_file.getvalue()
+ assert '| 3/4 ' in res
+ assert 'dynamically notify of 4 increments in total' in res
+
+
+def test_close():
+ """Test manual creation and closure and n_instances"""
+
+ # With `leave` option
+ with closing(StringIO()) as our_file:
+ progressbar = tqdm(total=3, file=our_file, miniters=10)
+ progressbar.update(3)
+ assert '| 3/3 ' not in our_file.getvalue() # Should be blank
+ assert len(tqdm._instances) == 1
+ progressbar.close()
+ assert len(tqdm._instances) == 0
+ assert '| 3/3 ' in our_file.getvalue()
+
+ # Without `leave` option
+ with closing(StringIO()) as our_file:
+ progressbar = tqdm(total=3, file=our_file, miniters=10, leave=False)
+ progressbar.update(3)
+ progressbar.close()
+ assert '| 3/3 ' not in our_file.getvalue() # Should be blank
+
+ # With all updates
+ with closing(StringIO()) as our_file:
+ assert len(tqdm._instances) == 0
+ with tqdm(total=3, file=our_file, miniters=0, mininterval=0,
+ leave=True) as progressbar:
+ assert len(tqdm._instances) == 1
+ progressbar.update(3)
+ res = our_file.getvalue()
+ assert '| 3/3 ' in res # Should be blank
+ assert '\n' not in res
+ # close() called
+ assert len(tqdm._instances) == 0
+
+ exres = res.rsplit(', ', 1)[0]
+ res = our_file.getvalue()
+ assert res[-1] == '\n'
+ if not res.startswith(exres):
+ raise AssertionError("\n<<< Expected:\n{0}\n>>> Got:\n{1}\n===".format(
+ exres + ', ...it/s]\n', our_file.getvalue()))
+
+ # Closing after the output stream has closed
+ with closing(StringIO()) as our_file:
+ t = tqdm(total=2, file=our_file)
+ t.update()
+ t.update()
+ t.close()
+
+
+def test_ema():
+ """Test exponential weighted average"""
+ ema = EMA(0.01)
+ assert round(ema(10), 2) == 10
+ assert round(ema(1), 2) == 5.48
+ assert round(ema(), 2) == 5.48
+ assert round(ema(1), 2) == 3.97
+ assert round(ema(1), 2) == 3.22
+
+
+def test_smoothing():
+ """Test exponential weighted average smoothing"""
+ timer = DiscreteTimer()
+
+ # -- Test disabling smoothing
+ with closing(StringIO()) as our_file:
+ with tqdm(_range(3), file=our_file, smoothing=None, leave=True) as t:
+ cpu_timify(t, timer)
+
+ for _ in t:
+ pass
+ assert '| 3/3 ' in our_file.getvalue()
+
+ # -- Test smoothing
+ # 1st case: no smoothing (only use average)
+ with closing(StringIO()) as our_file2:
+ with closing(StringIO()) as our_file:
+ t = tqdm(_range(3), file=our_file2, smoothing=None, leave=True,
+ miniters=1, mininterval=0)
+ cpu_timify(t, timer)
+
+ with tqdm(_range(3), file=our_file, smoothing=None, leave=True,
+ miniters=1, mininterval=0) as t2:
+ cpu_timify(t2, timer)
+
+ for i in t2:
+ # Sleep more for first iteration and
+ # see how quickly rate is updated
+ if i == 0:
+ timer.sleep(0.01)
+ else:
+ # Need to sleep in all iterations
+ # to calculate smoothed rate
+ # (else delta_t is 0!)
+ timer.sleep(0.001)
+ t.update()
+ n_old = len(tqdm._instances)
+ t.close()
+ assert len(tqdm._instances) == n_old - 1
+ # Get result for iter-based bar
+ a = progressbar_rate(get_bar(our_file.getvalue(), 3))
+ # Get result for manually updated bar
+ a2 = progressbar_rate(get_bar(our_file2.getvalue(), 3))
+
+ # 2nd case: use max smoothing (= instant rate)
+ with closing(StringIO()) as our_file2:
+ with closing(StringIO()) as our_file:
+ t = tqdm(_range(3), file=our_file2, smoothing=1, leave=True,
+ miniters=1, mininterval=0)
+ cpu_timify(t, timer)
+
+ with tqdm(_range(3), file=our_file, smoothing=1, leave=True,
+ miniters=1, mininterval=0) as t2:
+ cpu_timify(t2, timer)
+
+ for i in t2:
+ if i == 0:
+ timer.sleep(0.01)
+ else:
+ timer.sleep(0.001)
+ t.update()
+ t.close()
+ # Get result for iter-based bar
+ b = progressbar_rate(get_bar(our_file.getvalue(), 3))
+ # Get result for manually updated bar
+ b2 = progressbar_rate(get_bar(our_file2.getvalue(), 3))
+
+ # 3rd case: use medium smoothing
+ with closing(StringIO()) as our_file2:
+ with closing(StringIO()) as our_file:
+ t = tqdm(_range(3), file=our_file2, smoothing=0.5, leave=True,
+ miniters=1, mininterval=0)
+ cpu_timify(t, timer)
+
+ t2 = tqdm(_range(3), file=our_file, smoothing=0.5, leave=True,
+ miniters=1, mininterval=0)
+ cpu_timify(t2, timer)
+
+ for i in t2:
+ if i == 0:
+ timer.sleep(0.01)
+ else:
+ timer.sleep(0.001)
+ t.update()
+ t2.close()
+ t.close()
+ # Get result for iter-based bar
+ c = progressbar_rate(get_bar(our_file.getvalue(), 3))
+ # Get result for manually updated bar
+ c2 = progressbar_rate(get_bar(our_file2.getvalue(), 3))
+
+ # Check that medium smoothing's rate is between no and max smoothing rates
+ assert a <= c <= b
+ assert a2 <= c2 <= b2
+
+
+@mark.skipif(nt_and_no_colorama, reason="Windows without colorama")
+def test_deprecated_nested():
+ """Test nested progress bars"""
+ # TODO: test degradation on windows without colorama?
+
+ # Artificially test nested loop printing
+ # Without leave
+ our_file = StringIO()
+ try:
+ tqdm(total=2, file=our_file, nested=True)
+ except TqdmDeprecationWarning:
+ if """`nested` is deprecated and automated.
+Use `position` instead for manual control.""" not in our_file.getvalue():
+ raise
+ else:
+ raise DeprecationError("Should not allow nested kwarg")
+
+
+def test_bar_format():
+ """Test custom bar formatting"""
+ with closing(StringIO()) as our_file:
+ bar_format = ('{l_bar}{bar}|{n_fmt}/{total_fmt}-{n}/{total}'
+ '{percentage}{rate}{rate_fmt}{elapsed}{remaining}')
+ for _ in trange(2, file=our_file, leave=True, bar_format=bar_format):
+ pass
+ out = our_file.getvalue()
+ assert "\r 0%| |0/2-0/20.0None?it/s00:00?\r" in out
+
+ # Test unicode string auto conversion
+ with closing(StringIO()) as our_file:
+ bar_format = r'hello world'
+ with tqdm(ascii=False, bar_format=bar_format, file=our_file) as t:
+ assert isinstance(t.bar_format, _unicode)
+
+
+def test_custom_format():
+ """Test adding additional derived format arguments"""
+ class TqdmExtraFormat(tqdm):
+ """Provides a `total_time` format parameter"""
+ @property
+ def format_dict(self):
+ d = super(TqdmExtraFormat, self).format_dict
+ total_time = d["elapsed"] * (d["total"] or 0) / max(d["n"], 1)
+ d.update(total_time=self.format_interval(total_time) + " in total")
+ return d
+
+ with closing(StringIO()) as our_file:
+ for _ in TqdmExtraFormat(
+ range(10), file=our_file,
+ bar_format="{total_time}: {percentage:.0f}%|{bar}{r_bar}"):
+ pass
+ assert "00:00 in total" in our_file.getvalue()
+
+
+def test_eta(capsys):
+ """Test eta bar_format"""
+ from datetime import datetime as dt
+ for _ in trange(999, miniters=1, mininterval=0, leave=True,
+ bar_format='{l_bar}{eta:%Y-%m-%d}'):
+ pass
+ _, err = capsys.readouterr()
+ assert "\r100%|{eta:%Y-%m-%d}\n".format(eta=dt.now()) in err
+
+
+def test_unpause():
+ """Test unpause"""
+ timer = DiscreteTimer()
+ with closing(StringIO()) as our_file:
+ t = trange(10, file=our_file, leave=True, mininterval=0)
+ cpu_timify(t, timer)
+ timer.sleep(0.01)
+ t.update()
+ timer.sleep(0.01)
+ t.update()
+ timer.sleep(0.1) # longer wait time
+ t.unpause()
+ timer.sleep(0.01)
+ t.update()
+ timer.sleep(0.01)
+ t.update()
+ t.close()
+ r_before = progressbar_rate(get_bar(our_file.getvalue(), 2))
+ r_after = progressbar_rate(get_bar(our_file.getvalue(), 3))
+ assert r_before == r_after
+
+
+def test_disabled_unpause(capsys):
+ """Test disabled unpause"""
+ with tqdm(total=10, disable=True) as t:
+ t.update()
+ t.unpause()
+ t.update()
+ print(t)
+ out, err = capsys.readouterr()
+ assert not err
+ assert out == ' 0%| | 0/10 [00:00<?, ?it/s]\n'
+
+
+def test_reset():
+ """Test resetting a bar for re-use"""
+ with closing(StringIO()) as our_file:
+ with tqdm(total=10, file=our_file,
+ miniters=1, mininterval=0, maxinterval=0) as t:
+ t.update(9)
+ t.reset()
+ t.update()
+ t.reset(total=12)
+ t.update(10)
+ assert '| 1/10' in our_file.getvalue()
+ assert '| 10/12' in our_file.getvalue()
+
+
+def test_disabled_reset(capsys):
+ """Test disabled reset"""
+ with tqdm(total=10, disable=True) as t:
+ t.update(9)
+ t.reset()
+ t.update()
+ t.reset(total=12)
+ t.update(10)
+ print(t)
+ out, err = capsys.readouterr()
+ assert not err
+ assert out == ' 0%| | 0/12 [00:00<?, ?it/s]\n'
+
+
+@mark.skipif(nt_and_no_colorama, reason="Windows without colorama")
+def test_position():
+ """Test positioned progress bars"""
+ # Artificially test nested loop printing
+ # Without leave
+ our_file = StringIO()
+ kwargs = {'file': our_file, 'miniters': 1, 'mininterval': 0, 'maxinterval': 0}
+ t = tqdm(total=2, desc='pos2 bar', leave=False, position=2, **kwargs)
+ t.update()
+ t.close()
+ out = our_file.getvalue()
+ res = [m[0] for m in RE_pos.findall(out)]
+ exres = ['\n\n\rpos2 bar: 0%',
+ '\n\n\rpos2 bar: 50%',
+ '\n\n\r ']
+
+ pos_line_diff(res, exres)
+
+ # Test iteration-based tqdm positioning
+ our_file = StringIO()
+ kwargs["file"] = our_file
+ for _ in trange(2, desc='pos0 bar', position=0, **kwargs):
+ for _ in trange(2, desc='pos1 bar', position=1, **kwargs):
+ for _ in trange(2, desc='pos2 bar', position=2, **kwargs):
+ pass
+ out = our_file.getvalue()
+ res = [m[0] for m in RE_pos.findall(out)]
+ exres = ['\rpos0 bar: 0%',
+ '\n\rpos1 bar: 0%',
+ '\n\n\rpos2 bar: 0%',
+ '\n\n\rpos2 bar: 50%',
+ '\n\n\rpos2 bar: 100%',
+ '\rpos2 bar: 100%',
+ '\n\n\rpos1 bar: 50%',
+ '\n\n\rpos2 bar: 0%',
+ '\n\n\rpos2 bar: 50%',
+ '\n\n\rpos2 bar: 100%',
+ '\rpos2 bar: 100%',
+ '\n\n\rpos1 bar: 100%',
+ '\rpos1 bar: 100%',
+ '\n\rpos0 bar: 50%',
+ '\n\rpos1 bar: 0%',
+ '\n\n\rpos2 bar: 0%',
+ '\n\n\rpos2 bar: 50%',
+ '\n\n\rpos2 bar: 100%',
+ '\rpos2 bar: 100%',
+ '\n\n\rpos1 bar: 50%',
+ '\n\n\rpos2 bar: 0%',
+ '\n\n\rpos2 bar: 50%',
+ '\n\n\rpos2 bar: 100%',
+ '\rpos2 bar: 100%',
+ '\n\n\rpos1 bar: 100%',
+ '\rpos1 bar: 100%',
+ '\n\rpos0 bar: 100%',
+ '\rpos0 bar: 100%',
+ '\n']
+ pos_line_diff(res, exres)
+
+ # Test manual tqdm positioning
+ our_file = StringIO()
+ kwargs["file"] = our_file
+ kwargs["total"] = 2
+ t1 = tqdm(desc='pos0 bar', position=0, **kwargs)
+ t2 = tqdm(desc='pos1 bar', position=1, **kwargs)
+ t3 = tqdm(desc='pos2 bar', position=2, **kwargs)
+ for _ in _range(2):
+ t1.update()
+ t3.update()
+ t2.update()
+ out = our_file.getvalue()
+ res = [m[0] for m in RE_pos.findall(out)]
+ exres = ['\rpos0 bar: 0%',
+ '\n\rpos1 bar: 0%',
+ '\n\n\rpos2 bar: 0%',
+ '\rpos0 bar: 50%',
+ '\n\n\rpos2 bar: 50%',
+ '\n\rpos1 bar: 50%',
+ '\rpos0 bar: 100%',
+ '\n\n\rpos2 bar: 100%',
+ '\n\rpos1 bar: 100%']
+ pos_line_diff(res, exres)
+ t1.close()
+ t2.close()
+ t3.close()
+
+ # Test auto repositioning of bars when a bar is prematurely closed
+ # tqdm._instances.clear() # reset number of instances
+ with closing(StringIO()) as our_file:
+ t1 = tqdm(total=10, file=our_file, desc='1.pos0 bar', mininterval=0)
+ t2 = tqdm(total=10, file=our_file, desc='2.pos1 bar', mininterval=0)
+ t3 = tqdm(total=10, file=our_file, desc='3.pos2 bar', mininterval=0)
+ res = [m[0] for m in RE_pos.findall(our_file.getvalue())]
+ exres = ['\r1.pos0 bar: 0%',
+ '\n\r2.pos1 bar: 0%',
+ '\n\n\r3.pos2 bar: 0%']
+ pos_line_diff(res, exres)
+
+ t2.close()
+ t4 = tqdm(total=10, file=our_file, desc='4.pos2 bar', mininterval=0)
+ t1.update(1)
+ t3.update(1)
+ t4.update(1)
+ res = [m[0] for m in RE_pos.findall(our_file.getvalue())]
+ exres = ['\r1.pos0 bar: 0%',
+ '\n\r2.pos1 bar: 0%',
+ '\n\n\r3.pos2 bar: 0%',
+ '\r2.pos1 bar: 0%',
+ '\n\n\r4.pos2 bar: 0%',
+ '\r1.pos0 bar: 10%',
+ '\n\n\r3.pos2 bar: 10%',
+ '\n\r4.pos2 bar: 10%']
+ pos_line_diff(res, exres)
+ t4.close()
+ t3.close()
+ t1.close()
+
+
+def test_set_description():
+ """Test set description"""
+ with closing(StringIO()) as our_file:
+ with tqdm(desc='Hello', file=our_file) as t:
+ assert t.desc == 'Hello'
+ t.set_description_str('World')
+ assert t.desc == 'World'
+ t.set_description()
+ assert t.desc == ''
+ t.set_description('Bye')
+ assert t.desc == 'Bye: '
+ assert "World" in our_file.getvalue()
+
+ # without refresh
+ with closing(StringIO()) as our_file:
+ with tqdm(desc='Hello', file=our_file) as t:
+ assert t.desc == 'Hello'
+ t.set_description_str('World', False)
+ assert t.desc == 'World'
+ t.set_description(None, False)
+ assert t.desc == ''
+ assert "World" not in our_file.getvalue()
+
+ # unicode
+ with closing(StringIO()) as our_file:
+ with tqdm(total=10, file=our_file) as t:
+ t.set_description(u"\xe1\xe9\xed\xf3\xfa")
+
+
+def test_deprecated_gui():
+ """Test internal GUI properties"""
+ # Check: StatusPrinter iff gui is disabled
+ with closing(StringIO()) as our_file:
+ t = tqdm(total=2, gui=True, file=our_file, miniters=1, mininterval=0)
+ assert not hasattr(t, "sp")
+ try:
+ t.update(1)
+ except TqdmDeprecationWarning as e:
+ if (
+ 'Please use `tqdm.gui.tqdm(...)` instead of `tqdm(..., gui=True)`'
+ not in our_file.getvalue()
+ ):
+ raise e
+ else:
+ raise DeprecationError('Should not allow manual gui=True without'
+ ' overriding __iter__() and update()')
+ finally:
+ t._instances.clear()
+ # t.close()
+ # len(tqdm._instances) += 1 # undo the close() decrement
+
+ t = tqdm(_range(3), gui=True, file=our_file, miniters=1, mininterval=0)
+ try:
+ for _ in t:
+ pass
+ except TqdmDeprecationWarning as e:
+ if (
+ 'Please use `tqdm.gui.tqdm(...)` instead of `tqdm(..., gui=True)`'
+ not in our_file.getvalue()
+ ):
+ raise e
+ else:
+ raise DeprecationError('Should not allow manual gui=True without'
+ ' overriding __iter__() and update()')
+ finally:
+ t._instances.clear()
+ # t.close()
+ # len(tqdm._instances) += 1 # undo the close() decrement
+
+ with tqdm(total=1, gui=False, file=our_file) as t:
+ assert hasattr(t, "sp")
+
+
+def test_cmp():
+ """Test comparison functions"""
+ with closing(StringIO()) as our_file:
+ t0 = tqdm(total=10, file=our_file)
+ t1 = tqdm(total=10, file=our_file)
+ t2 = tqdm(total=10, file=our_file)
+
+ assert t0 < t1
+ assert t2 >= t0
+ assert t0 <= t2
+
+ t3 = tqdm(total=10, file=our_file)
+ t4 = tqdm(total=10, file=our_file)
+ t5 = tqdm(total=10, file=our_file)
+ t5.close()
+ t6 = tqdm(total=10, file=our_file)
+
+ assert t3 != t4
+ assert t3 > t2
+ assert t5 == t6
+ t6.close()
+ t4.close()
+ t3.close()
+ t2.close()
+ t1.close()
+ t0.close()
+
+
+def test_repr():
+ """Test representation"""
+ with closing(StringIO()) as our_file:
+ with tqdm(total=10, ascii=True, file=our_file) as t:
+ assert str(t) == ' 0%| | 0/10 [00:00<?, ?it/s]'
+
+
+def test_clear():
+ """Test clearing bar display"""
+ with closing(StringIO()) as our_file:
+ t1 = tqdm(total=10, file=our_file, desc='pos0 bar', bar_format='{l_bar}')
+ t2 = trange(10, file=our_file, desc='pos1 bar', bar_format='{l_bar}')
+ before = squash_ctrlchars(our_file.getvalue())
+ t2.clear()
+ t1.clear()
+ after = squash_ctrlchars(our_file.getvalue())
+ t1.close()
+ t2.close()
+ assert before == ['pos0 bar: 0%|', 'pos1 bar: 0%|']
+ assert after == ['', '']
+
+
+def test_clear_disabled():
+ """Test disabled clear"""
+ with closing(StringIO()) as our_file:
+ with tqdm(total=10, file=our_file, desc='pos0 bar', disable=True,
+ bar_format='{l_bar}') as t:
+ t.clear()
+ assert our_file.getvalue() == ''
+
+
+def test_refresh():
+ """Test refresh bar display"""
+ with closing(StringIO()) as our_file:
+ t1 = tqdm(total=10, file=our_file, desc='pos0 bar',
+ bar_format='{l_bar}', mininterval=999, miniters=999)
+ t2 = tqdm(total=10, file=our_file, desc='pos1 bar',
+ bar_format='{l_bar}', mininterval=999, miniters=999)
+ t1.update()
+ t2.update()
+ before = squash_ctrlchars(our_file.getvalue())
+ t1.refresh()
+ t2.refresh()
+ after = squash_ctrlchars(our_file.getvalue())
+ t1.close()
+ t2.close()
+
+ # Check that refreshing indeed forced the display to use realtime state
+ assert before == [u'pos0 bar: 0%|', u'pos1 bar: 0%|']
+ assert after == [u'pos0 bar: 10%|', u'pos1 bar: 10%|']
+
+
+def test_disabled_repr(capsys):
+ """Test disabled repr"""
+ with tqdm(total=10, disable=True) as t:
+ str(t)
+ t.update()
+ print(t)
+ out, err = capsys.readouterr()
+ assert not err
+ assert out == ' 0%| | 0/10 [00:00<?, ?it/s]\n'
+
+
+def test_disabled_refresh():
+ """Test disabled refresh"""
+ with closing(StringIO()) as our_file:
+ with tqdm(total=10, file=our_file, desc='pos0 bar', disable=True,
+ bar_format='{l_bar}', mininterval=999, miniters=999) as t:
+ t.update()
+ t.refresh()
+
+ assert our_file.getvalue() == ''
+
+
+def test_write():
+ """Test write messages"""
+ s = "Hello world"
+ with closing(StringIO()) as our_file:
+ # Change format to keep only left part w/o bar and it/s rate
+ t1 = tqdm(total=10, file=our_file, desc='pos0 bar',
+ bar_format='{l_bar}', mininterval=0, miniters=1)
+ t2 = trange(10, file=our_file, desc='pos1 bar', bar_format='{l_bar}',
+ mininterval=0, miniters=1)
+ t3 = tqdm(total=10, file=our_file, desc='pos2 bar',
+ bar_format='{l_bar}', mininterval=0, miniters=1)
+ t1.update()
+ t2.update()
+ t3.update()
+ before = our_file.getvalue()
+
+ # Write msg and see if bars are correctly redrawn below the msg
+ t1.write(s, file=our_file) # call as an instance method
+ tqdm.write(s, file=our_file) # call as a class method
+ after = our_file.getvalue()
+
+ t1.close()
+ t2.close()
+ t3.close()
+
+ before_squashed = squash_ctrlchars(before)
+ after_squashed = squash_ctrlchars(after)
+
+ assert after_squashed == [s, s] + before_squashed
+
+ # Check that no bar clearing if different file
+ with closing(StringIO()) as our_file_bar:
+ with closing(StringIO()) as our_file_write:
+ t1 = tqdm(total=10, file=our_file_bar, desc='pos0 bar',
+ bar_format='{l_bar}', mininterval=0, miniters=1)
+
+ t1.update()
+ before_bar = our_file_bar.getvalue()
+
+ tqdm.write(s, file=our_file_write)
+
+ after_bar = our_file_bar.getvalue()
+ t1.close()
+
+ assert before_bar == after_bar
+
+ # Test stdout/stderr anti-mixup strategy
+ # Backup stdout/stderr
+ stde = sys.stderr
+ stdo = sys.stdout
+ # Mock stdout/stderr
+ with closing(StringIO()) as our_stderr:
+ with closing(StringIO()) as our_stdout:
+ sys.stderr = our_stderr
+ sys.stdout = our_stdout
+ t1 = tqdm(total=10, file=sys.stderr, desc='pos0 bar',
+ bar_format='{l_bar}', mininterval=0, miniters=1)
+
+ t1.update()
+ before_err = sys.stderr.getvalue()
+ before_out = sys.stdout.getvalue()
+
+ tqdm.write(s, file=sys.stdout)
+ after_err = sys.stderr.getvalue()
+ after_out = sys.stdout.getvalue()
+
+ t1.close()
+
+ assert before_err == '\rpos0 bar: 0%|\rpos0 bar: 10%|'
+ assert before_out == ''
+ after_err_res = [m[0] for m in RE_pos.findall(after_err)]
+ exres = ['\rpos0 bar: 0%|',
+ '\rpos0 bar: 10%|',
+ '\r ',
+ '\r\rpos0 bar: 10%|']
+ pos_line_diff(after_err_res, exres)
+ assert after_out == s + '\n'
+ # Restore stdout and stderr
+ sys.stderr = stde
+ sys.stdout = stdo
+
+
+def test_len():
+ """Test advance len (numpy array shape)"""
+ np = importorskip('numpy')
+ with closing(StringIO()) as f:
+ with tqdm(np.zeros((3, 4)), file=f) as t:
+ assert len(t) == 3
+
+
+def test_autodisable_disable():
+ """Test autodisable will disable on non-TTY"""
+ with closing(StringIO()) as our_file:
+ with tqdm(total=10, disable=None, file=our_file) as t:
+ t.update(3)
+ assert our_file.getvalue() == ''
+
+
+def test_autodisable_enable():
+ """Test autodisable will not disable on TTY"""
+ with closing(StringIO()) as our_file:
+ our_file.isatty = lambda: True
+ with tqdm(total=10, disable=None, file=our_file) as t:
+ t.update()
+ assert our_file.getvalue() != ''
+
+
+def test_deprecation_exception():
+ def test_TqdmDeprecationWarning():
+ with closing(StringIO()) as our_file:
+ raise (TqdmDeprecationWarning('Test!', fp_write=getattr(
+ our_file, 'write', sys.stderr.write)))
+
+ def test_TqdmDeprecationWarning_nofpwrite():
+ raise TqdmDeprecationWarning('Test!', fp_write=None)
+
+ raises(TqdmDeprecationWarning, test_TqdmDeprecationWarning)
+ raises(Exception, test_TqdmDeprecationWarning_nofpwrite)
+
+
+def test_postfix():
+ """Test postfix"""
+ postfix = {'float': 0.321034, 'gen': 543, 'str': 'h', 'lst': [2]}
+ postfix_order = (('w', 'w'), ('a', 0)) # no need for OrderedDict
+ expected = ['float=0.321', 'gen=543', 'lst=[2]', 'str=h']
+ expected_order = ['w=w', 'a=0', 'float=0.321', 'gen=543', 'lst=[2]', 'str=h']
+
+ # Test postfix set at init
+ with closing(StringIO()) as our_file:
+ with tqdm(total=10, file=our_file, desc='pos0 bar',
+ bar_format='{r_bar}', postfix=postfix) as t1:
+ t1.refresh()
+ out = our_file.getvalue()
+
+ # Test postfix set after init
+ with closing(StringIO()) as our_file:
+ with trange(10, file=our_file, desc='pos1 bar', bar_format='{r_bar}',
+ postfix=None) as t2:
+ t2.set_postfix(**postfix)
+ t2.refresh()
+ out2 = our_file.getvalue()
+
+ # Order of items in dict may change, so need a loop to check per item
+ for res in expected:
+ assert res in out
+ assert res in out2
+
+ # Test postfix (with ordered dict and no refresh) set after init
+ with closing(StringIO()) as our_file:
+ with trange(10, file=our_file, desc='pos2 bar', bar_format='{r_bar}',
+ postfix=None) as t3:
+ t3.set_postfix(postfix_order, False, **postfix)
+ t3.refresh() # explicit external refresh
+ out3 = our_file.getvalue()
+
+ out3 = out3[1:-1].split(', ')[3:]
+ assert out3 == expected_order
+
+ # Test postfix (with ordered dict and refresh) set after init
+ with closing(StringIO()) as our_file:
+ with trange(10, file=our_file, desc='pos2 bar',
+ bar_format='{r_bar}', postfix=None) as t4:
+ t4.set_postfix(postfix_order, True, **postfix)
+ t4.refresh() # double refresh
+ out4 = our_file.getvalue()
+
+ assert out4.count('\r') > out3.count('\r')
+ assert out4.count(", ".join(expected_order)) == 2
+
+ # Test setting postfix string directly
+ with closing(StringIO()) as our_file:
+ with trange(10, file=our_file, desc='pos2 bar', bar_format='{r_bar}',
+ postfix=None) as t5:
+ t5.set_postfix_str("Hello", False)
+ t5.set_postfix_str("World")
+ out5 = our_file.getvalue()
+
+ assert "Hello" not in out5
+ out5 = out5[1:-1].split(', ')[3:]
+ assert out5 == ["World"]
+
+
+def test_postfix_direct():
+ """Test directly assigning non-str objects to postfix"""
+ with closing(StringIO()) as our_file:
+ with tqdm(total=10, file=our_file, miniters=1, mininterval=0,
+ bar_format="{postfix[0][name]} {postfix[1]:>5.2f}",
+ postfix=[{'name': "foo"}, 42]) as t:
+ for i in range(10):
+ if i % 2:
+ t.postfix[0]["name"] = "abcdefghij"[i]
+ else:
+ t.postfix[1] = i
+ t.update()
+ res = our_file.getvalue()
+ assert "f 6.00" in res
+ assert "h 6.00" in res
+ assert "h 8.00" in res
+ assert "j 8.00" in res
+
+
+@contextmanager
+def std_out_err_redirect_tqdm(tqdm_file=sys.stderr):
+ orig_out_err = sys.stdout, sys.stderr
+ try:
+ sys.stdout = sys.stderr = DummyTqdmFile(tqdm_file)
+ yield orig_out_err[0]
+ # Relay exceptions
+ except Exception as exc:
+ raise exc
+ # Always restore sys.stdout/err if necessary
+ finally:
+ sys.stdout, sys.stderr = orig_out_err
+
+
+def test_file_redirection():
+ """Test redirection of output"""
+ with closing(StringIO()) as our_file:
+ # Redirect stdout to tqdm.write()
+ with std_out_err_redirect_tqdm(tqdm_file=our_file):
+ with tqdm(total=3) as pbar:
+ print("Such fun")
+ pbar.update(1)
+ print("Such", "fun")
+ pbar.update(1)
+ print("Such ", end="")
+ print("fun")
+ pbar.update(1)
+ res = our_file.getvalue()
+ assert res.count("Such fun\n") == 3
+ assert "0/3" in res
+ assert "3/3" in res
+
+
+def test_external_write():
+ """Test external write mode"""
+ with closing(StringIO()) as our_file:
+ # Redirect stdout to tqdm.write()
+ for _ in trange(3, file=our_file):
+ del tqdm._lock # classmethod should be able to recreate lock
+ with tqdm.external_write_mode(file=our_file):
+ our_file.write("Such fun\n")
+ res = our_file.getvalue()
+ assert res.count("Such fun\n") == 3
+ assert "0/3" in res
+ assert "3/3" in res
+
+
+def test_unit_scale():
+ """Test numeric `unit_scale`"""
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(9), unit_scale=9, file=our_file,
+ miniters=1, mininterval=0):
+ pass
+ out = our_file.getvalue()
+ assert '81/81' in out
+
+
+def patch_lock(thread=True):
+ """decorator replacing tqdm's lock with vanilla threading/multiprocessing"""
+ try:
+ if thread:
+ from threading import RLock
+ else:
+ from multiprocessing import RLock
+ lock = RLock()
+ except (ImportError, OSError) as err:
+ skip(str(err))
+
+ def outer(func):
+ """actual decorator"""
+ @wraps(func)
+ def inner(*args, **kwargs):
+ """set & reset lock even if exceptions occur"""
+ default_lock = tqdm.get_lock()
+ try:
+ tqdm.set_lock(lock)
+ return func(*args, **kwargs)
+ finally:
+ tqdm.set_lock(default_lock)
+ return inner
+ return outer
+
+
+@patch_lock(thread=False)
+def test_threading():
+ """Test multiprocess/thread-realted features"""
+ pass # TODO: test interleaved output #445
+
+
+def test_bool():
+ """Test boolean cast"""
+ def internal(our_file, disable):
+ kwargs = {'file': our_file, 'disable': disable}
+ with trange(10, **kwargs) as t:
+ assert t
+ with trange(0, **kwargs) as t:
+ assert not t
+ with tqdm(total=10, **kwargs) as t:
+ assert bool(t)
+ with tqdm(total=0, **kwargs) as t:
+ assert not bool(t)
+ with tqdm([], **kwargs) as t:
+ assert not t
+ with tqdm([0], **kwargs) as t:
+ assert t
+ with tqdm(iter([]), **kwargs) as t:
+ assert t
+ with tqdm(iter([1, 2, 3]), **kwargs) as t:
+ assert t
+ with tqdm(**kwargs) as t:
+ try:
+ print(bool(t))
+ except TypeError:
+ pass
+ else:
+ raise TypeError("Expected bool(tqdm()) to fail")
+
+ # test with and without disable
+ with closing(StringIO()) as our_file:
+ internal(our_file, False)
+ internal(our_file, True)
+
+
+def backendCheck(module):
+ """Test tqdm-like module fallback"""
+ tn = module.tqdm
+ tr = module.trange
+
+ with closing(StringIO()) as our_file:
+ with tn(total=10, file=our_file) as t:
+ assert len(t) == 10
+ with tr(1337) as t:
+ assert len(t) == 1337
+
+
+def test_auto():
+ """Test auto fallback"""
+ from tqdm import auto, autonotebook
+ backendCheck(autonotebook)
+ backendCheck(auto)
+
+
+def test_wrapattr():
+ """Test wrapping file-like objects"""
+ data = "a twenty-char string"
+
+ with closing(StringIO()) as our_file:
+ with closing(StringIO()) as writer:
+ with tqdm.wrapattr(writer, "write", file=our_file, bytes=True) as wrap:
+ wrap.write(data)
+ res = writer.getvalue()
+ assert data == res
+ res = our_file.getvalue()
+ assert '%.1fB [' % len(data) in res
+
+ with closing(StringIO()) as our_file:
+ with closing(StringIO()) as writer:
+ with tqdm.wrapattr(writer, "write", file=our_file, bytes=False) as wrap:
+ wrap.write(data)
+ res = our_file.getvalue()
+ assert '%dit [' % len(data) in res
+
+
+def test_float_progress():
+ """Test float totals"""
+ with closing(StringIO()) as our_file:
+ with trange(10, total=9.6, file=our_file) as t:
+ with catch_warnings(record=True) as w:
+ simplefilter("always", category=TqdmWarning)
+ for i in t:
+ if i < 9:
+ assert not w
+ assert w
+ assert "clamping frac" in str(w[-1].message)
+
+
+def test_screen_shape():
+ """Test screen shape"""
+ # ncols
+ with closing(StringIO()) as our_file:
+ with trange(10, file=our_file, ncols=50) as t:
+ list(t)
+
+ res = our_file.getvalue()
+ assert all(len(i) == 50 for i in get_bar(res))
+
+ # no second/third bar, leave=False
+ with closing(StringIO()) as our_file:
+ kwargs = {'file': our_file, 'ncols': 50, 'nrows': 2, 'miniters': 0,
+ 'mininterval': 0, 'leave': False}
+ with trange(10, desc="one", **kwargs) as t1:
+ with trange(10, desc="two", **kwargs) as t2:
+ with trange(10, desc="three", **kwargs) as t3:
+ list(t3)
+ list(t2)
+ list(t1)
+
+ res = our_file.getvalue()
+ assert "one" in res
+ assert "two" not in res
+ assert "three" not in res
+ assert "\n\n" not in res
+ assert "more hidden" in res
+ # double-check ncols
+ assert all(len(i) == 50 for i in get_bar(res)
+ if i.strip() and "more hidden" not in i)
+
+ # all bars, leave=True
+ with closing(StringIO()) as our_file:
+ kwargs = {'file': our_file, 'ncols': 50, 'nrows': 2,
+ 'miniters': 0, 'mininterval': 0}
+ with trange(10, desc="one", **kwargs) as t1:
+ with trange(10, desc="two", **kwargs) as t2:
+ assert "two" not in our_file.getvalue()
+ with trange(10, desc="three", **kwargs) as t3:
+ assert "three" not in our_file.getvalue()
+ list(t3)
+ list(t2)
+ list(t1)
+
+ res = our_file.getvalue()
+ assert "one" in res
+ assert "two" in res
+ assert "three" in res
+ assert "\n\n" not in res
+ assert "more hidden" in res
+ # double-check ncols
+ assert all(len(i) == 50 for i in get_bar(res)
+ if i.strip() and "more hidden" not in i)
+
+ # second bar becomes first, leave=False
+ with closing(StringIO()) as our_file:
+ kwargs = {'file': our_file, 'ncols': 50, 'nrows': 2, 'miniters': 0,
+ 'mininterval': 0, 'leave': False}
+ t1 = tqdm(total=10, desc="one", **kwargs)
+ with tqdm(total=10, desc="two", **kwargs) as t2:
+ t1.update()
+ t2.update()
+ t1.close()
+ res = our_file.getvalue()
+ assert "one" in res
+ assert "two" not in res
+ assert "more hidden" in res
+ t2.update()
+
+ res = our_file.getvalue()
+ assert "two" in res
+
+
+def test_initial():
+ """Test `initial`"""
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(9), initial=10, total=19, file=our_file,
+ miniters=1, mininterval=0):
+ pass
+ out = our_file.getvalue()
+ assert '10/19' in out
+ assert '19/19' in out
+
+
+def test_colour():
+ """Test `colour`"""
+ with closing(StringIO()) as our_file:
+ for _ in tqdm(_range(9), file=our_file, colour="#beefed"):
+ pass
+ out = our_file.getvalue()
+ assert '\x1b[38;2;%d;%d;%dm' % (0xbe, 0xef, 0xed) in out
+
+ with catch_warnings(record=True) as w:
+ simplefilter("always", category=TqdmWarning)
+ with tqdm(total=1, file=our_file, colour="charm") as t:
+ assert w
+ t.update()
+ assert "Unknown colour" in str(w[-1].message)
+
+ with closing(StringIO()) as our_file2:
+ for _ in tqdm(_range(9), file=our_file2, colour="blue"):
+ pass
+ out = our_file2.getvalue()
+ assert '\x1b[34m' in out
+
+
+def test_closed():
+ """Test writing to closed file"""
+ with closing(StringIO()) as our_file:
+ for i in trange(9, file=our_file, miniters=1, mininterval=0):
+ if i == 5:
+ our_file.close()
+
+
+def test_reversed(capsys):
+ """Test reversed()"""
+ for _ in reversed(tqdm(_range(9))):
+ pass
+ out, err = capsys.readouterr()
+ assert not out
+ assert ' 0%' in err
+ assert '100%' in err
+
+
+def test_contains(capsys):
+ """Test __contains__ doesn't iterate"""
+ with tqdm(list(range(9))) as t:
+ assert 9 not in t
+ assert all(i in t for i in _range(9))
+ out, err = capsys.readouterr()
+ assert not out
+ assert ' 0%' in err
+ assert '100%' not in err
diff --git a/tests/tests_version.py b/tests/tests_version.py
new file mode 100644
index 0000000..495c797
--- /dev/null
+++ b/tests/tests_version.py
@@ -0,0 +1,14 @@
+"""Test `tqdm.__version__`."""
+import re
+from ast import literal_eval
+
+
+def test_version():
+ """Test version string"""
+ from tqdm import __version__
+ version_parts = re.split('[.-]', __version__)
+ if __version__ != "UNKNOWN":
+ assert 3 <= len(version_parts), "must have at least Major.minor.patch"
+ assert all(
+ isinstance(literal_eval(i), int) for i in version_parts[:3]
+ ), "Version Major.minor.patch must be 3 integers"