Skip to content

Commit

Permalink
minor fixes based on review comments
Browse files Browse the repository at this point in the history
Signed-off-by: xiaoyao0115 <[email protected]>
  • Loading branch information
xiaoyao0115 committed Dec 24, 2024
1 parent d346d9c commit b2a2fc2
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 70 deletions.
9 changes: 4 additions & 5 deletions transformer_engine/common/fused_attn/thd_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
}

/***************************************************************************************************
* Support THD(including SBHD and BSHD) format for Context Parallel: Out correction in forward
* Support BSHD, SBHD, and THD formats for Context Parallel: Out correction in forward
**************************************************************************************************/

// format of out and lse, ignoring d as it’s always the last dimension.
Expand Down Expand Up @@ -258,20 +258,19 @@ __global__ void fused_out_correction_kernel(dtype *out, TensorList<max_tensors>
}

for (int j = lane_id; j < num_loops_per_head; j += tile_size) {
size_t idx_out;
size_t idx_lse;
float4 data = reinterpret_cast<float4 *>(cur_out)[j];
dtype *p = reinterpret_cast<dtype *>(&data);

for (int i = start; i < end; i++) {
if (id[1] >= 0 && start + tensors.start_tensor_this_launch > full_num && i > rank) {
size_t idx_out;
size_t idx_lse;
if (causal && id[1] >= 0 && i > rank) {
idx_out = idx_out_half;
idx_lse = idx_lse_half;
} else {
idx_out = idx_out_full;
idx_lse = idx_lse_full;
}

dtype *cur_out_per_step =
reinterpret_cast<dtype *>(tensors.addresses_out[i]) + idx_out * dim_per_head;
float4 data_per_step = reinterpret_cast<float4 *>(cur_out_per_step)[j];
Expand Down
53 changes: 13 additions & 40 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,46 +2421,6 @@ def forward(

softmax_lse = softmax_lse.to(torch.float)

if qkv_format == "sbhd":
tex.fused_out_correction(
out.view(-1, *out.shape[-3:]),
out_per_step,
softmax_lse,
softmax_lse_per_step,
cu_seqlens_q_padded,
qkv_format,
cp_size,
rank,
causal,
softmax_lse_in_packed_format,
)
elif qkv_format == "bshd":
tex.fused_out_correction(
out.view(out.shape[-4], -1, *out.shape[-2:]),
out_per_step,
softmax_lse,
softmax_lse_per_step,
cu_seqlens_q_padded,
qkv_format,
cp_size,
rank,
causal,
softmax_lse_in_packed_format,
)
else:
tex.fused_out_correction(
out,
out_per_step,
softmax_lse,
softmax_lse_per_step,
cu_seqlens_q_padded,
qkv_format,
cp_size,
rank,
causal,
softmax_lse_in_packed_format,
)

if qkv_format != "thd" and softmax_lse_in_packed_format:
# [np, b, sq] -> [np, t]
softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1)
Expand All @@ -2472,6 +2432,19 @@ def forward(
out = out.view(-1, *out.shape[-3:])
ctx.batch_size = out.shape[1]

tex.fused_out_correction(
out,
out_per_step,
softmax_lse,
softmax_lse_per_step,
cu_seqlens_q_padded,
qkv_format,
cp_size,
rank,
causal,
softmax_lse_in_packed_format,
)

if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False)
out = flash_attn_a2a_communicate(
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,6 @@ std::vector<at::Tensor> fused_attn_bwd(
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

void fused_out_correction(at::Tensor out, const std::vector<at::Tensor> &out_per_step,
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
int rank, bool causal, bool softmax_lse_in_packed_format);

/***************************************************************************************************
* GEMM
**************************************************************************************************/
Expand Down Expand Up @@ -458,6 +453,11 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
int world_size, int rank);

void fused_out_correction(at::Tensor out, const std::vector<at::Tensor> &out_per_step,
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
int rank, bool causal, bool softmax_lse_in_packed_format);

/***************************************************************************************************
* multi_tensor_* kernels
**************************************************************************************************/
Expand Down
23 changes: 7 additions & 16 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1223,15 +1223,14 @@ std::vector<at::Tensor> fused_attn_bwd(
}

/***************************************************************************************************
* Support THD(including SBHD and BSHD) format for Context Parallel: Fused out correction in forward
* Support BSHD, SBHD, and THD formats for Context Parallel: Fused out correction in forward
**************************************************************************************************/

template <typename dtype, bool causal>
void fused_out_correction_helper(at::Tensor out, const std::vector<at::Tensor> &out_per_step,
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
int rank, bool softmax_lse_in_packed_format,
const at::Tensor *lse_ = nullptr) {
int rank, bool softmax_lse_in_packed_format) {
int lse_seqlen;
int batch;
int num_heads;
Expand Down Expand Up @@ -1278,19 +1277,11 @@ void fused_out_correction_helper(at::Tensor out, const std::vector<at::Tensor> &
tensors.addresses_lse[j] = lse_per_step[i + j].data_ptr<float>();
}
if (qkv_format == "sbhd") {
if (softmax_lse_in_packed_format) {
fused_out_correction_kernel<dtype, tile, causal, QKVFormat::SBH, QKVFormat::HBS,
max_tensors>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), tensors, lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i);
} else {
fused_out_correction_kernel<dtype, tile, causal, QKVFormat::SBH, QKVFormat::BHS,
max_tensors>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), tensors, lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i);
}
NVTE_CHECK(softmax_lse_in_packed_format == false, "Packed lse doesn't support SBHD format.");
fused_out_correction_kernel<dtype, tile, causal, QKVFormat::SBH, QKVFormat::BHS, max_tensors>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), tensors, lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i);

} else if (qkv_format == "bshd") {
if (softmax_lse_in_packed_format) {
Expand Down
7 changes: 3 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());

m.def("fused_out_correction", &fused_out_correction,
"fused out correction after qkv calculation without lse_",
py::call_guard<py::gil_scoped_release>());

// Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8",
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
Expand Down Expand Up @@ -199,6 +195,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>());
m.def("fused_out_correction", &fused_out_correction,
"fused out correction after qkv calculation without lse_",
py::call_guard<py::gil_scoped_release>());

// multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
Expand Down

0 comments on commit b2a2fc2

Please sign in to comment.