From 99b6bd25249150a00bae4745d581ec29ecc6e50d Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 6 Dec 2024 10:07:37 -0800 Subject: [PATCH] fix thread fork Signed-off-by: Phuong Nguyen --- .../encoder/test_multiprocessing_encoder.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 5f5571672c..5509d01e98 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -552,14 +552,28 @@ def encoder_parser(args): return parser.parse_args(args) -def unittest_query_gpu(): +def query_gpu(q): """Query GPU info on the system""" gpu_has_fp8 = is_fp8_available() gpu_has_bf16 = is_bf16_supported() # Avoid using len(jax.devices()) here as jax does not allow starting any jax computational graph # before calling jax.distributed.initialize() num_gpus = len(subprocess.check_output(["nvidia-smi", "-L"]).decode().strip().split("\n")) - return (num_gpus, gpu_has_fp8, gpu_has_bf16) + q.put([num_gpus, gpu_has_fp8, gpu_has_bf16]) + +def unittest_query_gpu(): + r""" + It is only used by TestEncoder. + The `jax.distributed.initialize` must be called before any other JAX or Flax API, + otherwise `jax.local_devices` will be incorrect. + Thus, fork another process to query number of GPUs and FP8 capability. + """ + q = mp.Queue() + p = mp.Process(target=query_gpu, args=(q,)) + p.start() + num_gpu, gpu_has_fp8, gpu_has_bf16 = q.get() + p.join() + return num_gpu, gpu_has_fp8, gpu_has_bf16 class TestEncoder(unittest.TestCase):