From d799f378a3fae5c83104350954dfb12bbf3c0d37 Mon Sep 17 00:00:00 2001 From: Lasse Westh-Nielsen Date: Wed, 22 May 2024 15:46:43 +0200 Subject: [PATCH] fix injection of termination flag into centrality algorithms --- .../gds/betweenness/BetweennessCentrality.java | 6 ++++-- .../betweenness/BetweennessCentralityFactory.java | 4 +++- .../gds/betweenness/BetweennessCentralityTest.java | 7 +++++-- .../WeightedBetweennessCentralityTest.java | 10 +++++++--- applications/algorithms/centrality/build.gradle | 1 + .../centrality/CentralityAlgorithms.java | 13 ++++++++++--- .../gds/applications/CentralityApplications.java | 2 +- ...eClassificationPredictPipelineExecutorTest.java | 14 +++++++++++++- 8 files changed, 44 insertions(+), 13 deletions(-) diff --git a/algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentrality.java b/algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentrality.java index d7cd319550..d032c05bf3 100644 --- a/algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentrality.java +++ b/algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentrality.java @@ -33,6 +33,7 @@ import org.neo4j.gds.core.utils.paged.HugeLongArrayStack; import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; +import org.neo4j.gds.termination.TerminationFlag; import java.util.concurrent.ExecutorService; import java.util.function.Consumer; @@ -57,7 +58,8 @@ public BetweennessCentrality( ForwardTraverser.Factory traverserFactory, ExecutorService executorService, Concurrency concurrency, - ProgressTracker progressTracker + ProgressTracker progressTracker, + TerminationFlag terminationFlag ) { super(progressTracker); this.graph = graph; @@ -69,7 +71,7 @@ public BetweennessCentrality( this.selectionStrategy.init(graph, executorService, concurrency); this.divisor = graph.schema().isUndirected() ? 2.0 : 1.0; this.traverserFactory = traverserFactory; - + this.terminationFlag = terminationFlag; } @Override diff --git a/algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentralityFactory.java b/algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentralityFactory.java index 688d9da354..b115e84463 100644 --- a/algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentralityFactory.java +++ b/algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentralityFactory.java @@ -26,6 +26,7 @@ import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.core.utils.progress.tasks.Task; import org.neo4j.gds.core.utils.progress.tasks.Tasks; +import org.neo4j.gds.termination.TerminationFlag; import java.util.Optional; @@ -58,7 +59,8 @@ public BetweennessCentrality build( traverserFactory, DefaultPool.INSTANCE, parameters.concurrency(), - progressTracker + progressTracker, + TerminationFlag.RUNNING_TRUE ); } diff --git a/algo/src/test/java/org/neo4j/gds/betweenness/BetweennessCentralityTest.java b/algo/src/test/java/org/neo4j/gds/betweenness/BetweennessCentralityTest.java index 59c653ec4b..2caf89e665 100644 --- a/algo/src/test/java/org/neo4j/gds/betweenness/BetweennessCentralityTest.java +++ b/algo/src/test/java/org/neo4j/gds/betweenness/BetweennessCentralityTest.java @@ -33,6 +33,7 @@ import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.termination.TerminationFlag; import java.util.Map; import java.util.Optional; @@ -139,7 +140,8 @@ void sampling(int concurrency, TestGraph graph, int samplingSize, Map