Skip to content

Commit

Permalink
fix injection of termination flag into centrality algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed May 22, 2024
1 parent 2e5367f commit d799f37
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.neo4j.gds.core.utils.paged.HugeLongArrayStack;
import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
Expand All @@ -57,7 +58,8 @@ public BetweennessCentrality(
ForwardTraverser.Factory traverserFactory,
ExecutorService executorService,
Concurrency concurrency,
ProgressTracker progressTracker
ProgressTracker progressTracker,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.graph = graph;
Expand All @@ -69,7 +71,7 @@ public BetweennessCentrality(
this.selectionStrategy.init(graph, executorService, concurrency);
this.divisor = graph.schema().isUndirected() ? 2.0 : 1.0;
this.traverserFactory = traverserFactory;

this.terminationFlag = terminationFlag;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Optional;

Expand Down Expand Up @@ -58,7 +59,8 @@ public BetweennessCentrality build(
traverserFactory,
DefaultPool.INSTANCE,
parameters.concurrency(),
progressTracker
progressTracker,
TerminationFlag.RUNNING_TRUE
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.extension.TestGraph;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -139,7 +140,8 @@ void sampling(int concurrency, TestGraph graph, int samplingSize, Map<String, Do
ForwardTraverser.Factory.unweighted(),
DefaultPool.INSTANCE,
new Concurrency(concurrency),
ProgressTracker.NULL_TRACKER
ProgressTracker.NULL_TRACKER,
TerminationFlag.RUNNING_TRUE
).compute().centralities();

assertEquals(expectedResult.size(), actualResult.size());
Expand All @@ -158,7 +160,8 @@ void noSampling(int concurrency) {
ForwardTraverser.Factory.unweighted(),
DefaultPool.INSTANCE,
new Concurrency(concurrency),
ProgressTracker.NULL_TRACKER
ProgressTracker.NULL_TRACKER,
TerminationFlag.RUNNING_TRUE
).compute().centralities();

assertEquals(5, actualResult.size(), "Expected 5 centrality values");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.neo4j.gds.extension.IdFunction;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.TestGraph;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Optional;

Expand Down Expand Up @@ -80,15 +81,17 @@ void shouldEqualWithUnweightedWhenWeightsAreEqual() {
ForwardTraverser.Factory.weighted(),
DefaultPool.INSTANCE,
new Concurrency(8),
ProgressTracker.NULL_TRACKER
ProgressTracker.NULL_TRACKER,
TerminationFlag.RUNNING_TRUE
);
var algoUnweighted = new BetweennessCentrality(
equallyWeightedGraph,
new RandomDegreeSelectionStrategy(7, Optional.of(42L)),
ForwardTraverser.Factory.unweighted(),
DefaultPool.INSTANCE,
new Concurrency(8),
ProgressTracker.NULL_TRACKER
ProgressTracker.NULL_TRACKER,
TerminationFlag.RUNNING_TRUE
);
var resultWeighted = algoWeighted.compute().centralities();
var resultUnweighted = algoUnweighted.compute().centralities();
Expand All @@ -114,7 +117,8 @@ void shouldComputeWithWeights() {
ForwardTraverser.Factory.weighted(),
DefaultPool.INSTANCE,
new Concurrency(8),
ProgressTracker.NULL_TRACKER
ProgressTracker.NULL_TRACKER,
TerminationFlag.RUNNING_TRUE
);
var result = bc.compute().centralities();
var softAssertions = new SoftAssertions();
Expand Down
1 change: 1 addition & 0 deletions applications/algorithms/centrality/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ dependencies {
implementation project(":logging")
implementation project(":memory-usage")
implementation project(":progress-tracking")
implementation project(":termination")
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@
import org.neo4j.gds.degree.DegreeCentrality;
import org.neo4j.gds.degree.DegreeCentralityConfig;
import org.neo4j.gds.degree.DegreeCentralityResult;
import org.neo4j.gds.termination.TerminationFlag;

public class CentralityAlgorithms {
private final ProgressTrackerCreator progressTrackerCreator;
private final TerminationFlag terminationFlag;

public CentralityAlgorithms(ProgressTrackerCreator progressTrackerCreator) {
public CentralityAlgorithms(ProgressTrackerCreator progressTrackerCreator, TerminationFlag terminationFlag) {
this.progressTrackerCreator = progressTrackerCreator;
this.terminationFlag = terminationFlag;
}

BetwennessCentralityResult betweennessCentrality(Graph graph, BetweennessCentralityBaseConfig configuration) {
Expand All @@ -60,7 +63,10 @@ BetwennessCentralityResult betweennessCentrality(Graph graph, BetweennessCentral
? ForwardTraverser.Factory.weighted()
: ForwardTraverser.Factory.unweighted();

var task = Tasks.leaf(LabelForProgressTracking.BetweennessCentrality.value, samplingSize.orElse(graph.nodeCount()));
var task = Tasks.leaf(
LabelForProgressTracking.BetweennessCentrality.value,
samplingSize.orElse(graph.nodeCount())
);
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);

var algorithm = new BetweennessCentrality(
Expand All @@ -69,7 +75,8 @@ BetwennessCentralityResult betweennessCentrality(Graph graph, BetweennessCentral
traverserFactory,
DefaultPool.INSTANCE,
parameters.concurrency(),
progressTracker
progressTracker,
terminationFlag
);

return algorithm.compute();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ static CentralityApplications create(
ProgressTrackerCreator progressTrackerCreator
) {
var estimation = new CentralityAlgorithmsEstimationModeBusinessFacade(estimationTemplate);
var algorithms = new CentralityAlgorithms(progressTrackerCreator);
var algorithms = new CentralityAlgorithms(progressTrackerCreator, requestScopedDependencies.getTerminationFlag());
var mutateNodePropertyService = new MutateNodePropertyService(log);
var mutateNodeProperty = new MutateNodeProperty(mutateNodePropertyService);
var mutation = new CentralityAlgorithmsMutateModeBusinessFacade(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.schema.GraphSchema;
import org.neo4j.gds.applications.ApplicationsFacade;
import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies;
import org.neo4j.gds.catalog.GraphProjectProc;
import org.neo4j.gds.catalog.GraphStreamNodePropertiesProc;
import org.neo4j.gds.core.CypherMapWrapper;
Expand Down Expand Up @@ -62,6 +63,7 @@
import org.neo4j.gds.procedures.algorithms.centrality.CentralityProcedureFacade;
import org.neo4j.gds.procedures.algorithms.configuration.ConfigurationParser;
import org.neo4j.gds.procedures.algorithms.stubs.GenericStub;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.test.TestProc;

import java.util.ArrayList;
Expand Down Expand Up @@ -423,7 +425,17 @@ void shouldEstimateMemory() {
* But let's be honest, there is enough work in front of us that such a fix is lower priority right now.
*/
private static AlgorithmsProcedureFacade createAlgorithmsProcedureFacade() {
var applicationsFacade = ApplicationsFacade.create(null, Optional.empty(), null, null, null, null, null, null);
var applicationsFacade = ApplicationsFacade.create(
null,
Optional.empty(),
null,
null,
null,
null,
RequestScopedDependencies.builder().with(
TerminationFlag.RUNNING_TRUE).build(),
null
);
var configurationParser = new ConfigurationParser(null, null);
var genericStub = new GenericStub(null, null, null, configurationParser, null, null);
var centralityProcedureFacade = CentralityProcedureFacade.create(
Expand Down

0 comments on commit d799f37

Please sign in to comment.