Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encountered a very confusing issue while resuming from checkpoint. #11

Open
DragonDRLI opened this issue Aug 12, 2023 · 4 comments
Open
Assignees

Comments

@DragonDRLI
Copy link

When I resumed training from a checkpoint, I encountered an error with load_state_dict() indicating that the loaded checkpoint is incompatible with the current UNet model structure, causing the loading process to fail. However, I am certain that the saved checkpoint is sd_small, and the UNet model used for resuming training is also sd_small, so this is really a very confusing issue.

The log is as follows:
Traceback (most recent call last):
File "distill_training.py", line 1140, in
main()
File "distill_training.py", line 899, in main
accelerator.load_state(os.path.join(args.output_dir, path))
File "/opt/anaconda3/lib/python3.7/site-packages/accelerate/accelerator.py", line 2347, in load_state
hook(models, input_dir)
File "distill_training.py", line 656, in load_model_hook
model.load_state_dict(load_model.state_dict())
File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1498, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel:
Unexpected key(s) in state_dict: "down_blocks.0.attentions.1.norm.weight", "down_blocks.0.attentions.1.norm.bias",

Could you check this?

Thanks in advance.

@Gothos
Copy link
Contributor

Gothos commented Aug 12, 2023

The problem here is that the Hugging Face save method saves the originally loaded unet-config, which does not represent the actual state of the unet-model once the model has been cut down to size. The current implementation has this fault, but this can be easily fixed. Just download the sd_small config and replace the config in the directory with that. It should work then. You only need to do this once, i.e after a fresh distillation.

Please let me know if your issue gets resolved. We are looking at a permanent fix for this.
I've added a notice suggesting this temporary fix to people in the meantime.

@DragonDRLI
Copy link
Author

Hi, @Gothos ,I downloaded the sd_small model configuration file from HuggingFace to replace the configuration file in the checkpoint directory. Now I am encountering another error, as shown below:
**
Traceback (most recent call last):
File "distill_training.py", line 1141, in
main()
File "distill_training.py", line 900, in main
accelerator.load_state(os.path.join(args.output_dir, path))
File "/opt/anaconda3/lib/python3.7/site-packages/accelerate/accelerator.py", line 2347, in load_state
hook(models, input_dir)
File "distill_training.py", line 654, in load_model_hook
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
File "/opt/anaconda3/lib/python3.7/site-packages/diffusers/modeling_utils.py", line 473, in from_pretrained
set_module_tensor_to_device(model, param_name, param_device, value=param)
File "/opt/anaconda3/lib/python3.7/site-packages/accelerate/utils/modeling.py", line 124, in set_module_tensor_to_device
new_module = getattr(module, split)
new_module = getattr(module, split)new_module = getattr(module, split)

File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1186, in getattr
File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1186, in getattr
File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1186, in getattr
type(self).name, name))
AttributeError type(self).name, name)):
'ModuleList' object has no attribute '1'
type(self).name, name))
AttributeError: 'ModuleList' object has no attribute '1'
AttributeError: 'ModuleList' object has no attribute '1'
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 104) of binary: /opt/anaconda3/bin/python
**

@DragonDRLI
Copy link
Author

Hi, @Gothos Do you know if there are any other viable solutions available?

@Gothos Gothos self-assigned this Aug 17, 2023
@Gothos
Copy link
Contributor

Gothos commented Aug 17, 2023

Update:
I've added extra config params to the distill_training script. You can turn off the U-net shortening part by not passing prepare_unet. I am working on solving the config replacement issue. Sorry for the inconvenience caused.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants