6
6
package com .vesoft .nebula .algorithm
7
7
8
8
import com .vesoft .nebula .algorithm .config .Configs .Argument
9
- import com .vesoft .nebula .algorithm .config .{
10
- AlgoConfig ,
11
- BetweennessConfig ,
12
- BfsConfig ,
13
- CcConfig ,
14
- CoefficientConfig ,
15
- Configs ,
16
- DfsConfig ,
17
- HanpConfig ,
18
- JaccardConfig ,
19
- KCoreConfig ,
20
- LPAConfig ,
21
- LouvainConfig ,
22
- Node2vecConfig ,
23
- PRConfig ,
24
- ShortestPathConfig ,
25
- SparkConfig ,
26
- DegreeStaticConfig
27
- }
28
- import com .vesoft .nebula .algorithm .lib .{
29
- BetweennessCentralityAlgo ,
30
- BfsAlgo ,
31
- ClosenessAlgo ,
32
- ClusteringCoefficientAlgo ,
33
- ConnectedComponentsAlgo ,
34
- DegreeStaticAlgo ,
35
- DfsAlgo ,
36
- GraphTriangleCountAlgo ,
37
- HanpAlgo ,
38
- JaccardAlgo ,
39
- KCoreAlgo ,
40
- LabelPropagationAlgo ,
41
- LouvainAlgo ,
42
- Node2vecAlgo ,
43
- PageRankAlgo ,
44
- ShortestPathAlgo ,
45
- StronglyConnectedComponentsAlgo ,
46
- TriangleCountAlgo
47
- }
48
- import com .vesoft .nebula .algorithm .reader .{CsvReader , JsonReader , NebulaReader }
49
- import com .vesoft .nebula .algorithm .writer .{CsvWriter , NebulaWriter , TextWriter }
9
+ import com .vesoft .nebula .algorithm .config ._
10
+ import com .vesoft .nebula .algorithm .lib ._
11
+ import com .vesoft .nebula .algorithm .reader .{CsvReader , DataReader , JsonReader , NebulaReader }
12
+ import com .vesoft .nebula .algorithm .writer .{AlgoWriter , CsvWriter , NebulaWriter , TextWriter }
50
13
import org .apache .commons .math3 .ode .UnknownParameterException
51
14
import org .apache .log4j .Logger
52
15
import org .apache .spark .sql .{DataFrame , Dataset , Row , SparkSession }
@@ -114,26 +77,8 @@ object Main {
114
77
private [this ] def createDataSource (spark : SparkSession ,
115
78
configs : Configs ,
116
79
partitionNum : String ): DataFrame = {
117
- val dataSource = configs.dataSourceSinkEntry.source
118
- val dataSet : Dataset [Row ] = dataSource.toLowerCase match {
119
- case " nebula" => {
120
- val reader = new NebulaReader (spark, configs, partitionNum)
121
- reader.read()
122
- }
123
- case " nebula-ngql" => {
124
- val reader = new NebulaReader (spark, configs, partitionNum)
125
- reader.readNgql()
126
- }
127
- case " csv" => {
128
- val reader = new CsvReader (spark, configs, partitionNum)
129
- reader.read()
130
- }
131
- case " json" => {
132
- val reader = new JsonReader (spark, configs, partitionNum)
133
- reader.read()
134
- }
135
- }
136
- dataSet
80
+ val dataSource = DataReader .make(configs)
81
+ dataSource.read(spark, configs, partitionNum)
137
82
}
138
83
139
84
/**
@@ -149,99 +94,63 @@ object Main {
149
94
configs : Configs ,
150
95
dataSet : DataFrame ): DataFrame = {
151
96
val hasWeight = configs.dataSourceSinkEntry.hasWeight
152
- val algoResult = {
153
- algoName.toLowerCase match {
154
- case " pagerank" => {
155
- val pageRankConfig = PRConfig .getPRConfig(configs)
156
- PageRankAlgo (spark, dataSet, pageRankConfig, hasWeight)
157
- }
158
- case " louvain" => {
159
- val louvainConfig = LouvainConfig .getLouvainConfig(configs)
160
- LouvainAlgo (spark, dataSet, louvainConfig, hasWeight)
161
- }
162
- case " connectedcomponent" => {
163
- val ccConfig = CcConfig .getCcConfig(configs)
164
- ConnectedComponentsAlgo (spark, dataSet, ccConfig, hasWeight)
165
- }
166
- case " labelpropagation" => {
167
- val lpaConfig = LPAConfig .getLPAConfig(configs)
168
- LabelPropagationAlgo (spark, dataSet, lpaConfig, hasWeight)
169
- }
170
- case " shortestpaths" => {
171
- val spConfig = ShortestPathConfig .getShortestPathConfig(configs)
172
- ShortestPathAlgo (spark, dataSet, spConfig, hasWeight)
173
- }
174
- case " degreestatic" => {
175
- val dsConfig = DegreeStaticConfig .getDegreeStaticConfig(configs)
176
- DegreeStaticAlgo (spark, dataSet, dsConfig)
177
- }
178
- case " kcore" => {
179
- val kCoreConfig = KCoreConfig .getKCoreConfig(configs)
180
- KCoreAlgo (spark, dataSet, kCoreConfig)
181
- }
182
- case " stronglyconnectedcomponent" => {
183
- val ccConfig = CcConfig .getCcConfig(configs)
184
- StronglyConnectedComponentsAlgo (spark, dataSet, ccConfig, hasWeight)
185
- }
186
- case " betweenness" => {
187
- val betweennessConfig = BetweennessConfig .getBetweennessConfig(configs)
188
- BetweennessCentralityAlgo (spark, dataSet, betweennessConfig, hasWeight)
189
- }
190
- case " trianglecount" => {
191
- TriangleCountAlgo (spark, dataSet)
192
- }
193
- case " graphtrianglecount" => {
194
- GraphTriangleCountAlgo (spark, dataSet)
195
- }
196
- case " clusteringcoefficient" => {
197
- val coefficientConfig = CoefficientConfig .getCoefficientConfig(configs)
198
- ClusteringCoefficientAlgo (spark, dataSet, coefficientConfig)
199
- }
200
- case " closeness" => {
201
- ClosenessAlgo (spark, dataSet, hasWeight)
202
- }
203
- case " hanp" => {
204
- val hanpConfig = HanpConfig .getHanpConfig(configs)
205
- HanpAlgo (spark, dataSet, hanpConfig, hasWeight)
206
- }
207
- case " node2vec" => {
208
- val node2vecConfig = Node2vecConfig .getNode2vecConfig(configs)
209
- Node2vecAlgo (spark, dataSet, node2vecConfig, hasWeight)
210
- }
211
- case " bfs" => {
212
- val bfsConfig = BfsConfig .getBfsConfig(configs)
213
- BfsAlgo (spark, dataSet, bfsConfig)
214
- }
215
- case " dfs" => {
216
- val dfsConfig = DfsConfig .getDfsConfig(configs)
217
- DfsAlgo (spark, dataSet, dfsConfig)
218
- }
219
- case " jaccard" => {
220
- val jaccardConfig = JaccardConfig .getJaccardConfig(configs)
221
- JaccardAlgo (spark, dataSet, jaccardConfig)
222
- }
223
- case _ => throw new UnknownParameterException (" unknown executeAlgo name." )
224
- }
97
+ AlgorithmType .mapping.getOrElse(algoName.toLowerCase, throw new UnknownParameterException (" unknown executeAlgo name." )) match {
98
+ case AlgorithmType .Bfs =>
99
+ val bfsConfig = BfsConfig .getBfsConfig(configs)
100
+ BfsAlgo (spark, dataSet, bfsConfig)
101
+ case AlgorithmType .Closeness =>
102
+ ClosenessAlgo (spark, dataSet, hasWeight)
103
+ case AlgorithmType .ClusteringCoefficient =>
104
+ val coefficientConfig = CoefficientConfig .getCoefficientConfig(configs)
105
+ ClusteringCoefficientAlgo (spark, dataSet, coefficientConfig)
106
+ case AlgorithmType .ConnectedComponents =>
107
+ val ccConfig = CcConfig .getCcConfig(configs)
108
+ ConnectedComponentsAlgo (spark, dataSet, ccConfig, hasWeight)
109
+ case AlgorithmType .DegreeStatic =>
110
+ val dsConfig = DegreeStaticConfig .getDegreeStaticConfig(configs)
111
+ DegreeStaticAlgo (spark, dataSet, dsConfig)
112
+ case AlgorithmType .Dfs =>
113
+ val dfsConfig = DfsConfig .getDfsConfig(configs)
114
+ DfsAlgo (spark, dataSet, dfsConfig)
115
+ case AlgorithmType .GraphTriangleCount =>
116
+ GraphTriangleCountAlgo (spark, dataSet)
117
+ case AlgorithmType .Hanp =>
118
+ val hanpConfig = HanpConfig .getHanpConfig(configs)
119
+ HanpAlgo (spark, dataSet, hanpConfig, hasWeight)
120
+ case AlgorithmType .Jaccard =>
121
+ val jaccardConfig = JaccardConfig .getJaccardConfig(configs)
122
+ JaccardAlgo (spark, dataSet, jaccardConfig)
123
+ case AlgorithmType .KCore =>
124
+ val kCoreConfig = KCoreConfig .getKCoreConfig(configs)
125
+ KCoreAlgo (spark, dataSet, kCoreConfig)
126
+ case AlgorithmType .LabelPropagation =>
127
+ val lpaConfig = LPAConfig .getLPAConfig(configs)
128
+ LabelPropagationAlgo (spark, dataSet, lpaConfig, hasWeight)
129
+ case AlgorithmType .Louvain =>
130
+ val louvainConfig = LouvainConfig .getLouvainConfig(configs)
131
+ LouvainAlgo (spark, dataSet, louvainConfig, hasWeight)
132
+ case AlgorithmType .Node2vec =>
133
+ val node2vecConfig = Node2vecConfig .getNode2vecConfig(configs)
134
+ Node2vecAlgo (spark, dataSet, node2vecConfig, hasWeight)
135
+ case AlgorithmType .PageRank =>
136
+ val pageRankConfig = PRConfig .getPRConfig(configs)
137
+ PageRankAlgo (spark, dataSet, pageRankConfig, hasWeight)
138
+ case AlgorithmType .ShortestPath =>
139
+ val spConfig = ShortestPathConfig .getShortestPathConfig(configs)
140
+ ShortestPathAlgo (spark, dataSet, spConfig, hasWeight)
141
+ case AlgorithmType .StronglyConnectedComponents =>
142
+ val ccConfig = CcConfig .getCcConfig(configs)
143
+ StronglyConnectedComponentsAlgo (spark, dataSet, ccConfig, hasWeight)
144
+ case AlgorithmType .TriangleCount =>
145
+ TriangleCountAlgo (spark, dataSet)
146
+ case AlgorithmType .BetweennessCentrality =>
147
+ val betweennessConfig = BetweennessConfig .getBetweennessConfig(configs)
148
+ BetweennessCentralityAlgo (spark, dataSet, betweennessConfig, hasWeight)
225
149
}
226
- algoResult
227
150
}
228
151
229
152
private [this ] def saveAlgoResult (algoResult : DataFrame , configs : Configs ): Unit = {
230
- val dataSink = configs.dataSourceSinkEntry.sink
231
- dataSink.toLowerCase match {
232
- case " nebula" => {
233
- val writer = new NebulaWriter (algoResult, configs)
234
- writer.write()
235
- }
236
- case " csv" => {
237
- val writer = new CsvWriter (algoResult, configs)
238
- writer.write()
239
- }
240
- case " text" => {
241
- val writer = new TextWriter (algoResult, configs)
242
- writer.write()
243
- }
244
- case _ => throw new UnsupportedOperationException (" unsupported data sink" )
245
- }
153
+ val writer = AlgoWriter .make(configs)
154
+ writer.write(algoResult, configs)
246
155
}
247
156
}
0 commit comments