Skip to content

Commit c23175a

Browse files
committed
migrate link prediction configuration
1 parent 4c082c0 commit c23175a

File tree

9 files changed

+172
-95
lines changed

9 files changed

+172
-95
lines changed

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddTrainerMethodProcs.java

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
import org.neo4j.gds.core.ConfigKeyValidation;
2424
import org.neo4j.gds.ml.api.TrainingMethod;
2525
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
26-
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
2726
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
2827
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
2928
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
3029
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
30+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
3131
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
32+
import org.neo4j.procedure.Context;
3233
import org.neo4j.procedure.Description;
3334
import org.neo4j.procedure.Internal;
3435
import org.neo4j.procedure.Name;
@@ -40,24 +41,16 @@
4041
import static org.neo4j.procedure.Mode.READ;
4142

4243
public class LinkPredictionPipelineAddTrainerMethodProcs extends BaseProc {
44+
@Context
45+
public GraphDataScienceProcedures facade;
4346

4447
@Procedure(name = "gds.beta.pipeline.linkPrediction.addLogisticRegression", mode = READ)
4548
@Description("Add a logistic regression configuration to the parameter space of the link prediction train pipeline.")
4649
public Stream<PipelineInfoResult> addLogisticRegression(
4750
@Name("pipelineName") String pipelineName,
4851
@Name(value = "config", defaultValue = "{}") Map<String, Object> logisticRegressionClassifierConfig
4952
) {
50-
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, LinkPredictionTrainingPipeline.class);
51-
52-
var allowedKeys = LogisticRegressionTrainConfig.DEFAULT.configKeys();
53-
ConfigKeyValidation.requireOnlyKeysFrom(allowedKeys, logisticRegressionClassifierConfig.keySet());
54-
55-
var tunableTrainerConfig = TunableTrainerConfig.of(logisticRegressionClassifierConfig, TrainingMethod.LogisticRegression);
56-
pipeline.addTrainerConfig(
57-
tunableTrainerConfig
58-
);
59-
60-
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
53+
return facade.pipelines().linkPrediction().addLogisticRegression(pipelineName, logisticRegressionClassifierConfig);
6154
}
6255

6356
@Procedure(name = "gds.beta.pipeline.linkPrediction.addRandomForest", mode = READ)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.procedures.pipelines;
21+
22+
import org.neo4j.gds.api.User;
23+
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
24+
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
25+
26+
import java.util.function.BiConsumer;
27+
import java.util.function.BiFunction;
28+
import java.util.function.Function;
29+
import java.util.function.Supplier;
30+
import java.util.stream.Stream;
31+
32+
class Configurer {
33+
private final PipelineRepository pipelineRepository;
34+
private final User user;
35+
36+
Configurer(PipelineRepository pipelineRepository, User user) {
37+
this.pipelineRepository = pipelineRepository;
38+
this.user = user;
39+
}
40+
41+
/**
42+
* Some dull scaffolding
43+
*/
44+
<CONFIGURATION> Stream<PipelineInfoResult> configureLinkPredictionTrainingPipeline(
45+
String pipelineNameAsString,
46+
Supplier<CONFIGURATION> configurationSupplier,
47+
BiConsumer<LinkPredictionTrainingPipeline, CONFIGURATION> action
48+
) {
49+
return configure(
50+
pipelineNameAsString,
51+
pipelineName -> pipelineRepository.getLinkPredictionTrainingPipeline(user, pipelineName),
52+
configurationSupplier,
53+
action,
54+
PipelineInfoResult::create
55+
);
56+
}
57+
58+
/**
59+
* Some more dull scaffolding
60+
*/
61+
<CONFIGURATION> Stream<NodePipelineInfoResult> configureNodeClassificationTrainingPipeline(
62+
String pipelineNameAsString,
63+
Supplier<CONFIGURATION> configurationSupplier,
64+
BiConsumer<NodeClassificationTrainingPipeline, CONFIGURATION> action
65+
) {
66+
return configure(
67+
pipelineNameAsString,
68+
pipelineName -> pipelineRepository.getNodeClassificationTrainingPipeline(user, pipelineName),
69+
configurationSupplier,
70+
action,
71+
NodePipelineInfoResult::create
72+
);
73+
}
74+
75+
/**
76+
* Some dull and very generic scaffolding
77+
*/
78+
private <CONFIGURATION, PIPELINE, RESULT> Stream<RESULT> configure(
79+
String pipelineNameAsString,
80+
Function<PipelineName, PIPELINE> pipelineSupplier,
81+
Supplier<CONFIGURATION> configurationSupplier,
82+
BiConsumer<PIPELINE, CONFIGURATION> action,
83+
BiFunction<PipelineName, PIPELINE, RESULT> resultRenderer
84+
) {
85+
var pipelineName = PipelineName.parse(pipelineNameAsString);
86+
var pipeline = pipelineSupplier.apply(pipelineName);
87+
88+
var configuration = configurationSupplier.get();
89+
90+
action.accept(pipeline, configuration);
91+
92+
var r = resultRenderer.apply(pipelineName, pipeline);
93+
94+
return Stream.of(r);
95+
}
96+
}

procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/LinkPredictionFacade.java

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,38 @@
1919
*/
2020
package org.neo4j.gds.procedures.pipelines;
2121

22+
import org.neo4j.gds.api.User;
23+
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
24+
2225
import java.util.Map;
2326
import java.util.stream.Stream;
2427

25-
public class LinkPredictionFacade {
28+
public final class LinkPredictionFacade {
29+
private final Configurer configurer;
30+
2631
private final PipelineConfigurationParser pipelineConfigurationParser;
2732
private final PipelineApplications pipelineApplications;
2833

29-
LinkPredictionFacade(
30-
PipelineConfigurationParser pipelineConfigurationParser,
34+
private LinkPredictionFacade(
35+
Configurer configurer, PipelineConfigurationParser pipelineConfigurationParser,
3136
PipelineApplications pipelineApplications
3237
) {
38+
this.configurer = configurer;
3339
this.pipelineConfigurationParser = pipelineConfigurationParser;
3440
this.pipelineApplications = pipelineApplications;
3541
}
3642

43+
static LinkPredictionFacade create(
44+
User user,
45+
PipelineConfigurationParser pipelineConfigurationParser,
46+
PipelineApplications pipelineApplications,
47+
PipelineRepository pipelineRepository
48+
) {
49+
var configurer = new Configurer(pipelineRepository, user);
50+
51+
return new LinkPredictionFacade(configurer, pipelineConfigurationParser, pipelineApplications);
52+
}
53+
3754
public Stream<PipelineInfoResult> addFeature(
3855
String pipelineNameAsString,
3956
String featureType,
@@ -44,11 +61,19 @@ public Stream<PipelineInfoResult> addFeature(
4461

4562
var pipeline = pipelineApplications.addFeature(pipelineName, featureType, configuration);
4663

47-
var result = PipelineInfoResult.create(pipelineName.value, pipeline);
64+
var result = PipelineInfoResult.create(pipelineName, pipeline);
4865

4966
return Stream.of(result);
5067
}
5168

69+
public Stream<PipelineInfoResult> addLogisticRegression(String pipelineName, Map<String, Object> configuration) {
70+
return configurer.configureLinkPredictionTrainingPipeline(
71+
pipelineName,
72+
() -> pipelineConfigurationParser.parseLogisticRegressionTrainerConfig(configuration),
73+
TrainingPipeline::addTrainerConfig
74+
);
75+
}
76+
5277
public Stream<PipelineInfoResult> addNodeProperty(
5378
String pipelineNameAsString,
5479
String taskName,
@@ -62,7 +87,7 @@ public Stream<PipelineInfoResult> addNodeProperty(
6287
procedureConfig
6388
);
6489

65-
var result = PipelineInfoResult.create(pipelineName.value, pipeline);
90+
var result = PipelineInfoResult.create(pipelineName, pipeline);
6691

6792
return Stream.of(result);
6893
}

procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationFacade.java

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,29 @@
2424
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
2525
import org.neo4j.gds.core.model.ModelCatalog;
2626
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
27+
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
2728
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureStep;
2829
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
2930

3031
import java.util.ArrayList;
3132
import java.util.List;
3233
import java.util.Map;
33-
import java.util.function.BiFunction;
34-
import java.util.function.Supplier;
3534
import java.util.stream.Stream;
3635

3736
public final class NodeClassificationFacade {
37+
private final Configurer configurer;
3838
private final NodeClassificationPredictConfigPreProcessor nodeClassificationPredictConfigPreProcessor;
39+
3940
private final PipelineConfigurationParser pipelineConfigurationParser;
4041
private final PipelineApplications pipelineApplications;
4142

4243
NodeClassificationFacade(
44+
Configurer configurer,
4345
NodeClassificationPredictConfigPreProcessor nodeClassificationPredictConfigPreProcessor,
4446
PipelineConfigurationParser pipelineConfigurationParser,
4547
PipelineApplications pipelineApplications
4648
) {
49+
this.configurer = configurer;
4750
this.nodeClassificationPredictConfigPreProcessor = nodeClassificationPredictConfigPreProcessor;
4851
this.pipelineConfigurationParser = pipelineConfigurationParser;
4952
this.pipelineApplications = pipelineApplications;
@@ -53,14 +56,17 @@ static NodeClassificationFacade create(
5356
ModelCatalog modelCatalog,
5457
User user,
5558
PipelineConfigurationParser pipelineConfigurationParser,
56-
PipelineApplications pipelineApplications
59+
PipelineApplications pipelineApplications,
60+
PipelineRepository pipelineRepository
5761
) {
62+
var configurer = new Configurer(pipelineRepository, user);
5863
var nodeClassificationPredictConfigPreProcessor = new NodeClassificationPredictConfigPreProcessor(
5964
modelCatalog,
6065
user
6166
);
6267

6368
return new NodeClassificationFacade(
69+
configurer,
6470
nodeClassificationPredictConfigPreProcessor,
6571
pipelineConfigurationParser,
6672
pipelineApplications
@@ -71,18 +77,18 @@ public Stream<NodePipelineInfoResult> addLogisticRegression(
7177
String pipelineName,
7278
Map<String, Object> configuration
7379
) {
74-
return configure(
80+
return configurer.configureNodeClassificationTrainingPipeline(
7581
pipelineName,
7682
() -> pipelineConfigurationParser.parseLogisticRegressionTrainerConfig(configuration),
77-
pipelineApplications::addTrainerConfiguration
83+
TrainingPipeline::addTrainerConfig
7884
);
7985
}
8086

8187
public Stream<NodePipelineInfoResult> addMLP(String pipelineName, Map<String, Object> configuration) {
82-
return configure(
88+
return configurer.configureNodeClassificationTrainingPipeline(
8389
pipelineName,
8490
() -> pipelineConfigurationParser.parseMLPClassifierTrainConfig(configuration),
85-
pipelineApplications::addTrainerConfiguration
91+
TrainingPipeline::addTrainerConfig
8692
);
8793
}
8894

@@ -105,28 +111,26 @@ public Stream<NodePipelineInfoResult> addNodeProperty(
105111
}
106112

107113
public Stream<NodePipelineInfoResult> addRandomForest(String pipelineName, Map<String, Object> configuration) {
108-
return configure(
114+
return configurer.configureNodeClassificationTrainingPipeline(
109115
pipelineName,
110-
() -> pipelineConfigurationParser.parseRandomForestClassifierTrainerConfig(
111-
configuration),
112-
pipelineApplications::addTrainerConfiguration
116+
() -> pipelineConfigurationParser.parseRandomForestClassifierTrainerConfig(configuration),
117+
TrainingPipeline::addTrainerConfig
113118
);
114119
}
115120

116121
public Stream<NodePipelineInfoResult> configureAutoTuning(String pipelineName, Map<String, Object> configuration) {
117-
return configure(
122+
return configurer.configureNodeClassificationTrainingPipeline(
118123
pipelineName,
119124
() -> pipelineConfigurationParser.parseAutoTuningConfig(configuration),
120-
pipelineApplications::configureAutoTuning
125+
TrainingPipeline::setAutoTuningConfig
121126
);
122127
}
123128

124129
public Stream<NodePipelineInfoResult> configureSplit(String pipelineName, Map<String, Object> configuration) {
125-
return configure(
130+
return configurer.configureNodeClassificationTrainingPipeline(
126131
pipelineName,
127-
() -> pipelineConfigurationParser.parseNodePropertyPredictionSplitConfig(
128-
configuration),
129-
pipelineApplications::configureSplit
132+
() -> pipelineConfigurationParser.parseNodePropertyPredictionSplitConfig(configuration),
133+
NodeClassificationTrainingPipeline::setSplitConfig
130134
);
131135
}
132136

@@ -280,22 +284,6 @@ public Stream<MemoryEstimateResult> writeEstimate(
280284
return Stream.of(result);
281285
}
282286

283-
private <CONFIGURATION> Stream<NodePipelineInfoResult> configure(
284-
String pipelineNameAsString,
285-
Supplier<CONFIGURATION> configurationSupplier,
286-
BiFunction<PipelineName, CONFIGURATION, NodeClassificationTrainingPipeline> configurationAction
287-
) {
288-
var pipelineName = PipelineName.parse(pipelineNameAsString);
289-
290-
var configuration = configurationSupplier.get();
291-
292-
var pipeline = configurationAction.apply(pipelineName, configuration);
293-
294-
var result = NodePipelineInfoResult.create(pipelineName, pipeline);
295-
296-
return Stream.of(result);
297-
}
298-
299287
private List<NodeFeatureStep> parseNodeProperties(Object nodeProperties) {
300288
if (nodeProperties instanceof String) return List.of(NodeFeatureStep.of((String) nodeProperties));
301289

0 commit comments

Comments
 (0)