diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 1007d6aa34..fd8e543adc 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -121,6 +121,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") if dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") + if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0): + pytest.skip("THD format is not supported for cuDNN 9.6+!") config = model_configs_fused_attn[model] if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6bfcde85dc..5c2618b559 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -602,6 +602,11 @@ def get_attention_backend( "Disabling FusedAttention as it does not support context parallelism with MLA" ) use_fused_attention = False + elif cudnn_version >= (9, 6, 0) and qkv_format == "thd": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with THD for cuDNN 9.6+" + ) + use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends