Skip to content

Commit 282e004

Browse files
committed
fix thd fwd output shape for FlashAttention and add backend info for DPA
Signed-off-by: Charlene Yang <[email protected]>
1 parent ddc7908 commit 282e004

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

transformer_engine/pytorch/attention.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,9 @@ def forward(
14351435
elif qkv_format == 'bshd':
14361436
# (bs)hd -> bs(hd)
14371437
output = output.view(batch_size, max_seqlen_q, -1).contiguous()
1438+
elif qkv_format == 'thd':
1439+
# thd -> t(hd)
1440+
output = output.view(output.shape[0], -1).contiguous()
14381441

14391442
return output
14401443

@@ -2299,6 +2302,7 @@ def forward(
22992302
and is_backend_avail)
23002303

23012304
if use_flash_attention:
2305+
print("[DotProductAttention]: using flash-attn",_flash_attn_version)
23022306
return self.flash_attention(query_layer,
23032307
key_layer,
23042308
value_layer,
@@ -2316,6 +2320,8 @@ def forward(
23162320
), "Context parallelism is only implemented with Flash Attention!"
23172321

23182322
if use_fused_attention:
2323+
print("[DotProductAttention]: using cuDNN fused attention (backend "
2324+
+ str(int(fused_attention_backend)) + ")")
23192325
if checkpoint_core_attention:
23202326
return self._checkpointed_attention_forward(self.fused_attention,
23212327
query_layer,
@@ -2341,6 +2347,7 @@ def forward(
23412347
core_attention_bias = core_attention_bias,
23422348
fast_zero_fill = fast_zero_fill)
23432349

2350+
print("[DotProductAttention]: using unfused DPA")
23442351
if checkpoint_core_attention:
23452352
return self._checkpointed_attention_forward(
23462353
self.unfused_attention,

0 commit comments

Comments
 (0)