Skip to content

Commit

Permalink
[PyTorch] Let Fused RoPE support CP with THD format (#1238)
Browse files Browse the repository at this point in the history
* Let Fused RoPE support THD with CP

Signed-off-by: Xin Yao <[email protected]>

* add comment

Signed-off-by: Xin Yao <[email protected]>

---------

Signed-off-by: Xin Yao <[email protected]>
Co-authored-by: Xiaowei Ren <[email protected]>
  • Loading branch information
yaox12 and xrennvidia authored Oct 12, 2024
1 parent b36bd0a commit 55dcbb4
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 138 deletions.
120 changes: 83 additions & 37 deletions tests/pytorch/test_fused_rope.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import math
import pytest
import torch
from typing import Callable, Dict, Tuple, Union
from typing import Callable, Tuple, Union
from transformer_engine.pytorch.attention import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
)


def _get_thd_freqs_on_this_cp_rank(
cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
if cp_size > 1:
cp_seg = x.size(0) // 2
full_seqlen = cp_size * x.size(0)
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
else:
return freqs[: x.size(0)]


def apply_rotary_pos_emb_thd(
t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
Expand All @@ -24,20 +45,18 @@ def apply_rotary_pos_emb_thd(
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens)]
[
apply_rotary_pos_emb(
x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs)
)
for x in torch.split(t, seqlens)
]
).squeeze(1)


def get_tol(dtype: torch.dtype) -> Dict:
if dtype == torch.bfloat16:
return dict(atol=1e-2, rtol=1e-2)
elif dtype == torch.float16:
return dict(atol=1e-3, rtol=1e-3)
return dict(atol=1e-5, rtol=1.3e-6)


# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2
Expand Down Expand Up @@ -84,7 +103,11 @@ def test_fused_rope(
emb = rotary_pos_emb(seq_length)

# unfused
output_unfused = apply_rotary_pos_emb(t, emb, tensor_format=tensor_format, fused=False)
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
output_unfused = apply_rotary_pos_emb(
t.float(), emb, tensor_format=tensor_format, fused=False
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
Expand All @@ -102,8 +125,8 @@ def test_fused_rope(
grad_fused = t.grad.detach().clone()
t.grad = None

torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype))
torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))
torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()


Expand All @@ -112,22 +135,34 @@ def test_fused_rope(
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2, 3])
def test_fused_rope_thd(
dtype: torch.dtype,
hidden_size: int,
rotary_percent: float,
transpose: Union[Tuple, None],
loss_func: Callable,
cp_size: int,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
cu_seqlens = torch.tensor(
[0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048],
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
if cp_size > 1:
cu_seqlens_padded = [0]
for i in range(1, len(cu_seqlens)):
cu_seqlens_padded.append(
cu_seqlens_padded[i - 1]
+ math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) * (cp_size * 2)
)
else:
cu_seqlens_padded = cu_seqlens
cu_seqlens_padded = torch.tensor(
cu_seqlens_padded,
dtype=torch.int32,
device=device,
)
t = torch.rand(
(cu_seqlens[-1], head_num, hidden_size),
(cu_seqlens_padded[-1] // cp_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
Expand All @@ -136,23 +171,34 @@ def test_fused_rope_thd(
t.requires_grad = True

rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(cu_seqlens[-1])

# unfused
output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None

# fused
output_fused = apply_rotary_pos_emb(
t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None

torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype))
torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))
emb = rotary_pos_emb(cu_seqlens_padded[-1])

for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
output_unfused = apply_rotary_pos_emb_thd(
t.float(), cu_seqlens_padded, emb, cp_size, cp_rank
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None

# fused
output_fused = apply_rotary_pos_emb(
t,
emb,
fused=True,
tensor_format="thd",
cu_seqlens=cu_seqlens_padded,
cp_size=cp_size,
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None

torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
Loading

0 comments on commit 55dcbb4

Please sign in to comment.