@@ -38,7 +38,7 @@ def drug_response_experiment(
38
38
path_out : str = "results/" ,
39
39
overwrite : bool = False ,
40
40
path_data : str = "data" ,
41
- model_checkpoint_dir : str | None = None ,
41
+ model_checkpoint_dir : str = "TEMPORARY" ,
42
42
) -> None :
43
43
"""
44
44
Run the drug response prediction experiment. Save results to disc.
@@ -84,7 +84,7 @@ def drug_response_experiment(
84
84
:param test_mode: test mode one of "LPO", "LCO", "LDO" (leave-pair-out, leave-cell-line-out, leave-drug-out)
85
85
:param overwrite: whether to overwrite existing results
86
86
:param path_data: path to the data directory, usually data/
87
- :param model_checkpoint_dir: directory to save model checkpoints. If None , a temporary directory is created.
87
+ :param model_checkpoint_dir: directory to save model checkpoints. If "TEMPORARY" , a temporary directory is created.
88
88
:raises ValueError: if no cv splits are found
89
89
"""
90
90
if baselines is None :
@@ -172,7 +172,6 @@ def drug_response_experiment(
172
172
if not os .path .isfile (
173
173
prediction_file
174
174
): # if this split has not been run yet (or for a single drug model, this drug_id)
175
-
176
175
tuning_inputs = {
177
176
"model" : model ,
178
177
"train_dataset" : train_dataset ,
@@ -311,7 +310,6 @@ def consolidate_single_drug_model_predictions(
311
310
"""
312
311
for model in models :
313
312
if model .get_model_name () in SINGLE_DRUG_MODEL_FACTORY :
314
-
315
313
model_instance = MODEL_FACTORY [model .get_model_name ()]()
316
314
model_path = os .path .join (results_path , model .get_model_name ())
317
315
out_path = os .path .join (out_path , model .get_model_name ())
@@ -324,7 +322,6 @@ def consolidate_single_drug_model_predictions(
324
322
os .makedirs (os .path .join (out_path , "robustness" ), exist_ok = True )
325
323
326
324
for split in range (n_cv_splits ):
327
-
328
325
# Collect predictions for drugs across all scenarios (main, cross_study, robustness, randomization)
329
326
predictions : Any = {
330
327
"main" : [],
@@ -594,7 +591,7 @@ def robustness_test(
594
591
path_out : str ,
595
592
split_index : int ,
596
593
response_transformation : Optional [TransformerMixin ] = None ,
597
- model_checkpoint_dir : str | None = None ,
594
+ model_checkpoint_dir : str = "TEMPORARY" ,
598
595
):
599
596
"""
600
597
Run robustness tests for the given model and dataset.
@@ -612,7 +609,7 @@ def robustness_test(
612
609
:param split_index: index of the split
613
610
:param response_transformation: sklearn.preprocessing scaler like StandardScaler or MinMaxScaler to use to scale
614
611
the target
615
- :param model_checkpoint_dir: directory to save model checkpoints
612
+ :param model_checkpoint_dir: directory to save model checkpoints, if "TEMPORARY": temporary directory is used
616
613
"""
617
614
robustness_test_path = os .path .join (path_out , "robustness" )
618
615
os .makedirs (robustness_test_path , exist_ok = True )
@@ -648,7 +645,7 @@ def robustness_train_predict(
648
645
hpam_set : dict ,
649
646
path_data : str ,
650
647
response_transformation : Optional [TransformerMixin ] = None ,
651
- model_checkpoint_dir : str | None = None ,
648
+ model_checkpoint_dir : str = "TEMPORARY" ,
652
649
) -> None :
653
650
"""
654
651
Train and predict for the robustness test.
@@ -662,7 +659,7 @@ def robustness_train_predict(
662
659
:param hpam_set: hyperparameters to use
663
660
:param path_data: path to the data directory, e.g., data/
664
661
:param response_transformation: sklearn.preprocessing scaler like StandardScaler or MinMaxScaler to use to scale
665
- :param model_checkpoint_dir: directory to save model checkpoints
662
+ :param model_checkpoint_dir: directory to save model checkpoints. If "TEMPORARY", a temporary directory is created.
666
663
"""
667
664
train_dataset .shuffle (random_state = trial )
668
665
test_dataset .shuffle (random_state = trial )
@@ -693,7 +690,7 @@ def randomization_test(
693
690
split_index : int ,
694
691
randomization_type : str = "permutation" ,
695
692
response_transformation = Optional [TransformerMixin ],
696
- model_checkpoint_dir : str | None = None ,
693
+ model_checkpoint_dir : str = "TEMPORARY" ,
697
694
) -> None :
698
695
"""
699
696
Run randomization tests for the given model and dataset.
@@ -762,7 +759,7 @@ def randomize_train_predict(
762
759
test_dataset : DrugResponseDataset ,
763
760
early_stopping_dataset : Optional [DrugResponseDataset ],
764
761
response_transformation : Optional [TransformerMixin ],
765
- model_checkpoint_dir : str | None = None ,
762
+ model_checkpoint_dir : str = "TEMPORARY" ,
766
763
) -> None :
767
764
"""
768
765
Randomize the features for a given view and run the model.
@@ -859,7 +856,7 @@ def train_and_predict(
859
856
response_transformation : TransformerMixin | None = None ,
860
857
cl_features : FeatureDataset | None = None ,
861
858
drug_features : FeatureDataset | None = None ,
862
- model_checkpoint_dir : str | None = None ,
859
+ model_checkpoint_dir : str = "TEMPORARY" ,
863
860
) -> DrugResponseDataset :
864
861
"""
865
862
Train the model and predict the response for the prediction dataset.
@@ -873,7 +870,8 @@ def train_and_predict(
873
870
:param response_transformation: normalizer to use for the response data, e.g., StandardScaler
874
871
:param cl_features: cell line features
875
872
:param drug_features: drug features
876
- :param model_checkpoint_dir: directory for model checkpoints, if None, checkpoints are not saved. Default is None
873
+ :param model_checkpoint_dir: directory for model checkpoints, if "TEMPORARY", checkpoints are not saved.
874
+ Default is "TEMPORARY"
877
875
:returns: prediction dataset with predictions
878
876
:raises ValueError: if train_dataset does not have a dataset_name
879
877
"""
@@ -924,7 +922,7 @@ def train_and_predict(
924
922
"output_earlystopping" : early_stopping_dataset ,
925
923
}
926
924
927
- if model_checkpoint_dir is None :
925
+ if model_checkpoint_dir == "TEMPORARY" :
928
926
with tempfile .TemporaryDirectory () as temp_dir :
929
927
print (f"Using temporary directory: { temp_dir } for model checkpoints" )
930
928
train_inputs ["model_checkpoint_dir" ] = temp_dir
@@ -963,7 +961,7 @@ def train_and_evaluate(
963
961
early_stopping_dataset : Optional [DrugResponseDataset ] = None ,
964
962
response_transformation : Optional [TransformerMixin ] = None ,
965
963
metric : str = "rmse" ,
966
- model_checkpoint_dir : str = "" ,
964
+ model_checkpoint_dir : str = "TEMPORARY " ,
967
965
) -> dict [str , float ]:
968
966
"""
969
967
Train and evaluate the model, i.e., call train_and_predict() and then evaluate().
@@ -1001,7 +999,7 @@ def hpam_tune(
1001
999
response_transformation : Optional [TransformerMixin ] = None ,
1002
1000
metric : str = "RMSE" ,
1003
1001
path_data : str = "data" ,
1004
- model_checkpoint_dir : str = "" ,
1002
+ model_checkpoint_dir : str = "TEMPORARY " ,
1005
1003
) -> dict :
1006
1004
"""
1007
1005
Tune the hyperparameters for the given model in an iterative manner.
@@ -1065,7 +1063,7 @@ def hpam_tune_raytune(
1065
1063
metric : str = "RMSE" ,
1066
1064
ray_path : str = "raytune" ,
1067
1065
path_data : str = "data" ,
1068
- model_checkpoint_dir : str = "" ,
1066
+ model_checkpoint_dir : str = "TEMPORARY " ,
1069
1067
) -> dict :
1070
1068
"""
1071
1069
Tune the hyperparameters for the given model using raytune.
0 commit comments