Skip to content

Commit

Permalink
Fix bug in torch compile and seqdim is integer (#1217)
Browse files Browse the repository at this point in the history
* Fix bug in torch compile and seqdim is integer

Signed-off-by: 李金梁 <[email protected]>

* Update attention.py

change the jit_fuser to torch.compile on flash_attn_fwd_out_correction

Signed-off-by: 李金梁 <[email protected]>

* Annotate fused functions

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: 李金梁 <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
wplf and ksivaman authored Oct 11, 2024
1 parent 3b89c36 commit 9ee2dbd
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,13 @@ def flash_attn_p2p_communicate(


@jit_fuser
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step):
def flash_attn_fwd_out_correction(
out: torch.Tensor,
out_per_step: torch.Tensor,
seq_dim: int,
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
Expand All @@ -1368,7 +1374,10 @@ def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softm


@jit_fuser
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
def flash_attn_fwd_softmax_lse_correction(
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
):
"""Merge softmax stats of each step in Attention with context parallelism"""
max_scale = torch.max(softmax_lse, softmax_lse_per_step)
min_scale = torch.min(softmax_lse, softmax_lse_per_step)
Expand All @@ -1378,7 +1387,12 @@ def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):

@jit_fuser
def get_cu_seqlens_on_cp_rank(
cu_seqlens, cu_seqlens_padded_on_cp_rank, cp_size, cp_rank, first_half, second_half
cu_seqlens: torch.Tensor,
cu_seqlens_padded_on_cp_rank: torch.Tensor,
cp_size: int,
cp_rank: int,
first_half: bool,
second_half: bool,
):
"""Compute cu_seqlens of a context parallelism rank"""
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
Expand Down

0 comments on commit 9ee2dbd

Please sign in to comment.