Skip to content

Commit 1508ea6

Browse files
author
Moritz Sallermann
committed
implemented weighted reservoir sampling in util/math.hpp
1 parent 334c7d3 commit 1508ea6

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

include/util/math.hpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#pragma once
22
#include <algorithm>
33
#include <cstddef>
4+
#include <queue>
45
#include <random>
6+
#include <utility>
57
#include <vector>
68

79
namespace Seldon
@@ -53,4 +55,50 @@ inline void draw_unique_k_from_n(
5355
std::sample( SequenceGenerator( 0, ignore_idx ), SequenceGenerator( n, ignore_idx ), buffer.begin(), k, gen );
5456
}
5557

58+
template<typename WeightCallbackT>
59+
void reservoir_sampling_A_ExpJ(
60+
size_t k, size_t n, WeightCallbackT weight, std::vector<std::size_t> & buffer, std::mt19937 & mt )
61+
{
62+
std::uniform_real_distribution<double> distribution( 0.0, 1.0 );
63+
64+
std::vector<size_t> reservoir( k );
65+
using QueueItemT = std::pair<size_t, double>;
66+
67+
auto compare = []( const QueueItemT & item1, const QueueItemT & item2 ) { return item1.second > item2.second; };
68+
std::priority_queue<QueueItemT, std::vector<QueueItemT>, decltype( compare )> H;
69+
70+
size_t idx = 0;
71+
while( idx < n & H.size() < k )
72+
{
73+
double r = std::pow( distribution( mt ), 1.0 / weight( idx ) );
74+
H.push( { idx, r } );
75+
idx++;
76+
}
77+
78+
auto X = std::log( distribution( mt ) ) / std::log( H.top().second );
79+
while( idx < n )
80+
{
81+
auto w = weight( idx );
82+
X -= w;
83+
if( X <= 0 )
84+
{
85+
auto t = std::pow( H.top().second, w );
86+
auto uniform_from_t_to_one = distribution( mt ) * ( 1.0 - t ) + t; // Random number in interval [t, 1.0]
87+
auto r = std::pow( uniform_from_t_to_one, 1.0 / w );
88+
H.pop();
89+
H.push( { idx, r } );
90+
X = std::log( distribution( mt ) ) / std::log( H.top().second );
91+
}
92+
idx++;
93+
}
94+
95+
buffer.resize( H.size() );
96+
97+
for( size_t i = 0; i < k; i++ )
98+
{
99+
buffer[i] = H.top().first;
100+
H.pop();
101+
}
102+
}
103+
56104
} // namespace Seldon

0 commit comments

Comments
 (0)