Skip to content

Commit

Permalink
fix injection of termination flag into pathfinding algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed May 23, 2024
1 parent d799f37 commit 04e32cf
Show file tree
Hide file tree
Showing 23 changed files with 162 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.neo4j.gds.api.RelationshipIterator;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Arrays;
import java.util.concurrent.BlockingQueue;
Expand Down Expand Up @@ -63,7 +64,7 @@ public class WeightedAllShortestPaths extends MSBFSASPAlgorithm {

private volatile boolean outputStreamOpen;

public WeightedAllShortestPaths(Graph graph, ExecutorService executorService, Concurrency concurrency) {
public WeightedAllShortestPaths(Graph graph, ExecutorService executorService, Concurrency concurrency, TerminationFlag terminationFlag) {
super(ProgressTracker.NULL_TRACKER);
if (!graph.hasRelationshipProperty()) {
throw new UnsupportedOperationException("WeightedAllShortestPaths is not supported on graphs without a weight property");
Expand All @@ -74,6 +75,7 @@ public WeightedAllShortestPaths(Graph graph, ExecutorService executorService, Co
this.executorService = executorService;
this.concurrency = concurrency;
this.counter = new AtomicInteger();
this.terminationFlag = terminationFlag;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,16 @@ public class DagLongestPath extends Algorithm<PathFindingResult> {
public DagLongestPath(
Graph graph,
ProgressTracker progressTracker,
Concurrency concurrency
Concurrency concurrency,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.graph = graph;
this.nodeCount = graph.nodeCount();
this.concurrency = concurrency;
this.inDegrees = HugeAtomicLongArray.of(nodeCount, ParalleLongPageCreator.passThrough(this.concurrency));
this.parentsAndDistances = TentativeDistances.distanceAndPredecessors(nodeCount, concurrency, -Double.MIN_VALUE, (a, b) -> Double.compare(a, b) < 0);
this.terminationFlag = terminationFlag;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;

Expand All @@ -33,7 +34,8 @@ public DagLongestPath build(Graph graph, DagLongestPathBaseConfig configuration,
return new DagLongestPath(
graph,
progressTracker,
configuration.concurrency()
configuration.concurrency(),
TerminationFlag.RUNNING_TRUE
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ public TopologicalSort(
Graph graph,
ProgressTracker progressTracker,
Concurrency concurrency,
boolean computeMaxDistanceFromSource

boolean computeMaxDistanceFromSource,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.graph = graph;
Expand All @@ -80,6 +80,7 @@ public TopologicalSort(
? Optional.of(HugeAtomicDoubleArray.of(nodeCount, ParallelDoublePageCreator.passThrough(this.concurrency)))
: Optional.empty();
this.result = new TopologicalSortResult(nodeCount, longestPathDistances);
this.terminationFlag = terminationFlag;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;

Expand All @@ -34,7 +35,8 @@ public TopologicalSort build(Graph graph, TopologicalSortBaseConfig configuratio
graph,
progressTracker,
configuration.concurrency(),
configuration.computeMaxDistanceFromSource()
configuration.computeMaxDistanceFromSource(),
TerminationFlag.RUNNING_TRUE
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.neo4j.gds.core.utils.queue.HugeLongPriorityQueue;
import org.neo4j.gds.spanningtree.Prim;
import org.neo4j.gds.spanningtree.SpanningTree;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.function.DoubleUnaryOperator;

Expand All @@ -53,14 +54,17 @@ public KSpanningTree(
DoubleUnaryOperator minMax,
long startNodeId,
long k,
ProgressTracker progressTracker
ProgressTracker progressTracker,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.graph = graph;
this.minMax = minMax;
this.startNodeId = startNodeId;

this.k = k;

this.terminationFlag = terminationFlag;
}

@Override
Expand All @@ -70,10 +74,10 @@ public SpanningTree compute() {
graph,
minMax,
startNodeId,
progressTracker
progressTracker,
terminationFlag
);

prim.setTerminationFlag(getTerminationFlag());
SpanningTree spanningTree = prim.compute();

var outputTree = growApproach(spanningTree);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.termination.TerminationFlag;

public class KSpanningTreeAlgorithmFactory<CONFIG extends KSpanningTreeBaseConfig> extends GraphAlgorithmFactory<KSpanningTree, CONFIG> {

Expand All @@ -37,7 +38,8 @@ public KSpanningTree build(Graph graph, KSpanningTreeParameters parameters, Prog
parameters.objective(),
graph.toMappedNodeId(parameters.sourceNode()),
parameters.k(),
progressTracker
progressTracker,
TerminationFlag.RUNNING_TRUE
);
}

Expand Down
5 changes: 4 additions & 1 deletion algo/src/main/java/org/neo4j/gds/spanningtree/Prim.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.queue.HugeLongPriorityQueue;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.function.DoubleUnaryOperator;

Expand Down Expand Up @@ -53,12 +54,14 @@ public Prim(
Graph graph,
DoubleUnaryOperator minMax,
long startNodeId,
ProgressTracker progressTracker
ProgressTracker progressTracker,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.graph = graph;
this.minMax = minMax;
this.startNodeId = startNodeId;
this.terminationFlag = terminationFlag;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.termination.TerminationFlag;

public class SpanningTreeAlgorithmFactory<CONFIG extends SpanningTreeBaseConfig> extends GraphAlgorithmFactory<Prim, CONFIG> {

Expand All @@ -37,7 +38,8 @@ public Prim build(Graph graph, SpanningTreeParameters parameters, ProgressTracke
graph,
parameters.objective(),
graph.toMappedNodeId(parameters.sourceNode()),
progressTracker
progressTracker,
TerminationFlag.RUNNING_TRUE
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.paths.PathResult;
import org.neo4j.gds.paths.dijkstra.PathFindingResult;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -59,7 +60,8 @@ public ShortestPathsSteinerAlgorithm(
Concurrency concurrency,
boolean applyRerouting,
ExecutorService executorService,
ProgressTracker progressTracker
ProgressTracker progressTracker,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.graph = graph;
Expand All @@ -73,6 +75,7 @@ public ShortestPathsSteinerAlgorithm(
this.binSizeThreshold = SteinerBasedDeltaStepping.BIN_SIZE_THRESHOLD;
this.examinationQueue = createExaminationQueue(graph, applyRerouting, terminals.size());
this.indexQueue = new LongAdder();
this.terminationFlag = terminationFlag;
}

@TestOnly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.ArrayList;
import java.util.stream.Collectors;
Expand All @@ -48,7 +49,8 @@ public ShortestPathsSteinerAlgorithm build(
parameters.concurrency(),
parameters.applyRerouting(),
DefaultPool.INSTANCE,
progressTracker
progressTracker,
TerminationFlag.RUNNING_TRUE
);
}

Expand Down
10 changes: 7 additions & 3 deletions algo/src/main/java/org/neo4j/gds/traversal/RandomWalk.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ public static RandomWalk create(
int walkBufferSize,
Optional<Long> randomSeed,
ProgressTracker progressTracker,
ExecutorService executorService
ExecutorService executorService,
TerminationFlag terminationFlag
) {
if (graph.hasRelationshipProperty()) {
EmbeddingUtils.validateRelationshipWeightPropertyValue(
Expand All @@ -84,7 +85,8 @@ public static RandomWalk create(
sourceNodes,
walkBufferSize,
randomSeed,
progressTracker
progressTracker,
terminationFlag
);
}

Expand All @@ -96,7 +98,8 @@ private RandomWalk(
List<Long> sourceNodes,
int walkBufferSize,
Optional<Long> maybeRandomSeed,
ProgressTracker progressTracker
ProgressTracker progressTracker,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.concurrency = concurrency;
Expand All @@ -107,6 +110,7 @@ private RandomWalk(
this.walkParameters = walkParameters;
this.sourceNodes = sourceNodes;
this.randomSeed = maybeRandomSeed.orElseGet(() -> new Random().nextLong());
this.terminationFlag = terminationFlag;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.degree.DegreeCentralityFactory;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.ArrayList;

Expand All @@ -50,7 +51,8 @@ public RandomWalk build(
configuration.walkBufferSize(),
configuration.randomSeed(),
progressTracker,
DefaultPool.INSTANCE
DefaultPool.INSTANCE,
TerminationFlag.RUNNING_TRUE
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.neo4j.gds.extension.IdFunction;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.gdl.GdlFactory;
import org.neo4j.gds.termination.TerminationFlag;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
Expand Down Expand Up @@ -100,7 +101,7 @@ void testResults() {

TriConsumer<Long, Long, Double> mock = mock(TriConsumer.class);

new WeightedAllShortestPaths(graph, DefaultPool.INSTANCE, new Concurrency(4))
new WeightedAllShortestPaths(graph, DefaultPool.INSTANCE, new Concurrency(4), TerminationFlag.RUNNING_TRUE)
.compute()
.forEach(r -> {
assertNotEquals(Double.POSITIVE_INFINITY, r.distance);
Expand All @@ -125,7 +126,7 @@ void shouldThrowIfGraphHasNoRelationshipProperty() {
var gdlGraph = GdlFactory.of("(a)-[:r]->(b)").build().getUnion();

UnsupportedOperationException exception = assertThrows(UnsupportedOperationException.class, () -> {
new WeightedAllShortestPaths(gdlGraph, DefaultPool.INSTANCE, new Concurrency(4));
new WeightedAllShortestPaths(gdlGraph, DefaultPool.INSTANCE, new Concurrency(4), TerminationFlag.RUNNING_TRUE);
});

assertTrue(exception.getMessage().contains("not supported"));
Expand Down
Loading

0 comments on commit 04e32cf

Please sign in to comment.