Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Linear: error in computation for wgrad if sequence_parallel=True #530

Closed
Marks101 opened this issue Nov 22, 2023 · 1 comment
Closed

Comments

@Marks101
Copy link
Contributor

Marks101 commented Nov 22, 2023

Hello,
I think we found an error in the computations for the weight gradient in the transformer engine Linear when fp8 and sequence parallel is enabled. The problem originates from gather_along_last_dim used in the backward pass (here). The torch.distributed.all_gather_into_tensor used here in gather_along_last_dim does not concaternate the data correctly as it assumes merging the data along the first dimension. This breaks the computation of the weight gradient when fp8 and sequence parallel is enabled. Linear is the only module using gather_along_last_dim, for this reason other parts of the code are not affected.

Script that shows the problem in the function `gather_along_last_dim`
import os
import torch
import transformer_engine as te


world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])

torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
tp_group = torch.distributed.new_group(ranks=range(world_size))


torch.manual_seed(0)

x_ref = torch.rand((128, 16*world_size)).cuda()

xi = torch.chunk(x_ref, world_size, dim=1)[rank]
assert xi.size(1) == 16

x_test, _ = te.pytorch.distributed.gather_along_last_dim(xi, tp_group)

assert torch.equal(x_test, x_ref)
Script that shows the problems on the transformer engine Linear
import os
import torch
import transformer_engine as te


world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])

torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
tp_group = torch.distributed.new_group(ranks=range(world_size))


def get_linear_wgrad(fp8, sequence_parallel):
    hidden_size = 16
    sequence_length = 16
    model = te.pytorch.Linear(hidden_size, hidden_size,
                              parallel_mode="column",
                              tp_group=tp_group,
                              sequence_parallel=sequence_parallel)

    x = torch.eye(sequence_length, hidden_size).cuda()
    if sequence_parallel:
        x = torch.chunk(x, world_size)[rank]

    with te.pytorch.fp8_autocast(enabled=fp8):
        output = model(x)

    output_grad = torch.ones((sequence_length, hidden_size // world_size)).cuda()
    output.backward(output_grad)
    return model.weight.grad.detach().clone()


t1 = get_linear_wgrad(fp8=True, sequence_parallel=True)
t2 = get_linear_wgrad(fp8=True, sequence_parallel=False)
deviation = torch.max(torch.abs(t2 - t1)).item()

print(f"Rank {rank}: Deviation sequence_parallel=True/False: {deviation}")

Both scripts can be executed with: torchrun --nproc-per-node 2 <script>.py

I prepared a fix that I am going to post as well.

@ptrendx
Copy link
Member

ptrendx commented Dec 1, 2023

Closing since the fix was merged.

@ptrendx ptrendx closed this as completed Dec 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants