From b2a2fc2df4490b5107cf3dddad0f8cf7b80a63a8 Mon Sep 17 00:00:00 2001 From: xiaoyao0115 <1804647152@qq.com> Date: Tue, 24 Dec 2024 16:13:15 +0000 Subject: [PATCH] minor fixes based on review comments Signed-off-by: xiaoyao0115 <1804647152@qq.com> --- .../common/fused_attn/thd_utils.h | 9 ++-- transformer_engine/pytorch/attention.py | 53 +++++-------------- transformer_engine/pytorch/csrc/extensions.h | 10 ++-- .../pytorch/csrc/extensions/attention.cu | 23 +++----- .../pytorch/csrc/extensions/pybind.cpp | 7 ++- 5 files changed, 32 insertions(+), 70 deletions(-) diff --git a/transformer_engine/common/fused_attn/thd_utils.h b/transformer_engine/common/fused_attn/thd_utils.h index 25becab38f..44f9abba76 100644 --- a/transformer_engine/common/fused_attn/thd_utils.h +++ b/transformer_engine/common/fused_attn/thd_utils.h @@ -102,7 +102,7 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, } /*************************************************************************************************** - * Support THD(including SBHD and BSHD) format for Context Parallel: Out correction in forward + * Support BSHD, SBHD, and THD formats for Context Parallel: Out correction in forward **************************************************************************************************/ // format of out and lse, ignoring d as it’s always the last dimension. @@ -258,20 +258,19 @@ __global__ void fused_out_correction_kernel(dtype *out, TensorList } for (int j = lane_id; j < num_loops_per_head; j += tile_size) { - size_t idx_out; - size_t idx_lse; float4 data = reinterpret_cast(cur_out)[j]; dtype *p = reinterpret_cast(&data); for (int i = start; i < end; i++) { - if (id[1] >= 0 && start + tensors.start_tensor_this_launch > full_num && i > rank) { + size_t idx_out; + size_t idx_lse; + if (causal && id[1] >= 0 && i > rank) { idx_out = idx_out_half; idx_lse = idx_lse_half; } else { idx_out = idx_out_full; idx_lse = idx_lse_full; } - dtype *cur_out_per_step = reinterpret_cast(tensors.addresses_out[i]) + idx_out * dim_per_head; float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 2875af1f83..0d2e9d4ed7 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2421,46 +2421,6 @@ def forward( softmax_lse = softmax_lse.to(torch.float) - if qkv_format == "sbhd": - tex.fused_out_correction( - out.view(-1, *out.shape[-3:]), - out_per_step, - softmax_lse, - softmax_lse_per_step, - cu_seqlens_q_padded, - qkv_format, - cp_size, - rank, - causal, - softmax_lse_in_packed_format, - ) - elif qkv_format == "bshd": - tex.fused_out_correction( - out.view(out.shape[-4], -1, *out.shape[-2:]), - out_per_step, - softmax_lse, - softmax_lse_per_step, - cu_seqlens_q_padded, - qkv_format, - cp_size, - rank, - causal, - softmax_lse_in_packed_format, - ) - else: - tex.fused_out_correction( - out, - out_per_step, - softmax_lse, - softmax_lse_per_step, - cu_seqlens_q_padded, - qkv_format, - cp_size, - rank, - causal, - softmax_lse_in_packed_format, - ) - if qkv_format != "thd" and softmax_lse_in_packed_format: # [np, b, sq] -> [np, t] softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1) @@ -2472,6 +2432,19 @@ def forward( out = out.view(-1, *out.shape[-3:]) ctx.batch_size = out.shape[1] + tex.fused_out_correction( + out, + out_per_step, + softmax_lse, + softmax_lse_per_step, + cu_seqlens_q_padded, + qkv_format, + cp_size, + rank, + causal, + softmax_lse_in_packed_format, + ) + if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False) out = flash_attn_a2a_communicate( diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b22b60ddb8..0d3cfc148b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -136,11 +136,6 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); -void fused_out_correction(at::Tensor out, const std::vector &out_per_step, - const at::Tensor &lse, const std::vector &lse_per_step, - const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size, - int rank, bool causal, bool softmax_lse_in_packed_format); - /*************************************************************************************************** * GEMM **************************************************************************************************/ @@ -458,6 +453,11 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, int world_size, int rank); +void fused_out_correction(at::Tensor out, const std::vector &out_per_step, + const at::Tensor &lse, const std::vector &lse_per_step, + const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size, + int rank, bool causal, bool softmax_lse_in_packed_format); + /*************************************************************************************************** * multi_tensor_* kernels **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index d1184c6904..d9747f55c4 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1223,15 +1223,14 @@ std::vector fused_attn_bwd( } /*************************************************************************************************** - * Support THD(including SBHD and BSHD) format for Context Parallel: Fused out correction in forward + * Support BSHD, SBHD, and THD formats for Context Parallel: Fused out correction in forward **************************************************************************************************/ template void fused_out_correction_helper(at::Tensor out, const std::vector &out_per_step, const at::Tensor &lse, const std::vector &lse_per_step, const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size, - int rank, bool softmax_lse_in_packed_format, - const at::Tensor *lse_ = nullptr) { + int rank, bool softmax_lse_in_packed_format) { int lse_seqlen; int batch; int num_heads; @@ -1278,19 +1277,11 @@ void fused_out_correction_helper(at::Tensor out, const std::vector & tensors.addresses_lse[j] = lse_per_step[i + j].data_ptr(); } if (qkv_format == "sbhd") { - if (softmax_lse_in_packed_format) { - fused_out_correction_kernel - <<>>( - out.data_ptr(), tensors, lse.data_ptr(), cu_seqlens.data_ptr(), - batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i); - } else { - fused_out_correction_kernel - <<>>( - out.data_ptr(), tensors, lse.data_ptr(), cu_seqlens.data_ptr(), - batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i); - } + NVTE_CHECK(softmax_lse_in_packed_format == false, "Packed lse doesn't support SBHD format."); + fused_out_correction_kernel + <<>>( + out.data_ptr(), tensors, lse.data_ptr(), cu_seqlens.data_ptr(), + batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i); } else if (qkv_format == "bshd") { if (softmax_lse_in_packed_format) { diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 4b7ddb7bc9..d07bd7f0b1 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -41,10 +41,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Scaled Bottom-Right Corner Aligned Masked Softmax BWD", py::call_guard()); - m.def("fused_out_correction", &fused_out_correction, - "fused out correction after qkv calculation without lse_", - py::call_guard()); - // Other granular functions m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8", py::call_guard(), py::arg("input"), py::arg("weight"), @@ -199,6 +195,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices, "Generate partitioned indices for inputs in THD format", py::call_guard()); + m.def("fused_out_correction", &fused_out_correction, + "fused out correction after qkv calculation without lse_", + py::call_guard()); // multi-tensor functions m.def("multi_tensor_scale", &multi_tensor_scale_cuda,