1
+ #include " util/math.hpp"
2
+ #include < fmt/format.h>
3
+ #include < algorithm>
4
+ #include < catch2/catch_test_macros.hpp>
5
+ #include < catch2/matchers/catch_matchers_floating_point.hpp>
6
+ #include < cstddef>
7
+ #include < random>
8
+ #include < set>
9
+ #include < vector>
10
+
11
+ double compute_p ( size_t k, size_t n )
12
+ {
13
+ if ( k == 0 )
14
+ {
15
+ return 0.0 ;
16
+ }
17
+ else
18
+ {
19
+ double p = 1.0 / ( double ( n ) - 1.0 );
20
+ return p + ( 1.0 - p ) * compute_p ( k - 1 , n - 1 );
21
+ }
22
+ }
23
+
24
+ TEST_CASE ( " Testing sampling functions" )
25
+ {
26
+ std::random_device rd;
27
+ std::mt19937 gen ( rd () );
28
+
29
+ SECTION ( " draw_unique_k_from_n" , " Drawing k numbers out of n" )
30
+ {
31
+
32
+ const size_t N_RUNS = 10000 ;
33
+
34
+ const size_t k = 6 ;
35
+ const size_t n = 100 ;
36
+ const size_t ignore_idx = 11 ;
37
+
38
+ std::vector<size_t > histogram ( n, 0 ); // Count how often each element occurs amongst all samples
39
+
40
+ std::vector<size_t > buffer{};
41
+ for ( size_t i = 0 ; i < N_RUNS; i++ )
42
+ {
43
+ Seldon::draw_unique_k_from_n ( ignore_idx, k, n, buffer, gen );
44
+ for ( const auto & n : buffer )
45
+ {
46
+ histogram[n]++;
47
+ }
48
+ }
49
+
50
+ // In each run there is a probability of p for each element to be selected
51
+ // That means for each histogram bin we have a binomial distribution with p
52
+ double p = compute_p ( k, n );
53
+
54
+ size_t mean = N_RUNS * p;
55
+ // The variance of a binomial distribution is var = n*p*(1-p)
56
+ size_t sigma = std::sqrt ( N_RUNS * p * ( 1.0 - p ) );
57
+
58
+ INFO ( " Binomial distribution parameters" );
59
+ INFO ( fmt::format ( " p = {}" , p ) );
60
+ INFO ( fmt::format ( " mean = {}" , mean ) );
61
+ INFO ( fmt::format ( " sigma = {}" , sigma ) );
62
+
63
+ REQUIRE ( histogram[ignore_idx] == 0 ); // The ignore_idx should never be selected
64
+
65
+ size_t number_outside_three_sigma = 0 ;
66
+ for ( const auto & n : histogram )
67
+ {
68
+ if ( n == 0 )
69
+ {
70
+ continue ;
71
+ }
72
+ INFO ( fmt::format ( " n = {}" , n ) );
73
+ INFO ( fmt::format ( " mean = {}" , mean ) );
74
+ INFO ( fmt::format ( " sigma = {}" , sigma ) );
75
+
76
+ if ( std::abs ( double ( n ) - double ( mean ) ) > 3.0 * sigma )
77
+ {
78
+ number_outside_three_sigma++;
79
+ }
80
+
81
+ REQUIRE_THAT ( n, Catch::Matchers::WithinAbs ( mean, 5 * sigma ) );
82
+ }
83
+
84
+ if ( number_outside_three_sigma > 0.01 * N_RUNS )
85
+ WARN ( fmt::format (
86
+ " Many deviations beyond the 3 sigma range. {} out of {}" , number_outside_three_sigma, N_RUNS ) );
87
+ }
88
+
89
+ SECTION ( " weighted_reservior_sampling" , " Testing weighted reservoir sampling with A_ExpJ algorithm" )
90
+ {
91
+
92
+ const size_t N_RUNS = 10000 ;
93
+
94
+ const size_t k = 6 ;
95
+ const size_t n = 100 ;
96
+ const size_t ignore_idx = 11 ;
97
+ const size_t ignore_idx2 = 29 ;
98
+
99
+ std::vector<size_t > histogram ( n, 0 ); // Count how often each element occurs amongst all samples
100
+
101
+ auto weight_callback = []( size_t idx ) {
102
+ if ( ( idx == ignore_idx ) | ( idx == ignore_idx2 ) )
103
+ {
104
+ return 0.0 ;
105
+ }
106
+ else
107
+ {
108
+ std::abs ( double ( n / 2.0 ) - double ( idx ) );
109
+ }
110
+ };
111
+
112
+ std::vector<size_t > buffer{};
113
+
114
+ for ( size_t i = 0 ; i < N_RUNS; i++ )
115
+ {
116
+ Seldon::reservoir_sampling_A_ExpJ ( k, n, weight_callback, buffer, gen );
117
+ for ( const auto & n : buffer )
118
+ {
119
+ histogram[n]++;
120
+ }
121
+ }
122
+
123
+ REQUIRE ( histogram[ignore_idx] == 0 ); // The ignore_idx should never be selected
124
+ REQUIRE ( histogram[ignore_idx2] == 0 ); // The ignore_idx should never be selected
125
+
126
+ // TODO: histogram and sigma test
127
+ }
128
+ }
0 commit comments