summaryrefslogtreecommitdiffstats
path: root/src/test/test_weighted_shuffle.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/test/test_weighted_shuffle.cc')
-rw-r--r--src/test/test_weighted_shuffle.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/src/test/test_weighted_shuffle.cc b/src/test/test_weighted_shuffle.cc
index 9f92cbdc0..efc1cdeb7 100644
--- a/src/test/test_weighted_shuffle.cc
+++ b/src/test/test_weighted_shuffle.cc
@@ -37,3 +37,55 @@ TEST(WeightedShuffle, Basic) {
epsilon);
}
}
+
+TEST(WeightedShuffle, ZeroedWeights) {
+ std::array<char, 5> choices{'a', 'b', 'c', 'd', 'e'};
+ std::array<int, 5> weights{0, 0, 0, 0, 0};
+ std::map<char, std::array<unsigned, 5>> frequency {
+ {'a', {0, 0, 0, 0, 0}},
+ {'b', {0, 0, 0, 0, 0}},
+ {'c', {0, 0, 0, 0, 0}},
+ {'d', {0, 0, 0, 0, 0}},
+ {'e', {0, 0, 0, 0, 0}}
+ }; // count each element appearing in each position
+ const int samples = 10000;
+ std::random_device rd;
+ for (auto i = 0; i < samples; i++) {
+ weighted_shuffle(begin(choices), end(choices),
+ begin(weights), end(weights),
+ std::mt19937{rd()});
+ for (size_t j = 0; j < choices.size(); ++j)
+ ++frequency[choices[j]][j];
+ }
+
+ for (char ch : choices) {
+ // all samples on the diagonal
+ ASSERT_EQ(std::accumulate(begin(frequency[ch]), end(frequency[ch]), 0),
+ samples);
+ ASSERT_EQ(frequency[ch][ch-'a'], samples);
+ }
+}
+
+TEST(WeightedShuffle, SingleNonZeroWeight) {
+ std::array<char, 5> choices{'a', 'b', 'c', 'd', 'e'};
+ std::array<int, 5> weights{0, 42, 0, 0, 0};
+ std::map<char, std::array<unsigned, 5>> frequency {
+ {'a', {0, 0, 0, 0, 0}},
+ {'b', {0, 0, 0, 0, 0}},
+ {'c', {0, 0, 0, 0, 0}},
+ {'d', {0, 0, 0, 0, 0}},
+ {'e', {0, 0, 0, 0, 0}}
+ }; // count each element appearing in each position
+ const int samples = 10000;
+ std::random_device rd;
+ for (auto i = 0; i < samples; i++) {
+ weighted_shuffle(begin(choices), end(choices),
+ begin(weights), end(weights),
+ std::mt19937{rd()});
+ for (size_t j = 0; j < choices.size(); ++j)
+ ++frequency[choices[j]][j];
+ }
+
+ // 'b' is always first
+ ASSERT_EQ(frequency['b'][0], samples);
+}