Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused out correction in CP #1248

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

xiaoyao0115
Copy link

Description

Fused multiple kernels in the out correction computation of attention in CP into a single kernel, reducing the kernel launch time.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • [✓] New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: xiaoyao0115 <[email protected]>
@xrennvidia xrennvidia self-requested a review October 14, 2024 08:09
transformer_engine/pytorch/csrc/extensions/attention.cu Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/extensions/attention.cu Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/extensions/attention.cu Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/extensions/attention.cu Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/extensions/attention.cu Outdated Show resolved Hide resolved
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
int rank, bool causal, bool softmax_lse_in_packed_format);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put these two functions close to other CP helper functions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't you fix this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will fix this

rank,
causal,
softmax_lse_in_packed_format,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we merger this two function into one? I see they are almost same.
I think you do not necessarily need softmax_lse_ in cuda code? it's just a view of softmax_lse, so softmax_lse should be enough?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok they are merged, and remove softmax_lse_ in cuda code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lots of arguments are same, only first one is different? should remove the duplicated code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, thanks for your suggestion

transformer_engine/pytorch/csrc/extensions/attention.cu Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/extensions/attention.cu Outdated Show resolved Hide resolved
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
int rank, bool causal, bool softmax_lse_in_packed_format);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't you fix this?

)
elif qkv_format == "bshd":
tex.fused_out_correction(
out.view(out.shape[-4], -1, *out.shape[-2:]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why out.shape[-4]? I guess you mean batch size, which should be out.shape[0]?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i will fix this

rank,
causal,
softmax_lse_in_packed_format,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lots of arguments are same, only first one is different? should remove the duplicated code.

m.def("fused_out_correction", &fused_out_correction,
"fused out correction after qkv calculation without lse_",
py::call_guard<py::gil_scoped_release>());

// Other granular functions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to the place which is next to other thd helper functions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okok

@@ -1222,6 +1222,152 @@ std::vector<at::Tensor> fused_attn_bwd(
return {dQ, dK, dV, dBias};
}

/***************************************************************************************************
* Support THD(including SBHD and BSHD) format for Context Parallel: Fused out correction in forward
**************************************************************************************************/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just say "Support BSHD, SBHD, and THD formats for Context Parallel: Fused out correction in forward"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

num_heads = out.size(1);
dim_per_head = out.size(2);
batch = cu_seqlens.size(0) - 1;
if (softmax_lse_in_packed_format) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current CP implementation uses varlen_fwd only, so softmax_lse_in_packed_format can be True for SBHD format also, you cannot put this if-else statement under THD format.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I understand that softmax_lse_in_packed_format can be true not only in the THD format. However, in the THD format, if softmax_lse_in_packed_format is true, lse_seqlen needs to be specially handled, whereas in SBHD and BSHD formats, it does not.

@@ -102,58 +102,187 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
}

/***************************************************************************************************
* Support THD format for Context Parallel: Out correction in forward
* Support THD(including SBHD and BSHD) format for Context Parallel: Out correction in forward
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this to "Support BSHD, SBHD, and THD formats for Context Parallel: Out correction in forward"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

constexpr int max_tensors = 64;
TensorList<max_tensors> tensors;

for (int i = 0; i < cp_size; i += max_tensors) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the particular reason to have this for loop? why can't the fusion kernel handle CP>64 in a single fused kernel?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When cp_size is too large, this CUDA kernel experiences performance degradation due to the excessive number of input parameters (a total of 2 * cp_size + 2 tensor addresses). Therefore, following the implementation in https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh, I adopted a batch processing approach when cp_size is too large

dtype *cur_out_per_step = out_per_step + idx_per_step;
for (int j = lane_id; j < num_loops_per_head; j += tile_size) {
size_t idx_out;
size_t idx_lse;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move idx_out and idx_lse into next-level of inner loop.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okok


if constexpr (out_format == QKVFormat::TH) {
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need this to initialize TensorFormat when THD format is applied

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants