35
35
get_qkv_format ,
36
36
reorder_causal_load_balancing ,
37
37
inverse_reorder_causal_load_balancing ,
38
+ CPStrategy ,
38
39
)
40
+ from transformer_engine .jax .sharding import MeshResource
39
41
40
42
# We will use the golden reference model from our non distributed attention test fixture.
41
43
from test_fused_attn import general_dot_product_attention , make_mask
@@ -333,6 +335,36 @@ def ref_func(query, kv, mask):
333
335
)
334
336
335
337
338
+ @pytest .mark .parametrize (
339
+ "device_count,mesh_shape,mesh_axes,mesh_resource" , generate_context_parallel_configs ()
340
+ )
341
+ @pytest .mark .parametrize (
342
+ "data_shape" ,
343
+ [
344
+ pytest .param ([2 , 512 , 12 , 128 ], id = "2-512-12-128" ),
345
+ pytest .param ([4 , 1024 , 16 , 64 ], id = "4-1024-16-64" ),
346
+ ],
347
+ )
348
+ @pytest .mark .parametrize ("kv_groups" , [1 , 4 , 8 , 12 , 16 ])
349
+ @pytest .mark .parametrize (
350
+ "attn_mask_type" ,
351
+ [
352
+ pytest .param (AttnMaskType .CAUSAL_MASK , id = "CAUSAL_MASK" ),
353
+ pytest .param (AttnMaskType .NO_MASK , id = "NO_MASK" ),
354
+ ],
355
+ )
356
+ @pytest .mark .parametrize ("dtype" , [jnp .bfloat16 ])
357
+ @pytest .mark .parametrize (
358
+ "qkv_layout" ,
359
+ [
360
+ pytest .param (QKVLayout .BSHD_BS2HD , id = "COMBINED_KV" ),
361
+ pytest .param (QKVLayout .BSHD_BSHD_BSHD , id = "SEPARATE" ),
362
+ ],
363
+ )
364
+ @pytest .mark .parametrize (
365
+ "load_balanced" ,
366
+ [pytest .param (False , id = "UNBALANCED" ), pytest .param (True , id = "BALANCED" )],
367
+ )
336
368
class TestDistributedContextParallelSelfAttn :
337
369
338
370
def generate_inputs (self , shape , kv_groups : int , attn_mask_type : AttnMaskType , dtype ):
@@ -370,37 +402,7 @@ def qkv_to_layout(self, q, k, v, qkv_layout):
370
402
raise ValueError (f"Unsupported { qkv_layout = } " )
371
403
return qkv_args
372
404
373
- @pytest .mark .parametrize (
374
- "device_count,mesh_shape,mesh_axes,mesh_resource" , generate_context_parallel_configs ()
375
- )
376
- @pytest .mark .parametrize (
377
- "data_shape" ,
378
- [
379
- pytest .param ([2 , 512 , 12 , 128 ], id = "2-512-12-128" ),
380
- pytest .param ([4 , 1024 , 16 , 64 ], id = "4-1024-16-64" ),
381
- ],
382
- )
383
- @pytest .mark .parametrize ("kv_groups" , [1 , 4 , 8 , 12 , 16 ])
384
- @pytest .mark .parametrize (
385
- "attn_mask_type" ,
386
- [
387
- pytest .param (AttnMaskType .CAUSAL_MASK , id = "CAUSAL_MASK" ),
388
- pytest .param (AttnMaskType .NO_MASK , id = "NO_MASK" ),
389
- ],
390
- )
391
- @pytest .mark .parametrize ("dtype" , [jnp .bfloat16 ])
392
- @pytest .mark .parametrize (
393
- "qkv_layout" ,
394
- [
395
- pytest .param (QKVLayout .BSHD_BS2HD , id = "COMBINED_KV" ),
396
- pytest .param (QKVLayout .BSHD_BSHD_BSHD , id = "SEPARATE" ),
397
- ],
398
- )
399
- @pytest .mark .parametrize (
400
- "load_balanced" ,
401
- [pytest .param (False , id = "UNBALANCED" ), pytest .param (True , id = "BALANCED" )],
402
- )
403
- def test_contex_parallel_self_attn (
405
+ def impl_test_contex_parallel_attn (
404
406
self ,
405
407
device_count ,
406
408
mesh_shape ,
@@ -412,6 +414,7 @@ def test_contex_parallel_self_attn(
412
414
dtype ,
413
415
qkv_layout ,
414
416
load_balanced ,
417
+ cp_strategy ,
415
418
):
416
419
attn_bias_type = AttnBiasType .NO_BIAS
417
420
dropout_prob = 0.0
@@ -469,6 +472,7 @@ def target_func(q, k, v, mask):
469
472
scaling_factor = scaling_factor ,
470
473
dropout_probability = dropout_prob ,
471
474
is_training = is_training ,
475
+ context_parallel_strategy = cp_strategy ,
472
476
context_parallel_causal_load_balanced = load_balanced ,
473
477
context_parallel_axis = "cp" ,
474
478
).astype (dtype )
@@ -574,6 +578,60 @@ def grad_func(func, *args, **kwargs):
574
578
575
579
assert_allclose (target_grads [i ], ref_grads [i ], dtype = dtype )
576
580
581
+ def test_contex_parallel_allgather_attn (
582
+ self ,
583
+ device_count ,
584
+ mesh_shape ,
585
+ mesh_axes ,
586
+ mesh_resource ,
587
+ data_shape ,
588
+ kv_groups ,
589
+ attn_mask_type ,
590
+ dtype ,
591
+ qkv_layout ,
592
+ load_balanced ,
593
+ ):
594
+ return self .impl_test_contex_parallel_attn (
595
+ device_count ,
596
+ mesh_shape ,
597
+ mesh_axes ,
598
+ mesh_resource ,
599
+ data_shape ,
600
+ kv_groups ,
601
+ attn_mask_type ,
602
+ dtype ,
603
+ qkv_layout ,
604
+ load_balanced ,
605
+ CPStrategy .ALL_GATHER ,
606
+ )
607
+
608
+ def test_context_parallel_ring_attn (
609
+ self ,
610
+ device_count ,
611
+ mesh_shape ,
612
+ mesh_axes ,
613
+ mesh_resource ,
614
+ data_shape ,
615
+ kv_groups ,
616
+ attn_mask_type ,
617
+ dtype ,
618
+ qkv_layout ,
619
+ load_balanced ,
620
+ ):
621
+ return self .impl_test_contex_parallel_attn (
622
+ device_count ,
623
+ mesh_shape ,
624
+ mesh_axes ,
625
+ mesh_resource ,
626
+ data_shape ,
627
+ kv_groups ,
628
+ attn_mask_type ,
629
+ dtype ,
630
+ qkv_layout ,
631
+ load_balanced ,
632
+ CPStrategy .RING ,
633
+ )
634
+
577
635
578
636
class TestReorderCausalLoadBalancing :
579
637
@pytest .mark .parametrize ("cp_size" , [2 , 4 , 8 ])
0 commit comments