Skip to content

Conversation

@adrianosantospb
Copy link

Changing the previous CustomDataParallel to DistributedDataParallel to improve the training speed.

def save_checkpoint(model, name):
if isinstance(model, CustomDataParallel):
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
torch.save(model.module.model.state_dict(), os.path.join(opt.saved_path, name))
Copy link
Owner

@zylo117 zylo117 Aug 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if DDP's module is the real model like DP

Copy link
Owner

@zylo117 zylo117 Aug 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ASAIK, it's model.model.state_dict() for DDP like the ordinary model. Or both are ok?

@zylo117
Copy link
Owner

zylo117 commented Aug 25, 2020

Hi, thanks for your contribution.
Have you tested it by training and loading the last training weights successfully? I think the weights saving & loading could be a problem.
And my another concern is that the printing will be showed for N times (N for number of gpus) and tensorboard will be recorded for N times too.

@adrianosantospb
Copy link
Author

adrianosantospb commented Aug 25, 2020

It's a pleasure. You did a great job. Yes. I've tested in the afternoon, but I'm training a big model now. Tomorrow I will do more tests to see; this model is running on my machine (job).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants