diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index f2ffd94e7..1350dca28 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -7,7 +7,7 @@ Routes to triton/ or gluon/ based on config.use_gluon. """ -from iris.ccl.utils import extract_group_info +from iris.ccl.utils import extract_group_info, _validate_output_symmetric def all_gather(output_tensor, input_tensor, ctx, group=None, async_op=False, config=None): @@ -16,8 +16,12 @@ def all_gather(output_tensor, input_tensor, ctx, group=None, async_op=False, con Output is (world_size * M, N) — inputs concatenated along dim 0. + The output tensor must be on the symmetric heap (other ranks write to it + via RMA). The input tensor is only read locally, so it does not need to + be symmetric. + Args: - output_tensor: Shape (world_size * M, N) + output_tensor: Shape (world_size * M, N) — must be on symmetric heap input_tensor: Shape (M, N) ctx: Iris instance group: ProcessGroup or None @@ -29,6 +33,9 @@ def all_gather(output_tensor, input_tensor, ctx, group=None, async_op=False, con if config is None: config = Config(block_size_m=32, block_size_n=64) + # Output is written remotely by other ranks — must be pre-allocated on symmetric heap + _validate_output_symmetric(ctx, output_tensor, "output_tensor") + rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx) M, N = input_tensor.shape[:2] diff --git a/iris/ccl/all_reduce.py b/iris/ccl/all_reduce.py index 4754dd178..dd7ec16d4 100644 --- a/iris/ccl/all_reduce.py +++ b/iris/ccl/all_reduce.py @@ -7,7 +7,7 @@ Triton only (no gluon support). """ -from iris.ccl.utils import extract_group_info +from iris.ccl.utils import extract_group_info, _ensure_symmetric, _validate_output_symmetric def all_reduce_preamble(output_tensor, input_tensor, ctx, config=None, workspace=None): @@ -21,6 +21,16 @@ def all_reduce(output_tensor, input_tensor, ctx, op=None, group=None, async_op=F """ All-reduce: sum inputs across all ranks, result on every rank. + Which tensors need the symmetric heap depends on the variant: + - atomic/spinlock: output is remote-accessed (input is local) + - one_shot: input is remote-read (output is local) + - two_shot: both are remote-accessed + - ring: neither (workspace ring_buffer/flags are, but those are internal) + + We ensure symmetric on the tensors that each variant actually accesses + remotely. Workspace tensors (ring_buffer, flags, locks) are allocated + internally on the heap already. + Args: output_tensor: Shape (M, N) input_tensor: Shape (M, N) @@ -55,6 +65,14 @@ def all_reduce(output_tensor, input_tensor, ctx, op=None, group=None, async_op=F if variant not in valid_variants: raise ValueError(f"Invalid all_reduce_variant: {variant}. Must be one of: {', '.join(valid_variants)}") + # Ensure/validate symmetric only for tensors that are remotely accessed per variant + if variant in ("one_shot", "two_shot"): + # Input is remote-read — auto-import if needed + input_tensor = _ensure_symmetric(ctx, input_tensor, "input_tensor") + if variant in ("atomic", "spinlock", "two_shot"): + # Output is remote-written — must be pre-allocated on heap + _validate_output_symmetric(ctx, output_tensor, "output_tensor") + rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx) from iris.ccl.triton.all_reduce import launch diff --git a/iris/ccl/all_to_all.py b/iris/ccl/all_to_all.py index aa2d5bd6c..caf841f3e 100644 --- a/iris/ccl/all_to_all.py +++ b/iris/ccl/all_to_all.py @@ -7,7 +7,7 @@ Routes to triton/ or gluon/ based on config.use_gluon. """ -from iris.ccl.utils import extract_group_info +from iris.ccl.utils import extract_group_info, _ensure_symmetric, _validate_output_symmetric def all_to_all(output_tensor, input_tensor, ctx, group=None, async_op=False, config=None): @@ -16,9 +16,12 @@ def all_to_all(output_tensor, input_tensor, ctx, group=None, async_op=False, con Input/output shape: (M, N * world_size). + Both tensors are accessed remotely (other ranks read input slices and + write to output slices via RMA), so both must be on the symmetric heap. + Args: - output_tensor: Shape (M, N * world_size) - input_tensor: Shape (M, N * world_size) + output_tensor: Shape (M, N * world_size) — must be on symmetric heap + input_tensor: Shape (M, N * world_size) — must be on symmetric heap ctx: Iris instance group: ProcessGroup or None async_op: If True, skip trailing barrier @@ -29,6 +32,11 @@ def all_to_all(output_tensor, input_tensor, ctx, group=None, async_op=False, con if config is None: config = Config(block_size_m=32, block_size_n=128) + # Input is remote-read by other ranks — auto-import if needed + input_tensor = _ensure_symmetric(ctx, input_tensor, "input_tensor") + # Output is remote-written by other ranks — must be pre-allocated on heap + _validate_output_symmetric(ctx, output_tensor, "output_tensor") + rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx) if config.use_gluon: diff --git a/iris/ccl/reduce_scatter.py b/iris/ccl/reduce_scatter.py index 4902aa33d..12471057a 100644 --- a/iris/ccl/reduce_scatter.py +++ b/iris/ccl/reduce_scatter.py @@ -7,16 +7,20 @@ Triton only (no gluon support). """ -from iris.ccl.utils import extract_group_info +from iris.ccl.utils import extract_group_info, _ensure_symmetric def reduce_scatter(output_tensor, input_tensor, ctx, op=None, group=None, async_op=False, config=None): """ Reduce-scatter: each rank reduces its assigned tiles, stores locally. + The input tensor is read remotely by all ranks (via iris.load). + The output tensor is only written locally — it does not need to be + on the symmetric heap. + Args: output_tensor: Shape (M, N) - input_tensor: Shape (M, N) + input_tensor: Shape (M, N) — must be on symmetric heap ctx: Iris instance op: ReduceOp (only SUM supported) group: ProcessGroup or None @@ -46,6 +50,9 @@ def reduce_scatter(output_tensor, input_tensor, ctx, op=None, group=None, async_ if variant != "two_shot": raise ValueError(f"reduce_scatter only supports variant='two_shot', got '{variant}'.") + # Input is remote-read by all ranks — must be on symmetric heap + input_tensor = _ensure_symmetric(ctx, input_tensor, "input_tensor") + rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx) M, N = input_tensor.shape[:2] diff --git a/iris/ccl/utils.py b/iris/ccl/utils.py index 07a8db416..573ebdab1 100644 --- a/iris/ccl/utils.py +++ b/iris/ccl/utils.py @@ -76,3 +76,30 @@ def extract_group_info(group, ctx) -> Tuple[int, int, int, int, int]: """ return _extract_group_info(group, ctx.get_rank(), ctx.get_num_ranks()) + + +def _ensure_symmetric(ctx, tensor, name="tensor"): + """Return tensor on symmetric heap, importing if needed. + + For input tensors that are only read: the kernel reads from the heap + copy, so the caller doesn't need the returned tensor back. + + Do NOT use this for output tensors — see _validate_output_symmetric. + """ + if ctx.is_symmetric(tensor): + return tensor + return ctx.as_symmetric(tensor) + + +def _validate_output_symmetric(ctx, tensor, name="output_tensor"): + """Raise if output tensor is not on symmetric heap. + + Output tensors cannot be auto-imported because the kernel writes results + into the heap copy, but the caller's original tensor would never see + those results (torch allocator makes an independent copy). + """ + if not ctx.is_symmetric(tensor): + raise ValueError( + f"{name} must be on the symmetric heap. " + f"Allocate with ctx.zeros() or import with ctx.as_symmetric() before calling." + ) diff --git a/tests/ccl/test_all_gather.py b/tests/ccl/test_all_gather.py index 7858ed18d..01a69161e 100644 --- a/tests/ccl/test_all_gather.py +++ b/tests/ccl/test_all_gather.py @@ -170,3 +170,84 @@ def test_all_gather_partitioned(dtype, M, N, block_size_m, block_size_n): import gc gc.collect() + + +def test_all_gather_rejects_non_symmetric_output(): + """Test that non-symmetric output tensor raises ValueError. + + In all_gather, other ranks write to the output tensor via RMA (iris.store). + If the output is not on the symmetric heap, those remote writes would + silently go to the wrong address. We reject early with a clear error. + """ + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**30 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + M, N = 128, 64 + + # Input on symmetric heap (fine) + iris_input = shmem.zeros((M, N), dtype=torch.float32) + # Output on regular CUDA memory (NOT on symmetric heap) + bad_output = torch.zeros(world_size * M, N, dtype=torch.float32, device=f"cuda:{rank}") + assert not shmem.is_symmetric(bad_output) + + try: + with pytest.raises(ValueError, match="output_tensor must be on the symmetric heap"): + config = Config(block_size_m=32, block_size_n=64) + shmem.ccl.all_gather(bad_output, iris_input, config=config) + finally: + shmem.barrier() + del shmem + import gc + + gc.collect() + + +def test_all_gather_non_symmetric_input_ok(): + """Test that non-symmetric input tensor works fine (input is local-only). + + In all_gather, each rank only reads its own input locally — it's never + accessed remotely. So non-symmetric input tensors should work as-is. + """ + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**30 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + M, N = 128, 64 + + # Input on regular CUDA memory (NOT on symmetric heap — that's fine for all_gather) + external_input = torch.randn(M, N, dtype=torch.float32, device=f"cuda:{rank}") + external_input.fill_(float(rank + 1)) + assert not shmem.is_symmetric(external_input) + + # Output on symmetric heap (required — remote writes) + iris_output = shmem.zeros((world_size * M, N), dtype=torch.float32) + + # Reference via PyTorch + pytorch_output = torch.zeros(world_size * M, N, dtype=torch.float32, device=f"cuda:{rank}") + shmem.barrier() + dist.all_gather_into_tensor(pytorch_output, external_input) + torch.cuda.synchronize() + + # Should work — input doesn't need symmetric heap + shmem.barrier() + config = Config(block_size_m=32, block_size_n=64) + shmem.ccl.all_gather(iris_output, external_input, config=config) + torch.cuda.synchronize() + + try: + assert torch.allclose(iris_output, pytorch_output, atol=1e-5), ( + f"Non-symmetric input: max diff {torch.abs(iris_output - pytorch_output).max().item()}" + ) + finally: + shmem.barrier() + del shmem + import gc + + gc.collect()