Skip to content

Commit fab62cc

Browse files
kocchopmgoldfarb-nvidia
authored andcommitted
[JAX] Expose context parallel params to jax DPA api (NVIDIA#1292)
Exposed context parallel params to DPA api Signed-off-by: Md Fahim Faysal Khan <[email protected]> Signed-off-by: Michael Goldfarb <[email protected]> --------- Signed-off-by: Md Fahim Faysal Khan <[email protected]> Signed-off-by: Michael Goldfarb <[email protected]> Co-authored-by: Michael Goldfarb <[email protected]>
1 parent 5fac3ee commit fab62cc

File tree

3 files changed

+41
-24
lines changed

3 files changed

+41
-24
lines changed

tests/jax/test_distributed_fused_attn.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def test_self_attn(
133133
seqlen,
134134
hidden,
135135
None, # no window
136-
False, # not context parallel
137136
):
138137
pytest.skip(f"No FusedAttn backend found")
139138

@@ -268,7 +267,6 @@ def test_cross_attn(
268267
seqlen,
269268
hidden,
270269
None, # no window
271-
False, # not context parallel
272270
):
273271
pytest.skip(f"No FusedAttn backend found")
274272

@@ -425,22 +423,32 @@ def test_contex_parallel_self_attn(
425423
num_kv_heads = num_head // kv_groups
426424
scaling_factor = 1.0 / np.sqrt(num_head)
427425

428-
if not is_fused_attn_kernel_available(
429-
dtype,
430-
dtype,
431-
qkv_layout,
432-
attn_bias_type,
433-
attn_mask_type,
434-
dropout_prob,
435-
num_head,
436-
num_kv_heads,
437-
seqlen,
438-
seqlen,
439-
hidden,
440-
None, # no window
441-
cp_size > 1,
442-
):
443-
pytest.skip(f"No FusedAttn backend found")
426+
def check_has_backend_for_mask(mask_type):
427+
return is_fused_attn_kernel_available(
428+
dtype,
429+
dtype,
430+
qkv_layout,
431+
attn_bias_type,
432+
attn_mask_type,
433+
dropout_prob,
434+
num_head,
435+
num_kv_heads,
436+
seqlen,
437+
seqlen,
438+
hidden,
439+
None,
440+
) # no SWA for CP
441+
442+
# For causal masking we depend on having bottom right support also.
443+
# The API does not check this and instead we rely on lower level checks to raise
444+
# and exception if the step backend is not supported. This was a deliberate API
445+
# decision to keep the CP size or flag out of the function.
446+
has_backend = check_has_backend_for_mask(attn_mask_type)
447+
if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
448+
has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)
449+
450+
if not has_backend:
451+
pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
444452

445453
if dp_size > 1 and batch % dp_size != 0:
446454
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")

transformer_engine/jax/attention.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def is_fused_attn_kernel_available(
190190
kv_max_seqlen,
191191
head_dim,
192192
window_size: Optional[Tuple[int, int]] = None,
193-
is_context_parallel: bool = False,
194193
):
195194
"""
196195
To check whether the fused attention kernel is supported
@@ -215,11 +214,6 @@ def make_helper(attn_mask_type):
215214
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
216215
return False
217216

218-
# For context parallel need to check additional masking types
219-
if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
220-
if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
221-
return False
222-
223217
return True
224218

225219

transformer_engine/jax/flax/transformer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
262262
scale_factor: Optional[float] = None
263263
transpose_batch_sequence: bool = False
264264
window_size: Optional[Tuple[int, int]] = None
265+
context_parallel_causal_load_balanced: bool = False
266+
context_parallel_axis: str = ""
265267

266268
@nn.compact
267269
def __call__(
@@ -308,6 +310,8 @@ def __call__(
308310
dropout_probability=self.attention_dropout,
309311
is_training=not deterministic,
310312
window_size=self.window_size,
313+
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
314+
context_parallel_axis=self.context_parallel_axis,
311315
)
312316
elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
313317
"""kvpacked format, treat
@@ -331,6 +335,8 @@ def __call__(
331335
dropout_probability=self.attention_dropout,
332336
is_training=not deterministic,
333337
window_size=self.window_size,
338+
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
339+
context_parallel_axis=self.context_parallel_axis,
334340
)
335341
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
336342
if self.transpose_batch_sequence:
@@ -349,6 +355,8 @@ def __call__(
349355
dropout_probability=self.attention_dropout,
350356
is_training=not deterministic,
351357
window_size=self.window_size,
358+
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
359+
context_parallel_axis=self.context_parallel_axis,
352360
)
353361
else:
354362
raise ValueError(f"Unsupported {self.qkv_layout=}.")
@@ -463,6 +471,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
463471
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
464472
window_size: Optional[Tuple[int, int]], default = None
465473
Sliding window size. The default value is no sliding window.
474+
context_parallel_causal_load_balanced (bool):
475+
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
476+
context_parallel_axis (str): The name of the context parallel axis.
466477
467478
Optimization parameters
468479
-----------------------
@@ -483,6 +494,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
483494
scale_factor: Optional[float] = None
484495
transpose_batch_sequence: bool = True
485496
window_size: Optional[Tuple[int, int]] = None
497+
context_parallel_causal_load_balanced: bool = False
498+
context_parallel_axis: str = ""
486499

487500
@nn.compact
488501
def __call__(
@@ -614,6 +627,8 @@ def __call__(
614627
transpose_batch_sequence=self.transpose_batch_sequence,
615628
qkv_layout=qkv_layout,
616629
window_size=self.window_size,
630+
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
631+
context_parallel_axis=self.context_parallel_axis,
617632
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
618633

619634
return x

0 commit comments

Comments
 (0)