From 627040c6ea17c0d93ab2ce67a733d36b5ba57dcd Mon Sep 17 00:00:00 2001 From: Moritz Sallermann Date: Mon, 16 Oct 2023 21:33:09 +0000 Subject: [PATCH] unit test for some of the sampling functions --- meson.build | 3 +- test/test_sampling.cpp | 128 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 test/test_sampling.cpp diff --git a/meson.build b/meson.build index 058cf0b..cd58a5b 100644 --- a/meson.build +++ b/meson.build @@ -24,7 +24,8 @@ exe = executable('seldon', sources_seldon + 'src/main.cpp', tests = [ ['Test Tarjan', 'test/test_tarjan.cpp'], ['Test DeGroot', 'test/test_deGroot.cpp'], - ['Test Network', 'test/test_network.cpp'] + ['Test Network', 'test/test_network.cpp'], + ['Test Sampling', 'test/test_sampling.cpp'], ] Catch2 = dependency('Catch2', method : 'cmake', modules : ['Catch2::Catch2WithMain', 'Catch2::Catch2']) diff --git a/test/test_sampling.cpp b/test/test_sampling.cpp new file mode 100644 index 0000000..3bc64d6 --- /dev/null +++ b/test/test_sampling.cpp @@ -0,0 +1,128 @@ +#include "util/math.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +double compute_p( size_t k, size_t n ) +{ + if( k == 0 ) + { + return 0.0; + } + else + { + double p = 1.0 / ( double( n ) - 1.0 ); + return p + ( 1.0 - p ) * compute_p( k - 1, n - 1 ); + } +} + +TEST_CASE( "Testing sampling functions" ) +{ + std::random_device rd; + std::mt19937 gen( rd() ); + + SECTION( "draw_unique_k_from_n", "Drawing k numbers out of n" ) + { + + const size_t N_RUNS = 10000; + + const size_t k = 6; + const size_t n = 100; + const size_t ignore_idx = 11; + + std::vector histogram( n, 0 ); // Count how often each element occurs amongst all samples + + std::vector buffer{}; + for( size_t i = 0; i < N_RUNS; i++ ) + { + Seldon::draw_unique_k_from_n( ignore_idx, k, n, buffer, gen ); + for( const auto & n : buffer ) + { + histogram[n]++; + } + } + + // In each run there is a probability of p for each element to be selected + // That means for each histogram bin we have a binomial distribution with p + double p = compute_p( k, n ); + + size_t mean = N_RUNS * p; + // The variance of a binomial distribution is var = n*p*(1-p) + size_t sigma = std::sqrt( N_RUNS * p * ( 1.0 - p ) ); + + INFO( "Binomial distribution parameters" ); + INFO( fmt::format( " p = {}", p ) ); + INFO( fmt::format( " mean = {}", mean ) ); + INFO( fmt::format( " sigma = {}", sigma ) ); + + REQUIRE( histogram[ignore_idx] == 0 ); // The ignore_idx should never be selected + + size_t number_outside_three_sigma = 0; + for( const auto & n : histogram ) + { + if( n == 0 ) + { + continue; + } + INFO( fmt::format( " n = {}", n ) ); + INFO( fmt::format( " mean = {}", mean ) ); + INFO( fmt::format( " sigma = {}", sigma ) ); + + if( std::abs( double( n ) - double( mean ) ) > 3.0 * sigma ) + { + number_outside_three_sigma++; + } + + REQUIRE_THAT( n, Catch::Matchers::WithinAbs( mean, 5 * sigma ) ); + } + + if( number_outside_three_sigma > 0.01 * N_RUNS ) + WARN( fmt::format( + "Many deviations beyond the 3 sigma range. {} out of {}", number_outside_three_sigma, N_RUNS ) ); + } + + SECTION( "weighted_reservior_sampling", "Testing weighted reservoir sampling with A_ExpJ algorithm" ) + { + + const size_t N_RUNS = 10000; + + const size_t k = 6; + const size_t n = 100; + const size_t ignore_idx = 11; + const size_t ignore_idx2 = 29; + + std::vector histogram( n, 0 ); // Count how often each element occurs amongst all samples + + auto weight_callback = []( size_t idx ) { + if( ( idx == ignore_idx ) | ( idx == ignore_idx2 ) ) + { + return 0.0; + } + else + { + std::abs( double( n / 2.0 ) - double( idx ) ); + } + }; + + std::vector buffer{}; + + for( size_t i = 0; i < N_RUNS; i++ ) + { + Seldon::reservoir_sampling_A_ExpJ( k, n, weight_callback, buffer, gen ); + for( const auto & n : buffer ) + { + histogram[n]++; + } + } + + REQUIRE( histogram[ignore_idx] == 0 ); // The ignore_idx should never be selected + REQUIRE( histogram[ignore_idx2] == 0 ); // The ignore_idx should never be selected + + // TODO: histogram and sigma test + } +} \ No newline at end of file