diff --git a/Core/src/test/java/org/tribuo/util/UtilTest.java b/Core/src/test/java/org/tribuo/util/UtilTest.java index 97dbeea95..2bd1457a7 100644 --- a/Core/src/test/java/org/tribuo/util/UtilTest.java +++ b/Core/src/test/java/org/tribuo/util/UtilTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,15 +17,22 @@ package org.tribuo.util; +import com.oracle.labs.mlrg.olcut.util.MutableLong; import com.oracle.labs.mlrg.olcut.util.Pair; import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.SplittableRandom; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; @@ -80,4 +87,43 @@ public void testAUC() { assertEquals(0.5,output,DELTA); } + @Test + public void testSampleFromCDF() { + double[] pmf = new double[]{0.1,0.2,0.0,0.3,0.0,0.0,0.4,0.0}; + double[] cdf = Util.generateCDF(pmf); + + double[] expectedCDF = new double[]{0.1,0.3,0.3,0.6,0.6,0.6,1.0,1.0}; + + assertArrayEquals(expectedCDF,cdf,1e-10); + + SplittableRandom rng = new SplittableRandom(1235L); + + Map counter = new HashMap<>(); + + final int numSamples = 10000; + for (int i = 0; i < numSamples; i++) { + int curSample = Util.sampleFromCDF(cdf,rng); + MutableLong l = counter.computeIfAbsent(curSample, k -> new MutableLong()); + l.increment(); + } + + assertNotNull(counter.get(0)); + assertNotNull(counter.get(1)); + assertNull(counter.get(2)); + assertNotNull(counter.get(3)); + assertNull(counter.get(4)); + assertNull(counter.get(5)); + assertNotNull(counter.get(6)); + assertNull(counter.get(7)); + + double total = 0; + for (Map.Entry e : counter.entrySet()) { + total += e.getValue().longValue(); + } + assertEquals(numSamples,total); + assertEquals(counter.get(0).longValue()/total,0.1,1e-1); + assertEquals(counter.get(1).longValue()/total,0.2,1e-1); + assertEquals(counter.get(3).longValue()/total,0.3,1e-1); + assertEquals(counter.get(6).longValue()/total,0.4,1e-1); + } }