From 3381e68d56b66562ba839e0beac5f1bb418e9759 Mon Sep 17 00:00:00 2001 From: NickEdwards7502 Date: Thu, 19 Sep 2024 14:23:06 +1000 Subject: [PATCH] STYLE: Format with scalamft (#237) --- .../csiro/variantspark/api/GetRFModel.scala | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/main/scala/au/csiro/variantspark/api/GetRFModel.scala b/src/main/scala/au/csiro/variantspark/api/GetRFModel.scala index a656a798..a4a22f95 100644 --- a/src/main/scala/au/csiro/variantspark/api/GetRFModel.scala +++ b/src/main/scala/au/csiro/variantspark/api/GetRFModel.scala @@ -1,22 +1,24 @@ package au.csiro.variantspark.api -import au.csiro.variantspark.algo.{ - RandomForest, - RandomForestModel, - RandomForestParams -} +import au.csiro.variantspark.algo.{RandomForest, RandomForestModel, RandomForestParams} import au.csiro.variantspark.input.{FeatureSource, LabelSource} /** Passes a trained random forest model back to the python wrapper */ object RFModelTrainer { - def trainModel( - featureSource: FeatureSource, - labelSource: LabelSource, - params: RandomForestParams, - nTrees: Int, - rfBatchSize: Int - ): RandomForestModel = { + + /** Trains a random forest model with provided data and parameters + * + * @param featureSource: FeatureSource object containing training X + * @param labelSource: LabelSource object containing training y + * @param params: Random forest hyperparameters (passed to model on initialisation) + * @param nTrees: Number of trees to compute (passed to model during training) + * @param rfBatchSize: Number of trees per batch (passed to model during training) + * + * @return Trained random forest model + */ + def trainModel(featureSource: FeatureSource, labelSource: LabelSource, + params: RandomForestParams, nTrees: Int, rfBatchSize: Int): RandomForestModel = { val labels = labelSource.getLabels(featureSource.sampleNames) lazy val inputData = featureSource.features.zipWithIndex.cache() val rf = new RandomForest(params)