-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fused out correction in CP #1248
base: main
Are you sure you want to change the base?
Changes from 1 commit
49a738e
314c841
c0f1998
c8af3f7
c66d980
dafbd5a
96dfd5e
2df8bd4
4f397da
e33b2a2
35c9836
fd83104
f977de3
97882f1
9d7e3cf
89bbeb7
d346d9c
b2a2fc2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -134,6 +134,17 @@ std::vector<at::Tensor> fused_attn_bwd( | |
at::Tensor fa_prepare_fwd(at::Tensor qkvi); | ||
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); | ||
|
||
void fused_out_correction_lse_(at::Tensor out, const std::vector<at::Tensor> &out_per_step, | ||
const at::Tensor &lse, const at::Tensor *lse_, | ||
const std::vector<at::Tensor> &lse_per_step, | ||
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size, | ||
int rank, bool causal, bool softmax_lse_in_packed_format); | ||
|
||
void fused_out_correction_(at::Tensor out, const std::vector<at::Tensor> &out_per_step, | ||
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step, | ||
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size, | ||
int rank, bool causal, bool softmax_lse_in_packed_format); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. put these two functions close to other CP helper functions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why don't you fix this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i will fix this |
||
/*************************************************************************************************** | ||
* GEMM | ||
**************************************************************************************************/ | ||
|
@@ -438,17 +449,15 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st | |
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, | ||
bool lse_packed); | ||
|
||
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, | ||
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, | ||
bool only_second_half, bool lse_packed); | ||
|
||
void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, | ||
const at::Tensor &cu_seqlens, const std::string &first_half, | ||
const std::string &second_half); | ||
|
||
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, | ||
int world_size, int rank); | ||
|
||
__forceinline__ __device__ int binary_search(int target, int *array, int len); | ||
|
||
/*************************************************************************************************** | ||
* multi_tensor_* kernels | ||
**************************************************************************************************/ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we merger this two function into one? I see they are almost same.
I think you do not necessarily need
softmax_lse_
in cuda code? it's just a view ofsoftmax_lse
, sosoftmax_lse
should be enough?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok they are merged, and remove softmax_lse_ in cuda code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lots of arguments are same, only first one is different? should remove the duplicated code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, thanks for your suggestion