Skip to content

Commit

Permalink
Merge branch 'main' into swa_left_brcm_padding
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa authored Dec 18, 2024
2 parents cdf5d56 + f033498 commit 0581953
Show file tree
Hide file tree
Showing 11 changed files with 356 additions and 249 deletions.
2 changes: 1 addition & 1 deletion tests/jax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def clear_live_arrays():


@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
def enable_fused_attn_after_hopper():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
Expand Down
6 changes: 2 additions & 4 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from utils import (
make_causal_mask,
make_self_mask,
assert_tree_like_allclose,
assert_allclose,
print_debug_tensor_stats,
)
Expand All @@ -32,7 +31,6 @@
AttnMaskType,
QKVLayout,
QKVFormat,
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
Expand Down Expand Up @@ -421,7 +419,7 @@ def impl_test_contex_parallel_attn(
dropout_prob = 0.0
is_training = True
dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout)
qkv_format = qkv_layout.get_qkv_format()

batch, seqlen, num_head, hidden = data_shape

Expand Down Expand Up @@ -503,7 +501,7 @@ def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
_, max_seq_len, num_heads, _ = data_shape
gradient_multiplier = max_seq_len * num_heads
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]:
if attn_mask_type.is_causal():
gradient_multiplier /= 10
ret_valid = func(*args, **kwargs)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
Expand Down
Loading

0 comments on commit 0581953

Please sign in to comment.