diff options
Diffstat (limited to '')
-rw-r--r-- | src/test/test_weighted_shuffle.cc | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/src/test/test_weighted_shuffle.cc b/src/test/test_weighted_shuffle.cc new file mode 100644 index 000000000..9f92cbdc0 --- /dev/null +++ b/src/test/test_weighted_shuffle.cc @@ -0,0 +1,39 @@ +// -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*- +// vim: ts=8 sw=2 smarttab + +#include "common/weighted_shuffle.h" +#include <array> +#include <map> +#include "gtest/gtest.h" + +TEST(WeightedShuffle, Basic) { + std::array<char, 5> choices{'a', 'b', 'c', 'd', 'e'}; + std::array<int, 5> weights{100, 50, 25, 10, 1}; + std::map<char, std::array<unsigned, 5>> frequency { + {'a', {0, 0, 0, 0, 0}}, + {'b', {0, 0, 0, 0, 0}}, + {'c', {0, 0, 0, 0, 0}}, + {'d', {0, 0, 0, 0, 0}}, + {'e', {0, 0, 0, 0, 0}} + }; // count each element appearing in each position + const int samples = 10000; + std::random_device rd; + for (auto i = 0; i < samples; i++) { + weighted_shuffle(begin(choices), end(choices), + begin(weights), end(weights), + std::mt19937{rd()}); + for (size_t j = 0; j < choices.size(); ++j) + ++frequency[choices[j]][j]; + } + // verify that the probability that the nth choice is selected as the first + // one is the nth weight divided by the sum of all weights + const auto total_weight = std::accumulate(weights.begin(), weights.end(), 0); + constexpr float epsilon = 0.02; + for (unsigned i = 0; i < choices.size(); i++) { + const auto& f = frequency[choices[i]]; + const auto& w = weights[i]; + ASSERT_NEAR(float(w) / total_weight, + float(f.front()) / samples, + epsilon); + } +} |