34
34
import org .tribuo .provenance .impl .TrainerProvenanceImpl ;
35
35
import org .tribuo .util .Util ;
36
36
37
+ import java .security .AccessController ;
38
+ import java .security .PrivilegedAction ;
37
39
import java .time .OffsetDateTime ;
38
40
import java .util .ArrayList ;
39
41
import java .util .Arrays ;
45
47
import java .util .SplittableRandom ;
46
48
import java .util .concurrent .ExecutionException ;
47
49
import java .util .concurrent .ForkJoinPool ;
50
+ import java .util .concurrent .ForkJoinWorkerThread ;
48
51
import java .util .concurrent .atomic .AtomicInteger ;
52
+ import java .util .function .Consumer ;
49
53
import java .util .logging .Level ;
50
54
import java .util .logging .Logger ;
51
55
import java .util .stream .IntStream ;
63
67
* of threads used in the training step. The thread pool is local to an invocation of train,
64
68
* so there can be multiple concurrent trainings.
65
69
* <p>
70
+ * Note parallel training uses a {@link ForkJoinPool} which requires that the Tribuo codebase
71
+ * is given the "modifyThread" and "modifyThreadGroup" privileges when running under a
72
+ * {@link java.lang.SecurityManager}.
73
+ * <p>
66
74
* See:
67
75
* <pre>
68
76
* J. Friedman, T. Hastie, & R. Tibshirani.
80
88
public class KMeansTrainer implements Trainer <ClusterID > {
81
89
private static final Logger logger = Logger .getLogger (KMeansTrainer .class .getName ());
82
90
91
+ // Thread factory for the FJP, to allow use with OpenSearch's SecureSM
92
+ private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory ();
93
+
83
94
/**
84
95
* Possible distance functions.
85
96
*/
@@ -138,8 +149,7 @@ public enum Initialisation {
138
149
/**
139
150
* for olcut.
140
151
*/
141
- private KMeansTrainer () {
142
- }
152
+ private KMeansTrainer () { }
143
153
144
154
/**
145
155
* Constructs a K-Means trainer using the supplied parameters and the default random initialisation.
@@ -194,7 +204,17 @@ public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> ru
194
204
}
195
205
ImmutableFeatureMap featureMap = examples .getFeatureIDMap ();
196
206
197
- ForkJoinPool fjp = new ForkJoinPool (numThreads );
207
+ boolean parallel = numThreads > 1 ;
208
+ ForkJoinPool fjp ;
209
+ if (parallel ) {
210
+ if (System .getSecurityManager () == null ) {
211
+ fjp = new ForkJoinPool (numThreads );
212
+ } else {
213
+ fjp = new ForkJoinPool (numThreads , THREAD_FACTORY , null , false );
214
+ }
215
+ } else {
216
+ fjp = null ;
217
+ }
198
218
199
219
int [] oldCentre = new int [examples .size ()];
200
220
SparseVector [] data = new SparseVector [examples .size ()];
@@ -213,62 +233,65 @@ public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> ru
213
233
centroidVectors = initialiseRandomCentroids (centroids , featureMap , localRNG );
214
234
break ;
215
235
case PLUSPLUS :
216
- centroidVectors = initialisePlusPlusCentroids (centroids , data , featureMap , localRNG , distanceType );
236
+ centroidVectors = initialisePlusPlusCentroids (centroids , data , localRNG , distanceType );
217
237
break ;
218
238
default :
219
239
throw new IllegalStateException ("Unknown initialisation" + initialisationType );
220
240
}
221
241
222
242
Map <Integer , List <Integer >> clusterAssignments = new HashMap <>();
223
243
for (int i = 0 ; i < centroids ; i ++) {
224
- clusterAssignments .put (i , Collections .synchronizedList (new ArrayList <>()));
244
+ clusterAssignments .put (i , parallel ? Collections .synchronizedList (new ArrayList <>()) : new ArrayList <>( ));
225
245
}
226
246
247
+ AtomicInteger changeCounter = new AtomicInteger (0 );
248
+ Consumer <IntAndVector > eStepFunc = (IntAndVector e ) -> {
249
+ double minDist = Double .POSITIVE_INFINITY ;
250
+ int clusterID = -1 ;
251
+ int id = e .idx ;
252
+ SGDVector vector = e .vector ;
253
+ for (int j = 0 ; j < centroids ; j ++) {
254
+ DenseVector cluster = centroidVectors [j ];
255
+ double distance = getDistance (cluster , vector , distanceType );
256
+ if (distance < minDist ) {
257
+ minDist = distance ;
258
+ clusterID = j ;
259
+ }
260
+ }
261
+
262
+ clusterAssignments .get (clusterID ).add (id );
263
+ if (oldCentre [id ] != clusterID ) {
264
+ // Changed the centroid of this vector.
265
+ oldCentre [id ] = clusterID ;
266
+ changeCounter .incrementAndGet ();
267
+ }
268
+ };
269
+
227
270
boolean converged = false ;
228
271
229
272
for (int i = 0 ; (i < iterations ) && !converged ; i ++) {
230
- // logger.log(Level.INFO ,"Beginning iteration " + i);
231
- AtomicInteger changeCounter = new AtomicInteger (0 );
273
+ logger .log (Level .FINE ,"Beginning iteration " + i );
274
+ changeCounter . set (0 );
232
275
233
276
for (Entry <Integer , List <Integer >> e : clusterAssignments .entrySet ()) {
234
277
e .getValue ().clear ();
235
278
}
236
279
237
280
// E step
238
- Stream <SparseVector > vecStream = Arrays .stream (data );
281
+ Stream <SGDVector > vecStream = Arrays .stream (data );
239
282
Stream <Integer > intStream = IntStream .range (0 , data .length ).boxed ();
240
- Stream <IntAndVector > eStream ;
241
- if (numThreads > 1 ) {
242
- eStream = StreamUtil .boundParallelism (StreamUtil .zip (intStream , vecStream , IntAndVector ::new ).parallel ());
283
+ Stream <IntAndVector > zipStream = StreamUtil .zip (intStream , vecStream , IntAndVector ::new );
284
+ if (parallel ) {
285
+ Stream <IntAndVector > parallelZipStream = StreamUtil .boundParallelism (zipStream .parallel ());
286
+ try {
287
+ fjp .submit (() -> parallelZipStream .forEach (eStepFunc )).get ();
288
+ } catch (InterruptedException | ExecutionException e ) {
289
+ throw new RuntimeException ("Parallel execution failed" , e );
290
+ }
243
291
} else {
244
- eStream = StreamUtil .zip (intStream , vecStream , IntAndVector ::new );
245
- }
246
- try {
247
- fjp .submit (() -> eStream .forEach ((IntAndVector e ) -> {
248
- double minDist = Double .POSITIVE_INFINITY ;
249
- int clusterID = -1 ;
250
- int id = e .idx ;
251
- SparseVector vector = e .vector ;
252
- for (int j = 0 ; j < centroids ; j ++) {
253
- DenseVector cluster = centroidVectors [j ];
254
- double distance = getDistance (cluster , vector , distanceType );
255
- if (distance < minDist ) {
256
- minDist = distance ;
257
- clusterID = j ;
258
- }
259
- }
260
-
261
- clusterAssignments .get (clusterID ).add (id );
262
- if (oldCentre [id ] != clusterID ) {
263
- // Changed the centroid of this vector.
264
- oldCentre [id ] = clusterID ;
265
- changeCounter .incrementAndGet ();
266
- }
267
- })).get ();
268
- } catch (InterruptedException | ExecutionException e ) {
269
- throw new RuntimeException ("Parallel execution failed" , e );
292
+ zipStream .forEach (eStepFunc );
270
293
}
271
- // logger.log(Level.INFO , "E step completed. " + changeCounter.get() + " words updated.");
294
+ logger .log (Level .FINE , "E step completed. " + changeCounter .get () + " words updated." );
272
295
273
296
mStep (fjp , centroidVectors , clusterAssignments , data , weights );
274
297
@@ -333,18 +356,15 @@ private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableF
333
356
*
334
357
* @param centroids The number of centroids to create.
335
358
* @param data The dataset of {@link SparseVector} to use.
336
- * @param featureMap The feature map to use for centroid sampling.
337
359
* @param rng The RNG to use.
338
360
* @return A {@link DenseVector} array of centroids.
339
361
*/
340
- private static DenseVector [] initialisePlusPlusCentroids (int centroids , SparseVector [] data ,
341
- ImmutableFeatureMap featureMap , SplittableRandom rng ,
362
+ private static DenseVector [] initialisePlusPlusCentroids (int centroids , SparseVector [] data , SplittableRandom rng ,
342
363
Distance distanceType ) {
343
364
if (centroids > data .length ) {
344
365
throw new IllegalArgumentException ("The number of centroids may not exceed the number of samples." );
345
366
}
346
367
347
- int numFeatures = featureMap .size ();
348
368
double [] minDistancePerVector = new double [data .length ];
349
369
Arrays .fill (minDistancePerVector , Double .POSITIVE_INFINITY );
350
370
@@ -353,7 +373,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
353
373
DenseVector [] centroidVectors = new DenseVector [centroids ];
354
374
355
375
// set first centroid randomly from the data
356
- centroidVectors [0 ] = getRandomCentroidFromData (data , numFeatures , rng );
376
+ centroidVectors [0 ] = getRandomCentroidFromData (data , rng );
357
377
358
378
// Set each uninitialised centroid remaining
359
379
for (int i = 1 ; i < centroids ; i ++) {
@@ -362,8 +382,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
362
382
// go through every vector and see if the min distance to the
363
383
// newest centroid is smaller than previous min distance for vec
364
384
for (int j = 0 ; j < data .length ; j ++) {
365
- SparseVector curVec = data [j ];
366
- double tempDistance = getDistance (prevCentroid , curVec , distanceType );
385
+ double tempDistance = getDistance (prevCentroid , data [j ], distanceType );
367
386
minDistancePerVector [j ] = Math .min (minDistancePerVector [j ], tempDistance );
368
387
}
369
388
@@ -382,7 +401,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
382
401
// sample from probabilities to get the new centroid from data
383
402
double [] cdf = Util .generateCDF (probabilities );
384
403
int idx = Util .sampleFromCDF (cdf , rng );
385
- centroidVectors [i ] = sparseToDense ( data [idx ], numFeatures );
404
+ centroidVectors [i ] = data [idx ]. densify ( );
386
405
}
387
406
return centroidVectors ;
388
407
}
@@ -391,39 +410,22 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
391
410
* Randomly select a piece of data as the starting centroid.
392
411
*
393
412
* @param data The dataset of {@link SparseVector} to use.
394
- * @param numFeatures The number of features.
395
413
* @param rng The RNG to use.
396
414
* @return A {@link DenseVector} representing a centroid.
397
415
*/
398
- private static DenseVector getRandomCentroidFromData (SparseVector [] data ,
399
- int numFeatures , SplittableRandom rng ) {
400
- int rand_idx = rng .nextInt (data .length );
401
- return sparseToDense (data [rand_idx ], numFeatures );
402
- }
403
-
404
- /**
405
- * Create a {@link DenseVector} from the data contained in a
406
- * {@link SparseVector}.
407
- *
408
- * @param vec The {@link SparseVector} to be transformed.
409
- * @param numFeatures The number of features.
410
- * @return A {@link DenseVector} containing the information from vec.
411
- */
412
- private static DenseVector sparseToDense (SparseVector vec , int numFeatures ) {
413
- DenseVector dense = new DenseVector (numFeatures );
414
- dense .intersectAndAddInPlace (vec );
415
- return dense ;
416
+ private static DenseVector getRandomCentroidFromData (SparseVector [] data , SplittableRandom rng ) {
417
+ int randIdx = rng .nextInt (data .length );
418
+ return data [randIdx ].densify ();
416
419
}
417
420
418
421
/**
419
- *
422
+ * Compute the distance between the two vectors.
420
423
* @param cluster A {@link DenseVector} representing a centroid.
421
424
* @param vector A {@link SGDVector} representing an example.
422
425
* @param distanceType The distance metric to employ.
423
426
* @return A double representing the distance from vector to centroid.
424
427
*/
425
- private static double getDistance (DenseVector cluster , SGDVector vector ,
426
- Distance distanceType ) {
428
+ private static double getDistance (DenseVector cluster , SGDVector vector , Distance distanceType ) {
427
429
double distance ;
428
430
switch (distanceType ) {
429
431
case EUCLIDEAN :
@@ -441,30 +443,41 @@ private static double getDistance(DenseVector cluster, SGDVector vector,
441
443
return distance ;
442
444
}
443
445
446
+ /**
447
+ * Runs the mStep, writing to the {@code centroidVectors} array.
448
+ * @param fjp The ForkJoinPool to run the computation in if it should be executed in parallel.
449
+ * If the fjp is null then the computation is executed sequentially on the main thread.
450
+ * @param centroidVectors The centroid vectors to write out.
451
+ * @param clusterAssignments The current cluster assignments.
452
+ * @param data The data points.
453
+ * @param weights The example weights.
454
+ */
444
455
protected void mStep (ForkJoinPool fjp , DenseVector [] centroidVectors , Map <Integer , List <Integer >> clusterAssignments , SparseVector [] data , double [] weights ) {
445
456
// M step
446
- Stream <Entry <Integer , List <Integer >>> mStream ;
447
- if (numThreads > 1 ) {
448
- mStream = StreamUtil .boundParallelism (clusterAssignments .entrySet ().stream ().parallel ());
457
+ Consumer <Entry <Integer , List <Integer >>> mStepFunc = (e ) -> {
458
+ DenseVector newCentroid = centroidVectors [e .getKey ()];
459
+ newCentroid .fill (0.0 );
460
+
461
+ double weightSum = 0.0 ;
462
+ for (Integer idx : e .getValue ()) {
463
+ newCentroid .intersectAndAddInPlace (data [idx ], (double f ) -> f * weights [idx ]);
464
+ weightSum += weights [idx ];
465
+ }
466
+ if (weightSum != 0.0 ) {
467
+ newCentroid .scaleInPlace (1.0 / weightSum );
468
+ }
469
+ };
470
+
471
+ Stream <Entry <Integer , List <Integer >>> mStream = clusterAssignments .entrySet ().stream ();
472
+ if (fjp != null ) {
473
+ Stream <Entry <Integer , List <Integer >>> parallelMStream = StreamUtil .boundParallelism (mStream .parallel ());
474
+ try {
475
+ fjp .submit (() -> parallelMStream .forEach (mStepFunc )).get ();
476
+ } catch (InterruptedException | ExecutionException e ) {
477
+ throw new RuntimeException ("Parallel execution failed" , e );
478
+ }
449
479
} else {
450
- mStream = clusterAssignments .entrySet ().stream ();
451
- }
452
- try {
453
- fjp .submit (() -> mStream .forEach ((e ) -> {
454
- DenseVector newCentroid = centroidVectors [e .getKey ()];
455
- newCentroid .fill (0.0 );
456
-
457
- int counter = 0 ;
458
- for (Integer idx : e .getValue ()) {
459
- newCentroid .intersectAndAddInPlace (data [idx ], (double f ) -> f * weights [idx ]);
460
- counter ++;
461
- }
462
- if (counter > 0 ) {
463
- newCentroid .scaleInPlace (1.0 / counter );
464
- }
465
- })).get ();
466
- } catch (InterruptedException | ExecutionException e ) {
467
- throw new RuntimeException ("Parallel execution failed" , e );
480
+ mStream .forEach (mStepFunc );
468
481
}
469
482
}
470
483
@@ -483,11 +496,25 @@ public TrainerProvenance getProvenance() {
483
496
*/
484
497
static class IntAndVector {
485
498
final int idx ;
486
- final SparseVector vector ;
499
+ final SGDVector vector ;
487
500
488
- public IntAndVector (int idx , SparseVector vector ) {
501
+ /**
502
+ * Constructs an index and vector tuple.
503
+ * @param idx The index.
504
+ * @param vector The vector.
505
+ */
506
+ public IntAndVector (int idx , SGDVector vector ) {
489
507
this .idx = idx ;
490
508
this .vector = vector ;
491
509
}
492
510
}
511
+
512
+ /**
513
+ * Used to allow FJPs to work with OpenSearch's SecureSM.
514
+ */
515
+ private static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool .ForkJoinWorkerThreadFactory {
516
+ public final ForkJoinWorkerThread newThread (ForkJoinPool pool ) {
517
+ return AccessController .doPrivileged ((PrivilegedAction <ForkJoinWorkerThread >) () -> new ForkJoinWorkerThread (pool ) {});
518
+ }
519
+ }
493
520
}
0 commit comments