Skip to content

Commit d346d9c

Browse files
authored
Merge branch 'NVIDIA:main' into fused_out_correction
2 parents 89bbeb7 + 838345e commit d346d9c

File tree

30 files changed

+950
-385
lines changed

30 files changed

+950
-385
lines changed

.github/workflows/trigger-ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ jobs:
4242
|| github.actor == 'kocchop'
4343
|| github.actor == 'youngeunkwon0405'
4444
|| github.actor == 'KshitijLakhani'
45+
|| github.actor == 'jberchtold-nvidia'
4546
)
4647
steps:
4748
- name: Check if comment is issued by authorized person

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 43 files

examples/pytorch/comm_gemm_overlap/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
Forward and backward passes with layer weights distributed over all GPUs in a single node.
1717

1818
```bash
19-
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py
19+
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py
2020

2121
# Sample output on 8x H100s:
2222
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7]
@@ -70,7 +70,7 @@ Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across
7070
groups in a single node.
7171

7272
```bash
73-
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2
73+
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py --num-replicas 2
7474

7575
# Sample output on 8x H100s:
7676
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
set -xe
6+
7+
: ${TE_PATH:=/opt/transformerengine}
8+
9+
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
10+
11+
# Make encoder tests to have run-to-run deterministic to have the stable CI results
12+
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
13+
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py
14+
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py
15+
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py

qa/L0_jax_unittest/test.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,4 @@ pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist
2020

2121
# Make encoder tests to have run-to-run deterministic to have the stable CI results
2222
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
23-
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
24-
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
23+
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py

qa/L0_pytorch_unittest/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
1313
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
1414
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
1515
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
16-
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
1716
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
1817
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
1918
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
@@ -22,3 +21,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
2221
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
2322
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
2423
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
24+
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py

qa/L1_pytorch_distributed_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
1111
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
1212
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
1313
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
14+
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
1415
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py

tests/jax/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def clear_live_arrays():
2020

2121

2222
@pytest.fixture(autouse=True, scope="module")
23-
def enable_fused_attn():
23+
def enable_fused_attn_after_hopper():
2424
"""
2525
Enable fused attn for hopper+ arch.
2626
Fused attn kernels on pre-hopper arch are not deterministic.

tests/jax/test_distributed_fused_attn.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from utils import (
2121
make_causal_mask,
2222
make_self_mask,
23-
assert_tree_like_allclose,
2423
assert_allclose,
2524
print_debug_tensor_stats,
2625
)
@@ -32,7 +31,6 @@
3231
AttnMaskType,
3332
QKVLayout,
3433
QKVFormat,
35-
get_qkv_format,
3634
reorder_causal_load_balancing,
3735
inverse_reorder_causal_load_balancing,
3836
CPStrategy,
@@ -421,7 +419,7 @@ def impl_test_contex_parallel_attn(
421419
dropout_prob = 0.0
422420
is_training = True
423421
dp_size, cp_size, tp_size = mesh_shape
424-
qkv_format = get_qkv_format(qkv_layout)
422+
qkv_format = qkv_layout.get_qkv_format()
425423

426424
batch, seqlen, num_head, hidden = data_shape
427425

@@ -503,7 +501,7 @@ def grad_func(func, *args, **kwargs):
503501
# Gradient is small, use a gradient multiplier to amplify the gradient
504502
_, max_seq_len, num_heads, _ = data_shape
505503
gradient_multiplier = max_seq_len * num_heads
506-
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]:
504+
if attn_mask_type.is_causal():
507505
gradient_multiplier /= 10
508506
ret_valid = func(*args, **kwargs)
509507
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)

0 commit comments

Comments
 (0)