From 55dcbb4b02f560d52dc1215a9de348b37487ee3d Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sat, 12 Oct 2024 13:18:03 +0800 Subject: [PATCH] [PyTorch] Let Fused RoPE support CP with THD format (#1238) * Let Fused RoPE support THD with CP Signed-off-by: Xin Yao * add comment Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> --- tests/pytorch/test_fused_rope.py | 120 +++++++---- .../common/fused_rope/fused_rope.cu | 193 +++++++++++------- .../include/transformer_engine/fused_rope.h | 24 ++- transformer_engine/pytorch/attention.py | 21 +- transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/apply_rope.cu | 13 +- 6 files changed, 237 insertions(+), 138 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index d6ba66cbbc..81c4973756 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -1,17 +1,38 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +import math import pytest import torch -from typing import Callable, Dict, Tuple, Union +from typing import Callable, Tuple, Union from transformer_engine.pytorch.attention import ( RotaryPositionEmbedding, apply_rotary_pos_emb, ) +def _get_thd_freqs_on_this_cp_rank( + cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + if cp_size > 1: + cp_seg = x.size(0) // 2 + full_seqlen = cp_size * x.size(0) + return torch.cat( + [ + freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg], + freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg], + ] + ) + else: + return freqs[: x.size(0)] + + def apply_rotary_pos_emb_thd( - t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor + t: torch.Tensor, + cu_seqlens: torch.Tensor, + freqs: torch.Tensor, + cp_size: int = 1, + cp_rank: int = 0, ) -> torch.Tensor: """A baseline implementation of applying RoPE for `thd` format. @@ -24,20 +45,18 @@ def apply_rotary_pos_emb_thd( Returns: Tensor: Shape [t, h, d]. The input tensor after applying RoPE. """ + cu_seqlens = cu_seqlens // cp_size seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return torch.cat( - [apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens)] + [ + apply_rotary_pos_emb( + x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs) + ) + for x in torch.split(t, seqlens) + ] ).squeeze(1) -def get_tol(dtype: torch.dtype) -> Dict: - if dtype == torch.bfloat16: - return dict(atol=1e-2, rtol=1e-2) - elif dtype == torch.float16: - return dict(atol=1e-3, rtol=1e-3) - return dict(atol=1e-5, rtol=1.3e-6) - - # Gradient is a broadcasted scalar def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: return output.sum() * 2 @@ -84,7 +103,11 @@ def test_fused_rope( emb = rotary_pos_emb(seq_length) # unfused - output_unfused = apply_rotary_pos_emb(t, emb, tensor_format=tensor_format, fused=False) + # The fused kernel computes in float32 internally, so we force the unfused func to use float32 + # for more accurate comparison + output_unfused = apply_rotary_pos_emb( + t.float(), emb, tensor_format=tensor_format, fused=False + ).to(dtype) loss_unfused = loss_func(output_unfused) loss_unfused.backward() grad_unfused = t.grad.detach().clone() @@ -102,8 +125,8 @@ def test_fused_rope( grad_fused = t.grad.detach().clone() t.grad = None - torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype)) - torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype)) + torch.testing.assert_close(output_fused, output_unfused) + torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() @@ -112,22 +135,34 @@ def test_fused_rope( @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("transpose", [None, (1, 2)]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [1, 2, 3]) def test_fused_rope_thd( dtype: torch.dtype, hidden_size: int, rotary_percent: float, transpose: Union[Tuple, None], loss_func: Callable, + cp_size: int, ) -> None: device = torch.device("cuda:0") batch_size, head_num = 2, 64 - cu_seqlens = torch.tensor( - [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048], + cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048] + if cp_size > 1: + cu_seqlens_padded = [0] + for i in range(1, len(cu_seqlens)): + cu_seqlens_padded.append( + cu_seqlens_padded[i - 1] + + math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) * (cp_size * 2) + ) + else: + cu_seqlens_padded = cu_seqlens + cu_seqlens_padded = torch.tensor( + cu_seqlens_padded, dtype=torch.int32, device=device, ) t = torch.rand( - (cu_seqlens[-1], head_num, hidden_size), + (cu_seqlens_padded[-1] // cp_size, head_num, hidden_size), dtype=dtype, device=device, ) @@ -136,23 +171,34 @@ def test_fused_rope_thd( t.requires_grad = True rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) - emb = rotary_pos_emb(cu_seqlens[-1]) - - # unfused - output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb) - loss_unfused = loss_func(output_unfused) - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() - t.grad = None - - # fused - output_fused = apply_rotary_pos_emb( - t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens - ) - loss_fused = loss_func(output_fused) - loss_fused.backward() - grad_fused = t.grad.detach().clone() - t.grad = None - - torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype)) - torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype)) + emb = rotary_pos_emb(cu_seqlens_padded[-1]) + + for cp_rank in range(cp_size): + # unfused + # The fused kernel computes in float32 internally, so we force the unfused func to use float32 + # for more accurate comparison + output_unfused = apply_rotary_pos_emb_thd( + t.float(), cu_seqlens_padded, emb, cp_size, cp_rank + ).to(dtype) + loss_unfused = loss_func(output_unfused) + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() + t.grad = None + + # fused + output_fused = apply_rotary_pos_emb( + t, + emb, + fused=True, + tensor_format="thd", + cu_seqlens=cu_seqlens_padded, + cp_size=cp_size, + cp_rank=cp_rank, + ) + loss_fused = loss_func(output_fused) + loss_fused.backward() + grad_fused = t.grad.detach().clone() + t.grad = None + + torch.testing.assert_close(output_fused, output_unfused) + torch.testing.assert_close(grad_fused, grad_unfused) diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index e7cf940a57..26f104d3ed 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include @@ -15,11 +16,10 @@ namespace transformer_engine { template __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst, - const int offset_block, const int offset_block_dst, - const int h, const int d, const int d2, const int stride_h, - const int stride_d, const int o_stride_h, - const int o_stride_d) { - int s_id = blockIdx.x; + const int s_id, const int offset_block, + const int offset_block_dst, const int h, const int d, + const int d2, const int stride_h, const int stride_d, + const int o_stride_h, const int o_stride_d) { #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos, v_sin; @@ -52,11 +52,10 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs template __device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst, - const int offset_block, const int offset_block_dst, - const int h, const int d, const int d2, - const int stride_h, const int stride_d, + const int s_id, const int offset_block, + const int offset_block_dst, const int h, const int d, + const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { - int s_id = blockIdx.x; #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos = cosf(freqs[s_id * d2 + d_id]); @@ -97,8 +96,8 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freq int s_id = blockIdx.x, b_id = blockIdx.y; int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + fused_rope_block_forward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, + stride_h, stride_d, o_stride_h, o_stride_d); } template @@ -111,40 +110,72 @@ __global__ void fused_rope_backward_kernel(const scalar_t *src, const float *fre int s_id = blockIdx.x, b_id = blockIdx.y; int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + fused_rope_block_backward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, + stride_h, stride_d, o_stride_h, o_stride_d); } template __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d) { + const float *freqs, scalar_t *dst, const int cp_size, + const int cp_rank, const int h, const int d, + const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int t_id = s_id + cu_seqlens[b_id]; - if (t_id >= cu_seqlens[b_id + 1]) return; + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int t_id = s_id + start; + if (t_id >= end) return; int offset_block = t_id * stride_t; int offset_block_dst = t_id * o_stride_t; - fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + + int s_id_for_freqs; + if (cp_size > 1) { + int cur_seqlens = end - start; + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + s_id_for_freqs = s_id; + } + fused_rope_block_forward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, + d2, stride_h, stride_d, o_stride_h, o_stride_d); } template __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d) { + const float *freqs, scalar_t *dst, const int cp_size, + const int cp_rank, const int h, const int d, + const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int t_id = s_id + cu_seqlens[b_id]; - if (t_id >= cu_seqlens[b_id + 1]) return; + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int t_id = s_id + start; + if (t_id >= end) return; int offset_block = t_id * stride_t; int offset_block_dst = t_id * o_stride_t; - fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + + int s_id_for_freqs; + if (cp_size > 1) { + int cur_seqlens = end - start; + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + s_id_for_freqs = s_id; + } + fused_rope_block_backward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, + d2, stride_h, stride_d, o_stride_h, o_stride_d); } template @@ -182,35 +213,37 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const float *fre template void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens, - const float *freqs, scalar_t *output, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { + const float *freqs, scalar_t *output, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, + cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(max_s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); - fused_rope_thd_forward_kernel<<>>(input, cu_seqlens, freqs, output, h, - d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d); + fused_rope_thd_forward_kernel<<>>( + input, cu_seqlens, freqs, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d, + o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, - const float *freqs, scalar_t *input_grads, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { + const float *freqs, scalar_t *input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, + cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(max_s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_thd_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d); + output_grads, cu_seqlens, freqs, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h, + stride_d, o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -243,33 +276,34 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor } void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, - Tensor *output, const int max_s, const int b, const int h, const int d, - const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { + Tensor *output, const int cp_size, const int cp_rank, const int max_s, + const int b, const int h, const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, scalar_t, fused_rope_thd_forward_launcher(reinterpret_cast(input.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), - reinterpret_cast(output->data.dptr), max_s, b, h, - d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, - o_stride_d, stream);); + reinterpret_cast(output->data.dptr), cp_size, + cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, + o_stride_t, o_stride_h, o_stride_d, stream);); } void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens, - const Tensor &freqs, Tensor *input_grads, const int max_s, const int b, - const int h, const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { + const Tensor &freqs, Tensor *input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_thd_backward_launcher(reinterpret_cast(output_grads.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), - reinterpret_cast(input_grads->data.dptr), max_s, - b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, - o_stride_h, o_stride_d, stream);); + reinterpret_cast(input_grads->data.dptr), + cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, + stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); } } // end namespace transformer_engine @@ -302,30 +336,31 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr } void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { + const NVTETensor freqs, NVTETensor output, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_thd_forward); using namespace transformer_engine; - fused_rope_thd_forward( - *reinterpret_cast(input), *reinterpret_cast(cu_seqlens), - *reinterpret_cast(freqs), reinterpret_cast(output), max_s, b, h, d, - d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); + fused_rope_thd_forward(*reinterpret_cast(input), + *reinterpret_cast(cu_seqlens), + *reinterpret_cast(freqs), + reinterpret_cast(output), cp_size, cp_rank, max_s, b, h, d, d2, + stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); } void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { + const NVTETensor freqs, NVTETensor input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_thd_backward); using namespace transformer_engine; - fused_rope_thd_backward(*reinterpret_cast(output_grads), - *reinterpret_cast(cu_seqlens), - *reinterpret_cast(freqs), - reinterpret_cast(input_grads), max_s, b, h, d, d2, stride_t, - stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); + fused_rope_thd_backward( + *reinterpret_cast(output_grads), + *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), + reinterpret_cast(input_grads), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, + stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); } diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index b92de88eca..b7b9b93881 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -72,6 +72,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] freqs The freqs tensor. * \param[out] output Output tensor. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. * \param[in] max_s Max sequence length. * \param[in] b Batch size. * \param[in] h Length of the h dimension of input. @@ -86,11 +88,11 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream); + const NVTETensor freqs, NVTETensor output, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream); /*! \brief Compute the backward of the fused rope in thd format. * @@ -98,6 +100,8 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seq * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] freqs The freqs tensor. * \param[out] input_grads Input gradient to calculate. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. * \param[in] max_s Max sequence length. * \param[in] b Batch size. * \param[in] h Length of the h dimension of output_grads. @@ -112,11 +116,11 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seq * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream); + const NVTETensor freqs, NVTETensor input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 68f645a7f5..d96ff76cd5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4305,6 +4305,8 @@ def forward( freqs: torch.Tensor, tensor_format: str = "sbhd", cu_seqlens: Union[torch.Tensor, None] = None, + cp_size: int = 1, + cp_rank: int = 0, ) -> torch.Tensor: if freqs.dtype != torch.float32: freqs = freqs.float() @@ -4313,11 +4315,13 @@ def forward( elif tensor_format == "bshd": output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) elif tensor_format == "thd": - output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs) + output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank) else: raise ValueError(f"Unsupported tensor_format: {tensor_format}.") ctx.save_for_backward(freqs, cu_seqlens) ctx.tensor_format = tensor_format + ctx.cp_size = cp_size + ctx.cp_rank = cp_rank return output @@ -4331,11 +4335,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output.transpose(0, 1), freqs, True ).transpose(0, 1) elif ctx.tensor_format == "thd": - grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs) + grad_input = tex.fused_rope_thd_backward( + grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank + ) else: raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") - return grad_input, None, None, None, None + return grad_input, None, None, None, None, None def _rotate_half(x: torch.Tensor) -> torch.Tensor: @@ -4353,6 +4359,8 @@ def apply_rotary_pos_emb( tensor_format: str = "sbhd", fused: bool = False, cu_seqlens: Union[torch.Tensor, None] = None, + cp_size: int = 1, + cp_rank: int = 0, ) -> torch.Tensor: """ Apply rotary positional embedding tensor to the input tensor. @@ -4373,12 +4381,17 @@ def apply_rotary_pos_emb( cu_seqlens: torch.Tensor, default = None. Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. Only valid when `tensor_format` is 'thd'. + Should be `cu_seqlens_padded` when cp_size > 1. + cp_size: int, default = 1. + Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True. + cp_rank: int, default = 0. + Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. """ if fused: assert ( tensor_format != "thd" or cu_seqlens is not None ), "cu_seqlens must not be None when tensor_format is 'thd'." - return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens) + return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank) assert tensor_format in ("sbhd", "bshd"), ( "Only formats `sbhd` or `bshd` are supported for input tensor `t` " diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a0ebf6faa7..c30e583178 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -412,10 +412,10 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor const bool transpose_output_memory); at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs); + const at::Tensor &freqs, const int cp_size, const int cp_rank); at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs); + const at::Tensor &freqs, const int cp_size, const int cp_rank); /*************************************************************************************************** * Miscellaneous diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu index c58ba91d5e..c0cd2e9920 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu @@ -121,7 +121,7 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor } at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs) { + const at::Tensor &freqs, const int cp_size, const int cp_rank) { using namespace transformer_engine; TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); @@ -165,14 +165,15 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_ auto output_cu = makeTransformerEngineTensor(output); nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - output_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); + output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, + stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, + at::cuda::getCurrentCUDAStream()); return output; } at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs) { + const at::Tensor &freqs, const int cp_size, const int cp_rank) { using namespace transformer_engine; TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); @@ -214,8 +215,8 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten auto input_grads_cu = makeTransformerEngineTensor(input_grads); nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, - stride_d, o_stride_t, o_stride_h, o_stride_d, + input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, + stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); return input_grads;