Skip to content

Commit 545993d

Browse files
committed
Add an efficient reservoir sampling aggregator
This aggregator uses Li's "Algorithm L", a simple yet efficient sampling method, with modifications to support a monoidal setting. A JMH benchmark was added for both this and the old priority-queue algoritm. In a single-threaded benchmark on an Intel Core i9-10885H, this algorithm can outperform the old one by an order of magnitude or more, depending on the parameters. Because of this, the new algorithm was made the default for Aggregtor.reservoirSample(). Unit tests were added for both algorithms. These are probabilistic and are expected to fail on some 0.1% of times, per test case (p-value is set to 0.001). Optimized overloads of aggregation methods append/appendAll were added that operate on IndexedSeqs. These have efficient random access and allow us to skip over items without examining each one, so sublinear runtime can be achieved.
1 parent 464917d commit 545993d

File tree

8 files changed

+721
-6
lines changed

8 files changed

+721
-6
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package com.twitter.algebird.benchmark
2+
3+
import com.twitter.algebird.mutable.ReservoirSamplingToListAggregator
4+
import com.twitter.algebird.{Aggregator, Preparer}
5+
import org.openjdk.jmh.annotations.{Benchmark, Param, Scope, State}
6+
import org.openjdk.jmh.infra.Blackhole
7+
8+
import scala.util.Random
9+
10+
object ReservoirSamplingBenchmark {
11+
@State(Scope.Benchmark)
12+
class BenchmarkState {
13+
@Param(Array("100", "10000", "1000000"))
14+
var collectionSize: Int = 0
15+
16+
@Param(Array("0.001", "0.01", "0.1"))
17+
var sampleRate: Double = 0.0
18+
19+
def samples: Int = (sampleRate * collectionSize).ceil.toInt
20+
}
21+
22+
val rng = new Random()
23+
implicit val randomSupplier: () => Random = () => rng
24+
}
25+
26+
class ReservoirSamplingBenchmark {
27+
import ReservoirSamplingBenchmark._
28+
29+
private def prioQueueSampler[T](count: Int) =
30+
Preparer[T]
31+
.map(rng.nextDouble() -> _)
32+
.monoidAggregate(Aggregator.sortByTake(count)(_._1))
33+
.andThenPresent(_.map(_._2))
34+
35+
@Benchmark
36+
def timeAlgorithmL(state: BenchmarkState, bh: Blackhole): Unit =
37+
bh.consume(new ReservoirSamplingToListAggregator[Int](state.samples).apply(0 until state.collectionSize))
38+
39+
@Benchmark
40+
def timePriorityQeueue(state: BenchmarkState, bh: Blackhole): Unit =
41+
bh.consume(prioQueueSampler(state.samples).apply(0 until state.collectionSize))
42+
}

algebird-core/src/main/scala/com/twitter/algebird/Aggregator.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.twitter.algebird
22

3+
import com.twitter.algebird.mutable.{Reservoir, ReservoirSamplingToListAggregator}
4+
35
import java.util.PriorityQueue
46
import scala.collection.compat._
57
import scala.collection.generic.CanBuildFrom
@@ -286,12 +288,9 @@ object Aggregator extends java.io.Serializable {
286288
def reservoirSample[T](
287289
count: Int,
288290
seed: Int = DefaultSeed
289-
): MonoidAggregator[T, PriorityQueue[(Double, T)], Seq[T]] = {
290-
val rng = new java.util.Random(seed)
291-
Preparer[T]
292-
.map(rng.nextDouble() -> _)
293-
.monoidAggregate(sortByTake(count)(_._1))
294-
.andThenPresent(_.map(_._2))
291+
): MonoidAggregator[T, Reservoir[T], Seq[T]] = {
292+
val rng = new scala.util.Random(seed)
293+
new ReservoirSamplingToListAggregator[T](count)(() => rng)
295294
}
296295

297296
/**
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
package com.twitter.algebird.mutable
2+
3+
import com.twitter.algebird.{Monoid, MonoidAggregator}
4+
5+
import scala.collection.mutable
6+
import scala.util.Random
7+
8+
/**
9+
* A reservoir of the currently sampled items.
10+
*
11+
* @param capacity
12+
* the reservoir capacity
13+
* @tparam T
14+
* the element type
15+
*/
16+
sealed abstract class Reservoir[T](val capacity: Int) {
17+
require(capacity > 0, "reservoir size must be positive")
18+
19+
def size: Int
20+
def isEmpty: Boolean = size == 0
21+
def isFull: Boolean = size == capacity
22+
23+
// When the reservoir is full, w is the threshold for accepting an element into the reservoir, and
24+
// the following invariant holds: The maximum score of the elements in the reservoir is w,
25+
// and the remaining elements are distributed as U[0, w].
26+
// Scores are not kept explicitly, only their distribution is tracked and sampled from.
27+
// (w = 1 when the reservoir is not full.)
28+
def w: Double = 1
29+
30+
private[algebird] def toSeq: mutable.IndexedSeq[T]
31+
32+
/**
33+
* Add an element to the reservoir. If the reservoir is full then the element will replace a random element
34+
* in the reservoir, and the threshold <pre>w</pre> is updated.
35+
*
36+
* When adding multiple elements, [[append]] should be used to take advantage of exponential jumps.
37+
*
38+
* @param x
39+
* the element to add
40+
* @param rng
41+
* the random source
42+
*/
43+
private[algebird] def accept(x: T, rng: Random): Reservoir[T]
44+
45+
// The number of items to skip before accepting the next item is geometrically distributed
46+
// with probability of success w / prior. The prior will be 1 when adding to a single reservoir,
47+
// but when merging reservoirs it will be the threshold of the reservoir being pulled from,
48+
// and in this case we require that w < prior.
49+
private def nextAcceptTime(rng: Random, prior: Double): Int =
50+
(-rng.self.nextExponential / Math.log1p(-w / prior)).toInt
51+
52+
/**
53+
* Add multiple elements to the reservoir.
54+
*
55+
* @param xs
56+
* the elements to add
57+
* @param rng
58+
* the random source
59+
* @param prior
60+
* the threshold of the elements being added, such that the added element's value is distributed as
61+
* <pre>U[0, prior]</pre>
62+
* @return
63+
* this reservoir
64+
*/
65+
def append(xs: TraversableOnce[T], rng: Random, prior: Double = 1.0): Reservoir[T] =
66+
xs match {
67+
case seq: IndexedSeq[T] => Reservoir.append(this, seq, rng, prior)
68+
case _ => Reservoir.append(this, xs, rng, prior)
69+
}
70+
}
71+
72+
private[algebird] case class Empty[T](k: Int) extends Reservoir[T](k) {
73+
override val size: Int = 0
74+
override val isEmpty: Boolean = true
75+
override val isFull: Boolean = false
76+
77+
override def toSeq: mutable.IndexedSeq[T] = mutable.IndexedSeq()
78+
79+
override def accept(x: T, rng: Random): Reservoir[T] =
80+
Singleton(k, x, if (capacity == 1) rng.nextDouble else 1)
81+
82+
override def toString: String = s"Empty($capacity)"
83+
}
84+
85+
// A reservoir with one element (but possibly higher capacity), optimizing the common case of sampling a
86+
// single element.
87+
private[algebird] case class Singleton[T](k: Int, x: T, w1: Double) extends Reservoir[T](k) {
88+
override val size: Int = 1
89+
override val isEmpty: Boolean = false
90+
override val isFull: Boolean = capacity == 1
91+
override val w: Double = w1
92+
93+
override def toSeq: mutable.IndexedSeq[T] = mutable.IndexedSeq(x)
94+
95+
override def accept(y: T, rng: Random): Reservoir[T] =
96+
if (isFull)
97+
Singleton(k, y, w * rng.nextDouble)
98+
else
99+
new ArrayReservoir(k).accept(x, rng).accept(y, rng)
100+
101+
override def toString: String = s"Singleton($capacity, $w, $x)"
102+
}
103+
104+
// A reservoir backed by a mutable buffer - will mutate by aggregation!
105+
private[algebird] class ArrayReservoir[T](val k: Int) extends Reservoir[T](k) {
106+
private val reservoir: mutable.ArrayBuffer[T] = new mutable.ArrayBuffer(k)
107+
private var w1: Double = 1
108+
private val kInv: Double = 1d / capacity
109+
110+
override def size: Int = reservoir.size
111+
override def w: Double = w1
112+
override def toSeq: mutable.IndexedSeq[T] = reservoir
113+
114+
override def accept(x: T, rng: Random): Reservoir[T] = {
115+
if (isFull) {
116+
reservoir(rng.nextInt(capacity)) = x
117+
} else {
118+
reservoir.append(x)
119+
}
120+
if (isFull) {
121+
w1 *= Math.pow(rng.nextDouble, kInv)
122+
}
123+
this
124+
}
125+
126+
override def toString: String = s"ArrayReservoir($capacity, $w, ${reservoir.toList})"
127+
}
128+
129+
object Reservoir {
130+
def apply[T](capacity: Int): Reservoir[T] = Empty(capacity)
131+
132+
private[algebird] def append[T](
133+
self: Reservoir[T],
134+
xs: TraversableOnce[T],
135+
rng: Random,
136+
prior: Double
137+
): Reservoir[T] = {
138+
var res = self
139+
var skip = if (res.isFull) res.nextAcceptTime(rng, prior) else 0
140+
for (x <- xs) {
141+
if (!res.isFull) {
142+
// keep adding until reservoir is full
143+
res = res.accept(x, rng)
144+
if (res.isFull) {
145+
skip = res.nextAcceptTime(rng, prior)
146+
}
147+
} else if (skip > 0) {
148+
skip -= 1
149+
} else {
150+
res = res.accept(x, rng)
151+
skip = res.nextAcceptTime(rng, prior)
152+
}
153+
}
154+
res
155+
}
156+
157+
/**
158+
* Add multiple elements to the reservoir. This overload is optimized for indexed sequences, where we can
159+
* skip over multiple indexes without accessing the elements.
160+
*
161+
* @param xs
162+
* the elements to add
163+
* @param rng
164+
* the random source
165+
* @param prior
166+
* the threshold of the elements being added, such that the added element's value is distributed as
167+
* <pre>U[0, prior]</pre>
168+
* @return
169+
* this reservoir
170+
*/
171+
private[algebird] def append[T](
172+
self: Reservoir[T],
173+
xs: IndexedSeq[T],
174+
rng: Random,
175+
prior: Double
176+
): Reservoir[T] = {
177+
var res = self
178+
var i = xs.size.min(res.capacity - res.size)
179+
for (j <- 0 until i) {
180+
res = res.accept(xs(j), rng)
181+
}
182+
183+
val end = xs.size
184+
assert(res.isFull || i == end)
185+
while (i >= 0 && i < end) {
186+
i += res.nextAcceptTime(rng, prior)
187+
// the addition can overflow, in which case i < 0
188+
if (i >= 0 && i < end) {
189+
// element enters the reservoir
190+
res = res.accept(xs(i), rng)
191+
i += 1
192+
}
193+
}
194+
res
195+
}
196+
}
197+
198+
/**
199+
* This is the "Algorithm L" reservoir sampling algorithm [1], with modifications to act as a monoid by
200+
* merging reservoirs.
201+
*
202+
* [1] Kim-Hung Li, "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))", 1994
203+
*
204+
* @tparam T
205+
* the item type
206+
*/
207+
class ReservoirMonoid[T](val capacity: Int)(implicit val randomSupplier: () => Random)
208+
extends Monoid[Reservoir[T]] {
209+
210+
override def zero: Reservoir[T] = Reservoir(capacity)
211+
override def isNonZero(r: Reservoir[T]): Boolean = !r.isEmpty
212+
213+
/**
214+
* Merge two reservoirs. NOTE: This mutates one or both of the reservoirs. They should not be used after
215+
* this operation, except when using the return value for further aggregation.
216+
*/
217+
override def plus(left: Reservoir[T], right: Reservoir[T]): Reservoir[T] =
218+
if (left.isEmpty) right
219+
else if (left.size + right.size <= left.capacity) {
220+
// the sum of the sizes is less than the reservoir size, so we can just merge
221+
left.append(right.toSeq, randomSupplier())
222+
} else {
223+
val (s1, s2) = if (left.w < right.w) (left, right) else (right, left)
224+
val xs = s2.toSeq
225+
val rng = randomSupplier()
226+
if (s2.isFull) {
227+
assert(s1.isFull)
228+
// The highest score in s2 is w, and the other scores are distributed as U[0, w].
229+
// Since s1.w < s2.w, we have to drop the single (sampled) element with the highest score
230+
// unconditionally. The other elements enter the reservoir with probability s1.w / s2.w.
231+
val i = rng.nextInt(xs.size)
232+
xs(i) = xs.head
233+
s1.append(xs.drop(1), rng, s2.w)
234+
} else {
235+
s1.append(xs, rng)
236+
}
237+
}
238+
}
239+
240+
/**
241+
* An aggregator that uses reservoir sampling to sample k elements from a stream of items. Because the
242+
* reservoir is mutable, it is a good idea to copy the result to an immutable view before using it, as is done
243+
* by [[ReservoirSamplingToListAggregator]].
244+
*
245+
* @param capacity
246+
* the number of elements to sample
247+
* @param randomSupplier
248+
* the random generator
249+
* @tparam T
250+
* the item type
251+
* @tparam C
252+
* the result type
253+
*/
254+
abstract class ReservoirSamplingAggregator[T, +C](capacity: Int)(implicit val randomSupplier: () => Random)
255+
extends MonoidAggregator[T, Reservoir[T], C] {
256+
override val monoid: ReservoirMonoid[T] = new ReservoirMonoid(capacity)
257+
override def prepare(x: T): Reservoir[T] = Reservoir(capacity).accept(x, randomSupplier())
258+
259+
override def apply(xs: TraversableOnce[T]): C = present(agg(xs))
260+
261+
override def applyOption(inputs: TraversableOnce[T]): Option[C] =
262+
if (inputs.isEmpty) None else Some(apply(inputs))
263+
264+
override def append(r: Reservoir[T], t: T): Reservoir[T] = r.append(Seq(t), randomSupplier())
265+
266+
override def appendAll(r: Reservoir[T], xs: TraversableOnce[T]): Reservoir[T] =
267+
r.append(xs, randomSupplier())
268+
269+
override def appendAll(xs: TraversableOnce[T]): Reservoir[T] = agg(xs)
270+
271+
private def agg(xs: TraversableOnce[T]): Reservoir[T] =
272+
appendAll(monoid.zero, xs)
273+
}
274+
275+
class ReservoirSamplingToListAggregator[T](capacity: Int)(implicit randomSupplier: () => Random)
276+
extends ReservoirSamplingAggregator[T, List[T]](capacity)(randomSupplier) {
277+
override def present(r: Reservoir[T]): List[T] =
278+
randomSupplier().shuffle(r.toSeq).toList
279+
280+
override def andThenPresent[D](f: List[T] => D): MonoidAggregator[T, Reservoir[T], D] =
281+
new AndThenPresent(this, f)
282+
}
283+
284+
/**
285+
* Monoid that implements [[andThenPresent]] without ruining the optimized behavior of the aggregator.
286+
*/
287+
private[algebird] class AndThenPresent[-A, B, C, +D](val agg: MonoidAggregator[A, B, C], f: C => D)
288+
extends MonoidAggregator[A, B, D] {
289+
override val monoid: Monoid[B] = agg.monoid
290+
override def prepare(a: A): B = agg.prepare(a)
291+
override def present(b: B): D = f(agg.present(b))
292+
293+
override def apply(xs: TraversableOnce[A]): D = f(agg(xs))
294+
override def applyOption(xs: TraversableOnce[A]): Option[D] = agg.applyOption(xs).map(f)
295+
override def append(b: B, a: A): B = agg.append(b, a)
296+
override def appendAll(b: B, as: TraversableOnce[A]): B = agg.appendAll(b, as)
297+
override def appendAll(as: TraversableOnce[A]): B = agg.appendAll(as)
298+
}

0 commit comments

Comments
 (0)