Skip to content

[SPARK-52191] [ML] [CONNECT] Remove Java deserializer in model local path loader #50922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
ReadWriteUtils.saveArray(dataPath, nodeData.toArray, sparkSession, numDataParts)
ReadWriteUtils.saveArray(
dataPath, nodeData.toArray, sparkSession, NodeData.serializeData, numDataParts
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.classification

import java.io.{DataInputStream, DataOutputStream}

import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
Expand Down Expand Up @@ -351,6 +353,21 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
factors: Matrix
)

private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
import ReadWriteUtils._
dos.writeDouble(data.intercept)
serializeVector(data.linear, dos)
serializeMatrix(data.factors, dos)
}

private[ml] def deserializeData(dis: DataInputStream): Data = {
import ReadWriteUtils._
val intercept = dis.readDouble()
val linear = deserializeVector(dis)
val factors = deserializeMatrix(dis)
Data(intercept, linear, factors)
}

@Since("3.0.0")
override def read: MLReader[FMClassificationModel] = new FMClassificationModelReader

Expand All @@ -365,7 +382,7 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.intercept, instance.linear, instance.factors)
val dataPath = new Path(path, "data").toString
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession, serializeData)
}
}

Expand All @@ -377,7 +394,7 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString

val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession, deserializeData)
val model = new FMClassificationModel(
metadata.uid, data.intercept, data.linear, data.factors
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.classification

import java.io.{DataInputStream, DataOutputStream}

import scala.collection.mutable

import breeze.linalg.{DenseVector => BDV}
Expand Down Expand Up @@ -449,6 +451,19 @@ class LinearSVCModel private[classification] (
object LinearSVCModel extends MLReadable[LinearSVCModel] {
private[ml] case class Data(coefficients: Vector, intercept: Double)

private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
import ReadWriteUtils._
serializeVector(data.coefficients, dos)
dos.writeDouble(data.intercept)
}

private[ml] def deserializeData(dis: DataInputStream): Data = {
import ReadWriteUtils._
val coefficients = deserializeVector(dis)
val intercept = dis.readDouble()
Data(coefficients, intercept)
}

@Since("2.2.0")
override def read: MLReader[LinearSVCModel] = new LinearSVCReader

Expand All @@ -465,7 +480,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.coefficients, instance.intercept)
val dataPath = new Path(path, "data").toString
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession, serializeData)
}
}

Expand All @@ -477,7 +492,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
override def load(path: String): LinearSVCModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession, deserializeData)
val model = new LinearSVCModel(metadata.uid, data.coefficients, data.intercept)
metadata.getAndSetParams(model)
model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.classification

import java.io.{DataInputStream, DataOutputStream}
import java.util.Locale

import scala.collection.mutable
Expand Down Expand Up @@ -1325,6 +1326,25 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
coefficientMatrix: Matrix,
isMultinomial: Boolean)

private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
import ReadWriteUtils._
dos.writeInt(data.numClasses)
dos.writeInt(data.numFeatures)
serializeVector(data.interceptVector, dos)
serializeMatrix(data.coefficientMatrix, dos)
dos.writeBoolean(data.isMultinomial)
}

private[ml] def deserializeData(dis: DataInputStream): Data = {
import ReadWriteUtils._
val numClasses = dis.readInt()
val numFeatures = dis.readInt()
val interceptVector = deserializeVector(dis)
val coefficientMatrix = deserializeMatrix(dis)
val isMultinomial = dis.readBoolean()
Data(numClasses, numFeatures, interceptVector, coefficientMatrix, isMultinomial)
}

@Since("1.6.0")
override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader

Expand All @@ -1343,7 +1363,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val data = Data(instance.numClasses, instance.numFeatures, instance.interceptVector,
instance.coefficientMatrix, instance.isMultinomial)
val dataPath = new Path(path, "data").toString
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession, serializeData)
}
}

Expand Down Expand Up @@ -1372,7 +1392,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
interceptVector, numClasses, isMultinomial = false)
} else {
// 2.1+
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession, deserializeData)
new LogisticRegressionModel(metadata.uid, data.coefficientMatrix, data.interceptVector,
data.numClasses, data.isMultinomial)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.classification

import java.io.{DataInputStream, DataOutputStream}

import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
Expand Down Expand Up @@ -370,6 +372,17 @@ object MultilayerPerceptronClassificationModel
extends MLReadable[MultilayerPerceptronClassificationModel] {
private[ml] case class Data(weights: Vector)

private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
import ReadWriteUtils._
serializeVector(data.weights, dos)
}

private[ml] def deserializeData(dis: DataInputStream): Data = {
import ReadWriteUtils._
val weights = deserializeVector(dis)
Data(weights)
}

@Since("2.0.0")
override def read: MLReader[MultilayerPerceptronClassificationModel] =
new MultilayerPerceptronClassificationModelReader
Expand All @@ -388,7 +401,7 @@ object MultilayerPerceptronClassificationModel
// Save model data: weights
val data = Data(instance.weights)
val dataPath = new Path(path, "data").toString
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession, serializeData)
}
}

