Skip to content

Commit 84ebe66

Browse files
committed
migrate node classification predict estimate
1 parent 2b11fb8 commit 84ebe66

File tree

35 files changed

+444
-345
lines changed

35 files changed

+444
-345
lines changed

proc/community/src/integrationTest/java/org/neo4j/gds/labelpropagation/LabelPropagationMutateProcTest.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import org.neo4j.gds.metrics.procedures.DeprecatedProceduresMetricService;
7474
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
7575
import org.neo4j.gds.procedures.GraphDataScienceProceduresBuilder;
76+
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
7677
import org.neo4j.gds.procedures.algorithms.community.CommunityProcedureFacade;
7778
import org.neo4j.gds.procedures.algorithms.configuration.ConfigurationParser;
7879
import org.neo4j.gds.procedures.algorithms.configuration.UserSpecificConfigurationParser;
@@ -502,11 +503,7 @@ private GraphDataScienceProcedures constructFacade() {
502503

503504
var configurationParser = new UserSpecificConfigurationParser(new ConfigurationParser(DefaultsConfiguration.Instance, LimitsConfiguration.Instance),requestScopedDependencies.getUser());
504505

505-
var genericStub = GenericStub.create(
506-
graphStoreCatalogService,
507-
configurationParser,
508-
requestScopedDependencies
509-
);
506+
var genericStub = new GenericStub(configurationParser, null);
510507
var applicationsFacade = ApplicationsFacade.create(
511508
logMock,
512509
null,
@@ -537,7 +534,7 @@ private GraphDataScienceProcedures constructFacade() {
537534
);
538535

539536
return new GraphDataScienceProceduresBuilder(Log.noOpLog())
540-
.with(communityProcedureFacade)
537+
.with(new AlgorithmsProcedureFacade(null, communityProcedureFacade, null, null, null, null, null))
541538
.with(DeprecatedProceduresMetricService.PASSTHROUGH)
542539
.build();
543540
}

proc/community/src/integrationTest/java/org/neo4j/gds/modularityoptimization/ModularityOptimizationMutateProcTest.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import org.neo4j.gds.metrics.projections.ProjectionMetricsService;
7474
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
7575
import org.neo4j.gds.procedures.GraphDataScienceProceduresBuilder;
76+
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
7677
import org.neo4j.gds.procedures.algorithms.community.CommunityProcedureFacade;
7778
import org.neo4j.gds.procedures.algorithms.configuration.ConfigurationParser;
7879
import org.neo4j.gds.procedures.algorithms.configuration.UserSpecificConfigurationParser;
@@ -509,11 +510,7 @@ private GraphDataScienceProcedures createFacade() {
509510

510511
var configurationParser = new UserSpecificConfigurationParser(new ConfigurationParser(DefaultsConfiguration.Instance, LimitsConfiguration.Instance),requestScopedDependencies.getUser());
511512

512-
var genericStub = GenericStub.create(
513-
graphStoreCatalogService,
514-
configurationParser,
515-
requestScopedDependencies
516-
);
513+
var genericStub = new GenericStub(configurationParser, null);
517514
var applicationsFacade = ApplicationsFacade.create(
518515
logMock,
519516
null,
@@ -544,7 +541,7 @@ private GraphDataScienceProcedures createFacade() {
544541
);
545542

546543
return new GraphDataScienceProceduresBuilder(logMock)
547-
.with(communityProcedureFacade)
544+
.with(new AlgorithmsProcedureFacade(null, communityProcedureFacade, null, null, null, null, null))
548545
.with(DeprecatedProceduresMetricService.PASSTHROUGH)
549546
.build();
550547
}

proc/community/src/integrationTest/java/org/neo4j/gds/wcc/WccMutateProcTest.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
import org.neo4j.gds.metrics.projections.ProjectionMetricsService;
7575
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
7676
import org.neo4j.gds.procedures.GraphDataScienceProceduresBuilder;
77+
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
7778
import org.neo4j.gds.procedures.algorithms.community.CommunityProcedureFacade;
7879
import org.neo4j.gds.procedures.algorithms.community.WccMutateResult;
7980
import org.neo4j.gds.procedures.algorithms.configuration.ConfigurationParser;
@@ -605,11 +606,7 @@ private GraphDataScienceProcedures constructGraphDataScienceProcedures() {
605606

606607
var configurationParser = new UserSpecificConfigurationParser(new ConfigurationParser(DefaultsConfiguration.Instance, LimitsConfiguration.Instance),requestScopedDependencies.getUser());
607608

608-
var genericStub = GenericStub.create(
609-
graphStoreCatalogService,
610-
configurationParser,
611-
requestScopedDependencies
612-
);
609+
var genericStub = new GenericStub(configurationParser, null);
613610
var communityProcedureFacade = CommunityProcedureFacade.create(
614611
genericStub,
615612
applicationsFacade,
@@ -622,7 +619,7 @@ private GraphDataScienceProcedures constructGraphDataScienceProcedures() {
622619
);
623620

624621
return new GraphDataScienceProceduresBuilder(Log.noOpLog())
625-
.with(communityProcedureFacade)
622+
.with(new AlgorithmsProcedureFacade(null, communityProcedureFacade, null, null, null, null, null))
626623
.with(DeprecatedProceduresMetricService.PASSTHROUGH)
627624
.build();
628625
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/NodePropertyPredictPipelineFilterUtil.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.neo4j.gds.core.model.ModelCatalog;
2525
import org.neo4j.gds.ml.models.BaseModelData;
2626
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPipelineBaseTrainConfig;
27+
import org.neo4j.gds.procedures.pipelines.NodePropertyPredictPipelineBaseConfig;
2728

2829
import java.util.List;
2930

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineMutateProc.java

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
package org.neo4j.gds.ml.pipeline.node.classification.predict;
2121

2222
import org.neo4j.gds.BaseProc;
23+
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
2324
import org.neo4j.gds.core.model.ModelCatalog;
2425
import org.neo4j.gds.executor.ExecutionContext;
25-
import org.neo4j.gds.executor.MemoryEstimationExecutor;
2626
import org.neo4j.gds.executor.ProcedureExecutor;
2727
import org.neo4j.gds.ml.pipeline.node.PredictMutateResult;
28-
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
28+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
2929
import org.neo4j.procedure.Context;
3030
import org.neo4j.procedure.Description;
3131
import org.neo4j.procedure.Mode;
@@ -43,6 +43,9 @@ public class NodeClassificationPipelineMutateProc extends BaseProc {
4343
@Context
4444
public ModelCatalog internalModelCatalog;
4545

46+
@Context
47+
public GraphDataScienceProcedures facade;
48+
4649
@Procedure(name = "gds.beta.pipeline.nodeClassification.predict.mutate", mode = Mode.READ)
4750
@Description(PREDICT_DESCRIPTION)
4851
public Stream<PredictMutateResult> mutate(
@@ -59,15 +62,10 @@ public Stream<PredictMutateResult> mutate(
5962
@Procedure(name = "gds.beta.pipeline.nodeClassification.predict.mutate.estimate", mode = Mode.READ)
6063
@Description(ESTIMATE_PREDICT_DESCRIPTION)
6164
public Stream<MemoryEstimateResult> estimate(
62-
@Name(value = "graphName") Object graphName,
63-
@Name(value = "configuration") Map<String, Object> configuration
65+
@Name(value = "graphName") Object graphNameOrConfiguration,
66+
@Name(value = "configuration") Map<String, Object> algorithmConfiguration
6467
) {
65-
preparePipelineConfig(graphName, configuration);
66-
return new MemoryEstimationExecutor<>(
67-
new NodeClassificationPipelineMutateSpec(),
68-
executionContext(),
69-
transactionContext()
70-
).computeEstimate(graphName, configuration);
68+
return facade.pipelines().nodeClassificationMutateEstimate(graphNameOrConfiguration, algorithmConfiguration);
7169
}
7270

7371
@Override

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineMutateSpec.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
import org.neo4j.gds.executor.GdsCallable;
2929
import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction;
3030
import org.neo4j.gds.ml.pipeline.node.PredictMutateResult;
31+
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictConfigPreProcessor;
32+
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineMutateConfig;
3133
import org.neo4j.gds.result.AbstractResultBuilder;
3234

3335
import java.util.Map;

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.neo4j.gds.executor.ExecutionContext;
2929
import org.neo4j.gds.executor.GdsCallable;
3030
import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction;
31+
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictConfigPreProcessor;
3132

3233
import java.util.List;
3334
import java.util.Map;

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPredictNewMutateConfigFn.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import org.neo4j.gds.core.CypherMapWrapper;
2323
import org.neo4j.gds.core.model.ModelCatalog;
2424
import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction;
25+
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineMutateConfig;
26+
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineMutateConfigImpl;
2527

2628
import static org.neo4j.gds.ml.pipeline.node.NodePropertyPredictPipelineFilterUtil.generatePredictPipelineFilter;
2729

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPredictPipelineAlgorithmFactory.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.neo4j.gds.ml.models.Classifier;
3333
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineModelInfo;
3434
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
35+
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineBaseConfig;
3536

3637
public class NodeClassificationPredictPipelineAlgorithmFactory
3738
<CONFIG extends NodeClassificationPredictPipelineBaseConfig>

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPredictPipelineExecutor.java

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.neo4j.gds.core.model.Model;
2525
import org.neo4j.gds.core.model.ModelCatalog;
2626
import org.neo4j.gds.mem.MemoryEstimation;
27-
import org.neo4j.gds.mem.MemoryEstimations;
2827
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2928
import org.neo4j.gds.core.utils.progress.tasks.Task;
3029
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
@@ -42,18 +41,18 @@
4241
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineModelInfo;
4342
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
4443
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
44+
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineBaseConfig;
45+
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineConstants;
46+
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineEstimator;
4547
import org.neo4j.gds.utils.StringJoining;
4648

47-
import java.util.List;
48-
4949
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
5050

5151
public class NodeClassificationPredictPipelineExecutor extends PredictPipelineExecutor<
5252
NodeClassificationPredictPipelineBaseConfig,
5353
NodePropertyPredictPipeline,
5454
NodeClassificationPipelineResult
5555
> {
56-
private static final int MIN_BATCH_SIZE = 100;
5756
private final Classifier.ClassifierData modelData;
5857
private final LocalIdMap classIdMap;
5958
private final PipelineGraphFilter predictGraphFilter;
@@ -90,35 +89,9 @@ public static MemoryEstimation estimate(
9089
ModelCatalog modelCatalog,
9190
AlgorithmsProcedureFacade algorithmsProcedureFacade
9291
) {
93-
var pipeline = model.customInfo().pipeline();
94-
var classCount = model.customInfo().classes().size();
95-
var featureCount = model.data().featureDimension();
96-
97-
var combinedNodeLabels = configuration.targetNodeLabels().isEmpty() ? model.trainConfig().targetNodeLabels() : configuration.targetNodeLabels();
98-
var combinedRelationshipTypes = configuration.relationshipTypes().isEmpty() ? model.trainConfig().relationshipTypes() : configuration.relationshipTypes();
99-
100-
MemoryEstimation nodePropertyStepEstimation = NodePropertyStepExecutor.estimateNodePropertySteps(
101-
algorithmsProcedureFacade,
102-
modelCatalog,
103-
configuration.username(),
104-
pipeline.nodePropertySteps(),
105-
combinedNodeLabels,
106-
combinedRelationshipTypes
107-
);
108-
109-
var predictionEstimation = MemoryEstimations.builder().add(
110-
"Pipeline Predict",
111-
NodeClassificationPredict.memoryEstimationWithDerivedBatchSize(
112-
model.data().trainerMethod(),
113-
configuration.includePredictedProbabilities(),
114-
MIN_BATCH_SIZE,
115-
featureCount,
116-
classCount,
117-
false
118-
)
119-
).build();
92+
var estimator = new NodeClassificationPredictPipelineEstimator(modelCatalog, algorithmsProcedureFacade);
12093

121-
return MemoryEstimations.maxEstimation(List.of(nodePropertyStepEstimation, predictionEstimation));
94+
return estimator.estimate(model, configuration);
12295
}
12396

12497
@Override
@@ -143,7 +116,7 @@ protected NodeClassificationPipelineResult execute() {
143116
var nodeClassificationResult = new NodeClassificationPredict(
144117
ClassifierFactory.create(modelData),
145118
features,
146-
MIN_BATCH_SIZE,
119+
NodeClassificationPredictPipelineConstants.MIN_BATCH_SIZE,
147120
config.concurrency(),
148121
config.includePredictedProbabilities(),
149122
progressTracker,

0 commit comments

Comments
 (0)