Skip to content

Commit 3af1e8a

Browse files
committed
optimize
1 parent 3fc5955 commit 3af1e8a

File tree

6 files changed

+238
-174
lines changed

6 files changed

+238
-174
lines changed

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala

Lines changed: 60 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,10 @@
66
package com.vesoft.nebula.algorithm
77

88
import com.vesoft.nebula.algorithm.config.Configs.Argument
9-
import com.vesoft.nebula.algorithm.config.{
10-
AlgoConfig,
11-
BetweennessConfig,
12-
BfsConfig,
13-
CcConfig,
14-
CoefficientConfig,
15-
Configs,
16-
DfsConfig,
17-
HanpConfig,
18-
JaccardConfig,
19-
KCoreConfig,
20-
LPAConfig,
21-
LouvainConfig,
22-
Node2vecConfig,
23-
PRConfig,
24-
ShortestPathConfig,
25-
SparkConfig,
26-
DegreeStaticConfig
27-
}
28-
import com.vesoft.nebula.algorithm.lib.{
29-
BetweennessCentralityAlgo,
30-
BfsAlgo,
31-
ClosenessAlgo,
32-
ClusteringCoefficientAlgo,
33-
ConnectedComponentsAlgo,
34-
DegreeStaticAlgo,
35-
DfsAlgo,
36-
GraphTriangleCountAlgo,
37-
HanpAlgo,
38-
JaccardAlgo,
39-
KCoreAlgo,
40-
LabelPropagationAlgo,
41-
LouvainAlgo,
42-
Node2vecAlgo,
43-
PageRankAlgo,
44-
ShortestPathAlgo,
45-
StronglyConnectedComponentsAlgo,
46-
TriangleCountAlgo
47-
}
48-
import com.vesoft.nebula.algorithm.reader.{CsvReader, JsonReader, NebulaReader}
49-
import com.vesoft.nebula.algorithm.writer.{CsvWriter, NebulaWriter, TextWriter}
9+
import com.vesoft.nebula.algorithm.config._
10+
import com.vesoft.nebula.algorithm.lib._
11+
import com.vesoft.nebula.algorithm.reader.{CsvReader, DataReader, JsonReader, NebulaReader}
12+
import com.vesoft.nebula.algorithm.writer.{AlgoWriter, CsvWriter, NebulaWriter, TextWriter}
5013
import org.apache.commons.math3.ode.UnknownParameterException
5114
import org.apache.log4j.Logger
5215
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
@@ -114,26 +77,8 @@ object Main {
11477
private[this] def createDataSource(spark: SparkSession,
11578
configs: Configs,
11679
partitionNum: String): DataFrame = {
117-
val dataSource = configs.dataSourceSinkEntry.source
118-
val dataSet: Dataset[Row] = dataSource.toLowerCase match {
119-
case "nebula" => {
120-
val reader = new NebulaReader(spark, configs, partitionNum)
121-
reader.read()
122-
}
123-
case "nebula-ngql" => {
124-
val reader = new NebulaReader(spark, configs, partitionNum)
125-
reader.readNgql()
126-
}
127-
case "csv" => {
128-
val reader = new CsvReader(spark, configs, partitionNum)
129-
reader.read()
130-
}
131-
case "json" => {
132-
val reader = new JsonReader(spark, configs, partitionNum)
133-
reader.read()
134-
}
135-
}
136-
dataSet
80+
val dataSource = DataReader.make(configs)
81+
dataSource.read(spark, configs, partitionNum)
13782
}
13883

13984
/**
@@ -149,99 +94,63 @@ object Main {
14994
configs: Configs,
15095
dataSet: DataFrame): DataFrame = {
15196
val hasWeight = configs.dataSourceSinkEntry.hasWeight
152-
val algoResult = {
153-
algoName.toLowerCase match {
154-
case "pagerank" => {
155-
val pageRankConfig = PRConfig.getPRConfig(configs)
156-
PageRankAlgo(spark, dataSet, pageRankConfig, hasWeight)
157-
}
158-
case "louvain" => {
159-
val louvainConfig = LouvainConfig.getLouvainConfig(configs)
160-
LouvainAlgo(spark, dataSet, louvainConfig, hasWeight)
161-
}
162-
case "connectedcomponent" => {
163-
val ccConfig = CcConfig.getCcConfig(configs)
164-
ConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight)
165-
}
166-
case "labelpropagation" => {
167-
val lpaConfig = LPAConfig.getLPAConfig(configs)
168-
LabelPropagationAlgo(spark, dataSet, lpaConfig, hasWeight)
169-
}
170-
case "shortestpaths" => {
171-
val spConfig = ShortestPathConfig.getShortestPathConfig(configs)
172-
ShortestPathAlgo(spark, dataSet, spConfig, hasWeight)
173-
}
174-
case "degreestatic" => {
175-
val dsConfig = DegreeStaticConfig.getDegreeStaticConfig(configs)
176-
DegreeStaticAlgo(spark, dataSet, dsConfig)
177-
}
178-
case "kcore" => {
179-
val kCoreConfig = KCoreConfig.getKCoreConfig(configs)
180-
KCoreAlgo(spark, dataSet, kCoreConfig)
181-
}
182-
case "stronglyconnectedcomponent" => {
183-
val ccConfig = CcConfig.getCcConfig(configs)
184-
StronglyConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight)
185-
}
186-
case "betweenness" => {
187-
val betweennessConfig = BetweennessConfig.getBetweennessConfig(configs)
188-
BetweennessCentralityAlgo(spark, dataSet, betweennessConfig, hasWeight)
189-
}
190-
case "trianglecount" => {
191-
TriangleCountAlgo(spark, dataSet)
192-
}
193-
case "graphtrianglecount" => {
194-
GraphTriangleCountAlgo(spark, dataSet)
195-
}
196-
case "clusteringcoefficient" => {
197-
val coefficientConfig = CoefficientConfig.getCoefficientConfig(configs)
198-
ClusteringCoefficientAlgo(spark, dataSet, coefficientConfig)
199-
}
200-
case "closeness" => {
201-
ClosenessAlgo(spark, dataSet, hasWeight)
202-
}
203-
case "hanp" => {
204-
val hanpConfig = HanpConfig.getHanpConfig(configs)
205-
HanpAlgo(spark, dataSet, hanpConfig, hasWeight)
206-
}
207-
case "node2vec" => {
208-
val node2vecConfig = Node2vecConfig.getNode2vecConfig(configs)
209-
Node2vecAlgo(spark, dataSet, node2vecConfig, hasWeight)
210-
}
211-
case "bfs" => {
212-
val bfsConfig = BfsConfig.getBfsConfig(configs)
213-
BfsAlgo(spark, dataSet, bfsConfig)
214-
}
215-
case "dfs" => {
216-
val dfsConfig = DfsConfig.getDfsConfig(configs)
217-
DfsAlgo(spark, dataSet, dfsConfig)
218-
}
219-
case "jaccard" => {
220-
val jaccardConfig = JaccardConfig.getJaccardConfig(configs)
221-
JaccardAlgo(spark, dataSet, jaccardConfig)
222-
}
223-
case _ => throw new UnknownParameterException("unknown executeAlgo name.")
224-
}
97+
AlgorithmType.mapping.getOrElse(algoName.toLowerCase, throw new UnknownParameterException("unknown executeAlgo name.")) match {
98+
case AlgorithmType.Bfs =>
99+
val bfsConfig = BfsConfig.getBfsConfig(configs)
100+
BfsAlgo(spark, dataSet, bfsConfig)
101+
case AlgorithmType.Closeness =>
102+
ClosenessAlgo(spark, dataSet, hasWeight)
103+
case AlgorithmType.ClusteringCoefficient =>
104+
val coefficientConfig = CoefficientConfig.getCoefficientConfig(configs)
105+
ClusteringCoefficientAlgo(spark, dataSet, coefficientConfig)
106+
case AlgorithmType.ConnectedComponents =>
107+
val ccConfig = CcConfig.getCcConfig(configs)
108+
ConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight)
109+
case AlgorithmType.DegreeStatic =>
110+
val dsConfig = DegreeStaticConfig.getDegreeStaticConfig(configs)
111+
DegreeStaticAlgo(spark, dataSet, dsConfig)
112+
case AlgorithmType.Dfs =>
113+
val dfsConfig = DfsConfig.getDfsConfig(configs)
114+
DfsAlgo(spark, dataSet, dfsConfig)
115+
case AlgorithmType.GraphTriangleCount =>
116+
GraphTriangleCountAlgo(spark, dataSet)
117+
case AlgorithmType.Hanp =>
118+
val hanpConfig = HanpConfig.getHanpConfig(configs)
119+
HanpAlgo(spark, dataSet, hanpConfig, hasWeight)
120+
case AlgorithmType.Jaccard =>
121+
val jaccardConfig = JaccardConfig.getJaccardConfig(configs)
122+
JaccardAlgo(spark, dataSet, jaccardConfig)
123+
case AlgorithmType.KCore =>
124+
val kCoreConfig = KCoreConfig.getKCoreConfig(configs)
125+
KCoreAlgo(spark, dataSet, kCoreConfig)
126+
case AlgorithmType.LabelPropagation =>
127+
val lpaConfig = LPAConfig.getLPAConfig(configs)
128+
LabelPropagationAlgo(spark, dataSet, lpaConfig, hasWeight)
129+
case AlgorithmType.Louvain =>
130+
val louvainConfig = LouvainConfig.getLouvainConfig(configs)
131+
LouvainAlgo(spark, dataSet, louvainConfig, hasWeight)
132+
case AlgorithmType.Node2vec =>
133+
val node2vecConfig = Node2vecConfig.getNode2vecConfig(configs)
134+
Node2vecAlgo(spark, dataSet, node2vecConfig, hasWeight)
135+
case AlgorithmType.PageRank =>
136+
val pageRankConfig = PRConfig.getPRConfig(configs)
137+
PageRankAlgo(spark, dataSet, pageRankConfig, hasWeight)
138+
case AlgorithmType.ShortestPath =>
139+
val spConfig = ShortestPathConfig.getShortestPathConfig(configs)
140+
ShortestPathAlgo(spark, dataSet, spConfig, hasWeight)
141+
case AlgorithmType.StronglyConnectedComponents =>
142+
val ccConfig = CcConfig.getCcConfig(configs)
143+
StronglyConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight)
144+
case AlgorithmType.TriangleCount =>
145+
TriangleCountAlgo(spark, dataSet)
146+
case AlgorithmType.BetweennessCentrality =>
147+
val betweennessConfig = BetweennessConfig.getBetweennessConfig(configs)
148+
BetweennessCentralityAlgo(spark, dataSet, betweennessConfig, hasWeight)
225149
}
226-
algoResult
227150
}
228151

229152
private[this] def saveAlgoResult(algoResult: DataFrame, configs: Configs): Unit = {
230-
val dataSink = configs.dataSourceSinkEntry.sink
231-
dataSink.toLowerCase match {
232-
case "nebula" => {
233-
val writer = new NebulaWriter(algoResult, configs)
234-
writer.write()
235-
}
236-
case "csv" => {
237-
val writer = new CsvWriter(algoResult, configs)
238-
writer.write()
239-
}
240-
case "text" => {
241-
val writer = new TextWriter(algoResult, configs)
242-
writer.write()
243-
}
244-
case _ => throw new UnsupportedOperationException("unsupported data sink")
245-
}
153+
val writer = AlgoWriter.make(configs)
154+
writer.write(algoResult, configs)
246155
}
247156
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package com.vesoft.nebula.algorithm.lib
2+
3+
/**
4+
*
5+
* @author 梦境迷离
6+
* @version 1.0,2023/9/12
7+
*/
8+
sealed trait AlgorithmType {
9+
self =>
10+
def stringify: String = self match {
11+
case AlgorithmType.Bfs => "bfs"
12+
case AlgorithmType.Closeness => "closeness"
13+
case AlgorithmType.ClusteringCoefficient => "clusteringcoefficient"
14+
case AlgorithmType.ConnectedComponents => "connectedcomponent"
15+
case AlgorithmType.DegreeStatic => "degreestatic"
16+
case AlgorithmType.Dfs => "dfs"
17+
case AlgorithmType.GraphTriangleCount => "graphtrianglecount"
18+
case AlgorithmType.Hanp => "hanp"
19+
case AlgorithmType.Jaccard => "jaccard"
20+
case AlgorithmType.KCore => "kcore"
21+
case AlgorithmType.LabelPropagation => "labelpropagation"
22+
case AlgorithmType.Louvain => "louvain"
23+
case AlgorithmType.Node2vec => "node2vec"
24+
case AlgorithmType.PageRank => "pagerank"
25+
case AlgorithmType.ShortestPath => "shortestpaths"
26+
case AlgorithmType.StronglyConnectedComponents => "stronglyconnectedcomponent"
27+
case AlgorithmType.TriangleCount => "trianglecount"
28+
case AlgorithmType.BetweennessCentrality => "betweenness"
29+
}
30+
}
31+
object AlgorithmType {
32+
lazy val mapping: Map[String, AlgorithmType] = Map(
33+
Bfs.stringify -> Bfs,
34+
Closeness.stringify -> Closeness,
35+
ClusteringCoefficient.stringify -> ClusteringCoefficient,
36+
ConnectedComponents.stringify -> ConnectedComponents,
37+
DegreeStatic.stringify -> DegreeStatic,
38+
GraphTriangleCount.stringify -> GraphTriangleCount,
39+
Hanp.stringify -> Hanp,
40+
Jaccard.stringify -> Jaccard,
41+
KCore.stringify -> KCore,
42+
LabelPropagation.stringify -> LabelPropagation,
43+
Louvain.stringify -> Louvain,
44+
Node2vec.stringify -> Node2vec,
45+
PageRank.stringify -> PageRank,
46+
ShortestPath.stringify -> ShortestPath,
47+
StronglyConnectedComponents.stringify -> StronglyConnectedComponents,
48+
TriangleCount.stringify -> TriangleCount,
49+
BetweennessCentrality.stringify -> BetweennessCentrality
50+
)
51+
object BetweennessCentrality extends AlgorithmType
52+
object Bfs extends AlgorithmType
53+
object Closeness extends AlgorithmType
54+
object ClusteringCoefficient extends AlgorithmType
55+
object ConnectedComponents extends AlgorithmType
56+
object DegreeStatic extends AlgorithmType
57+
object Dfs extends AlgorithmType
58+
object GraphTriangleCount extends AlgorithmType
59+
object Hanp extends AlgorithmType
60+
object Jaccard extends AlgorithmType
61+
object KCore extends AlgorithmType
62+
object LabelPropagation extends AlgorithmType
63+
object Louvain extends AlgorithmType
64+
object Node2vec extends AlgorithmType
65+
object PageRank extends AlgorithmType
66+
object ShortestPath extends AlgorithmType
67+
object StronglyConnectedComponents extends AlgorithmType
68+
object TriangleCount extends AlgorithmType
69+
}

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,27 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
1212

1313
import scala.collection.mutable.ListBuffer
1414

15-
abstract class DataReader(spark: SparkSession, configs: Configs) {
16-
def read(): DataFrame
15+
abstract class DataReader {
16+
val tpe: ReaderType
17+
def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame
18+
}
19+
object DataReader {
20+
def make(configs: Configs): DataReader = {
21+
ReaderType.mapping
22+
.get(configs.dataSourceSinkEntry.source.toLowerCase)
23+
.collect {
24+
case ReaderType.json => new JsonReader
25+
case ReaderType.nebulaNgql => new NebulaNgqlReader
26+
case ReaderType.nebula => new NebulaReader
27+
case ReaderType.csv => new CsvReader
28+
}
29+
.getOrElse(throw new UnsupportedOperationException("unsupported reader"))
30+
}
1731
}
1832

19-
class NebulaReader(spark: SparkSession, configs: Configs, partitionNum: String)
20-
extends DataReader(spark, configs) {
21-
override def read(): DataFrame = {
33+
class NebulaReader extends DataReader {
34+
override val tpe: ReaderType = ReaderType.nebula
35+
override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
2236
val metaAddress = configs.nebulaConfig.readConfigEntry.address
2337
val space = configs.nebulaConfig.readConfigEntry.space
2438
val labels = configs.nebulaConfig.readConfigEntry.labels
@@ -66,7 +80,12 @@ class NebulaReader(spark: SparkSession, configs: Configs, partitionNum: String)
6680
dataset
6781
}
6882

69-
def readNgql(): DataFrame = {
83+
}
84+
final class NebulaNgqlReader extends NebulaReader {
85+
86+
override val tpe: ReaderType = ReaderType.nebulaNgql
87+
88+
override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
7089
val metaAddress = configs.nebulaConfig.readConfigEntry.address
7190
val graphAddress = configs.nebulaConfig.readConfigEntry.graphAddress
7291
val space = configs.nebulaConfig.readConfigEntry.space
@@ -113,11 +132,12 @@ class NebulaReader(spark: SparkSession, configs: Configs, partitionNum: String)
113132
}
114133
dataset
115134
}
135+
116136
}
117137

118-
class CsvReader(spark: SparkSession, configs: Configs, partitionNum: String)
119-
extends DataReader(spark, configs) {
120-
override def read(): DataFrame = {
138+
final class CsvReader extends DataReader {
139+
override val tpe: ReaderType = ReaderType.csv
140+
override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
121141
val delimiter = configs.localConfigEntry.delimiter
122142
val header = configs.localConfigEntry.header
123143
val localPath = configs.localConfigEntry.filePath
@@ -132,7 +152,7 @@ class CsvReader(spark: SparkSession, configs: Configs, partitionNum: String)
132152
val weight = configs.localConfigEntry.weight
133153
val src = configs.localConfigEntry.srcId
134154
val dst = configs.localConfigEntry.dstId
135-
if (configs.dataSourceSinkEntry.hasWeight && weight != null && !weight.trim.isEmpty) {
155+
if (configs.dataSourceSinkEntry.hasWeight && weight != null && weight.trim.nonEmpty) {
136156
data.select(src, dst, weight)
137157
} else {
138158
data.select(src, dst)
@@ -143,18 +163,17 @@ class CsvReader(spark: SparkSession, configs: Configs, partitionNum: String)
143163
data
144164
}
145165
}
146-
147-
class JsonReader(spark: SparkSession, configs: Configs, partitionNum: String)
148-
extends DataReader(spark, configs) {
149-
override def read(): DataFrame = {
166+
final class JsonReader extends DataReader {
167+
override val tpe: ReaderType = ReaderType.json
168+
override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
150169
val localPath = configs.localConfigEntry.filePath
151170
val data = spark.read.json(localPath)
152171
val partition = partitionNum.toInt
153172

154173
val weight = configs.localConfigEntry.weight
155174
val src = configs.localConfigEntry.srcId
156175
val dst = configs.localConfigEntry.dstId
157-
if (configs.dataSourceSinkEntry.hasWeight && weight != null && !weight.trim.isEmpty) {
176+
if (configs.dataSourceSinkEntry.hasWeight && weight != null && weight.trim.nonEmpty) {
158177
data.select(src, dst, weight)
159178
} else {
160179
data.select(src, dst)

0 commit comments

Comments
 (0)