diff --git a/drevalpy/datasets/dataset.py b/drevalpy/datasets/dataset.py index bb46ea8..0613e2e 100644 --- a/drevalpy/datasets/dataset.py +++ b/drevalpy/datasets/dataset.py @@ -274,7 +274,7 @@ def split_dataset( response = self.response if mode == "LPO": - cv_splits = leave_pair_out_cv( + cv_splits = _leave_pair_out_cv( n_cv_splits, response, cell_line_ids, @@ -287,7 +287,7 @@ def split_dataset( elif mode in ["LCO", "LDO"]: group = "cell_line" if mode == "LCO" else "drug" - cv_splits = leave_group_out_cv( + cv_splits = _leave_group_out_cv( group=group, n_cv_splits=n_cv_splits, response=response, @@ -303,7 +303,7 @@ def split_dataset( if split_validation and split_early_stopping: for split in cv_splits: - validation_es, early_stopping = split_early_stopping_data(split["validation"], test_mode=mode) + validation_es, early_stopping = _split_early_stopping_data(split["validation"], test_mode=mode) split["validation_es"] = validation_es split["early_stopping"] = early_stopping self.cv_splits = cv_splits @@ -444,7 +444,7 @@ def inverse_transform(self, response_transformation: TransformerMixin) -> None: self.predictions = response_transformation.inverse_transform(self.predictions.reshape(-1, 1)).squeeze() -def split_early_stopping_data( +def _split_early_stopping_data( validation_dataset: DrugResponseDataset, test_mode: str ) -> tuple[DrugResponseDataset, DrugResponseDataset]: """ @@ -468,7 +468,7 @@ def split_early_stopping_data( return validation_dataset, early_stopping_dataset -def leave_pair_out_cv( +def _leave_pair_out_cv( n_cv_splits: int, response: ArrayLike, cell_line_ids: ArrayLike, @@ -543,7 +543,7 @@ def leave_pair_out_cv( return cv_sets -def leave_group_out_cv( +def _leave_group_out_cv( group: str, n_cv_splits: int, response: ArrayLike,