From d3cbccdf9e98a5b3eb756a61e5c1a744b6daf06f Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Wed, 4 Dec 2024 09:52:54 -0600 Subject: [PATCH] [JAX] Scale sequence length in CP tests to avoid tiny sizes. (#1347) Scale sequence length in CP tests to avoid tiny sizes. Signed-off-by: Michael Goldfarb --- tests/jax/test_distributed_fused_attn.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 7ef0d68474..e194a228d2 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -341,8 +341,9 @@ def ref_func(query, kv, mask): @pytest.mark.parametrize( "data_shape", [ - pytest.param([2, 512, 12, 128], id="2-512-12-128"), - pytest.param([4, 1024, 16, 64], id="4-1024-16-64"), + # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. + pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"), + pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), ], ) @pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) @@ -423,6 +424,12 @@ def impl_test_contex_parallel_attn( qkv_format = get_qkv_format(qkv_layout) batch, seqlen, num_head, hidden = data_shape + + # Scale the sequence length by 2*CP so its never too small as we scale up test. + # 2*CP is used since we split into two CP groups for load balancing. + seqlen = seqlen * cp_size * 2 + data_shape = batch, seqlen, num_head, hidden + num_kv_heads = num_head // kv_groups scaling_factor = 1.0 / np.sqrt(num_head)