-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Fix that ds_secondary_tensor may be dirty when loading the model or zero checkpoint for zero++. #7707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This picture proves that the bug has been fixed. The experimental conditions for
|
…ero checkpoint for zero++. Signed-off-by: zhengchenyu <[email protected]>
7f90ee5 to
9214092
Compare
|
@zhengchenyu thanks for PR. We are taking a look. |
|
The unit test |
|
@zhengchenyu thanks for this PR. My opinion is that invalidating the secondary tensor is the correct solution both these cases. So I am aligned with your solution for For What do you think? For context, |
|
@sfc-gh-truwase However, I found the root cause was that the In fact, we have two solutions to this problem:
Your solution maintains consistent logic: if weight changes, invalidate the secondary tensor. If that's the case, I think here might also need to invalidate the secondary tensor. However, the drawback of this approach is that it wastes a useful secondary tensor.
This approach avoids wasting the In fact, I think both are ok, but I prefer (2). However, if you think we need to maintain consistent logic, I would change it to (1). |
|
And do you mean the case |
Yes, this would be incorrect usage but it is not the API responsibility to detect such cases. So let's not worry about it. |
Thanks for making the change. |
|
@zhengchenyu I apologize I realize I gave you misleading information because I didn't read existing GathereredParameters.exit() carefully. In summary, your current PR is fine as is. I will approve to unblock for merging. I will explain a bit more below just for the records.
Apologies for the confusion and extra work. |
|
Thanks very much for your review! |

ds_secondary_tensormay be dirty during model loading or zero checkpointing for zero++.My task is transformers SFT. In the transformers code, initialization is done using code like the following:
After this,
paramis already a ds tensor, meaning bothds_tensorandds_secondary_tensorexist. Thenload_modelis called to reload the model.In
GatheredParameters.__exit__,params[0].partitionis called, andhas_been_updatedis set toTrue, indicating that data updates are needed. However,_partition_param_secdid not passhas_been_updated. This results inds_secondary_tensorbeing dirty.The zero checkpoint is loaded into
fp16_partitioned_groups_flat, meaningparam.ds_tensorhas been updated. However, the data inparam.ds_secondary_tensorhas not been updated. But the nextallgatherwill use the dirtyparam.ds_secondary_tensor.A dirty
ds_secondary_tensorcan lead to abnormal loss. After callinginvalidate_secondary_tensorin_post_step, the loss returns to normal. This is why loss anomaly only occurs during beginning steps.Relate issue: #7606