Skip to content

Commit

Permalink
Appends prediction columns to transform schema
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcus-Rosti committed Nov 27, 2024
1 parent 42a1f1a commit f61607b
Showing 1 changed file with 11 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.ml.Estimator
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.Dataset
import org.apache.spark.{HashPartitioner, TaskContext}

Expand Down Expand Up @@ -200,7 +200,16 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores
require(schema($(featuresCol)).dataType == VectorType,
s"Input column ${$(featuresCol)} is not of required type ${VectorType}")

val outputFields = schema.fields
val outputFields: Array[StructField] = schema.fields ++ Array(
StructField(
name = s"$predictionCol",
dataType = DoubleType
),
StructField(
name = s"$scoreCol",
dataType = DoubleType
)
)

StructType(outputFields)
}
Expand Down

0 comments on commit f61607b

Please sign in to comment.