You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
I think we found an error in the computations for the weight gradient in the transformer engine Linear when fp8 and sequence parallel is enabled. The problem originates from gather_along_last_dim used in the backward pass (here). The torch.distributed.all_gather_into_tensor used here in gather_along_last_dim does not concaternate the data correctly as it assumes merging the data along the first dimension. This breaks the computation of the weight gradient when fp8 and sequence parallel is enabled. Linear is the only module using gather_along_last_dim, for this reason other parts of the code are not affected.
Script that shows the problem in the function `gather_along_last_dim`
Hello,
I think we found an error in the computations for the weight gradient in the transformer engine Linear when fp8 and sequence parallel is enabled. The problem originates from
gather_along_last_dim
used in the backward pass (here). Thetorch.distributed.all_gather_into_tensor
used here ingather_along_last_dim
does not concaternate the data correctly as it assumes merging the data along the first dimension. This breaks the computation of the weight gradient when fp8 and sequence parallel is enabled. Linear is the only module usinggather_along_last_dim
, for this reason other parts of the code are not affected.Script that shows the problem in the function `gather_along_last_dim`
Script that shows the problems on the transformer engine Linear
Both scripts can be executed with:
torchrun --nproc-per-node 2 <script>.py
I prepared a fix that I am going to post as well.
The text was updated successfully, but these errors were encountered: