From 4b605ed38131227febaf6f56475c8928f14c28f2 Mon Sep 17 00:00:00 2001 From: sichu Date: Wed, 27 Nov 2024 19:25:14 +0000 Subject: [PATCH] add spawn cuda process testing --- .../test_megatron_parallel_state_utils.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/sub-packages/bionemo-testing/tests/bionemo/testing/test_megatron_parallel_state_utils.py b/sub-packages/bionemo-testing/tests/bionemo/testing/test_megatron_parallel_state_utils.py index 3c25b4cf43..52cb2706ae 100644 --- a/sub-packages/bionemo-testing/tests/bionemo/testing/test_megatron_parallel_state_utils.py +++ b/sub-packages/bionemo-testing/tests/bionemo/testing/test_megatron_parallel_state_utils.py @@ -13,11 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager +import pytest import torch import torch.distributed as dist +import torch.multiprocessing.spawn from megatron.core import parallel_state from nemo import lightning as nl +from pytest import MonkeyPatch from bionemo.testing import megatron_parallel_state_utils @@ -106,3 +110,47 @@ def test_reduce_scatter(): dist.reduce_scatter(output_tensor, to_reduce_scatter) assert tuple(output_tensor.shape) == (3, 3) + + +# def test_all_reduce_sum(): +# with megatron_parallel_state_utils.mock_distributed_parallel_state(world_size=2, rank=1): +# tensor = torch.tensor([dist.get_rank()+1]) +# dist.all_reduce(tensor) +# assert tensor.item() == (1+2) / 2 # TODO does not work; there is no barrier for the actual communication; got 2 + + +# move to src +@contextmanager +def dist_environment( + world_size: int = 1, + rank: int = 1, +): + with MonkeyPatch.context() as context: + torch.cuda.empty_cache() + parallel_state.destroy_model_parallel() + + context.setenv("MASTER_ADDR", "localhost") + context.setenv("MASTER_PORT", "29500") + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + yield + dist.destroy_process_group() + torch.cuda.empty_cache() + parallel_state.destroy_model_parallel() + + +def _test_all_reduce_sum(rank: int, world_size: int): + with dist_environment(rank=rank, world_size=world_size): + device = torch.device(f"cuda:{rank}") + tensor = torch.tensor([rank + 1], device=device) + dist.all_reduce(tensor) + assert tensor.item() == world_size * (world_size + 1) / 2 + + +@pytest.mark.skipif(torch.cuda.device_count() > 1, reason=f"Requires 2 devices but got {torch.cuda.device_count()}") +def test_all_reduce_sum(): + world_size = 2 + torch.multiprocessing.spawn( + fn=_test_all_reduce_sum, + args=(world_size,), + nprocs=world_size, + )