summaryrefslogtreecommitdiffstats
path: root/src/test/test_weighted_shuffle.cc
blob: 9f92cbdc09519308745c6d274de922df1ea98244 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*-
// vim: ts=8 sw=2 smarttab

#include "common/weighted_shuffle.h"
#include <array>
#include <map>
#include "gtest/gtest.h"

TEST(WeightedShuffle, Basic) {
  std::array<char, 5> choices{'a', 'b', 'c', 'd', 'e'};
  std::array<int, 5> weights{100, 50, 25, 10, 1};
  std::map<char, std::array<unsigned, 5>> frequency {
    {'a', {0, 0, 0, 0, 0}},
    {'b', {0, 0, 0, 0, 0}},
    {'c', {0, 0, 0, 0, 0}},
    {'d', {0, 0, 0, 0, 0}},
    {'e', {0, 0, 0, 0, 0}}
  }; // count each element appearing in each position
  const int samples = 10000;
  std::random_device rd;
  for (auto i = 0; i < samples; i++) {
    weighted_shuffle(begin(choices), end(choices),
		     begin(weights), end(weights),
		     std::mt19937{rd()});
    for (size_t j = 0; j < choices.size(); ++j)
      ++frequency[choices[j]][j];
  }
  // verify that the probability that the nth choice is selected as the first
  // one is the nth weight divided by the sum of all weights
  const auto total_weight = std::accumulate(weights.begin(), weights.end(), 0);
  constexpr float epsilon = 0.02;
  for (unsigned i = 0; i < choices.size(); i++) {
    const auto& f = frequency[choices[i]];
    const auto& w = weights[i];
    ASSERT_NEAR(float(w) / total_weight,
		float(f.front()) / samples,
		epsilon);
  }
}