Skip to content

Commit

Permalink
Merge branch 'main' into debug-mcore-test
Browse files Browse the repository at this point in the history
  • Loading branch information
timmoon10 authored Dec 5, 2024
2 parents 101ec84 + d3cbccd commit 07a5d1e
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,9 @@ def ref_func(query, kv, mask):
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
Expand Down Expand Up @@ -423,6 +424,12 @@ def impl_test_contex_parallel_attn(
qkv_format = get_qkv_format(qkv_layout)

batch, seqlen, num_head, hidden = data_shape

# Scale the sequence length by 2*CP so its never too small as we scale up test.
# 2*CP is used since we split into two CP groups for load balancing.
seqlen = seqlen * cp_size * 2
data_shape = batch, seqlen, num_head, hidden

num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)

Expand Down

0 comments on commit 07a5d1e

Please sign in to comment.