Closed
Description
Hello,
I think we found an error in the computations for the weight gradient in the transformer engine Linear when fp8 and sequence parallel is enabled. The problem originates from gather_along_last_dim
used in the backward pass (here). The torch.distributed.all_gather_into_tensor
used here in gather_along_last_dim
does not concaternate the data correctly as it assumes merging the data along the first dimension. This breaks the computation of the weight gradient when fp8 and sequence parallel is enabled. Linear is the only module using gather_along_last_dim
, for this reason other parts of the code are not affected.
Script that shows the problem in the function `gather_along_last_dim`
import os
import torch
import transformer_engine as te
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
tp_group = torch.distributed.new_group(ranks=range(world_size))
torch.manual_seed(0)
x_ref = torch.rand((128, 16*world_size)).cuda()
xi = torch.chunk(x_ref, world_size, dim=1)[rank]
assert xi.size(1) == 16
x_test, _ = te.pytorch.distributed.gather_along_last_dim(xi, tp_group)
assert torch.equal(x_test, x_ref)
Script that shows the problems on the transformer engine Linear
import os
import torch
import transformer_engine as te
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
tp_group = torch.distributed.new_group(ranks=range(world_size))
def get_linear_wgrad(fp8, sequence_parallel):
hidden_size = 16
sequence_length = 16
model = te.pytorch.Linear(hidden_size, hidden_size,
parallel_mode="column",
tp_group=tp_group,
sequence_parallel=sequence_parallel)
x = torch.eye(sequence_length, hidden_size).cuda()
if sequence_parallel:
x = torch.chunk(x, world_size)[rank]
with te.pytorch.fp8_autocast(enabled=fp8):
output = model(x)
output_grad = torch.ones((sequence_length, hidden_size // world_size)).cuda()
output.backward(output_grad)
return model.weight.grad.detach().clone()
t1 = get_linear_wgrad(fp8=True, sequence_parallel=True)
t2 = get_linear_wgrad(fp8=True, sequence_parallel=False)
deviation = torch.max(torch.abs(t2 - t1)).item()
print(f"Rank {rank}: Deviation sequence_parallel=True/False: {deviation}")
Both scripts can be executed with: torchrun --nproc-per-node 2 <script>.py
I prepared a fix that I am going to post as well.
Metadata
Metadata
Assignees
Labels
No labels