From 7c814e5ad718666731be01979af3e74d7686961b Mon Sep 17 00:00:00 2001 From: Florence Townend Date: Wed, 15 May 2024 13:43:34 +0200 Subject: [PATCH 1/3] added training_modifications argument to train_and_save_models where the accelerator and number of devices can be specified. Also changed the metrics_utils calculations to put the metric equations onto whatever device the data is on. --- fusilli/train.py | 75 +++++++++++++++++++--------------- fusilli/utils/metrics_utils.py | 60 ++++++++++++++++++--------- 2 files changed, 83 insertions(+), 52 deletions(-) diff --git a/fusilli/train.py b/fusilli/train.py index 1ac9f0d..0596751 100644 --- a/fusilli/train.py +++ b/fusilli/train.py @@ -16,19 +16,19 @@ def train_and_test( - data_module, - k, - fusion_model, - kfold, - extra_log_string_dict=None, - layer_mods=None, - max_epochs=1000, - enable_checkpointing=True, - show_loss_plot=False, - wandb_logging=False, - project_name=None, - training_modifications=None, - metrics_list=None, + data_module, + k, + fusion_model, + kfold, + extra_log_string_dict=None, + layer_mods=None, + max_epochs=1000, + enable_checkpointing=True, + show_loss_plot=False, + wandb_logging=False, + project_name=None, + training_modifications=None, + metrics_list=None, ): """ Trains and tests a model and, if k_fold trained, a fold. @@ -127,13 +127,14 @@ def train_and_test( val_dataloader = data_module.val_dataloader() output_paths = data_module.output_paths - logger = set_logger(fold=k, - project_name=project_name, - output_paths=output_paths, - fusion_model=fusion_model, - extra_log_string_dict=extra_log_string_dict, - wandb_logging=wandb_logging, - ) # set logger + logger = set_logger( + fold=k, + project_name=project_name, + output_paths=output_paths, + fusion_model=fusion_model, + extra_log_string_dict=extra_log_string_dict, + wandb_logging=wandb_logging, + ) # set logger trainer = init_trainer( logger, @@ -182,7 +183,9 @@ def train_and_test( # if logger is CSVLogger, plot loss curve if isinstance(logger, CSVLogger): - plot_loss_curve(figures_path=output_paths["figures"], logger=logger, show=show_loss_plot) + plot_loss_curve( + figures_path=output_paths["figures"], logger=logger, show=show_loss_plot + ) return pl_model @@ -232,16 +235,17 @@ def _store_trained_model(trained_model, trained_models_dict): def train_and_save_models( - data_module, - fusion_model, - wandb_logging=False, - extra_log_string_dict=None, - layer_mods=None, - max_epochs=1000, - enable_checkpointing=True, - show_loss_plot=False, - project_name=None, - metrics_list=None, + data_module, + fusion_model, + wandb_logging=False, + extra_log_string_dict=None, + layer_mods=None, + max_epochs=1000, + enable_checkpointing=True, + show_loss_plot=False, + project_name=None, + metrics_list=None, + training_modifications=None, ): """ Trains/tests the model and saves the trained model to a dictionary for further analysis. @@ -282,6 +286,8 @@ def train_and_save_models( (AUROC, accuracy for binary/multiclass, R2 and MAE for regression). The first metric in the list will be used in the comparison evaluation figures to rank the models' performances. Length must be 2 or more. + training_modifications : dict + Dictionary of training modifications. Used to modify the training process. Keys could be "accelerator", "devices" Returns ------- @@ -300,7 +306,10 @@ def train_and_save_models( kfold = True num_folds = data_module.num_folds elif isinstance(data_module, list): - if hasattr(data_module[0], "num_folds") and data_module[0].num_folds is not None: + if ( + hasattr(data_module[0], "num_folds") + and data_module[0].num_folds is not None + ): kfold = True num_folds = data_module[0].num_folds else: @@ -322,6 +331,7 @@ def train_and_save_models( wandb_logging=wandb_logging, project_name=project_name, metrics_list=metrics_list, + training_modifications=training_modifications, ) trained_models_list.append(trained_model) @@ -343,6 +353,7 @@ def train_and_save_models( wandb_logging=wandb_logging, project_name=project_name, metrics_list=metrics_list, + training_modifications=training_modifications, ) trained_models_list.append(trained_model) diff --git a/fusilli/utils/metrics_utils.py b/fusilli/utils/metrics_utils.py index 5809345..6fa3e3f 100644 --- a/fusilli/utils/metrics_utils.py +++ b/fusilli/utils/metrics_utils.py @@ -1,6 +1,7 @@ """ Calculates metrics of the models and houses list of the available metrics to use. """ + import torch import torchmetrics as tm @@ -41,9 +42,11 @@ def auroc(self, preds, labels, logits): """ if self.prediction_task == "binary": - auroc_equation = tm.AUROC(task="binary") + auroc_equation = tm.AUROC(task="binary").to(preds.device) elif self.prediction_task == "multiclass": - auroc_equation = tm.AUROC(num_classes=self.model.multiclass_dimensions, task="multiclass") + auroc_equation = tm.AUROC( + num_classes=self.model.multiclass_dimensions, task="multiclass" + ).to(preds.device) else: raise ValueError("Invalid prediction task for AUROC.") @@ -71,11 +74,13 @@ def accuracy(self, preds, labels, logits): if self.prediction_task == "binary": # do binary accuracy - accuracy_equation = tm.Accuracy(task="binary") + accuracy_equation = tm.Accuracy(task="binary").to(preds.device) elif self.prediction_task == "multiclass": # do multiclass accuracy - accuracy_equation = tm.Accuracy(num_classes=self.model.multiclass_dimensions, task="multiclass", top_k=1) + accuracy_equation = tm.Accuracy( + num_classes=self.model.multiclass_dimensions, task="multiclass", top_k=1 + ).to(preds.device) else: raise ValueError("Invalid prediction task for accuracy.") @@ -104,7 +109,7 @@ def r2(self, preds, labels, logits): if self.prediction_task != "regression": raise ValueError("Invalid prediction task for R2.") - return tm.R2Score()(preds, labels) + return tm.R2Score().to(preds.device)(preds, labels) def mse(self, preds, labels, logits): """ @@ -128,7 +133,7 @@ def mse(self, preds, labels, logits): if self.prediction_task != "regression": raise ValueError("Invalid prediction task for mse.") - return tm.MeanSquaredError()(preds, labels) + return tm.MeanSquaredError().to(preds.device)(preds, labels) def mae(self, preds, labels, logits): """ @@ -153,7 +158,7 @@ def mae(self, preds, labels, logits): if self.prediction_task != "regression": raise ValueError("Invalid prediction task for mae.") - return tm.MeanAbsoluteError()(preds, labels) + return tm.MeanAbsoluteError().to(preds.device)(preds, labels) def recall(self, preds, labels, logits): """ @@ -176,9 +181,11 @@ def recall(self, preds, labels, logits): """ if self.prediction_task == "binary": - recall_equation = tm.Recall(task="binary") + recall_equation = tm.Recall(task="binary").to(preds.device) elif self.prediction_task == "multiclass": - recall_equation = tm.Recall(num_classes=self.model.multiclass_dimensions, task="multiclass") + recall_equation = tm.Recall( + num_classes=self.model.multiclass_dimensions, task="multiclass" + ).to(preds.device) else: raise ValueError("Invalid prediction task for recall.") @@ -205,9 +212,11 @@ def specificity(self, preds, labels, logits): """ if self.prediction_task == "binary": - specificity_equation = tm.Specificity(task="binary") + specificity_equation = tm.Specificity(task="binary").to(preds.device) elif self.prediction_task == "multiclass": - specificity_equation = tm.Specificity(num_classes=self.model.multiclass_dimensions, task="multiclass") + specificity_equation = tm.Specificity( + num_classes=self.model.multiclass_dimensions, task="multiclass" + ).to(preds.device) else: raise ValueError("Invalid prediction task for specificity.") @@ -234,9 +243,11 @@ def precision(self, preds, labels, logits): """ if self.prediction_task == "binary": - precision_equation = tm.Precision(task="binary") + precision_equation = tm.Precision(task="binary").to(preds.device) elif self.prediction_task == "multiclass": - precision_equation = tm.Precision(num_classes=self.model.multiclass_dimensions, task="multiclass") + precision_equation = tm.Precision( + num_classes=self.model.multiclass_dimensions, task="multiclass" + ).to(preds.device) else: raise ValueError("Invalid prediction task for precision.") @@ -263,9 +274,11 @@ def f1(self, preds, labels, logits): """ if self.prediction_task == "binary": - f1_equation = tm.F1Score(task="binary") + f1_equation = tm.F1Score(task="binary").to(preds.device) elif self.prediction_task == "multiclass": - f1_equation = tm.F1Score(num_classes=self.model.multiclass_dimensions, task="multiclass") + f1_equation = tm.F1Score( + num_classes=self.model.multiclass_dimensions, task="multiclass" + ).to(preds.device) else: raise ValueError("Invalid prediction task for F1.") @@ -291,9 +304,11 @@ def auprc(self, preds, labels, logits): """ if self.prediction_task == "binary": - auprc_equation = tm.AveragePrecision(task="binary") + auprc_equation = tm.AveragePrecision(task="binary").to(preds.device) elif self.prediction_task == "multiclass": - auprc_equation = tm.AveragePrecision(num_classes=self.model.multiclass_dimensions, task="multiclass") + auprc_equation = tm.AveragePrecision( + num_classes=self.model.multiclass_dimensions, task="multiclass" + ).to(preds.device) else: raise ValueError("Invalid prediction task for AUPRC.") @@ -319,10 +334,15 @@ def balanced_accuracy(self, preds, labels, logits): """ if self.prediction_task == "binary": - balanced_accuracy_equation = tm.Accuracy(task='multiclass', num_classes=2, average='macro') + balanced_accuracy_equation = tm.Accuracy( + task="multiclass", num_classes=2, average="macro" + ).to(preds.device) elif self.prediction_task == "multiclass": - balanced_accuracy_equation = tm.Accuracy(task='multiclass', num_classes=self.model.multiclass_dimensions, - average='macro') + balanced_accuracy_equation = tm.Accuracy( + task="multiclass", + num_classes=self.model.multiclass_dimensions, + average="macro", + ).to(preds.device) else: raise ValueError("Invalid prediction task for balanced accuracy.") From 84af094113745e69b37b8f2a902f4dc99ee555d1 Mon Sep 17 00:00:00 2001 From: Florence Townend Date: Wed, 15 May 2024 13:52:54 +0200 Subject: [PATCH 2/3] added documentation for GPU on customising training section. Currently not implemented for subspace methods but that is on the list --- docs/customising_training.rst | 33 +++ fusilli/data.py | 237 ++++++++++-------- .../fusionmodels/tabularfusion/mcvae_model.py | 17 +- 3 files changed, 176 insertions(+), 111 deletions(-) diff --git a/docs/customising_training.rst b/docs/customising_training.rst index 5a455c4..f24cfdd 100644 --- a/docs/customising_training.rst +++ b/docs/customising_training.rst @@ -5,6 +5,7 @@ This page will show you how to customise the training and evaluation of your fus We will cover the following topics: +* Using GPU * Early stopping * Valildation metrics * Batch size @@ -13,6 +14,38 @@ We will cover the following topics: * Number of workers in PyTorch DataLoader * Train/test and cross-validation splitting yourself +Using GPU +------------ + +If you want to use a GPU to train your model, you can pass the ``training_modifications`` argument to the :func:`~.fusilli.data.prepare_fusion_data` and :func:`~.fusilli.train.train_and_save_models` functions. By default, the model will train on the CPU. + +For example, to train on a single GPU, you can do the following: + +.. code-block:: python + + from fusilli.data import prepare_fusion_data + from fusilli.train import train_and_save_models + + datamodule = prepare_fusion_data( + prediction_task="binary", + fusion_model=example_model, + data_paths=data_paths, + output_paths=output_path, + ) + + trained_model_list = train_and_save_models( + data_module=datamodule, + fusion_model=example_model, + training_modifications={"accelerator": "gpu", "devices": 1}, + ) + +.. warning:: + + This is currently not implemented for subspace-based models as of May 2024. + When this is implemented, the documentation will be updated. + + + Early stopping -------------- diff --git a/fusilli/data.py b/fusilli/data.py index b6c54ef..5a39f56 100644 --- a/fusilli/data.py +++ b/fusilli/data.py @@ -228,7 +228,9 @@ def __init__(self, sources, img_downsample_dims=None): if "ID" not in tab1_df.columns: raise ValueError("The CSV must have an index column named 'ID'.") if "prediction_label" not in tab1_df.columns: - raise ValueError("The CSV must have a label column named 'prediction_label'.") + raise ValueError( + "The CSV must have a label column named 'prediction_label'." + ) # if tabular2_source exists, check it has the right columns if self.tabular2_source != "": @@ -236,7 +238,9 @@ def __init__(self, sources, img_downsample_dims=None): if "ID" not in tab2_df.columns: raise ValueError("The CSV must have an index column named 'ID'.") if "prediction_label" not in tab2_df.columns: - raise ValueError("The CSV must have a label column named 'prediction_label'.") + raise ValueError( + "The CSV must have a label column named 'prediction_label'." + ) def load_tabular1(self): """ @@ -337,11 +341,17 @@ def load_tabular_tabular(self): tab1_df.set_index("ID", inplace=True) tab2_df.set_index("ID", inplace=True) - tab1_pred_features = torch.Tensor(tab1_df.drop(columns=["prediction_label"]).values) - tab2_pred_features = torch.Tensor(tab2_df.drop(columns=["prediction_label"]).values) + tab1_pred_features = torch.Tensor( + tab1_df.drop(columns=["prediction_label"]).values + ) + tab2_pred_features = torch.Tensor( + tab2_df.drop(columns=["prediction_label"]).values + ) prediction_label = tab1_df[["prediction_label"]] - dataset = CustomDataset([tab1_pred_features, tab2_pred_features], prediction_label) + dataset = CustomDataset( + [tab1_pred_features, tab2_pred_features], prediction_label + ) mod1_dim = tab1_pred_features.shape[1] mod2_dim = tab2_pred_features.shape[1] @@ -430,23 +440,23 @@ class TrainTestDataModule(pl.LightningDataModule): """ def __init__( - self, - fusion_model, - sources, - output_paths, - prediction_task, - batch_size, - test_size, - multiclass_dimensions, - subspace_method=None, - image_downsample_size=None, - layer_mods=None, - max_epochs=1000, - extra_log_string_dict=None, - own_early_stopping_callback=None, - num_workers=0, - test_indices=None, - kwargs=None, + self, + fusion_model, + sources, + output_paths, + prediction_task, + batch_size, + test_size, + multiclass_dimensions, + subspace_method=None, + image_downsample_size=None, + layer_mods=None, + max_epochs=1000, + extra_log_string_dict=None, + own_early_stopping_callback=None, + num_workers=0, + test_indices=None, + kwargs=None, ): """ Parameters @@ -539,8 +549,8 @@ def prepare_data(self): self.dataset, self.data_dims = self.modality_methods[self.modality_type]() def setup( - self, - checkpoint_path=None, + self, + checkpoint_path=None, ): """ Splits the data into train and test sets, and runs the subspace method if specified. @@ -565,30 +575,31 @@ def setup( self.dataset, [1 - self.test_size, self.test_size] ) else: - self.test_dataset = torch.utils.data.Subset( - self.dataset, self.test_indices - ) + self.test_dataset = torch.utils.data.Subset(self.dataset, self.test_indices) self.train_dataset = torch.utils.data.Subset( - self.dataset, list(set(range(len(self.dataset))) - set(self.test_indices)) + self.dataset, + list(set(range(len(self.dataset))) - set(self.test_indices)), ) if self.subspace_method is not None: # if subspace method is specified if ( - checkpoint_path is None + checkpoint_path is None ): # if no checkpoint path specified, train the subspace method self.subspace_method_train = self.subspace_method( datamodule=self, max_epochs=self.max_epochs, k=None, - train_subspace=True + train_subspace=True, ) # modify the subspace method architecture if specified if self.layer_mods is not None: - self.subspace_method_train = model_modifier.modify_model_architecture( - self.subspace_method_train, - self.layer_mods, + self.subspace_method_train = ( + model_modifier.modify_model_architecture( + self.subspace_method_train, + self.layer_mods, + ) ) # train the subspace method and convert train dataset to the latent space @@ -612,17 +623,16 @@ def setup( # we have already trained the subspace method, so load it from the checkpoint self.subspace_method_train = self.subspace_method( - self, - max_epochs=self.max_epochs, - k=None, - train_subspace=False + self, max_epochs=self.max_epochs, k=None, train_subspace=False ) # will return a init subspace method with the subspace models as instance attributes # modify the subspace method architecture if specified if self.layer_mods is not None: - self.subspace_method_train = model_modifier.modify_model_architecture( - self.subspace_method_train, - self.layer_mods, + self.subspace_method_train = ( + model_modifier.modify_model_architecture( + self.subspace_method_train, + self.layer_mods, + ) ) # load checkpoint state dict @@ -656,7 +666,10 @@ def train_dataloader(self): Dataloader for training. """ return DataLoader( - self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, ) def val_dataloader(self): @@ -669,7 +682,10 @@ def val_dataloader(self): Dataloader for validation. """ return DataLoader( - self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, ) @@ -728,23 +744,23 @@ class KFoldDataModule(pl.LightningDataModule): """ def __init__( - self, - fusion_model, - sources, - output_paths, - prediction_task, - batch_size, - num_folds, - multiclass_dimensions, - subspace_method=None, - image_downsample_size=None, - layer_mods=None, - max_epochs=1000, - extra_log_string_dict=None, - own_early_stopping_callback=None, - num_workers=0, - own_kfold_indices=None, - kwargs=None, + self, + fusion_model, + sources, + output_paths, + prediction_task, + batch_size, + num_folds, + multiclass_dimensions, + subspace_method=None, + image_downsample_size=None, + layer_mods=None, + max_epochs=1000, + extra_log_string_dict=None, + own_early_stopping_callback=None, + num_workers=0, + own_kfold_indices=None, + kwargs=None, ): """ Parameters @@ -877,8 +893,8 @@ def kfold_split(self): return folds # list of tuples of (train_dataset, test_dataset) def setup( - self, - checkpoint_path=None, + self, + checkpoint_path=None, ): """ Splits the data into train and test sets, and runs the subspace method if specified @@ -1014,7 +1030,10 @@ def train_dataloader(self, fold_idx): self.train_dataset, self.test_dataset = self.folds[fold_idx] return DataLoader( - self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, ) def val_dataloader(self, fold_idx): @@ -1034,7 +1053,10 @@ def val_dataloader(self, fold_idx): self.train_dataset, self.test_dataset = self.folds[fold_idx] return DataLoader( - self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, ) @@ -1078,15 +1100,15 @@ class TrainTestGraphDataModule: """ def __init__( - self, - fusion_model, - sources, - graph_creation_method, - test_size, - image_downsample_size=None, - layer_mods=None, - extra_log_string_dict=None, - own_test_indices=None, + self, + fusion_model, + sources, + graph_creation_method, + test_size, + image_downsample_size=None, + layer_mods=None, + extra_log_string_dict=None, + own_test_indices=None, ): """ Parameters @@ -1174,9 +1196,7 @@ def setup(self): self.test_idxs = test_dataset.indices else: self.test_idxs = self.own_test_indices - self.train_idxs = list( - set(range(len(self.dataset))) - set(self.test_idxs) - ) + self.train_idxs = list(set(range(len(self.dataset))) - set(self.test_idxs)) # get the graph data structure self.graph_maker_instance = self.graph_creation_method(self.dataset) @@ -1247,15 +1267,15 @@ class KFoldGraphDataModule: """ def __init__( - self, - num_folds, - fusion_model, - sources, - graph_creation_method, - image_downsample_size=None, - layer_mods=None, - extra_log_string_dict=None, - own_kfold_indices=None, + self, + num_folds, + fusion_model, + sources, + graph_creation_method, + image_downsample_size=None, + layer_mods=None, + extra_log_string_dict=None, + own_kfold_indices=None, ): """ Parameters @@ -1369,7 +1389,7 @@ def setup(self): # modify the graph maker architecture if specified if self.layer_mods is not None: - graph_maker = model_modifier.modify_model_architecture( + self.graph_maker_instance = model_modifier.modify_model_architecture( self.graph_maker_instance, self.layer_mods, ) @@ -1414,25 +1434,25 @@ def get_lightning_module(self): def prepare_fusion_data( - prediction_task, - fusion_model, - data_paths, - output_paths, - kfold=False, - num_folds=None, - test_size=0.2, - batch_size=8, - multiclass_dimensions=None, - image_downsample_size=None, - layer_mods=None, - max_epochs=1000, - checkpoint_path=None, - extra_log_string_dict=None, - own_early_stopping_callback=None, - num_workers=0, - test_indices=None, - own_kfold_indices=None, - **kwargs, + prediction_task, + fusion_model, + data_paths, + output_paths, + kfold=False, + num_folds=None, + test_size=0.2, + batch_size=8, + multiclass_dimensions=None, + image_downsample_size=None, + layer_mods=None, + max_epochs=1000, + checkpoint_path=None, + extra_log_string_dict=None, + own_early_stopping_callback=None, + num_workers=0, + test_indices=None, + own_kfold_indices=None, + **kwargs, ): """ Gets the data module for a specific fusion model and training protocol. @@ -1497,7 +1517,8 @@ def prepare_fusion_data( if kfold and own_early_stopping_callback is not None: raise ValueError( - "Cannot use own early stopping callback with kfold cross validation yet. Working on fixing this currently (Nov 2023)") + "Cannot use own early stopping callback with kfold cross validation yet. Working on fixing this currently (Nov 2023)" + ) # Getting the data paths from the data_paths dictionary into a list data_sources = [ @@ -1519,7 +1540,7 @@ def prepare_fusion_data( image_downsample_size=image_downsample_size, layer_mods=layer_mods, extra_log_string_dict=extra_log_string_dict, - # here is where the kfold split will go + own_kfold_indices=own_kfold_indices, ) else: graph_data_module = TrainTestGraphDataModule( @@ -1543,7 +1564,9 @@ def prepare_fusion_data( for dm_instance in data_module: dm_instance.data_dims = graph_data_module.data_dims dm_instance.own_early_stopping_callback = own_early_stopping_callback - dm_instance.graph_maker_instance = graph_data_module.graph_maker_instance + dm_instance.graph_maker_instance = ( + graph_data_module.graph_maker_instance + ) dm_instance.output_paths = output_paths dm_instance.num_folds = num_folds dm_instance.prediction_task = prediction_task diff --git a/fusilli/fusionmodels/tabularfusion/mcvae_model.py b/fusilli/fusionmodels/tabularfusion/mcvae_model.py index d1563ab..f11981e 100644 --- a/fusilli/fusionmodels/tabularfusion/mcvae_model.py +++ b/fusilli/fusionmodels/tabularfusion/mcvae_model.py @@ -11,6 +11,7 @@ import pandas as pd import numpy as np from fusilli.utils.training_utils import get_checkpoint_filenames_for_subspace_models +import sys from fusilli.utils import check_model_validity @@ -136,7 +137,9 @@ def load_ckpt(self, checkpoint_path): init_dict = { "n_channels": 2, "lat_dim": self.num_latent_dims, - "n_feats": tuple([self.datamodule.data_dims[0], self.datamodule.data_dims[1]]), + "n_feats": tuple( + [self.datamodule.data_dims[0], self.datamodule.data_dims[1]] + ), } self.fit_model = Mcvae(**init_dict, sparse=True) @@ -261,7 +264,9 @@ def train(self, train_dataset, val_dataset=None): with contextlib.redirect_stdout(None): mcvae_fit.optimize(epochs=self.max_epochs, data=mcvae_training_data) ideal_epoch = mcvae_early_stopping_tol( - tolerance=mcvae_tolerance, patience=mcvae_patience, loss_logs=mcvae_fit.loss["total"] + tolerance=mcvae_tolerance, + patience=mcvae_patience, + loss_logs=mcvae_fit.loss["total"], ) mcvae_esfit = Mcvae(**init_dict, sparse=True) @@ -284,7 +289,9 @@ def train(self, train_dataset, val_dataset=None): # getting mean latent space mean_latents = self.get_latents(mcvae_training_data) - return torch.Tensor(mean_latents), pd.DataFrame(labels, columns=["prediction_label"]) + return torch.Tensor(mean_latents), pd.DataFrame( + labels, columns=["prediction_label"] + ) def convert_to_latent(self, test_dataset): """ @@ -373,7 +380,9 @@ def __init__(self, prediction_task, data_dims, multiclass_dimensions): multiclass_dimensions : int Number of classes in the multiclass classification task. """ - ParentFusionModel.__init__(self, prediction_task, data_dims, multiclass_dimensions) + ParentFusionModel.__init__( + self, prediction_task, data_dims, multiclass_dimensions + ) self.prediction_task = prediction_task From c208ab031927389dc559a3ce7bc51c6b3a23fa1c Mon Sep 17 00:00:00 2001 From: Florence Townend Date: Wed, 15 May 2024 13:54:14 +0200 Subject: [PATCH 3/3] formatting --- tests/test_utils/test_training_utils.py | 61 ++++++++++++++----------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/tests/test_utils/test_training_utils.py b/tests/test_utils/test_training_utils.py index 003485b..ab7415e 100644 --- a/tests/test_utils/test_training_utils.py +++ b/tests/test_utils/test_training_utils.py @@ -13,7 +13,7 @@ import os import tempfile import lightning.pytorch as pl -from lightning.pytorch.callbacks import (EarlyStopping, ModelCheckpoint) +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch import Trainer from torchmetrics import Accuracy, R2Score import numpy as np @@ -50,13 +50,11 @@ class SomeFusionModelClass: fold = 1 extra_log_string_dict = {"param1": "value1", "param2": 42} - checkpoint_name = set_checkpoint_name( - fusion_model, fold, extra_log_string_dict - ) + checkpoint_name = set_checkpoint_name(fusion_model, fold, extra_log_string_dict) assert ( - checkpoint_name - == "SomeFusionModelClass_fold_1_param1_value1_param2_42_{epoch:02d}" + checkpoint_name + == "SomeFusionModelClass_fold_1_param1_value1_param2_42_{epoch:02d}" ) @@ -68,9 +66,7 @@ class SomeFusionModelClass: fold = None extra_log_string_dict = {"param1": "value1", "param2": 42} - checkpoint_name = set_checkpoint_name( - fusion_model, fold, extra_log_string_dict - ) + checkpoint_name = set_checkpoint_name(fusion_model, fold, extra_log_string_dict) assert checkpoint_name == "SomeFusionModelClass_param1_value1_param2_42_{epoch:02d}" @@ -83,9 +79,7 @@ class SomeFusionModelClass: fold = 2 extra_log_string_dict = None - checkpoint_name = set_checkpoint_name( - fusion_model, fold, extra_log_string_dict - ) + checkpoint_name = set_checkpoint_name(fusion_model, fold, extra_log_string_dict) assert checkpoint_name == "SomeFusionModelClass_fold_2_{epoch:02d}" @@ -149,7 +143,8 @@ class SomeFusionModelClass: k = None checkpoint_filenames = get_checkpoint_filenames_for_subspace_models( - subspace_method, k) + subspace_method, k + ) expected_filenames = [ "subspace_SomeFusionModelClass_SubspaceModel1_key_value", @@ -228,10 +223,10 @@ def test_get_checkpoint_filename_for_trained_fusion_model_not_found(params, mode # Attempt to get a checkpoint filename when no matching file exists with pytest.raises( - ValueError, match=r"Could not find checkpoint file with name .*" + ValueError, match=r"Could not find checkpoint file with name .*" ): get_checkpoint_filename_for_trained_fusion_model( - params['checkpoint_dir'], model, checkpoint_file_suffix + params["checkpoint_dir"], model, checkpoint_file_suffix ) @@ -253,7 +248,7 @@ def test_get_checkpoint_filename_for_trained_fusion_model_multiple_files(params, # Attempt to get a checkpoint filename when multiple matching files exist with pytest.raises( - ValueError, match=r"Found multiple checkpoint files with name .*" + ValueError, match=r"Found multiple checkpoint files with name .*" ): get_checkpoint_filename_for_trained_fusion_model( params["checkpoint_dir"], model, checkpoint_file_suffix @@ -279,7 +274,10 @@ def mock_logger(): def test_init_trainer_default(mock_logger): # Test initializing trainer with default parameters - trainer = init_trainer(mock_logger, output_paths={}, ) + trainer = init_trainer( + mock_logger, + output_paths={}, + ) assert trainer is not None assert isinstance(trainer, Trainer) assert trainer.max_epochs == 1000 @@ -289,14 +287,18 @@ def test_init_trainer_default(mock_logger): assert trainer.checkpoint_callback is not None -@pytest.mark.filterwarnings("ignore:.*GPU available but not used*.", ) +@pytest.mark.filterwarnings( + "ignore:.*GPU available but not used*.", +) def test_init_trainer_custom_early_stopping(mock_logger): # Test initializing trainer with a custom early stopping callback # custom_early_stopping = Mock() - custom_early_stopping = EarlyStopping(monitor="val_loss", - patience=3, - verbose=True, - mode="max", ) + custom_early_stopping = EarlyStopping( + monitor="val_loss", + patience=3, + verbose=True, + mode="max", + ) trainer = init_trainer( mock_logger, output_paths={}, own_early_stopping_callback=custom_early_stopping ) @@ -311,7 +313,10 @@ def test_init_trainer_custom_early_stopping(mock_logger): assert isinstance(trainer.callbacks[0], EarlyStopping) assert trainer.callbacks[0] == custom_early_stopping for key in custom_early_stopping.__dict__: - assert custom_early_stopping.__dict__[key] == trainer.early_stopping_callback.__dict__[key] + assert ( + custom_early_stopping.__dict__[key] + == trainer.early_stopping_callback.__dict__[key] + ) assert trainer.checkpoint_callback is not None @@ -339,7 +344,11 @@ def test_init_trainer_with_accelerator_and_devices(mock_logger): # Test initializing trainer with custom accelerator and devices params = {"accelerator": "cpu", "devices": 3} - trainer = init_trainer(mock_logger, output_paths={}, training_modifications={"accelerator": "cpu", "devices": 3}) + trainer = init_trainer( + mock_logger, + output_paths={}, + training_modifications={"accelerator": "cpu", "devices": 3}, + ) assert trainer is not None assert isinstance(trainer, Trainer) @@ -441,7 +450,7 @@ def __init__(self, model): # Get the final validation metrics with pytest.raises( - ValueError, - match=r"not in trainer.callback_metrics.keys()", + ValueError, + match=r"not in trainer.callback_metrics.keys()", ): get_final_val_metrics(trainer)