Skip to content

Commit 727ac7c

Browse files
committed
Create ml-facade-api module
1 parent 723fd93 commit 727ac7c

File tree

29 files changed

+240
-132
lines changed

29 files changed

+240
-132
lines changed

algo/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies {
3939
// TODO: move the factories out
4040
implementation project(':centrality-configs')
4141
implementation project(':community-configs')
42+
implementation project(':ml-configs')
4243
implementation project(':node-embeddings-configs')
4344
implementation project(':path-finding-configs')
4445
implementation project(':similarity-configs')
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.algorithms.machinelearning;
21+
22+
public final class KGEPredictConfigTransformer {
23+
24+
25+
private KGEPredictConfigTransformer() {}
26+
27+
public static KGEPredictParameters toParameters(KGEPredictBaseConfig config) {
28+
return new KGEPredictParameters(
29+
config.concurrency(),
30+
config.sourceNodeFilter(),
31+
config.targetNodeFilter(),
32+
config.relationshipTypesFilter(),
33+
config.relationshipTypeEmbedding(),
34+
config.nodeEmbeddingProperty(),
35+
config.scoringFunction(),
36+
config.topK()
37+
);
38+
}
39+
}

applications/algorithms/machine-learning/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies {
1414
implementation project(":logging")
1515
implementation project(":memory-usage")
1616
implementation project(":ml-algo")
17+
implementation project(":ml-configs")
1718
implementation project(":progress-tracking")
1819
implementation project(":similarity-configs")
1920
implementation project(":termination")

applications/algorithms/machine-learning/src/main/java/org/neo4j/gds/applications/algorithms/machinelearning/MachineLearningAlgorithms.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.carrotsearch.hppc.BitSet;
2323
import org.neo4j.gds.algorithms.machinelearning.KGEPredictBaseConfig;
2424
import org.neo4j.gds.algorithms.machinelearning.KGEPredictResult;
25+
import org.neo4j.gds.algorithms.machinelearning.KGEPredictConfigTransformer;
2526
import org.neo4j.gds.algorithms.machinelearning.TopKMapComputer;
2627
import org.neo4j.gds.api.Graph;
2728
import org.neo4j.gds.api.GraphStore;
@@ -53,7 +54,7 @@ KGEPredictResult kge(Graph graph, KGEPredictBaseConfig configuration) {
5354

5455
var sourceNodes = new BitSet(graph.nodeCount());
5556
var targetNodes = new BitSet(graph.nodeCount());
56-
var parameters = configuration.toParameters();
57+
var parameters = KGEPredictConfigTransformer.toParameters(configuration);
5758
var sourceNodeFilter = parameters.sourceNodeFilter().toNodeFilter(graph);
5859
var targetNodeFilter = parameters.targetNodeFilter().toNodeFilter(graph);
5960
graph.forEachNode(node -> {

applications/algorithms/machine-learning/src/main/java/org/neo4j/gds/applications/algorithms/machinelearning/MachineLearningAlgorithmsEstimationModeBusinessFacade.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.neo4j.gds.exceptions.MemoryEstimationNotImplementedException;
2323
import org.neo4j.gds.mem.MemoryEstimation;
24+
import org.neo4j.gds.ml.splitting.SplitRelationshipConfigTransformer;
2425
import org.neo4j.gds.ml.splitting.SplitRelationshipsEstimateDefinition;
2526
import org.neo4j.gds.ml.splitting.SplitRelationshipsMutateConfig;
2627

@@ -30,6 +31,6 @@ public MemoryEstimation kge() {
3031
}
3132

3233
public MemoryEstimation splitRelationships(SplitRelationshipsMutateConfig configuration) {
33-
return new SplitRelationshipsEstimateDefinition(configuration.toMemoryEstimateParameters()).memoryEstimation();
34+
return new SplitRelationshipsEstimateDefinition(SplitRelationshipConfigTransformer.toMemoryEstimateParameters(configuration)).memoryEstimation();
3435
}
3536
}

ml/ml-algo/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ dependencies {
2020
}
2121
}
2222

23+
implementation project(':ml-configs')
24+
2325
implementation project(':algo-common')
2426
implementation project(':annotations')
2527
implementation project(':config-api')
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.ml.splitting;
21+
22+
public final class SplitRelationshipConfigTransformer {
23+
24+
private SplitRelationshipConfigTransformer() {}
25+
26+
public static SplitRelationshipsEstimateParameters toMemoryEstimateParameters(SplitRelationshipsBaseConfig config) {
27+
return new SplitRelationshipsEstimateParameters(
28+
config.hasRelationshipWeightProperty(),
29+
config.relationshipTypes(),
30+
config.negativeSamplingRatio(),
31+
config.holdoutFraction()
32+
);
33+
}
34+
35+
public static SplitRelationshipsParameters toParameters(SplitRelationshipsBaseConfig config) {
36+
return new SplitRelationshipsParameters(
37+
config.concurrency(),
38+
config.holdoutRelationshipType(),
39+
config.remainingRelationshipType(),
40+
config.holdoutFraction(),
41+
config.relationshipWeightProperty(),
42+
config.negativeSamplingRatio(),
43+
config.randomSeed()
44+
);
45+
}
46+
}

ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/SplitRelationships.java

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,14 @@ public final class SplitRelationships extends Algorithm<EdgeSplitter.SplitResult
4343
private final IdMap targetNodes;
4444
private final SplitRelationshipsParameters parameters;
4545

46-
private SplitRelationships(Graph graph, Graph masterGraph,
47-
IdMap rootNodes,
48-
IdMap sourceNodes, IdMap targetNodes, SplitRelationshipsParameters parameters) {
46+
private SplitRelationships(
47+
Graph graph,
48+
Graph masterGraph,
49+
IdMap rootNodes,
50+
IdMap sourceNodes,
51+
IdMap targetNodes,
52+
SplitRelationshipsParameters parameters
53+
) {
4954
super(ProgressTracker.NULL_TRACKER);
5055
this.graph = graph;
5156
this.masterGraph = masterGraph;
@@ -68,40 +73,40 @@ public static SplitRelationships of(GraphStore graphStore, SplitRelationshipsBas
6873
IdMap sourceNodes = graphStore.getGraph(sourceLabels);
6974
IdMap targetNodes = graphStore.getGraph(targetLabels);
7075

71-
return new SplitRelationships(graph, masterGraph, graphStore.nodes(), sourceNodes, targetNodes, config.toParameters());
76+
return new SplitRelationships(graph,
77+
masterGraph,
78+
graphStore.nodes(),
79+
sourceNodes,
80+
targetNodes,
81+
SplitRelationshipConfigTransformer.toParameters(config)
82+
);
7283
}
7384

7485
@Override
7586
public EdgeSplitter.SplitResult compute() {
7687
boolean isUndirected = graph.schema().isUndirected();
77-
var splitter = isUndirected
78-
? new UndirectedEdgeSplitter(
79-
parameters.randomSeed(),
88+
var splitter = isUndirected ? new UndirectedEdgeSplitter(parameters.randomSeed(),
8089
rootNodes,
8190
sourceNodes,
8291
targetNodes,
8392
parameters.holdoutRelationshipType(),
8493
parameters.remainingRelationshipType(),
8594
parameters.concurrency()
86-
)
87-
: new DirectedEdgeSplitter(
88-
parameters.randomSeed(),
89-
rootNodes,
90-
sourceNodes,
91-
targetNodes,
92-
parameters.holdoutRelationshipType(),
93-
parameters.remainingRelationshipType(),
94-
parameters.concurrency()
95-
);
96-
97-
var splitResult = splitter.splitPositiveExamples(
98-
graph,
95+
) : new DirectedEdgeSplitter(parameters.randomSeed(),
96+
rootNodes,
97+
sourceNodes,
98+
targetNodes,
99+
parameters.holdoutRelationshipType(),
100+
parameters.remainingRelationshipType(),
101+
parameters.concurrency()
102+
);
103+
104+
var splitResult = splitter.splitPositiveExamples(graph,
99105
parameters.holdoutFraction(),
100106
parameters.relationshipWeightProperty()
101107
);
102108

103-
NegativeSampler negativeSampler = new RandomNegativeSampler(
104-
masterGraph,
109+
NegativeSampler negativeSampler = new RandomNegativeSampler(masterGraph,
105110
(long) (splitResult.selectedRelCount() * parameters.negativeSamplingRatio()),
106111
//SplitRelationshipsProc does not add negative samples to holdout set
107112
0,

ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/SplitRelationshipsTest.java

Lines changed: 30 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -46,30 +46,7 @@
4646
class SplitRelationshipsTest {
4747

4848
@GdlGraph(orientation = Orientation.UNDIRECTED, idOffset = 42)
49-
static final String GRAPH_WITH_OFFSET =
50-
" CREATE" +
51-
" (a: Node)," +
52-
" (b: Node)," +
53-
" (c: Node)," +
54-
" (d: Node)," +
55-
" (e: Node)," +
56-
" (f: Node)," +
57-
" (g: Node)," +
58-
" (a)-[:REL]->(b)," +
59-
" (a)-[:REL]->(c)," +
60-
" (a)-[:REL]->(d)," +
61-
" (a)-[:REL]->(e)," +
62-
" (a)-[:REL]->(f)," +
63-
" (a)-[:REL]->(g)," +
64-
" (b)-[:REL]->(c)," +
65-
" (b)-[:REL]->(d)," +
66-
" (b)-[:REL]->(e)," +
67-
" (b)-[:REL]->(f)," +
68-
" (b)-[:REL2]->(g)," +
69-
" (c)-[:REL2]->(d)," +
70-
" (c)-[:REL2]->(e)," +
71-
" (c)-[:REL2]->(f)," +
72-
" (c)-[:REL2]->(g)";
49+
static final String GRAPH_WITH_OFFSET = " CREATE" + " (a: Node)," + " (b: Node)," + " (c: Node)," + " (d: Node)," + " (e: Node)," + " (f: Node)," + " (g: Node)," + " (a)-[:REL]->(b)," + " (a)-[:REL]->(c)," + " (a)-[:REL]->(d)," + " (a)-[:REL]->(e)," + " (a)-[:REL]->(f)," + " (a)-[:REL]->(g)," + " (b)-[:REL]->(c)," + " (b)-[:REL]->(d)," + " (b)-[:REL]->(e)," + " (b)-[:REL]->(f)," + " (b)-[:REL2]->(g)," + " (c)-[:REL2]->(d)," + " (c)-[:REL2]->(e)," + " (c)-[:REL2]->(f)," + " (c)-[:REL2]->(g)";
7350

7451
@Inject
7552
private GraphStore graphStore;
@@ -96,16 +73,11 @@ void computeWithOffset() {
9673
@Test
9774
void estimate() {
9875
var graphDimensions = GraphDimensions.of(1, 10_000);
99-
var config = SplitRelationshipsMutateConfigImpl
100-
.builder()
101-
.negativeSamplingRatio(1.0)
102-
.holdoutRelationshipType("HOLDOUT")
103-
.remainingRelationshipType("REST")
104-
.holdoutFraction(0.3)
105-
.relationshipTypes(List.of(ElementProjection.PROJECT_ALL))
106-
.build();
76+
var config = SplitRelationshipsMutateConfigImpl.builder().negativeSamplingRatio(1.0).holdoutRelationshipType(
77+
"HOLDOUT").remainingRelationshipType("REST").holdoutFraction(0.3).relationshipTypes(List.of(
78+
ElementProjection.PROJECT_ALL)).build();
10779

108-
MemoryTree actualEstimate = new SplitRelationshipsEstimateDefinition(config.toMemoryEstimateParameters())
80+
MemoryTree actualEstimate = new SplitRelationshipsEstimateDefinition(SplitRelationshipConfigTransformer.toMemoryEstimateParameters(config))
10981
.memoryEstimation()
11082
.estimate(graphDimensions, config.concurrency());
11183

@@ -127,27 +99,24 @@ public static Stream<Arguments> withTypesParams() {
12799
void estimateWithTypes(List<String> relTypes, MemoryRange expectedMemory) {
128100
var nodeCount = 100;
129101
var relationshipCounts = Map.of(
130-
RelationshipType.of("TYPE1"), 10L,
131-
RelationshipType.of("TYPE2"), 20L,
132-
RelationshipType.of("TYPE3"), 30L
102+
RelationshipType.of("TYPE1"),
103+
10L,
104+
RelationshipType.of("TYPE2"),
105+
20L,
106+
RelationshipType.of("TYPE3"),
107+
30L
133108
);
134109

135-
var graphDimensions = ImmutableGraphDimensions.builder()
136-
.nodeCount(nodeCount)
137-
.relationshipCounts(relationshipCounts)
138-
.relCountUpperBound(relationshipCounts.values().stream().mapToLong(Long::longValue).sum())
139-
.build();
110+
var graphDimensions = ImmutableGraphDimensions.builder().nodeCount(nodeCount).relationshipCounts(
111+
relationshipCounts).relCountUpperBound(relationshipCounts.values()
112+
.stream()
113+
.mapToLong(Long::longValue)
114+
.sum()).build();
140115

141-
var config = SplitRelationshipsMutateConfigImpl
142-
.builder()
143-
.negativeSamplingRatio(1.0)
144-
.holdoutRelationshipType("HOLDOUT")
145-
.remainingRelationshipType("REST")
146-
.holdoutFraction(0.3)
147-
.relationshipTypes(relTypes)
148-
.build();
116+
var config = SplitRelationshipsMutateConfigImpl.builder().negativeSamplingRatio(1.0).holdoutRelationshipType(
117+
"HOLDOUT").remainingRelationshipType("REST").holdoutFraction(0.3).relationshipTypes(relTypes).build();
149118

150-
MemoryTree actualEstimate = new SplitRelationshipsEstimateDefinition(config.toMemoryEstimateParameters())
119+
MemoryTree actualEstimate = new SplitRelationshipsEstimateDefinition(SplitRelationshipConfigTransformer.toMemoryEstimateParameters(config))
151120
.memoryEstimation()
152121
.estimate(graphDimensions, config.concurrency());
153122

@@ -156,23 +125,18 @@ void estimateWithTypes(List<String> relTypes, MemoryRange expectedMemory) {
156125

157126
@Test
158127
void estimateIndependentOfNodeCount() {
159-
var config = SplitRelationshipsMutateConfigImpl
160-
.builder()
161-
.negativeSamplingRatio(1.0)
162-
.holdoutRelationshipType("HOLDOUT")
163-
.remainingRelationshipType("REST")
164-
.holdoutFraction(0.3)
165-
.relationshipTypes(List.of(ElementProjection.PROJECT_ALL))
166-
.build();
128+
var config = SplitRelationshipsMutateConfigImpl.builder().negativeSamplingRatio(1.0).holdoutRelationshipType(
129+
"HOLDOUT").remainingRelationshipType("REST").holdoutFraction(0.3).relationshipTypes(List.of(
130+
ElementProjection.PROJECT_ALL)).build();
167131

168132
var graphDimensions = GraphDimensions.of(1, 10_000);
169-
MemoryTree actualEstimate = new SplitRelationshipsEstimateDefinition(config.toMemoryEstimateParameters())
133+
MemoryTree actualEstimate = new SplitRelationshipsEstimateDefinition(SplitRelationshipConfigTransformer.toMemoryEstimateParameters(config))
170134
.memoryEstimation()
171135
.estimate(graphDimensions, config.concurrency());
172136
assertMemoryRange(actualEstimate.memoryUsage(), MemoryRange.of(160_000, 208_000));
173137

174138
graphDimensions = GraphDimensions.of(100_000, 10_000);
175-
actualEstimate = new SplitRelationshipsEstimateDefinition(config.toMemoryEstimateParameters())
139+
actualEstimate = new SplitRelationshipsEstimateDefinition(SplitRelationshipConfigTransformer.toMemoryEstimateParameters(config))
176140
.memoryEstimation()
177141
.estimate(graphDimensions, config.concurrency());
178142
assertMemoryRange(actualEstimate.memoryUsage(), MemoryRange.of(160_000, 208_000));
@@ -182,22 +146,21 @@ void estimateIndependentOfNodeCount() {
182146
void estimateDifferentSamplingRatios() {
183147
var graphDimensions = GraphDimensions.of(1, 10_000);
184148

185-
var configBuilder = SplitRelationshipsMutateConfigImpl
186-
.builder()
149+
var configBuilder = SplitRelationshipsMutateConfigImpl.builder()
187150
.holdoutRelationshipType("HOLDOUT")
188151
.remainingRelationshipType("REST")
189152
.holdoutFraction(0.3)
190153
.relationshipTypes(List.of(ElementProjection.PROJECT_ALL));
191154

192155
var config = configBuilder.negativeSamplingRatio(1.0).build();
193-
MemoryTree actualEstimate = new SplitRelationshipsEstimateDefinition(config.toMemoryEstimateParameters())
156+
MemoryTree actualEstimate = new SplitRelationshipsEstimateDefinition(SplitRelationshipConfigTransformer.toMemoryEstimateParameters(config))
194157
.memoryEstimation()
195158
.estimate(graphDimensions, config.concurrency());
196159

197160
assertMemoryRange(actualEstimate.memoryUsage(), MemoryRange.of(160_000, 208_000));
198161

199162
config = configBuilder.negativeSamplingRatio(2.0).build();
200-
actualEstimate = new SplitRelationshipsEstimateDefinition(config.toMemoryEstimateParameters())
163+
actualEstimate = new SplitRelationshipsEstimateDefinition(SplitRelationshipConfigTransformer.toMemoryEstimateParameters(config))
201164
.memoryEstimation()
202165
.estimate(graphDimensions, config.concurrency());
203166

@@ -208,21 +171,20 @@ void estimateDifferentSamplingRatios() {
208171
void estimateDifferentHoldoutFractions() {
209172
var graphDimensions = GraphDimensions.of(1, 10_000);
210173

211-
var configBuilder = SplitRelationshipsMutateConfigImpl
212-
.builder()
174+
var configBuilder = SplitRelationshipsMutateConfigImpl.builder()
213175
.holdoutRelationshipType("HOLDOUT")
214176
.remainingRelationshipType("REST")
215177
.negativeSamplingRatio(1.0)
216178
.relationshipTypes(List.of(ElementProjection.PROJECT_ALL));
217179

218180
var config = configBuilder.holdoutFraction(0.3).build();
219-
MemoryTree actualEstimate = new SplitRelationshipsEstimateDefinition(config.toMemoryEstimateParameters())
181+
MemoryTree actualEstimate = new SplitRelationshipsEstimateDefinition(SplitRelationshipConfigTransformer.toMemoryEstimateParameters(config))
220182
.memoryEstimation()
221183
.estimate(graphDimensions, config.concurrency());
222184
assertMemoryRange(actualEstimate.memoryUsage(), MemoryRange.of(160_000, 208_000));
223185

224186
config = configBuilder.holdoutFraction(0.1).build();
225-
actualEstimate = new SplitRelationshipsEstimateDefinition(config.toMemoryEstimateParameters())
187+
actualEstimate = new SplitRelationshipsEstimateDefinition(SplitRelationshipConfigTransformer.toMemoryEstimateParameters(config))
226188
.memoryEstimation()
227189
.estimate(graphDimensions, config.concurrency());
228190

0 commit comments

Comments
 (0)