From cc665822eafd69cbec3862fbe9b0608c2b8efbd8 Mon Sep 17 00:00:00 2001
From: PascalIversen
Date: Tue, 8 Oct 2024 14:56:23 +0200
Subject: [PATCH 001/208] hpams
---
.../hyperparameters.yaml | 36 +++++++++++++------
1 file changed, 25 insertions(+), 11 deletions(-)
diff --git a/drevalpy/models/simple_neural_network/hyperparameters.yaml b/drevalpy/models/simple_neural_network/hyperparameters.yaml
index 883bdb7..1a48b09 100644
--- a/drevalpy/models/simple_neural_network/hyperparameters.yaml
+++ b/drevalpy/models/simple_neural_network/hyperparameters.yaml
@@ -3,27 +3,41 @@ SimpleNeuralNetwork:
dropout_prob:
- 0.2
- 0.3
+ - 0.4
units_per_layer:
- - 10
- 10
- 10
- - - 20
- - 10
- - 10
- n_features:
- - 1036
+ - - 32
+ - 16
+ - 8
+ - 4
+ - - 128
+ - 64
+ - 32
+ - - 64
+ - 64
+ - 32
MultiOmicsNeuralNetwork:
dropout_prob:
- 0.2
- 0.3
+ - 0.4
units_per_layer:
- - - 10
- - 10
- - 10
- - - 20
- - 10
- - 10
+ - - 16
+ - 8
+ - 4
+ - - 32
+ - 16
+ - 8
+ - 4
+ - - 128
+ - 64
+ - 32
+ - - 64
+ - 64
+ - 32
methylation_pca_components:
- 100
From ba5f2ea60f0992d14d9a6f40d8a8e97d96feb6d3 Mon Sep 17 00:00:00 2001
From: PascalIversen
Date: Mon, 14 Oct 2024 09:49:08 +0200
Subject: [PATCH 002/208] black everything
---
drevalpy/models/utils.py | 10 ++++++----
tests/test_dataset.py | 4 +++-
tests/test_drp_model.py | 33 ++++++++++++---------------------
3 files changed, 21 insertions(+), 26 deletions(-)
diff --git a/drevalpy/models/utils.py b/drevalpy/models/utils.py
index 7b7e23f..107d3ce 100644
--- a/drevalpy/models/utils.py
+++ b/drevalpy/models/utils.py
@@ -57,7 +57,6 @@ def load_and_reduce_gene_features(
sep=",",
)
-
genes_in_list = set(gene_info["Symbol"])
genes_in_features = set(cl_features.meta_info[feature_type])
# Ensure that all genes from gene_list are in the dataset
@@ -65,10 +64,13 @@ def load_and_reduce_gene_features(
if missing_genes:
missing_genes_list = list(missing_genes)
if len(missing_genes_list) > 10:
- raise ValueError(f"The following genes are missing from the dataset {dataset_name} for {feature_type}: {', '.join(missing_genes_list[:10])}, ... ({len(missing_genes)} genes in total)")
+ raise ValueError(
+ f"The following genes are missing from the dataset {dataset_name} for {feature_type}: {', '.join(missing_genes_list[:10])}, ... ({len(missing_genes)} genes in total)"
+ )
else:
- raise ValueError(f"The following genes are missing from the dataset {dataset_name} for {feature_type}: {', '.join(missing_genes_list)}")
-
+ raise ValueError(
+ f"The following genes are missing from the dataset {dataset_name} for {feature_type}: {', '.join(missing_genes_list)}"
+ )
# Only proceed with genes that are available
gene_mask = np.array(
diff --git a/tests/test_dataset.py b/tests/test_dataset.py
index c682744..c55c099 100644
--- a/tests/test_dataset.py
+++ b/tests/test_dataset.py
@@ -354,7 +354,9 @@ def graph_dataset():
def test_feature_dataset_get_ids(sample_dataset):
- assert np.all(sample_dataset.get_ids() == ["drug1", "drug2", "drug3", "drug4", "drug5"])
+ assert np.all(
+ sample_dataset.get_ids() == ["drug1", "drug2", "drug3", "drug4", "drug5"]
+ )
def test_feature_dataset_get_view_names(sample_dataset):
diff --git a/tests/test_drp_model.py b/tests/test_drp_model.py
index fbf32b8..8bab0e4 100644
--- a/tests/test_drp_model.py
+++ b/tests/test_drp_model.py
@@ -52,30 +52,19 @@ def write_gene_list(temp_dir, gene_list):
if gene_list == "landmark_genes":
with open(temp_file, "w") as f:
f.write(
- 'Entrez ID,Symbol,Name,Gene Family,Type,RNA-Seq Correlation,RNA-Seq Correlation Self-Rank\n'
- '3638,INSIG1,insulin induced gene 1,,landmark,,\n'
- '2309,FOXO3,forkhead box O3,Forkhead boxes,landmark,,\n'
+ "Entrez ID,Symbol,Name,Gene Family,Type,RNA-Seq Correlation,RNA-Seq Correlation Self-Rank\n"
+ "3638,INSIG1,insulin induced gene 1,,landmark,,\n"
+ "2309,FOXO3,forkhead box O3,Forkhead boxes,landmark,,\n"
'672,BRCA1,"BRCA1, DNA repair associated","Ring finger proteins, Fanconi anemia complementation groups, Protein phosphatase 1 regulatory subunits, BRCA1 A complex, BRCA1 B complex, BRCA1 C complex",landmark,,\n'
- '57147,SCYL3,SCY1 like pseudokinase 3,SCY1 like pseudokinases,landmark,,'
+ "57147,SCYL3,SCY1 like pseudokinase 3,SCY1 like pseudokinases,landmark,,"
)
elif gene_list == "drug_target_genes_all_drugs":
with open(temp_file, "w") as f:
- f.write(
- "Symbol\n"
- "TSPAN6\n"
- "SCYL3\n"
- "BRCA1\n"
- )
+ f.write("Symbol\n" "TSPAN6\n" "SCYL3\n" "BRCA1\n")
elif gene_list == "gene_list_paccmann_network_prop":
with open(temp_file, "w") as f:
f.write(
- "Symbol\n"
- "HDAC1\n"
- "ALS2CR12\n"
- "BFAR\n"
- "ZCWPW1\n"
- "ZP1\n"
- "PDZD7"
+ "Symbol\n" "HDAC1\n" "ALS2CR12\n" "BFAR\n" "ZCWPW1\n" "ZP1\n" "PDZD7"
)
@@ -138,8 +127,9 @@ def test_load_and_reduce_gene_features(gene_list):
colnames.sort()
assert np.all(colnames == ["BRCA1", "SCYL3", "TSPAN6"])
elif gene_list == "gene_list_paccmann_network_prop":
- assert ("The following genes are missing from the dataset GDSC1_small"
- in str(valerr.value))
+ assert "The following genes are missing from the dataset GDSC1_small" in str(
+ valerr.value
+ )
def test_iterate_features():
@@ -308,8 +298,9 @@ def test_get_multiomics_feature_dataset(gene_list):
else:
assert np.all(dataset.meta_info[key] == feature_names)
elif gene_list == "gene_list_paccmann_network_prop":
- assert ("The following genes are missing from the dataset GDSC1_small"
- in str(valerr.value))
+ assert "The following genes are missing from the dataset GDSC1_small" in str(
+ valerr.value
+ )
def test_unique():
From e396e096ffee59bddc4099bdd19a5880c41c5e1f Mon Sep 17 00:00:00 2001
From: PascalIversen
Date: Mon, 14 Oct 2024 11:14:27 +0200
Subject: [PATCH 003/208] isort imports
---
create_report.py | 40 +++--
drevalpy/datasets/__init__.py | 2 +-
drevalpy/datasets/dataset.py | 118 +++++++++-----
drevalpy/datasets/gdsc1.py | 5 +-
drevalpy/datasets/gdsc2.py | 8 +-
drevalpy/datasets/toy.py | 4 +-
drevalpy/datasets/utils.py | 20 ++-
drevalpy/evaluation.py | 28 +++-
drevalpy/experiment.py | 146 +++++++++++++-----
drevalpy/models/DrugRegNet/DrugRegNetModel.py | 16 +-
drevalpy/models/MOLI/moli_model.py | 3 +-
drevalpy/models/SRMF/srmf.py | 21 ++-
drevalpy/models/__init__.py | 12 +-
.../baselines/multi_omics_random_forest.py | 7 +-
drevalpy/models/baselines/naive_pred.py | 47 ++++--
.../baselines/singledrug_random_forest.py | 2 +
drevalpy/models/baselines/sklearn_models.py | 30 ++--
drevalpy/models/drp_model.py | 33 ++--
.../multiomics_neural_network.py | 23 +--
.../simple_neural_network.py | 27 ++--
.../models/simple_neural_network/utils.py | 38 +++--
drevalpy/models/tCNNs/tcnns.py | 46 ++++--
drevalpy/models/utils.py | 58 ++++---
drevalpy/utils.py | 35 ++++-
drevalpy/visualization/corr_comp_scatter.py | 48 ++++--
.../visualization/critical_difference_plot.py | 57 +++++--
drevalpy/visualization/heatmap.py | 19 ++-
drevalpy/visualization/html_tables.py | 9 +-
.../visualization/regression_slider_plot.py | 19 ++-
drevalpy/visualization/utils.py | 83 ++++++----
drevalpy/visualization/vioheat.py | 15 +-
drevalpy/visualization/violin.py | 56 +++++--
run_suite.py | 6 +-
setup.cfg | 6 +-
setup.py | 2 +-
tests/__init__.py | 0
tests/conftest.py | 1 +
tests/individual_models/test_baselines.py | 52 +++++--
.../test_simple_neural_network.py | 5 +-
tests/individual_models/utils.py | 18 ++-
tests/test_available_data.py | 18 ++-
tests/test_dataset.py | 79 +++++++---
tests/test_drp_model.py | 56 ++++---
tests/test_evaluation.py | 16 +-
tests/test_run_suite.py | 14 +-
45 files changed, 937 insertions(+), 411 deletions(-)
delete mode 100644 tests/__init__.py
diff --git a/create_report.py b/create_report.py
index b8a526c..7ef0d78 100644
--- a/create_report.py
+++ b/create_report.py
@@ -2,24 +2,23 @@
Renders the evaluation results into an HTML report with various plots and tables.
"""
-import os
import argparse
+import os
from drevalpy.visualization import (
- HTMLTable,
+ CorrelationComparisonScatter,
CriticalDifferencePlot,
- Violin,
Heatmap,
- CorrelationComparisonScatter,
+ HTMLTable,
RegressionSliderPlot,
+ Violin,
)
-
from drevalpy.visualization.utils import (
+ create_html,
+ create_index_html,
parse_results,
prep_results,
write_results,
- create_index_html,
- create_html,
)
@@ -34,7 +33,9 @@ def create_output_directories(custom_id):
os.makedirs(f"results/{custom_id}/regression_plots", exist_ok=True)
os.makedirs(f"results/{custom_id}/corr_comp_scatter", exist_ok=True)
os.makedirs(f"results/{custom_id}/html_tables", exist_ok=True)
- os.makedirs(f"results/{custom_id}/critical_difference_plots", exist_ok=True)
+ os.makedirs(
+ f"results/{custom_id}/critical_difference_plots", exist_ok=True
+ )
def draw_setting_plots(
@@ -60,7 +61,9 @@ def draw_setting_plots(
)
# only draw figures for 'real' predictions comparing all models
- eval_results_preds = ev_res_subset[ev_res_subset["rand_setting"] == "predictions"]
+ eval_results_preds = ev_res_subset[
+ ev_res_subset["rand_setting"] == "predictions"
+ ]
# PIPELINE: DRAW_CRITICAL_DIFFERENCE
cd_plot = CriticalDifferencePlot(
@@ -95,7 +98,8 @@ def draw_setting_plots(
whole_name=False,
)
out_plot.draw_and_save(
- out_prefix=f"results/{custom_id}/{out_dir}/", out_suffix=out_suffix
+ out_prefix=f"results/{custom_id}/{out_dir}/",
+ out_suffix=out_suffix,
)
# per group plots
@@ -117,7 +121,9 @@ def draw_setting_plots(
return eval_results_preds["algorithm"].unique()
-def draw_per_grouping_setting_plots(grouping, ev_res_per_group, lpo_lco_ldo, custom_id):
+def draw_per_grouping_setting_plots(
+ grouping, ev_res_per_group, lpo_lco_ldo, custom_id
+):
"""
Draw plots for a specific grouping (drug or cell line) for a specific setting (LPO, LCO, LDO)
:param grouping: drug or cell_line
@@ -181,12 +187,16 @@ def draw_algorithm_plots(
if plt_type == "violinplot":
out_dir = "violin_plots"
out_plot = Violin(
- df=eval_results_algorithm, normalized_metrics=False, whole_name=True
+ df=eval_results_algorithm,
+ normalized_metrics=False,
+ whole_name=True,
)
else:
out_dir = "heatmaps"
out_plot = Heatmap(
- df=eval_results_algorithm, normalized_metrics=False, whole_name=True
+ df=eval_results_algorithm,
+ normalized_metrics=False,
+ whole_name=True,
)
out_plot.draw_and_save(
out_prefix=f"results/{custom_id}/{out_dir}/",
@@ -368,5 +378,7 @@ def draw_per_grouping_algorithm_plots(
)
# PIPELINE: WRITE_INDEX
create_index_html(
- custom_id=run_id, test_modes=settings, prefix_results=f"results/{run_id}"
+ custom_id=run_id,
+ test_modes=settings,
+ prefix_results=f"results/{run_id}",
)
diff --git a/drevalpy/datasets/__init__.py b/drevalpy/datasets/__init__.py
index 8cbe6c8..3610c2f 100644
--- a/drevalpy/datasets/__init__.py
+++ b/drevalpy/datasets/__init__.py
@@ -3,9 +3,9 @@
"""
__all__ = ["GDSC1", "GDSC2", "CCLE", "Toy", "RESPONSE_DATASET_FACTORY"]
+from .ccle import CCLE
from .gdsc1 import GDSC1
from .gdsc2 import GDSC2
-from .ccle import CCLE
from .toy import Toy
RESPONSE_DATASET_FACTORY = {
diff --git a/drevalpy/datasets/dataset.py b/drevalpy/datasets/dataset.py
index 8cce599..5ace3c1 100644
--- a/drevalpy/datasets/dataset.py
+++ b/drevalpy/datasets/dataset.py
@@ -1,29 +1,30 @@
"""
-Defines the different dataset classes: DrugResponseDataset for response values and
-FeatureDataset for feature values. They both inherit from the abstract class Dataset.
-The DrugResponseDataset class is used to store drug response values per cell line and drug.
-The FeatureDataset class is used to store feature values per cell line or drug. The FeatureDataset
-class can also store meta information for the feature views.
-The DrugResponseDataset class can be split into training, validation and test sets for
-cross-validation.
+Defines the different dataset classes:
+DrugResponseDataset for response values and FeatureDataset for feature values.
+They both inherit from the abstract class Dataset.
+The DrugResponseDataset class is used
+to store drug response values per cell line and drug.
+The FeatureDataset class is used to store
+feature values per cell line or drug.
+The FeatureDataset class can also store meta information
+for the feature views. The DrugResponseDataset class
+can be split into training, validation and test sets for cross-validation.
The FeatureDataset class can be used to randomize feature vectors.
"""
-from abc import ABC, abstractmethod
-import os
import copy
-from typing import Dict, List, Optional, Tuple, Union, Any, Callable
+import os
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import networkx as nx
import numpy as np
-from numpy.typing import ArrayLike
import pandas as pd
+from numpy.typing import ArrayLike
from sklearn.base import TransformerMixin
from sklearn.model_selection import GroupKFold, train_test_split
-import networkx as nx
-from .utils import (
- randomize_graph,
- permute_features,
-)
+from .utils import permute_features, randomize_graph
class Dataset(ABC):
@@ -164,11 +165,15 @@ def add_rows(self, other: "DrugResponseDataset") -> None:
:param other: other dataset
"""
self.response = np.concatenate([self.response, other.response])
- self.cell_line_ids = np.concatenate([self.cell_line_ids, other.cell_line_ids])
+ self.cell_line_ids = np.concatenate(
+ [self.cell_line_ids, other.cell_line_ids]
+ )
self.drug_ids = np.concatenate([self.drug_ids, other.drug_ids])
if self.predictions is not None and other.predictions is not None:
- self.predictions = np.concatenate([self.predictions, other.predictions])
+ self.predictions = np.concatenate(
+ [self.predictions, other.predictions]
+ )
def remove_nan_responses(self) -> None:
"""
@@ -208,7 +213,9 @@ def remove_drugs(self, drugs_to_remove: Union[str, list]) -> None:
self.cell_line_ids = self.cell_line_ids[mask]
self.response = self.response[mask]
- def remove_cell_lines(self, cell_lines_to_remove: Union[str, list]) -> None:
+ def remove_cell_lines(
+ self, cell_lines_to_remove: Union[str, list]
+ ) -> None:
"""
Removes cell lines from the dataset.
:param cell_lines_to_remove: name of cell line or list of names of multiple cell lines to
@@ -218,7 +225,8 @@ def remove_cell_lines(self, cell_lines_to_remove: Union[str, list]) -> None:
cell_lines_to_remove = [cell_lines_to_remove]
mask = [
- cell_line not in cell_lines_to_remove for cell_line in self.cell_line_ids
+ cell_line not in cell_lines_to_remove
+ for cell_line in self.cell_line_ids
]
self.drug_ids = self.drug_ids[mask]
self.cell_line_ids = self.cell_line_ids[mask]
@@ -247,7 +255,9 @@ def reduce_to(
self.remove_drugs(list(set(self.drug_ids) - set(drug_ids)))
if cell_line_ids is not None:
- self.remove_cell_lines(list(set(self.cell_line_ids) - set(cell_line_ids)))
+ self.remove_cell_lines(
+ list(set(self.cell_line_ids) - set(cell_line_ids))
+ )
def split_dataset(
self,
@@ -355,13 +365,17 @@ def load_splits(self, path: str) -> None:
train_splits = [file for file in files if "train" in file]
test_splits = [file for file in files if "test" in file]
- validation_es_splits = [file for file in files if "validation_es" in file]
+ validation_es_splits = [
+ file for file in files if "validation_es" in file
+ ]
validation_splits = [
file
for file in files
if "validation" in file and file not in validation_es_splits
]
- early_stopping_splits = [file for file in files if "early_stopping" in file]
+ early_stopping_splits = [
+ file for file in files if "early_stopping" in file
+ ]
for ds in [
train_splits,
@@ -413,7 +427,11 @@ def __hash__(self):
tuple(self.cell_line_ids),
tuple(self.drug_ids),
tuple(self.response),
- tuple(self.predictions) if self.predictions is not None else None,
+ (
+ tuple(self.predictions)
+ if self.predictions is not None
+ else None
+ ),
)
)
@@ -449,7 +467,9 @@ def fit_transform(self, response_transformation: TransformerMixin) -> None:
response_transformation.fit(self.response.reshape(-1, 1))
self.transform(response_transformation)
- def inverse_transform(self, response_transformation: TransformerMixin) -> None:
+ def inverse_transform(
+ self, response_transformation: TransformerMixin
+ ) -> None:
"""
Inverse transform the response data and prediction data of the dataset.
:param response_transformation: e.g., StandardScaler, MinMaxScaler, RobustScaler
@@ -480,7 +500,7 @@ def split_early_stopping_data(
split_early_stopping=False,
random_state=42,
)
- # take the first fold of a 4 cv as the split i.e., 3/4 for validation and 1/4 for early stopping
+ # take the first fold of a 4 cv as the split i.e. 3/4 for validation and 1/4 for early stopping
validation_dataset = cv_v[0]["train"]
early_stopping_dataset = cv_v[0]["test"]
return validation_dataset, early_stopping_dataset
@@ -497,7 +517,7 @@ def leave_pair_out_cv(
dataset_name: Optional[str] = None,
) -> List[dict]:
"""
- Leave pair out cross validation. Splits data into n_cv_splits number of cross validation splits.
+ Leave pair out cross validation. Splits data into n_cv_splits number of cross validation splits
:param n_cv_splits: number of cross validation splits
:param response: response (e.g. ic50 values)
:param cell_line_ids: cell line IDs
@@ -618,7 +638,8 @@ def leave_group_out_cv(
),
}
if split_validation:
- # split training set into training and validation set. The validation set also does
+ # split training set into training and validation set.
+ # The validation set also does
# contain unqiue cell lines/drugs
unique_train_groups = np.unique(group_ids[train_indices])
train_groups, validation_groups = train_test_split(
@@ -628,7 +649,9 @@ def leave_group_out_cv(
random_state=random_state,
)
train_indices = np.where(np.isin(group_ids, train_groups))[0]
- validation_indices = np.where(np.isin(group_ids, validation_groups))[0]
+ validation_indices = np.where(
+ np.isin(group_ids, validation_groups)
+ )[0]
cv_fold["train"] = DrugResponseDataset(
cell_line_ids=cell_line_ids[train_indices],
drug_ids=drug_ids[train_indices],
@@ -661,7 +684,7 @@ def __init__(
:param features: dictionary of features,
key: drug ID/cell line ID, value: Dict of feature views,
key: feature name, value: feature vector
- :param meta_info: additional information for the views, e.g., gene names for gene expression
+ :param meta_info: additional information for the views, e.g. gene names for gene expression
"""
super().__init__()
self.features = features
@@ -717,7 +740,8 @@ def randomize_features(
if randomization_type == "permutation":
# Permute the specified views for each entity (= cell line or drug)
- # E.g. each cell line gets the feature vector/graph/image... of another cell line.
+ # E.g. each cell line gets the feature vector/graph/image...
+ # of another cell line.
# Drawn without replacement.
self.features = permute_features(
features=self.features,
@@ -727,7 +751,8 @@ def randomize_features(
)
elif randomization_type == "invariant":
- # Invariant randomization: Randomize the specified views for each entity in a way that
+ # Invariant randomization:
+ # Randomize the specified views for each entity in a way that
# a key characteristic of the feature is preserved.
# For vectors this is the mean and standard deviation the feature view,
# for networks the degree distribution.
@@ -742,7 +767,9 @@ def randomize_features(
elif isinstance(
self.features[identifier][view], nx.classes.graph.Graph
):
- new_features = randomize_graph(self.features[identifier][view])
+ new_features = randomize_graph(
+ self.features[identifier][view]
+ )
else:
raise ValueError(
@@ -767,16 +794,21 @@ def get_feature_matrix(
self, view: str, identifiers: ArrayLike, stack: bool = True
) -> Union[np.ndarray, List]:
"""
- Returns the feature matrix for the given view. The feature view must be a vector or matrix.
+ Returns the feature matrix for the given view.
+ The feature view must be a vector or matrix.
:param view: view name
:param identifiers: list of identifiers (cell lines oder drugs)
:param stack: if True, stacks the feature vectors to a matrix.
If False, returns a list of features.
:return: feature matrix
"""
- assert len(identifiers) > 0, "get_feature_matrix: No identifiers given."
+ assert (
+ len(identifiers) > 0
+ ), "get_feature_matrix: No identifiers given."
- assert view in self.view_names, f"View '{view}' not in in the FeatureDataset."
+ assert (
+ view in self.view_names
+ ), f"View '{view}' not in in the FeatureDataset."
missing_identifiers = {
id_ for id_ in identifiers if id_ not in self.identifiers
}
@@ -786,12 +818,14 @@ def get_feature_matrix(
)
assert all(
- len(self.features[id_][view]) == len(self.features[identifiers[0]][view])
+ len(self.features[id_][view])
+ == len(self.features[identifiers[0]][view])
for id_ in identifiers
), f"Feature vectors of view {view} have different lengths."
assert all(
- isinstance(self.features[id_][view], np.ndarray) for id_ in identifiers
+ isinstance(self.features[id_][view], np.ndarray)
+ for id_ in identifiers
), f"get_feature_matrix only works for vectors or matrices. {view} is not a numpy array."
out = [self.features[id_][view] for id_ in identifiers]
return np.stack(out, axis=0) if stack else out
@@ -813,7 +847,9 @@ def add_features(self, other: "FeatureDataset") -> None:
if other.meta_info is not None:
self.add_meta_info(other)
- common_identifiers = set(self.identifiers).intersection(other.identifiers)
+ common_identifiers = set(self.identifiers).intersection(
+ other.identifiers
+ )
new_features = {}
for id_ in common_identifiers:
new_features[id_] = {
@@ -898,4 +934,6 @@ def apply(self, function: Callable, view: str):
Applies a function to the features of a view.
"""
for identifier in self.features:
- self.features[identifier][view] = function(self.features[identifier][view])
+ self.features[identifier][view] = function(
+ self.features[identifier][view]
+ )
diff --git a/drevalpy/datasets/gdsc1.py b/drevalpy/datasets/gdsc1.py
index f58025b..7cee71a 100644
--- a/drevalpy/datasets/gdsc1.py
+++ b/drevalpy/datasets/gdsc1.py
@@ -3,6 +3,7 @@
"""
import os
+
import pandas as pd
from .dataset import DrugResponseDataset
@@ -28,7 +29,9 @@ def __init__(
download_dataset(dataset_name, path_data, redownload=True)
response_data = pd.read_csv(path)
- response_data["DRUG_NAME"] = response_data["DRUG_NAME"].str.replace(",", "")
+ response_data["DRUG_NAME"] = response_data["DRUG_NAME"].str.replace(
+ ",", ""
+ )
super().__init__(
response=response_data["LN_IC50"].values,
diff --git a/drevalpy/datasets/gdsc2.py b/drevalpy/datasets/gdsc2.py
index 4c1a6da..8bcdee0 100644
--- a/drevalpy/datasets/gdsc2.py
+++ b/drevalpy/datasets/gdsc2.py
@@ -10,5 +10,9 @@ class GDSC2(GDSC1):
GDSC2 dataset.
"""
- def __init__(self, path_data: str = "data", file_name: str = "response_GDSC2.csv"):
- super().__init__(path_data=path_data, file_name=file_name, dataset_name="GDSC2")
+ def __init__(
+ self, path_data: str = "data", file_name: str = "response_GDSC2.csv"
+ ):
+ super().__init__(
+ path_data=path_data, file_name=file_name, dataset_name="GDSC2"
+ )
diff --git a/drevalpy/datasets/toy.py b/drevalpy/datasets/toy.py
index ac393e1..6aac677 100644
--- a/drevalpy/datasets/toy.py
+++ b/drevalpy/datasets/toy.py
@@ -27,7 +27,9 @@ def __init__(
with open(path, "rb") as f:
response_data = pickle.load(f)
- response_data.drug_ids = [di.replace(",", "") for di in response_data.drug_ids]
+ response_data.drug_ids = [
+ di.replace(",", "") for di in response_data.drug_ids
+ ]
super().__init__(
response=response_data.response,
cell_line_ids=response_data.cell_line_ids,
diff --git a/drevalpy/datasets/utils.py b/drevalpy/datasets/utils.py
index 7273b82..e06bef8 100644
--- a/drevalpy/datasets/utils.py
+++ b/drevalpy/datasets/utils.py
@@ -2,13 +2,14 @@
Utility functions for datasets.
"""
-import zipfile
import os
+import zipfile
from typing import List
-import requests
+
+import networkx as nx
import numpy as np
+import requests
from numpy.typing import ArrayLike
-import networkx as nx
def download_dataset(
@@ -47,7 +48,9 @@ def download_dataset(
os.makedirs(data_path, exist_ok=True)
# Download each file
- name_to_url = {file["key"]: file["links"]["self"] for file in data["files"]}
+ name_to_url = {
+ file["key"]: file["links"]["self"] for file in data["files"]
+ }
file_url = name_to_url[file_name]
# Download the file
print(f"Downloading {dataset} from {file_url}...")
@@ -103,7 +106,10 @@ def randomize_graph(original_graph: nx.Graph) -> nx.Graph:
def permute_features(
- features: dict, identifiers: ArrayLike, views_to_permute: List, all_views: List
+ features: dict,
+ identifiers: ArrayLike,
+ views_to_permute: List,
+ all_views: List,
) -> dict:
"""
Permute the specified views for each entity (= cell line or drug)
@@ -125,5 +131,7 @@ def permute_features(
)
for view in all_views
}
- for entity, other_entity in zip(identifiers, np.random.permutation(identifiers))
+ for entity, other_entity in zip(
+ identifiers, np.random.permutation(identifiers)
+ )
}
diff --git a/drevalpy/evaluation.py b/drevalpy/evaluation.py
index 5ee576e..5c032ec 100644
--- a/drevalpy/evaluation.py
+++ b/drevalpy/evaluation.py
@@ -3,12 +3,13 @@
"""
import warnings
-from typing import Union, List, Tuple
-from sklearn import metrics
-import pandas as pd
+from typing import List, Tuple, Union
+
import numpy as np
-from scipy.stats import pearsonr, spearmanr, kendalltau
+import pandas as pd
import pingouin as pg
+from scipy.stats import kendalltau, pearsonr, spearmanr
+from sklearn import metrics
from .datasets.dataset import DrugResponseDataset
@@ -50,7 +51,9 @@ def partial_correlation(
}
)
- if (len(df["cell_line_ids"].unique()) < 2) or (len(df["drug_ids"].unique()) < 2):
+ if (len(df["cell_line_ids"].unique()) < 2) or (
+ len(df["drug_ids"].unique()) < 2
+ ):
# if we don't have more than one cell line or drug in the data, partial correlation is
# meaningless
global warning_shown
@@ -81,7 +84,9 @@ def partial_correlation(
df["cell_line_ids"] = pd.factorize(df["cell_line_ids"])[0]
df["drug_ids"] = pd.factorize(df["drug_ids"])[0]
# One-hot encode the categorical covariates
- df_encoded = pd.get_dummies(df, columns=["cell_line_ids", "drug_ids"], dtype=int)
+ df_encoded = pd.get_dummies(
+ df, columns=["cell_line_ids", "drug_ids"], dtype=int
+ )
if df.shape[0] < 3:
r, p = np.nan, np.nan
@@ -93,7 +98,8 @@ def partial_correlation(
covar=[
col
for col in df_encoded.columns
- if col.startswith("cell_line_ids") or col.startswith("drug_ids")
+ if col.startswith("cell_line_ids")
+ or col.startswith("drug_ids")
],
method=method,
)
@@ -195,7 +201,13 @@ def kendall(y_pred: np.ndarray, y_true: np.ndarray) -> float:
"Partial_Correlation": partial_correlation,
}
MINIMIZATION_METRICS = ["MSE", "RMSE", "MAE"]
-MAXIMIZATION_METRICS = ["R^2", "Pearson", "Spearman", "Kendall", "Partial_Correlation"]
+MAXIMIZATION_METRICS = [
+ "R^2",
+ "Pearson",
+ "Spearman",
+ "Kendall",
+ "Partial_Correlation",
+]
def get_mode(metric: str):
diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py
index f144e0f..87bcf50 100644
--- a/drevalpy/experiment.py
+++ b/drevalpy/experiment.py
@@ -2,22 +2,23 @@
Main module for running the drug response prediction experiment.
"""
-import os
-from typing import Dict, List, Optional, Tuple, Type
-import warnings
import json
+import os
import shutil
+import warnings
+from typing import Dict, List, Optional, Tuple, Type
+
import numpy as np
+import pandas as pd
import ray
import torch
from ray import tune
from sklearn.base import TransformerMixin
-import pandas as pd
from .datasets.dataset import DrugResponseDataset, FeatureDataset
from .evaluation import evaluate, get_mode
-from .models.drp_model import DRPModel, SingleDrugModel
from .models import MODEL_FACTORY, MULTI_DRUG_MODEL_FACTORY, SINGLE_DRUG_MODEL_FACTORY
+from .models.drp_model import DRPModel, SingleDrugModel
def drug_response_experiment(
@@ -162,7 +163,9 @@ def drug_response_experiment(
validation_dataset,
early_stopping_dataset,
test_dataset,
- ) = get_datasets_from_cv_split(split, model_class, model_name, drug_id)
+ ) = get_datasets_from_cv_split(
+ split, model_class, model_name, drug_id
+ )
model = model_class()
@@ -213,7 +216,9 @@ def drug_response_experiment(
train_dataset=train_dataset,
prediction_dataset=test_dataset,
early_stopping_dataset=(
- early_stopping_dataset if model.early_stopping else None
+ early_stopping_dataset
+ if model.early_stopping
+ else None
),
response_transformation=response_transformation,
)
@@ -230,13 +235,17 @@ def drug_response_experiment(
train_dataset=train_dataset,
path_data=path_data,
early_stopping_dataset=(
- early_stopping_dataset if model.early_stopping else None
+ early_stopping_dataset
+ if model.early_stopping
+ else None
),
response_transformation=response_transformation,
path_out=parent_dir,
split_index=split_index,
single_drug_id=(
- drug_id if model_name in SINGLE_DRUG_MODEL_FACTORY else None
+ drug_id
+ if model_name in SINGLE_DRUG_MODEL_FACTORY
+ else None
),
)
@@ -265,7 +274,9 @@ def drug_response_experiment(
train_dataset=train_dataset,
test_dataset=test_dataset,
early_stopping_dataset=(
- early_stopping_dataset if model.early_stopping else None
+ early_stopping_dataset
+ if model.early_stopping
+ else None
),
path_out=parent_dir,
split_index=split_index,
@@ -282,7 +293,9 @@ def drug_response_experiment(
train_dataset=train_dataset,
test_dataset=test_dataset,
early_stopping_dataset=(
- early_stopping_dataset if model.early_stopping else None
+ early_stopping_dataset
+ if model.early_stopping
+ else None
),
path_out=parent_dir,
split_index=split_index,
@@ -320,11 +333,17 @@ def consolidate_single_drug_model_predictions(
out_path = os.path.join(out_path, str(model.model_name))
os.makedirs(os.path.join(out_path, "predictions"), exist_ok=True)
if cross_study_datasets:
- os.makedirs(os.path.join(out_path, "cross_study"), exist_ok=True)
+ os.makedirs(
+ os.path.join(out_path, "cross_study"), exist_ok=True
+ )
if randomization_mode:
- os.makedirs(os.path.join(out_path, "randomization"), exist_ok=True)
+ os.makedirs(
+ os.path.join(out_path, "randomization"), exist_ok=True
+ )
if n_trials_robustness:
- os.makedirs(os.path.join(out_path, "robustness"), exist_ok=True)
+ os.makedirs(
+ os.path.join(out_path, "robustness"), exist_ok=True
+ )
for split in range(n_cv_splits):
@@ -389,12 +408,15 @@ def consolidate_single_drug_model_predictions(
if trial not in predictions["robustness"]:
predictions["robustness"][trial] = []
predictions["robustness"][trial].append(
- pd.read_csv(os.path.join(robustness_path, f), index_col=0)
+ pd.read_csv(
+ os.path.join(robustness_path, f), index_col=0
+ )
)
# Randomization predictions
randomization_test_views = get_randomization_test_views(
- model=model_instance, randomization_mode=randomization_mode
+ model=model_instance,
+ randomization_mode=randomization_mode,
)
for view in randomization_test_views:
randomization_path = os.path.join(
@@ -405,14 +427,17 @@ def consolidate_single_drug_model_predictions(
predictions["randomization"][view] = []
predictions["randomization"][view].append(
pd.read_csv(
- os.path.join(randomization_path, f), index_col=0
+ os.path.join(randomization_path, f),
+ index_col=0,
)
)
# Save the consolidated predictions
pd.concat(predictions["main"], axis=0).to_csv(
os.path.join(
- out_path, "predictions", f"predictions_split_{split}.csv"
+ out_path,
+ "predictions",
+ f"predictions_split_{split}.csv",
)
)
@@ -427,7 +452,9 @@ def consolidate_single_drug_model_predictions(
)
)
- for trial, trial_predictions in predictions["robustness"].items():
+ for trial, trial_predictions in predictions[
+ "robustness"
+ ].items():
pd.concat(trial_predictions, axis=0).to_csv(
os.path.join(
out_path,
@@ -436,7 +463,9 @@ def consolidate_single_drug_model_predictions(
)
)
- for view, view_predictions in predictions["randomization"].items():
+ for view, view_predictions in predictions[
+ "randomization"
+ ].items():
pd.concat(view_predictions, axis=0).to_csv(
os.path.join(
out_path,
@@ -500,12 +529,16 @@ def cross_study_prediction(
warnings.warn(e)
return
- cell_lines_to_keep = cl_features.identifiers if cl_features is not None else None
+ cell_lines_to_keep = (
+ cl_features.identifiers if cl_features is not None else None
+ )
if single_drug_id is not None:
drugs_to_keep = [single_drug_id]
else:
- drugs_to_keep = drug_features.identifiers if drug_features is not None else None
+ drugs_to_keep = (
+ drug_features.identifiers if drug_features is not None else None
+ )
print(
f"Reducing cross study dataset ... feature data available for "
@@ -522,10 +555,13 @@ def cross_study_prediction(
if test_mode == "LPO":
train_pairs = {
f"{cl}_{drug}"
- for cl, drug in zip(train_dataset.cell_line_ids, train_dataset.drug_ids)
+ for cl, drug in zip(
+ train_dataset.cell_line_ids, train_dataset.drug_ids
+ )
}
dataset_pairs = [
- f"{cl}_{drug}" for cl, drug in zip(dataset.cell_line_ids, dataset.drug_ids)
+ f"{cl}_{drug}"
+ for cl, drug in zip(dataset.cell_line_ids, dataset.drug_ids)
]
dataset.remove_rows(
@@ -535,7 +571,9 @@ def cross_study_prediction(
train_cell_lines = set(train_dataset.cell_line_ids)
dataset.reduce_to(
cell_line_ids=[
- cl for cl in dataset.cell_line_ids if cl not in train_cell_lines
+ cl
+ for cl in dataset.cell_line_ids
+ if cl not in train_cell_lines
],
drug_ids=None,
)
@@ -543,10 +581,14 @@ def cross_study_prediction(
train_drugs = set(train_dataset.drug_ids)
dataset.reduce_to(
cell_line_ids=None,
- drug_ids=[drug for drug in dataset.drug_ids if drug not in train_drugs],
+ drug_ids=[
+ drug for drug in dataset.drug_ids if drug not in train_drugs
+ ],
)
else:
- raise ValueError(f"Invalid test mode: {test_mode}. Choose from LPO, LCO, LDO")
+ raise ValueError(
+ f"Invalid test mode: {test_mode}. Choose from LPO, LCO, LDO"
+ )
if len(dataset) > 0:
dataset.shuffle(random_state=42)
dataset.predictions = model.predict(
@@ -744,7 +786,9 @@ def randomization_test(
randomization_test_file
): # if this splits test has not been run yet
for view in views:
- print(f"Randomizing view {view} for randomization test {test_name} ...")
+ print(
+ f"Randomizing view {view} for randomization test {test_name} ..."
+ )
randomize_train_predict(
view=view,
test_name=test_name,
@@ -801,10 +845,14 @@ def randomize_train_predict(
)
return
cl_features_rand = cl_features.copy() if cl_features is not None else None
- drug_features_rand = drug_features.copy() if drug_features is not None else None
+ drug_features_rand = (
+ drug_features.copy() if drug_features is not None else None
+ )
if view in cl_features.get_view_names():
- cl_features_rand.randomize_features(view, randomization_type=randomization_type)
+ cl_features_rand.randomize_features(
+ view, randomization_type=randomization_type
+ )
elif view in drug_features.get_view_names():
drug_features_rand.randomize_features(
view, randomization_type=randomization_type
@@ -883,24 +931,33 @@ def train_and_predict(
data_path=path_data, dataset_name=train_dataset.dataset_name
)
- cell_lines_to_keep = cl_features.identifiers if cl_features is not None else None
- drugs_to_keep = drug_features.identifiers if drug_features is not None else None
+ cell_lines_to_keep = (
+ cl_features.identifiers if cl_features is not None else None
+ )
+ drugs_to_keep = (
+ drug_features.identifiers if drug_features is not None else None
+ )
# making sure there are no missing features:
len_train_before = len(train_dataset)
len_pred_before = len(prediction_dataset)
- train_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep)
+ train_dataset.reduce_to(
+ cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep
+ )
prediction_dataset.reduce_to(
cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep
)
- print(f"Reduced training dataset from {len_train_before} to {len(train_dataset)}")
+ print(
+ f"Reduced training dataset from {len_train_before} to {len(train_dataset)}"
+ )
print(
f"Reduced prediction dataset from {len_pred_before} to {len(prediction_dataset)}"
)
if early_stopping_dataset is not None:
early_stopping_dataset.reduce_to(
- cell_line_ids=cl_features.identifiers, drug_ids=drug_features.identifiers
+ cell_line_ids=cl_features.identifiers,
+ drug_ids=drug_features.identifiers,
)
if response_transformation:
@@ -1017,7 +1074,9 @@ def hpam_tune(
best_hyperparameters = hyperparameter
if best_hyperparameters is None:
- warnings.warn("all hpams lead to NaN respone. using last hpam combination.")
+ warnings.warn(
+ "all hpams lead to NaN respone. using last hpam combination."
+ )
best_hyperparameters = hyperparameter
return best_hyperparameters
@@ -1094,7 +1153,9 @@ def make_model_list(
for model in models:
if issubclass(model, SingleDrugModel):
for drug in unique_drugs:
- model_list[f"{model.model_name}.{drug}"] = str(model.model_name)
+ model_list[f"{model.model_name}.{drug}"] = str(
+ model.model_name
+ )
else:
model_list[str(model.model_name)] = str(model.model_name)
return model_list
@@ -1148,13 +1209,20 @@ def get_datasets_from_cv_split(split, model_class, model_name, drug_id):
return train_cp, val_cp, es_cp, test_cp
return train_cp, val_cp, None, test_cp
- return train_dataset, validation_dataset, early_stopping_dataset, test_dataset
+ return (
+ train_dataset,
+ validation_dataset,
+ early_stopping_dataset,
+ test_dataset,
+ )
def generate_data_saving_path(model_name, drug_id, result_path, suffix):
is_single_drug_model = model_name in SINGLE_DRUG_MODEL_FACTORY
if is_single_drug_model:
- model_path = os.path.join(result_path, model_name, "drugs", drug_id, suffix)
+ model_path = os.path.join(
+ result_path, model_name, "drugs", drug_id, suffix
+ )
else:
model_path = os.path.join(result_path, model_name, suffix)
os.makedirs(model_path, exist_ok=True)
diff --git a/drevalpy/models/DrugRegNet/DrugRegNetModel.py b/drevalpy/models/DrugRegNet/DrugRegNetModel.py
index a2ed696..f5ac439 100644
--- a/drevalpy/models/DrugRegNet/DrugRegNetModel.py
+++ b/drevalpy/models/DrugRegNet/DrugRegNetModel.py
@@ -1,7 +1,7 @@
-import pandas as pd
-from sklearn.linear_model import Lasso
import numpy as np
+import pandas as pd
from scipy import stats
+from sklearn.linear_model import Lasso
class DrugRegNetModel:
@@ -57,12 +57,16 @@ def train_model(self):
def calculate_pvalues(model, x, y):
params = np.append(model.intercept_, model.coef_)
predictions = model.predict(x)
- newX = pd.DataFrame({"Constant": np.ones(len(x))}, index=x.index).join(x)
+ newX = pd.DataFrame({"Constant": np.ones(len(x))}, index=x.index).join(
+ x
+ )
MSE = (sum((y - predictions) ** 2)) / (len(newX) - len(newX.columns))
var_b = MSE * (np.linalg.inv(np.dot(newX.T, newX)).diagonal())
sd_b = np.sqrt(var_b)
ts_b = params / sd_b
- p_values = [2 * (1 - stats.t.cdf(np.abs(i), (len(newX) - 1))) for i in ts_b]
+ p_values = [
+ 2 * (1 - stats.t.cdf(np.abs(i), (len(newX) - 1))) for i in ts_b
+ ]
p_values = np.round(p_values, 3)
p_values = p_values[1:]
return p_values
@@ -81,7 +85,9 @@ def export_results(self, path):
drug_specific_network = drug_specific_network.str.replace(
"(", ""
).str.replace(")", "")
- drug_specific_network = drug_specific_network.str.replace("'", "")
+ drug_specific_network = drug_specific_network.str.replace(
+ "'", ""
+ )
drug_specific_network = drug_specific_network.str.split(
", ", expand=True
)
diff --git a/drevalpy/models/MOLI/moli_model.py b/drevalpy/models/MOLI/moli_model.py
index e69d50a..28dcbec 100644
--- a/drevalpy/models/MOLI/moli_model.py
+++ b/drevalpy/models/MOLI/moli_model.py
@@ -47,7 +47,8 @@ def __init__(self, input_sizes, output_sizes, dropout_rates):
input_sizes[2], output_sizes[2], dropout_rates[2]
)
self.classifier = MOLIClassifier(
- output_sizes[0] + output_sizes[1] + output_sizes[2], dropout_rates[3]
+ output_sizes[0] + output_sizes[1] + output_sizes[2],
+ dropout_rates[3],
)
def forward_with_features(self, expression, mutation, cna):
diff --git a/drevalpy/models/SRMF/srmf.py b/drevalpy/models/SRMF/srmf.py
index 774972a..a89e760 100644
--- a/drevalpy/models/SRMF/srmf.py
+++ b/drevalpy/models/SRMF/srmf.py
@@ -1,13 +1,12 @@
+from typing import Dict
+
import numpy as np
import pandas as pd
-from scipy.spatial.distance import jaccard
from numpy.typing import ArrayLike
-from typing import Dict
+from scipy.spatial.distance import jaccard
-from drevalpy.models.drp_model import DRPModel
from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
-
-
+from drevalpy.models.drp_model import DRPModel
from drevalpy.models.utils import (
load_and_reduce_gene_features,
load_drug_fingerprint_features,
@@ -151,8 +150,12 @@ def CMF(self, W, intMat, drugMat, cellMat):
WR = W * intMat
for t in range(self.max_iter):
- U = self.alg_update(U0, V0, W, WR, drugMat, self.lambda_l, self.lambda_d)
- V = self.alg_update(V0, U, W.T, WR.T, cellMat, self.lambda_l, self.lambda_c)
+ U = self.alg_update(
+ U0, V0, W, WR, drugMat, self.lambda_l, self.lambda_d
+ )
+ V = self.alg_update(
+ V0, U, W.T, WR.T, cellMat, self.lambda_l, self.lambda_c
+ )
curr_loss = self.compute_loss(U, V, W, intMat, drugMat, cellMat)
if curr_loss < bestloss:
@@ -219,5 +222,7 @@ def load_cell_line_features(
dataset_name=dataset_name,
)
- def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
+ def load_drug_features(
+ self, data_path: str, dataset_name: str
+ ) -> FeatureDataset:
return load_drug_fingerprint_features(data_path, dataset_name)
diff --git a/drevalpy/models/__init__.py b/drevalpy/models/__init__.py
index 46f3ea5..dac5d56 100644
--- a/drevalpy/models/__init__.py
+++ b/drevalpy/models/__init__.py
@@ -20,21 +20,21 @@
"MODEL_FACTORY",
]
+from .baselines.multi_omics_random_forest import MultiOmicsRandomForest
from .baselines.naive_pred import (
- NaivePredictor,
- NaiveDrugMeanPredictor,
NaiveCellLineMeanPredictor,
+ NaiveDrugMeanPredictor,
+ NaivePredictor,
)
+from .baselines.singledrug_random_forest import SingleDrugRandomForest
from .baselines.sklearn_models import (
ElasticNetModel,
+ GradientBoosting,
RandomForest,
SVMRegressor,
- GradientBoosting,
)
-from .baselines.multi_omics_random_forest import MultiOmicsRandomForest
-from .simple_neural_network.simple_neural_network import SimpleNeuralNetwork
from .simple_neural_network.multiomics_neural_network import MultiOmicsNeuralNetwork
-from .baselines.singledrug_random_forest import SingleDrugRandomForest
+from .simple_neural_network.simple_neural_network import SimpleNeuralNetwork
from .SRMF.srmf import SRMF
SINGLE_DRUG_MODEL_FACTORY = {
diff --git a/drevalpy/models/baselines/multi_omics_random_forest.py b/drevalpy/models/baselines/multi_omics_random_forest.py
index d5bdb2b..86f8595 100644
--- a/drevalpy/models/baselines/multi_omics_random_forest.py
+++ b/drevalpy/models/baselines/multi_omics_random_forest.py
@@ -6,11 +6,10 @@
from numpy.typing import ArrayLike
from sklearn.decomposition import PCA
-from drevalpy.datasets.dataset import FeatureDataset, DrugResponseDataset
+from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
+
+from ..utils import get_multiomics_feature_dataset
from .sklearn_models import RandomForest
-from ..utils import (
- get_multiomics_feature_dataset,
-)
class MultiOmicsRandomForest(RandomForest):
diff --git a/drevalpy/models/baselines/naive_pred.py b/drevalpy/models/baselines/naive_pred.py
index 426bcb4..a86cb6c 100644
--- a/drevalpy/models/baselines/naive_pred.py
+++ b/drevalpy/models/baselines/naive_pred.py
@@ -6,10 +6,11 @@
"""
from typing import Dict
+
import numpy as np
from numpy.typing import ArrayLike
-from drevalpy.datasets.dataset import FeatureDataset, DrugResponseDataset
+from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
from drevalpy.models.drp_model import DRPModel
from drevalpy.models.utils import load_cl_ids_from_csv, load_drug_ids_from_csv, unique
@@ -64,17 +65,23 @@ def predict(
return np.full(cell_line_ids.shape[0], self.dataset_mean)
def save(self, path):
- raise NotImplementedError("Naive predictor does not support saving yet ...")
+ raise NotImplementedError(
+ "Naive predictor does not support saving yet ..."
+ )
def load(self, path):
- raise NotImplementedError("Naive predictor does not support loading yet ...")
+ raise NotImplementedError(
+ "Naive predictor does not support loading yet ..."
+ )
def load_cell_line_features(
self, data_path: str, dataset_name: str
) -> FeatureDataset:
return load_cl_ids_from_csv(data_path, dataset_name)
- def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
+ def load_drug_features(
+ self, data_path: str, dataset_name: str
+ ) -> FeatureDataset:
return load_drug_ids_from_csv(data_path, dataset_name)
@@ -154,17 +161,23 @@ def predict_drug(self, drug_id: str):
return self.dataset_mean
def save(self, path):
- raise NotImplementedError("Naive predictor does not support saving yet ...")
+ raise NotImplementedError(
+ "Naive predictor does not support saving yet ..."
+ )
def load(self, path):
- raise NotImplementedError("Naive predictor does not support loading yet ...")
+ raise NotImplementedError(
+ "Naive predictor does not support loading yet ..."
+ )
def load_cell_line_features(
self, data_path: str, dataset_name: str
) -> FeatureDataset:
return load_cl_ids_from_csv(data_path, dataset_name)
- def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
+ def load_drug_features(
+ self, data_path: str, dataset_name: str
+ ) -> FeatureDataset:
return load_drug_ids_from_csv(data_path, dataset_name)
@@ -209,10 +222,14 @@ def train(
for cell_line_response, cell_line_feature in zip(
unique(output.cell_line_ids), unique(cell_line_ids)
):
- responses_cl = output.response[cell_line_feature == output.cell_line_ids]
+ responses_cl = output.response[
+ cell_line_feature == output.cell_line_ids
+ ]
if len(responses_cl) > 0:
# prevent nan response
- self.cell_line_means[cell_line_response] = np.mean(responses_cl)
+ self.cell_line_means[cell_line_response] = np.mean(
+ responses_cl
+ )
def predict(
self,
@@ -244,15 +261,21 @@ def predict_cl(self, cl_id: str):
return self.dataset_mean
def save(self, path):
- raise NotImplementedError("Naive predictor does not support saving yet ...")
+ raise NotImplementedError(
+ "Naive predictor does not support saving yet ..."
+ )
def load(self, path):
- raise NotImplementedError("Naive predictor does not support loading yet ...")
+ raise NotImplementedError(
+ "Naive predictor does not support loading yet ..."
+ )
def load_cell_line_features(
self, data_path: str, dataset_name: str
) -> FeatureDataset:
return load_cl_ids_from_csv(data_path, dataset_name)
- def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
+ def load_drug_features(
+ self, data_path: str, dataset_name: str
+ ) -> FeatureDataset:
return load_drug_ids_from_csv(data_path, dataset_name)
diff --git a/drevalpy/models/baselines/singledrug_random_forest.py b/drevalpy/models/baselines/singledrug_random_forest.py
index 8230085..8357997 100644
--- a/drevalpy/models/baselines/singledrug_random_forest.py
+++ b/drevalpy/models/baselines/singledrug_random_forest.py
@@ -4,10 +4,12 @@
"""
from typing import Optional
+
import numpy as np
from numpy.typing import ArrayLike
from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
+
from ..drp_model import SingleDrugModel
from .sklearn_models import RandomForest
diff --git a/drevalpy/models/baselines/sklearn_models.py b/drevalpy/models/baselines/sklearn_models.py
index aee443d..7aba6f6 100644
--- a/drevalpy/models/baselines/sklearn_models.py
+++ b/drevalpy/models/baselines/sklearn_models.py
@@ -3,20 +3,17 @@
"""
from typing import Dict
+
import numpy as np
+from numpy.typing import ArrayLike
+from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from sklearn.linear_model import ElasticNet, Ridge
-from sklearn.ensemble import RandomForestRegressor
-from sklearn.ensemble import GradientBoostingRegressor
-
from sklearn.svm import SVR
-from numpy.typing import ArrayLike
-from drevalpy.datasets.dataset import FeatureDataset, DrugResponseDataset
+from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
from drevalpy.models.drp_model import DRPModel
-from ..utils import (
- load_and_reduce_gene_features,
- load_drug_fingerprint_features,
-)
+
+from ..utils import load_and_reduce_gene_features, load_drug_fingerprint_features
class SklearnModel(DRPModel):
@@ -91,10 +88,14 @@ def predict(
return self.model.predict(x)
def save(self, path):
- raise NotImplementedError("ElasticNetModel does not support saving yet ...")
+ raise NotImplementedError(
+ "ElasticNetModel does not support saving yet ..."
+ )
def load(self, path):
- raise NotImplementedError("ElasticNetModel does not support loading yet ...")
+ raise NotImplementedError(
+ "ElasticNetModel does not support loading yet ..."
+ )
def load_cell_line_features(
self, data_path: str, dataset_name: str
@@ -112,7 +113,9 @@ def load_cell_line_features(
dataset_name=dataset_name,
)
- def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
+ def load_drug_features(
+ self, data_path: str, dataset_name: str
+ ) -> FeatureDataset:
return load_drug_fingerprint_features(data_path, dataset_name)
@@ -132,7 +135,8 @@ def build_model(self, hyperparameters: Dict):
self.model = Ridge(alpha=hyperparameters["alpha"])
else:
self.model = ElasticNet(
- alpha=hyperparameters["alpha"], l1_ratio=hyperparameters["l1_ratio"]
+ alpha=hyperparameters["alpha"],
+ l1_ratio=hyperparameters["l1_ratio"],
)
diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py
index 369aae8..8a53eca 100644
--- a/drevalpy/models/drp_model.py
+++ b/drevalpy/models/drp_model.py
@@ -5,14 +5,15 @@
into a global model by applying a separate model for each drug.
"""
-from abc import ABC, abstractmethod
import inspect
import os
-from typing import Any, Dict, Optional, Type, List
import warnings
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Type
+
import numpy as np
-from numpy.typing import ArrayLike
import yaml
+from numpy.typing import ArrayLike
from sklearn.model_selection import ParameterGrid
from ..datasets.dataset import DrugResponseDataset, FeatureDataset
@@ -137,7 +138,9 @@ def load_cell_line_features(
"""
@abstractmethod
- def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
+ def load_drug_features(
+ self, data_path: str, dataset_name: str
+ ) -> FeatureDataset:
"""
:return: FeatureDataset
"""
@@ -212,9 +215,13 @@ def get_feature_matrices(
if drug_input is not None:
for drug_view in self.drug_views:
if drug_view not in drug_input.get_view_names():
- raise ValueError(f"Drug input does not contain view {drug_view}")
- drug_feature_matrices[drug_view] = drug_input.get_feature_matrix(
- view=drug_view, identifiers=drug_ids
+ raise ValueError(
+ f"Drug input does not contain view {drug_view}"
+ )
+ drug_feature_matrices[drug_view] = (
+ drug_input.get_feature_matrix(
+ view=drug_view, identifiers=drug_ids
+ )
)
return {**cell_line_feature_matrices, **drug_feature_matrices}
@@ -228,7 +235,9 @@ class SingleDrugModel(DRPModel, ABC):
early_stopping = False
drug_views = []
- def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
+ def load_drug_features(
+ self, data_path: str, dataset_name: str
+ ) -> FeatureDataset:
return None
@@ -270,7 +279,9 @@ def load_cell_line_features(
data_path=data_path, dataset_name=dataset_name
)
- def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
+ def load_drug_features(
+ self, data_path: str, dataset_name: str
+ ) -> FeatureDataset:
return None
def train(
@@ -300,7 +311,9 @@ def train(
output_drug.mask(output_mask)
output_earlystopping_drug = None
if output_earlystopping is not None:
- output_earlystopping_mask = output_earlystopping.drug_ids == drug
+ output_earlystopping_mask = (
+ output_earlystopping.drug_ids == drug
+ )
output_earlystopping_drug = output_earlystopping.copy()
output_earlystopping_drug.mask(output_earlystopping_mask)
diff --git a/drevalpy/models/simple_neural_network/multiomics_neural_network.py b/drevalpy/models/simple_neural_network/multiomics_neural_network.py
index 3e4cb09..645f76a 100644
--- a/drevalpy/models/simple_neural_network/multiomics_neural_network.py
+++ b/drevalpy/models/simple_neural_network/multiomics_neural_network.py
@@ -3,18 +3,17 @@
"""
import warnings
-from typing import Optional, Dict
+from typing import Dict, Optional
+
import numpy as np
from numpy.typing import ArrayLike
from sklearn.decomposition import PCA
from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
-from .utils import FeedForwardNetwork
+
from ..drp_model import DRPModel
-from ..utils import (
- load_drug_fingerprint_features,
- get_multiomics_feature_dataset,
-)
+from ..utils import get_multiomics_feature_dataset, load_drug_fingerprint_features
+from .utils import FeedForwardNetwork
class MultiOmicsNeuralNetwork(DRPModel):
@@ -49,7 +48,9 @@ def build_model(self, hyperparameters: Dict):
n_units_per_layer=hyperparameters["units_per_layer"],
dropout_prob=hyperparameters["dropout_prob"],
)
- self.pca = PCA(n_components=hyperparameters["methylation_pca_components"])
+ self.pca = PCA(
+ n_components=hyperparameters["methylation_pca_components"]
+ )
def train(
self,
@@ -73,7 +74,9 @@ def train(
axis=0,
)
- self.pca.n_components = min(self.pca.n_components, len(unique_methylation))
+ self.pca.n_components = min(
+ self.pca.n_components, len(unique_methylation)
+ )
self.pca = self.pca.fit(unique_methylation)
with warnings.catch_warnings():
@@ -162,5 +165,7 @@ def load_cell_line_features(
data_path=data_path, dataset_name=dataset_name
)
- def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
+ def load_drug_features(
+ self, data_path: str, dataset_name: str
+ ) -> FeatureDataset:
return load_drug_fingerprint_features(data_path, dataset_name)
diff --git a/drevalpy/models/simple_neural_network/simple_neural_network.py b/drevalpy/models/simple_neural_network/simple_neural_network.py
index 222480d..9d8988a 100644
--- a/drevalpy/models/simple_neural_network/simple_neural_network.py
+++ b/drevalpy/models/simple_neural_network/simple_neural_network.py
@@ -3,18 +3,17 @@
"""
import warnings
-from typing import Optional, Dict
+from typing import Dict, Optional
+
import numpy as np
from numpy.typing import ArrayLike
+from sklearn.preprocessing import StandardScaler
from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
-from ..utils import (
- load_drug_fingerprint_features,
- load_and_reduce_gene_features,
-)
-from .utils import FeedForwardNetwork
+
from ..drp_model import DRPModel
-from sklearn.preprocessing import StandardScaler
+from ..utils import load_and_reduce_gene_features, load_drug_fingerprint_features
+from .utils import FeedForwardNetwork
class SimpleNeuralNetwork(DRPModel):
@@ -64,10 +63,12 @@ def train(
if "gene_expression" in self.cell_line_views:
cell_line_input = cell_line_input.copy()
cell_line_input.apply(function=np.arcsinh, view="gene_expression")
- self.gene_expression_scaler = cell_line_input.fit_transform_features(
- train_ids=np.unique(output.cell_line_ids),
- transformer=self.gene_expression_scaler,
- view="gene_expression",
+ self.gene_expression_scaler = (
+ cell_line_input.fit_transform_features(
+ train_ids=np.unique(output.cell_line_ids),
+ transformer=self.gene_expression_scaler,
+ view="gene_expression",
+ )
)
with warnings.catch_warnings():
@@ -141,6 +142,8 @@ def load_cell_line_features(
dataset_name=dataset_name,
)
- def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
+ def load_drug_features(
+ self, data_path: str, dataset_name: str
+ ) -> FeatureDataset:
return load_drug_fingerprint_features(data_path, dataset_name)
diff --git a/drevalpy/models/simple_neural_network/utils.py b/drevalpy/models/simple_neural_network/utils.py
index 5c59f38..7741aac 100644
--- a/drevalpy/models/simple_neural_network/utils.py
+++ b/drevalpy/models/simple_neural_network/utils.py
@@ -4,13 +4,14 @@
import os
import random
-from typing import Optional, List
+from typing import List, Optional
+
import numpy as np
+import pytorch_lightning as pl
import torch
+from pytorch_lightning.callbacks import EarlyStopping, TQDMProgressBar
from torch import nn
from torch.utils.data import DataLoader, Dataset
-import pytorch_lightning as pl
-from pytorch_lightning.callbacks import EarlyStopping, TQDMProgressBar
from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
@@ -61,7 +62,9 @@ def __getitem__(self, idx):
if cell_line_features is None:
cell_line_features = feature_mat
else:
- cell_line_features = np.concatenate((cell_line_features, feature_mat))
+ cell_line_features = np.concatenate(
+ (cell_line_features, feature_mat)
+ )
for d_view in self.drug_views:
if drug_features is None:
drug_features = self.drug_input.features[drug_id][d_view]
@@ -135,7 +138,10 @@ def fit(
:return:
"""
if trainer_params is None:
- trainer_params = {"progress_bar_refresh_rate": 300, "max_epochs": 70}
+ trainer_params = {
+ "progress_bar_refresh_rate": 300,
+ "max_epochs": 70,
+ }
train_dataset = RegressionDataset(
output=output_train,
@@ -199,7 +205,11 @@ def fit(
# Initialize the Lightning trainer
trainer = pl.Trainer(
- callbacks=[early_stop_callback, self.checkpoint_callback, progress_bar],
+ callbacks=[
+ early_stop_callback,
+ self.checkpoint_callback,
+ progress_bar,
+ ],
default_root_dir=os.path.join(
os.getcwd(), "model_checkpoints/lightning_logs/" + name
),
@@ -244,15 +254,23 @@ def initialize_model(self, x):
self.fully_connected_layers.append(
nn.Linear(n_features, self.n_units_per_layer[0])
)
- self.batch_norm_layers.append(nn.BatchNorm1d(self.n_units_per_layer[0]))
+ self.batch_norm_layers.append(
+ nn.BatchNorm1d(self.n_units_per_layer[0])
+ )
for i in range(1, len(self.n_units_per_layer)):
self.fully_connected_layers.append(
- nn.Linear(self.n_units_per_layer[i - 1], self.n_units_per_layer[i])
+ nn.Linear(
+ self.n_units_per_layer[i - 1], self.n_units_per_layer[i]
+ )
+ )
+ self.batch_norm_layers.append(
+ nn.BatchNorm1d(self.n_units_per_layer[i])
)
- self.batch_norm_layers.append(nn.BatchNorm1d(self.n_units_per_layer[i]))
- self.fully_connected_layers.append(nn.Linear(self.n_units_per_layer[-1], 1))
+ self.fully_connected_layers.append(
+ nn.Linear(self.n_units_per_layer[-1], 1)
+ )
if self.dropout_prob is not None:
self.dropout_layer = nn.Dropout(p=self.dropout_prob)
self.model_initialized = True
diff --git a/drevalpy/models/tCNNs/tcnns.py b/drevalpy/models/tCNNs/tcnns.py
index b37abd9..8154c6a 100644
--- a/drevalpy/models/tCNNs/tcnns.py
+++ b/drevalpy/models/tCNNs/tcnns.py
@@ -1,14 +1,15 @@
+import warnings
+from typing import Optional
+
import numpy as np
-import torch
-from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
+import torch
import torch.nn.functional as F
from torch import nn
-from typing import Optional
+from torch.utils.data import DataLoader, Dataset
-from drevalpy.models.drp_model import DRPModel
from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
-import warnings
+from drevalpy.models.drp_model import DRPModel
class tCNNs(DRPModel):
@@ -79,7 +80,9 @@ def train(
output_earlystopping.response,
)
early_stopping_loader = DataLoader(
- dataset=dataset_earlystopping, batch_size=batch_size, shuffle=False
+ dataset=dataset_earlystopping,
+ batch_size=batch_size,
+ shuffle=False,
)
trainer = pl.Trainer(
@@ -87,7 +90,9 @@ def train(
progress_bar_refresh_rate=0,
gpus=1 if torch.cuda.is_available() else 0,
)
- trainer.fit(self.model, train_loader, val_dataloaders=early_stopping_loader)
+ trainer.fit(
+ self.model, train_loader, val_dataloaders=early_stopping_loader
+ )
# TODO define trainer properlz and early stopinng via callback
@@ -122,7 +127,9 @@ def load_cell_line_features(
pass
- def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset:
+ def load_drug_features(
+ self, data_path: str, dataset_name: str
+ ) -> FeatureDataset:
pass
@@ -142,7 +149,9 @@ def __len__(self):
def __getitem__(self, idx):
return (
- torch.tensor(self.gene_mutation_sequence[idx], dtype=torch.float32),
+ torch.tensor(
+ self.gene_mutation_sequence[idx], dtype=torch.float32
+ ),
torch.tensor(self.smiles_sequence[idx], dtype=torch.float32),
torch.tensor(self.response[idx], dtype=torch.float32),
)
@@ -151,7 +160,12 @@ def __getitem__(self, idx):
# Custom DataModule class
class DrugCellDataModule(pl.LightningDataModule):
def __init__(
- self, batch_size, drug_smile_dict, drug_cell_dict, cell_mut_dict, label_list
+ self,
+ batch_size,
+ drug_smile_dict,
+ drug_cell_dict,
+ cell_mut_dict,
+ label_list,
):
super().__init__()
self.batch_size = batch_size
@@ -173,7 +187,9 @@ def setup(self, stage=None):
)
value_shape = self.drug_cell_dict["IC50"].shape
- value = np.zeros((value_shape[0], value_shape[1], len(self.label_list)))
+ value = np.zeros(
+ (value_shape[0], value_shape[1], len(self.label_list))
+ )
for i in range(len(self.label_list)):
value[:, :, i] = self.drug_cell_dict[self.label_list[i]]
drug_smile = self.drug_smile_dict["canonical"]
@@ -185,10 +201,14 @@ def setup(self, stage=None):
self.valid_dataset = DrugCellDataset(
drug_smile, cell_mut, value, valid_positions
)
- self.test_dataset = DrugCellDataset(drug_smile, cell_mut, value, test_positions)
+ self.test_dataset = DrugCellDataset(
+ drug_smile, cell_mut, value, test_positions
+ )
def train_dataloader(self):
- return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
+ return DataLoader(
+ self.train_dataset, batch_size=self.batch_size, shuffle=True
+ )
def val_dataloader(self):
return DataLoader(self.valid_dataset, batch_size=self.batch_size)
diff --git a/drevalpy/models/utils.py b/drevalpy/models/utils.py
index 107d3ce..3bab38b 100644
--- a/drevalpy/models/utils.py
+++ b/drevalpy/models/utils.py
@@ -3,13 +3,13 @@
"""
import os.path
+import pickle
import warnings
from typing import Optional
-import pickle
-import pandas as pd
+
import numpy as np
-from numpy.typing import ArrayLike
-from sklearn.base import TransformerMixin
+import pandas as pd
+
from drevalpy.datasets.dataset import FeatureDataset
@@ -23,14 +23,21 @@ def load_cl_ids_from_csv(path: str, dataset_name: str) -> FeatureDataset:
if dataset_name == "Toy_Data":
return load_toy_features(path, dataset_name, "cell_line")
- cl_names = pd.read_csv(f"{path}/{dataset_name}/cell_line_names.csv", index_col=0)
+ cl_names = pd.read_csv(
+ f"{path}/{dataset_name}/cell_line_names.csv", index_col=0
+ )
return FeatureDataset(
- features={cl: {"cell_line_id": np.array([cl])} for cl in cl_names.index}
+ features={
+ cl: {"cell_line_id": np.array([cl])} for cl in cl_names.index
+ }
)
def load_and_reduce_gene_features(
- feature_type: str, gene_list: Optional[str], data_path: str, dataset_name: str
+ feature_type: str,
+ gene_list: Optional[str],
+ data_path: str,
+ dataset_name: str,
) -> FeatureDataset:
"""
Load and reduce gene features.
@@ -44,7 +51,9 @@ def load_and_reduce_gene_features(
cl_features = load_toy_features(data_path, dataset_name, "cell_line")
dataset_name = "GDSC1"
else:
- ge = pd.read_csv(f"{data_path}/{dataset_name}/{feature_type}.csv", index_col=0)
+ ge = pd.read_csv(
+ f"{data_path}/{dataset_name}/{feature_type}.csv", index_col=0
+ )
cl_features = FeatureDataset(
features=iterate_features(df=ge, feature_type=feature_type),
meta_info={feature_type: ge.columns.values},
@@ -76,11 +85,13 @@ def load_and_reduce_gene_features(
gene_mask = np.array(
[gene in genes_in_list for gene in cl_features.meta_info[feature_type]]
)
- cl_features.meta_info[feature_type] = cl_features.meta_info[feature_type][gene_mask]
+ cl_features.meta_info[feature_type] = cl_features.meta_info[feature_type][
+ gene_mask
+ ]
for cell_line in cl_features.features.keys():
- cl_features.features[cell_line][feature_type] = cl_features.features[cell_line][
- feature_type
- ][gene_mask]
+ cl_features.features[cell_line][feature_type] = cl_features.features[
+ cell_line
+ ][feature_type][gene_mask]
return cl_features
@@ -104,7 +115,9 @@ def iterate_features(df: pd.DataFrame, feature_type: str):
return features
-def load_drug_ids_from_csv(data_path: str, dataset_name: str) -> FeatureDataset:
+def load_drug_ids_from_csv(
+ data_path: str, dataset_name: str
+) -> FeatureDataset:
"""
Load drug ids from csv file.
:param data_path:
@@ -113,13 +126,19 @@ def load_drug_ids_from_csv(data_path: str, dataset_name: str) -> FeatureDataset:
"""
if dataset_name == "Toy_Data":
return load_toy_features(data_path, dataset_name, "drug")
- drug_names = pd.read_csv(f"{data_path}/{dataset_name}/drug_names.csv", index_col=0)
+ drug_names = pd.read_csv(
+ f"{data_path}/{dataset_name}/drug_names.csv", index_col=0
+ )
return FeatureDataset(
- features={drug: {"drug_id": np.array([drug])} for drug in drug_names.index}
+ features={
+ drug: {"drug_id": np.array([drug])} for drug in drug_names.index
+ }
)
-def load_drug_fingerprint_features(data_path: str, dataset_name: str) -> FeatureDataset:
+def load_drug_fingerprint_features(
+ data_path: str, dataset_name: str
+) -> FeatureDataset:
"""
Load drug features from fingerprints.
:param data_path:
@@ -129,7 +148,8 @@ def load_drug_fingerprint_features(data_path: str, dataset_name: str) -> Feature
if dataset_name == "Toy_Data":
return load_toy_features(data_path, dataset_name, "drug")
fingerprints = pd.read_csv(
- f"{data_path}/{dataset_name}/drug_fingerprints/drug_name_to_demorgan_128_map.csv",
+ f"{data_path}/{dataset_name}/drug_fingerprints/"
+ "drug_name_to_demorgan_128_map.csv",
index_col=0,
).T
return FeatureDataset(
@@ -141,7 +161,9 @@ def load_drug_fingerprint_features(data_path: str, dataset_name: str) -> Feature
def get_multiomics_feature_dataset(
- data_path: str, dataset_name: str, gene_list: str = "drug_target_genes_all_drugs"
+ data_path: str,
+ dataset_name: str,
+ gene_list: str = "drug_target_genes_all_drugs",
) -> FeatureDataset:
"""
Get multiomics feature dataset.
diff --git a/drevalpy/utils.py b/drevalpy/utils.py
index 13cb584..725ed91 100644
--- a/drevalpy/utils.py
+++ b/drevalpy/utils.py
@@ -7,10 +7,10 @@
from sklearn.preprocessing import MinMaxScaler, RobustScaler, StandardScaler
-from drevalpy.models import MODEL_FACTORY
from drevalpy.datasets import RESPONSE_DATASET_FACTORY
from drevalpy.evaluation import AVAILABLE_METRICS
from drevalpy.experiment import drug_response_experiment
+from drevalpy.models import MODEL_FACTORY
def get_parser():
@@ -22,15 +22,23 @@ def get_parser():
description="Run the drug response prediction model test suite."
)
parser.add_argument(
- "--run_id", type=str, default="my_run", help="identifier to save the results"
+ "--run_id",
+ type=str,
+ default="my_run",
+ help="identifier to save the results",
)
parser.add_argument(
- "--path_data", type=str, default="data", help="Path to the data directory"
+ "--path_data",
+ type=str,
+ default="data",
+ help="Path to the data directory",
)
parser.add_argument(
- "--models", nargs="+", help="model to evaluate or list of models to compare"
+ "--models",
+ nargs="+",
+ help="model to evaluate or list of models to compare",
)
parser.add_argument(
"--baselines",
@@ -104,14 +112,21 @@ def get_parser():
)
parser.add_argument(
- "--path_out", type=str, default="results/", help="Path to the output directory"
+ "--path_out",
+ type=str,
+ default="results/",
+ help="Path to the output directory",
)
parser.add_argument(
"--curve_curator",
action="store_true",
default=False,
- help="Whether to run " "CurveCurator " "to sort out " "non-reactive " "curves",
+ help="Whether to run "
+ "CurveCurator "
+ "to sort out "
+ "non-reactive "
+ "curves",
)
parser.add_argument(
"--overwrite",
@@ -240,7 +255,9 @@ def main(args):
if args.randomization_mode[0] == "None":
args.randomization_mode = None
- response_transformation = get_response_transformation(args.response_transformation)
+ response_transformation = get_response_transformation(
+ args.response_transformation
+ )
for test_mode in args.test_mode:
drug_response_experiment(
@@ -263,7 +280,9 @@ def main(args):
)
-def load_data(dataset_name: str, cross_study_datasets: List, path_data: str = "data"):
+def load_data(
+ dataset_name: str, cross_study_datasets: List, path_data: str = "data"
+):
"""
Load the response data and cross-study datasets.
:param dataset_name:
diff --git a/drevalpy/visualization/corr_comp_scatter.py b/drevalpy/visualization/corr_comp_scatter.py
index f828b9e..3b8638d 100644
--- a/drevalpy/visualization/corr_comp_scatter.py
+++ b/drevalpy/visualization/corr_comp_scatter.py
@@ -1,13 +1,14 @@
from typing import TextIO
-import pandas as pd
+
import numpy as np
+import pandas as pd
+import plotly.graph_objects as go
import scipy
-from scipy import stats
from plotly.subplots import make_subplots
-import plotly.graph_objects as go
+from scipy import stats
-from drevalpy.visualization.outplot import OutPlot
from drevalpy.models import SINGLE_DRUG_MODEL_FACTORY
+from drevalpy.visualization.outplot import OutPlot
class CorrelationComparisonScatter(OutPlot):
@@ -55,7 +56,9 @@ def __init__(
self.color_by = color_by
self.metric = metric
- self.df["setting"] = self.df["model"].str.split("_").str[0:3].str.join("_")
+ self.df["setting"] = (
+ self.df["model"].str.split("_").str[0:3].str.join("_")
+ )
self.models = self.df["setting"].unique()
self.fig_overall = make_subplots(
@@ -152,7 +155,9 @@ def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO:
f'\n'
)
- f.write("Comparison between all models, dropdown menu
\n")
+ f.write(
+ "Comparison between all models, dropdown menu
\n"
+ )
f.write(
f'\n'
@@ -162,7 +167,8 @@ def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO:
listed_files = [
elem
for elem in plot_list
- if elem != f"corr_comp_scatter_{lpo_lco_ldo}_{group_by}.html"
+ if elem
+ != f"corr_comp_scatter_{lpo_lco_ldo}_{group_by}.html"
and elem
!= f"corr_comp_scatter_overall_{lpo_lco_ldo}_{group_by}.html"
]
@@ -220,7 +226,9 @@ def __generate_corr_comp_scatterplots__(self):
self.fig_overall.add_trace(
scatterplot, col=run_idx + 1, row=run2_idx + 1
)
- self.fig_overall.add_trace(line_corr, col=run_idx + 1, row=run2_idx + 1)
+ self.fig_overall.add_trace(
+ line_corr, col=run_idx + 1, row=run2_idx + 1
+ )
# create dropdown buttons for y axis only in the first iteration
if run_idx == 0:
@@ -239,15 +247,17 @@ def __generate_corr_comp_scatterplots__(self):
self.fig_overall["layout"]["yaxis"]["title"] = str(
run2
).replace("_", "
", 2)
- self.fig_overall["layout"]["yaxis"]["title"]["font"]["size"] = 6
+ self.fig_overall["layout"]["yaxis"]["title"]["font"][
+ "size"
+ ] = 6
else:
y_axis_idx = (run2_idx) * len(self.models) + 1
- self.fig_overall["layout"][f"yaxis{y_axis_idx}"]["title"] = str(
- run2
- ).replace("_", "
", 2)
- self.fig_overall["layout"][f"yaxis{y_axis_idx}"]["title"][
- "font"
- ]["size"] = 6
+ self.fig_overall["layout"][f"yaxis{y_axis_idx}"][
+ "title"
+ ] = str(run2).replace("_", "
", 2)
+ self.fig_overall["layout"][f"yaxis{y_axis_idx}"][
+ "title"
+ ]["font"]["size"] = 6
def __subset_df__(self, run_id: str):
s_df = self.df[self.df["setting"] == run_id][
@@ -263,7 +273,9 @@ def __draw_subplot__(self, x_df, y_df, run, run2):
common_indices = x_df.index.intersection(y_df.index)
x_df_inter = x_df.loc[common_indices]
y_df = y_df.loc[common_indices]
- x_df_inter["setting"] = x_df_inter["model"].str.split("_").str[4:].str.join("")
+ x_df_inter["setting"] = (
+ x_df_inter["model"].str.split("_").str[4:].str.join("")
+ )
y_df["setting"] = y_df["model"].str.split("_").str[4:].str.join("")
joint_df = pd.concat([x_df_inter, y_df], axis=1)
@@ -292,7 +304,9 @@ def __draw_subplot__(self, x_df, y_df, run, run2):
x=x_df_inter[self.metric],
y=y_df[self.metric],
mode="markers",
- marker=dict(size=4, color=density, colorscale="Viridis", showscale=False),
+ marker=dict(
+ size=4, color=density, colorscale="Viridis", showscale=False
+ ),
showlegend=False,
visible=True,
meta=[run, run2],
diff --git a/drevalpy/visualization/critical_difference_plot.py b/drevalpy/visualization/critical_difference_plot.py
index 19098d5..d3020a9 100644
--- a/drevalpy/visualization/critical_difference_plot.py
+++ b/drevalpy/visualization/critical_difference_plot.py
@@ -1,13 +1,13 @@
+import math
+import operator
from typing import TextIO
-import numpy as np
-import pandas as pd
+
import matplotlib
import matplotlib.pyplot as plt
-import operator
-import math
-from scipy.stats import wilcoxon
-from scipy.stats import friedmanchisquare
import networkx
+import numpy as np
+import pandas as pd
+from scipy.stats import friedmanchisquare, wilcoxon
from drevalpy.evaluation import MINIMIZATION_METRICS
from drevalpy.visualization.outplot import OutPlot
@@ -37,14 +37,18 @@ def __init__(self, eval_results_preds: pd.DataFrame, metric="MSE"):
def draw_and_save(self, out_prefix: str, out_suffix: str) -> None:
try:
self.__draw__()
- path_out = f"{out_prefix}critical_difference_algorithms_{out_suffix}.svg"
+ path_out = (
+ f"{out_prefix}critical_difference_algorithms_{out_suffix}.svg"
+ )
self.fig.savefig(path_out, bbox_inches="tight")
except Exception as e:
print(f"Error in drawing critical difference plot: {e}")
def __draw__(self) -> None:
self.fig = self.__draw_cd_diagram__(
- alpha=0.05, title=f"Critical Difference: {self.metric}", labels=True
+ alpha=0.05,
+ title=f"Critical Difference: {self.metric}",
+ labels=True,
)
@staticmethod
@@ -53,7 +57,9 @@ def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO:
f.write(f"")
return f
- def __draw_cd_diagram__(self, alpha=0.05, title=None, labels=False) -> plt.Figure:
+ def __draw_cd_diagram__(
+ self, alpha=0.05, title=None, labels=False
+ ) -> plt.Figure:
"""
Draws the critical difference diagram given the list of pairwise classifiers that are
significant or not
@@ -264,7 +270,11 @@ def line(l, color="k", **kwargs):
def text(x, y, s, *args, **kwargs):
ax.text(wf * x, hf * y, s, *args, **kwargs)
- line([(textspace, cline), (width - textspace, cline)], linewidth=2, color="black")
+ line(
+ [(textspace, cline), (width - textspace, cline)],
+ linewidth=2,
+ color="black",
+ )
bigtick = 0.3
smalltick = 0.15
@@ -431,7 +441,9 @@ def wilcoxon_holm(alpha=0.05, df_perf=None):
)[1]
if friedman_p_value >= alpha:
# then the null hypothesis over the entire classifiers cannot be rejected
- print("the null hypothesis over the entire classifiers cannot be rejected")
+ print(
+ "the null hypothesis over the entire classifiers cannot be rejected"
+ )
exit()
# get the number of classifiers
m = len(classifiers)
@@ -443,7 +455,9 @@ def wilcoxon_holm(alpha=0.05, df_perf=None):
classifier_1 = classifiers[i]
# get the performance of classifier one
perf_1 = np.array(
- df_perf.loc[df_perf["classifier_name"] == classifier_1]["accuracy"],
+ df_perf.loc[df_perf["classifier_name"] == classifier_1][
+ "accuracy"
+ ],
dtype=np.float64,
)
for j in range(i + 1, m):
@@ -451,7 +465,9 @@ def wilcoxon_holm(alpha=0.05, df_perf=None):
classifier_2 = classifiers[j]
# get the performance of classifier one
perf_2 = np.array(
- df_perf.loc[df_perf["classifier_name"] == classifier_2]["accuracy"],
+ df_perf.loc[df_perf["classifier_name"] == classifier_2][
+ "accuracy"
+ ],
dtype=np.float64,
)
# calculate the p_value
@@ -469,7 +485,12 @@ def wilcoxon_holm(alpha=0.05, df_perf=None):
new_alpha = float(alpha / (k - i))
# test if significant after holm's correction of alpha
if p_values[i][2] <= new_alpha:
- p_values[i] = (p_values[i][0], p_values[i][1], p_values[i][2], True)
+ p_values[i] = (
+ p_values[i][0],
+ p_values[i][1],
+ p_values[i][2],
+ True,
+ )
else:
# stop
break
@@ -479,7 +500,9 @@ def wilcoxon_holm(alpha=0.05, df_perf=None):
df_perf["classifier_name"].isin(classifiers)
].sort_values(["classifier_name", "dataset_name"])
# get the rank data
- rank_data = np.array(sorted_df_perf["accuracy"]).reshape(m, max_nb_datasets)
+ rank_data = np.array(sorted_df_perf["accuracy"]).reshape(
+ m, max_nb_datasets
+ )
# create the data frame containg the accuracies
df_ranks = pd.DataFrame(
@@ -494,7 +517,9 @@ def wilcoxon_holm(alpha=0.05, df_perf=None):
# average the ranks
average_ranks = (
- df_ranks.rank(ascending=False).mean(axis=1).sort_values(ascending=False)
+ df_ranks.rank(ascending=False)
+ .mean(axis=1)
+ .sort_values(ascending=False)
)
# return the p-values and the average ranks
return p_values, average_ranks, max_nb_datasets
diff --git a/drevalpy/visualization/heatmap.py b/drevalpy/visualization/heatmap.py
index 6d61db3..c13f576 100644
--- a/drevalpy/visualization/heatmap.py
+++ b/drevalpy/visualization/heatmap.py
@@ -7,9 +7,13 @@
class Heatmap(VioHeat):
- def __init__(self, df: pd.DataFrame, normalized_metrics=False, whole_name=False):
+ def __init__(
+ self, df: pd.DataFrame, normalized_metrics=False, whole_name=False
+ ):
super().__init__(df, normalized_metrics, whole_name)
- self.df = self.df[[col for col in self.df.columns if col in self.all_metrics]]
+ self.df = self.df[
+ [col for col in self.df.columns if col in self.all_metrics]
+ ]
if self.normalized_metrics:
titles = [
"Standard Errors over CV folds",
@@ -32,7 +36,12 @@ def __init__(self, df: pd.DataFrame, normalized_metrics=False, whole_name=False)
"Mean Errors",
]
nr_subplots = 4
- self.plot_settings = ["standard_errors", "r2", "correlations", "errors"]
+ self.plot_settings = [
+ "standard_errors",
+ "r2",
+ "correlations",
+ "errors",
+ ]
self.fig = make_subplots(
rows=nr_subplots,
cols=1,
@@ -50,7 +59,9 @@ def __draw__(self) -> None:
for plot_setting in self.plot_settings:
self.__draw_subplots__(plot_setting)
self.fig.update_layout(
- height=1000, width=1100, title_text="Heatmap of the evaluation metrics"
+ height=1000,
+ width=1100,
+ title_text="Heatmap of the evaluation metrics",
)
self.fig.update_traces(showscale=False)
diff --git a/drevalpy/visualization/html_tables.py b/drevalpy/visualization/html_tables.py
index caa7bde..fb552e2 100644
--- a/drevalpy/visualization/html_tables.py
+++ b/drevalpy/visualization/html_tables.py
@@ -1,7 +1,8 @@
-from typing import TextIO, List
+import os
+from typing import List, TextIO
import pandas as pd
-import os
+
from drevalpy.visualization.outplot import OutPlot
@@ -74,7 +75,9 @@ def write_to_html(
if prefix != "":
prefix = os.path.join(prefix, "html_tables")
f.write(' Evaluation Results Table
\n')
- whole_table = __get_table__(files=files, file_table=f"table_{lpo_lco_ldo}.html")
+ whole_table = __get_table__(
+ files=files, file_table=f"table_{lpo_lco_ldo}.html"
+ )
__write_table__(f=f, table=whole_table, prefix=prefix)
if lpo_lco_ldo != "LCO":
diff --git a/drevalpy/visualization/regression_slider_plot.py b/drevalpy/visualization/regression_slider_plot.py
index b0cf0ea..426dc24 100644
--- a/drevalpy/visualization/regression_slider_plot.py
+++ b/drevalpy/visualization/regression_slider_plot.py
@@ -1,11 +1,12 @@
-from typing import TextIO, List
-import plotly.express as px
-from scipy.stats import pearsonr
+from typing import List, TextIO
+
import numpy as np
import pandas as pd
+import plotly.express as px
+from scipy.stats import pearsonr
-from drevalpy.visualization.outplot import OutPlot
from drevalpy.models import SINGLE_DRUG_MODEL_FACTORY
+from drevalpy.visualization.outplot import OutPlot
class RegressionSliderPlot(OutPlot):
@@ -18,7 +19,8 @@ def __init__(
normalize=False,
):
self.df = df[
- (df["LPO_LCO_LDO"] == lpo_lco_ldo) & (df["rand_setting"] == "predictions")
+ (df["LPO_LCO_LDO"] == lpo_lco_ldo)
+ & (df["rand_setting"] == "predictions")
]
self.df = self.df[(self.df["algorithm"] == model)]
self.group_by = group_by
@@ -65,7 +67,9 @@ def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO:
f.write('Regression plots
\n')
f.write("\n")
regr_files = [
- f for f in files if lpo_lco_ldo in f and f.startswith("regression_lines")
+ f
+ for f in files
+ if lpo_lco_ldo in f and f.startswith("regression_lines")
]
regr_files.sort()
for regr_file in regr_files:
@@ -159,5 +163,6 @@ def __make_slider__(self, setting_title):
]
self.fig.update_layout(
- sliders=sliders, legend=dict(yanchor="top", y=1.0, xanchor="left", x=1.05)
+ sliders=sliders,
+ legend=dict(yanchor="top", y=1.0, xanchor="left", x=1.05),
)
diff --git a/drevalpy/visualization/utils.py b/drevalpy/visualization/utils.py
index b33c01b..c4c73de 100644
--- a/drevalpy/visualization/utils.py
+++ b/drevalpy/visualization/utils.py
@@ -6,17 +6,17 @@
import pathlib
import shutil
from typing import List
+
import importlib_resources
import pandas as pd
-
from drevalpy.datasets.dataset import DrugResponseDataset
-from drevalpy.evaluation import evaluate, AVAILABLE_METRICS
+from drevalpy.evaluation import AVAILABLE_METRICS, evaluate
from drevalpy.visualization import HTMLTable
-from drevalpy.visualization.vioheat import VioHeat
from drevalpy.visualization.corr_comp_scatter import CorrelationComparisonScatter
-from drevalpy.visualization.regression_slider_plot import RegressionSliderPlot
from drevalpy.visualization.critical_difference_plot import CriticalDifferencePlot
+from drevalpy.visualization.regression_slider_plot import RegressionSliderPlot
+from drevalpy.visualization.vioheat import VioHeat
def parse_layout(f, path_to_layout):
@@ -77,7 +77,9 @@ def parse_results(path_to_results: str):
eval_results_per_cl,
t_vs_p,
model_name,
- ) = evaluate_file(pred_file=file, test_mode=lpo_lco_ldo, model_name=algorithm)
+ ) = evaluate_file(
+ pred_file=file, test_mode=lpo_lco_ldo, model_name=algorithm
+ )
evaluation_results = (
overall_eval
@@ -85,21 +87,27 @@ def parse_results(path_to_results: str):
else pd.concat([evaluation_results, overall_eval])
)
true_vs_pred = (
- t_vs_p if true_vs_pred is None else pd.concat([true_vs_pred, t_vs_p])
+ t_vs_p
+ if true_vs_pred is None
+ else pd.concat([true_vs_pred, t_vs_p])
)
if eval_results_per_drug is not None:
evaluation_results_per_drug = (
eval_results_per_drug
if evaluation_results_per_drug is None
- else pd.concat([evaluation_results_per_drug, eval_results_per_drug])
+ else pd.concat(
+ [evaluation_results_per_drug, eval_results_per_drug]
+ )
)
if eval_results_per_cl is not None:
evaluation_results_per_cell_line = (
eval_results_per_cl
if evaluation_results_per_cell_line is None
- else pd.concat([evaluation_results_per_cell_line, eval_results_per_cl])
+ else pd.concat(
+ [evaluation_results_per_cell_line, eval_results_per_cl]
+ )
)
return (
@@ -149,12 +157,14 @@ def evaluate_file(pred_file: pathlib.Path, test_mode: str, model_name: str):
norm_cl_eval_results = {}
if "LPO" in model or "LCO" in model:
- norm_drug_eval_results, evaluation_results_per_drug = evaluate_per_group(
- df=true_vs_pred,
- group_by="drug",
- norm_group_eval_results=norm_drug_eval_results,
- eval_results_per_group=evaluation_results_per_drug,
- model=model,
+ norm_drug_eval_results, evaluation_results_per_drug = (
+ evaluate_per_group(
+ df=true_vs_pred,
+ group_by="drug",
+ norm_group_eval_results=norm_drug_eval_results,
+ eval_results_per_group=evaluation_results_per_drug,
+ model=model,
+ )
)
if "LPO" in model or "LDO" in model:
norm_cl_eval_results, evaluation_results_per_cl = evaluate_per_group(
@@ -166,9 +176,13 @@ def evaluate_file(pred_file: pathlib.Path, test_mode: str, model_name: str):
)
overall_eval = pd.DataFrame.from_dict(overall_eval, orient="index")
if len(norm_drug_eval_results) > 0:
- overall_eval = concat_results(norm_drug_eval_results, "drug", overall_eval)
+ overall_eval = concat_results(
+ norm_drug_eval_results, "drug", overall_eval
+ )
if len(norm_cl_eval_results) > 0:
- overall_eval = concat_results(norm_cl_eval_results, "cell_line", overall_eval)
+ overall_eval = concat_results(
+ norm_cl_eval_results, "cell_line", overall_eval
+ )
return (
overall_eval,
@@ -219,7 +233,9 @@ def prep_results(
"CV_split",
]
new_columns.index = eval_results.index
- eval_results = pd.concat([new_columns.drop("split", axis=1), eval_results], axis=1)
+ eval_results = pd.concat(
+ [new_columns.drop("split", axis=1), eval_results], axis=1
+ )
if eval_results_per_drug is not None:
eval_results_per_drug[
["algorithm", "rand_setting", "LPO_LCO_LDO", "split", "CV_split"]
@@ -228,11 +244,16 @@ def prep_results(
eval_results_per_cell_line[
["algorithm", "rand_setting", "LPO_LCO_LDO", "split", "CV_split"]
] = eval_results_per_cell_line["model"].str.split("_", expand=True)
- t_vs_p[["algorithm", "rand_setting", "LPO_LCO_LDO", "split", "CV_split"]] = t_vs_p[
- "model"
- ].str.split("_", expand=True)
+ t_vs_p[
+ ["algorithm", "rand_setting", "LPO_LCO_LDO", "split", "CV_split"]
+ ] = t_vs_p["model"].str.split("_", expand=True)
- return eval_results, eval_results_per_drug, eval_results_per_cell_line, t_vs_p
+ return (
+ eval_results,
+ eval_results_per_drug,
+ eval_results_per_cell_line,
+ t_vs_p,
+ )
def generate_model_names(test_mode, model_name, pred_file):
@@ -273,10 +294,16 @@ def evaluate_per_group(
"""
# calculate the mean of y_true per drug
print(f"Calculating {group_by}-wise evaluation measures …")
- df[f"mean_y_true_per_{group_by}"] = df.groupby(group_by)["y_true"].transform("mean")
+ df[f"mean_y_true_per_{group_by}"] = df.groupby(group_by)[
+ "y_true"
+ ].transform("mean")
norm_df = df.copy()
- norm_df["y_true"] = norm_df["y_true"] - norm_df[f"mean_y_true_per_{group_by}"]
- norm_df["y_pred"] = norm_df["y_pred"] - norm_df[f"mean_y_true_per_{group_by}"]
+ norm_df["y_true"] = (
+ norm_df["y_true"] - norm_df[f"mean_y_true_per_{group_by}"]
+ )
+ norm_df["y_pred"] = (
+ norm_df["y_pred"] - norm_df[f"mean_y_true_per_{group_by}"]
+ )
norm_group_eval_results[model] = evaluate(
DrugResponseDataset(
response=norm_df["y_true"],
@@ -348,7 +375,9 @@ def write_results(
t_vs_p.to_csv(f"{path_out}true_vs_pred.csv", index=True)
-def create_index_html(custom_id: str, test_modes: List[str], prefix_results: str):
+def create_index_html(
+ custom_id: str, test_modes: List[str], prefix_results: str
+):
"""
Create the index.html file.
:param custom_id:
@@ -412,7 +441,9 @@ def create_index_html(custom_id: str, test_modes: List[str], prefix_results: str
f.write("