-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fused out correction in CP #1248
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: xiaoyao0115 <[email protected]>
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step, | ||
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size, | ||
int rank, bool causal, bool softmax_lse_in_packed_format); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put these two functions close to other CP helper functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't you fix this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i will fix this
rank, | ||
causal, | ||
softmax_lse_in_packed_format, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we merger this two function into one? I see they are almost same.
I think you do not necessarily need softmax_lse_
in cuda code? it's just a view of softmax_lse
, so softmax_lse
should be enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok they are merged, and remove softmax_lse_ in cuda code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lots of arguments are same, only first one is different? should remove the duplicated code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, thanks for your suggestion
Signed-off-by: tailaim <[email protected]>
Signed-off-by: tailaim <[email protected]>
merge main to fused_out_correction
Signed-off-by: tailaim <[email protected]>
Signed-off-by: xiaoyao0115 <[email protected]>
Signed-off-by: xiaoyao0115 <[email protected]>
Signed-off-by: xiaoyao0115 <[email protected]>
fa158ed
to
f977de3
Compare
Signed-off-by: xiaoyao0115 <[email protected]>
for more information, see https://pre-commit.ci
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step, | ||
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size, | ||
int rank, bool causal, bool softmax_lse_in_packed_format); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't you fix this?
) | ||
elif qkv_format == "bshd": | ||
tex.fused_out_correction( | ||
out.view(out.shape[-4], -1, *out.shape[-2:]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why out.shape[-4]
? I guess you mean batch size, which should be out.shape[0]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, i will fix this
rank, | ||
causal, | ||
softmax_lse_in_packed_format, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lots of arguments are same, only first one is different? should remove the duplicated code.
m.def("fused_out_correction", &fused_out_correction, | ||
"fused out correction after qkv calculation without lse_", | ||
py::call_guard<py::gil_scoped_release>()); | ||
|
||
// Other granular functions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this to the place which is next to other thd helper functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okok
@@ -1222,6 +1222,152 @@ std::vector<at::Tensor> fused_attn_bwd( | |||
return {dQ, dK, dV, dBias}; | |||
} | |||
|
|||
/*************************************************************************************************** | |||
* Support THD(including SBHD and BSHD) format for Context Parallel: Fused out correction in forward | |||
**************************************************************************************************/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just say "Support BSHD, SBHD, and THD formats for Context Parallel: Fused out correction in forward"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
num_heads = out.size(1); | ||
dim_per_head = out.size(2); | ||
batch = cu_seqlens.size(0) - 1; | ||
if (softmax_lse_in_packed_format) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current CP implementation uses varlen_fwd only, so softmax_lse_in_packed_format
can be True for SBHD format also, you cannot put this if-else statement under THD format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I understand that softmax_lse_in_packed_format
can be true not only in the THD format. However, in the THD format, if softmax_lse_in_packed_format
is true, lse_seqlen
needs to be specially handled, whereas in SBHD and BSHD formats, it does not.
@@ -102,58 +102,187 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, | |||
} | |||
|
|||
/*************************************************************************************************** | |||
* Support THD format for Context Parallel: Out correction in forward | |||
* Support THD(including SBHD and BSHD) format for Context Parallel: Out correction in forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change this to "Support BSHD, SBHD, and THD formats for Context Parallel: Out correction in forward"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
constexpr int max_tensors = 64; | ||
TensorList<max_tensors> tensors; | ||
|
||
for (int i = 0; i < cp_size; i += max_tensors) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the particular reason to have this for loop? why can't the fusion kernel handle CP>64 in a single fused kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When cp_size
is too large, this CUDA kernel experiences performance degradation due to the excessive number of input parameters (a total of 2 * cp_size + 2
tensor addresses). Therefore, following the implementation in https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh, I adopted a batch processing approach when cp_size
is too large
dtype *cur_out_per_step = out_per_step + idx_per_step; | ||
for (int j = lane_id; j < num_loops_per_head; j += tile_size) { | ||
size_t idx_out; | ||
size_t idx_lse; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move idx_out
and idx_lse
into next-level of inner loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okok
|
||
if constexpr (out_format == QKVFormat::TH) { | ||
for (int i = threadIdx.x; i <= batch; i += blockDim.x) { | ||
cu_seqlens_s[i] = cu_seqlens[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need this to initialize TensorFormat when THD format is applied
ed446e7
to
d346d9c
Compare
Signed-off-by: xiaoyao0115 <[email protected]>
Description
Fused multiple kernels in the out correction computation of attention in CP into a single kernel, reducing the kernel launch time.
Fixes # (issue)
Type of change
Checklist: