|
| 1 | +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# See LICENSE for license information. |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | +from typing import Callable, Dict, Tuple, Union |
| 7 | +from transformer_engine.pytorch.attention import ( |
| 8 | + RotaryPositionEmbedding, |
| 9 | + apply_rotary_pos_emb, |
| 10 | +) |
| 11 | + |
| 12 | + |
| 13 | +def apply_rotary_pos_emb_thd( |
| 14 | + t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor |
| 15 | +) -> torch.Tensor: |
| 16 | + """A baseline implementation of applying RoPE for `thd` format. |
| 17 | +
|
| 18 | + Args: |
| 19 | + t (Tensor): Input tensor T is of shape [t, h, d] |
| 20 | + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, |
| 21 | + with shape [b + 1] and dtype torch.int32. |
| 22 | + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] |
| 23 | +
|
| 24 | + Returns: |
| 25 | + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. |
| 26 | + """ |
| 27 | + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() |
| 28 | + return torch.cat( |
| 29 | + [ |
| 30 | + apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) |
| 31 | + for x in torch.split(t, seqlens) |
| 32 | + ] |
| 33 | + ).squeeze(1) |
| 34 | + |
| 35 | + |
| 36 | +def get_tol(dtype: torch.dtype) -> Dict: |
| 37 | + if dtype == torch.bfloat16: |
| 38 | + return dict(atol=1e-2, rtol=1e-2) |
| 39 | + elif dtype == torch.float16: |
| 40 | + return dict(atol=1e-3, rtol=1e-3) |
| 41 | + return dict(atol=1e-5, rtol=1.3e-6) |
| 42 | + |
| 43 | + |
| 44 | +# Gradient is a broadcasted scalar |
| 45 | +def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: |
| 46 | + return output.sum() * 2 |
| 47 | + |
| 48 | +# Gradient is a full tensor |
| 49 | +def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: |
| 50 | + t = torch.ones_like(output) |
| 51 | + return torch.sum(output * t) |
| 52 | + |
| 53 | + |
| 54 | +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) |
| 55 | +@pytest.mark.parametrize("seq_length", [2048, 4096]) |
| 56 | +@pytest.mark.parametrize("hidden_size", [128, 256]) |
| 57 | +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) |
| 58 | +@pytest.mark.parametrize("margin", [0, 10]) |
| 59 | +@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)]) |
| 60 | +@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) |
| 61 | +@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) |
| 62 | +def test_fused_rope( |
| 63 | + dtype: torch.dtype, |
| 64 | + seq_length: int, |
| 65 | + hidden_size: int, |
| 66 | + rotary_percent: float, |
| 67 | + margin: int, |
| 68 | + transpose: Union[Tuple, None], |
| 69 | + tensor_format: str, |
| 70 | + loss_func: Callable, |
| 71 | +) -> None: |
| 72 | + device = torch.device("cuda:0") |
| 73 | + batch_size, head_num = 2, 64 |
| 74 | + t = torch.rand( |
| 75 | + (seq_length - margin, batch_size, head_num, hidden_size), |
| 76 | + dtype=dtype, |
| 77 | + device=device, |
| 78 | + ) |
| 79 | + if tensor_format == "bshd": |
| 80 | + t = t.transpose(0, 1).contiguous() |
| 81 | + if transpose: |
| 82 | + t = t.transpose(*transpose).contiguous().transpose(*transpose) |
| 83 | + t.requires_grad = True |
| 84 | + |
| 85 | + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) |
| 86 | + emb = rotary_pos_emb(seq_length) |
| 87 | + |
| 88 | + # unfused |
| 89 | + output_unfused = apply_rotary_pos_emb( |
| 90 | + t, emb, tensor_format=tensor_format, fused=False |
| 91 | + ) |
| 92 | + loss_unfused = loss_func(output_unfused) |
| 93 | + loss_unfused.backward() |
| 94 | + grad_unfused = t.grad.detach().clone() |
| 95 | + t.grad = None |
| 96 | + |
| 97 | + # fused |
| 98 | + output_fused = apply_rotary_pos_emb( |
| 99 | + t, |
| 100 | + emb, |
| 101 | + tensor_format=tensor_format, |
| 102 | + fused=True, |
| 103 | + ) |
| 104 | + loss_fused = loss_func(output_fused) |
| 105 | + loss_fused.backward() |
| 106 | + grad_fused = t.grad.detach().clone() |
| 107 | + t.grad = None |
| 108 | + |
| 109 | + torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype)) |
| 110 | + torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype)) |
| 111 | + assert output_fused.is_contiguous() |
| 112 | + |
| 113 | + |
| 114 | +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) |
| 115 | +@pytest.mark.parametrize("hidden_size", [128, 256]) |
| 116 | +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) |
| 117 | +@pytest.mark.parametrize("transpose", [None, (1, 2)]) |
| 118 | +@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) |
| 119 | +def test_fused_rope_thd( |
| 120 | + dtype: torch.dtype, |
| 121 | + hidden_size: int, |
| 122 | + rotary_percent: float, |
| 123 | + transpose: Union[Tuple, None], |
| 124 | + loss_func: Callable, |
| 125 | +) -> None: |
| 126 | + device = torch.device("cuda:0") |
| 127 | + batch_size, head_num = 2, 64 |
| 128 | + cu_seqlens = torch.tensor( |
| 129 | + [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048], |
| 130 | + dtype=torch.int32, |
| 131 | + device=device, |
| 132 | + ) |
| 133 | + t = torch.rand( |
| 134 | + (cu_seqlens[-1], head_num, hidden_size), |
| 135 | + dtype=dtype, |
| 136 | + device=device, |
| 137 | + ) |
| 138 | + if transpose: |
| 139 | + t = t.transpose(*transpose).contiguous().transpose(*transpose) |
| 140 | + t.requires_grad = True |
| 141 | + |
| 142 | + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) |
| 143 | + emb = rotary_pos_emb(cu_seqlens[-1]) |
| 144 | + |
| 145 | + # unfused |
| 146 | + output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb) |
| 147 | + loss_unfused = loss_func(output_unfused) |
| 148 | + loss_unfused.backward() |
| 149 | + grad_unfused = t.grad.detach().clone() |
| 150 | + t.grad = None |
| 151 | + |
| 152 | + # fused |
| 153 | + output_fused = apply_rotary_pos_emb( |
| 154 | + t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens |
| 155 | + ) |
| 156 | + loss_fused = loss_func(output_fused) |
| 157 | + loss_fused.backward() |
| 158 | + grad_fused = t.grad.detach().clone() |
| 159 | + t.grad = None |
| 160 | + |
| 161 | + torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype)) |
| 162 | + torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype)) |
0 commit comments