You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello! I've been doing some experimenting with the TransformerEngine and cannot get the multi-gpu example to run successfully with more than a single GPU with fuse_qkv_params disabled:
import os
import torch
import transformer_engine.pytorch as te
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
# Layer configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 4
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.float16
# Synthetic data
x = torch.rand(sequence_length//int(os.environ['WORLD_SIZE']), batch_size, hidden_size).cuda().to(dtype=dtype)
dy = torch.rand(sequence_length//int(os.environ['WORLD_SIZE']), batch_size, hidden_size).cuda().to(dtype=dtype)
# Configure parallel groups
world_group = torch.distributed.init_process_group(
"nccl",
init_method="file:///tmp/rdzv",
world_size=int(os.environ['WORLD_SIZE']),
rank=int(os.environ['LOCAL_RANK']),
)
data_parallel_group = torch.distributed.new_group(ranks=list(range(int(os.environ['WORLD_SIZE']))), backend="nccl")
tensor_parallel_group = torch.distributed.new_group(ranks=list(range(int(os.environ['WORLD_SIZE']))), backend="nccl")
# Construct layer
parallel_transformer = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
set_parallel_mode=True,
tp_group=tensor_parallel_group,
sequence_parallel=True,
)
parallel_transformer.to(dtype=dtype).cuda()
parallel_transformer = torch.nn.parallel.DistributedDataParallel(
parallel_transformer,
process_group=data_parallel_group,
)
# Training step
with torch.autocast(device_type='cuda', dtype=dtype):
y = parallel_transformer(x, attention_mask=None)
y.backward(dy)
with the following run command: torchrun --nproc_per_node=4 te_example.py
and this is my traceback:
Traceback (most recent call last):
File "/workspace/local/nvme/rita/src/rita/modeling/tiled/te_example.py", line 30, in <module>
parallel_transformer = te.TransformerLayer(
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py", line 333, in __init__
self.self_attention = MultiheadAttention(
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 2450, in __init__
self.layernorm_qkv = LayerNormLinear(
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/layernorm_linear.py", line 787, in __init__
self.out_features == overall_split_size
AssertionError: Overall sum of parameters_split (=12288) does not match to out features (=3072)
Any insights would be helpful. Thank you!
The text was updated successfully, but these errors were encountered:
ggalal
changed the title
Testing multi-gpu example with >1 GPU
Failed multi-gpu example with >1 GPU
Nov 27, 2023
ggalal
changed the title
Failed multi-gpu example with >1 GPU
multi-gpu example with >1 GPU crashes
Nov 27, 2023
ggalal
changed the title
multi-gpu example with >1 GPU crashes
multi-gpu example with >1 GPU crashes with fuse_qkv=True
Nov 28, 2023
ggalal
changed the title
multi-gpu example with >1 GPU crashes with fuse_qkv=True
multi-gpu example with >1 GPU crashes without fuse_qkv=True
Nov 28, 2023
Thanks for the report! I reproduce the error and #590 fixes it for me. Let me know if you run into any more issues.
I also notice you're setting both the tensor-parallel and data-parallel process groups to be the world process group. These groups should be orthogonal, maybe something like:
Hello! I've been doing some experimenting with the TransformerEngine and cannot get the multi-gpu example to run successfully with more than a single GPU with fuse_qkv_params disabled:
with the following run command:
torchrun --nproc_per_node=4 te_example.py
and this is my traceback:
Any insights would be helpful. Thank you!
The text was updated successfully, but these errors were encountered: