diff --git a/algo/src/main/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPaths.java b/algo/src/main/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPaths.java index 74a664cfef..5d4bff2885 100644 --- a/algo/src/main/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPaths.java +++ b/algo/src/main/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPaths.java @@ -23,6 +23,7 @@ import org.neo4j.gds.api.RelationshipIterator; import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; +import org.neo4j.gds.termination.TerminationFlag; import java.util.Arrays; import java.util.concurrent.BlockingQueue; @@ -63,7 +64,7 @@ public class WeightedAllShortestPaths extends MSBFSASPAlgorithm { private volatile boolean outputStreamOpen; - public WeightedAllShortestPaths(Graph graph, ExecutorService executorService, Concurrency concurrency) { + public WeightedAllShortestPaths(Graph graph, ExecutorService executorService, Concurrency concurrency, TerminationFlag terminationFlag) { super(ProgressTracker.NULL_TRACKER); if (!graph.hasRelationshipProperty()) { throw new UnsupportedOperationException("WeightedAllShortestPaths is not supported on graphs without a weight property"); @@ -74,6 +75,7 @@ public WeightedAllShortestPaths(Graph graph, ExecutorService executorService, Co this.executorService = executorService; this.concurrency = concurrency; this.counter = new AtomicInteger(); + this.terminationFlag = terminationFlag; } /** diff --git a/algo/src/main/java/org/neo4j/gds/dag/longestPath/DagLongestPath.java b/algo/src/main/java/org/neo4j/gds/dag/longestPath/DagLongestPath.java index 93e7aecac8..1ad019963b 100644 --- a/algo/src/main/java/org/neo4j/gds/dag/longestPath/DagLongestPath.java +++ b/algo/src/main/java/org/neo4j/gds/dag/longestPath/DagLongestPath.java @@ -68,7 +68,8 @@ public class DagLongestPath extends Algorithm { public DagLongestPath( Graph graph, ProgressTracker progressTracker, - Concurrency concurrency + Concurrency concurrency, + TerminationFlag terminationFlag ) { super(progressTracker); this.graph = graph; @@ -76,6 +77,7 @@ public DagLongestPath( this.concurrency = concurrency; this.inDegrees = HugeAtomicLongArray.of(nodeCount, ParalleLongPageCreator.passThrough(this.concurrency)); this.parentsAndDistances = TentativeDistances.distanceAndPredecessors(nodeCount, concurrency, -Double.MIN_VALUE, (a, b) -> Double.compare(a, b) < 0); + this.terminationFlag = terminationFlag; } @Override diff --git a/algo/src/main/java/org/neo4j/gds/dag/longestPath/DagLongestPathFactory.java b/algo/src/main/java/org/neo4j/gds/dag/longestPath/DagLongestPathFactory.java index 47f3e1425e..d40e2887b9 100644 --- a/algo/src/main/java/org/neo4j/gds/dag/longestPath/DagLongestPathFactory.java +++ b/algo/src/main/java/org/neo4j/gds/dag/longestPath/DagLongestPathFactory.java @@ -24,6 +24,7 @@ import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.core.utils.progress.tasks.Task; import org.neo4j.gds.core.utils.progress.tasks.Tasks; +import org.neo4j.gds.termination.TerminationFlag; import java.util.List; @@ -33,7 +34,8 @@ public DagLongestPath build(Graph graph, DagLongestPathBaseConfig configuration, return new DagLongestPath( graph, progressTracker, - configuration.concurrency() + configuration.concurrency(), + TerminationFlag.RUNNING_TRUE ); } diff --git a/algo/src/main/java/org/neo4j/gds/dag/topologicalsort/TopologicalSort.java b/algo/src/main/java/org/neo4j/gds/dag/topologicalsort/TopologicalSort.java index 618eff4989..d52d00f59b 100644 --- a/algo/src/main/java/org/neo4j/gds/dag/topologicalsort/TopologicalSort.java +++ b/algo/src/main/java/org/neo4j/gds/dag/topologicalsort/TopologicalSort.java @@ -68,8 +68,8 @@ public TopologicalSort( Graph graph, ProgressTracker progressTracker, Concurrency concurrency, - boolean computeMaxDistanceFromSource - + boolean computeMaxDistanceFromSource, + TerminationFlag terminationFlag ) { super(progressTracker); this.graph = graph; @@ -80,6 +80,7 @@ public TopologicalSort( ? Optional.of(HugeAtomicDoubleArray.of(nodeCount, ParallelDoublePageCreator.passThrough(this.concurrency))) : Optional.empty(); this.result = new TopologicalSortResult(nodeCount, longestPathDistances); + this.terminationFlag = terminationFlag; } @Override diff --git a/algo/src/main/java/org/neo4j/gds/dag/topologicalsort/TopologicalSortFactory.java b/algo/src/main/java/org/neo4j/gds/dag/topologicalsort/TopologicalSortFactory.java index 793d0c9978..58bfc1cbc2 100644 --- a/algo/src/main/java/org/neo4j/gds/dag/topologicalsort/TopologicalSortFactory.java +++ b/algo/src/main/java/org/neo4j/gds/dag/topologicalsort/TopologicalSortFactory.java @@ -24,6 +24,7 @@ import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.core.utils.progress.tasks.Task; import org.neo4j.gds.core.utils.progress.tasks.Tasks; +import org.neo4j.gds.termination.TerminationFlag; import java.util.List; @@ -34,7 +35,8 @@ public TopologicalSort build(Graph graph, TopologicalSortBaseConfig configuratio graph, progressTracker, configuration.concurrency(), - configuration.computeMaxDistanceFromSource() + configuration.computeMaxDistanceFromSource(), + TerminationFlag.RUNNING_TRUE ); } diff --git a/algo/src/main/java/org/neo4j/gds/kspanningtree/KSpanningTree.java b/algo/src/main/java/org/neo4j/gds/kspanningtree/KSpanningTree.java index 4643ed0182..232f9cb414 100644 --- a/algo/src/main/java/org/neo4j/gds/kspanningtree/KSpanningTree.java +++ b/algo/src/main/java/org/neo4j/gds/kspanningtree/KSpanningTree.java @@ -29,6 +29,7 @@ import org.neo4j.gds.core.utils.queue.HugeLongPriorityQueue; import org.neo4j.gds.spanningtree.Prim; import org.neo4j.gds.spanningtree.SpanningTree; +import org.neo4j.gds.termination.TerminationFlag; import java.util.function.DoubleUnaryOperator; @@ -53,7 +54,8 @@ public KSpanningTree( DoubleUnaryOperator minMax, long startNodeId, long k, - ProgressTracker progressTracker + ProgressTracker progressTracker, + TerminationFlag terminationFlag ) { super(progressTracker); this.graph = graph; @@ -61,6 +63,8 @@ public KSpanningTree( this.startNodeId = startNodeId; this.k = k; + + this.terminationFlag = terminationFlag; } @Override @@ -70,10 +74,10 @@ public SpanningTree compute() { graph, minMax, startNodeId, - progressTracker + progressTracker, + terminationFlag ); - prim.setTerminationFlag(getTerminationFlag()); SpanningTree spanningTree = prim.compute(); var outputTree = growApproach(spanningTree); diff --git a/algo/src/main/java/org/neo4j/gds/kspanningtree/KSpanningTreeAlgorithmFactory.java b/algo/src/main/java/org/neo4j/gds/kspanningtree/KSpanningTreeAlgorithmFactory.java index 71c58d4b3c..5b203c6f44 100644 --- a/algo/src/main/java/org/neo4j/gds/kspanningtree/KSpanningTreeAlgorithmFactory.java +++ b/algo/src/main/java/org/neo4j/gds/kspanningtree/KSpanningTreeAlgorithmFactory.java @@ -24,6 +24,7 @@ import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.core.utils.progress.tasks.Task; import org.neo4j.gds.core.utils.progress.tasks.Tasks; +import org.neo4j.gds.termination.TerminationFlag; public class KSpanningTreeAlgorithmFactory extends GraphAlgorithmFactory { @@ -37,7 +38,8 @@ public KSpanningTree build(Graph graph, KSpanningTreeParameters parameters, Prog parameters.objective(), graph.toMappedNodeId(parameters.sourceNode()), parameters.k(), - progressTracker + progressTracker, + TerminationFlag.RUNNING_TRUE ); } diff --git a/algo/src/main/java/org/neo4j/gds/spanningtree/Prim.java b/algo/src/main/java/org/neo4j/gds/spanningtree/Prim.java index b9343b17e3..056722d54a 100644 --- a/algo/src/main/java/org/neo4j/gds/spanningtree/Prim.java +++ b/algo/src/main/java/org/neo4j/gds/spanningtree/Prim.java @@ -25,6 +25,7 @@ import org.neo4j.gds.collections.ha.HugeLongArray; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.core.utils.queue.HugeLongPriorityQueue; +import org.neo4j.gds.termination.TerminationFlag; import java.util.function.DoubleUnaryOperator; @@ -53,12 +54,14 @@ public Prim( Graph graph, DoubleUnaryOperator minMax, long startNodeId, - ProgressTracker progressTracker + ProgressTracker progressTracker, + TerminationFlag terminationFlag ) { super(progressTracker); this.graph = graph; this.minMax = minMax; this.startNodeId = startNodeId; + this.terminationFlag = terminationFlag; } @Override diff --git a/algo/src/main/java/org/neo4j/gds/spanningtree/SpanningTreeAlgorithmFactory.java b/algo/src/main/java/org/neo4j/gds/spanningtree/SpanningTreeAlgorithmFactory.java index bddb082212..3fde84455a 100644 --- a/algo/src/main/java/org/neo4j/gds/spanningtree/SpanningTreeAlgorithmFactory.java +++ b/algo/src/main/java/org/neo4j/gds/spanningtree/SpanningTreeAlgorithmFactory.java @@ -25,6 +25,7 @@ import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.core.utils.progress.tasks.Task; import org.neo4j.gds.core.utils.progress.tasks.Tasks; +import org.neo4j.gds.termination.TerminationFlag; public class SpanningTreeAlgorithmFactory extends GraphAlgorithmFactory { @@ -37,7 +38,8 @@ public Prim build(Graph graph, SpanningTreeParameters parameters, ProgressTracke graph, parameters.objective(), graph.toMappedNodeId(parameters.sourceNode()), - progressTracker + progressTracker, + TerminationFlag.RUNNING_TRUE ); } diff --git a/algo/src/main/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithm.java b/algo/src/main/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithm.java index 1a6c7bc6ad..3ed5f1357a 100644 --- a/algo/src/main/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithm.java +++ b/algo/src/main/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithm.java @@ -30,6 +30,7 @@ import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.paths.PathResult; import org.neo4j.gds.paths.dijkstra.PathFindingResult; +import org.neo4j.gds.termination.TerminationFlag; import java.util.List; import java.util.concurrent.ExecutorService; @@ -59,7 +60,8 @@ public ShortestPathsSteinerAlgorithm( Concurrency concurrency, boolean applyRerouting, ExecutorService executorService, - ProgressTracker progressTracker + ProgressTracker progressTracker, + TerminationFlag terminationFlag ) { super(progressTracker); this.graph = graph; @@ -73,6 +75,7 @@ public ShortestPathsSteinerAlgorithm( this.binSizeThreshold = SteinerBasedDeltaStepping.BIN_SIZE_THRESHOLD; this.examinationQueue = createExaminationQueue(graph, applyRerouting, terminals.size()); this.indexQueue = new LongAdder(); + this.terminationFlag = terminationFlag; } @TestOnly diff --git a/algo/src/main/java/org/neo4j/gds/steiner/SteinerTreeAlgorithmFactory.java b/algo/src/main/java/org/neo4j/gds/steiner/SteinerTreeAlgorithmFactory.java index a5356d6d8f..1c73837bdf 100644 --- a/algo/src/main/java/org/neo4j/gds/steiner/SteinerTreeAlgorithmFactory.java +++ b/algo/src/main/java/org/neo4j/gds/steiner/SteinerTreeAlgorithmFactory.java @@ -26,6 +26,7 @@ import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.core.utils.progress.tasks.Task; import org.neo4j.gds.core.utils.progress.tasks.Tasks; +import org.neo4j.gds.termination.TerminationFlag; import java.util.ArrayList; import java.util.stream.Collectors; @@ -48,7 +49,8 @@ public ShortestPathsSteinerAlgorithm build( parameters.concurrency(), parameters.applyRerouting(), DefaultPool.INSTANCE, - progressTracker + progressTracker, + TerminationFlag.RUNNING_TRUE ); } diff --git a/algo/src/main/java/org/neo4j/gds/traversal/RandomWalk.java b/algo/src/main/java/org/neo4j/gds/traversal/RandomWalk.java index 6ce5f33c97..a9383f9b17 100644 --- a/algo/src/main/java/org/neo4j/gds/traversal/RandomWalk.java +++ b/algo/src/main/java/org/neo4j/gds/traversal/RandomWalk.java @@ -64,7 +64,8 @@ public static RandomWalk create( int walkBufferSize, Optional randomSeed, ProgressTracker progressTracker, - ExecutorService executorService + ExecutorService executorService, + TerminationFlag terminationFlag ) { if (graph.hasRelationshipProperty()) { EmbeddingUtils.validateRelationshipWeightPropertyValue( @@ -84,7 +85,8 @@ public static RandomWalk create( sourceNodes, walkBufferSize, randomSeed, - progressTracker + progressTracker, + terminationFlag ); } @@ -96,7 +98,8 @@ private RandomWalk( List sourceNodes, int walkBufferSize, Optional maybeRandomSeed, - ProgressTracker progressTracker + ProgressTracker progressTracker, + TerminationFlag terminationFlag ) { super(progressTracker); this.concurrency = concurrency; @@ -107,6 +110,7 @@ private RandomWalk( this.walkParameters = walkParameters; this.sourceNodes = sourceNodes; this.randomSeed = maybeRandomSeed.orElseGet(() -> new Random().nextLong()); + this.terminationFlag = terminationFlag; } @Override diff --git a/algo/src/main/java/org/neo4j/gds/traversal/RandomWalkAlgorithmFactory.java b/algo/src/main/java/org/neo4j/gds/traversal/RandomWalkAlgorithmFactory.java index 4289cf56fa..2b06b415be 100644 --- a/algo/src/main/java/org/neo4j/gds/traversal/RandomWalkAlgorithmFactory.java +++ b/algo/src/main/java/org/neo4j/gds/traversal/RandomWalkAlgorithmFactory.java @@ -27,6 +27,7 @@ import org.neo4j.gds.core.utils.progress.tasks.Task; import org.neo4j.gds.core.utils.progress.tasks.Tasks; import org.neo4j.gds.degree.DegreeCentralityFactory; +import org.neo4j.gds.termination.TerminationFlag; import java.util.ArrayList; @@ -50,7 +51,8 @@ public RandomWalk build( configuration.walkBufferSize(), configuration.randomSeed(), progressTracker, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ); } diff --git a/algo/src/test/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPathsTest.java b/algo/src/test/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPathsTest.java index 71aa037335..4549d3b38d 100644 --- a/algo/src/test/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPathsTest.java +++ b/algo/src/test/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPathsTest.java @@ -30,6 +30,7 @@ import org.neo4j.gds.extension.IdFunction; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.gdl.GdlFactory; +import org.neo4j.gds.termination.TerminationFlag; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertNotEquals; @@ -100,7 +101,7 @@ void testResults() { TriConsumer mock = mock(TriConsumer.class); - new WeightedAllShortestPaths(graph, DefaultPool.INSTANCE, new Concurrency(4)) + new WeightedAllShortestPaths(graph, DefaultPool.INSTANCE, new Concurrency(4), TerminationFlag.RUNNING_TRUE) .compute() .forEach(r -> { assertNotEquals(Double.POSITIVE_INFINITY, r.distance); @@ -125,7 +126,7 @@ void shouldThrowIfGraphHasNoRelationshipProperty() { var gdlGraph = GdlFactory.of("(a)-[:r]->(b)").build().getUnion(); UnsupportedOperationException exception = assertThrows(UnsupportedOperationException.class, () -> { - new WeightedAllShortestPaths(gdlGraph, DefaultPool.INSTANCE, new Concurrency(4)); + new WeightedAllShortestPaths(gdlGraph, DefaultPool.INSTANCE, new Concurrency(4), TerminationFlag.RUNNING_TRUE); }); assertTrue(exception.getMessage().contains("not supported")); diff --git a/algo/src/test/java/org/neo4j/gds/dag/topologicalsort/TopologicalSortTest.java b/algo/src/test/java/org/neo4j/gds/dag/topologicalsort/TopologicalSortTest.java index 47b3ecf8c2..521f63dcce 100644 --- a/algo/src/test/java/org/neo4j/gds/dag/topologicalsort/TopologicalSortTest.java +++ b/algo/src/test/java/org/neo4j/gds/dag/topologicalsort/TopologicalSortTest.java @@ -32,6 +32,7 @@ import org.neo4j.gds.extension.GdlGraph; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.termination.TerminationFlag; import java.util.List; import java.util.Random; @@ -71,7 +72,8 @@ void shouldSortRight() { basicGraph, ProgressTracker.NULL_TRACKER, CONFIG.concurrency(), - CONFIG.computeMaxDistanceFromSource() + CONFIG.computeMaxDistanceFromSource(), + TerminationFlag.RUNNING_TRUE ); TopologicalSortResult result = ts.compute(); HugeLongArray nodes = result.sortedNodes(); @@ -119,7 +121,8 @@ void allCycleShouldGiveEmptySorting() { allCycleGraph, ProgressTracker.NULL_TRACKER, BASIC_CONFIG.concurrency(), - BASIC_CONFIG.computeMaxDistanceFromSource() + BASIC_CONFIG.computeMaxDistanceFromSource(), + TerminationFlag.RUNNING_TRUE ); TopologicalSortResult result = ts.compute(); HugeLongArray nodes = result.sortedNodes(); @@ -133,7 +136,8 @@ void shouldNotAllocateArraysOnBasicConfig() { allCycleGraph, ProgressTracker.NULL_TRACKER, BASIC_CONFIG.concurrency(), - BASIC_CONFIG.computeMaxDistanceFromSource() + BASIC_CONFIG.computeMaxDistanceFromSource(), + TerminationFlag.RUNNING_TRUE ); TopologicalSortResult result = ts.compute(); @@ -158,7 +162,8 @@ void ShouldExcludeSelfLoops() { TopologicalSort ts = new TopologicalSort(selfLoopGraph, ProgressTracker.NULL_TRACKER, CONFIG.concurrency(), - CONFIG.computeMaxDistanceFromSource() + CONFIG.computeMaxDistanceFromSource(), + TerminationFlag.RUNNING_TRUE ); TopologicalSortResult result = ts.compute(); HugeLongArray nodes = result.sortedNodes(); @@ -241,7 +246,8 @@ void hundredShouldComeLast() { TopologicalSort ts = new TopologicalSort(lastGraph, ProgressTracker.NULL_TRACKER, CONFIG.concurrency(), - CONFIG.computeMaxDistanceFromSource() + CONFIG.computeMaxDistanceFromSource(), + TerminationFlag.RUNNING_TRUE ); TopologicalSortResult result = ts.compute(); HugeLongArray nodes = result.sortedNodes(); @@ -330,7 +336,8 @@ void shouldNotIncludeCycles() { TopologicalSort ts = new TopologicalSort(cyclesGraph, ProgressTracker.NULL_TRACKER, CONFIG.concurrency(), - CONFIG.computeMaxDistanceFromSource() + CONFIG.computeMaxDistanceFromSource(), + TerminationFlag.RUNNING_TRUE ); TopologicalSortResult result = ts.compute(); HugeLongArray nodes = result.sortedNodes(); @@ -372,7 +379,8 @@ void randomShouldContainAllNodesOnDag() { TopologicalSort ts = new TopologicalSort(graph, ProgressTracker.NULL_TRACKER, BASIC_CONFIG.concurrency(), - BASIC_CONFIG.computeMaxDistanceFromSource() + BASIC_CONFIG.computeMaxDistanceFromSource(), + TerminationFlag.RUNNING_TRUE ); TopologicalSortResult result = ts.compute(); assertEquals(100, result.size()); diff --git a/algo/src/test/java/org/neo4j/gds/kspanningtree/KSpanningTreeTest.java b/algo/src/test/java/org/neo4j/gds/kspanningtree/KSpanningTreeTest.java index deaacf7489..943f2e2ea9 100644 --- a/algo/src/test/java/org/neo4j/gds/kspanningtree/KSpanningTreeTest.java +++ b/algo/src/test/java/org/neo4j/gds/kspanningtree/KSpanningTreeTest.java @@ -38,6 +38,7 @@ import org.neo4j.gds.extension.Inject; import org.neo4j.gds.gdl.GdlFactory; import org.neo4j.gds.spanningtree.Prim; +import org.neo4j.gds.termination.TerminationFlag; import java.util.HashSet; @@ -90,7 +91,7 @@ void setUp() { @Test void testMaximumKSpanningTree() { - var spanningTree = new KSpanningTree(graph, Prim.MAX_OPERATOR, a, 2, ProgressTracker.NULL_TRACKER) + var spanningTree = new KSpanningTree(graph, Prim.MAX_OPERATOR, a, 2, ProgressTracker.NULL_TRACKER, TerminationFlag.RUNNING_TRUE) .compute(); assertThat(spanningTree).matches(tree -> tree.head(a) == tree.head(b) ^ tree.head(c) == tree.head(d)); @@ -102,7 +103,7 @@ void testMaximumKSpanningTree() { @Test void testMinimumKSpanningTree() { - var spanningTree = new KSpanningTree(graph, Prim.MIN_OPERATOR, a, 2, ProgressTracker.NULL_TRACKER) + var spanningTree = new KSpanningTree(graph, Prim.MIN_OPERATOR, a, 2, ProgressTracker.NULL_TRACKER, TerminationFlag.RUNNING_TRUE) .compute(); assertThat(spanningTree).matches(tree -> tree.head(a) == tree.head(d) ^ tree.head(b) == tree.head(c)); @@ -133,7 +134,8 @@ void shouldProduceSingleConnectedTree() { Prim.MIN_OPERATOR, startNode, k, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); // if there are more than k nodes then there is more than one root @@ -173,7 +175,8 @@ void shouldProduceSingleTreeWithKMinusOneEdges(int k, double expected) { Prim.MIN_OPERATOR, startNode, k, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); var counter = new MutableLong(0); @@ -215,7 +218,8 @@ void worstCaseForPruningLeaves() { Prim.MIN_OPERATOR, startNode, 4, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); var counter = new MutableLong(0); @@ -254,7 +258,8 @@ void shouldWorkForComponentSmallerThanK() { Prim.MIN_OPERATOR, startNode, 5, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(spanningTree.effectiveNodeCount()).isEqualTo(4); diff --git a/algo/src/test/java/org/neo4j/gds/spanningtree/PrimTest.java b/algo/src/test/java/org/neo4j/gds/spanningtree/PrimTest.java index fcefb5b60a..83e260e061 100644 --- a/algo/src/test/java/org/neo4j/gds/spanningtree/PrimTest.java +++ b/algo/src/test/java/org/neo4j/gds/spanningtree/PrimTest.java @@ -37,6 +37,7 @@ import org.neo4j.gds.extension.IdFunction; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.termination.TerminationFlag; import java.util.stream.Stream; @@ -126,7 +127,8 @@ void testMaximum(String nodeId, String parentA, String parentB, String parentC, graph, Prim.MAX_OPERATOR, idFunction.of(nodeId), - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(mst.totalWeight()).isEqualTo(17L); assertTreeIsCorrect(mst, parentA, parentB, parentC, parentD, parentE); @@ -139,7 +141,8 @@ void testMinimum(String nodeId, String parentA, String parentB, String parentC, graph, Prim.MIN_OPERATOR, idFunction.of(nodeId), - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(mst.totalWeight()).isEqualTo(12L); assertTreeIsCorrect(mst, parentA, parentB, parentC, parentD, parentE); diff --git a/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathSteinerAlgorithmExtendedTest.java b/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathSteinerAlgorithmExtendedTest.java index 91494ffffe..fa34a10880 100644 --- a/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathSteinerAlgorithmExtendedTest.java +++ b/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathSteinerAlgorithmExtendedTest.java @@ -34,6 +34,7 @@ import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.TestGraph; import org.neo4j.gds.paths.PathResult; +import org.neo4j.gds.termination.TerminationFlag; import java.util.List; import java.util.stream.Stream; @@ -177,7 +178,8 @@ void shouldWorkCorrectlyWithLineGraph() { new Concurrency(1), false, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ) .compute(); @@ -231,7 +233,8 @@ void shouldWorkIfRevisitsVertices() { new Concurrency(1), false, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); long[] parentArray = new long[]{ @@ -260,7 +263,8 @@ void shouldWorkOnTriangle() { new Concurrency(1), false, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); long[] parentArray = new long[]{ShortestPathsSteinerAlgorithm.ROOT_NODE, a[0], a[1], a[2]}; diff --git a/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithmReroutingTest.java b/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithmReroutingTest.java index 2cc32b671b..b084744337 100644 --- a/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithmReroutingTest.java +++ b/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithmReroutingTest.java @@ -35,6 +35,7 @@ import org.neo4j.gds.extension.IdFunction; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.termination.TerminationFlag; import java.util.List; import java.util.stream.Collectors; @@ -192,7 +193,8 @@ void shouldPruneUnusedIfRerouting() { new Concurrency(1), false, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(steinerResult.totalCost()).isEqualTo(7.0); assertThat(steinerResult.effectiveNodeCount()).isEqualTo(5); @@ -206,7 +208,8 @@ void shouldPruneUnusedIfRerouting() { new Concurrency(1), true, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(steinerResultWithReroute.totalCost()).isEqualTo(4.0); assertThat(steinerResultWithReroute.effectiveNodeCount()).isEqualTo(3); @@ -226,7 +229,8 @@ void shouldPruneUnusedIfReroutingOnInvertedIndex() { new Concurrency(1), true, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(steinerResultWithReroute.totalCost()).isEqualTo(4.0); assertThat(steinerResultWithReroute.effectiveNodeCount()).isEqualTo(3); @@ -245,7 +249,8 @@ void rerouteShouldNotCreateLoops() { new Concurrency(1), true, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); var parent = steinerResult.parentArray().toArray(); @@ -270,7 +275,8 @@ void rerouteShouldNotCreateLoopsOnInvertedIndex() { new Concurrency(1), true, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); var parent = steinerResult.parentArray().toArray(); @@ -297,7 +303,8 @@ void shouldWorkForUnreachableAndReachableTerminals() { new Concurrency(1), true, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(steinerTreeResult.effectiveTargetNodesCount()).isEqualTo(2); }); @@ -318,7 +325,8 @@ void shouldWorkIfNoReachableTerminals() { new Concurrency(1), true, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(steinerTreeResult.effectiveTargetNodesCount()).isEqualTo(0); assertThat(steinerTreeResult.effectiveNodeCount()).isEqualTo(1); @@ -352,7 +360,8 @@ void shouldLogProgress() { concurrency, false, DefaultPool.INSTANCE, - progressTracker + progressTracker, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(log.getMessages(TestLog.INFO)) @@ -397,7 +406,8 @@ void shouldLogProgressWithRerouting() { concurrency, applyRerouting, DefaultPool.INSTANCE, - progressTracker + progressTracker, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(log.getMessages(TestLog.INFO)) @@ -448,7 +458,8 @@ void shouldLogProgressWithInverseRerouting() { concurrency, applyRerouting, DefaultPool.INSTANCE, - progressTracker + progressTracker, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(log.getMessages(TestLog.INFO)) @@ -489,7 +500,8 @@ void shouldNotGetOptimalWithoutBetterRerouting() { new Concurrency(1), true, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(steinerResultWithReroute.totalCost()).isEqualTo(25.0); assertThat(steinerResultWithReroute.effectiveNodeCount()).isEqualTo(8); @@ -513,7 +525,8 @@ void shouldHandleMultiplePruningsOnSameTreeAndGetBetter() { new Concurrency(1), true, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(steinerResultWithReroute.totalCost()).isEqualTo(22.0); assertThat(steinerResultWithReroute.effectiveNodeCount()).isEqualTo(5); @@ -537,7 +550,8 @@ void shouldNotPruneUnprunableNodes() { new Concurrency(1), true, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(steinerResultWithReroute.totalCost()).isEqualTo(170.0 - 19); assertThat(steinerResultWithReroute.effectiveNodeCount()).isEqualTo(6); @@ -561,7 +575,8 @@ void shouldTakeAdvantageOfNewSingleParents() { new Concurrency(1), true, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); assertThat(steinerResultWithReroute.totalCost()).isEqualTo(20); diff --git a/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithmTest.java b/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithmTest.java index d2c641eefc..5266e4e85a 100644 --- a/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithmTest.java +++ b/algo/src/test/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithmTest.java @@ -28,6 +28,7 @@ import org.neo4j.gds.extension.GdlGraph; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.termination.TerminationFlag; import java.util.List; @@ -77,7 +78,8 @@ void shouldWorkCorrectly() { new Concurrency(1), false, DefaultPool.INSTANCE, - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute(); var pruned = ShortestPathsSteinerAlgorithm.PRUNED; var rootnode = ShortestPathsSteinerAlgorithm.ROOT_NODE; diff --git a/algo/src/test/java/org/neo4j/gds/traversal/RandomWalkTest.java b/algo/src/test/java/org/neo4j/gds/traversal/RandomWalkTest.java index c1e714b5e5..da7ce0c34c 100644 --- a/algo/src/test/java/org/neo4j/gds/traversal/RandomWalkTest.java +++ b/algo/src/test/java/org/neo4j/gds/traversal/RandomWalkTest.java @@ -42,6 +42,7 @@ import org.neo4j.gds.extension.GdlGraph; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.termination.TerminationFlag; import java.time.Instant; import java.time.temporal.ChronoUnit; @@ -95,7 +96,8 @@ void testWithDefaultConfig() { 1000, Optional.empty(), ProgressTracker.NULL_TRACKER, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ); List result = randomWalk.compute().collect(Collectors.toList()); @@ -128,7 +130,8 @@ void shouldBeDeterministic() { walkBufferSize, randomSeed, ProgressTracker.NULL_TRACKER, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ).compute().collect(Collectors.toList()); var secondResult = RandomWalk.create( @@ -139,7 +142,8 @@ void shouldBeDeterministic() { walkBufferSize, randomSeed, ProgressTracker.NULL_TRACKER, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ).compute().collect(Collectors.toList()); var firstResultAsSet = new TreeSet(Arrays::compare); @@ -164,7 +168,8 @@ void testSampleFromMultipleRelationshipTypes() { 1000, Optional.empty(), ProgressTracker.NULL_TRACKER, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ); int expectedNumberOfWalks = walkParameters.walksPerNode() * 3; @@ -201,7 +206,8 @@ void returnFactorShouldMakeWalksIncludeStartNodeMoreOften() { 1000, Optional.of(42L), ProgressTracker.NULL_TRACKER, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ); var nodeCounter = new HashMap(); @@ -257,7 +263,8 @@ void largeInOutFactorShouldMakeTheWalkKeepTheSameDistance() { 1000, Optional.of(87L), ProgressTracker.NULL_TRACKER, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ); var nodeCounter = new HashMap(); @@ -302,7 +309,8 @@ void shouldRespectRelationshipWeights() { 100, Optional.of(23L), ProgressTracker.NULL_TRACKER, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ); var nodeCounter = new HashMap(); @@ -335,7 +343,8 @@ void failOnInvalidRelationshipWeights(double invalidWeight) { 100, Optional.of(23L), ProgressTracker.NULL_TRACKER, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ) ).isInstanceOf(RuntimeException.class) .hasMessage( @@ -368,7 +377,8 @@ void parallelWeighted() { 100, Optional.of(23L), ProgressTracker.NULL_TRACKER, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ); assertThat(randomWalk.compute().collect(Collectors.toList())) @@ -397,7 +407,8 @@ void testWithConfiguredOffsetStartNodes() { 1000, Optional.empty(), ProgressTracker.NULL_TRACKER, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ); assertThat(randomWalk.compute().collect(Collectors.toList())) @@ -421,7 +432,8 @@ void testSetTerminationFlagAndMultipleRuns() { 1, Optional.empty(), ProgressTracker.NULL_TRACKER, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + TerminationFlag.RUNNING_TRUE ); var stream = randomWalk.compute(); diff --git a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java index 1f2d10bdcf..8a10947c2d 100644 --- a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java +++ b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java @@ -180,7 +180,8 @@ SpanningTree kSpanningTree(Graph graph, KSpanningTreeBaseConfig configuration) { parameters.objective(), graph.toMappedNodeId(parameters.sourceNode()), parameters.k(), - progressTracker + progressTracker, + requestScopedDependencies.getTerminationFlag() ); return algorithm.compute(); @@ -195,7 +196,8 @@ PathFindingResult longestPath(Graph graph, AlgoBaseConfig configuration) { var algorithm = new DagLongestPath( graph, progressTracker, - configuration.concurrency() + configuration.concurrency(), + requestScopedDependencies.getTerminationFlag() ); return algorithm.compute(); @@ -218,7 +220,8 @@ Stream randomWalk(Graph graph, RandomWalkBaseConfig configuration) { configuration.walkBufferSize(), configuration.randomSeed(), progressTracker, - DefaultPool.INSTANCE + DefaultPool.INSTANCE, + requestScopedDependencies.getTerminationFlag() ); return algorithm.compute(); @@ -317,7 +320,8 @@ SpanningTree spanningTree(Graph graph, SpanningTreeBaseConfig configuration) { graph, parameters.objective(), graph.toMappedNodeId(parameters.sourceNode()), - progressTracker + progressTracker, + requestScopedDependencies.getTerminationFlag() ); return prim.compute(); @@ -349,7 +353,8 @@ SteinerTreeResult steinerTree(Graph graph, SteinerTreeBaseConfig configuration) parameters.concurrency(), parameters.applyRerouting(), DefaultPool.INSTANCE, - progressTracker + progressTracker, + requestScopedDependencies.getTerminationFlag() ); return steiner.compute(); @@ -368,7 +373,8 @@ public TopologicalSortResult topologicalSort(Graph graph, TopologicalSortBaseCon graph, progressTracker, configuration.concurrency(), - configuration.computeMaxDistanceFromSource() + configuration.computeMaxDistanceFromSource(), + requestScopedDependencies.getTerminationFlag() ); return algorithm.compute(); @@ -379,7 +385,8 @@ private MSBFSASPAlgorithm selectAlgorithm(Graph graph, AllShortestPathsConfig co return new WeightedAllShortestPaths( graph, DefaultPool.INSTANCE, - configuration.concurrency() + configuration.concurrency(), + requestScopedDependencies.getTerminationFlag() ); } else { return new MSBFSAllShortestPaths( diff --git a/proc/path-finding/src/main/java/org/neo4j/gds/paths/all/AllShortestPathsStreamSpec.java b/proc/path-finding/src/main/java/org/neo4j/gds/paths/all/AllShortestPathsStreamSpec.java index 1cde163fb6..567363c160 100644 --- a/proc/path-finding/src/main/java/org/neo4j/gds/paths/all/AllShortestPathsStreamSpec.java +++ b/proc/path-finding/src/main/java/org/neo4j/gds/paths/all/AllShortestPathsStreamSpec.java @@ -35,6 +35,7 @@ import org.neo4j.gds.executor.ExecutionContext; import org.neo4j.gds.executor.GdsCallable; import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; +import org.neo4j.gds.termination.TerminationFlag; import java.util.stream.Stream; @@ -71,7 +72,8 @@ public MSBFSASPAlgorithm build( return new WeightedAllShortestPaths( graph, DefaultPool.INSTANCE, - configuration.concurrency() + configuration.concurrency(), + TerminationFlag.RUNNING_TRUE ); } else { return new MSBFSAllShortestPaths(