From 535af66e7195d224ab47af769aca2df6a563700a Mon Sep 17 00:00:00 2001 From: Md Fahim Faysal Khan Date: Sun, 27 Oct 2024 01:22:22 -0700 Subject: [PATCH 1/2] exposed cp params to DPA api Signed-off-by: Md Fahim Faysal Khan --- tests/jax/test_distributed_fused_attn.py | 3 --- transformer_engine/jax/attention.py | 6 ------ transformer_engine/jax/flax/transformer.py | 15 +++++++++++++++ 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 23a26087d4..d34afb340c 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -133,7 +133,6 @@ def test_self_attn( seqlen, hidden, None, # no window - False, # not context parallel ): pytest.skip(f"No FusedAttn backend found") @@ -268,7 +267,6 @@ def test_cross_attn( seqlen, hidden, None, # no window - False, # not context parallel ): pytest.skip(f"No FusedAttn backend found") @@ -438,7 +436,6 @@ def test_contex_parallel_self_attn( seqlen, hidden, None, # no window - cp_size > 1, ): pytest.skip(f"No FusedAttn backend found") diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index b3b11bb9dd..218fd174de 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -190,7 +190,6 @@ def is_fused_attn_kernel_available( kv_max_seqlen, head_dim, window_size: Optional[Tuple[int, int]] = None, - is_context_parallel: bool = False, ): """ To check whether the fused attention kernel is supported @@ -215,11 +214,6 @@ def make_helper(attn_mask_type): if not make_helper(attn_mask_type).is_fused_attn_kernel_available(): return False - # For context parallel need to check additional masking types - if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK: - if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available(): - return False - return True diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index b91584219f..cb71188221 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -262,6 +262,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me scale_factor: Optional[float] = None transpose_batch_sequence: bool = False window_size: Optional[Tuple[int, int]] = None + context_parallel_causal_load_balanced: bool = False + context_parallel_axis: str = "" @nn.compact def __call__( @@ -308,6 +310,8 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, ) elif self.qkv_layout == QKVLayout.BSHD_BS2HD: """kvpacked format, treat @@ -331,6 +335,8 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, ) elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: if self.transpose_batch_sequence: @@ -349,6 +355,8 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, ) else: raise ValueError(f"Unsupported {self.qkv_layout=}.") @@ -463,6 +471,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). window_size: Optional[Tuple[int, int]], default = None Sliding window size. The default value is no sliding window. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Optimization parameters ----------------------- @@ -483,6 +494,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods scale_factor: Optional[float] = None transpose_batch_sequence: bool = True window_size: Optional[Tuple[int, int]] = None + context_parallel_causal_load_balanced: bool = False + context_parallel_axis: str = "" @nn.compact def __call__( @@ -614,6 +627,8 @@ def __call__( transpose_batch_sequence=self.transpose_batch_sequence, qkv_layout=qkv_layout, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) return x From 0b302d4ee250026eb7734c96f570689f772d9628 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Wed, 30 Oct 2024 20:34:08 +0000 Subject: [PATCH 2/2] Clean up checks in unit test. Signed-off-by: Michael Goldfarb --- tests/jax/test_distributed_fused_attn.py | 41 +++++++++++++++--------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index d34afb340c..4ef56b1b10 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -423,21 +423,32 @@ def test_contex_parallel_self_attn( num_kv_heads = num_head // kv_groups scaling_factor = 1.0 / np.sqrt(num_head) - if not is_fused_attn_kernel_available( - dtype, - dtype, - qkv_layout, - attn_bias_type, - attn_mask_type, - dropout_prob, - num_head, - num_kv_heads, - seqlen, - seqlen, - hidden, - None, # no window - ): - pytest.skip(f"No FusedAttn backend found") + def check_has_backend_for_mask(mask_type): + return is_fused_attn_kernel_available( + dtype, + dtype, + qkv_layout, + attn_bias_type, + attn_mask_type, + dropout_prob, + num_head, + num_kv_heads, + seqlen, + seqlen, + hidden, + None, + ) # no SWA for CP + + # For causal masking we depend on having bottom right support also. + # The API does not check this and instead we rely on lower level checks to raise + # and exception if the step backend is not supported. This was a deliberate API + # decision to keep the CP size or flag out of the function. + has_backend = check_has_backend_for_mask(attn_mask_type) + if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK: + has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK) + + if not has_backend: + pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.") if dp_size > 1 and batch % dp_size != 0: pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")