@@ -1435,6 +1435,9 @@ def forward(
1435
1435
elif qkv_format == 'bshd' :
1436
1436
# (bs)hd -> bs(hd)
1437
1437
output = output .view (batch_size , max_seqlen_q , - 1 ).contiguous ()
1438
+ elif qkv_format == 'thd' :
1439
+ # thd -> t(hd)
1440
+ output = output .view (output .shape [0 ], - 1 ).contiguous ()
1438
1441
1439
1442
return output
1440
1443
@@ -2299,6 +2302,7 @@ def forward(
2299
2302
and is_backend_avail )
2300
2303
2301
2304
if use_flash_attention :
2305
+ print ("[DotProductAttention]: using flash-attn" ,_flash_attn_version )
2302
2306
return self .flash_attention (query_layer ,
2303
2307
key_layer ,
2304
2308
value_layer ,
@@ -2316,6 +2320,8 @@ def forward(
2316
2320
), "Context parallelism is only implemented with Flash Attention!"
2317
2321
2318
2322
if use_fused_attention :
2323
+ print ("[DotProductAttention]: using cuDNN fused attention (backend "
2324
+ + str (int (fused_attention_backend )) + ")" )
2319
2325
if checkpoint_core_attention :
2320
2326
return self ._checkpointed_attention_forward (self .fused_attention ,
2321
2327
query_layer ,
@@ -2341,6 +2347,7 @@ def forward(
2341
2347
core_attention_bias = core_attention_bias ,
2342
2348
fast_zero_fill = fast_zero_fill )
2343
2349
2350
+ print ("[DotProductAttention]: using unfused DPA" )
2344
2351
if checkpoint_core_attention :
2345
2352
return self ._checkpointed_attention_forward (
2346
2353
self .unfused_attention ,
0 commit comments