-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does Torch.ao Support FullyShardedDataParallel? #1413
Comments
The issue was ultimately traced to the function _replace_with_custom_fn_if_matches_filter, as shown :
` If the model is wrapped with FSDP, the output when printing model and model.weight will look like this: With fsdp. BAD CASE (Pdb) p model.weight No fsdp. GOOD CASE (Pdb) p model.weight |
Generally you should quantize the model first, before applying FSDP. This is because you can't re-assign I'm not sure if AQT (the subclass backing |
I have already tried it, and it still results in a memory access error. I will continue to investigate how to address this issue. If you have already resolved it, please share the solution with me. |
Oh I didn't notice you were using FSDP1. FSDP1 won't be supported I think. FSDP2 can be supported (similar to NF4+FSDP2 in torchtune), but I'm not sure if it's currently working now. You can try FSDP2. For example ao/test/prototype/test_quantized_training.py Lines 318 to 323 in cbd7c29
|
When I add FullyShardedDataParallel to the model,
net_model_fsdp = FullyShardedDataParallel(net, **settings)
and then try to quantize it using:
quantize_(net_model_fsdp, int8_dynamic_activation_int8_weight())
I encounter the following error with torch.ao:
RuntimeError: CUDA error: an illegal memory access was encountered.
If I do not use FullyShardedDataParallel and directly quantize net (as shown below), there is no problem:
quantize_(net, int8_dynamic_activation_int8_weight())
Please help me analyze the reason.
The text was updated successfully, but these errors were encountered: