diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index d64d55cb..bed753ff 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -928,7 +928,7 @@ def train_and_predict( output_earlystopping=early_stopping_dataset, cell_line_input=cl_features, drug_input=drug_features, - model_checkpoint_dir=model_checkpoint_dir, + model_checkpoint_dir=temp_dir, ) else: if not os.path.exists(model_checkpoint_dir): diff --git a/drevalpy/models/MOLIR/utils.py b/drevalpy/models/MOLIR/utils.py index 79e46fd8..072c271c 100644 --- a/drevalpy/models/MOLIR/utils.py +++ b/drevalpy/models/MOLIR/utils.py @@ -170,7 +170,7 @@ def create_dataset_and_loaders( train_dataset, batch_size=batch_size, shuffle=False, - num_workers=1, + num_workers=1 if os.name == "nt" else 4, # multiprocessing on Windows is not supported persistent_workers=True, drop_last=True, # avoids batch norm errors if last batch < batch_size ) @@ -354,11 +354,11 @@ def fit( [secrets.choice("0123456789abcdef") for _ in range(20)] ) # preventing conflicts of filenames self.checkpoint_callback = pl.callbacks.ModelCheckpoint( - dirpath=model_checkpoint_dir, + dirpath=os.path.join(model_checkpoint_dir, name), monitor=monitor, mode="min", save_top_k=1, - filename=name, + save_weights_only=True, ) # Initialize the Lightning trainer @@ -367,9 +367,10 @@ def fit( callbacks=[ early_stop_callback, self.checkpoint_callback, - TQDMProgressBar(), + TQDMProgressBar(refresh_rate=0), ], - default_root_dir=os.path.join(model_checkpoint_dir, "moli_checkpoints/lightning_logs/" + name), + devices=1, + enable_model_summary=False, ) if val_loader is None: trainer.fit(self, train_loader) @@ -377,7 +378,7 @@ def fit( trainer.fit(self, train_loader, val_loader) # load best model if self.checkpoint_callback.best_model_path is not None: - checkpoint = torch.load(self.checkpoint_callback.best_model_path) # noqa: S614 + checkpoint = torch.load(self.checkpoint_callback.best_model_path, weights_only=True) # noqa: S614 self.load_state_dict(checkpoint["state_dict"]) def predict( diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 4df6f8e1..ba0d84fe 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -1,5 +1,6 @@ """Contains the SimpleNeuralNetwork model.""" +import platform import warnings import numpy as np @@ -102,7 +103,7 @@ def train( output_earlystopping=output_earlystopping, batch_size=16, patience=5, - num_workers=8, + num_workers=1 if platform.system() == "Windows" else 8, model_checkpoint_dir=model_checkpoint_dir, )