from __future__ import annotations import itertools import pytest from prompt_toolkit.utils import take_using_weights def test_using_weights(): def take(generator, count): return list(itertools.islice(generator, 0, count)) # Check distribution. data = take(take_using_weights(["A", "B", "C"], [5, 10, 20]), 35) assert data.count("A") == 5 assert data.count("B") == 10 assert data.count("C") == 20 assert data == [ "A", "B", "C", "C", "B", "C", "C", "A", "B", "C", "C", "B", "C", "C", "A", "B", "C", "C", "B", "C", "C", "A", "B", "C", "C", "B", "C", "C", "A", "B", "C", "C", "B", "C", "C", ] # Another order. data = take(take_using_weights(["A", "B", "C"], [20, 10, 5]), 35) assert data.count("A") == 20 assert data.count("B") == 10 assert data.count("C") == 5 # Bigger numbers. data = take(take_using_weights(["A", "B", "C"], [20, 10, 5]), 70) assert data.count("A") == 40 assert data.count("B") == 20 assert data.count("C") == 10 # Negative numbers. data = take(take_using_weights(["A", "B", "C"], [-20, 10, 0]), 70) assert data.count("A") == 0 assert data.count("B") == 70 assert data.count("C") == 0 # All zero-weight items. with pytest.raises(ValueError): take(take_using_weights(["A", "B", "C"], [0, 0, 0]), 70)