From eb2a4cf57ec573e557c24934d9b22bd1d53ecce1 Mon Sep 17 00:00:00 2001
From: PascalIversen
Date: Tue, 29 Oct 2024 10:40:26 +0100
Subject: [PATCH] black am
---
drevalpy/datasets/dataset.py | 6 +++++-
drevalpy/datasets/loader.py | 5 +++--
drevalpy/experiment.py | 18 ++++++++++++++----
drevalpy/models/SRMF/srmf.py | 4 +++-
drevalpy/models/utils.py | 16 ++--------------
drevalpy/utils.py | 5 ++++-
drevalpy/visualization/corr_comp_scatter.py | 5 ++++-
drevalpy/visualization/utils.py | 7 +++++--
drevalpy/visualization/vioheat.py | 4 +++-
drevalpy/visualization/violin.py | 17 +++++++++++++++--
tests/individual_models/conftest.py | 12 +++++++-----
tests/test_run_suite.py | 4 ++--
12 files changed, 67 insertions(+), 36 deletions(-)
diff --git a/drevalpy/datasets/dataset.py b/drevalpy/datasets/dataset.py
index d653afa..9049b57 100644
--- a/drevalpy/datasets/dataset.py
+++ b/drevalpy/datasets/dataset.py
@@ -105,7 +105,11 @@ def __str__(self):
f"Response {self.response[:3]}..."
)
else:
- string = f"DrugResponseDataset: CLs {self.cell_line_ids}; " f"Drugs {self.drug_ids}; " f"Response {self.response}"
+ string = (
+ f"DrugResponseDataset: CLs {self.cell_line_ids}; "
+ f"Drugs {self.drug_ids}; "
+ f"Response {self.response}"
+ )
if self.predictions is not None:
if len(self.predictions) > 3:
string += f"; Predictions {self.predictions[:3]}..."
diff --git a/drevalpy/datasets/loader.py b/drevalpy/datasets/loader.py
index e829e5a..d2a5d5d 100644
--- a/drevalpy/datasets/loader.py
+++ b/drevalpy/datasets/loader.py
@@ -5,8 +5,9 @@
from .utils import download_dataset
-def load_gdsc1(path_data: str = "data", file_name: str = "response_GDSC1.csv", dataset_name: str = "GDSC1") -> (
- DrugResponseDataset):
+def load_gdsc1(
+ path_data: str = "data", file_name: str = "response_GDSC1.csv", dataset_name: str = "GDSC1"
+) -> DrugResponseDataset:
"""
Loads the GDSC1 dataset.
diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py
index c0d43bf..e2a86bc 100644
--- a/drevalpy/experiment.py
+++ b/drevalpy/experiment.py
@@ -163,7 +163,9 @@ def drug_response_experiment(
model = model_class()
- if not os.path.isfile(prediction_file): # if this split has not been run yet (or for a single drug model, this drug_id)
+ if not os.path.isfile(
+ prediction_file
+ ): # if this split has not been run yet (or for a single drug model, this drug_id)
tuning_inputs = {
"model": model,
@@ -234,7 +236,9 @@ def drug_response_experiment(
print(f"Randomization tests for {model_class.model_name}")
# if this line changes, it also needs to be changed in pipeline:
# randomization_split.py
- randomization_test_views = get_randomization_test_views(model=model, randomization_mode=randomization_mode)
+ randomization_test_views = get_randomization_test_views(
+ model=model, randomization_mode=randomization_mode
+ )
randomization_test(
randomization_test_views=randomization_test_views,
model=model,
@@ -308,7 +312,11 @@ def consolidate_single_drug_model_predictions(
"randomization": {},
}
# list all dirs in model_path/drugs
- drugs = [d for d in os.listdir(os.path.join(model_path, "drugs")) if os.path.isdir(os.path.join(model_path, "drugs", d))]
+ drugs = [
+ d
+ for d in os.listdir(os.path.join(model_path, "drugs"))
+ if os.path.isdir(os.path.join(model_path, "drugs", d))
+ ]
for drug in drugs:
single_drug_prediction_path = os.path.join(model_path, "drugs", drug)
@@ -343,7 +351,9 @@ def consolidate_single_drug_model_predictions(
f = f"robustness_{trial+1}_split_{split}.csv"
if trial not in predictions["robustness"]:
predictions["robustness"][trial] = []
- predictions["robustness"][trial].append(pd.read_csv(os.path.join(robustness_path, f), index_col=0))
+ predictions["robustness"][trial].append(
+ pd.read_csv(os.path.join(robustness_path, f), index_col=0)
+ )
# Randomization predictions
if randomization_mode is not None:
diff --git a/drevalpy/models/SRMF/srmf.py b/drevalpy/models/SRMF/srmf.py
index 1881e2c..0f293a5 100644
--- a/drevalpy/models/SRMF/srmf.py
+++ b/drevalpy/models/SRMF/srmf.py
@@ -76,7 +76,9 @@ def train(
drug_response_matrix = drug_response_matrix.groupby(["cell_line_id", "drug_id"]).mean().reset_index()
drug_response_matrix = drug_response_matrix.pivot(index="cell_line_id", columns="drug_id", values="response")
- drug_response_matrix = drug_response_matrix.reindex(index=cell_lines, columns=drugs) # missing rows and columns are filled with NaN
+ drug_response_matrix = drug_response_matrix.reindex(
+ index=cell_lines, columns=drugs
+ ) # missing rows and columns are filled with NaN
self.w = ~np.isnan(drug_response_matrix)
drug_response_matrix = drug_response_matrix.copy()
diff --git a/drevalpy/models/utils.py b/drevalpy/models/utils.py
index 9ccf8a9..538a223 100644
--- a/drevalpy/models/utils.py
+++ b/drevalpy/models/utils.py
@@ -112,22 +112,10 @@ def load_drug_fingerprint_features(data_path: str, dataset_name: str) -> Feature
:return:
"""
if dataset_name == "Toy_Data":
- fingerprints = pd.read_csv(
- os.path.join(
- data_path,
- dataset_name,
- "fingerprints.csv"
- ),
- index_col=0
- )
+ fingerprints = pd.read_csv(os.path.join(data_path, dataset_name, "fingerprints.csv"), index_col=0)
else:
fingerprints = pd.read_csv(
- os.path.join(
- data_path,
- dataset_name,
- "drug_fingerprints",
- "drug_name_to_demorgan_128_map.csv"
- ),
+ os.path.join(data_path, dataset_name, "drug_fingerprints", "drug_name_to_demorgan_128_map.csv"),
index_col=0,
).T
return FeatureDataset(
diff --git a/drevalpy/utils.py b/drevalpy/utils.py
index a407135..d0a43ba 100644
--- a/drevalpy/utils.py
+++ b/drevalpy/utils.py
@@ -295,4 +295,7 @@ def get_response_transformation(response_transformation: str):
return MinMaxScaler()
if response_transformation == "robust":
return RobustScaler()
- raise ValueError(f"Unknown response transformation {response_transformation}. Choose from 'None', " f"'standard', 'minmax', 'robust'")
+ raise ValueError(
+ f"Unknown response transformation {response_transformation}. Choose from 'None', "
+ f"'standard', 'minmax', 'robust'"
+ )
diff --git a/drevalpy/visualization/corr_comp_scatter.py b/drevalpy/visualization/corr_comp_scatter.py
index 3daef17..fe03e38 100644
--- a/drevalpy/visualization/corr_comp_scatter.py
+++ b/drevalpy/visualization/corr_comp_scatter.py
@@ -154,7 +154,10 @@ def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO:
]
listed_files.sort()
for group_comparison in listed_files:
- f.write(f'' f"{group_comparison}\n")
+ f.write(
+ f''
+ f"{group_comparison}\n"
+ )
f.write("\n")
return f
diff --git a/drevalpy/visualization/utils.py b/drevalpy/visualization/utils.py
index df82e5e..da35c91 100644
--- a/drevalpy/visualization/utils.py
+++ b/drevalpy/visualization/utils.py
@@ -47,7 +47,9 @@ def parse_results(path_to_results: str):
result_dir = pathlib.Path(path_to_results)
result_files = list(result_dir.rglob("*.csv"))
# filter for all files that follow this pattern: result_dir/*/{predictions|cross_study|randomization|robustness}/*.csv
- pattern = re.compile(fr"{result_dir}/(LPO|LCO|LDO)/[^/]+/(predictions|cross_study|randomization|robustness)/.*\.csv$")
+ pattern = re.compile(
+ rf"{result_dir}/(LPO|LCO|LDO)/[^/]+/(predictions|cross_study|randomization|robustness)/.*\.csv$"
+ )
result_files = [file for file in result_files if pattern.match(str(file))]
# inititalize dictionaries to store the evaluation results
@@ -371,7 +373,8 @@ def create_index_html(custom_id: str, test_modes: list[str], prefix_results: str
)
shutil.copyfile(img_path, os.path.join(prefix_results, f"{lpo_lco_ldo}.png"))
f.write(
- f'\n'
+ f'\n'
)
f.write("\n")
f.write("\n")
diff --git a/drevalpy/visualization/vioheat.py b/drevalpy/visualization/vioheat.py
index d4ca7ec..29f4ec4 100644
--- a/drevalpy/visualization/vioheat.py
+++ b/drevalpy/visualization/vioheat.py
@@ -65,7 +65,9 @@ def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO:
]
f.write(f'{plot} Plots of Performance Measures over CV runs
\n')
f.write(f"{plot} plots comparing all models
\n")
- f.write(f'\n')
+ f.write(
+ f'\n'
+ )
f.write(f"{plot} plots comparing all models with normalized metrics
\n")
f.write(
f"Before calculating the evaluation metrics, all values were normalized by the mean of the drug or cell line. "
diff --git a/drevalpy/visualization/violin.py b/drevalpy/visualization/violin.py
index 7e19cc7..21267c6 100644
--- a/drevalpy/visualization/violin.py
+++ b/drevalpy/visualization/violin.py
@@ -97,7 +97,14 @@ def __draw__(self) -> None:
{
"visible": [False] * (self.count_r2 + self.count_pearson + self.count_spearman)
+ [True] * self.count_kendall
- + [False] * (count_sum - self.count_r2 - self.count_pearson - self.count_spearman - self.count_kendall)
+ + [False]
+ * (
+ count_sum
+ - self.count_r2
+ - self.count_pearson
+ - self.count_spearman
+ - self.count_kendall
+ )
},
{"title": "Kendall"},
],
@@ -108,7 +115,13 @@ def __draw__(self) -> None:
args=[
{
"visible": [False]
- * (count_sum - self.count_partial_correlation - self.count_mse - self.count_rmse - self.count_mae)
+ * (
+ count_sum
+ - self.count_partial_correlation
+ - self.count_mse
+ - self.count_rmse
+ - self.count_mae
+ )
+ [True] * self.count_partial_correlation
+ [False] * (self.count_mse + self.count_rmse + self.count_mae)
},
diff --git a/tests/individual_models/conftest.py b/tests/individual_models/conftest.py
index 0c188c2..75b2895 100644
--- a/tests/individual_models/conftest.py
+++ b/tests/individual_models/conftest.py
@@ -2,17 +2,19 @@
from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
from drevalpy.datasets.loader import load_toy
-from drevalpy.models.utils import (get_multiomics_feature_dataset, load_drug_fingerprint_features,
- load_drug_ids_from_csv, load_cl_ids_from_csv)
+from drevalpy.models.utils import (
+ get_multiomics_feature_dataset,
+ load_drug_fingerprint_features,
+ load_drug_ids_from_csv,
+ load_cl_ids_from_csv,
+)
@pytest.fixture(scope="session")
def sample_dataset() -> tuple[DrugResponseDataset, FeatureDataset, FeatureDataset]:
path_data = "../data"
drug_response = load_toy(path_data)
- cell_line_input = get_multiomics_feature_dataset(
- data_path=path_data, dataset_name="Toy_Data", gene_list=None
- )
+ cell_line_input = get_multiomics_feature_dataset(data_path=path_data, dataset_name="Toy_Data", gene_list=None)
cell_line_ids = load_cl_ids_from_csv(path=path_data, dataset_name="Toy_Data")
cell_line_input._add_features(cell_line_ids)
# Load the drug features
diff --git a/tests/test_run_suite.py b/tests/test_run_suite.py
index 5813511..f3859aa 100644
--- a/tests/test_run_suite.py
+++ b/tests/test_run_suite.py
@@ -44,7 +44,7 @@ def test_run_suite(args):
args = Namespace(**args)
main(args)
assert os.listdir(temp_dir.name) == ["test_run"]
- '''
+ """
(
evaluation_results,
evaluation_results_per_drug,
@@ -86,4 +86,4 @@ def test_run_suite(args):
assert all(test_mode in evaluation_results.LPO_LCO_LDO.unique() for test_mode in args.test_mode)
assert evaluation_results.CV_split.astype(int).max() == (args.n_cv_splits - 1)
assert evaluation_results.Pearson.astype(float).max() > 0.5
- '''
+ """