Skip to content

Commit

Permalink
Merge branch 'main' into jax_multi_test
Browse files Browse the repository at this point in the history
  • Loading branch information
phu0ngng committed Dec 18, 2024
2 parents 1b961b0 + f033498 commit 518d071
Show file tree
Hide file tree
Showing 23 changed files with 730 additions and 284 deletions.
1 change: 1 addition & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
|| github.actor == 'kocchop'
|| github.actor == 'youngeunkwon0405'
|| github.actor == 'KshitijLakhani'
|| github.actor == 'jberchtold-nvidia'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/comm_gemm_overlap/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Forward and backward passes with layer weights distributed over all GPUs in a single node.

```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py

# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7]
Expand Down Expand Up @@ -70,7 +70,7 @@ Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across
groups in a single node.

```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py --num-replicas 2

# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3]
Expand Down
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
2 changes: 1 addition & 1 deletion tests/jax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def clear_live_arrays():


@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
def enable_fused_attn_after_hopper():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
Expand Down
6 changes: 2 additions & 4 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from utils import (
make_causal_mask,
make_self_mask,
assert_tree_like_allclose,
assert_allclose,
print_debug_tensor_stats,
)
Expand All @@ -32,7 +31,6 @@
AttnMaskType,
QKVLayout,
QKVFormat,
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
Expand Down Expand Up @@ -421,7 +419,7 @@ def impl_test_contex_parallel_attn(
dropout_prob = 0.0
is_training = True
dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout)
qkv_format = qkv_layout.get_qkv_format()

batch, seqlen, num_head, hidden = data_shape

Expand Down Expand Up @@ -503,7 +501,7 @@ def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
_, max_seq_len, num_heads, _ = data_shape
gradient_multiplier = max_seq_len * num_heads
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]:
if attn_mask_type.is_causal():
gradient_multiplier /= 10
ret_valid = func(*args, **kwargs)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
Expand Down
Loading

0 comments on commit 518d071

Please sign in to comment.