Skip to content

Commit 0ae1951

Browse files
authored
Merge branch 'main' into fused_attn/graph_api_v1
Signed-off-by: cyanguwa <[email protected]>
2 parents 282e004 + 4d444db commit 0ae1951

File tree

6 files changed

+270
-214
lines changed

6 files changed

+270
-214
lines changed

tests/jax/test_distributed_fused_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def target_func(q, kv, mask):
197197
return jnp.mean(
198198
cross_fused_attn(q,
199199
kv,
200+
None,
200201
mask,
201202
None,
202203
attn_bias_type=attn_bias_type,

tests/jax/test_fused_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
163163
# mask invert
164164
mask = (mask == 0)
165165

166-
return cross_fused_attn(q, kv, mask, dropout_rng, **kwargs)
166+
return cross_fused_attn(q, kv, None, mask, dropout_rng, **kwargs)
167167

168168

169169
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)

0 commit comments

Comments
 (0)