Skip to content

Commit fa696b0

Browse files
committed
Single threaded K-Means training no longer uses a ForkJoinPool (#197)
* Stop using a FJP for single threaded k-means training. * Fixing a bug in weighted k-means calculations. * Adding a note about FJP requiring the modifyThread permission to KMeansTrainer's javadoc. * Adding a custom ForkJoinWorkerThreadFactory to KMeansTrainer and KNNModel to make them work with custom security managers. * Tightening the check so the custom thread factory is only used when there is a security manager.
1 parent ac9802e commit fa696b0

File tree

2 files changed

+139
-94
lines changed

2 files changed

+139
-94
lines changed

Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java

Lines changed: 118 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
3535
import org.tribuo.util.Util;
3636

37+
import java.security.AccessController;
38+
import java.security.PrivilegedAction;
3739
import java.time.OffsetDateTime;
3840
import java.util.ArrayList;
3941
import java.util.Arrays;
@@ -45,7 +47,9 @@
4547
import java.util.SplittableRandom;
4648
import java.util.concurrent.ExecutionException;
4749
import java.util.concurrent.ForkJoinPool;
50+
import java.util.concurrent.ForkJoinWorkerThread;
4851
import java.util.concurrent.atomic.AtomicInteger;
52+
import java.util.function.Consumer;
4953
import java.util.logging.Level;
5054
import java.util.logging.Logger;
5155
import java.util.stream.IntStream;
@@ -63,6 +67,10 @@
6367
* of threads used in the training step. The thread pool is local to an invocation of train,
6468
* so there can be multiple concurrent trainings.
6569
* <p>
70+
* Note parallel training uses a {@link ForkJoinPool} which requires that the Tribuo codebase
71+
* is given the "modifyThread" and "modifyThreadGroup" privileges when running under a
72+
* {@link java.lang.SecurityManager}.
73+
* <p>
6674
* See:
6775
* <pre>
6876
* J. Friedman, T. Hastie, &amp; R. Tibshirani.
@@ -80,6 +88,9 @@
8088
public class KMeansTrainer implements Trainer<ClusterID> {
8189
private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName());
8290

91+
// Thread factory for the FJP, to allow use with OpenSearch's SecureSM
92+
private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory();
93+
8394
/**
8495
* Possible distance functions.
8596
*/
@@ -138,8 +149,7 @@ public enum Initialisation {
138149
/**
139150
* for olcut.
140151
*/
141-
private KMeansTrainer() {
142-
}
152+
private KMeansTrainer() { }
143153

144154
/**
145155
* Constructs a K-Means trainer using the supplied parameters and the default random initialisation.
@@ -194,7 +204,17 @@ public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> ru
194204
}
195205
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
196206

197-
ForkJoinPool fjp = new ForkJoinPool(numThreads);
207+
boolean parallel = numThreads > 1;
208+
ForkJoinPool fjp;
209+
if (parallel) {
210+
if (System.getSecurityManager() == null) {
211+
fjp = new ForkJoinPool(numThreads);
212+
} else {
213+
fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false);
214+
}
215+
} else {
216+
fjp = null;
217+
}
198218

199219
int[] oldCentre = new int[examples.size()];
200220
SparseVector[] data = new SparseVector[examples.size()];
@@ -213,62 +233,65 @@ public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> ru
213233
centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG);
214234
break;
215235
case PLUSPLUS:
216-
centroidVectors = initialisePlusPlusCentroids(centroids, data, featureMap, localRNG, distanceType);
236+
centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, distanceType);
217237
break;
218238
default:
219239
throw new IllegalStateException("Unknown initialisation" + initialisationType);
220240
}
221241

222242
Map<Integer, List<Integer>> clusterAssignments = new HashMap<>();
223243
for (int i = 0; i < centroids; i++) {
224-
clusterAssignments.put(i, Collections.synchronizedList(new ArrayList<>()));
244+
clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>());
225245
}
226246

247+
AtomicInteger changeCounter = new AtomicInteger(0);
248+
Consumer<IntAndVector> eStepFunc = (IntAndVector e) -> {
249+
double minDist = Double.POSITIVE_INFINITY;
250+
int clusterID = -1;
251+
int id = e.idx;
252+
SGDVector vector = e.vector;
253+
for (int j = 0; j < centroids; j++) {
254+
DenseVector cluster = centroidVectors[j];
255+
double distance = getDistance(cluster, vector, distanceType);
256+
if (distance < minDist) {
257+
minDist = distance;
258+
clusterID = j;
259+
}
260+
}
261+
262+
clusterAssignments.get(clusterID).add(id);
263+
if (oldCentre[id] != clusterID) {
264+
// Changed the centroid of this vector.
265+
oldCentre[id] = clusterID;
266+
changeCounter.incrementAndGet();
267+
}
268+
};
269+
227270
boolean converged = false;
228271

229272
for (int i = 0; (i < iterations) && !converged; i++) {
230-
//logger.log(Level.INFO,"Beginning iteration " + i);
231-
AtomicInteger changeCounter = new AtomicInteger(0);
273+
logger.log(Level.FINE,"Beginning iteration " + i);
274+
changeCounter.set(0);
232275

233276
for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
234277
e.getValue().clear();
235278
}
236279

237280
// E step
238-
Stream<SparseVector> vecStream = Arrays.stream(data);
281+
Stream<SGDVector> vecStream = Arrays.stream(data);
239282
Stream<Integer> intStream = IntStream.range(0, data.length).boxed();
240-
Stream<IntAndVector> eStream;
241-
if (numThreads > 1) {
242-
eStream = StreamUtil.boundParallelism(StreamUtil.zip(intStream, vecStream, IntAndVector::new).parallel());
283+
Stream<IntAndVector> zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new);
284+
if (parallel) {
285+
Stream<IntAndVector> parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel());
286+
try {
287+
fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get();
288+
} catch (InterruptedException | ExecutionException e) {
289+
throw new RuntimeException("Parallel execution failed", e);
290+
}
243291
} else {
244-
eStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new);
245-
}
246-
try {
247-
fjp.submit(() -> eStream.forEach((IntAndVector e) -> {
248-
double minDist = Double.POSITIVE_INFINITY;
249-
int clusterID = -1;
250-
int id = e.idx;
251-
SparseVector vector = e.vector;
252-
for (int j = 0; j < centroids; j++) {
253-
DenseVector cluster = centroidVectors[j];
254-
double distance = getDistance(cluster, vector, distanceType);
255-
if (distance < minDist) {
256-
minDist = distance;
257-
clusterID = j;
258-
}
259-
}
260-
261-
clusterAssignments.get(clusterID).add(id);
262-
if (oldCentre[id] != clusterID) {
263-
// Changed the centroid of this vector.
264-
oldCentre[id] = clusterID;
265-
changeCounter.incrementAndGet();
266-
}
267-
})).get();
268-
} catch (InterruptedException | ExecutionException e) {
269-
throw new RuntimeException("Parallel execution failed", e);
292+
zipStream.forEach(eStepFunc);
270293
}
271-
//logger.log(Level.INFO, "E step completed. " + changeCounter.get() + " words updated.");
294+
logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated.");
272295

273296
mStep(fjp, centroidVectors, clusterAssignments, data, weights);
274297

@@ -333,18 +356,15 @@ private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableF
333356
*
334357
* @param centroids The number of centroids to create.
335358
* @param data The dataset of {@link SparseVector} to use.
336-
* @param featureMap The feature map to use for centroid sampling.
337359
* @param rng The RNG to use.
338360
* @return A {@link DenseVector} array of centroids.
339361
*/
340-
private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVector[] data,
341-
ImmutableFeatureMap featureMap, SplittableRandom rng,
362+
private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVector[] data, SplittableRandom rng,
342363
Distance distanceType) {
343364
if (centroids > data.length) {
344365
throw new IllegalArgumentException("The number of centroids may not exceed the number of samples.");
345366
}
346367

347-
int numFeatures = featureMap.size();
348368
double[] minDistancePerVector = new double[data.length];
349369
Arrays.fill(minDistancePerVector, Double.POSITIVE_INFINITY);
350370

@@ -353,7 +373,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
353373
DenseVector[] centroidVectors = new DenseVector[centroids];
354374

355375
// set first centroid randomly from the data
356-
centroidVectors[0] = getRandomCentroidFromData(data, numFeatures, rng);
376+
centroidVectors[0] = getRandomCentroidFromData(data, rng);
357377

358378
// Set each uninitialised centroid remaining
359379
for (int i = 1; i < centroids; i++) {
@@ -362,8 +382,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
362382
// go through every vector and see if the min distance to the
363383
// newest centroid is smaller than previous min distance for vec
364384
for (int j = 0; j < data.length; j++) {
365-
SparseVector curVec = data[j];
366-
double tempDistance = getDistance(prevCentroid, curVec, distanceType);
385+
double tempDistance = getDistance(prevCentroid, data[j], distanceType);
367386
minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance);
368387
}
369388

@@ -382,7 +401,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
382401
// sample from probabilities to get the new centroid from data
383402
double[] cdf = Util.generateCDF(probabilities);
384403
int idx = Util.sampleFromCDF(cdf, rng);
385-
centroidVectors[i] = sparseToDense(data[idx], numFeatures);
404+
centroidVectors[i] = data[idx].densify();
386405
}
387406
return centroidVectors;
388407
}
@@ -391,39 +410,22 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
391410
* Randomly select a piece of data as the starting centroid.
392411
*
393412
* @param data The dataset of {@link SparseVector} to use.
394-
* @param numFeatures The number of features.
395413
* @param rng The RNG to use.
396414
* @return A {@link DenseVector} representing a centroid.
397415
*/
398-
private static DenseVector getRandomCentroidFromData(SparseVector[] data,
399-
int numFeatures, SplittableRandom rng) {
400-
int rand_idx = rng.nextInt(data.length);
401-
return sparseToDense(data[rand_idx], numFeatures);
402-
}
403-
404-
/**
405-
* Create a {@link DenseVector} from the data contained in a
406-
* {@link SparseVector}.
407-
*
408-
* @param vec The {@link SparseVector} to be transformed.
409-
* @param numFeatures The number of features.
410-
* @return A {@link DenseVector} containing the information from vec.
411-
*/
412-
private static DenseVector sparseToDense(SparseVector vec, int numFeatures) {
413-
DenseVector dense = new DenseVector(numFeatures);
414-
dense.intersectAndAddInPlace(vec);
415-
return dense;
416+
private static DenseVector getRandomCentroidFromData(SparseVector[] data, SplittableRandom rng) {
417+
int randIdx = rng.nextInt(data.length);
418+
return data[randIdx].densify();
416419
}
417420

418421
/**
419-
*
422+
* Compute the distance between the two vectors.
420423
* @param cluster A {@link DenseVector} representing a centroid.
421424
* @param vector A {@link SGDVector} representing an example.
422425
* @param distanceType The distance metric to employ.
423426
* @return A double representing the distance from vector to centroid.
424427
*/
425-
private static double getDistance(DenseVector cluster, SGDVector vector,
426-
Distance distanceType) {
428+
private static double getDistance(DenseVector cluster, SGDVector vector, Distance distanceType) {
427429
double distance;
428430
switch (distanceType) {
429431
case EUCLIDEAN:
@@ -441,30 +443,41 @@ private static double getDistance(DenseVector cluster, SGDVector vector,
441443
return distance;
442444
}
443445

446+
/**
447+
* Runs the mStep, writing to the {@code centroidVectors} array.
448+
* @param fjp The ForkJoinPool to run the computation in if it should be executed in parallel.
449+
* If the fjp is null then the computation is executed sequentially on the main thread.
450+
* @param centroidVectors The centroid vectors to write out.
451+
* @param clusterAssignments The current cluster assignments.
452+
* @param data The data points.
453+
* @param weights The example weights.
454+
*/
444455
protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer, List<Integer>> clusterAssignments, SparseVector[] data, double[] weights) {
445456
// M step
446-
Stream<Entry<Integer, List<Integer>>> mStream;
447-
if (numThreads > 1) {
448-
mStream = StreamUtil.boundParallelism(clusterAssignments.entrySet().stream().parallel());
457+
Consumer<Entry<Integer, List<Integer>>> mStepFunc = (e) -> {
458+
DenseVector newCentroid = centroidVectors[e.getKey()];
459+
newCentroid.fill(0.0);
460+
461+
double weightSum = 0.0;
462+
for (Integer idx : e.getValue()) {
463+
newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]);
464+
weightSum += weights[idx];
465+
}
466+
if (weightSum != 0.0) {
467+
newCentroid.scaleInPlace(1.0 / weightSum);
468+
}
469+
};
470+
471+
Stream<Entry<Integer, List<Integer>>> mStream = clusterAssignments.entrySet().stream();
472+
if (fjp != null) {
473+
Stream<Entry<Integer, List<Integer>>> parallelMStream = StreamUtil.boundParallelism(mStream.parallel());
474+
try {
475+
fjp.submit(() -> parallelMStream.forEach(mStepFunc)).get();
476+
} catch (InterruptedException | ExecutionException e) {
477+
throw new RuntimeException("Parallel execution failed", e);
478+
}
449479
} else {
450-
mStream = clusterAssignments.entrySet().stream();
451-
}
452-
try {
453-
fjp.submit(() -> mStream.forEach((e) -> {
454-
DenseVector newCentroid = centroidVectors[e.getKey()];
455-
newCentroid.fill(0.0);
456-
457-
int counter = 0;
458-
for (Integer idx : e.getValue()) {
459-
newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]);
460-
counter++;
461-
}
462-
if (counter > 0) {
463-
newCentroid.scaleInPlace(1.0 / counter);
464-
}
465-
})).get();
466-
} catch (InterruptedException | ExecutionException e) {
467-
throw new RuntimeException("Parallel execution failed", e);
480+
mStream.forEach(mStepFunc);
468481
}
469482
}
470483

@@ -483,11 +496,25 @@ public TrainerProvenance getProvenance() {
483496
*/
484497
static class IntAndVector {
485498
final int idx;
486-
final SparseVector vector;
499+
final SGDVector vector;
487500

488-
public IntAndVector(int idx, SparseVector vector) {
501+
/**
502+
* Constructs an index and vector tuple.
503+
* @param idx The index.
504+
* @param vector The vector.
505+
*/
506+
public IntAndVector(int idx, SGDVector vector) {
489507
this.idx = idx;
490508
this.vector = vector;
491509
}
492510
}
511+
512+
/**
513+
* Used to allow FJPs to work with OpenSearch's SecureSM.
514+
*/
515+
private static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {
516+
public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
517+
return AccessController.doPrivileged((PrivilegedAction<ForkJoinWorkerThread>) () -> new ForkJoinWorkerThread(pool) {});
518+
}
519+
}
493520
}

0 commit comments

Comments
 (0)