Skip to content

Commit

Permalink
fixed tempdir and MOLIR
Browse files Browse the repository at this point in the history
  • Loading branch information
PascalIversen committed Dec 17, 2024
1 parent 603ff39 commit 17aece9
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion drevalpy/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ def train_and_predict(
output_earlystopping=early_stopping_dataset,
cell_line_input=cl_features,
drug_input=drug_features,
model_checkpoint_dir=model_checkpoint_dir,
model_checkpoint_dir=temp_dir,
)
else:
if not os.path.exists(model_checkpoint_dir):
Expand Down
13 changes: 7 additions & 6 deletions drevalpy/models/MOLIR/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def create_dataset_and_loaders(
train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=1,
num_workers=1 if os.name == "nt" else 4, # multiprocessing on Windows is not supported
persistent_workers=True,
drop_last=True, # avoids batch norm errors if last batch < batch_size
)
Expand Down Expand Up @@ -354,11 +354,11 @@ def fit(
[secrets.choice("0123456789abcdef") for _ in range(20)]
) # preventing conflicts of filenames
self.checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=model_checkpoint_dir,
dirpath=os.path.join(model_checkpoint_dir, name),
monitor=monitor,
mode="min",
save_top_k=1,
filename=name,
save_weights_only=True,
)

# Initialize the Lightning trainer
Expand All @@ -367,17 +367,18 @@ def fit(
callbacks=[
early_stop_callback,
self.checkpoint_callback,
TQDMProgressBar(),
TQDMProgressBar(refresh_rate=0),
],
default_root_dir=os.path.join(model_checkpoint_dir, "moli_checkpoints/lightning_logs/" + name),
devices=1,
enable_model_summary=False,
)
if val_loader is None:
trainer.fit(self, train_loader)
else:
trainer.fit(self, train_loader, val_loader)
# load best model
if self.checkpoint_callback.best_model_path is not None:
checkpoint = torch.load(self.checkpoint_callback.best_model_path) # noqa: S614
checkpoint = torch.load(self.checkpoint_callback.best_model_path, weights_only=True) # noqa: S614
self.load_state_dict(checkpoint["state_dict"])

def predict(
Expand Down
3 changes: 2 additions & 1 deletion drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Contains the SimpleNeuralNetwork model."""

import platform
import warnings

import numpy as np
Expand Down Expand Up @@ -102,7 +103,7 @@ def train(
output_earlystopping=output_earlystopping,
batch_size=16,
patience=5,
num_workers=8,
num_workers=1 if platform.system() == "Windows" else 8,
model_checkpoint_dir=model_checkpoint_dir,
)

Expand Down

0 comments on commit 17aece9

Please sign in to comment.