diff --git a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala index 748db13..4894db0 100644 --- a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala +++ b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala @@ -128,7 +128,7 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores }).collect() val isolationForestModel = copyValues( - new IsolationForestModel(uid, isolationTrees, numSamples).setParent(this)) + new IsolationForestModel(uid, isolationTrees, numSamples, numFeatures).setParent(this)) // Determine and set the model threshold based upon the specified contamination and // contaminationError parameters. diff --git a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModel.scala b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModel.scala index 9d425dc..30c8c29 100644 --- a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModel.scala +++ b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModel.scala @@ -16,16 +16,24 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} * * @param uid The immutable unique ID for the model. * @param isolationTrees The array of isolation tree models that compose the isolation forest. + * @param numSamples The number of samples used to train each tree. + * @param numFeatures The user-specified number of features used to train each isolation tree. For certain edge cases, + * a given isolation tree may not have any nodes using some of these features, e.g., a shallow tree + * where the number of features in the training data exceeds the number of nodes in the tree. */ class IsolationForestModel( override val uid: String, val isolationTrees: Array[IsolationTree], - private val numSamples: Int) + private val numSamples: Int, + private val numFeatures: Int) extends Model[IsolationForestModel] with IsolationForestParams with MLWritable { require(numSamples > 0, s"parameter numSamples must be >0, but given invalid value ${numSamples}") final def getNumSamples: Int = numSamples + require(numFeatures > 0, s"parameter numFeatures must be >0, but given invalid value ${numFeatures}") + final def getNumFeatures: Int = numFeatures + // The outlierScoreThreshold needs to be a mutable variable because it is not known when an // IsolationForestModel instance is created. private var outlierScoreThreshold: Double = -1 @@ -40,7 +48,7 @@ class IsolationForestModel( override def copy(extra: ParamMap): IsolationForestModel = { - val isolationForestCopy = new IsolationForestModel(uid, isolationTrees, numSamples) + val isolationForestCopy = new IsolationForestModel(uid, isolationTrees, numSamples, numFeatures) .setParent(this.parent) isolationForestCopy.setOutlierScoreThreshold(outlierScoreThreshold) copyValues(isolationForestCopy, extra) diff --git a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModelReadWrite.scala b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModelReadWrite.scala index 49ef480..0e2a687 100644 --- a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModelReadWrite.scala +++ b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModelReadWrite.scala @@ -49,6 +49,7 @@ private[isolationforest] case object IsolationForestModelReadWrite extends Loggi implicit val format = DefaultFormats val (metadata, treesData) = loadImpl(path, sparkSession) val numSamples = (metadata.metadata \ "numSamples").extract[Int] + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val outlierScoreThreshold = (metadata.metadata \ "outlierScoreThreshold").extract[Double] val trees = treesData.map { @@ -56,7 +57,7 @@ private[isolationforest] case object IsolationForestModelReadWrite extends Loggi case externalNode: ExternalNode => new IsolationTree(externalNode.asInstanceOf[ExternalNode]) } - val model = new IsolationForestModel(metadata.uid, trees, numSamples) + val model = new IsolationForestModel(metadata.uid, trees, numSamples, numFeatures) metadata.setParams(model) model.setOutlierScoreThreshold(outlierScoreThreshold) @@ -237,7 +238,8 @@ private[isolationforest] case object IsolationForestModelReadWrite extends Loggi val extraMetadata: JObject = ("outlierScoreThreshold", instance.getOutlierScoreThreshold) ~ - ("numSamples", instance.getNumSamples) + ("numSamples", instance.getNumSamples) ~ + ("numFeatures", instance.getNumFeatures) saveImplHelper(path, sparkSession, extraMetadata) } @@ -246,7 +248,7 @@ private[isolationforest] case object IsolationForestModelReadWrite extends Loggi * * @param path The path on disk used to save the ensemble model. * @param spark The SparkSession instance to use. - * @param extraMetadata Metadata such as outlierScoreThreshold and numSamples. + * @param extraMetadata Metadata such as outlierScoreThreshold, numSamples, and numFeatures. */ private def saveImplHelper(path: String, spark: SparkSession, extraMetadata: JObject): Unit = { diff --git a/isolation-forest/src/test/resources/savedIsolationForestModel/metadata/part-00000 b/isolation-forest/src/test/resources/savedIsolationForestModel/metadata/part-00000 index bce5d8a..e69ebb9 100644 --- a/isolation-forest/src/test/resources/savedIsolationForestModel/metadata/part-00000 +++ b/isolation-forest/src/test/resources/savedIsolationForestModel/metadata/part-00000 @@ -1 +1 @@ -{"class":"com.linkedin.relevance.isolationforest.IsolationForestModel","timestamp":1544084998332,"sparkVersion":"2.3.0.89","uid":"isolation-forest_746c9083c2c1","paramMap":{"predictionCol":"predictedLabel","maxFeatures":1.0,"scoreCol":"outlierScore","maxSamples":256.0,"randomSeed":1,"bootstrap":false,"contamination":0.02,"featuresCol":"features","numEstimators":100},"outlierScoreThreshold":0.6015323679815825,"numSamples":256} +{"class":"com.linkedin.relevance.isolationforest.IsolationForestModel","timestamp":1544084998332,"sparkVersion":"2.3.0.89","uid":"isolation-forest_746c9083c2c1","paramMap":{"predictionCol":"predictedLabel","maxFeatures":1.0,"scoreCol":"outlierScore","maxSamples":256.0,"randomSeed":1,"bootstrap":false,"contamination":0.02,"featuresCol":"features","numEstimators":100},"outlierScoreThreshold":0.6015323679815825,"numSamples":256,"numFeatures":6} diff --git a/isolation-forest/src/test/scala/com/linkedin/relevance/isolationforest/IsolationForestModelWriteReadTest.scala b/isolation-forest/src/test/scala/com/linkedin/relevance/isolationforest/IsolationForestModelWriteReadTest.scala index bb068ad..3127712 100644 --- a/isolation-forest/src/test/scala/com/linkedin/relevance/isolationforest/IsolationForestModelWriteReadTest.scala +++ b/isolation-forest/src/test/scala/com/linkedin/relevance/isolationforest/IsolationForestModelWriteReadTest.scala @@ -47,6 +47,7 @@ class IsolationForestModelWriteReadTest extends Logging { isolationForestModel1.extractParamMap.toString, isolationForestModel2.extractParamMap.toString) Assert.assertEquals(isolationForestModel1.getNumSamples, isolationForestModel2.getNumSamples) + Assert.assertEquals(isolationForestModel1.getNumFeatures, isolationForestModel2.getNumFeatures) Assert.assertEquals( isolationForestModel1.getOutlierScoreThreshold, isolationForestModel2.getOutlierScoreThreshold) @@ -110,6 +111,7 @@ class IsolationForestModelWriteReadTest extends Logging { isolationForestModel1.extractParamMap.toString, isolationForestModel2.extractParamMap.toString) Assert.assertEquals(isolationForestModel1.getNumSamples, isolationForestModel2.getNumSamples) + Assert.assertEquals(isolationForestModel1.getNumFeatures, isolationForestModel2.getNumFeatures) Assert.assertEquals( isolationForestModel1.getOutlierScoreThreshold, isolationForestModel2.getOutlierScoreThreshold) @@ -207,7 +209,7 @@ class IsolationForestModelWriteReadTest extends Logging { val spark = getSparkSession // Create an isolation forest model with no isolation trees - val isolationForestModel1 = new IsolationForestModel("testUid", Array(), 1) + val isolationForestModel1 = new IsolationForestModel("testUid", Array(), numSamples = 1, numFeatures = 2) isolationForestModel1.setOutlierScoreThreshold(0.5) // Write the trained model to disk and then read it back from disk @@ -221,6 +223,7 @@ class IsolationForestModelWriteReadTest extends Logging { isolationForestModel1.extractParamMap.toString, isolationForestModel2.extractParamMap.toString) Assert.assertEquals(isolationForestModel1.getNumSamples, isolationForestModel2.getNumSamples) + Assert.assertEquals(isolationForestModel1.getNumFeatures, isolationForestModel2.getNumFeatures) Assert.assertEquals( isolationForestModel1.getOutlierScoreThreshold, isolationForestModel2.getOutlierScoreThreshold) diff --git a/version.properties b/version.properties index 27558af..0af5cf0 100644 --- a/version.properties +++ b/version.properties @@ -1,3 +1,3 @@ # Version of the produced binaries. # The version is inferred by shipkit-auto-version Gradle plugin (https://github.com/shipkit/shipkit-auto-version). -version=2.0.* +version=3.0.*