Skip to content

Commit 903c801

Browse files
committed
unit test for some of the sampling functions
1 parent 1cc9b41 commit 903c801

File tree

2 files changed

+130
-1
lines changed

2 files changed

+130
-1
lines changed

meson.build

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ exe = executable('seldon', sources_seldon + 'src/main.cpp',
2424
tests = [
2525
['Test Tarjan', 'test/test_tarjan.cpp'],
2626
['Test DeGroot', 'test/test_deGroot.cpp'],
27-
['Test Network', 'test/test_network.cpp']
27+
['Test Network', 'test/test_network.cpp'],
28+
['Test Sampling', 'test/test_sampling.cpp'],
2829
]
2930

3031
Catch2 = dependency('Catch2', method : 'cmake', modules : ['Catch2::Catch2WithMain', 'Catch2::Catch2'])

test/test_sampling.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#include "util/math.hpp"
2+
#include <fmt/format.h>
3+
#include <algorithm>
4+
#include <catch2/catch_test_macros.hpp>
5+
#include <catch2/matchers/catch_matchers_floating_point.hpp>
6+
#include <cstddef>
7+
#include <random>
8+
#include <set>
9+
#include <vector>
10+
11+
double compute_p( size_t k, size_t n )
12+
{
13+
if( k == 0 )
14+
{
15+
return 0.0;
16+
}
17+
else
18+
{
19+
double p = 1.0 / ( double( n ) - 1.0 );
20+
return p + ( 1.0 - p ) * compute_p( k - 1, n - 1 );
21+
}
22+
}
23+
24+
TEST_CASE( "Testing sampling functions" )
25+
{
26+
std::random_device rd;
27+
std::mt19937 gen( rd() );
28+
29+
SECTION( "draw_unique_k_from_n", "Drawing k numbers out of n" )
30+
{
31+
32+
const size_t N_RUNS = 10000;
33+
34+
const size_t k = 6;
35+
const size_t n = 100;
36+
const size_t ignore_idx = 11;
37+
38+
std::vector<size_t> histogram( n, 0 ); // Count how often each element occurs amongst all samples
39+
40+
std::vector<size_t> buffer{};
41+
for( size_t i = 0; i < N_RUNS; i++ )
42+
{
43+
Seldon::draw_unique_k_from_n( ignore_idx, k, n, buffer, gen );
44+
for( const auto & n : buffer )
45+
{
46+
histogram[n]++;
47+
}
48+
}
49+
50+
// In each run there is a probability of p for each element to be selected
51+
// That means for each histogram bin we have a binomial distribution with p
52+
double p = compute_p( k, n );
53+
54+
size_t mean = N_RUNS * p;
55+
// The variance of a binomial distribution is var = n*p*(1-p)
56+
size_t sigma = std::sqrt( N_RUNS * p * ( 1.0 - p ) );
57+
58+
INFO( "Binomial distribution parameters" );
59+
INFO( fmt::format( " p = {}", p ) );
60+
INFO( fmt::format( " mean = {}", mean ) );
61+
INFO( fmt::format( " sigma = {}", sigma ) );
62+
63+
REQUIRE( histogram[ignore_idx] == 0 ); // The ignore_idx should never be selected
64+
65+
size_t number_outside_three_sigma = 0;
66+
for( const auto & n : histogram )
67+
{
68+
if( n == 0 )
69+
{
70+
continue;
71+
}
72+
INFO( fmt::format( " n = {}", n ) );
73+
INFO( fmt::format( " mean = {}", mean ) );
74+
INFO( fmt::format( " sigma = {}", sigma ) );
75+
76+
if( std::abs( double( n ) - double( mean ) ) > 3.0 * sigma )
77+
{
78+
number_outside_three_sigma++;
79+
}
80+
81+
REQUIRE_THAT( n, Catch::Matchers::WithinAbs( mean, 5 * sigma ) );
82+
}
83+
84+
if( number_outside_three_sigma > 0.01 * N_RUNS )
85+
WARN( fmt::format(
86+
"Many deviations beyond the 3 sigma range. {} out of {}", number_outside_three_sigma, N_RUNS ) );
87+
}
88+
89+
SECTION( "weighted_reservior_sampling", "Testing weighted reservoir sampling with A_ExpJ algorithm" )
90+
{
91+
92+
const size_t N_RUNS = 10000;
93+
94+
const size_t k = 6;
95+
const size_t n = 100;
96+
const size_t ignore_idx = 11;
97+
const size_t ignore_idx2 = 29;
98+
99+
std::vector<size_t> histogram( n, 0 ); // Count how often each element occurs amongst all samples
100+
101+
auto weight_callback = []( size_t idx ) {
102+
if( ( idx == ignore_idx ) | ( idx == ignore_idx2 ) )
103+
{
104+
return 0.0;
105+
}
106+
else
107+
{
108+
std::abs( double( n / 2.0 ) - double( idx ) );
109+
}
110+
};
111+
112+
std::vector<size_t> buffer{};
113+
114+
for( size_t i = 0; i < N_RUNS; i++ )
115+
{
116+
Seldon::reservoir_sampling_A_ExpJ( k, n, weight_callback, buffer, gen );
117+
for( const auto & n : buffer )
118+
{
119+
histogram[n]++;
120+
}
121+
}
122+
123+
REQUIRE( histogram[ignore_idx] == 0 ); // The ignore_idx should never be selected
124+
REQUIRE( histogram[ignore_idx2] == 0 ); // The ignore_idx should never be selected
125+
126+
// TODO: histogram and sigma test
127+
}
128+
}

0 commit comments

Comments
 (0)