Skip to content

Commit

Permalink
introducing algorithm machinery to manage running algorithms and prod…
Browse files Browse the repository at this point in the history
…ding progress tracker the right way
  • Loading branch information
lassewesth committed May 23, 2024
1 parent 04e32cf commit 6660781
Show file tree
Hide file tree
Showing 15 changed files with 341 additions and 49 deletions.
17 changes: 12 additions & 5 deletions algo/src/main/java/org/neo4j/gds/paths/traverse/BFS.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -93,7 +94,8 @@ public static BFS create(
Aggregator aggregatorFunction,
Concurrency concurrency,
ProgressTracker progressTracker,
long maximumDepth
long maximumDepth,
TerminationFlag terminationFlag
) {
return create(
graph,
Expand All @@ -103,7 +105,8 @@ public static BFS create(
concurrency,
progressTracker,
DEFAULT_DELTA,
maximumDepth
maximumDepth,
terminationFlag
);
}

Expand All @@ -115,7 +118,8 @@ static BFS create(
Concurrency concurrency,
ProgressTracker progressTracker,
int delta,
long maximumDepth
long maximumDepth,
TerminationFlag terminationFlag
) {

var nodeCount = graph.nodeCount();
Expand All @@ -135,7 +139,8 @@ static BFS create(
concurrency,
progressTracker,
delta,
maximumDepth
maximumDepth,
terminationFlag
);
}

Expand All @@ -150,7 +155,8 @@ private BFS(
Concurrency concurrency,
ProgressTracker progressTracker,
int delta,
long maximumDepth
long maximumDepth,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.graph = graph;
Expand All @@ -163,6 +169,7 @@ private BFS(
this.traversedNodes = traversedNodes;
this.weights = weights;
this.visited = visited;
this.terminationFlag = terminationFlag;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -59,7 +60,8 @@ public BFS build(Graph graph, CONFIG configuration, ProgressTracker progressTrac
aggregatorFunction,
configuration.concurrency(),
progressTracker,
configuration.maxDepth()
configuration.maxDepth(),
TerminationFlag.RUNNING_TRUE
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.neo4j.gds.extension.GdlGraph;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.TestGraph;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;
import java.util.stream.Stream;
Expand Down Expand Up @@ -113,7 +114,8 @@ void testBfsToTargetOut(int concurrency, int delta) {
new Concurrency(concurrency),
ProgressTracker.NULL_TRACKER,
delta,
BFS.ALL_DEPTHS_ALLOWED
BFS.ALL_DEPTHS_ALLOWED,
TerminationFlag.RUNNING_TRUE
).compute().toArray();

assertThat(nodes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.neo4j.gds.extension.GdlGraph;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.TestGraph;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.stream.Stream;

Expand Down Expand Up @@ -112,7 +113,8 @@ void testBfsToTargetOut(int concurrency, int delta) {
new Concurrency(concurrency),
ProgressTracker.NULL_TRACKER,
delta,
BFS.ALL_DEPTHS_ALLOWED
BFS.ALL_DEPTHS_ALLOWED,
TerminationFlag.RUNNING_TRUE
).compute().toArray();

assertThat(nodes)
Expand Down
16 changes: 11 additions & 5 deletions algo/src/test/java/org/neo4j/gds/paths/traverse/BFSTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.TestGraph;
import org.neo4j.gds.paths.traverse.ExitPredicate.Result;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.stream.Stream;

Expand Down Expand Up @@ -107,7 +108,8 @@ void testBfsToTargetOut(int concurrency) {
(s, t, w) -> 1.,
new Concurrency(concurrency),
ProgressTracker.NULL_TRACKER,
BFS.ALL_DEPTHS_ALLOWED
BFS.ALL_DEPTHS_ALLOWED,
TerminationFlag.RUNNING_TRUE
).compute().toArray();

// algorithms return mapped ids
Expand All @@ -133,7 +135,8 @@ void testBfsToTargetIn(int concurrency) {
Aggregator.NO_AGGREGATION,
new Concurrency(concurrency),
ProgressTracker.NULL_TRACKER,
BFS.ALL_DEPTHS_ALLOWED
BFS.ALL_DEPTHS_ALLOWED,
TerminationFlag.RUNNING_TRUE
).compute().toArray();
assertEquals(7, nodes.length);
}
Expand All @@ -156,7 +159,8 @@ void testBfsMaxDepthOut(int concurrency) {
(s, t, w) -> w + 1.,
new Concurrency(concurrency),
ProgressTracker.NULL_TRACKER,
maxHops - 1
maxHops - 1,
TerminationFlag.RUNNING_TRUE
).compute().toArray();

assertThat(nodes).isEqualTo(
Expand All @@ -172,7 +176,8 @@ void testBfsOnLoopGraph(int concurrency) {
Aggregator.NO_AGGREGATION,
new Concurrency(concurrency),
ProgressTracker.NULL_TRACKER,
BFS.ALL_DEPTHS_ALLOWED
BFS.ALL_DEPTHS_ALLOWED,
TerminationFlag.RUNNING_TRUE
).compute();
}

Expand All @@ -189,7 +194,8 @@ void shouldLogProgress(int concurrency) {
Aggregator.NO_AGGREGATION,
new Concurrency(concurrency),
progressTracker,
BFS.ALL_DEPTHS_ALLOWED
BFS.ALL_DEPTHS_ALLOWED,
TerminationFlag.RUNNING_TRUE
).compute();
var messagesInOrder = testLog.getMessages(INFO);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.neo4j.gds.extension.GdlGraph;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.TestGraph;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -96,7 +97,8 @@ void testBfsToTargetOut(int concurrency, int delta) {
new Concurrency(concurrency),
ProgressTracker.NULL_TRACKER,
delta,
BFS.ALL_DEPTHS_ALLOWED
BFS.ALL_DEPTHS_ALLOWED,
TerminationFlag.RUNNING_TRUE
).compute().toArray();

assertThat(nodes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.neo4j.gds.applications.algorithms.centrality;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery;
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
import org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking;
import org.neo4j.gds.betweenness.BetweennessCentrality;
Expand All @@ -41,6 +42,8 @@
import org.neo4j.gds.termination.TerminationFlag;

public class CentralityAlgorithms {
private final AlgorithmMachinery algorithmMachinery = new AlgorithmMachinery();

private final ProgressTrackerCreator progressTrackerCreator;
private final TerminationFlag terminationFlag;

Expand Down Expand Up @@ -79,7 +82,7 @@ BetwennessCentralityResult betweennessCentrality(Graph graph, BetweennessCentral
terminationFlag
);

return algorithm.compute();
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
}

ClosenessCentralityResult closenessCentrality(Graph graph, ClosenessCentralityBaseConfig configuration) {
Expand All @@ -103,7 +106,7 @@ ClosenessCentralityResult closenessCentrality(Graph graph, ClosenessCentralityBa
progressTracker
);

return algorithm.compute();
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
}

DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig configuration) {
Expand All @@ -122,6 +125,6 @@ DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig conf
progressTracker
);

return algorithm.compute();
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.applications.algorithms.machinery;

import org.neo4j.gds.Algorithm;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;

/**
* I wish this did not exist quite like this; it is where we encapsulate running an algorithm,
* managing termination, and handling (progress tracker) resources.
* Somehow I wish that was encapsulated more naturally, but as you can hear from this use of language,
* the design has not crystallized yet.
* At least nothing here is tied to termination flag.
*/
public class AlgorithmMachinery {
/**
* Runs algorithm.
* Optionally releases progress tracker.
* Exceptionally marks progress tracker state as failed.
*
* @return algorithm result, or an error in the form of an exception
*/
public <RESULT> RESULT runAlgorithmsAndManageProgressTracker(
Algorithm<RESULT> algorithm,
ProgressTracker progressTracker,
boolean shouldReleaseProgressTracker
) {
try {
return algorithm.compute();
} catch (Exception e) {
progressTracker.endSubTaskWithFailure();
throw e;
} finally {
if (shouldReleaseProgressTracker) progressTracker.release();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.applications.algorithms.machinery;

import org.junit.jupiter.api.Test;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;

class AlgorithmMachineryTest {
@Test
void shouldRunAlgorithm() {
var algorithmMachinery = new AlgorithmMachinery();

var progressTracker = mock(ProgressTracker.class);
var result = algorithmMachinery.runAlgorithmsAndManageProgressTracker(
new RegurgitatingAlgorithm("Hello, world!"),
progressTracker,
false
);

assertThat(result).isEqualTo("Hello, world!");

verifyNoInteractions(progressTracker);
}

@Test
void shouldReleaseProgressTrackerWhenAsked() {
var algorithmMachinery = new AlgorithmMachinery();

var progressTracker = mock(ProgressTracker.class);
var result = algorithmMachinery.runAlgorithmsAndManageProgressTracker(
new RegurgitatingAlgorithm("Dodgers win world series!"),
progressTracker,
true
);

assertThat(result).isEqualTo("Dodgers win world series!");

verify(progressTracker).release();
}

@Test
void shouldMarkProgressTracker() {
var algorithmMachinery = new AlgorithmMachinery();

var progressTracker = mock(ProgressTracker.class);
var exception = new RuntimeException("Whoops!");
try {
algorithmMachinery.runAlgorithmsAndManageProgressTracker(
new FailingAlgorithm(exception),
progressTracker,
false
);
fail();
} catch (Exception e) {
assertThat(e).hasMessage("Whoops!");
}

verify(progressTracker).endSubTaskWithFailure();
}

@Test
void shouldMarkProgressTrackerAndReleaseIt() {
var algorithmMachinery = new AlgorithmMachinery();

var progressTracker = mock(ProgressTracker.class);
var exception = new RuntimeException("Yeah, no...");
try {
algorithmMachinery.runAlgorithmsAndManageProgressTracker(
new FailingAlgorithm(exception),
progressTracker,
true
);
fail();
} catch (Exception e) {
assertThat(e).hasMessage("Yeah, no...");
}

verify(progressTracker).endSubTaskWithFailure();
verify(progressTracker).release();
}
}
Loading

0 comments on commit 6660781

Please sign in to comment.