Skip to content

Commit

Permalink
fix thread fork
Browse files Browse the repository at this point in the history
Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng committed Dec 11, 2024
1 parent 47aa3fa commit 99b6bd2
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,14 +552,28 @@ def encoder_parser(args):
return parser.parse_args(args)


def unittest_query_gpu():
def query_gpu(q):
"""Query GPU info on the system"""
gpu_has_fp8 = is_fp8_available()
gpu_has_bf16 = is_bf16_supported()
# Avoid using len(jax.devices()) here as jax does not allow starting any jax computational graph
# before calling jax.distributed.initialize()
num_gpus = len(subprocess.check_output(["nvidia-smi", "-L"]).decode().strip().split("\n"))
return (num_gpus, gpu_has_fp8, gpu_has_bf16)
q.put([num_gpus, gpu_has_fp8, gpu_has_bf16])

def unittest_query_gpu():
r"""
It is only used by TestEncoder.
The `jax.distributed.initialize` must be called before any other JAX or Flax API,
otherwise `jax.local_devices` will be incorrect.
Thus, fork another process to query number of GPUs and FP8 capability.
"""
q = mp.Queue()
p = mp.Process(target=query_gpu, args=(q,))
p.start()
num_gpu, gpu_has_fp8, gpu_has_bf16 = q.get()
p.join()
return num_gpu, gpu_has_fp8, gpu_has_bf16


class TestEncoder(unittest.TestCase):
Expand Down

0 comments on commit 99b6bd2

Please sign in to comment.