Skip to content

Commit

Permalink
Revert "Avoid redundant computation for cu_seqlens (#535)"
Browse files Browse the repository at this point in the history
This reverts commit fad3044.

Signed-off-by: Przemek Tredak <[email protected]>
  • Loading branch information
ptrendx committed Jan 25, 2024
1 parent 2bce81f commit bcbe9b0
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,24 +1621,20 @@ def forward(
query_layer_packed, key_layer_packed, value_layer_packed)
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
else:
if self.layer_number == 1:
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
else:
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
elif qkv_format == 'thd':
assert not context_parallel, "thd format not supported with context parallelism!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
Expand Down

0 comments on commit bcbe9b0

Please sign in to comment.