-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] AttributeError: 'NoneType' object has no attribute 'swap_folder' #4998
Comments
@tjruwase @mrwyattii @loadams |
meet same error , this because def _configure_offloading(self, offload_optimizer_config, offload_param_config):
###################### offload optimizer setup ##################################
if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none:
self.offload_optimizer = True
self.offload_optimizer_pin_memory = offload_optimizer_config.pin_memory
self.swap_optimizer = offload_optimizer_config.device == OffloadDeviceEnum.nvme
self.offload_optimizer_fast_init = offload_optimizer_config.fast_init
###################### offload param setup ##################################
if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
self.offload_param = True
self.offload_param_pin_memory = offload_param_config.pin_memory
self.params_in_nvme_and_cpu = offload_param_config.device == OffloadDeviceEnum.nvme
self.max_params_in_cpu = offload_param_config.max_in_cpu
print_rank_0(
f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}",
force=False)
then the error happened , so i try to change the
which the origin code is :
this way may fix this error . the other way is set |
Describe the bug
I tried to train mixtral-7*8B with ZeRO 3, and got an error when saving a checkpoint, just as follows:
To Reproduce
Expected behavior
Successfully save the checkpoint.
System info:
I have updated the lastest version of deepspeed-0.13.0 that supports saving checkpoint under ZeRO 3, but unfortunately with a failure for Mixtral-7*8B model. It seems that the error
AttributeError: 'NoneType' object has no attribute 'swap_folder'
was caused byself.optimizer.optimizer_swapper
?Could you please resolve this issue? Or any help or comments are mostly welcome!
The text was updated successfully, but these errors were encountered: