11from prefigure .prefigure import get_all_args , push_wandb_config
22import json
3+ import os
34import torch
45import pytorch_lightning as pl
56
@@ -44,15 +45,21 @@ def main():
4445
4546 training_wrapper = create_training_wrapper_from_config (model_config , model )
4647
48+ wandb_logger = pl .loggers .WandbLogger (project = args .name )
49+ wandb_logger .watch (training_wrapper )
50+
4751 exc_callback = ExceptionCallback ()
48- ckpt_callback = pl .callbacks .ModelCheckpoint (every_n_train_steps = args .checkpoint_every , save_top_k = - 1 )
52+
53+ if args .save_dir and isinstance (wandb_logger .experiment .id , str ):
54+ checkpoint_dir = os .path .join (args .save_dir , wandb_logger .experiment .project , wandb_logger .experiment .id , "checkpoints" )
55+ else :
56+ checkpoint_dir = None
57+
58+ ckpt_callback = pl .callbacks .ModelCheckpoint (every_n_train_steps = args .checkpoint_every , dirpath = checkpoint_dir , save_top_k = - 1 )
4959 save_model_config_callback = ModelConfigEmbedderCallback (model_config )
5060
5161 demo_callback = create_demo_callback_from_config (model_config , demo_dl = train_dl )
5262
53- wandb_logger = pl .loggers .WandbLogger (project = args .name )
54- wandb_logger .watch (training_wrapper )
55-
5663 #Combine args and config dicts
5764 args_dict = vars (args )
5865 args_dict .update ({"model_config" : model_config })
@@ -74,7 +81,7 @@ def main():
7481 else :
7582 strategy = args .strategy
7683 else :
77- strategy = 'ddp_find_unused_parameters_true' if args .num_gpus > 1 else None
84+ strategy = 'ddp_find_unused_parameters_true' if args .num_gpus > 1 else "auto"
7885
7986 trainer = pl .Trainer (
8087 devices = args .num_gpus ,
0 commit comments