Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused out correction in CP #1248

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 28 additions & 66 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,24 +1362,6 @@ def flash_attn_p2p_communicate(
return send_recv_reqs


@jit_fuser
def flash_attn_fwd_out_correction(
out: torch.Tensor,
out_per_step: torch.Tensor,
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
movedim_src: int,
movedim_dst: int,
):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(
movedim_src, movedim_dst
)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp
out.add_(out_corrected)


@jit_fuser
def flash_attn_fwd_softmax_lse_correction(
softmax_lse: torch.Tensor,
Expand Down Expand Up @@ -2233,54 +2215,34 @@ def forward(
torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

softmax_lse = softmax_lse.to(torch.float)
for i in range(cp_size):
if qkv_format == "bshd":
out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
out_ = out[:, 1, ...]
elif qkv_format == "sbhd":
out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
out_ = out[1]

if i <= rank or not causal:
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(
out.view(*out_per_step[i].shape),
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2,
2 if softmax_lse_in_packed_format else seq_dim,
)
elif qkv_format == "thd":
tex.thd_out_correction(
out,
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
cu_seqlens_q_padded,
False,
softmax_lse_in_packed_format,
)
else:
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(
out_,
out_per_step[i],
softmax_lse_[..., 1, :],
softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2,
2 if softmax_lse_in_packed_format else seq_dim,
)
elif qkv_format == "thd":
tex.thd_out_correction(
out,
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
cu_seqlens_q_padded,
True,
softmax_lse_in_packed_format,
)

if qkv_format == "thd" or not causal:
tex.fused_out_correction_(
out,
out_per_step,
softmax_lse,
softmax_lse_per_step,
cu_seqlens_q_padded,
qkv_format,
cp_size,
rank,
causal,
softmax_lse_in_packed_format,
)
else:
tex.fused_out_correction_lse_(
out,
out_per_step,
softmax_lse,
softmax_lse_,
softmax_lse_per_step,
cu_seqlens_q_padded,
qkv_format,
cp_size,
rank,
causal,
softmax_lse_in_packed_format,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we merger this two function into one? I see they are almost same.
I think you do not necessarily need softmax_lse_ in cuda code? it's just a view of softmax_lse, so softmax_lse should be enough?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok they are merged, and remove softmax_lse_ in cuda code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lots of arguments are same, only first one is different? should remove the duplicated code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, thanks for your suggestion


if qkv_format != "thd" and softmax_lse_in_packed_format:
# [np, b, sq] -> [np, t]
Expand Down
17 changes: 13 additions & 4 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ std::vector<at::Tensor> fused_attn_bwd(
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

void fused_out_correction_lse_(at::Tensor out, const std::vector<at::Tensor> &out_per_step,
const at::Tensor &lse, const at::Tensor *lse_,
const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
int rank, bool causal, bool softmax_lse_in_packed_format);

void fused_out_correction_(at::Tensor out, const std::vector<at::Tensor> &out_per_step,
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
int rank, bool causal, bool softmax_lse_in_packed_format);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put these two functions close to other CP helper functions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't you fix this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will fix this

/***************************************************************************************************
* GEMM
**************************************************************************************************/
Expand Down Expand Up @@ -438,17 +449,15 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
bool lse_packed);

void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half, bool lse_packed);

void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens, const std::string &first_half,
const std::string &second_half);

at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
int world_size, int rank);

__forceinline__ __device__ int binary_search(int target, int *array, int len);

/***************************************************************************************************
* multi_tensor_* kernels
**************************************************************************************************/
Expand Down
Loading