Skip to content

Commit a6a55e4

Browse files
committed
Merge branch 'main' into activation-ops
2 parents 5a091c0 + c0a539c commit a6a55e4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1972
-660
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ jobs:
7676
name: 'PaddlePaddle'
7777
runs-on: ubuntu-latest
7878
container:
79-
image: nvcr.io/nvidia/paddlepaddle:24.07-py3
79+
image: nvcr.io/nvidia/paddlepaddle:24.10-py3
8080
options: --user root
8181
steps:
8282
- name: 'Checkout'

.github/workflows/trigger-ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ jobs:
3939
|| github.actor == 'pggPL'
4040
|| github.actor == 'vasunvidia'
4141
|| github.actor == 'erhoo82'
42+
|| github.actor == 'kocchop'
4243
)
4344
steps:
4445
- name: Check if comment is issued by authorized person

build_tools/paddle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def setup_paddle_extension(
2525
# Source files
2626
csrc_source_files = Path(csrc_source_files)
2727
sources = [
28-
csrc_source_files / "extensions.cu",
28+
csrc_source_files / "extensions.cpp",
2929
csrc_source_files / "common.cpp",
3030
csrc_source_files / "custom_ops.cu",
3131
]

build_tools/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def setup_pytorch_extension(
2626
csrc_source_files = Path(csrc_source_files)
2727
extensions_dir = csrc_source_files / "extensions"
2828
sources = [
29-
csrc_source_files / "common.cu",
29+
csrc_source_files / "common.cpp",
3030
csrc_source_files / "ts_fp8_op.cpp",
3131
] + all_files_in_dir(extensions_dir)
3232

qa/L0_jax_unittest/test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,7 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
1818

1919
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist
2020

21+
# Make encoder tests to have run-to-run deterministic to have the stable CI results
22+
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
2123
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
2224
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py

qa/L1_jax_distributed_unittest/test.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,11 @@
55
set -xe
66

77
: ${TE_PATH:=/opt/transformerengine}
8-
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
8+
9+
# Skip ring attention tests since they need fixed environment vars
10+
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn'
11+
12+
# Test ring attention with and without scan loop
13+
NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
14+
NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \
15+
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn

tests/jax/test_custom_call_compute.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from transformer_engine.jax.cpp_extensions.transpose import (
2323
_jax_transpose,
2424
_jax_cast_transpose,
25+
_jax_dbias_cast_transpose,
2526
)
2627
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
2728
from transformer_engine.jax import cpp_extensions as tex
@@ -504,7 +505,6 @@ def _prim_func_bwd(ctx, g):
504505
scale_inv,
505506
FP8Helper.BWD_DTYPE,
506507
-1,
507-
-2,
508508
self.activation_type,
509509
)
510510
)
@@ -812,6 +812,34 @@ def test_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
812812
assert_tree_like_allclose(jax_output, ffi_output)
813813
assert_tree_like_allclose(noffi_output, ffi_output)
814814

815+
@pytest.mark.parametrize(
816+
"out_dtype",
817+
[
818+
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
819+
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
820+
],
821+
)
822+
def test_dbias_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
823+
amax = jnp.zeros(1, jnp.float32)
824+
scale = jnp.ones(1, jnp.float32)
825+
scale_inv = jnp.ones(1, jnp.float32)
826+
key = jax.random.PRNGKey(0)
827+
input = jax.random.uniform(key, input_shape, in_dtype)
828+
static_axis_boundary = -1
829+
jax_output = _jax_dbias_cast_transpose(
830+
input, amax, scale, out_dtype, static_axis_boundary, transpose_axis
831+
)
832+
os.environ["NVTE_JAX_WITH_FFI"] = "0"
833+
noffi_output = tex.dbias_cast_transpose(
834+
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
835+
)
836+
os.environ["NVTE_JAX_WITH_FFI"] = "1"
837+
ffi_output = tex.dbias_cast_transpose(
838+
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
839+
)
840+
assert_tree_like_allclose(jax_output, ffi_output)
841+
assert_tree_like_allclose(noffi_output, ffi_output)
842+
815843

816844
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
817845
@pytest.mark.parametrize(

tests/jax/test_distributed_fused_attn.py

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
get_qkv_format,
3636
reorder_causal_load_balancing,
3737
inverse_reorder_causal_load_balancing,
38+
CPStrategy,
3839
)
40+
from transformer_engine.jax.sharding import MeshResource
3941

4042
# We will use the golden reference model from our non distributed attention test fixture.
4143
from test_fused_attn import general_dot_product_attention, make_mask
@@ -333,6 +335,36 @@ def ref_func(query, kv, mask):
333335
)
334336

335337

338+
@pytest.mark.parametrize(
339+
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
340+
)
341+
@pytest.mark.parametrize(
342+
"data_shape",
343+
[
344+
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
345+
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
346+
],
347+
)
348+
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
349+
@pytest.mark.parametrize(
350+
"attn_mask_type",
351+
[
352+
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
353+
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
354+
],
355+
)
356+
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
357+
@pytest.mark.parametrize(
358+
"qkv_layout",
359+
[
360+
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
361+
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
362+
],
363+
)
364+
@pytest.mark.parametrize(
365+
"load_balanced",
366+
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
367+
)
336368
class TestDistributedContextParallelSelfAttn:
337369

338370
def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
@@ -370,37 +402,7 @@ def qkv_to_layout(self, q, k, v, qkv_layout):
370402
raise ValueError(f"Unsupported {qkv_layout=}")
371403
return qkv_args
372404

373-
@pytest.mark.parametrize(
374-
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
375-
)
376-
@pytest.mark.parametrize(
377-
"data_shape",
378-
[
379-
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
380-
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
381-
],
382-
)
383-
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
384-
@pytest.mark.parametrize(
385-
"attn_mask_type",
386-
[
387-
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
388-
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
389-
],
390-
)
391-
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
392-
@pytest.mark.parametrize(
393-
"qkv_layout",
394-
[
395-
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
396-
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
397-
],
398-
)
399-
@pytest.mark.parametrize(
400-
"load_balanced",
401-
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
402-
)
403-
def test_contex_parallel_self_attn(
405+
def impl_test_contex_parallel_attn(
404406
self,
405407
device_count,
406408
mesh_shape,
@@ -412,6 +414,7 @@ def test_contex_parallel_self_attn(
412414
dtype,
413415
qkv_layout,
414416
load_balanced,
417+
cp_strategy,
415418
):
416419
attn_bias_type = AttnBiasType.NO_BIAS
417420
dropout_prob = 0.0
@@ -469,6 +472,7 @@ def target_func(q, k, v, mask):
469472
scaling_factor=scaling_factor,
470473
dropout_probability=dropout_prob,
471474
is_training=is_training,
475+
context_parallel_strategy=cp_strategy,
472476
context_parallel_causal_load_balanced=load_balanced,
473477
context_parallel_axis="cp",
474478
).astype(dtype)
@@ -574,6 +578,60 @@ def grad_func(func, *args, **kwargs):
574578

575579
assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)
576580

581+
def test_contex_parallel_allgather_attn(
582+
self,
583+
device_count,
584+
mesh_shape,
585+
mesh_axes,
586+
mesh_resource,
587+
data_shape,
588+
kv_groups,
589+
attn_mask_type,
590+
dtype,
591+
qkv_layout,
592+
load_balanced,
593+
):
594+
return self.impl_test_contex_parallel_attn(
595+
device_count,
596+
mesh_shape,
597+
mesh_axes,
598+
mesh_resource,
599+
data_shape,
600+
kv_groups,
601+
attn_mask_type,
602+
dtype,
603+
qkv_layout,
604+
load_balanced,
605+
CPStrategy.ALL_GATHER,
606+
)
607+
608+
def test_context_parallel_ring_attn(
609+
self,
610+
device_count,
611+
mesh_shape,
612+
mesh_axes,
613+
mesh_resource,
614+
data_shape,
615+
kv_groups,
616+
attn_mask_type,
617+
dtype,
618+
qkv_layout,
619+
load_balanced,
620+
):
621+
return self.impl_test_contex_parallel_attn(
622+
device_count,
623+
mesh_shape,
624+
mesh_axes,
625+
mesh_resource,
626+
data_shape,
627+
kv_groups,
628+
attn_mask_type,
629+
dtype,
630+
qkv_layout,
631+
load_balanced,
632+
CPStrategy.RING,
633+
)
634+
577635

578636
class TestReorderCausalLoadBalancing:
579637
@pytest.mark.parametrize("cp_size", [2, 4, 8])

tests/jax/test_fused_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import partial
88
from math import sqrt
99
from typing import Tuple, Optional
10+
import random
1011

1112
import jax
1213
import jax.numpy as jnp

tests/jax/test_misc.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
import pytest
6+
from functools import partial
7+
import os
8+
9+
from transformer_engine.jax.cpp_extensions.misc import get_xla_flag
10+
11+
12+
@pytest.fixture(autouse=True, scope="function")
13+
def preserve_xla_flags():
14+
"""Ensures the XLA flags environment variable is restored after any tests in this file run."""
15+
old_flags = os.getenv("XLA_FLAGS")
16+
yield
17+
if old_flags is not None:
18+
os.environ["XLA_FLAGS"] = old_flags
19+
20+
21+
def test_get_xla_flag(request):
22+
os.environ["XLA_FLAGS"] = ""
23+
assert get_xla_flag("") is None
24+
assert get_xla_flag("--foo") is None
25+
assert get_xla_flag("--bar=1") is None
26+
27+
os.environ["XLA_FLAGS"] = "--foo --bar=1 --baz=biz"
28+
assert get_xla_flag("--foo") == True
29+
assert get_xla_flag("--bar") == "1"
30+
assert get_xla_flag("--bar", cast=int) == 1
31+
assert get_xla_flag("--bar", cast=bool) == True
32+
assert get_xla_flag("--baz") == "biz"
33+
with pytest.raises(ValueError):
34+
# cast will fail
35+
assert get_xla_flag("--baz", cast=int)
36+
assert get_xla_flag("--xla") is None
37+
38+
os.environ["XLA_FLAGS"] = "--xla_abc --xla_abb"
39+
assert get_xla_flag("--xla_abc") == True
40+
assert get_xla_flag("--xla_abb") == True

0 commit comments

Comments
 (0)