diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/SparkConfig.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/SparkConfig.scala index 5a1908e..7c863be 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/SparkConfig.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/SparkConfig.scala @@ -11,26 +11,22 @@ case class SparkConfig(spark: SparkSession, partitionNum: Int) object SparkConfig { - var spark: SparkSession = _ - - var partitionNum: String = _ - def getSpark(configs: Configs, defaultAppName: String = "algorithm"): SparkConfig = { val sparkConfigs = configs.sparkConfig.map val session = SparkSession.builder .appName(defaultAppName) .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - for (key <- sparkConfigs.keySet) { - session.config(key, sparkConfigs(key)) + sparkConfigs.foreach { case (key, value) => + session.config(key, value) } - partitionNum = sparkConfigs.getOrElse("spark.app.partitionNum", "0") + val partitionNum = sparkConfigs.getOrElse("spark.app.partitionNum", "0") val spark = session.getOrCreate() validate(spark.version, "2.4.*") SparkConfig(spark, partitionNum.toInt) } - def validate(sparkVersion: String, supportedVersions: String*): Unit = { + private def validate(sparkVersion: String, supportedVersions: String*): Unit = { if (sparkVersion != "UNKNOWN" && !supportedVersions.exists(sparkVersion.matches)) { throw new RuntimeException( s"""Your current spark version ${sparkVersion} is not supported by the current NebulaGraph Algorithm.