Skip to content

Commit 6c1a8bb

Browse files
yaox12timmoon10ptrendx
authored
[Common][PyTorch] Fused apply_rotorary_pos_emb (#517)
* fused apply rope Signed-off-by: Xin Yao <[email protected]> * Apply suggestions from code review Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Xin Yao <[email protected]> * resolve comments Signed-off-by: Xin Yao <[email protected]> * make rotary_percent optional Signed-off-by: Xin Yao <[email protected]> * fix ci Signed-off-by: Xin Yao <[email protected]> * fix test Signed-off-by: Xin Yao <[email protected]> * add rope test to qa Signed-off-by: Xin Yao <[email protected]> * fix linting Signed-off-by: Xin Yao <[email protected]> * sync apex: add transpose_output_memory Signed-off-by: Xin Yao <[email protected]> * small fix Signed-off-by: Xin Yao <[email protected]> * sync apex: fuse sin/cos Signed-off-by: Xin Yao <[email protected]> * sync apex: fused rope for thd format Signed-off-by: Xin Yao <[email protected]> * fix lint Signed-off-by: Xin Yao <[email protected]> * Fix license headers Signed-off-by: Przemek Tredak <[email protected]> * add support for bshd format Signed-off-by: Xin Yao <[email protected]> * support different seq length Signed-off-by: Xin Yao <[email protected]> * update Signed-off-by: Xin Yao <[email protected]> * update copyright Signed-off-by: Xin Yao <[email protected]> * remove transpose_output_memory Signed-off-by: Xin Yao <[email protected]> * Make outputs contiguous in SBHD case Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Przemek Tredak <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]>
1 parent b957aa4 commit 6c1a8bb

File tree

10 files changed

+1038
-28
lines changed

10 files changed

+1038
-28
lines changed

qa/L0_pytorch_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
1212
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
1313
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
1414
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
15+
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
1516
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
1617
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py

tests/pytorch/test_fused_rope.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
import pytest
5+
import torch
6+
from typing import Callable, Dict, Tuple, Union
7+
from transformer_engine.pytorch.attention import (
8+
RotaryPositionEmbedding,
9+
apply_rotary_pos_emb,
10+
)
11+
12+
13+
def apply_rotary_pos_emb_thd(
14+
t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor
15+
) -> torch.Tensor:
16+
"""A baseline implementation of applying RoPE for `thd` format.
17+
18+
Args:
19+
t (Tensor): Input tensor T is of shape [t, h, d]
20+
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
21+
with shape [b + 1] and dtype torch.int32.
22+
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
23+
24+
Returns:
25+
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
26+
"""
27+
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
28+
return torch.cat(
29+
[
30+
apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)])
31+
for x in torch.split(t, seqlens)
32+
]
33+
).squeeze(1)
34+
35+
36+
def get_tol(dtype: torch.dtype) -> Dict:
37+
if dtype == torch.bfloat16:
38+
return dict(atol=1e-2, rtol=1e-2)
39+
elif dtype == torch.float16:
40+
return dict(atol=1e-3, rtol=1e-3)
41+
return dict(atol=1e-5, rtol=1.3e-6)
42+
43+
44+
# Gradient is a broadcasted scalar
45+
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
46+
return output.sum() * 2
47+
48+
# Gradient is a full tensor
49+
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
50+
t = torch.ones_like(output)
51+
return torch.sum(output * t)
52+
53+
54+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
55+
@pytest.mark.parametrize("seq_length", [2048, 4096])
56+
@pytest.mark.parametrize("hidden_size", [128, 256])
57+
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
58+
@pytest.mark.parametrize("margin", [0, 10])
59+
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
60+
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
61+
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
62+
def test_fused_rope(
63+
dtype: torch.dtype,
64+
seq_length: int,
65+
hidden_size: int,
66+
rotary_percent: float,
67+
margin: int,
68+
transpose: Union[Tuple, None],
69+
tensor_format: str,
70+
loss_func: Callable,
71+
) -> None:
72+
device = torch.device("cuda:0")
73+
batch_size, head_num = 2, 64
74+
t = torch.rand(
75+
(seq_length - margin, batch_size, head_num, hidden_size),
76+
dtype=dtype,
77+
device=device,
78+
)
79+
if tensor_format == "bshd":
80+
t = t.transpose(0, 1).contiguous()
81+
if transpose:
82+
t = t.transpose(*transpose).contiguous().transpose(*transpose)
83+
t.requires_grad = True
84+
85+
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
86+
emb = rotary_pos_emb(seq_length)
87+
88+
# unfused
89+
output_unfused = apply_rotary_pos_emb(
90+
t, emb, tensor_format=tensor_format, fused=False
91+
)
92+
loss_unfused = loss_func(output_unfused)
93+
loss_unfused.backward()
94+
grad_unfused = t.grad.detach().clone()
95+
t.grad = None
96+
97+
# fused
98+
output_fused = apply_rotary_pos_emb(
99+
t,
100+
emb,
101+
tensor_format=tensor_format,
102+
fused=True,
103+
)
104+
loss_fused = loss_func(output_fused)
105+
loss_fused.backward()
106+
grad_fused = t.grad.detach().clone()
107+
t.grad = None
108+
109+
torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype))
110+
torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))
111+
assert output_fused.is_contiguous()
112+
113+
114+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
115+
@pytest.mark.parametrize("hidden_size", [128, 256])
116+
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
117+
@pytest.mark.parametrize("transpose", [None, (1, 2)])
118+
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
119+
def test_fused_rope_thd(
120+
dtype: torch.dtype,
121+
hidden_size: int,
122+
rotary_percent: float,
123+
transpose: Union[Tuple, None],
124+
loss_func: Callable,
125+
) -> None:
126+
device = torch.device("cuda:0")
127+
batch_size, head_num = 2, 64
128+
cu_seqlens = torch.tensor(
129+
[0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048],
130+
dtype=torch.int32,
131+
device=device,
132+
)
133+
t = torch.rand(
134+
(cu_seqlens[-1], head_num, hidden_size),
135+
dtype=dtype,
136+
device=device,
137+
)
138+
if transpose:
139+
t = t.transpose(*transpose).contiguous().transpose(*transpose)
140+
t.requires_grad = True
141+
142+
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
143+
emb = rotary_pos_emb(cu_seqlens[-1])
144+
145+
# unfused
146+
output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb)
147+
loss_unfused = loss_func(output_unfused)
148+
loss_unfused.backward()
149+
grad_unfused = t.grad.detach().clone()
150+
t.grad = None
151+
152+
# fused
153+
output_fused = apply_rotary_pos_emb(
154+
t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens
155+
)
156+
loss_fused = loss_func(output_fused)
157+
loss_fused.backward()
158+
grad_fused = t.grad.detach().clone()
159+
t.grad = None
160+
161+
torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype))
162+
torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))

transformer_engine/common/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ list(APPEND transformer_engine_SOURCES
3434
fused_softmax/scaled_masked_softmax.cu
3535
fused_softmax/scaled_upper_triang_masked_softmax.cu
3636
fused_softmax/scaled_masked_softmax.cu
37-
fused_softmax/scaled_upper_triang_masked_softmax.cu)
37+
fused_softmax/scaled_upper_triang_masked_softmax.cu
38+
fused_rope/fused_rope.cu)
3839
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
3940
target_include_directories(transformer_engine PUBLIC
4041
"${CMAKE_CURRENT_SOURCE_DIR}/include")

0 commit comments

Comments
 (0)