Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-gpu example with >1 GPU crashes without fuse_qkv=True #533

Closed
ggalal opened this issue Nov 22, 2023 · 2 comments · Fixed by #590
Closed

multi-gpu example with >1 GPU crashes without fuse_qkv=True #533

ggalal opened this issue Nov 22, 2023 · 2 comments · Fixed by #590
Assignees
Labels
bug Something isn't working

Comments

@ggalal
Copy link

ggalal commented Nov 22, 2023

Hello! I've been doing some experimenting with the TransformerEngine and cannot get the multi-gpu example to run successfully with more than a single GPU with fuse_qkv_params disabled:

import os
import torch
import transformer_engine.pytorch as te

torch.cuda.set_device(int(os.environ['LOCAL_RANK']))

# Layer configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 4
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.float16

# Synthetic data
x = torch.rand(sequence_length//int(os.environ['WORLD_SIZE']), batch_size, hidden_size).cuda().to(dtype=dtype)
dy = torch.rand(sequence_length//int(os.environ['WORLD_SIZE']), batch_size, hidden_size).cuda().to(dtype=dtype)

# Configure parallel groups
world_group = torch.distributed.init_process_group(
    "nccl",
    init_method="file:///tmp/rdzv",
    world_size=int(os.environ['WORLD_SIZE']),
    rank=int(os.environ['LOCAL_RANK']),
)
data_parallel_group = torch.distributed.new_group(ranks=list(range(int(os.environ['WORLD_SIZE']))), backend="nccl")
tensor_parallel_group = torch.distributed.new_group(ranks=list(range(int(os.environ['WORLD_SIZE']))), backend="nccl")

# Construct layer
parallel_transformer = te.TransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
    set_parallel_mode=True,
    tp_group=tensor_parallel_group,
    sequence_parallel=True,
)
parallel_transformer.to(dtype=dtype).cuda()
parallel_transformer = torch.nn.parallel.DistributedDataParallel(
    parallel_transformer,
    process_group=data_parallel_group,
)

# Training step
with torch.autocast(device_type='cuda', dtype=dtype):
    y = parallel_transformer(x, attention_mask=None)
y.backward(dy)

with the following run command:
torchrun --nproc_per_node=4 te_example.py

and this is my traceback:

Traceback (most recent call last):
  File "/workspace/local/nvme/rita/src/rita/modeling/tiled/te_example.py", line 30, in <module>
    parallel_transformer = te.TransformerLayer(
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py", line 333, in __init__
    self.self_attention = MultiheadAttention(
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 2450, in __init__
    self.layernorm_qkv = LayerNormLinear(
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/layernorm_linear.py", line 787, in __init__
    self.out_features == overall_split_size
AssertionError: Overall sum of parameters_split (=12288) does not match to out features (=3072)

Any insights would be helpful. Thank you!

@ggalal ggalal changed the title Testing multi-gpu example with >1 GPU Failed multi-gpu example with >1 GPU Nov 27, 2023
@ggalal ggalal changed the title Failed multi-gpu example with >1 GPU multi-gpu example with >1 GPU crashes Nov 27, 2023
@ggalal ggalal changed the title multi-gpu example with >1 GPU crashes multi-gpu example with >1 GPU crashes with fuse_qkv=True Nov 28, 2023
@ggalal ggalal changed the title multi-gpu example with >1 GPU crashes with fuse_qkv=True multi-gpu example with >1 GPU crashes without fuse_qkv=True Nov 28, 2023
@ptrendx
Copy link
Member

ptrendx commented Dec 1, 2023

@timmoon10 Could you take a look at this?

@timmoon10
Copy link
Collaborator

Thanks for the report! I reproduce the error and #590 fixes it for me. Let me know if you run into any more issues.

I also notice you're setting both the tensor-parallel and data-parallel process groups to be the world process group. These groups should be orthogonal, maybe something like:

tensor_parallel_size = 2
num_tensor_parallel_groups = world_size // tensor_parallel_size
data_parallel_size = num_tensor_parallel_groups
num_data_parallel_groups = world_size // data_parallel_size

data_parallel_group = None
tensor_parallel_group = None
for i in range(num_data_parallel_groups):
    ranks = list(range(i, world_size, num_data_parallel_groups))
    group = torch.distributed.new_group(ranks=ranks, backend="nccl")
    if rank in ranks:
        data_parallel_group = group
for i in range(num_tensor_parallel_groups):
    ranks = list(range(i*tensor_parallel_size, (i+1)*tensor_parallel_size))
    group = torch.distributed.new_group(ranks=ranks, backend="nccl")
    if rank in ranks:
        tensor_parallel_group = group

There is some high-level overview of tensor parallelism in the docs (link) and the Megatron-LM paper goes into more detail.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants