Skip to content

Commit 6660781

Browse files
committed
introducing algorithm machinery to manage running algorithms and prodding progress tracker the right way
1 parent 04e32cf commit 6660781

File tree

15 files changed

+341
-49
lines changed

15 files changed

+341
-49
lines changed

algo/src/main/java/org/neo4j/gds/paths/traverse/BFS.java

+12-5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.neo4j.gds.collections.ha.HugeLongArray;
3232
import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
3333
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
34+
import org.neo4j.gds.termination.TerminationFlag;
3435

3536
import java.util.ArrayList;
3637
import java.util.Collection;
@@ -93,7 +94,8 @@ public static BFS create(
9394
Aggregator aggregatorFunction,
9495
Concurrency concurrency,
9596
ProgressTracker progressTracker,
96-
long maximumDepth
97+
long maximumDepth,
98+
TerminationFlag terminationFlag
9799
) {
98100
return create(
99101
graph,
@@ -103,7 +105,8 @@ public static BFS create(
103105
concurrency,
104106
progressTracker,
105107
DEFAULT_DELTA,
106-
maximumDepth
108+
maximumDepth,
109+
terminationFlag
107110
);
108111
}
109112

@@ -115,7 +118,8 @@ static BFS create(
115118
Concurrency concurrency,
116119
ProgressTracker progressTracker,
117120
int delta,
118-
long maximumDepth
121+
long maximumDepth,
122+
TerminationFlag terminationFlag
119123
) {
120124

121125
var nodeCount = graph.nodeCount();
@@ -135,7 +139,8 @@ static BFS create(
135139
concurrency,
136140
progressTracker,
137141
delta,
138-
maximumDepth
142+
maximumDepth,
143+
terminationFlag
139144
);
140145
}
141146

@@ -150,7 +155,8 @@ private BFS(
150155
Concurrency concurrency,
151156
ProgressTracker progressTracker,
152157
int delta,
153-
long maximumDepth
158+
long maximumDepth,
159+
TerminationFlag terminationFlag
154160
) {
155161
super(progressTracker);
156162
this.graph = graph;
@@ -163,6 +169,7 @@ private BFS(
163169
this.traversedNodes = traversedNodes;
164170
this.weights = weights;
165171
this.visited = visited;
172+
this.terminationFlag = terminationFlag;
166173
}
167174

168175
@Override

algo/src/main/java/org/neo4j/gds/paths/traverse/BfsAlgorithmFactory.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.neo4j.gds.api.Graph;
2424
import org.neo4j.gds.mem.MemoryEstimation;
2525
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
26+
import org.neo4j.gds.termination.TerminationFlag;
2627

2728
import java.util.List;
2829
import java.util.stream.Collectors;
@@ -59,7 +60,8 @@ public BFS build(Graph graph, CONFIG configuration, ProgressTracker progressTrac
5960
aggregatorFunction,
6061
configuration.concurrency(),
6162
progressTracker,
62-
configuration.maxDepth()
63+
configuration.maxDepth(),
64+
TerminationFlag.RUNNING_TRUE
6365
);
6466
}
6567

algo/src/test/java/org/neo4j/gds/paths/traverse/BFSComplexTreeTest.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.neo4j.gds.extension.GdlGraph;
2929
import org.neo4j.gds.extension.Inject;
3030
import org.neo4j.gds.extension.TestGraph;
31+
import org.neo4j.gds.termination.TerminationFlag;
3132

3233
import java.util.List;
3334
import java.util.stream.Stream;
@@ -113,7 +114,8 @@ void testBfsToTargetOut(int concurrency, int delta) {
113114
new Concurrency(concurrency),
114115
ProgressTracker.NULL_TRACKER,
115116
delta,
116-
BFS.ALL_DEPTHS_ALLOWED
117+
BFS.ALL_DEPTHS_ALLOWED,
118+
TerminationFlag.RUNNING_TRUE
117119
).compute().toArray();
118120

119121
assertThat(nodes)

algo/src/test/java/org/neo4j/gds/paths/traverse/BFSOnBiggerGraphTest.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.neo4j.gds.extension.GdlGraph;
2929
import org.neo4j.gds.extension.Inject;
3030
import org.neo4j.gds.extension.TestGraph;
31+
import org.neo4j.gds.termination.TerminationFlag;
3132

3233
import java.util.stream.Stream;
3334

@@ -112,7 +113,8 @@ void testBfsToTargetOut(int concurrency, int delta) {
112113
new Concurrency(concurrency),
113114
ProgressTracker.NULL_TRACKER,
114115
delta,
115-
BFS.ALL_DEPTHS_ALLOWED
116+
BFS.ALL_DEPTHS_ALLOWED,
117+
TerminationFlag.RUNNING_TRUE
116118
).compute().toArray();
117119

118120
assertThat(nodes)

algo/src/test/java/org/neo4j/gds/paths/traverse/BFSTest.java

+11-5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.neo4j.gds.extension.Inject;
3434
import org.neo4j.gds.extension.TestGraph;
3535
import org.neo4j.gds.paths.traverse.ExitPredicate.Result;
36+
import org.neo4j.gds.termination.TerminationFlag;
3637

3738
import java.util.stream.Stream;
3839

@@ -107,7 +108,8 @@ void testBfsToTargetOut(int concurrency) {
107108
(s, t, w) -> 1.,
108109
new Concurrency(concurrency),
109110
ProgressTracker.NULL_TRACKER,
110-
BFS.ALL_DEPTHS_ALLOWED
111+
BFS.ALL_DEPTHS_ALLOWED,
112+
TerminationFlag.RUNNING_TRUE
111113
).compute().toArray();
112114

113115
// algorithms return mapped ids
@@ -133,7 +135,8 @@ void testBfsToTargetIn(int concurrency) {
133135
Aggregator.NO_AGGREGATION,
134136
new Concurrency(concurrency),
135137
ProgressTracker.NULL_TRACKER,
136-
BFS.ALL_DEPTHS_ALLOWED
138+
BFS.ALL_DEPTHS_ALLOWED,
139+
TerminationFlag.RUNNING_TRUE
137140
).compute().toArray();
138141
assertEquals(7, nodes.length);
139142
}
@@ -156,7 +159,8 @@ void testBfsMaxDepthOut(int concurrency) {
156159
(s, t, w) -> w + 1.,
157160
new Concurrency(concurrency),
158161
ProgressTracker.NULL_TRACKER,
159-
maxHops - 1
162+
maxHops - 1,
163+
TerminationFlag.RUNNING_TRUE
160164
).compute().toArray();
161165

162166
assertThat(nodes).isEqualTo(
@@ -172,7 +176,8 @@ void testBfsOnLoopGraph(int concurrency) {
172176
Aggregator.NO_AGGREGATION,
173177
new Concurrency(concurrency),
174178
ProgressTracker.NULL_TRACKER,
175-
BFS.ALL_DEPTHS_ALLOWED
179+
BFS.ALL_DEPTHS_ALLOWED,
180+
TerminationFlag.RUNNING_TRUE
176181
).compute();
177182
}
178183

@@ -189,7 +194,8 @@ void shouldLogProgress(int concurrency) {
189194
Aggregator.NO_AGGREGATION,
190195
new Concurrency(concurrency),
191196
progressTracker,
192-
BFS.ALL_DEPTHS_ALLOWED
197+
BFS.ALL_DEPTHS_ALLOWED,
198+
TerminationFlag.RUNNING_TRUE
193199
).compute();
194200
var messagesInOrder = testLog.getMessages(INFO);
195201

algo/src/test/java/org/neo4j/gds/paths/traverse/BFSTridentGraphTest.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.neo4j.gds.extension.GdlGraph;
3030
import org.neo4j.gds.extension.Inject;
3131
import org.neo4j.gds.extension.TestGraph;
32+
import org.neo4j.gds.termination.TerminationFlag;
3233

3334
import java.util.Arrays;
3435
import java.util.List;
@@ -96,7 +97,8 @@ void testBfsToTargetOut(int concurrency, int delta) {
9697
new Concurrency(concurrency),
9798
ProgressTracker.NULL_TRACKER,
9899
delta,
99-
BFS.ALL_DEPTHS_ALLOWED
100+
BFS.ALL_DEPTHS_ALLOWED,
101+
TerminationFlag.RUNNING_TRUE
100102
).compute().toArray();
101103

102104
assertThat(nodes)

applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/CentralityAlgorithms.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.neo4j.gds.applications.algorithms.centrality;
2121

2222
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery;
2324
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
2425
import org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking;
2526
import org.neo4j.gds.betweenness.BetweennessCentrality;
@@ -41,6 +42,8 @@
4142
import org.neo4j.gds.termination.TerminationFlag;
4243

4344
public class CentralityAlgorithms {
45+
private final AlgorithmMachinery algorithmMachinery = new AlgorithmMachinery();
46+
4447
private final ProgressTrackerCreator progressTrackerCreator;
4548
private final TerminationFlag terminationFlag;
4649

@@ -79,7 +82,7 @@ BetwennessCentralityResult betweennessCentrality(Graph graph, BetweennessCentral
7982
terminationFlag
8083
);
8184

82-
return algorithm.compute();
85+
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
8386
}
8487

8588
ClosenessCentralityResult closenessCentrality(Graph graph, ClosenessCentralityBaseConfig configuration) {
@@ -103,7 +106,7 @@ ClosenessCentralityResult closenessCentrality(Graph graph, ClosenessCentralityBa
103106
progressTracker
104107
);
105108

106-
return algorithm.compute();
109+
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
107110
}
108111

109112
DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig configuration) {
@@ -122,6 +125,6 @@ DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig conf
122125
progressTracker
123126
);
124127

125-
return algorithm.compute();
128+
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
126129
}
127130
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.applications.algorithms.machinery;
21+
22+
import org.neo4j.gds.Algorithm;
23+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
24+
25+
/**
26+
* I wish this did not exist quite like this; it is where we encapsulate running an algorithm,
27+
* managing termination, and handling (progress tracker) resources.
28+
* Somehow I wish that was encapsulated more naturally, but as you can hear from this use of language,
29+
* the design has not crystallized yet.
30+
* At least nothing here is tied to termination flag.
31+
*/
32+
public class AlgorithmMachinery {
33+
/**
34+
* Runs algorithm.
35+
* Optionally releases progress tracker.
36+
* Exceptionally marks progress tracker state as failed.
37+
*
38+
* @return algorithm result, or an error in the form of an exception
39+
*/
40+
public <RESULT> RESULT runAlgorithmsAndManageProgressTracker(
41+
Algorithm<RESULT> algorithm,
42+
ProgressTracker progressTracker,
43+
boolean shouldReleaseProgressTracker
44+
) {
45+
try {
46+
return algorithm.compute();
47+
} catch (Exception e) {
48+
progressTracker.endSubTaskWithFailure();
49+
throw e;
50+
} finally {
51+
if (shouldReleaseProgressTracker) progressTracker.release();
52+
}
53+
}
54+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.applications.algorithms.machinery;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
24+
25+
import static org.assertj.core.api.Assertions.assertThat;
26+
import static org.junit.jupiter.api.Assertions.fail;
27+
import static org.mockito.Mockito.mock;
28+
import static org.mockito.Mockito.verify;
29+
import static org.mockito.Mockito.verifyNoInteractions;
30+
31+
class AlgorithmMachineryTest {
32+
@Test
33+
void shouldRunAlgorithm() {
34+
var algorithmMachinery = new AlgorithmMachinery();
35+
36+
var progressTracker = mock(ProgressTracker.class);
37+
var result = algorithmMachinery.runAlgorithmsAndManageProgressTracker(
38+
new RegurgitatingAlgorithm("Hello, world!"),
39+
progressTracker,
40+
false
41+
);
42+
43+
assertThat(result).isEqualTo("Hello, world!");
44+
45+
verifyNoInteractions(progressTracker);
46+
}
47+
48+
@Test
49+
void shouldReleaseProgressTrackerWhenAsked() {
50+
var algorithmMachinery = new AlgorithmMachinery();
51+
52+
var progressTracker = mock(ProgressTracker.class);
53+
var result = algorithmMachinery.runAlgorithmsAndManageProgressTracker(
54+
new RegurgitatingAlgorithm("Dodgers win world series!"),
55+
progressTracker,
56+
true
57+
);
58+
59+
assertThat(result).isEqualTo("Dodgers win world series!");
60+
61+
verify(progressTracker).release();
62+
}
63+
64+
@Test
65+
void shouldMarkProgressTracker() {
66+
var algorithmMachinery = new AlgorithmMachinery();
67+
68+
var progressTracker = mock(ProgressTracker.class);
69+
var exception = new RuntimeException("Whoops!");
70+
try {
71+
algorithmMachinery.runAlgorithmsAndManageProgressTracker(
72+
new FailingAlgorithm(exception),
73+
progressTracker,
74+
false
75+
);
76+
fail();
77+
} catch (Exception e) {
78+
assertThat(e).hasMessage("Whoops!");
79+
}
80+
81+
verify(progressTracker).endSubTaskWithFailure();
82+
}
83+
84+
@Test
85+
void shouldMarkProgressTrackerAndReleaseIt() {
86+
var algorithmMachinery = new AlgorithmMachinery();
87+
88+
var progressTracker = mock(ProgressTracker.class);
89+
var exception = new RuntimeException("Yeah, no...");
90+
try {
91+
algorithmMachinery.runAlgorithmsAndManageProgressTracker(
92+
new FailingAlgorithm(exception),
93+
progressTracker,
94+
true
95+
);
96+
fail();
97+
} catch (Exception e) {
98+
assertThat(e).hasMessage("Yeah, no...");
99+
}
100+
101+
verify(progressTracker).endSubTaskWithFailure();
102+
verify(progressTracker).release();
103+
}
104+
}

0 commit comments

Comments
 (0)