@@ -262,6 +262,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
262
262
scale_factor : Optional [float ] = None
263
263
transpose_batch_sequence : bool = False
264
264
window_size : Optional [Tuple [int , int ]] = None
265
+ context_parallel_causal_load_balanced : bool = False
266
+ context_parallel_axis : str = ""
265
267
266
268
@nn .compact
267
269
def __call__ (
@@ -308,6 +310,8 @@ def __call__(
308
310
dropout_probability = self .attention_dropout ,
309
311
is_training = not deterministic ,
310
312
window_size = self .window_size ,
313
+ context_parallel_causal_load_balanced = self .context_parallel_causal_load_balanced ,
314
+ context_parallel_axis = self .context_parallel_axis ,
311
315
)
312
316
elif self .qkv_layout == QKVLayout .BSHD_BS2HD :
313
317
"""kvpacked format, treat
@@ -331,6 +335,8 @@ def __call__(
331
335
dropout_probability = self .attention_dropout ,
332
336
is_training = not deterministic ,
333
337
window_size = self .window_size ,
338
+ context_parallel_causal_load_balanced = self .context_parallel_causal_load_balanced ,
339
+ context_parallel_axis = self .context_parallel_axis ,
334
340
)
335
341
elif self .qkv_layout == QKVLayout .BSHD_BSHD_BSHD :
336
342
if self .transpose_batch_sequence :
@@ -349,6 +355,8 @@ def __call__(
349
355
dropout_probability = self .attention_dropout ,
350
356
is_training = not deterministic ,
351
357
window_size = self .window_size ,
358
+ context_parallel_causal_load_balanced = self .context_parallel_causal_load_balanced ,
359
+ context_parallel_axis = self .context_parallel_axis ,
352
360
)
353
361
else :
354
362
raise ValueError (f"Unsupported { self .qkv_layout = } ." )
@@ -463,6 +471,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
463
471
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
464
472
window_size: Optional[Tuple[int, int]], default = None
465
473
Sliding window size. The default value is no sliding window.
474
+ context_parallel_causal_load_balanced (bool):
475
+ Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
476
+ context_parallel_axis (str): The name of the context parallel axis.
466
477
467
478
Optimization parameters
468
479
-----------------------
@@ -483,6 +494,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
483
494
scale_factor : Optional [float ] = None
484
495
transpose_batch_sequence : bool = True
485
496
window_size : Optional [Tuple [int , int ]] = None
497
+ context_parallel_causal_load_balanced : bool = False
498
+ context_parallel_axis : str = ""
486
499
487
500
@nn .compact
488
501
def __call__ (
@@ -614,6 +627,8 @@ def __call__(
614
627
transpose_batch_sequence = self .transpose_batch_sequence ,
615
628
qkv_layout = qkv_layout ,
616
629
window_size = self .window_size ,
630
+ context_parallel_causal_load_balanced = self .context_parallel_causal_load_balanced ,
631
+ context_parallel_axis = self .context_parallel_axis ,
617
632
)(query , key , value , mask , bias , dropout_rng = dropout_rng , deterministic = deterministic )
618
633
619
634
return x
0 commit comments