Skip to content

multi-gpu example with >1 GPU crashes without fuse_qkv=True #533

@ggalal

Description

@ggalal

Hello! I've been doing some experimenting with the TransformerEngine and cannot get the multi-gpu example to run successfully with more than a single GPU with fuse_qkv_params disabled:

import os
import torch
import transformer_engine.pytorch as te

torch.cuda.set_device(int(os.environ['LOCAL_RANK']))

# Layer configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 4
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.float16

# Synthetic data
x = torch.rand(sequence_length//int(os.environ['WORLD_SIZE']), batch_size, hidden_size).cuda().to(dtype=dtype)
dy = torch.rand(sequence_length//int(os.environ['WORLD_SIZE']), batch_size, hidden_size).cuda().to(dtype=dtype)

# Configure parallel groups
world_group = torch.distributed.init_process_group(
    "nccl",
    init_method="file:///tmp/rdzv",
    world_size=int(os.environ['WORLD_SIZE']),
    rank=int(os.environ['LOCAL_RANK']),
)
data_parallel_group = torch.distributed.new_group(ranks=list(range(int(os.environ['WORLD_SIZE']))), backend="nccl")
tensor_parallel_group = torch.distributed.new_group(ranks=list(range(int(os.environ['WORLD_SIZE']))), backend="nccl")

# Construct layer
parallel_transformer = te.TransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
    set_parallel_mode=True,
    tp_group=tensor_parallel_group,
    sequence_parallel=True,
)
parallel_transformer.to(dtype=dtype).cuda()
parallel_transformer = torch.nn.parallel.DistributedDataParallel(
    parallel_transformer,
    process_group=data_parallel_group,
)

# Training step
with torch.autocast(device_type='cuda', dtype=dtype):
    y = parallel_transformer(x, attention_mask=None)
y.backward(dy)

with the following run command:
torchrun --nproc_per_node=4 te_example.py

and this is my traceback:

Traceback (most recent call last):
  File "/workspace/local/nvme/rita/src/rita/modeling/tiled/te_example.py", line 30, in <module>
    parallel_transformer = te.TransformerLayer(
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py", line 333, in __init__
    self.self_attention = MultiheadAttention(
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 2450, in __init__
    self.layernorm_qkv = LayerNormLinear(
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/layernorm_linear.py", line 787, in __init__
    self.out_features == overall_split_size
AssertionError: Overall sum of parameters_split (=12288) does not match to out features (=3072)

Any insights would be helpful. Thank you!

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions