Skip to content

Commit f0cac3a

Browse files
committed
change to non None default for easier handling in nf pipeline
1 parent 65ab1f2 commit f0cac3a

File tree

2 files changed

+17
-19
lines changed

2 files changed

+17
-19
lines changed

drevalpy/experiment.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def drug_response_experiment(
3838
path_out: str = "results/",
3939
overwrite: bool = False,
4040
path_data: str = "data",
41-
model_checkpoint_dir: str | None = None,
41+
model_checkpoint_dir: str = "TEMPORARY",
4242
) -> None:
4343
"""
4444
Run the drug response prediction experiment. Save results to disc.
@@ -84,7 +84,7 @@ def drug_response_experiment(
8484
:param test_mode: test mode one of "LPO", "LCO", "LDO" (leave-pair-out, leave-cell-line-out, leave-drug-out)
8585
:param overwrite: whether to overwrite existing results
8686
:param path_data: path to the data directory, usually data/
87-
:param model_checkpoint_dir: directory to save model checkpoints. If None, a temporary directory is created.
87+
:param model_checkpoint_dir: directory to save model checkpoints. If "TEMPORARY", a temporary directory is created.
8888
:raises ValueError: if no cv splits are found
8989
"""
9090
if baselines is None:
@@ -172,7 +172,6 @@ def drug_response_experiment(
172172
if not os.path.isfile(
173173
prediction_file
174174
): # if this split has not been run yet (or for a single drug model, this drug_id)
175-
176175
tuning_inputs = {
177176
"model": model,
178177
"train_dataset": train_dataset,
@@ -311,7 +310,6 @@ def consolidate_single_drug_model_predictions(
311310
"""
312311
for model in models:
313312
if model.get_model_name() in SINGLE_DRUG_MODEL_FACTORY:
314-
315313
model_instance = MODEL_FACTORY[model.get_model_name()]()
316314
model_path = os.path.join(results_path, model.get_model_name())
317315
out_path = os.path.join(out_path, model.get_model_name())
@@ -324,7 +322,6 @@ def consolidate_single_drug_model_predictions(
324322
os.makedirs(os.path.join(out_path, "robustness"), exist_ok=True)
325323

326324
for split in range(n_cv_splits):
327-
328325
# Collect predictions for drugs across all scenarios (main, cross_study, robustness, randomization)
329326
predictions: Any = {
330327
"main": [],
@@ -594,7 +591,7 @@ def robustness_test(
594591
path_out: str,
595592
split_index: int,
596593
response_transformation: Optional[TransformerMixin] = None,
597-
model_checkpoint_dir: str | None = None,
594+
model_checkpoint_dir: str = "TEMPORARY",
598595
):
599596
"""
600597
Run robustness tests for the given model and dataset.
@@ -612,7 +609,7 @@ def robustness_test(
612609
:param split_index: index of the split
613610
:param response_transformation: sklearn.preprocessing scaler like StandardScaler or MinMaxScaler to use to scale
614611
the target
615-
:param model_checkpoint_dir: directory to save model checkpoints
612+
:param model_checkpoint_dir: directory to save model checkpoints, if "TEMPORARY": temporary directory is used
616613
"""
617614
robustness_test_path = os.path.join(path_out, "robustness")
618615
os.makedirs(robustness_test_path, exist_ok=True)
@@ -648,7 +645,7 @@ def robustness_train_predict(
648645
hpam_set: dict,
649646
path_data: str,
650647
response_transformation: Optional[TransformerMixin] = None,
651-
model_checkpoint_dir: str | None = None,
648+
model_checkpoint_dir: str = "TEMPORARY",
652649
) -> None:
653650
"""
654651
Train and predict for the robustness test.
@@ -662,7 +659,7 @@ def robustness_train_predict(
662659
:param hpam_set: hyperparameters to use
663660
:param path_data: path to the data directory, e.g., data/
664661
:param response_transformation: sklearn.preprocessing scaler like StandardScaler or MinMaxScaler to use to scale
665-
:param model_checkpoint_dir: directory to save model checkpoints
662+
:param model_checkpoint_dir: directory to save model checkpoints. If "TEMPORARY", a temporary directory is created.
666663
"""
667664
train_dataset.shuffle(random_state=trial)
668665
test_dataset.shuffle(random_state=trial)
@@ -693,7 +690,7 @@ def randomization_test(
693690
split_index: int,
694691
randomization_type: str = "permutation",
695692
response_transformation=Optional[TransformerMixin],
696-
model_checkpoint_dir: str | None = None,
693+
model_checkpoint_dir: str = "TEMPORARY",
697694
) -> None:
698695
"""
699696
Run randomization tests for the given model and dataset.
@@ -762,7 +759,7 @@ def randomize_train_predict(
762759
test_dataset: DrugResponseDataset,
763760
early_stopping_dataset: Optional[DrugResponseDataset],
764761
response_transformation: Optional[TransformerMixin],
765-
model_checkpoint_dir: str | None = None,
762+
model_checkpoint_dir: str = "TEMPORARY",
766763
) -> None:
767764
"""
768765
Randomize the features for a given view and run the model.
@@ -859,7 +856,7 @@ def train_and_predict(
859856
response_transformation: TransformerMixin | None = None,
860857
cl_features: FeatureDataset | None = None,
861858
drug_features: FeatureDataset | None = None,
862-
model_checkpoint_dir: str | None = None,
859+
model_checkpoint_dir: str = "TEMPORARY",
863860
) -> DrugResponseDataset:
864861
"""
865862
Train the model and predict the response for the prediction dataset.
@@ -873,7 +870,8 @@ def train_and_predict(
873870
:param response_transformation: normalizer to use for the response data, e.g., StandardScaler
874871
:param cl_features: cell line features
875872
:param drug_features: drug features
876-
:param model_checkpoint_dir: directory for model checkpoints, if None, checkpoints are not saved. Default is None
873+
:param model_checkpoint_dir: directory for model checkpoints, if "TEMPORARY", checkpoints are not saved.
874+
Default is "TEMPORARY"
877875
:returns: prediction dataset with predictions
878876
:raises ValueError: if train_dataset does not have a dataset_name
879877
"""
@@ -924,7 +922,7 @@ def train_and_predict(
924922
"output_earlystopping": early_stopping_dataset,
925923
}
926924

927-
if model_checkpoint_dir is None:
925+
if model_checkpoint_dir == "TEMPORARY":
928926
with tempfile.TemporaryDirectory() as temp_dir:
929927
print(f"Using temporary directory: {temp_dir} for model checkpoints")
930928
train_inputs["model_checkpoint_dir"] = temp_dir
@@ -963,7 +961,7 @@ def train_and_evaluate(
963961
early_stopping_dataset: Optional[DrugResponseDataset] = None,
964962
response_transformation: Optional[TransformerMixin] = None,
965963
metric: str = "rmse",
966-
model_checkpoint_dir: str = "",
964+
model_checkpoint_dir: str = "TEMPORARY",
967965
) -> dict[str, float]:
968966
"""
969967
Train and evaluate the model, i.e., call train_and_predict() and then evaluate().
@@ -1001,7 +999,7 @@ def hpam_tune(
1001999
response_transformation: Optional[TransformerMixin] = None,
10021000
metric: str = "RMSE",
10031001
path_data: str = "data",
1004-
model_checkpoint_dir: str = "",
1002+
model_checkpoint_dir: str = "TEMPORARY",
10051003
) -> dict:
10061004
"""
10071005
Tune the hyperparameters for the given model in an iterative manner.
@@ -1065,7 +1063,7 @@ def hpam_tune_raytune(
10651063
metric: str = "RMSE",
10661064
ray_path: str = "raytune",
10671065
path_data: str = "data",
1068-
model_checkpoint_dir: str = "",
1066+
model_checkpoint_dir: str = "TEMPORARY",
10691067
) -> dict:
10701068
"""
10711069
Tune the hyperparameters for the given model using raytune.

drevalpy/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def get_parser() -> argparse.ArgumentParser:
177177
parser.add_argument(
178178
"--model_checkpoint_dir",
179179
type=str,
180-
default="None",
180+
default="TEMPORARY",
181181
help="Directory to save model checkpoints",
182182
)
183183

@@ -311,7 +311,7 @@ def main(args) -> None:
311311
run_id=args.run_id,
312312
overwrite=args.overwrite,
313313
path_data=args.path_data,
314-
model_checkpoint_dir=args.model_checkpoint_dir if args.model_checkpoint_dir != "None" else None,
314+
model_checkpoint_dir=args.model_checkpoint_dir,
315315
)
316316

317317

0 commit comments

Comments
 (0)