Skip to content

Commit

Permalink
skip thd + CP + fused attn tests for cuDNN 9.6+ due to different stat…
Browse files Browse the repository at this point in the history
…s shapes

Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Dec 20, 2024
1 parent 16a9d04 commit bb458ee
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")
if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0):
pytest.skip("THD format is not supported for cuDNN 9.6+!")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,11 @@ def get_attention_backend(
"Disabling FusedAttention as it does not support context parallelism with MLA"
)
use_fused_attention = False
elif cudnn_version >= (9, 6, 0) and qkv_format == "thd":
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with THD for cuDNN 9.6+"
)
use_fused_attention = False

# Filter: Attention mask
# attn_mask_type | attention_mask | supported backends
Expand Down

0 comments on commit bb458ee

Please sign in to comment.