summaryrefslogtreecommitdiffstats
path: root/tests/tests_contrib.py
blob: 69a1cada992768f4b8d147d27fb05cc475fe9a8a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Tests for `tqdm.contrib`.
"""
import sys

import pytest

from tqdm import tqdm
from tqdm.contrib import tenumerate, tmap, tzip

from .tests_tqdm import StringIO, closing, importorskip


def incr(x):
    """Dummy function"""
    return x + 1


@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
def test_enumerate(tqdm_kwargs):
    """Test contrib.tenumerate"""
    with closing(StringIO()) as our_file:
        a = range(9)
        assert list(tenumerate(a, file=our_file, **tqdm_kwargs)) == list(enumerate(a))
        assert list(tenumerate(a, 42, file=our_file, **tqdm_kwargs)) == list(
            enumerate(a, 42)
        )
    with closing(StringIO()) as our_file:
        _ = list(tenumerate(iter(a), file=our_file, **tqdm_kwargs))
        assert "100%" not in our_file.getvalue()
    with closing(StringIO()) as our_file:
        _ = list(tenumerate(iter(a), file=our_file, total=len(a), **tqdm_kwargs))
        assert "100%" in our_file.getvalue()


def test_enumerate_numpy():
    """Test contrib.tenumerate(numpy.ndarray)"""
    np = importorskip("numpy")
    with closing(StringIO()) as our_file:
        a = np.random.random((42, 7))
        assert list(tenumerate(a, file=our_file)) == list(np.ndenumerate(a))


@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
def test_zip(tqdm_kwargs):
    """Test contrib.tzip"""
    with closing(StringIO()) as our_file:
        a = range(9)
        b = [i + 1 for i in a]
        if sys.version_info[:1] < (3,):
            assert tzip(a, b, file=our_file, **tqdm_kwargs) == zip(a, b)
        else:
            gen = tzip(a, b, file=our_file, **tqdm_kwargs)
            assert gen != list(zip(a, b))
            assert list(gen) == list(zip(a, b))


@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
def test_map(tqdm_kwargs):
    """Test contrib.tmap"""
    with closing(StringIO()) as our_file:
        a = range(9)
        b = [i + 1 for i in a]
        if sys.version_info[:1] < (3,):
            assert tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs) == map(
                incr, a
            )
        else:
            gen = tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs)
            assert gen != b
            assert list(gen) == b