Skip to content

Commit b210114

Browse files
committed
Fix save_dir and strategy args
1 parent f70cce8 commit b210114

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

train.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from prefigure.prefigure import get_all_args, push_wandb_config
22
import json
3+
import os
34
import torch
45
import pytorch_lightning as pl
56

@@ -44,15 +45,21 @@ def main():
4445

4546
training_wrapper = create_training_wrapper_from_config(model_config, model)
4647

48+
wandb_logger = pl.loggers.WandbLogger(project=args.name)
49+
wandb_logger.watch(training_wrapper)
50+
4751
exc_callback = ExceptionCallback()
48-
ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1)
52+
53+
if args.save_dir and isinstance(wandb_logger.experiment.id, str):
54+
checkpoint_dir = os.path.join(args.save_dir, wandb_logger.experiment.project, wandb_logger.experiment.id, "checkpoints")
55+
else:
56+
checkpoint_dir = None
57+
58+
ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, save_top_k=-1)
4959
save_model_config_callback = ModelConfigEmbedderCallback(model_config)
5060

5161
demo_callback = create_demo_callback_from_config(model_config, demo_dl=train_dl)
5262

53-
wandb_logger = pl.loggers.WandbLogger(project=args.name)
54-
wandb_logger.watch(training_wrapper)
55-
5663
#Combine args and config dicts
5764
args_dict = vars(args)
5865
args_dict.update({"model_config": model_config})
@@ -74,7 +81,7 @@ def main():
7481
else:
7582
strategy = args.strategy
7683
else:
77-
strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else None
84+
strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto"
7885

7986
trainer = pl.Trainer(
8087
devices=args.num_gpus,

0 commit comments

Comments
 (0)