Expand All @@ -411,7 +424,7 @@ object MultilayerPerceptronClassificationModel
val model = new MultilayerPerceptronClassificationModel(metadata.uid, weights)
model.set("layers", layers)
} else {
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession, deserializeData)
new MultilayerPerceptronClassificationModel(metadata.uid, data.weights)
}
metadata.getAndSetParams(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.classification

import java.io.{DataInputStream, DataOutputStream}

import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats

Expand Down Expand Up @@ -600,6 +602,21 @@ class NaiveBayesModel private[ml] (
object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
private[ml] case class Data(pi: Vector, theta: Matrix, sigma: Matrix)

private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
import ReadWriteUtils._
serializeVector(data.pi, dos)
serializeMatrix(data.theta, dos)
serializeMatrix(data.sigma, dos)
}

private[ml] def deserializeData(dis: DataInputStream): Data = {
import ReadWriteUtils._
val pi = deserializeVector(dis)
val theta = deserializeMatrix(dis)
val sigma = deserializeMatrix(dis)
Data(pi, theta, sigma)
}

@Since("1.6.0")
override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader

Expand All @@ -623,7 +640,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
}

val data = Data(instance.pi, instance.theta, instance.sigma)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession, serializeData)
}
}

Expand All @@ -647,7 +664,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
.head()
new NaiveBayesModel(metadata.uid, pi, theta, Matrices.zeros(0, 0))
} else {
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession, deserializeData)
new NaiveBayesModel(metadata.uid, data.pi, data.theta, data.sigma)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.clustering

import java.io.{DataInputStream, DataOutputStream}

import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
Expand Down Expand Up @@ -229,6 +231,25 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
sigmas: Array[OldMatrix]
)

private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
import ReadWriteUtils._
serializeDoubleArray(data.weights, dos)
serializeGenericArray[OldVector](data.mus, dos, (v, dos) => serializeVector(v.asML, dos))
serializeGenericArray[OldMatrix](data.sigmas, dos, (v, dos) => serializeMatrix(v.asML, dos))
}

private[ml] def deserializeData(dis: DataInputStream): Data = {
import ReadWriteUtils._
val weights = deserializeDoubleArray(dis)
val mus = deserializeGenericArray[OldVector](
dis, dis => OldVectors.fromML(deserializeVector(dis))
)
val sigmas = deserializeGenericArray[OldMatrix](
dis, dis => OldMatrices.fromML(deserializeMatrix(dis))
)
Data(weights, mus, sigmas)
}

@Since("2.0.0")
override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader

Expand All @@ -249,7 +270,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov))
val data = Data(weights, mus, sigmas)
val dataPath = new Path(path, "data").toString
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession, serializeData)
}
}

Expand All @@ -264,7 +285,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
val dataPath = new Path(path, "data").toString

val data = if (ReadWriteUtils.localSavingModeState.get()) {
ReadWriteUtils.loadObjectFromLocal(dataPath)
ReadWriteUtils.loadObjectFromLocal(dataPath, deserializeData)
} else {
val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
Data(
Expand Down
24 changes: 22 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.clustering

import java.io.{DataInputStream, DataOutputStream}

import scala.collection.mutable

import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -215,6 +217,20 @@ class KMeansModel private[ml] (
/** Helper class for storing model data */
private[ml] case class ClusterData(clusterIdx: Int, clusterCenter: Vector)

private[ml] object ClusterData {
private[ml] def serializeData(data: ClusterData, dos: DataOutputStream): Unit = {
import ReadWriteUtils._
dos.writeInt(data.clusterIdx)
serializeVector(data.clusterCenter, dos)
}

private[ml] def deserializeData(dis: DataInputStream): ClusterData = {
import ReadWriteUtils._
val clusterIdx = dis.readInt()
val clusterCenter = deserializeVector(dis)
ClusterData(clusterIdx, clusterCenter)
}
}

/** A writer for KMeans that handles the "internal" (or default) format */
private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
Expand All @@ -233,7 +249,9 @@ private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegi
ClusterData(idx, center)
}
val dataPath = new Path(path, "data").toString
ReadWriteUtils.saveArray[ClusterData](dataPath, data, sparkSession)
ReadWriteUtils.saveArray[ClusterData](
dataPath, data, sparkSession, ClusterData.serializeData
)
}
}

Expand Down Expand Up @@ -281,7 +299,9 @@ object KMeansModel extends MLReadable[KMeansModel] {
val dataPath = new Path(path, "data").toString

val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
val data = ReadWriteUtils.loadArray[ClusterData](dataPath, sparkSession)
val data = ReadWriteUtils.loadArray[ClusterData](
dataPath, sparkSession, ClusterData.deserializeData
)
data.sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
Expand Down
Loading