diff --git a/drevalpy/datasets/dataset.py b/drevalpy/datasets/dataset.py index d653afa..9049b57 100644 --- a/drevalpy/datasets/dataset.py +++ b/drevalpy/datasets/dataset.py @@ -105,7 +105,11 @@ def __str__(self): f"Response {self.response[:3]}..." ) else: - string = f"DrugResponseDataset: CLs {self.cell_line_ids}; " f"Drugs {self.drug_ids}; " f"Response {self.response}" + string = ( + f"DrugResponseDataset: CLs {self.cell_line_ids}; " + f"Drugs {self.drug_ids}; " + f"Response {self.response}" + ) if self.predictions is not None: if len(self.predictions) > 3: string += f"; Predictions {self.predictions[:3]}..." diff --git a/drevalpy/datasets/loader.py b/drevalpy/datasets/loader.py index e829e5a..d2a5d5d 100644 --- a/drevalpy/datasets/loader.py +++ b/drevalpy/datasets/loader.py @@ -5,8 +5,9 @@ from .utils import download_dataset -def load_gdsc1(path_data: str = "data", file_name: str = "response_GDSC1.csv", dataset_name: str = "GDSC1") -> ( - DrugResponseDataset): +def load_gdsc1( + path_data: str = "data", file_name: str = "response_GDSC1.csv", dataset_name: str = "GDSC1" +) -> DrugResponseDataset: """ Loads the GDSC1 dataset. diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index c0d43bf..e2a86bc 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -163,7 +163,9 @@ def drug_response_experiment( model = model_class() - if not os.path.isfile(prediction_file): # if this split has not been run yet (or for a single drug model, this drug_id) + if not os.path.isfile( + prediction_file + ): # if this split has not been run yet (or for a single drug model, this drug_id) tuning_inputs = { "model": model, @@ -234,7 +236,9 @@ def drug_response_experiment( print(f"Randomization tests for {model_class.model_name}") # if this line changes, it also needs to be changed in pipeline: # randomization_split.py - randomization_test_views = get_randomization_test_views(model=model, randomization_mode=randomization_mode) + randomization_test_views = get_randomization_test_views( + model=model, randomization_mode=randomization_mode + ) randomization_test( randomization_test_views=randomization_test_views, model=model, @@ -308,7 +312,11 @@ def consolidate_single_drug_model_predictions( "randomization": {}, } # list all dirs in model_path/drugs - drugs = [d for d in os.listdir(os.path.join(model_path, "drugs")) if os.path.isdir(os.path.join(model_path, "drugs", d))] + drugs = [ + d + for d in os.listdir(os.path.join(model_path, "drugs")) + if os.path.isdir(os.path.join(model_path, "drugs", d)) + ] for drug in drugs: single_drug_prediction_path = os.path.join(model_path, "drugs", drug) @@ -343,7 +351,9 @@ def consolidate_single_drug_model_predictions( f = f"robustness_{trial+1}_split_{split}.csv" if trial not in predictions["robustness"]: predictions["robustness"][trial] = [] - predictions["robustness"][trial].append(pd.read_csv(os.path.join(robustness_path, f), index_col=0)) + predictions["robustness"][trial].append( + pd.read_csv(os.path.join(robustness_path, f), index_col=0) + ) # Randomization predictions if randomization_mode is not None: diff --git a/drevalpy/models/SRMF/srmf.py b/drevalpy/models/SRMF/srmf.py index 1881e2c..0f293a5 100644 --- a/drevalpy/models/SRMF/srmf.py +++ b/drevalpy/models/SRMF/srmf.py @@ -76,7 +76,9 @@ def train( drug_response_matrix = drug_response_matrix.groupby(["cell_line_id", "drug_id"]).mean().reset_index() drug_response_matrix = drug_response_matrix.pivot(index="cell_line_id", columns="drug_id", values="response") - drug_response_matrix = drug_response_matrix.reindex(index=cell_lines, columns=drugs) # missing rows and columns are filled with NaN + drug_response_matrix = drug_response_matrix.reindex( + index=cell_lines, columns=drugs + ) # missing rows and columns are filled with NaN self.w = ~np.isnan(drug_response_matrix) drug_response_matrix = drug_response_matrix.copy() diff --git a/drevalpy/models/utils.py b/drevalpy/models/utils.py index 9ccf8a9..538a223 100644 --- a/drevalpy/models/utils.py +++ b/drevalpy/models/utils.py @@ -112,22 +112,10 @@ def load_drug_fingerprint_features(data_path: str, dataset_name: str) -> Feature :return: """ if dataset_name == "Toy_Data": - fingerprints = pd.read_csv( - os.path.join( - data_path, - dataset_name, - "fingerprints.csv" - ), - index_col=0 - ) + fingerprints = pd.read_csv(os.path.join(data_path, dataset_name, "fingerprints.csv"), index_col=0) else: fingerprints = pd.read_csv( - os.path.join( - data_path, - dataset_name, - "drug_fingerprints", - "drug_name_to_demorgan_128_map.csv" - ), + os.path.join(data_path, dataset_name, "drug_fingerprints", "drug_name_to_demorgan_128_map.csv"), index_col=0, ).T return FeatureDataset( diff --git a/drevalpy/utils.py b/drevalpy/utils.py index a407135..d0a43ba 100644 --- a/drevalpy/utils.py +++ b/drevalpy/utils.py @@ -295,4 +295,7 @@ def get_response_transformation(response_transformation: str): return MinMaxScaler() if response_transformation == "robust": return RobustScaler() - raise ValueError(f"Unknown response transformation {response_transformation}. Choose from 'None', " f"'standard', 'minmax', 'robust'") + raise ValueError( + f"Unknown response transformation {response_transformation}. Choose from 'None', " + f"'standard', 'minmax', 'robust'" + ) diff --git a/drevalpy/visualization/corr_comp_scatter.py b/drevalpy/visualization/corr_comp_scatter.py index 3daef17..fe03e38 100644 --- a/drevalpy/visualization/corr_comp_scatter.py +++ b/drevalpy/visualization/corr_comp_scatter.py @@ -154,7 +154,10 @@ def write_to_html(lpo_lco_ldo: str, f: TextIO, *args, **kwargs) -> TextIO: ] listed_files.sort() for group_comparison in listed_files: - f.write(f'