From a3b32ec6cb15dac8dc96ae03e40f51dfd072f195 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 18 Dec 2024 10:47:36 -0500 Subject: [PATCH] [JAX] Move parallel encoder tests to L0 distributed test set. (#1356) * Move test distributed encoder to L0 distributed test suit --------- Signed-off-by: Phuong Nguyen Co-authored-by: Reese Wang --- qa/L0_jax_distributed_unittest/test.sh | 15 +++++++++++++++ qa/L0_jax_unittest/test.sh | 3 +-- 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 qa/L0_jax_distributed_unittest/test.sh diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh new file mode 100644 index 0000000000..f9e16793a4 --- /dev/null +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -0,0 +1,15 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -xe + +: ${TE_PATH:=/opt/transformerengine} + +pip install -r $TE_PATH/examples/jax/encoder/requirements.txt + +# Make encoder tests to have run-to-run deterministic to have the stable CI results +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index db3aa31951..278a3c8b44 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -20,5 +20,4 @@ pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py