From 3d2c9fd003c14a4969f383cd5eb0966b7b6a3d7b Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 2 Mar 2024 09:20:07 +0100 Subject: Adding upstream version 4.64.1. Signed-off-by: Daniel Baumann --- tests/__init__.py | 0 tests/conftest.py | 41 + tests/py37_asyncio.py | 128 +++ tests/tests_asyncio.py | 11 + tests/tests_concurrent.py | 49 + tests/tests_contrib.py | 71 ++ tests/tests_contrib_logging.py | 173 ++++ tests/tests_dask.py | 20 + tests/tests_gui.py | 7 + tests/tests_itertools.py | 26 + tests/tests_keras.py | 93 ++ tests/tests_main.py | 245 +++++ tests/tests_notebook.py | 7 + tests/tests_pandas.py | 219 +++++ tests/tests_perf.py | 325 +++++++ tests/tests_rich.py | 10 + tests/tests_synchronisation.py | 224 +++++ tests/tests_tk.py | 7 + tests/tests_tqdm.py | 1996 ++++++++++++++++++++++++++++++++++++++++ tests/tests_version.py | 14 + 20 files changed, 3666 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/py37_asyncio.py create mode 100644 tests/tests_asyncio.py create mode 100644 tests/tests_concurrent.py create mode 100644 tests/tests_contrib.py create mode 100644 tests/tests_contrib_logging.py create mode 100644 tests/tests_dask.py create mode 100644 tests/tests_gui.py create mode 100644 tests/tests_itertools.py create mode 100644 tests/tests_keras.py create mode 100644 tests/tests_main.py create mode 100644 tests/tests_notebook.py create mode 100644 tests/tests_pandas.py create mode 100644 tests/tests_perf.py create mode 100644 tests/tests_rich.py create mode 100644 tests/tests_synchronisation.py create mode 100644 tests/tests_tk.py create mode 100644 tests/tests_tqdm.py create mode 100644 tests/tests_version.py (limited to 'tests') diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6717044 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,41 @@ +"""Shared pytest config.""" +import sys + +from pytest import fixture + +from tqdm import tqdm + + +@fixture(autouse=True) +def pretest_posttest(): + """Fixture for all tests ensuring environment cleanup""" + try: + sys.setswitchinterval(1) + except AttributeError: + sys.setcheckinterval(100) # deprecated + + if getattr(tqdm, "_instances", False): + n = len(tqdm._instances) + if n: + tqdm._instances.clear() + raise EnvironmentError( + "{0} `tqdm` instances still in existence PRE-test".format(n)) + yield + if getattr(tqdm, "_instances", False): + n = len(tqdm._instances) + if n: + tqdm._instances.clear() + raise EnvironmentError( + "{0} `tqdm` instances still in existence POST-test".format(n)) + + +if sys.version_info[0] > 2: + @fixture + def capsysbin(capsysbinary): + """alias for capsysbinary (py3)""" + return capsysbinary +else: + @fixture + def capsysbin(capsys): + """alias for capsys (py2)""" + return capsys diff --git a/tests/py37_asyncio.py b/tests/py37_asyncio.py new file mode 100644 index 0000000..8bf61e7 --- /dev/null +++ b/tests/py37_asyncio.py @@ -0,0 +1,128 @@ +import asyncio +from functools import partial +from sys import platform +from time import time + +from tqdm.asyncio import tarange, tqdm_asyncio + +from .tests_tqdm import StringIO, closing, mark + +tqdm = partial(tqdm_asyncio, miniters=0, mininterval=0) +trange = partial(tarange, miniters=0, mininterval=0) +as_completed = partial(tqdm_asyncio.as_completed, miniters=0, mininterval=0) +gather = partial(tqdm_asyncio.gather, miniters=0, mininterval=0) + + +def count(start=0, step=1): + i = start + while True: + new_start = yield i + if new_start is None: + i += step + else: + i = new_start + + +async def acount(*args, **kwargs): + for i in count(*args, **kwargs): + yield i + + +@mark.asyncio +async def test_break(): + """Test asyncio break""" + pbar = tqdm(count()) + async for _ in pbar: + break + pbar.close() + + +@mark.asyncio +async def test_generators(capsys): + """Test asyncio generators""" + with tqdm(count(), desc="counter") as pbar: + async for i in pbar: + if i >= 8: + break + _, err = capsys.readouterr() + assert '9it' in err + + with tqdm(acount(), desc="async_counter") as pbar: + async for i in pbar: + if i >= 8: + break + _, err = capsys.readouterr() + assert '9it' in err + + +@mark.asyncio +async def test_range(): + """Test asyncio range""" + with closing(StringIO()) as our_file: + async for _ in tqdm(range(9), desc="range", file=our_file): + pass + assert '9/9' in our_file.getvalue() + our_file.seek(0) + our_file.truncate() + + async for _ in trange(9, desc="trange", file=our_file): + pass + assert '9/9' in our_file.getvalue() + + +@mark.asyncio +async def test_nested(): + """Test asyncio nested""" + with closing(StringIO()) as our_file: + async for _ in tqdm(trange(9, desc="inner", file=our_file), + desc="outer", file=our_file): + pass + assert 'inner: 100%' in our_file.getvalue() + assert 'outer: 100%' in our_file.getvalue() + + +@mark.asyncio +async def test_coroutines(): + """Test asyncio coroutine.send""" + with closing(StringIO()) as our_file: + with tqdm(count(), file=our_file) as pbar: + async for i in pbar: + if i == 9: + pbar.send(-10) + elif i < 0: + assert i == -9 + break + assert '10it' in our_file.getvalue() + + +@mark.slow +@mark.asyncio +@mark.parametrize("tol", [0.2 if platform.startswith("darwin") else 0.1]) +async def test_as_completed(capsys, tol): + """Test asyncio as_completed""" + for retry in range(3): + t = time() + skew = time() - t + for i in as_completed([asyncio.sleep(0.01 * i) for i in range(30, 0, -1)]): + await i + t = time() - t - 2 * skew + try: + assert 0.3 * (1 - tol) < t < 0.3 * (1 + tol), t + _, err = capsys.readouterr() + assert '30/30' in err + except AssertionError: + if retry == 2: + raise + + +async def double(i): + return i * 2 + + +@mark.asyncio +async def test_gather(capsys): + """Test asyncio gather""" + res = await gather(*map(double, range(30))) + _, err = capsys.readouterr() + assert '30/30' in err + assert res == list(range(0, 30 * 2, 2)) diff --git a/tests/tests_asyncio.py b/tests/tests_asyncio.py new file mode 100644 index 0000000..6f08926 --- /dev/null +++ b/tests/tests_asyncio.py @@ -0,0 +1,11 @@ +"""Tests `tqdm.asyncio` on `python>=3.7`.""" +import sys + +if sys.version_info[:2] > (3, 6): + from .py37_asyncio import * # NOQA, pylint: disable=wildcard-import +else: + from .tests_tqdm import skip + try: + skip("async not supported", allow_module_level=True) + except TypeError: + pass diff --git a/tests/tests_concurrent.py b/tests/tests_concurrent.py new file mode 100644 index 0000000..5cd439c --- /dev/null +++ b/tests/tests_concurrent.py @@ -0,0 +1,49 @@ +""" +Tests for `tqdm.contrib.concurrent`. +""" +from pytest import warns + +from tqdm.contrib.concurrent import process_map, thread_map + +from .tests_tqdm import StringIO, TqdmWarning, closing, importorskip, mark, skip + + +def incr(x): + """Dummy function""" + return x + 1 + + +def test_thread_map(): + """Test contrib.concurrent.thread_map""" + with closing(StringIO()) as our_file: + a = range(9) + b = [i + 1 for i in a] + try: + assert thread_map(lambda x: x + 1, a, file=our_file) == b + except ImportError as err: + skip(str(err)) + assert thread_map(incr, a, file=our_file) == b + + +def test_process_map(): + """Test contrib.concurrent.process_map""" + with closing(StringIO()) as our_file: + a = range(9) + b = [i + 1 for i in a] + try: + assert process_map(incr, a, file=our_file) == b + except ImportError as err: + skip(str(err)) + + +@mark.parametrize("iterables,should_warn", [([], False), (['x'], False), ([()], False), + (['x', ()], False), (['x' * 1001], True), + (['x' * 100, ('x',) * 1001], True)]) +def test_chunksize_warning(iterables, should_warn): + """Test contrib.concurrent.process_map chunksize warnings""" + patch = importorskip('unittest.mock').patch + with patch('tqdm.contrib.concurrent._executor_map'): + if should_warn: + warns(TqdmWarning, process_map, incr, *iterables) + else: + process_map(incr, *iterables) diff --git a/tests/tests_contrib.py b/tests/tests_contrib.py new file mode 100644 index 0000000..69a1cad --- /dev/null +++ b/tests/tests_contrib.py @@ -0,0 +1,71 @@ +""" +Tests for `tqdm.contrib`. +""" +import sys + +import pytest + +from tqdm import tqdm +from tqdm.contrib import tenumerate, tmap, tzip + +from .tests_tqdm import StringIO, closing, importorskip + + +def incr(x): + """Dummy function""" + return x + 1 + + +@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}]) +def test_enumerate(tqdm_kwargs): + """Test contrib.tenumerate""" + with closing(StringIO()) as our_file: + a = range(9) + assert list(tenumerate(a, file=our_file, **tqdm_kwargs)) == list(enumerate(a)) + assert list(tenumerate(a, 42, file=our_file, **tqdm_kwargs)) == list( + enumerate(a, 42) + ) + with closing(StringIO()) as our_file: + _ = list(tenumerate(iter(a), file=our_file, **tqdm_kwargs)) + assert "100%" not in our_file.getvalue() + with closing(StringIO()) as our_file: + _ = list(tenumerate(iter(a), file=our_file, total=len(a), **tqdm_kwargs)) + assert "100%" in our_file.getvalue() + + +def test_enumerate_numpy(): + """Test contrib.tenumerate(numpy.ndarray)""" + np = importorskip("numpy") + with closing(StringIO()) as our_file: + a = np.random.random((42, 7)) + assert list(tenumerate(a, file=our_file)) == list(np.ndenumerate(a)) + + +@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}]) +def test_zip(tqdm_kwargs): + """Test contrib.tzip""" + with closing(StringIO()) as our_file: + a = range(9) + b = [i + 1 for i in a] + if sys.version_info[:1] < (3,): + assert tzip(a, b, file=our_file, **tqdm_kwargs) == zip(a, b) + else: + gen = tzip(a, b, file=our_file, **tqdm_kwargs) + assert gen != list(zip(a, b)) + assert list(gen) == list(zip(a, b)) + + +@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}]) +def test_map(tqdm_kwargs): + """Test contrib.tmap""" + with closing(StringIO()) as our_file: + a = range(9) + b = [i + 1 for i in a] + if sys.version_info[:1] < (3,): + assert tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs) == map( + incr, a + ) + else: + gen = tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs) + assert gen != b + assert list(gen) == b diff --git a/tests/tests_contrib_logging.py b/tests/tests_contrib_logging.py new file mode 100644 index 0000000..6f675dd --- /dev/null +++ b/tests/tests_contrib_logging.py @@ -0,0 +1,173 @@ +# pylint: disable=missing-module-docstring, missing-class-docstring +# pylint: disable=missing-function-docstring, no-self-use +from __future__ import absolute_import + +import logging +import logging.handlers +import sys +from io import StringIO + +import pytest + +from tqdm import tqdm +from tqdm.contrib.logging import _get_first_found_console_logging_handler +from tqdm.contrib.logging import _TqdmLoggingHandler as TqdmLoggingHandler +from tqdm.contrib.logging import logging_redirect_tqdm, tqdm_logging_redirect + +from .tests_tqdm import importorskip + +LOGGER = logging.getLogger(__name__) + +TEST_LOGGING_FORMATTER = logging.Formatter() + + +class CustomTqdm(tqdm): + messages = [] + + @classmethod + def write(cls, s, **__): # pylint: disable=arguments-differ + CustomTqdm.messages.append(s) + + +class ErrorRaisingTqdm(tqdm): + exception_class = RuntimeError + + @classmethod + def write(cls, s, **__): # pylint: disable=arguments-differ + raise ErrorRaisingTqdm.exception_class('fail fast') + + +class TestTqdmLoggingHandler: + def test_should_call_tqdm_write(self): + CustomTqdm.messages = [] + logger = logging.Logger('test') + logger.handlers = [TqdmLoggingHandler(CustomTqdm)] + logger.info('test') + assert CustomTqdm.messages == ['test'] + + def test_should_call_handle_error_if_exception_was_thrown(self): + patch = importorskip('unittest.mock').patch + logger = logging.Logger('test') + ErrorRaisingTqdm.exception_class = RuntimeError + handler = TqdmLoggingHandler(ErrorRaisingTqdm) + logger.handlers = [handler] + with patch.object(handler, 'handleError') as mock: + logger.info('test') + assert mock.called + + @pytest.mark.parametrize('exception_class', [ + KeyboardInterrupt, + SystemExit + ]) + def test_should_not_swallow_certain_exceptions(self, exception_class): + logger = logging.Logger('test') + ErrorRaisingTqdm.exception_class = exception_class + handler = TqdmLoggingHandler(ErrorRaisingTqdm) + logger.handlers = [handler] + with pytest.raises(exception_class): + logger.info('test') + + +class TestGetFirstFoundConsoleLoggingHandler: + def test_should_return_none_for_no_handlers(self): + assert _get_first_found_console_logging_handler([]) is None + + def test_should_return_none_without_stream_handler(self): + handler = logging.handlers.MemoryHandler(capacity=1) + assert _get_first_found_console_logging_handler([handler]) is None + + def test_should_return_none_for_stream_handler_not_stdout_or_stderr(self): + handler = logging.StreamHandler(StringIO()) + assert _get_first_found_console_logging_handler([handler]) is None + + def test_should_return_stream_handler_if_stream_is_stdout(self): + handler = logging.StreamHandler(sys.stdout) + assert _get_first_found_console_logging_handler([handler]) == handler + + def test_should_return_stream_handler_if_stream_is_stderr(self): + handler = logging.StreamHandler(sys.stderr) + assert _get_first_found_console_logging_handler([handler]) == handler + + +class TestRedirectLoggingToTqdm: + def test_should_add_and_remove_tqdm_handler(self): + logger = logging.Logger('test') + with logging_redirect_tqdm(loggers=[logger]): + assert len(logger.handlers) == 1 + assert isinstance(logger.handlers[0], TqdmLoggingHandler) + assert not logger.handlers + + def test_should_remove_and_restore_console_handlers(self): + logger = logging.Logger('test') + stderr_console_handler = logging.StreamHandler(sys.stderr) + stdout_console_handler = logging.StreamHandler(sys.stderr) + logger.handlers = [stderr_console_handler, stdout_console_handler] + with logging_redirect_tqdm(loggers=[logger]): + assert len(logger.handlers) == 1 + assert isinstance(logger.handlers[0], TqdmLoggingHandler) + assert logger.handlers == [stderr_console_handler, stdout_console_handler] + + def test_should_inherit_console_logger_formatter(self): + logger = logging.Logger('test') + formatter = logging.Formatter('custom: %(message)s') + console_handler = logging.StreamHandler(sys.stderr) + console_handler.setFormatter(formatter) + logger.handlers = [console_handler] + with logging_redirect_tqdm(loggers=[logger]): + assert logger.handlers[0].formatter == formatter + + def test_should_not_remove_stream_handlers_not_for_stdout_or_stderr(self): + logger = logging.Logger('test') + stream_handler = logging.StreamHandler(StringIO()) + logger.addHandler(stream_handler) + with logging_redirect_tqdm(loggers=[logger]): + assert len(logger.handlers) == 2 + assert logger.handlers[0] == stream_handler + assert isinstance(logger.handlers[1], TqdmLoggingHandler) + assert logger.handlers == [stream_handler] + + +class TestTqdmWithLoggingRedirect: + def test_should_add_and_remove_handler_from_root_logger_by_default(self): + original_handlers = list(logging.root.handlers) + with tqdm_logging_redirect(total=1) as pbar: + assert isinstance(logging.root.handlers[-1], TqdmLoggingHandler) + LOGGER.info('test') + pbar.update(1) + assert logging.root.handlers == original_handlers + + def test_should_add_and_remove_handler_from_custom_logger(self): + logger = logging.Logger('test') + with tqdm_logging_redirect(total=1, loggers=[logger]) as pbar: + assert len(logger.handlers) == 1 + assert isinstance(logger.handlers[0], TqdmLoggingHandler) + logger.info('test') + pbar.update(1) + assert not logger.handlers + + def test_should_not_fail_with_logger_without_console_handler(self): + logger = logging.Logger('test') + logger.handlers = [] + with tqdm_logging_redirect(total=1, loggers=[logger]): + logger.info('test') + assert not logger.handlers + + def test_should_format_message(self): + logger = logging.Logger('test') + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(logging.Formatter( + r'prefix:%(message)s' + )) + logger.handlers = [console_handler] + CustomTqdm.messages = [] + with tqdm_logging_redirect(loggers=[logger], tqdm_class=CustomTqdm): + logger.info('test') + assert CustomTqdm.messages == ['prefix:test'] + + def test_use_root_logger_by_default_and_write_to_custom_tqdm(self): + logger = logging.root + CustomTqdm.messages = [] + with tqdm_logging_redirect(total=1, tqdm_class=CustomTqdm) as pbar: + assert isinstance(pbar, CustomTqdm) + logger.info('test') + assert CustomTqdm.messages == ['test'] diff --git a/tests/tests_dask.py b/tests/tests_dask.py new file mode 100644 index 0000000..8bf4b64 --- /dev/null +++ b/tests/tests_dask.py @@ -0,0 +1,20 @@ +from __future__ import division + +from time import sleep + +from .tests_tqdm import importorskip, mark + +pytestmark = mark.slow + + +def test_dask(capsys): + """Test tqdm.dask.TqdmCallback""" + ProgressBar = importorskip('tqdm.dask').TqdmCallback + dask = importorskip('dask') + + schedule = [dask.delayed(sleep)(i / 10) for i in range(5)] + with ProgressBar(desc="computing"): + dask.compute(schedule) + _, err = capsys.readouterr() + assert "computing: " in err + assert '5/5' in err diff --git a/tests/tests_gui.py b/tests/tests_gui.py new file mode 100644 index 0000000..dddd918 --- /dev/null +++ b/tests/tests_gui.py @@ -0,0 +1,7 @@ +"""Test `tqdm.gui`.""" +from .tests_tqdm import importorskip + + +def test_gui_import(): + """Test `tqdm.gui` import""" + importorskip('tqdm.gui') diff --git a/tests/tests_itertools.py b/tests/tests_itertools.py new file mode 100644 index 0000000..bfb6eb2 --- /dev/null +++ b/tests/tests_itertools.py @@ -0,0 +1,26 @@ +""" +Tests for `tqdm.contrib.itertools`. +""" +import itertools as it + +from tqdm.contrib.itertools import product + +from .tests_tqdm import StringIO, closing + + +class NoLenIter(object): + def __init__(self, iterable): + self._it = iterable + + def __iter__(self): + for i in self._it: + yield i + + +def test_product(): + """Test contrib.itertools.product""" + with closing(StringIO()) as our_file: + a = range(9) + assert list(product(a, a[::-1], file=our_file)) == list(it.product(a, a[::-1])) + + assert list(product(a, NoLenIter(a), file=our_file)) == list(it.product(a, NoLenIter(a))) diff --git a/tests/tests_keras.py b/tests/tests_keras.py new file mode 100644 index 0000000..220f946 --- /dev/null +++ b/tests/tests_keras.py @@ -0,0 +1,93 @@ +from __future__ import division + +from .tests_tqdm import importorskip, mark + +pytestmark = mark.slow + + +@mark.filterwarnings("ignore:.*:DeprecationWarning") +def test_keras(capsys): + """Test tqdm.keras.TqdmCallback""" + TqdmCallback = importorskip('tqdm.keras').TqdmCallback + np = importorskip('numpy') + try: + import keras as K + except ImportError: + K = importorskip('tensorflow.keras') + + # 1D autoencoder + dtype = np.float32 + model = K.models.Sequential([ + K.layers.InputLayer((1, 1), dtype=dtype), K.layers.Conv1D(1, 1)]) + model.compile("adam", "mse") + x = np.random.rand(100, 1, 1).astype(dtype) + batch_size = 10 + batches = len(x) / batch_size + epochs = 5 + + # just epoch (no batch) progress + model.fit( + x, + x, + epochs=epochs, + batch_size=batch_size, + verbose=False, + callbacks=[ + TqdmCallback( + epochs, + desc="training", + data_size=len(x), + batch_size=batch_size, + verbose=0)]) + _, res = capsys.readouterr() + assert "training: " in res + assert "{epochs}/{epochs}".format(epochs=epochs) in res + assert "{batches}/{batches}".format(batches=batches) not in res + + # full (epoch and batch) progress + model.fit( + x, + x, + epochs=epochs, + batch_size=batch_size, + verbose=False, + callbacks=[ + TqdmCallback( + epochs, + desc="training", + data_size=len(x), + batch_size=batch_size, + verbose=2)]) + _, res = capsys.readouterr() + assert "training: " in res + assert "{epochs}/{epochs}".format(epochs=epochs) in res + assert "{batches}/{batches}".format(batches=batches) in res + + # auto-detect epochs and batches + model.fit( + x, + x, + epochs=epochs, + batch_size=batch_size, + verbose=False, + callbacks=[TqdmCallback(desc="training", verbose=2)]) + _, res = capsys.readouterr() + assert "training: " in res + assert "{epochs}/{epochs}".format(epochs=epochs) in res + assert "{batches}/{batches}".format(batches=batches) in res + + # continue training (start from epoch != 0) + initial_epoch = 3 + model.fit( + x, + x, + initial_epoch=initial_epoch, + epochs=epochs, + batch_size=batch_size, + verbose=False, + callbacks=[TqdmCallback(desc="training", verbose=0, + miniters=1, mininterval=0, maxinterval=0)]) + _, res = capsys.readouterr() + assert "training: " in res + assert "{epochs}/{epochs}".format(epochs=initial_epoch - 1) not in res + assert "{epochs}/{epochs}".format(epochs=epochs) in res diff --git a/tests/tests_main.py b/tests/tests_main.py new file mode 100644 index 0000000..0523cc7 --- /dev/null +++ b/tests/tests_main.py @@ -0,0 +1,245 @@ +"""Test CLI usage.""" +import logging +import subprocess # nosec +import sys +from functools import wraps +from os import linesep + +from tqdm.cli import TqdmKeyError, TqdmTypeError, main +from tqdm.utils import IS_WIN + +from .tests_tqdm import BytesIO, _range, closing, mark, raises + + +def restore_sys(func): + """Decorates `func(capsysbin)` to save & restore `sys.(stdin|argv)`.""" + @wraps(func) + def inner(capsysbin): + """function requiring capsysbin which may alter `sys.(stdin|argv)`""" + _SYS = sys.stdin, sys.argv + try: + res = func(capsysbin) + finally: + sys.stdin, sys.argv = _SYS + return res + + return inner + + +def norm(bytestr): + """Normalise line endings.""" + return bytestr if linesep == "\n" else bytestr.replace(linesep.encode(), b"\n") + + +@mark.slow +def test_pipes(): + """Test command line pipes""" + ls_out = subprocess.check_output(['ls']) # nosec + ls = subprocess.Popen(['ls'], stdout=subprocess.PIPE) # nosec + res = subprocess.Popen( # nosec + [sys.executable, '-c', 'from tqdm.cli import main; main()'], + stdin=ls.stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = res.communicate() + assert ls.poll() == 0 + + # actual test: + assert norm(ls_out) == norm(out) + assert b"it/s" in err + assert b"Error" not in err + + +if sys.version_info[:2] >= (3, 8): + test_pipes = mark.filterwarnings("ignore:unclosed file:ResourceWarning")( + test_pipes) + + +def test_main_import(): + """Test main CLI import""" + N = 123 + _SYS = sys.stdin, sys.argv + # test direct import + sys.stdin = [str(i).encode() for i in _range(N)] + sys.argv = ['', '--desc', 'Test CLI import', + '--ascii', 'True', '--unit_scale', 'True'] + try: + import tqdm.__main__ # NOQA, pylint: disable=unused-variable + finally: + sys.stdin, sys.argv = _SYS + + +@restore_sys +def test_main_bytes(capsysbin): + """Test CLI --bytes""" + N = 123 + + # test --delim + IN_DATA = '\0'.join(map(str, _range(N))).encode() + with closing(BytesIO()) as sys.stdin: + sys.stdin.write(IN_DATA) + # sys.stdin.write(b'\xff') # TODO + sys.stdin.seek(0) + main(sys.stderr, ['--desc', 'Test CLI delim', '--ascii', 'True', + '--delim', r'\0', '--buf_size', '64']) + out, err = capsysbin.readouterr() + assert out == IN_DATA + assert str(N) + "it" in err.decode("U8") + + # test --bytes + IN_DATA = IN_DATA.replace(b'\0', b'\n') + with closing(BytesIO()) as sys.stdin: + sys.stdin.write(IN_DATA) + sys.stdin.seek(0) + main(sys.stderr, ['--ascii', '--bytes=True', '--unit_scale', 'False']) + out, err = capsysbin.readouterr() + assert out == IN_DATA + assert str(len(IN_DATA)) + "B" in err.decode("U8") + + +@mark.skipif(sys.version_info[0] == 2, reason="no caplog on py2") +def test_main_log(capsysbin, caplog): + """Test CLI --log""" + _SYS = sys.stdin, sys.argv + N = 123 + sys.stdin = [(str(i) + '\n').encode() for i in _range(N)] + IN_DATA = b''.join(sys.stdin) + try: + with caplog.at_level(logging.INFO): + main(sys.stderr, ['--log', 'INFO']) + out, err = capsysbin.readouterr() + assert norm(out) == IN_DATA and b"123/123" in err + assert not caplog.record_tuples + with caplog.at_level(logging.DEBUG): + main(sys.stderr, ['--log', 'DEBUG']) + out, err = capsysbin.readouterr() + assert norm(out) == IN_DATA and b"123/123" in err + assert caplog.record_tuples + finally: + sys.stdin, sys.argv = _SYS + + +@restore_sys +def test_main(capsysbin): + """Test misc CLI options""" + N = 123 + sys.stdin = [(str(i) + '\n').encode() for i in _range(N)] + IN_DATA = b''.join(sys.stdin) + + # test --tee + main(sys.stderr, ['--mininterval', '0', '--miniters', '1']) + out, err = capsysbin.readouterr() + assert norm(out) == IN_DATA and b"123/123" in err + assert N <= len(err.split(b"\r")) < N + 5 + + len_err = len(err) + main(sys.stderr, ['--tee', '--mininterval', '0', '--miniters', '1']) + out, err = capsysbin.readouterr() + assert norm(out) == IN_DATA and b"123/123" in err + # spaces to clear intermediate lines could increase length + assert len_err + len(norm(out)) <= len(err) + + # test --null + main(sys.stderr, ['--null']) + out, err = capsysbin.readouterr() + assert not out and b"123/123" in err + + # test integer --update + main(sys.stderr, ['--update']) + out, err = capsysbin.readouterr() + assert norm(out) == IN_DATA + assert (str(N // 2 * N) + "it").encode() in err, "expected arithmetic sum formula" + + # test integer --update_to + main(sys.stderr, ['--update-to']) + out, err = capsysbin.readouterr() + assert norm(out) == IN_DATA + assert (str(N - 1) + "it").encode() in err + assert (str(N) + "it").encode() not in err + + with closing(BytesIO()) as sys.stdin: + sys.stdin.write(IN_DATA.replace(b'\n', b'D')) + + # test integer --update --delim + sys.stdin.seek(0) + main(sys.stderr, ['--update', '--delim', 'D']) + out, err = capsysbin.readouterr() + assert out == IN_DATA.replace(b'\n', b'D') + assert (str(N // 2 * N) + "it").encode() in err, "expected arithmetic sum" + + # test integer --update_to --delim + sys.stdin.seek(0) + main(sys.stderr, ['--update-to', '--delim', 'D']) + out, err = capsysbin.readouterr() + assert out == IN_DATA.replace(b'\n', b'D') + assert (str(N - 1) + "it").encode() in err + assert (str(N) + "it").encode() not in err + + # test float --update_to + sys.stdin = [(str(i / 2.0) + '\n').encode() for i in _range(N)] + IN_DATA = b''.join(sys.stdin) + main(sys.stderr, ['--update-to']) + out, err = capsysbin.readouterr() + assert norm(out) == IN_DATA + assert (str((N - 1) / 2.0) + "it").encode() in err + assert (str(N / 2.0) + "it").encode() not in err + + +@mark.slow +@mark.skipif(IS_WIN, reason="no manpages on windows") +def test_manpath(tmp_path): + """Test CLI --manpath""" + man = tmp_path / "tqdm.1" + assert not man.exists() + with raises(SystemExit): + main(argv=['--manpath', str(tmp_path)]) + assert man.is_file() + + +@mark.slow +@mark.skipif(IS_WIN, reason="no completion on windows") +def test_comppath(tmp_path): + """Test CLI --comppath""" + man = tmp_path / "tqdm_completion.sh" + assert not man.exists() + with raises(SystemExit): + main(argv=['--comppath', str(tmp_path)]) + assert man.is_file() + + # check most important options appear + script = man.read_text() + opts = {'--help', '--desc', '--total', '--leave', '--ncols', '--ascii', + '--dynamic_ncols', '--position', '--bytes', '--nrows', '--delim', + '--manpath', '--comppath'} + assert all(args in script for args in opts) + + +@restore_sys +def test_exceptions(capsysbin): + """Test CLI Exceptions""" + N = 123 + sys.stdin = [str(i) + '\n' for i in _range(N)] + IN_DATA = ''.join(sys.stdin).encode() + + with raises(TqdmKeyError, match="bad_arg_u_ment"): + main(sys.stderr, argv=['-ascii', '-unit_scale', '--bad_arg_u_ment', 'foo']) + out, _ = capsysbin.readouterr() + assert norm(out) == IN_DATA + + with raises(TqdmTypeError, match="invalid_bool_value"): + main(sys.stderr, argv=['-ascii', '-unit_scale', 'invalid_bool_value']) + out, _ = capsysbin.readouterr() + assert norm(out) == IN_DATA + + with raises(TqdmTypeError, match="invalid_int_value"): + main(sys.stderr, argv=['-ascii', '--total', 'invalid_int_value']) + out, _ = capsysbin.readouterr() + assert norm(out) == IN_DATA + + with raises(TqdmKeyError, match="Can only have one of --"): + main(sys.stderr, argv=['--update', '--update_to']) + out, _ = capsysbin.readouterr() + assert norm(out) == IN_DATA + + # test SystemExits + for i in ('-h', '--help', '-v', '--version'): + with raises(SystemExit): + main(argv=[i]) diff --git a/tests/tests_notebook.py b/tests/tests_notebook.py new file mode 100644 index 0000000..004d7e5 --- /dev/null +++ b/tests/tests_notebook.py @@ -0,0 +1,7 @@ +from tqdm.notebook import tqdm as tqdm_notebook + + +def test_notebook_disabled_description(): + """Test that set_description works for disabled tqdm_notebook""" + with tqdm_notebook(1, disable=True) as t: + t.set_description("description") diff --git a/tests/tests_pandas.py b/tests/tests_pandas.py new file mode 100644 index 0000000..334a97c --- /dev/null +++ b/tests/tests_pandas.py @@ -0,0 +1,219 @@ +from tqdm import tqdm + +from .tests_tqdm import StringIO, closing, importorskip, mark, skip + +pytestmark = mark.slow + +random = importorskip('numpy.random') +rand = random.rand +randint = random.randint +pd = importorskip('pandas') + + +def test_pandas_setup(): + """Test tqdm.pandas()""" + with closing(StringIO()) as our_file: + tqdm.pandas(file=our_file, leave=True, ascii=True, total=123) + series = pd.Series(randint(0, 50, (100,))) + series.progress_apply(lambda x: x + 10) + res = our_file.getvalue() + assert '100/123' in res + + +def test_pandas_rolling_expanding(): + """Test pandas.(Series|DataFrame).(rolling|expanding)""" + with closing(StringIO()) as our_file: + tqdm.pandas(file=our_file, leave=True, ascii=True) + + series = pd.Series(randint(0, 50, (123,))) + res1 = series.rolling(10).progress_apply(lambda x: 1, raw=True) + res2 = series.rolling(10).apply(lambda x: 1, raw=True) + assert res1.equals(res2) + + res3 = series.expanding(10).progress_apply(lambda x: 2, raw=True) + res4 = series.expanding(10).apply(lambda x: 2, raw=True) + assert res3.equals(res4) + + expects = ['114it'] # 123-10+1 + for exres in expects: + our_file.seek(0) + if our_file.getvalue().count(exres) < 2: + our_file.seek(0) + raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format( + exres + " at least twice.", our_file.read())) + + +def test_pandas_series(): + """Test pandas.Series.progress_apply and .progress_map""" + with closing(StringIO()) as our_file: + tqdm.pandas(file=our_file, leave=True, ascii=True) + + series = pd.Series(randint(0, 50, (123,))) + res1 = series.progress_apply(lambda x: x + 10) + res2 = series.apply(lambda x: x + 10) + assert res1.equals(res2) + + res3 = series.progress_map(lambda x: x + 10) + res4 = series.map(lambda x: x + 10) + assert res3.equals(res4) + + expects = ['100%', '123/123'] + for exres in expects: + our_file.seek(0) + if our_file.getvalue().count(exres) < 2: + our_file.seek(0) + raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format( + exres + " at least twice.", our_file.read())) + + +def test_pandas_data_frame(): + """Test pandas.DataFrame.progress_apply and .progress_applymap""" + with closing(StringIO()) as our_file: + tqdm.pandas(file=our_file, leave=True, ascii=True) + df = pd.DataFrame(randint(0, 50, (100, 200))) + + def task_func(x): + return x + 1 + + # applymap + res1 = df.progress_applymap(task_func) + res2 = df.applymap(task_func) + assert res1.equals(res2) + + # apply unhashable + res1 = [] + df.progress_apply(res1.extend) + assert len(res1) == df.size + + # apply + for axis in [0, 1, 'index', 'columns']: + res3 = df.progress_apply(task_func, axis=axis) + res4 = df.apply(task_func, axis=axis) + assert res3.equals(res4) + + our_file.seek(0) + if our_file.read().count('100%') < 3: + our_file.seek(0) + raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format( + '100% at least three times', our_file.read())) + + # apply_map, apply axis=0, apply axis=1 + expects = ['20000/20000', '200/200', '100/100'] + for exres in expects: + our_file.seek(0) + if our_file.getvalue().count(exres) < 1: + our_file.seek(0) + raise AssertionError("\nExpected:\n{0}\nIn:\n {1}\n".format( + exres + " at least once.", our_file.read())) + + +def test_pandas_groupby_apply(): + """Test pandas.DataFrame.groupby(...).progress_apply""" + with closing(StringIO()) as our_file: + tqdm.pandas(file=our_file, leave=False, ascii=True) + + df = pd.DataFrame(randint(0, 50, (500, 3))) + df.groupby(0).progress_apply(lambda x: None) + + dfs = pd.DataFrame(randint(0, 50, (500, 3)), columns=list('abc')) + dfs.groupby(['a']).progress_apply(lambda x: None) + + df2 = df = pd.DataFrame({'a': randint(1, 8, 10000), 'b': rand(10000)}) + res1 = df2.groupby("a").apply(max) + res2 = df2.groupby("a").progress_apply(max) + assert res1.equals(res2) + + our_file.seek(0) + + # don't expect final output since no `leave` and + # high dynamic `miniters` + nexres = '100%|##########|' + if nexres in our_file.read(): + our_file.seek(0) + raise AssertionError("\nDid not expect:\n{0}\nIn:{1}\n".format( + nexres, our_file.read())) + + with closing(StringIO()) as our_file: + tqdm.pandas(file=our_file, leave=True, ascii=True) + + dfs = pd.DataFrame(randint(0, 50, (500, 3)), columns=list('abc')) + dfs.loc[0] = [2, 1, 1] + dfs['d'] = 100 + + expects = ['500/500', '1/1', '4/4', '2/2'] + dfs.groupby(dfs.index).progress_apply(lambda x: None) + dfs.groupby('d').progress_apply(lambda x: None) + dfs.groupby(dfs.columns, axis=1).progress_apply(lambda x: None) + dfs.groupby([2, 2, 1, 1], axis=1).progress_apply(lambda x: None) + + our_file.seek(0) + if our_file.read().count('100%') < 4: + our_file.seek(0) + raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format( + '100% at least four times', our_file.read())) + + for exres in expects: + our_file.seek(0) + if our_file.getvalue().count(exres) < 1: + our_file.seek(0) + raise AssertionError("\nExpected:\n{0}\nIn:\n {1}\n".format( + exres + " at least once.", our_file.read())) + + +def test_pandas_leave(): + """Test pandas with `leave=True`""" + with closing(StringIO()) as our_file: + df = pd.DataFrame(randint(0, 100, (1000, 6))) + tqdm.pandas(file=our_file, leave=True, ascii=True) + df.groupby(0).progress_apply(lambda x: None) + + our_file.seek(0) + + exres = '100%|##########| 100/100' + if exres not in our_file.read(): + our_file.seek(0) + raise AssertionError("\nExpected:\n{0}\nIn:{1}\n".format( + exres, our_file.read())) + + +def test_pandas_apply_args_deprecation(): + """Test warning info in + `pandas.Dataframe(Series).progress_apply(func, *args)`""" + try: + from tqdm import tqdm_pandas + except ImportError as err: + skip(str(err)) + + with closing(StringIO()) as our_file: + tqdm_pandas(tqdm(file=our_file, leave=False, ascii=True, ncols=20)) + df = pd.DataFrame(randint(0, 50, (500, 3))) + df.progress_apply(lambda x: None, 1) # 1 shall cause a warning + # Check deprecation message + res = our_file.getvalue() + assert all(i in res for i in ( + "TqdmDeprecationWarning", "not supported", + "keyword arguments instead")) + + +def test_pandas_deprecation(): + """Test bar object instance as argument deprecation""" + try: + from tqdm import tqdm_pandas + except ImportError as err: + skip(str(err)) + + with closing(StringIO()) as our_file: + tqdm_pandas(tqdm(file=our_file, leave=False, ascii=True, ncols=20)) + df = pd.DataFrame(randint(0, 50, (500, 3))) + df.groupby(0).progress_apply(lambda x: None) + # Check deprecation message + assert "TqdmDeprecationWarning" in our_file.getvalue() + assert "instead of `tqdm_pandas(tqdm(...))`" in our_file.getvalue() + + with closing(StringIO()) as our_file: + tqdm_pandas(tqdm, file=our_file, leave=False, ascii=True, ncols=20) + df = pd.DataFrame(randint(0, 50, (500, 3))) + df.groupby(0).progress_apply(lambda x: None) + # Check deprecation message + assert "TqdmDeprecationWarning" in our_file.getvalue() + assert "instead of `tqdm_pandas(tqdm, ...)`" in our_file.getvalue() diff --git a/tests/tests_perf.py b/tests/tests_perf.py new file mode 100644 index 0000000..552a169 --- /dev/null +++ b/tests/tests_perf.py @@ -0,0 +1,325 @@ +from __future__ import division, print_function + +import sys +from contextlib import contextmanager +from functools import wraps +from time import sleep, time + +# Use relative/cpu timer to have reliable timings when there is a sudden load +try: + from time import process_time +except ImportError: + from time import clock + process_time = clock + +from tqdm import tqdm, trange + +from .tests_tqdm import _range, importorskip, mark, patch_lock, skip + +pytestmark = mark.slow + + +def cpu_sleep(t): + """Sleep the given amount of cpu time""" + start = process_time() + while (process_time() - start) < t: + pass + + +def checkCpuTime(sleeptime=0.2): + """Check if cpu time works correctly""" + if checkCpuTime.passed: + return True + # First test that sleeping does not consume cputime + start1 = process_time() + sleep(sleeptime) + t1 = process_time() - start1 + + # secondly check by comparing to cpusleep (where we actually do something) + start2 = process_time() + cpu_sleep(sleeptime) + t2 = process_time() - start2 + + if abs(t1) < 0.0001 and t1 < t2 / 10: + checkCpuTime.passed = True + return True + skip("cpu time not reliable on this machine") + + +checkCpuTime.passed = False + + +@contextmanager +def relative_timer(): + """yields a context timer function which stops ticking on exit""" + start = process_time() + + def elapser(): + return process_time() - start + + yield lambda: elapser() + spent = elapser() + + def elapser(): # NOQA + return spent + + +def retry_on_except(n=3, check_cpu_time=True): + """decroator for retrying `n` times before raising Exceptions""" + def wrapper(func): + """actual decorator""" + @wraps(func) + def test_inner(*args, **kwargs): + """may skip if `check_cpu_time` fails""" + for i in range(1, n + 1): + try: + if check_cpu_time: + checkCpuTime() + func(*args, **kwargs) + except Exception: + if i >= n: + raise + else: + return + return test_inner + return wrapper + + +def simple_progress(iterable=None, total=None, file=sys.stdout, desc='', + leave=False, miniters=1, mininterval=0.1, width=60): + """Simple progress bar reproducing tqdm's major features""" + n = [0] # use a closure + start_t = [time()] + last_n = [0] + last_t = [0] + if iterable is not None: + total = len(iterable) + + def format_interval(t): + mins, s = divmod(int(t), 60) + h, m = divmod(mins, 60) + if h: + return '{0:d}:{1:02d}:{2:02d}'.format(h, m, s) + else: + return '{0:02d}:{1:02d}'.format(m, s) + + def update_and_print(i=1): + n[0] += i + if (n[0] - last_n[0]) >= miniters: + last_n[0] = n[0] + + if (time() - last_t[0]) >= mininterval: + last_t[0] = time() # last_t[0] == current time + + spent = last_t[0] - start_t[0] + spent_fmt = format_interval(spent) + rate = n[0] / spent if spent > 0 else 0 + rate_fmt = "%.2fs/it" % (1.0 / rate) if 0.0 < rate < 1.0 else "%.2fit/s" % rate + + frac = n[0] / total + percentage = int(frac * 100) + eta = (total - n[0]) / rate if rate > 0 else 0 + eta_fmt = format_interval(eta) + + # full_bar = "#" * int(frac * width) + barfill = " " * int((1.0 - frac) * width) + bar_length, frac_bar_length = divmod(int(frac * width * 10), 10) + full_bar = '#' * bar_length + frac_bar = chr(48 + frac_bar_length) if frac_bar_length else ' ' + + file.write("\r%s %i%%|%s%s%s| %i/%i [%s<%s, %s]" % + (desc, percentage, full_bar, frac_bar, barfill, n[0], + total, spent_fmt, eta_fmt, rate_fmt)) + + if n[0] == total and leave: + file.write("\n") + file.flush() + + def update_and_yield(): + for elt in iterable: + yield elt + update_and_print() + + update_and_print(0) + if iterable is not None: + return update_and_yield() + else: + return update_and_print + + +def assert_performance(thresh, name_left, time_left, name_right, time_right): + """raises if time_left > thresh * time_right""" + if time_left > thresh * time_right: + raise ValueError( + ('{name[0]}: {time[0]:f}, ' + '{name[1]}: {time[1]:f}, ' + 'ratio {ratio:f} > {thresh:f}').format( + name=(name_left, name_right), + time=(time_left, time_right), + ratio=time_left / time_right, thresh=thresh)) + + +@retry_on_except() +def test_iter_basic_overhead(): + """Test overhead of iteration based tqdm""" + total = int(1e6) + + a = 0 + with trange(total) as t: + with relative_timer() as time_tqdm: + for i in t: + a += i + assert a == (total ** 2 - total) / 2.0 + + a = 0 + with relative_timer() as time_bench: + for i in _range(total): + a += i + sys.stdout.write(str(a)) + + assert_performance(3, 'trange', time_tqdm(), 'range', time_bench()) + + +@retry_on_except() +def test_manual_basic_overhead(): + """Test overhead of manual tqdm""" + total = int(1e6) + + with tqdm(total=total * 10, leave=True) as t: + a = 0 + with relative_timer() as time_tqdm: + for i in _range(total): + a += i + t.update(10) + + a = 0 + with relative_timer() as time_bench: + for i in _range(total): + a += i + sys.stdout.write(str(a)) + + assert_performance(5, 'tqdm', time_tqdm(), 'range', time_bench()) + + +def worker(total, blocking=True): + def incr_bar(x): + for _ in trange(total, lock_args=None if blocking else (False,), + miniters=1, mininterval=0, maxinterval=0): + pass + return x + 1 + return incr_bar + + +@retry_on_except() +@patch_lock(thread=True) +def test_lock_args(): + """Test overhead of nonblocking threads""" + ThreadPoolExecutor = importorskip('concurrent.futures').ThreadPoolExecutor + + total = 16 + subtotal = 10000 + + with ThreadPoolExecutor() as pool: + sys.stderr.write('block ... ') + sys.stderr.flush() + with relative_timer() as time_tqdm: + res = list(pool.map(worker(subtotal, True), range(total))) + assert sum(res) == sum(range(total)) + total + sys.stderr.write('noblock ... ') + sys.stderr.flush() + with relative_timer() as time_noblock: + res = list(pool.map(worker(subtotal, False), range(total))) + assert sum(res) == sum(range(total)) + total + + assert_performance(0.5, 'noblock', time_noblock(), 'tqdm', time_tqdm()) + + +@retry_on_except(10) +def test_iter_overhead_hard(): + """Test overhead of iteration based tqdm (hard)""" + total = int(1e5) + + a = 0 + with trange(total, leave=True, miniters=1, + mininterval=0, maxinterval=0) as t: + with relative_timer() as time_tqdm: + for i in t: + a += i + assert a == (total ** 2 - total) / 2.0 + + a = 0 + with relative_timer() as time_bench: + for i in _range(total): + a += i + sys.stdout.write(("%i" % a) * 40) + + assert_performance(130, 'trange', time_tqdm(), 'range', time_bench()) + + +@retry_on_except(10) +def test_manual_overhead_hard(): + """Test overhead of manual tqdm (hard)""" + total = int(1e5) + + with tqdm(total=total * 10, leave=True, miniters=1, + mininterval=0, maxinterval=0) as t: + a = 0 + with relative_timer() as time_tqdm: + for i in _range(total): + a += i + t.update(10) + + a = 0 + with relative_timer() as time_bench: + for i in _range(total): + a += i + sys.stdout.write(("%i" % a) * 40) + + assert_performance(130, 'tqdm', time_tqdm(), 'range', time_bench()) + + +@retry_on_except(10) +def test_iter_overhead_simplebar_hard(): + """Test overhead of iteration based tqdm vs simple progress bar (hard)""" + total = int(1e4) + + a = 0 + with trange(total, leave=True, miniters=1, + mininterval=0, maxinterval=0) as t: + with relative_timer() as time_tqdm: + for i in t: + a += i + assert a == (total ** 2 - total) / 2.0 + + a = 0 + s = simple_progress(_range(total), leave=True, + miniters=1, mininterval=0) + with relative_timer() as time_bench: + for i in s: + a += i + + assert_performance(10, 'trange', time_tqdm(), 'simple_progress', time_bench()) + + +@retry_on_except(10) +def test_manual_overhead_simplebar_hard(): + """Test overhead of manual tqdm vs simple progress bar (hard)""" + total = int(1e4) + + with tqdm(total=total * 10, leave=True, miniters=1, + mininterval=0, maxinterval=0) as t: + a = 0 + with relative_timer() as time_tqdm: + for i in _range(total): + a += i + t.update(10) + + simplebar_update = simple_progress(total=total * 10, leave=True, + miniters=1, mininterval=0) + a = 0 + with relative_timer() as time_bench: + for i in _range(total): + a += i + simplebar_update(10) + + assert_performance(10, 'tqdm', time_tqdm(), 'simple_progress', time_bench()) diff --git a/tests/tests_rich.py b/tests/tests_rich.py new file mode 100644 index 0000000..c75e246 --- /dev/null +++ b/tests/tests_rich.py @@ -0,0 +1,10 @@ +"""Test `tqdm.rich`.""" +import sys + +from .tests_tqdm import importorskip, mark + + +@mark.skipif(sys.version_info[:3] < (3, 6, 1), reason="`rich` needs py>=3.6.1") +def test_rich_import(): + """Test `tqdm.rich` import""" + importorskip('tqdm.rich') diff --git a/tests/tests_synchronisation.py b/tests/tests_synchronisation.py new file mode 100644 index 0000000..7ee55fb --- /dev/null +++ b/tests/tests_synchronisation.py @@ -0,0 +1,224 @@ +from __future__ import division + +import sys +from functools import wraps +from threading import Event +from time import sleep, time + +from tqdm import TMonitor, tqdm, trange + +from .tests_perf import retry_on_except +from .tests_tqdm import StringIO, closing, importorskip, patch_lock, skip + + +class Time(object): + """Fake time class class providing an offset""" + offset = 0 + + @classmethod + def reset(cls): + """zeroes internal offset""" + cls.offset = 0 + + @classmethod + def time(cls): + """time.time() + offset""" + return time() + cls.offset + + @staticmethod + def sleep(dur): + """identical to time.sleep()""" + sleep(dur) + + @classmethod + def fake_sleep(cls, dur): + """adds `dur` to internal offset""" + cls.offset += dur + sleep(0.000001) # sleep to allow interrupt (instead of pass) + + +def FakeEvent(): + """patched `threading.Event` where `wait()` uses `Time.fake_sleep()`""" + event = Event() # not a class in py2 so can't inherit + + def wait(timeout=None): + """uses Time.fake_sleep""" + if timeout is not None: + Time.fake_sleep(timeout) + return event.is_set() + + event.wait = wait + return event + + +def patch_sleep(func): + """Temporarily makes TMonitor use Time.fake_sleep""" + @wraps(func) + def inner(*args, **kwargs): + """restores TMonitor on completion regardless of Exceptions""" + TMonitor._test["time"] = Time.time + TMonitor._test["Event"] = FakeEvent + if tqdm.monitor: + assert not tqdm.monitor.get_instances() + tqdm.monitor.exit() + del tqdm.monitor + tqdm.monitor = None + try: + return func(*args, **kwargs) + finally: + # Check that class var monitor is deleted if no instance left + tqdm.monitor_interval = 10 + if tqdm.monitor: + assert not tqdm.monitor.get_instances() + tqdm.monitor.exit() + del tqdm.monitor + tqdm.monitor = None + TMonitor._test.pop("Event") + TMonitor._test.pop("time") + + return inner + + +def cpu_timify(t, timer=Time): + """Force tqdm to use the specified timer instead of system-wide time""" + t._time = timer.time + t._sleep = timer.fake_sleep + t.start_t = t.last_print_t = t._time() + return timer + + +class FakeTqdm(object): + _instances = set() + get_lock = tqdm.get_lock + + +def incr(x): + return x + 1 + + +def incr_bar(x): + with closing(StringIO()) as our_file: + for _ in trange(x, lock_args=(False,), file=our_file): + pass + return incr(x) + + +@patch_sleep +def test_monitor_thread(): + """Test dummy monitoring thread""" + monitor = TMonitor(FakeTqdm, 10) + # Test if alive, then killed + assert monitor.report() + monitor.exit() + assert not monitor.report() + assert not monitor.is_alive() + del monitor + + +@patch_sleep +def test_monitoring_and_cleanup(): + """Test for stalled tqdm instance and monitor deletion""" + # Note: should fix miniters for these tests, else with dynamic_miniters + # it's too complicated to handle with monitoring update and maxinterval... + maxinterval = tqdm.monitor_interval + assert maxinterval == 10 + total = 1000 + + with closing(StringIO()) as our_file: + with tqdm(total=total, file=our_file, miniters=500, mininterval=0.1, + maxinterval=maxinterval) as t: + cpu_timify(t, Time) + # Do a lot of iterations in a small timeframe + # (smaller than monitor interval) + Time.fake_sleep(maxinterval / 10) # monitor won't wake up + t.update(500) + # check that our fixed miniters is still there + assert t.miniters <= 500 # TODO: should really be == 500 + # Then do 1 it after monitor interval, so that monitor kicks in + Time.fake_sleep(maxinterval) + t.update(1) + # Wait for the monitor to get out of sleep's loop and update tqdm. + timeend = Time.time() + while not (t.monitor.woken >= timeend and t.miniters == 1): + Time.fake_sleep(1) # Force awake up if it woken too soon + assert t.miniters == 1 # check that monitor corrected miniters + # Note: at this point, there may be a race condition: monitor saved + # current woken time but Time.sleep() happen just before monitor + # sleep. To fix that, either sleep here or increase time in a loop + # to ensure that monitor wakes up at some point. + + # Try again but already at miniters = 1 so nothing will be done + Time.fake_sleep(maxinterval) + t.update(2) + timeend = Time.time() + while t.monitor.woken < timeend: + Time.fake_sleep(1) # Force awake if it woken too soon + # Wait for the monitor to get out of sleep's loop and update + # tqdm + assert t.miniters == 1 # check that monitor corrected miniters + + +@patch_sleep +def test_monitoring_multi(): + """Test on multiple bars, one not needing miniters adjustment""" + # Note: should fix miniters for these tests, else with dynamic_miniters + # it's too complicated to handle with monitoring update and maxinterval... + maxinterval = tqdm.monitor_interval + assert maxinterval == 10 + total = 1000 + + with closing(StringIO()) as our_file: + with tqdm(total=total, file=our_file, miniters=500, mininterval=0.1, + maxinterval=maxinterval) as t1: + # Set high maxinterval for t2 so monitor does not need to adjust it + with tqdm(total=total, file=our_file, miniters=500, mininterval=0.1, + maxinterval=1E5) as t2: + cpu_timify(t1, Time) + cpu_timify(t2, Time) + # Do a lot of iterations in a small timeframe + Time.fake_sleep(maxinterval / 10) + t1.update(500) + t2.update(500) + assert t1.miniters <= 500 # TODO: should really be == 500 + assert t2.miniters == 500 + # Then do 1 it after monitor interval, so that monitor kicks in + Time.fake_sleep(maxinterval) + t1.update(1) + t2.update(1) + # Wait for the monitor to get out of sleep and update tqdm + timeend = Time.time() + while not (t1.monitor.woken >= timeend and t1.miniters == 1): + Time.fake_sleep(1) + assert t1.miniters == 1 # check that monitor corrected miniters + assert t2.miniters == 500 # check that t2 was not adjusted + + +def test_imap(): + """Test multiprocessing.Pool""" + try: + from multiprocessing import Pool + except ImportError as err: + skip(str(err)) + + pool = Pool() + res = list(tqdm(pool.imap(incr, range(100)), disable=True)) + pool.close() + assert res[-1] == 100 + + +# py2: locks won't propagate to incr_bar so may cause `AttributeError` +@retry_on_except(n=3 if sys.version_info < (3,) else 1, check_cpu_time=False) +@patch_lock(thread=True) +def test_threadpool(): + """Test concurrent.futures.ThreadPoolExecutor""" + ThreadPoolExecutor = importorskip('concurrent.futures').ThreadPoolExecutor + + with ThreadPoolExecutor(8) as pool: + try: + res = list(tqdm(pool.map(incr_bar, range(100)), disable=True)) + except AttributeError: + if sys.version_info < (3,): + skip("not supported on py2") + else: + raise + assert sum(res) == sum(range(1, 101)) diff --git a/tests/tests_tk.py b/tests/tests_tk.py new file mode 100644 index 0000000..9aa645c --- /dev/null +++ b/tests/tests_tk.py @@ -0,0 +1,7 @@ +"""Test `tqdm.tk`.""" +from .tests_tqdm import importorskip + + +def test_tk_import(): + """Test `tqdm.tk` import""" + importorskip('tqdm.tk') diff --git a/tests/tests_tqdm.py b/tests/tests_tqdm.py new file mode 100644 index 0000000..bba457a --- /dev/null +++ b/tests/tests_tqdm.py @@ -0,0 +1,1996 @@ +# -*- coding: utf-8 -*- +# Advice: use repr(our_file.read()) to print the full output of tqdm +# (else '\r' will replace the previous lines and you'll see only the latest. +from __future__ import print_function + +import csv +import os +import re +import sys +from contextlib import contextmanager +from functools import wraps +from warnings import catch_warnings, simplefilter + +from pytest import importorskip, mark, raises, skip + +from tqdm import TqdmDeprecationWarning, TqdmWarning, tqdm, trange +from tqdm.contrib import DummyTqdmFile +from tqdm.std import EMA, Bar + +try: + from StringIO import StringIO +except ImportError: + from io import StringIO + +from io import IOBase # to support unicode strings +from io import BytesIO + + +class DeprecationError(Exception): + pass + + +# Ensure we can use `with closing(...) as ... :` syntax +if getattr(StringIO, '__exit__', False) and getattr(StringIO, '__enter__', False): + def closing(arg): + return arg +else: + from contextlib import closing + +try: + _range = xrange +except NameError: + _range = range + +try: + _unicode = unicode +except NameError: + _unicode = str + +nt_and_no_colorama = False +if os.name == 'nt': + try: + import colorama # NOQA + except ImportError: + nt_and_no_colorama = True + +# Regex definitions +# List of control characters +CTRLCHR = [r'\r', r'\n', r'\x1b\[A'] # Need to escape [ for regex +# Regular expressions compilation +RE_rate = re.compile(r'[^\d](\d[.\d]+)it/s') +RE_ctrlchr = re.compile("(%s)" % '|'.join(CTRLCHR)) # Match control chars +RE_ctrlchr_excl = re.compile('|'.join(CTRLCHR)) # Match and exclude ctrl chars +RE_pos = re.compile(r'([\r\n]+((pos\d+) bar:\s+\d+%|\s{3,6})?[^\r\n]*)') + + +def pos_line_diff(res_list, expected_list, raise_nonempty=True): + """ + Return differences between two bar output lists. + To be used with `RE_pos` + """ + res = [(r, e) for r, e in zip(res_list, expected_list) + for pos in [len(e) - len(e.lstrip('\n'))] # bar position + if r != e # simple comparison + if not r.startswith(e) # start matches + or not ( + # move up at end (maybe less due to closing bars) + any(r.endswith(end + i * '\x1b[A') for i in range(pos + 1) + for end in [ + ']', # bar + ' ']) # cleared + or '100%' in r # completed bar + or r == '\n') # final bar + or r[(-1 - pos) * len('\x1b[A'):] == '\x1b[A'] # too many moves up + if raise_nonempty and (res or len(res_list) != len(expected_list)): + if len(res_list) < len(expected_list): + res.extend([(None, e) for e in expected_list[len(res_list):]]) + elif len(res_list) > len(expected_list): + res.extend([(r, None) for r in res_list[len(expected_list):]]) + raise AssertionError( + "Got => Expected\n" + '\n'.join('%r => %r' % i for i in res)) + return res + + +class DiscreteTimer(object): + """Virtual discrete time manager, to precisely control time for tests""" + def __init__(self): + self.t = 0.0 + + def sleep(self, t): + """Sleep = increment the time counter (almost no CPU used)""" + self.t += t + + def time(self): + """Get the current time""" + return self.t + + +def cpu_timify(t, timer=None): + """Force tqdm to use the specified timer instead of system-wide time()""" + if timer is None: + timer = DiscreteTimer() + t._time = timer.time + t._sleep = timer.sleep + t.start_t = t.last_print_t = t._time() + return timer + + +class UnicodeIO(IOBase): + """Unicode version of StringIO""" + def __init__(self, *args, **kwargs): + super(UnicodeIO, self).__init__(*args, **kwargs) + self.encoding = 'U8' # io.StringIO supports unicode, but no encoding + self.text = '' + self.cursor = 0 + + def __len__(self): + return len(self.text) + + def seek(self, offset): + self.cursor = offset + + def tell(self): + return self.cursor + + def write(self, s): + self.text = self.text[:self.cursor] + s + self.text[self.cursor + len(s):] + self.cursor += len(s) + + def read(self, n=-1): + _cur = self.cursor + self.cursor = len(self) if n < 0 else min(_cur + n, len(self)) + return self.text[_cur:self.cursor] + + def getvalue(self): + return self.text + + +def get_bar(all_bars, i=None): + """Get a specific update from a whole bar traceback""" + # Split according to any used control characters + bars_split = RE_ctrlchr_excl.split(all_bars) + bars_split = list(filter(None, bars_split)) # filter out empty splits + return bars_split if i is None else bars_split[i] + + +def progressbar_rate(bar_str): + return float(RE_rate.search(bar_str).group(1)) + + +def squash_ctrlchars(s): + """Apply control characters in a string just like a terminal display""" + curline = 0 + lines = [''] # state of fake terminal + for nextctrl in filter(None, RE_ctrlchr.split(s)): + # apply control chars + if nextctrl == '\r': + # go to line beginning (simplified here: just empty the string) + lines[curline] = '' + elif nextctrl == '\n': + if curline >= len(lines) - 1: + # wrap-around creates newline + lines.append('') + # move cursor down + curline += 1 + elif nextctrl == '\x1b[A': + # move cursor up + if curline > 0: + curline -= 1 + else: + raise ValueError("Cannot go further up") + else: + # print message on current line + lines[curline] += nextctrl + return lines + + +def test_format_interval(): + """Test time interval format""" + format_interval = tqdm.format_interval + + assert format_interval(60) == '01:00' + assert format_interval(6160) == '1:42:40' + assert format_interval(238113) == '66:08:33' + + +def test_format_num(): + """Test number format""" + format_num = tqdm.format_num + + assert float(format_num(1337)) == 1337 + assert format_num(int(1e6)) == '1e+6' + assert format_num(1239876) == '1' '239' '876' + + +def test_format_meter(): + """Test statistics and progress bar formatting""" + try: + unich = unichr + except NameError: + unich = chr + + format_meter = tqdm.format_meter + + assert format_meter(0, 1000, 13) == " 0%| | 0/1000 [00:13= (3,): + assert format_meter(0, 1000, 13, ncols=68, prefix='fullwidth: ') == ( + "fullwidth: 0%| | 0/1000 [00:13= (bigstep - 1) and ((i - (bigstep - 1)) % smallstep) == 0: + timer.sleep(1e-2) + if i >= 3 * bigstep: + break + + assert "15%" in our_file.getvalue() + + # Test different behavior with and without mininterval + timer = DiscreteTimer() + total = 1000 + mininterval = 0.1 + maxinterval = 10 + with closing(StringIO()) as our_file: + with tqdm(total=total, file=our_file, miniters=None, smoothing=1, + mininterval=mininterval, maxinterval=maxinterval) as tm1: + with tqdm(total=total, file=our_file, miniters=None, smoothing=1, + mininterval=0, maxinterval=maxinterval) as tm2: + + cpu_timify(tm1, timer) + cpu_timify(tm2, timer) + + # Fast iterations, check if dynamic_miniters triggers + timer.sleep(mininterval) # to force update for t1 + tm1.update(total / 2) + tm2.update(total / 2) + assert int(tm1.miniters) == tm2.miniters == total / 2 + + # Slow iterations, check different miniters if mininterval + timer.sleep(maxinterval * 2) + tm1.update(total / 2) + tm2.update(total / 2) + res = [tm1.miniters, tm2.miniters] + assert res == [(total / 2) * mininterval / (maxinterval * 2), + (total / 2) * maxinterval / (maxinterval * 2)] + + # Same with iterable based tqdm + timer1 = DiscreteTimer() # need 2 timers for each bar because zip not work + timer2 = DiscreteTimer() + total = 100 + mininterval = 0.1 + maxinterval = 10 + with closing(StringIO()) as our_file: + t1 = tqdm(_range(total), file=our_file, miniters=None, smoothing=1, + mininterval=mininterval, maxinterval=maxinterval) + t2 = tqdm(_range(total), file=our_file, miniters=None, smoothing=1, + mininterval=0, maxinterval=maxinterval) + + cpu_timify(t1, timer1) + cpu_timify(t2, timer2) + + for i in t1: + if i == ((total / 2) - 2): + timer1.sleep(mininterval) + if i == (total - 1): + timer1.sleep(maxinterval * 2) + + for i in t2: + if i == ((total / 2) - 2): + timer2.sleep(mininterval) + if i == (total - 1): + timer2.sleep(maxinterval * 2) + + assert t1.miniters == 0.255 + assert t2.miniters == 0.5 + + t1.close() + t2.close() + + +def test_delay(): + """Test delay""" + timer = DiscreteTimer() + with closing(StringIO()) as our_file: + t = tqdm(total=2, file=our_file, leave=True, delay=3) + cpu_timify(t, timer) + timer.sleep(2) + t.update(1) + assert not our_file.getvalue() + timer.sleep(2) + t.update(1) + assert our_file.getvalue() + t.close() + + +def test_min_iters(): + """Test miniters""" + with closing(StringIO()) as our_file: + for _ in tqdm(_range(3), file=our_file, leave=True, mininterval=0, miniters=2): + pass + + out = our_file.getvalue() + assert '| 0/3 ' in out + assert '| 1/3 ' not in out + assert '| 2/3 ' in out + assert '| 3/3 ' in out + + with closing(StringIO()) as our_file: + for _ in tqdm(_range(3), file=our_file, leave=True, mininterval=0, miniters=1): + pass + + out = our_file.getvalue() + assert '| 0/3 ' in out + assert '| 1/3 ' in out + assert '| 2/3 ' in out + assert '| 3/3 ' in out + + +def test_dynamic_min_iters(): + """Test purely dynamic miniters (and manual updates and __del__)""" + with closing(StringIO()) as our_file: + total = 10 + t = tqdm(total=total, file=our_file, miniters=None, mininterval=0, smoothing=1) + + t.update() + # Increase 3 iterations + t.update(3) + # The next two iterations should be skipped because of dynamic_miniters + t.update() + t.update() + # The third iteration should be displayed + t.update() + + out = our_file.getvalue() + assert t.dynamic_miniters + t.__del__() # simulate immediate del gc + + assert ' 0%| | 0/10 [00:00<' in out + assert '40%' in out + assert '50%' not in out + assert '60%' not in out + assert '70%' in out + + # Check with smoothing=0, miniters should be set to max update seen so far + with closing(StringIO()) as our_file: + total = 10 + t = tqdm(total=total, file=our_file, miniters=None, mininterval=0, smoothing=0) + + t.update() + t.update(2) + t.update(5) # this should be stored as miniters + t.update(1) + + out = our_file.getvalue() + assert all(i in out for i in ("0/10", "1/10", "3/10")) + assert "2/10" not in out + assert t.dynamic_miniters and not t.smoothing + assert t.miniters == 5 + t.close() + + # Check iterable based tqdm + with closing(StringIO()) as our_file: + t = tqdm(_range(10), file=our_file, miniters=None, mininterval=None, + smoothing=0.5) + for _ in t: + pass + assert t.dynamic_miniters + + # No smoothing + with closing(StringIO()) as our_file: + t = tqdm(_range(10), file=our_file, miniters=None, mininterval=None, + smoothing=0) + for _ in t: + pass + assert t.dynamic_miniters + + # No dynamic_miniters (miniters is fixed manually) + with closing(StringIO()) as our_file: + t = tqdm(_range(10), file=our_file, miniters=1, mininterval=None) + for _ in t: + pass + assert not t.dynamic_miniters + + +def test_big_min_interval(): + """Test large mininterval""" + with closing(StringIO()) as our_file: + for _ in tqdm(_range(2), file=our_file, mininterval=1E10): + pass + assert '50%' not in our_file.getvalue() + + with closing(StringIO()) as our_file: + with tqdm(_range(2), file=our_file, mininterval=1E10) as t: + t.update() + t.update() + assert '50%' not in our_file.getvalue() + + +def test_smoothed_dynamic_min_iters(): + """Test smoothed dynamic miniters""" + timer = DiscreteTimer() + + with closing(StringIO()) as our_file: + with tqdm(total=100, file=our_file, miniters=None, mininterval=1, + smoothing=0.5, maxinterval=0) as t: + cpu_timify(t, timer) + + # Increase 10 iterations at once + timer.sleep(1) + t.update(10) + # The next iterations should be partially skipped + for _ in _range(2): + timer.sleep(1) + t.update(4) + for _ in _range(20): + timer.sleep(1) + t.update() + + assert t.dynamic_miniters + out = our_file.getvalue() + assert ' 0%| | 0/100 [00:00<' in out + assert '20%' in out + assert '23%' not in out + assert '25%' in out + assert '26%' not in out + assert '28%' in out + + +def test_smoothed_dynamic_min_iters_with_min_interval(): + """Test smoothed dynamic miniters with mininterval""" + timer = DiscreteTimer() + + # In this test, `miniters` should gradually decline + total = 100 + + with closing(StringIO()) as our_file: + # Test manual updating tqdm + with tqdm(total=total, file=our_file, miniters=None, mininterval=1e-3, + smoothing=1, maxinterval=0) as t: + cpu_timify(t, timer) + + t.update(10) + timer.sleep(1e-2) + for _ in _range(4): + t.update() + timer.sleep(1e-2) + out = our_file.getvalue() + assert t.dynamic_miniters + + with closing(StringIO()) as our_file: + # Test iteration-based tqdm + with tqdm(_range(total), file=our_file, miniters=None, + mininterval=0.01, smoothing=1, maxinterval=0) as t2: + cpu_timify(t2, timer) + + for i in t2: + if i >= 10: + timer.sleep(0.1) + if i >= 14: + break + out2 = our_file.getvalue() + + assert t.dynamic_miniters + assert ' 0%| | 0/100 [00:00<' in out + assert '11%' in out and '11%' in out2 + # assert '12%' not in out and '12%' in out2 + assert '13%' in out and '13%' in out2 + assert '14%' in out and '14%' in out2 + + +@mark.slow +def test_rlock_creation(): + """Test that importing tqdm does not create multiprocessing objects.""" + mp = importorskip('multiprocessing') + if not hasattr(mp, 'get_context'): + skip("missing multiprocessing.get_context") + + # Use 'spawn' instead of 'fork' so that the process does not inherit any + # globals that have been constructed by running other tests + ctx = mp.get_context('spawn') + with ctx.Pool(1) as pool: + # The pool will propagate the error if the target method fails + pool.apply(_rlock_creation_target) + + +def _rlock_creation_target(): + """Check that the RLock has not been constructed.""" + import multiprocessing as mp + patch = importorskip('unittest.mock').patch + + # Patch the RLock class/method but use the original implementation + with patch('multiprocessing.RLock', wraps=mp.RLock) as rlock_mock: + # Importing the module should not create a lock + from tqdm import tqdm + assert rlock_mock.call_count == 0 + # Creating a progress bar should initialize the lock + with closing(StringIO()) as our_file: + with tqdm(file=our_file) as _: # NOQA + pass + assert rlock_mock.call_count == 1 + # Creating a progress bar again should reuse the lock + with closing(StringIO()) as our_file: + with tqdm(file=our_file) as _: # NOQA + pass + assert rlock_mock.call_count == 1 + + +def test_disable(): + """Test disable""" + with closing(StringIO()) as our_file: + for _ in tqdm(_range(3), file=our_file, disable=True): + pass + assert our_file.getvalue() == '' + + with closing(StringIO()) as our_file: + progressbar = tqdm(total=3, file=our_file, miniters=1, disable=True) + progressbar.update(3) + progressbar.close() + assert our_file.getvalue() == '' + + +def test_infinite_total(): + """Test treatment of infinite total""" + with closing(StringIO()) as our_file: + for _ in tqdm(_range(3), file=our_file, total=float("inf")): + pass + + +def test_nototal(): + """Test unknown total length""" + with closing(StringIO()) as our_file: + for _ in tqdm(iter(range(10)), file=our_file, unit_scale=10): + pass + assert "100it" in our_file.getvalue() + + with closing(StringIO()) as our_file: + for _ in tqdm(iter(range(10)), file=our_file, + bar_format="{l_bar}{bar}{r_bar}"): + pass + assert "10/?" in our_file.getvalue() + + +def test_unit(): + """Test SI unit prefix""" + with closing(StringIO()) as our_file: + for _ in tqdm(_range(3), file=our_file, miniters=1, unit="bytes"): + pass + assert 'bytes/s' in our_file.getvalue() + + +def test_ascii(): + """Test ascii/unicode bar""" + # Test ascii autodetection + with closing(StringIO()) as our_file: + with tqdm(total=10, file=our_file, ascii=None) as t: + assert t.ascii # TODO: this may fail in the future + + # Test ascii bar + with closing(StringIO()) as our_file: + for _ in tqdm(_range(3), total=15, file=our_file, miniters=1, + mininterval=0, ascii=True): + pass + res = our_file.getvalue().strip("\r").split("\r") + assert '7%|6' in res[1] + assert '13%|#3' in res[2] + assert '20%|##' in res[3] + + # Test unicode bar + with closing(UnicodeIO()) as our_file: + with tqdm(total=15, file=our_file, ascii=False, mininterval=0) as t: + for _ in _range(3): + t.update() + res = our_file.getvalue().strip("\r").split("\r") + assert u"7%|\u258b" in res[1] + assert u"13%|\u2588\u258e" in res[2] + assert u"20%|\u2588\u2588" in res[3] + + # Test custom bar + for bars in [" .oO0", " #"]: + with closing(StringIO()) as our_file: + for _ in tqdm(_range(len(bars) - 1), file=our_file, miniters=1, + mininterval=0, ascii=bars, ncols=27): + pass + res = our_file.getvalue().strip("\r").split("\r") + for b, line in zip(bars, res): + assert '|' + b + '|' in line + + +def test_update(): + """Test manual creation and updates""" + res = None + with closing(StringIO()) as our_file: + with tqdm(total=2, file=our_file, miniters=1, mininterval=0) as progressbar: + assert len(progressbar) == 2 + progressbar.update(2) + assert '| 2/2' in our_file.getvalue() + progressbar.desc = 'dynamically notify of 4 increments in total' + progressbar.total = 4 + progressbar.update(-1) + progressbar.update(2) + res = our_file.getvalue() + assert '| 3/4 ' in res + assert 'dynamically notify of 4 increments in total' in res + + +def test_close(): + """Test manual creation and closure and n_instances""" + + # With `leave` option + with closing(StringIO()) as our_file: + progressbar = tqdm(total=3, file=our_file, miniters=10) + progressbar.update(3) + assert '| 3/3 ' not in our_file.getvalue() # Should be blank + assert len(tqdm._instances) == 1 + progressbar.close() + assert len(tqdm._instances) == 0 + assert '| 3/3 ' in our_file.getvalue() + + # Without `leave` option + with closing(StringIO()) as our_file: + progressbar = tqdm(total=3, file=our_file, miniters=10, leave=False) + progressbar.update(3) + progressbar.close() + assert '| 3/3 ' not in our_file.getvalue() # Should be blank + + # With all updates + with closing(StringIO()) as our_file: + assert len(tqdm._instances) == 0 + with tqdm(total=3, file=our_file, miniters=0, mininterval=0, + leave=True) as progressbar: + assert len(tqdm._instances) == 1 + progressbar.update(3) + res = our_file.getvalue() + assert '| 3/3 ' in res # Should be blank + assert '\n' not in res + # close() called + assert len(tqdm._instances) == 0 + + exres = res.rsplit(', ', 1)[0] + res = our_file.getvalue() + assert res[-1] == '\n' + if not res.startswith(exres): + raise AssertionError("\n<<< Expected:\n{0}\n>>> Got:\n{1}\n===".format( + exres + ', ...it/s]\n', our_file.getvalue())) + + # Closing after the output stream has closed + with closing(StringIO()) as our_file: + t = tqdm(total=2, file=our_file) + t.update() + t.update() + t.close() + + +def test_ema(): + """Test exponential weighted average""" + ema = EMA(0.01) + assert round(ema(10), 2) == 10 + assert round(ema(1), 2) == 5.48 + assert round(ema(), 2) == 5.48 + assert round(ema(1), 2) == 3.97 + assert round(ema(1), 2) == 3.22 + + +def test_smoothing(): + """Test exponential weighted average smoothing""" + timer = DiscreteTimer() + + # -- Test disabling smoothing + with closing(StringIO()) as our_file: + with tqdm(_range(3), file=our_file, smoothing=None, leave=True) as t: + cpu_timify(t, timer) + + for _ in t: + pass + assert '| 3/3 ' in our_file.getvalue() + + # -- Test smoothing + # 1st case: no smoothing (only use average) + with closing(StringIO()) as our_file2: + with closing(StringIO()) as our_file: + t = tqdm(_range(3), file=our_file2, smoothing=None, leave=True, + miniters=1, mininterval=0) + cpu_timify(t, timer) + + with tqdm(_range(3), file=our_file, smoothing=None, leave=True, + miniters=1, mininterval=0) as t2: + cpu_timify(t2, timer) + + for i in t2: + # Sleep more for first iteration and + # see how quickly rate is updated + if i == 0: + timer.sleep(0.01) + else: + # Need to sleep in all iterations + # to calculate smoothed rate + # (else delta_t is 0!) + timer.sleep(0.001) + t.update() + n_old = len(tqdm._instances) + t.close() + assert len(tqdm._instances) == n_old - 1 + # Get result for iter-based bar + a = progressbar_rate(get_bar(our_file.getvalue(), 3)) + # Get result for manually updated bar + a2 = progressbar_rate(get_bar(our_file2.getvalue(), 3)) + + # 2nd case: use max smoothing (= instant rate) + with closing(StringIO()) as our_file2: + with closing(StringIO()) as our_file: + t = tqdm(_range(3), file=our_file2, smoothing=1, leave=True, + miniters=1, mininterval=0) + cpu_timify(t, timer) + + with tqdm(_range(3), file=our_file, smoothing=1, leave=True, + miniters=1, mininterval=0) as t2: + cpu_timify(t2, timer) + + for i in t2: + if i == 0: + timer.sleep(0.01) + else: + timer.sleep(0.001) + t.update() + t.close() + # Get result for iter-based bar + b = progressbar_rate(get_bar(our_file.getvalue(), 3)) + # Get result for manually updated bar + b2 = progressbar_rate(get_bar(our_file2.getvalue(), 3)) + + # 3rd case: use medium smoothing + with closing(StringIO()) as our_file2: + with closing(StringIO()) as our_file: + t = tqdm(_range(3), file=our_file2, smoothing=0.5, leave=True, + miniters=1, mininterval=0) + cpu_timify(t, timer) + + t2 = tqdm(_range(3), file=our_file, smoothing=0.5, leave=True, + miniters=1, mininterval=0) + cpu_timify(t2, timer) + + for i in t2: + if i == 0: + timer.sleep(0.01) + else: + timer.sleep(0.001) + t.update() + t2.close() + t.close() + # Get result for iter-based bar + c = progressbar_rate(get_bar(our_file.getvalue(), 3)) + # Get result for manually updated bar + c2 = progressbar_rate(get_bar(our_file2.getvalue(), 3)) + + # Check that medium smoothing's rate is between no and max smoothing rates + assert a <= c <= b + assert a2 <= c2 <= b2 + + +@mark.skipif(nt_and_no_colorama, reason="Windows without colorama") +def test_deprecated_nested(): + """Test nested progress bars""" + # TODO: test degradation on windows without colorama? + + # Artificially test nested loop printing + # Without leave + our_file = StringIO() + try: + tqdm(total=2, file=our_file, nested=True) + except TqdmDeprecationWarning: + if """`nested` is deprecated and automated. +Use `position` instead for manual control.""" not in our_file.getvalue(): + raise + else: + raise DeprecationError("Should not allow nested kwarg") + + +def test_bar_format(): + """Test custom bar formatting""" + with closing(StringIO()) as our_file: + bar_format = ('{l_bar}{bar}|{n_fmt}/{total_fmt}-{n}/{total}' + '{percentage}{rate}{rate_fmt}{elapsed}{remaining}') + for _ in trange(2, file=our_file, leave=True, bar_format=bar_format): + pass + out = our_file.getvalue() + assert "\r 0%| |0/2-0/20.0None?it/s00:00?\r" in out + + # Test unicode string auto conversion + with closing(StringIO()) as our_file: + bar_format = r'hello world' + with tqdm(ascii=False, bar_format=bar_format, file=our_file) as t: + assert isinstance(t.bar_format, _unicode) + + +def test_custom_format(): + """Test adding additional derived format arguments""" + class TqdmExtraFormat(tqdm): + """Provides a `total_time` format parameter""" + @property + def format_dict(self): + d = super(TqdmExtraFormat, self).format_dict + total_time = d["elapsed"] * (d["total"] or 0) / max(d["n"], 1) + d.update(total_time=self.format_interval(total_time) + " in total") + return d + + with closing(StringIO()) as our_file: + for _ in TqdmExtraFormat( + range(10), file=our_file, + bar_format="{total_time}: {percentage:.0f}%|{bar}{r_bar}"): + pass + assert "00:00 in total" in our_file.getvalue() + + +def test_eta(capsys): + """Test eta bar_format""" + from datetime import datetime as dt + for _ in trange(999, miniters=1, mininterval=0, leave=True, + bar_format='{l_bar}{eta:%Y-%m-%d}'): + pass + _, err = capsys.readouterr() + assert "\r100%|{eta:%Y-%m-%d}\n".format(eta=dt.now()) in err + + +def test_unpause(): + """Test unpause""" + timer = DiscreteTimer() + with closing(StringIO()) as our_file: + t = trange(10, file=our_file, leave=True, mininterval=0) + cpu_timify(t, timer) + timer.sleep(0.01) + t.update() + timer.sleep(0.01) + t.update() + timer.sleep(0.1) # longer wait time + t.unpause() + timer.sleep(0.01) + t.update() + timer.sleep(0.01) + t.update() + t.close() + r_before = progressbar_rate(get_bar(our_file.getvalue(), 2)) + r_after = progressbar_rate(get_bar(our_file.getvalue(), 3)) + assert r_before == r_after + + +def test_disabled_unpause(capsys): + """Test disabled unpause""" + with tqdm(total=10, disable=True) as t: + t.update() + t.unpause() + t.update() + print(t) + out, err = capsys.readouterr() + assert not err + assert out == ' 0%| | 0/10 [00:00= t0 + assert t0 <= t2 + + t3 = tqdm(total=10, file=our_file) + t4 = tqdm(total=10, file=our_file) + t5 = tqdm(total=10, file=our_file) + t5.close() + t6 = tqdm(total=10, file=our_file) + + assert t3 != t4 + assert t3 > t2 + assert t5 == t6 + t6.close() + t4.close() + t3.close() + t2.close() + t1.close() + t0.close() + + +def test_repr(): + """Test representation""" + with closing(StringIO()) as our_file: + with tqdm(total=10, ascii=True, file=our_file) as t: + assert str(t) == ' 0%| | 0/10 [00:00 out3.count('\r') + assert out4.count(", ".join(expected_order)) == 2 + + # Test setting postfix string directly + with closing(StringIO()) as our_file: + with trange(10, file=our_file, desc='pos2 bar', bar_format='{r_bar}', + postfix=None) as t5: + t5.set_postfix_str("Hello", False) + t5.set_postfix_str("World") + out5 = our_file.getvalue() + + assert "Hello" not in out5 + out5 = out5[1:-1].split(', ')[3:] + assert out5 == ["World"] + + +def test_postfix_direct(): + """Test directly assigning non-str objects to postfix""" + with closing(StringIO()) as our_file: + with tqdm(total=10, file=our_file, miniters=1, mininterval=0, + bar_format="{postfix[0][name]} {postfix[1]:>5.2f}", + postfix=[{'name': "foo"}, 42]) as t: + for i in range(10): + if i % 2: + t.postfix[0]["name"] = "abcdefghij"[i] + else: + t.postfix[1] = i + t.update() + res = our_file.getvalue() + assert "f 6.00" in res + assert "h 6.00" in res + assert "h 8.00" in res + assert "j 8.00" in res + + +@contextmanager +def std_out_err_redirect_tqdm(tqdm_file=sys.stderr): + orig_out_err = sys.stdout, sys.stderr + try: + sys.stdout = sys.stderr = DummyTqdmFile(tqdm_file) + yield orig_out_err[0] + # Relay exceptions + except Exception as exc: + raise exc + # Always restore sys.stdout/err if necessary + finally: + sys.stdout, sys.stderr = orig_out_err + + +def test_file_redirection(): + """Test redirection of output""" + with closing(StringIO()) as our_file: + # Redirect stdout to tqdm.write() + with std_out_err_redirect_tqdm(tqdm_file=our_file): + with tqdm(total=3) as pbar: + print("Such fun") + pbar.update(1) + print("Such", "fun") + pbar.update(1) + print("Such ", end="") + print("fun") + pbar.update(1) + res = our_file.getvalue() + assert res.count("Such fun\n") == 3 + assert "0/3" in res + assert "3/3" in res + + +def test_external_write(): + """Test external write mode""" + with closing(StringIO()) as our_file: + # Redirect stdout to tqdm.write() + for _ in trange(3, file=our_file): + del tqdm._lock # classmethod should be able to recreate lock + with tqdm.external_write_mode(file=our_file): + our_file.write("Such fun\n") + res = our_file.getvalue() + assert res.count("Such fun\n") == 3 + assert "0/3" in res + assert "3/3" in res + + +def test_unit_scale(): + """Test numeric `unit_scale`""" + with closing(StringIO()) as our_file: + for _ in tqdm(_range(9), unit_scale=9, file=our_file, + miniters=1, mininterval=0): + pass + out = our_file.getvalue() + assert '81/81' in out + + +def patch_lock(thread=True): + """decorator replacing tqdm's lock with vanilla threading/multiprocessing""" + try: + if thread: + from threading import RLock + else: + from multiprocessing import RLock + lock = RLock() + except (ImportError, OSError) as err: + skip(str(err)) + + def outer(func): + """actual decorator""" + @wraps(func) + def inner(*args, **kwargs): + """set & reset lock even if exceptions occur""" + default_lock = tqdm.get_lock() + try: + tqdm.set_lock(lock) + return func(*args, **kwargs) + finally: + tqdm.set_lock(default_lock) + return inner + return outer + + +@patch_lock(thread=False) +def test_threading(): + """Test multiprocess/thread-realted features""" + pass # TODO: test interleaved output #445 + + +def test_bool(): + """Test boolean cast""" + def internal(our_file, disable): + kwargs = {'file': our_file, 'disable': disable} + with trange(10, **kwargs) as t: + assert t + with trange(0, **kwargs) as t: + assert not t + with tqdm(total=10, **kwargs) as t: + assert bool(t) + with tqdm(total=0, **kwargs) as t: + assert not bool(t) + with tqdm([], **kwargs) as t: + assert not t + with tqdm([0], **kwargs) as t: + assert t + with tqdm(iter([]), **kwargs) as t: + assert t + with tqdm(iter([1, 2, 3]), **kwargs) as t: + assert t + with tqdm(**kwargs) as t: + try: + print(bool(t)) + except TypeError: + pass + else: + raise TypeError("Expected bool(tqdm()) to fail") + + # test with and without disable + with closing(StringIO()) as our_file: + internal(our_file, False) + internal(our_file, True) + + +def backendCheck(module): + """Test tqdm-like module fallback""" + tn = module.tqdm + tr = module.trange + + with closing(StringIO()) as our_file: + with tn(total=10, file=our_file) as t: + assert len(t) == 10 + with tr(1337) as t: + assert len(t) == 1337 + + +def test_auto(): + """Test auto fallback""" + from tqdm import auto, autonotebook + backendCheck(autonotebook) + backendCheck(auto) + + +def test_wrapattr(): + """Test wrapping file-like objects""" + data = "a twenty-char string" + + with closing(StringIO()) as our_file: + with closing(StringIO()) as writer: + with tqdm.wrapattr(writer, "write", file=our_file, bytes=True) as wrap: + wrap.write(data) + res = writer.getvalue() + assert data == res + res = our_file.getvalue() + assert '%.1fB [' % len(data) in res + + with closing(StringIO()) as our_file: + with closing(StringIO()) as writer: + with tqdm.wrapattr(writer, "write", file=our_file, bytes=False) as wrap: + wrap.write(data) + res = our_file.getvalue() + assert '%dit [' % len(data) in res + + +def test_float_progress(): + """Test float totals""" + with closing(StringIO()) as our_file: + with trange(10, total=9.6, file=our_file) as t: + with catch_warnings(record=True) as w: + simplefilter("always", category=TqdmWarning) + for i in t: + if i < 9: + assert not w + assert w + assert "clamping frac" in str(w[-1].message) + + +def test_screen_shape(): + """Test screen shape""" + # ncols + with closing(StringIO()) as our_file: + with trange(10, file=our_file, ncols=50) as t: + list(t) + + res = our_file.getvalue() + assert all(len(i) == 50 for i in get_bar(res)) + + # no second/third bar, leave=False + with closing(StringIO()) as our_file: + kwargs = {'file': our_file, 'ncols': 50, 'nrows': 2, 'miniters': 0, + 'mininterval': 0, 'leave': False} + with trange(10, desc="one", **kwargs) as t1: + with trange(10, desc="two", **kwargs) as t2: + with trange(10, desc="three", **kwargs) as t3: + list(t3) + list(t2) + list(t1) + + res = our_file.getvalue() + assert "one" in res + assert "two" not in res + assert "three" not in res + assert "\n\n" not in res + assert "more hidden" in res + # double-check ncols + assert all(len(i) == 50 for i in get_bar(res) + if i.strip() and "more hidden" not in i) + + # all bars, leave=True + with closing(StringIO()) as our_file: + kwargs = {'file': our_file, 'ncols': 50, 'nrows': 2, + 'miniters': 0, 'mininterval': 0} + with trange(10, desc="one", **kwargs) as t1: + with trange(10, desc="two", **kwargs) as t2: + assert "two" not in our_file.getvalue() + with trange(10, desc="three", **kwargs) as t3: + assert "three" not in our_file.getvalue() + list(t3) + list(t2) + list(t1) + + res = our_file.getvalue() + assert "one" in res + assert "two" in res + assert "three" in res + assert "\n\n" not in res + assert "more hidden" in res + # double-check ncols + assert all(len(i) == 50 for i in get_bar(res) + if i.strip() and "more hidden" not in i) + + # second bar becomes first, leave=False + with closing(StringIO()) as our_file: + kwargs = {'file': our_file, 'ncols': 50, 'nrows': 2, 'miniters': 0, + 'mininterval': 0, 'leave': False} + t1 = tqdm(total=10, desc="one", **kwargs) + with tqdm(total=10, desc="two", **kwargs) as t2: + t1.update() + t2.update() + t1.close() + res = our_file.getvalue() + assert "one" in res + assert "two" not in res + assert "more hidden" in res + t2.update() + + res = our_file.getvalue() + assert "two" in res + + +def test_initial(): + """Test `initial`""" + with closing(StringIO()) as our_file: + for _ in tqdm(_range(9), initial=10, total=19, file=our_file, + miniters=1, mininterval=0): + pass + out = our_file.getvalue() + assert '10/19' in out + assert '19/19' in out + + +def test_colour(): + """Test `colour`""" + with closing(StringIO()) as our_file: + for _ in tqdm(_range(9), file=our_file, colour="#beefed"): + pass + out = our_file.getvalue() + assert '\x1b[38;2;%d;%d;%dm' % (0xbe, 0xef, 0xed) in out + + with catch_warnings(record=True) as w: + simplefilter("always", category=TqdmWarning) + with tqdm(total=1, file=our_file, colour="charm") as t: + assert w + t.update() + assert "Unknown colour" in str(w[-1].message) + + with closing(StringIO()) as our_file2: + for _ in tqdm(_range(9), file=our_file2, colour="blue"): + pass + out = our_file2.getvalue() + assert '\x1b[34m' in out + + +def test_closed(): + """Test writing to closed file""" + with closing(StringIO()) as our_file: + for i in trange(9, file=our_file, miniters=1, mininterval=0): + if i == 5: + our_file.close() + + +def test_reversed(capsys): + """Test reversed()""" + for _ in reversed(tqdm(_range(9))): + pass + out, err = capsys.readouterr() + assert not out + assert ' 0%' in err + assert '100%' in err + + +def test_contains(capsys): + """Test __contains__ doesn't iterate""" + with tqdm(list(range(9))) as t: + assert 9 not in t + assert all(i in t for i in _range(9)) + out, err = capsys.readouterr() + assert not out + assert ' 0%' in err + assert '100%' not in err diff --git a/tests/tests_version.py b/tests/tests_version.py new file mode 100644 index 0000000..495c797 --- /dev/null +++ b/tests/tests_version.py @@ -0,0 +1,14 @@ +"""Test `tqdm.__version__`.""" +import re +from ast import literal_eval + + +def test_version(): + """Test version string""" + from tqdm import __version__ + version_parts = re.split('[.-]', __version__) + if __version__ != "UNKNOWN": + assert 3 <= len(version_parts), "must have at least Major.minor.patch" + assert all( + isinstance(literal_eval(i), int) for i in version_parts[:3] + ), "Version Major.minor.patch must be 3 integers" -- cgit v1.2.3