diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 1290465ea7c8..db861904c339 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -439,14 +439,6 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath, nativeJsonModelPath)) - // test default "deprecated" - val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath - model.write.save(modelUbjPath) - val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath - model.nativeBooster.saveModel(nativeDeprecatedModelPath) - assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath, - nativeDeprecatedModelPath)) - // json file should be indifferent with ubj file val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath model.write.option("format", "json").save(modelJsonPath) @@ -454,6 +446,14 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS model.nativeBooster.saveModel(nativeUbjModelPath) assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath, nativeUbjModelPath)) + + // test default "deprecated" + val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath + model.write.option("format", "deprecated").save(modelUbjPath) + val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath + model.nativeBooster.saveModel(nativeDeprecatedModelPath) + assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath, + nativeDeprecatedModelPath)) } test("native json model file should store feature_name and feature_type") { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index 1bdea7a827bd..ea6f063a1054 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -333,14 +333,6 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath, nativeJsonModelPath)) - // test default "deprecated" - val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath - model.write.save(modelUbjPath) - val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath - model.nativeBooster.saveModel(nativeDeprecatedModelPath) - assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath, - nativeDeprecatedModelPath)) - // json file should be indifferent with ubj file val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath model.write.option("format", "json").save(modelJsonPath) @@ -348,6 +340,14 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu model.nativeBooster.saveModel(nativeUbjModelPath) assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath, nativeUbjModelPath)) + + // test default "deprecated" + val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath + model.write.option("format", "deprecated").save(modelUbjPath) + val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath + model.nativeBooster.saveModel(nativeDeprecatedModelPath) + assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath, + nativeDeprecatedModelPath)) } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 23b8b1a80d4c..a9b34f2b07a9 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -15,7 +15,9 @@ */ package ml.dmlc.xgboost4j.java; -import java.io.*; +import java.io.IOException; +import java.io.OutputStream; +import java.io.Serializable; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; @@ -33,8 +35,8 @@ /** * Booster for xgboost, this is a model API that support interactive build of a XGBoost Model */ -public class Booster implements Serializable, KryoSerializable { - public static final String DEFAULT_FORMAT = "deprecated"; +public class Booster implements Serializable, KryoSerializable, AutoCloseable { + public static final String DEFAULT_FORMAT = "ubj"; private static final Log logger = LogFactory.getLog(Booster.class); // handle to the booster. private long handle = 0; @@ -815,6 +817,11 @@ protected void finalize() throws Throwable { dispose(); } + @Override + public void close() throws XGBoostError { + dispose(); + } + public synchronized void dispose() { if (handle != 0L) { XGBoostJNI.XGBoosterFree(handle); diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index d53c003a416b..cec6fabbb808 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -17,18 +17,20 @@ import java.io.*; import java.util.*; +import java.util.stream.Collectors; -import junit.framework.TestCase; import org.junit.Test; +import static org.junit.Assert.*; + /** * test cases for Booster * * @author hzx */ public class BoosterImplTest { - private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm"; - private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1&format=libsvm"; + private final String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm"; + private final String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1&format=libsvm"; public static class EvalError implements IEvaluation { @Override @@ -83,20 +85,21 @@ private Booster trainBooster(DMatrix trainMat, DMatrix testMat) throws XGBoostEr } @Test - public void testBoosterBasic() throws XGBoostError, IOException { + public void testBoosterBasic() throws XGBoostError { DMatrix trainMat = new DMatrix(this.train_uri); DMatrix testMat = new DMatrix(this.test_uri); - Booster booster = trainBooster(trainMat, testMat); + try (Booster booster = trainBooster(trainMat, testMat)) { - //predict raw output - float[][] predicts = booster.predict(testMat, true, 0); + //predict raw output + float[][] predicts = booster.predict(testMat, true, 0); - //eval - IEvaluation eval = new EvalError(); - //error must be less than 0.1 - TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f); + //eval + IEvaluation eval = new EvalError(); + //error must be less than 0.1 + assertTrue(eval.eval(predicts, testMat) < 0.1f); + } } @Test @@ -105,22 +108,24 @@ public void saveLoadModelWithPath() throws XGBoostError, IOException { DMatrix testMat = new DMatrix(this.test_uri); IEvaluation eval = new EvalError(); - Booster booster = trainBooster(trainMat, testMat); - // save and load - File temp = File.createTempFile("temp", "model"); - temp.deleteOnExit(); - booster.saveModel(temp.getAbsolutePath()); - - Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath()); - assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj"))); - assert (Arrays.equals(bst2.toByteArray("json"), booster.toByteArray("json"))); - assert (Arrays.equals(bst2.toByteArray("deprecated"), booster.toByteArray("deprecated"))); - float[][] predicts2 = bst2.predict(testMat, true, 0); - TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f); + try (Booster booster = trainBooster(trainMat, testMat)) { + // save and load + File temp = File.createTempFile("temp", "model"); + temp.deleteOnExit(); + booster.saveModel(temp.getAbsolutePath()); + + try (Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath())) { + for (String format: Arrays.asList("ubj", "json", "deprecated")) { + assertArrayEquals(bst2.toByteArray(format), booster.toByteArray(format)); + } + float[][] predicts2 = bst2.predict(testMat, true, 0); + assertTrue(eval.eval(predicts2, testMat) < 0.1f); + } + } } @Test - public void saveLoadModelWithFeaturesWithPath() throws XGBoostError, IOException { + public void saveLoadModelWithFeaturesWithPath() throws Exception { DMatrix trainMat = new DMatrix(this.train_uri); DMatrix testMat = new DMatrix(this.test_uri); IEvaluation eval = new EvalError(); @@ -136,39 +141,40 @@ public void saveLoadModelWithFeaturesWithPath() throws XGBoostError, IOException trainMat.setFeatureTypes(featureTypes); testMat.setFeatureTypes(featureTypes); - Booster booster = trainBooster(trainMat, testMat); - // save and load, only json format save and load feature_name and feature_type - File temp = File.createTempFile("temp", ".json"); - temp.deleteOnExit(); - booster.saveModel(temp.getAbsolutePath()); - - String modelString = new String(booster.toByteArray("json")); - System.out.println(modelString); - - Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath()); - assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj"))); - assert (Arrays.equals(bst2.toByteArray("json"), booster.toByteArray("json"))); - assert (Arrays.equals(bst2.toByteArray("deprecated"), booster.toByteArray("deprecated"))); - float[][] predicts2 = bst2.predict(testMat, true, 0); - TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f); + try(Booster booster = trainBooster(trainMat, testMat)) { + // save and load, only json format save and load feature_name and feature_type + File temp = File.createTempFile("temp", ".json"); + temp.deleteOnExit(); + booster.saveModel(temp.getAbsolutePath()); + + try (Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath())) { + for (String format: Arrays.asList("ubj", "json", "deprecated")) { + assertArrayEquals(bst2.toByteArray(format), booster.toByteArray(format)); + } + float[][] predicts2 = bst2.predict(testMat, true, 0); + assertTrue(eval.eval(predicts2, testMat) < 0.1f); + } + } } @Test - public void saveLoadModelWithStream() throws XGBoostError, IOException { + public void saveLoadModelWithStream() throws Exception { DMatrix trainMat = new DMatrix(this.train_uri); DMatrix testMat = new DMatrix(this.test_uri); - Booster booster = trainBooster(trainMat, testMat); - - ByteArrayOutputStream output = new ByteArrayOutputStream(); - booster.saveModel(output); - IEvaluation eval = new EvalError(); - Booster loadedBooster = XGBoost.loadModel(new ByteArrayInputStream(output.toByteArray())); - float originalPredictError = eval.eval(booster.predict(testMat, true), testMat); - TestCase.assertTrue("originalPredictErr:" + originalPredictError, - originalPredictError < 0.1f); - float loadedPredictError = eval.eval(loadedBooster.predict(testMat, true), testMat); - TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f); + try (Booster booster = trainBooster(trainMat, testMat)) { + + ByteArrayOutputStream output = new ByteArrayOutputStream(); + booster.saveModel(output); + IEvaluation eval = new EvalError(); + try (Booster loadedBooster = XGBoost.loadModel(new ByteArrayInputStream(output.toByteArray()))) { + float originalPredictError = eval.eval(booster.predict(testMat, true), testMat); + assertTrue("originalPredictErr:" + originalPredictError, + originalPredictError < 0.1f); + float loadedPredictError = eval.eval(loadedBooster.predict(testMat, true), testMat); + assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f); + } + } } private static class IncreasingEval implements IEvaluation { @@ -199,9 +205,9 @@ public void testDescendMetricsWithBoundaryCondition() { for (int itr = 0; itr < totalIterations; itr++) { boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, itr, bestIteration); if (itr == totalIterations - 1) { - TestCase.assertTrue(es); + assertTrue(es); } else { - TestCase.assertFalse(es); + assertFalse(es); } } } @@ -224,7 +230,7 @@ public void testEarlyStoppingForMultipleMetrics() { for (int i = 0; i < totalIterations; i++) { bestIteration = i; boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration); - TestCase.assertFalse(es); + assertFalse(es); } // when we have multiple datasets, only the last one was used to determinate early stop @@ -235,7 +241,7 @@ public void testEarlyStoppingForMultipleMetrics() { for (int i = 0; i < totalIterations; i++) { bestIteration = i; boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration); - TestCase.assertFalse(es); + assertFalse(es); } // Now assign metric values to the last dataset. @@ -248,9 +254,9 @@ public void testEarlyStoppingForMultipleMetrics() { // if any metrics off, we need to stop boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration); if (i >= earlyStoppingRound) { - TestCase.assertTrue(es); + assertTrue(es); } else { - TestCase.assertFalse(es); + assertFalse(es); } } } @@ -267,14 +273,14 @@ public void testDescendMetrics() { int bestIteration = 0; boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); - TestCase.assertTrue(es); + assertTrue(es); for (int i = 0; i < totalIterations; i++) { metrics[0][i] = totalIterations - i; } bestIteration = totalIterations - 1; es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); - TestCase.assertFalse(es); + assertFalse(es); for (int i = 0; i < totalIterations; i++) { metrics[0][i] = totalIterations - i; @@ -285,7 +291,7 @@ public void testDescendMetrics() { bestIteration = 4; es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); - TestCase.assertTrue(es); + assertTrue(es); } @Test @@ -302,9 +308,9 @@ public void testAscendMetricsWithBoundaryCondition() { for (int itr = 0; itr < totalIterations; itr++) { boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, itr, bestIteration); if (itr == totalIterations - 1) { - TestCase.assertTrue(es); + assertTrue(es); } else { - TestCase.assertFalse(es); + assertFalse(es); } } } @@ -321,14 +327,14 @@ public void testAscendMetrics() { int bestIteration = 0; boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); - TestCase.assertTrue(es); + assertTrue(es); for (int i = 0; i < totalIterations; i++) { metrics[0][i] = i; } bestIteration = totalIterations - 1; es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); - TestCase.assertFalse(es); + assertFalse(es); for (int i = 0; i < totalIterations; i++) { metrics[0][i] = i; @@ -339,11 +345,11 @@ public void testAscendMetrics() { bestIteration = 4; es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); - TestCase.assertTrue(es); + assertTrue(es); } @Test - public void testBoosterEarlyStop() throws XGBoostError, IOException { + public void testBoosterEarlyStop() throws XGBoostError { DMatrix trainMat = new DMatrix(this.train_uri); DMatrix testMat = new DMatrix(this.test_uri); Map paramMap = new HashMap() { @@ -361,25 +367,27 @@ public void testBoosterEarlyStop() throws XGBoostError, IOException { final int round = 10; int earlyStoppingRound = 2; float[][] metrics = new float[watches.size()][round]; - XGBoost.train(trainMat, paramMap, round, watches, metrics, null, new IncreasingEval(), - earlyStoppingRound); + try (Booster b = XGBoost.train(trainMat, paramMap, round, watches, metrics, null, new IncreasingEval(), + earlyStoppingRound)) { + // Not needed. + } // Make sure we've stopped early. for (int w = 0; w < watches.size(); w++) { for (int r = 0; r <= earlyStoppingRound; r++) { - TestCase.assertFalse(0.0f == metrics[w][r]); + assertNotEquals(0.0f, metrics[w][r], 0.0); } } for (int w = 0; w < watches.size(); w++) { for (int r = earlyStoppingRound + 1; r < round; r++) { - TestCase.assertEquals(0.0f, metrics[w][r]); + assertEquals(0.0f, metrics[w][r], 0.0); } } } @Test - public void testEarlyStoppingAttributes() throws XGBoostError, IOException { + public void testEarlyStoppingAttributes() throws XGBoostError { DMatrix trainMat = new DMatrix(this.train_uri); DMatrix testMat = new DMatrix(this.test_uri); Map paramMap = new HashMap() { @@ -397,30 +405,31 @@ public void testEarlyStoppingAttributes() throws XGBoostError, IOException { int earlyStoppingRound = 4; float[][] metrics = new float[watches.size()][round]; - Booster booster = XGBoost.train(trainMat, paramMap, round, - watches, metrics, null, null, earlyStoppingRound); - - int bestIter = Integer.valueOf(booster.getAttr("best_iteration")); - float bestScore = Float.valueOf(booster.getAttr("best_score")); - TestCase.assertEquals(bestIter, round - 1); - TestCase.assertEquals(bestScore, metrics[watches.size() - 1][round - 1]); + try (Booster booster = XGBoost.train(trainMat, paramMap, round, + watches, metrics, null, null, earlyStoppingRound)) { + int bestIter = Integer.parseInt(booster.getAttr("best_iteration")); + float bestScore = Float.parseFloat(booster.getAttr("best_score")); + assertEquals(bestIter, round - 1); + assertEquals(bestScore, metrics[watches.size() - 1][round - 1], 0.0); + } } private void testWithQuantileHisto(DMatrix trainingSet, Map watches, int round, Map paramMap, float threshold) throws XGBoostError { float[][] metrics = new float[watches.size()][round]; - Booster booster = XGBoost.train(trainingSet, paramMap, round, watches, - metrics, null, null, 0); - for (int i = 0; i < metrics.length; i++) - for (int j = 1; j < metrics[i].length; j++) { - TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1] || - Math.abs(metrics[i][j] - metrics[i][j - 1]) < 0.1); - } - for (int i = 0; i < metrics.length; i++) - for (int j = 0; j < metrics[i].length; j++) { - TestCase.assertTrue(metrics[i][j] >= threshold); + try(Booster booster = XGBoost.train(trainingSet, paramMap, round, watches, + metrics, null, null, 0)) { + for (float[] metric : metrics) + for (int j = 1; j < metric.length; j++) { + assertTrue(metric[j] >= metric[j - 1] || + Math.abs(metric[j] - metric[j - 1]) < 0.1); + } + for (float[] metric : metrics) { + for (float v : metric) { + assertTrue(v >= threshold); + } } - booster.dispose(); + } } @Test @@ -490,15 +499,19 @@ public void testDumpModelJson() throws XGBoostError { DMatrix trainMat = new DMatrix(this.train_uri); DMatrix testMat = new DMatrix(this.test_uri); - Booster booster = trainBooster(trainMat, testMat); - String[] dump = booster.getModelDump("", false, "json"); - TestCase.assertEquals(" { \"nodeid\":", dump[0].substring(0, 13)); + String[] dump; + try (Booster booster = trainBooster(trainMat, testMat)) { + dump = booster.getModelDump("", false, "json"); + assertEquals(" { \"nodeid\":", dump[0].substring(0, 13)); - // test with specified feature names - String[] featureNames = new String[126]; - for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; - dump = booster.getModelDump(featureNames, false, "json"); - TestCase.assertTrue(dump[0].contains("test_feature_name_")); + // test with specified feature names + String[] featureNames = new String[126]; + for (int i = 0; i < 126; i++) { + featureNames[i] = "test_feature_name_" + i; + } + dump = booster.getModelDump(featureNames, false, "json"); + } + assertTrue(dump[0].contains("test_feature_name_")); } @Test @@ -506,59 +519,39 @@ public void testGetFeatureScore() throws XGBoostError { DMatrix trainMat = new DMatrix(this.train_uri); DMatrix testMat = new DMatrix(this.test_uri); - Booster booster = trainBooster(trainMat, testMat); + Map scoreMap; String[] featureNames = new String[126]; - for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; - Map scoreMap = booster.getFeatureScore(featureNames); - for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_")); - } - - @Test - public void testGetFeatureImportanceGain() throws XGBoostError { - DMatrix trainMat = new DMatrix(this.train_uri); - DMatrix testMat = new DMatrix(this.test_uri); - - Booster booster = trainBooster(trainMat, testMat); - String[] featureNames = new String[126]; - for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; - Map scoreMap = booster.getScore(featureNames, "gain"); - for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_")); - } - - @Test - public void testGetFeatureImportanceTotalGain() throws XGBoostError { - DMatrix trainMat = new DMatrix(this.train_uri); - DMatrix testMat = new DMatrix(this.test_uri); - - Booster booster = trainBooster(trainMat, testMat); - String[] featureNames = new String[126]; - for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; - Map scoreMap = booster.getScore(featureNames, "total_gain"); - for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_")); - } - - @Test - public void testGetFeatureImportanceCover() throws XGBoostError { - DMatrix trainMat = new DMatrix(this.train_uri); - DMatrix testMat = new DMatrix(this.test_uri); + try (Booster booster = trainBooster(trainMat, testMat)) { + for (int i = 0; i < 126; i++) { + featureNames[i] = "test_feature_name_" + i; + } + scoreMap = booster.getFeatureScore(featureNames); + } - Booster booster = trainBooster(trainMat, testMat); - String[] featureNames = new String[126]; - for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; - Map scoreMap = booster.getScore(featureNames, "cover"); - for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_")); + for (String fName: scoreMap.keySet()) { + assertTrue(fName.startsWith("test_feature_name_")); + assertTrue(Arrays.stream(featureNames).anyMatch(v -> v.equalsIgnoreCase(fName))); + } } @Test - public void testGetFeatureImportanceTotalCover() throws XGBoostError { + public void testGetFeatureImportance() throws XGBoostError { DMatrix trainMat = new DMatrix(this.train_uri); DMatrix testMat = new DMatrix(this.test_uri); - Booster booster = trainBooster(trainMat, testMat); - String[] featureNames = new String[126]; - for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; - Map scoreMap = booster.getScore(featureNames, "total_cover"); - for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_")); + try (Booster booster = trainBooster(trainMat, testMat)) { + String[] featureNames = new String[126]; + for (int i = 0; i < 126; i++) { + featureNames[i] = "test_feature_name_" + i; + } + for (String importanceType: Arrays.asList("gain", "total_gain", "cover", "total_cover")) { + Map scoreMap = booster.getScore(featureNames, importanceType); + for (String fName: scoreMap.keySet()) { + assertTrue(fName.startsWith("test_feature_name_")); + assertTrue(Arrays.stream(featureNames).anyMatch(v -> v.equalsIgnoreCase(fName))); + } + } + } } @Test @@ -675,8 +668,8 @@ public void testTrainFromExistingModel() throws XGBoostError, IOException { round = 4; Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster); float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat); - TestCase.assertTrue(booster1error == booster2error); - TestCase.assertTrue(tempBoosterError > booster2error); + assertEquals(booster1error, booster2error, 0.0); + assertTrue(tempBoosterError > booster2error); } /** @@ -689,24 +682,26 @@ public void testSetAndGetAttrs() throws XGBoostError { DMatrix trainMat = new DMatrix(this.train_uri); DMatrix testMat = new DMatrix(this.test_uri); - Booster booster = trainBooster(trainMat, testMat); - booster.setAttr("testKey1", "testValue1"); - TestCase.assertEquals(booster.getAttr("testKey1"), "testValue1"); - booster.setAttr("testKey1", "testValue2"); - TestCase.assertEquals(booster.getAttr("testKey1"), "testValue2"); - - booster.setAttrs(new HashMap(){{ - put("aa", "AA"); - put("bb", "BB"); - put("cc", "CC"); - }}); - - Map attr = booster.getAttrs(); - TestCase.assertEquals(attr.size(), 6); - TestCase.assertEquals(attr.get("testKey1"), "testValue2"); - TestCase.assertEquals(attr.get("aa"), "AA"); - TestCase.assertEquals(attr.get("bb"), "BB"); - TestCase.assertEquals(attr.get("cc"), "CC"); + Map attr; + try (Booster booster = trainBooster(trainMat, testMat)) { + booster.setAttr("testKey1", "testValue1"); + assertEquals(booster.getAttr("testKey1"), "testValue1"); + booster.setAttr("testKey1", "testValue2"); + assertEquals(booster.getAttr("testKey1"), "testValue2"); + + booster.setAttrs(new HashMap() {{ + put("aa", "AA"); + put("bb", "BB"); + put("cc", "CC"); + }}); + + attr = booster.getAttrs(); + } + assertEquals(attr.size(), 6); + assertEquals(attr.get("testKey1"), "testValue2"); + assertEquals(attr.get("aa"), "AA"); + assertEquals(attr.get("bb"), "BB"); + assertEquals(attr.get("cc"), "CC"); } /** @@ -719,7 +714,8 @@ public void testGetNumFeature() throws XGBoostError { DMatrix trainMat = new DMatrix(this.train_uri); DMatrix testMat = new DMatrix(this.test_uri); - Booster booster = trainBooster(trainMat, testMat); - TestCase.assertEquals(booster.getNumFeature(), 126); + try (Booster booster = trainBooster(trainMat, testMat)) { + assertEquals(booster.getNumFeature(), 126); + } } }