diff --git a/drevalpy/visualization/utils.py b/drevalpy/visualization/utils.py index 8ec0fc57..b33c01ba 100644 --- a/drevalpy/visualization/utils.py +++ b/drevalpy/visualization/utils.py @@ -251,6 +251,8 @@ def generate_model_names(test_mode, model_name, pred_file): pred_setting = "randomize-" + "-".join(file_parts[1:-2]) elif pred_rand_rob == "robustness": pred_setting = "-".join(file_parts[:2]) + elif pred_rand_rob == "cross": + pred_setting = "cross-study-" + file_parts[2] else: raise ValueError(f"Unknown prediction setting: {pred_rand_rob}") split = "_".join(os.path.basename(pred_file).split(".")[0].split("_")[-2:]) diff --git a/tests/test_run_suite.py b/tests/test_run_suite.py index 70cd4fc1..1a3d1ed3 100644 --- a/tests/test_run_suite.py +++ b/tests/test_run_suite.py @@ -23,7 +23,7 @@ "randomization_mode": ["SVRC"], "randomization_type": "permutation", "n_trials_robustness": 2, - "cross_study_datasets": [], + "cross_study_datasets": ["GDSC2"], "curve_curator": False, "overwrite": False, "optim_metric": "RMSE",