Skip to content

Commit

Permalink
Fix pipeline parallelism with FusedAttn (#635)
Browse files Browse the repository at this point in the history
Signed-off-by: Przemek Tredak <[email protected]>
  • Loading branch information
ptrendx authored Jan 26, 2024
1 parent bcbe9b0 commit e7319f5
Showing 1 changed file with 39 additions and 47 deletions.
86 changes: 39 additions & 47 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,32 +1587,30 @@ def forward(
assert (
max_seqlen_q == max_seqlen_kv
), "Maximum sequence length for Q and KV should be the same."
if self.layer_number == 1:
if cu_seqlens_q is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask)
else:
_cu_seqlens_q = cu_seqlens_q
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
if cu_seqlens_q is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask)
else:
_cu_seqlens_q = cu_seqlens_q
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_cu_seqlens_kv = _cu_seqlens_q
query_layer_packed, key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_q, query_layer, key_layer, value_layer
)
else:
if self.layer_number == 1:
if cu_seqlens_q is None or cu_seqlens_kv is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(
attention_mask[0])
_cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(
attention_mask[1])
else:
_cu_seqlens_q = cu_seqlens_q
_cu_seqlens_kv = cu_seqlens_kv
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
if cu_seqlens_q is None or cu_seqlens_kv is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(
attention_mask[0])
_cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(
attention_mask[1])
else:
_cu_seqlens_q = cu_seqlens_q
_cu_seqlens_kv = cu_seqlens_kv
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
query_layer_packed = PackTensors.apply(_indices_q, query_layer)
key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_kv, key_layer, value_layer
Expand Down Expand Up @@ -2030,39 +2028,33 @@ def forward(
global _cu_seqlens_q, _cu_seqlens_kv
if (cu_seqlens_q is not None and cu_seqlens_kv is not None):
# use cu_seqlens when both cu_seqlens and attention_mask are present
if self.layer_number == 1:
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
elif attention_mask is not None:
if self.attention_type == "self":
if self.layer_number == 1:
_cu_seqlens_q = get_cu_seqlens(attention_mask)
_cu_seqlens_kv = _cu_seqlens_q
_cu_seqlens_q = get_cu_seqlens(attention_mask)
_cu_seqlens_kv = _cu_seqlens_q
else:
if self.layer_number == 1:
_cu_seqlens_q = get_cu_seqlens(attention_mask[0])
_cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
_cu_seqlens_q = get_cu_seqlens(attention_mask[0])
_cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
else:
raise Exception("Please provide attention_mask or cu_seqlens for padding!")
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
else:
if self.layer_number == 1:
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
else:
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv

qkv_dtype = TE_DType[query_layer.dtype]

Expand Down

0 comments on commit e7319f5

Please sign in to comment.