From 0fc736f3ffcb29e6188e34523bbf76fde8769e2b Mon Sep 17 00:00:00 2001 From: NickEdwards7502 Date: Wed, 11 Sep 2024 15:37:08 +1000 Subject: [PATCH] DEV: Created scala function that trains a forest and passes back to python context (#237) --- .../csiro/variantspark/api/GetRFModel.scala | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 src/main/scala/au/csiro/variantspark/api/GetRFModel.scala diff --git a/src/main/scala/au/csiro/variantspark/api/GetRFModel.scala b/src/main/scala/au/csiro/variantspark/api/GetRFModel.scala new file mode 100644 index 00000000..a656a798 --- /dev/null +++ b/src/main/scala/au/csiro/variantspark/api/GetRFModel.scala @@ -0,0 +1,26 @@ +package au.csiro.variantspark.api + +import au.csiro.variantspark.algo.{ + RandomForest, + RandomForestModel, + RandomForestParams +} +import au.csiro.variantspark.input.{FeatureSource, LabelSource} + +/** Passes a trained random forest model back to the python wrapper + */ +object RFModelTrainer { + def trainModel( + featureSource: FeatureSource, + labelSource: LabelSource, + params: RandomForestParams, + nTrees: Int, + rfBatchSize: Int + ): RandomForestModel = { + val labels = labelSource.getLabels(featureSource.sampleNames) + lazy val inputData = featureSource.features.zipWithIndex.cache() + val rf = new RandomForest(params) + val rfTrained = rf.batchTrain(inputData, labels, nTrees, rfBatchSize) + rfTrained + } +}