2424import org .neo4j .gds .applications .algorithms .machinery .MemoryEstimateResult ;
2525import org .neo4j .gds .core .model .ModelCatalog ;
2626import org .neo4j .gds .ml .pipeline .PipelineCompanion ;
27+ import org .neo4j .gds .ml .pipeline .TrainingPipeline ;
2728import org .neo4j .gds .ml .pipeline .nodePipeline .NodeFeatureStep ;
2829import org .neo4j .gds .ml .pipeline .nodePipeline .classification .NodeClassificationTrainingPipeline ;
2930
3031import java .util .ArrayList ;
3132import java .util .List ;
3233import java .util .Map ;
33- import java .util .function .BiFunction ;
34- import java .util .function .Supplier ;
3534import java .util .stream .Stream ;
3635
3736public final class NodeClassificationFacade {
37+ private final Configurer configurer ;
3838 private final NodeClassificationPredictConfigPreProcessor nodeClassificationPredictConfigPreProcessor ;
39+
3940 private final PipelineConfigurationParser pipelineConfigurationParser ;
4041 private final PipelineApplications pipelineApplications ;
4142
4243 NodeClassificationFacade (
44+ Configurer configurer ,
4345 NodeClassificationPredictConfigPreProcessor nodeClassificationPredictConfigPreProcessor ,
4446 PipelineConfigurationParser pipelineConfigurationParser ,
4547 PipelineApplications pipelineApplications
4648 ) {
49+ this .configurer = configurer ;
4750 this .nodeClassificationPredictConfigPreProcessor = nodeClassificationPredictConfigPreProcessor ;
4851 this .pipelineConfigurationParser = pipelineConfigurationParser ;
4952 this .pipelineApplications = pipelineApplications ;
@@ -53,14 +56,17 @@ static NodeClassificationFacade create(
5356 ModelCatalog modelCatalog ,
5457 User user ,
5558 PipelineConfigurationParser pipelineConfigurationParser ,
56- PipelineApplications pipelineApplications
59+ PipelineApplications pipelineApplications ,
60+ PipelineRepository pipelineRepository
5761 ) {
62+ var configurer = new Configurer (pipelineRepository , user );
5863 var nodeClassificationPredictConfigPreProcessor = new NodeClassificationPredictConfigPreProcessor (
5964 modelCatalog ,
6065 user
6166 );
6267
6368 return new NodeClassificationFacade (
69+ configurer ,
6470 nodeClassificationPredictConfigPreProcessor ,
6571 pipelineConfigurationParser ,
6672 pipelineApplications
@@ -71,18 +77,18 @@ public Stream<NodePipelineInfoResult> addLogisticRegression(
7177 String pipelineName ,
7278 Map <String , Object > configuration
7379 ) {
74- return configure (
80+ return configurer . configureNodeClassificationTrainingPipeline (
7581 pipelineName ,
7682 () -> pipelineConfigurationParser .parseLogisticRegressionTrainerConfig (configuration ),
77- pipelineApplications :: addTrainerConfiguration
83+ TrainingPipeline :: addTrainerConfig
7884 );
7985 }
8086
8187 public Stream <NodePipelineInfoResult > addMLP (String pipelineName , Map <String , Object > configuration ) {
82- return configure (
88+ return configurer . configureNodeClassificationTrainingPipeline (
8389 pipelineName ,
8490 () -> pipelineConfigurationParser .parseMLPClassifierTrainConfig (configuration ),
85- pipelineApplications :: addTrainerConfiguration
91+ TrainingPipeline :: addTrainerConfig
8692 );
8793 }
8894
@@ -105,28 +111,26 @@ public Stream<NodePipelineInfoResult> addNodeProperty(
105111 }
106112
107113 public Stream <NodePipelineInfoResult > addRandomForest (String pipelineName , Map <String , Object > configuration ) {
108- return configure (
114+ return configurer . configureNodeClassificationTrainingPipeline (
109115 pipelineName ,
110- () -> pipelineConfigurationParser .parseRandomForestClassifierTrainerConfig (
111- configuration ),
112- pipelineApplications ::addTrainerConfiguration
116+ () -> pipelineConfigurationParser .parseRandomForestClassifierTrainerConfig (configuration ),
117+ TrainingPipeline ::addTrainerConfig
113118 );
114119 }
115120
116121 public Stream <NodePipelineInfoResult > configureAutoTuning (String pipelineName , Map <String , Object > configuration ) {
117- return configure (
122+ return configurer . configureNodeClassificationTrainingPipeline (
118123 pipelineName ,
119124 () -> pipelineConfigurationParser .parseAutoTuningConfig (configuration ),
120- pipelineApplications :: configureAutoTuning
125+ TrainingPipeline :: setAutoTuningConfig
121126 );
122127 }
123128
124129 public Stream <NodePipelineInfoResult > configureSplit (String pipelineName , Map <String , Object > configuration ) {
125- return configure (
130+ return configurer . configureNodeClassificationTrainingPipeline (
126131 pipelineName ,
127- () -> pipelineConfigurationParser .parseNodePropertyPredictionSplitConfig (
128- configuration ),
129- pipelineApplications ::configureSplit
132+ () -> pipelineConfigurationParser .parseNodePropertyPredictionSplitConfig (configuration ),
133+ NodeClassificationTrainingPipeline ::setSplitConfig
130134 );
131135 }
132136
@@ -280,22 +284,6 @@ public Stream<MemoryEstimateResult> writeEstimate(
280284 return Stream .of (result );
281285 }
282286
283- private <CONFIGURATION > Stream <NodePipelineInfoResult > configure (
284- String pipelineNameAsString ,
285- Supplier <CONFIGURATION > configurationSupplier ,
286- BiFunction <PipelineName , CONFIGURATION , NodeClassificationTrainingPipeline > configurationAction
287- ) {
288- var pipelineName = PipelineName .parse (pipelineNameAsString );
289-
290- var configuration = configurationSupplier .get ();
291-
292- var pipeline = configurationAction .apply (pipelineName , configuration );
293-
294- var result = NodePipelineInfoResult .create (pipelineName , pipeline );
295-
296- return Stream .of (result );
297- }
298-
299287 private List <NodeFeatureStep > parseNodeProperties (Object nodeProperties ) {
300288 if (nodeProperties instanceof String ) return List .of (NodeFeatureStep .of ((String ) nodeProperties ));
301289
0 commit comments