Skip to content

[PyTorch] Linear: error in computation for wgrad if sequence_parallel=True #530

Closed
@Marks101

Description

@Marks101

Hello,
I think we found an error in the computations for the weight gradient in the transformer engine Linear when fp8 and sequence parallel is enabled. The problem originates from gather_along_last_dim used in the backward pass (here). The torch.distributed.all_gather_into_tensor used here in gather_along_last_dim does not concaternate the data correctly as it assumes merging the data along the first dimension. This breaks the computation of the weight gradient when fp8 and sequence parallel is enabled. Linear is the only module using gather_along_last_dim, for this reason other parts of the code are not affected.

Script that shows the problem in the function `gather_along_last_dim`
import os
import torch
import transformer_engine as te


world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])

torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
tp_group = torch.distributed.new_group(ranks=range(world_size))


torch.manual_seed(0)

x_ref = torch.rand((128, 16*world_size)).cuda()

xi = torch.chunk(x_ref, world_size, dim=1)[rank]
assert xi.size(1) == 16

x_test, _ = te.pytorch.distributed.gather_along_last_dim(xi, tp_group)

assert torch.equal(x_test, x_ref)
Script that shows the problems on the transformer engine Linear
import os
import torch
import transformer_engine as te


world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])

torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
tp_group = torch.distributed.new_group(ranks=range(world_size))


def get_linear_wgrad(fp8, sequence_parallel):
    hidden_size = 16
    sequence_length = 16
    model = te.pytorch.Linear(hidden_size, hidden_size,
                              parallel_mode="column",
                              tp_group=tp_group,
                              sequence_parallel=sequence_parallel)

    x = torch.eye(sequence_length, hidden_size).cuda()
    if sequence_parallel:
        x = torch.chunk(x, world_size)[rank]

    with te.pytorch.fp8_autocast(enabled=fp8):
        output = model(x)

    output_grad = torch.ones((sequence_length, hidden_size // world_size)).cuda()
    output.backward(output_grad)
    return model.weight.grad.detach().clone()


t1 = get_linear_wgrad(fp8=True, sequence_parallel=True)
t2 = get_linear_wgrad(fp8=True, sequence_parallel=False)
deviation = torch.max(torch.abs(t2 - t1)).item()

print(f"Rank {rank}: Deviation sequence_parallel=True/False: {deviation}")

Both scripts can be executed with: torchrun --nproc-per-node 2 <script>.py

I prepared a fix that I am going to post as well.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions