-
Notifications
You must be signed in to change notification settings - Fork 39
Validate symmetric heap placement in CCL collectives #526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
| Routes to triton/ or gluon/ based on config.use_gluon. | ||
| """ | ||
|
|
||
| from iris.ccl.utils import extract_group_info | ||
| from iris.ccl.utils import extract_group_info, _ensure_symmetric, _validate_output_symmetric | ||
|
|
||
|
|
||
| def all_to_all(output_tensor, input_tensor, ctx, group=None, async_op=False, config=None): | ||
|
|
@@ -16,9 +16,12 @@ def all_to_all(output_tensor, input_tensor, ctx, group=None, async_op=False, con | |
|
|
||
| Input/output shape: (M, N * world_size). | ||
|
|
||
| Both tensors are accessed remotely (other ranks read input slices and | ||
| write to output slices via RMA), so both must be on the symmetric heap. | ||
|
|
||
| Args: | ||
| output_tensor: Shape (M, N * world_size) | ||
| input_tensor: Shape (M, N * world_size) | ||
| output_tensor: Shape (M, N * world_size) — must be on symmetric heap | ||
| input_tensor: Shape (M, N * world_size) — must be on symmetric heap | ||
|
Comment on lines
+19
to
+24
|
||
| ctx: Iris instance | ||
| group: ProcessGroup or None | ||
| async_op: If True, skip trailing barrier | ||
|
|
@@ -29,6 +32,11 @@ def all_to_all(output_tensor, input_tensor, ctx, group=None, async_op=False, con | |
| if config is None: | ||
| config = Config(block_size_m=32, block_size_n=128) | ||
|
|
||
| # Input is remote-read by other ranks — auto-import if needed | ||
| input_tensor = _ensure_symmetric(ctx, input_tensor, "input_tensor") | ||
|
Comment on lines
+35
to
+36
|
||
| # Output is remote-written by other ranks — must be pre-allocated on heap | ||
| _validate_output_symmetric(ctx, output_tensor, "output_tensor") | ||
|
|
||
| rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx) | ||
|
|
||
| if config.use_gluon: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,16 +7,20 @@ | |
| Triton only (no gluon support). | ||
| """ | ||
|
|
||
| from iris.ccl.utils import extract_group_info | ||
| from iris.ccl.utils import extract_group_info, _ensure_symmetric | ||
|
|
||
|
|
||
| def reduce_scatter(output_tensor, input_tensor, ctx, op=None, group=None, async_op=False, config=None): | ||
|
||
| """ | ||
| Reduce-scatter: each rank reduces its assigned tiles, stores locally. | ||
|
|
||
| The input tensor is read remotely by all ranks (via iris.load). | ||
| The output tensor is only written locally — it does not need to be | ||
| on the symmetric heap. | ||
|
|
||
| Args: | ||
| output_tensor: Shape (M, N) | ||
| input_tensor: Shape (M, N) | ||
| input_tensor: Shape (M, N) — must be on symmetric heap | ||
|
||
| ctx: Iris instance | ||
| op: ReduceOp (only SUM supported) | ||
| group: ProcessGroup or None | ||
|
|
@@ -46,6 +50,9 @@ def reduce_scatter(output_tensor, input_tensor, ctx, op=None, group=None, async_ | |
| if variant != "two_shot": | ||
| raise ValueError(f"reduce_scatter only supports variant='two_shot', got '{variant}'.") | ||
|
|
||
| # Input is remote-read by all ranks — must be on symmetric heap | ||
| input_tensor = _ensure_symmetric(ctx, input_tensor, "input_tensor") | ||
|
Comment on lines
+53
to
+54
|
||
|
|
||
| rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx) | ||
| M, N = input_tensor.shape[:2] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -76,3 +76,30 @@ def extract_group_info(group, ctx) -> Tuple[int, int, int, int, int]: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return _extract_group_info(group, ctx.get_rank(), ctx.get_num_ranks()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _ensure_symmetric(ctx, tensor, name="tensor"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Return tensor on symmetric heap, importing if needed. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| For input tensors that are only read: the kernel reads from the heap | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| copy, so the caller doesn't need the returned tensor back. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Do NOT use this for output tensors — see _validate_output_symmetric. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ctx.is_symmetric(tensor): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ctx.as_symmetric(tensor) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+82
to
+91
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Return tensor on symmetric heap, importing if needed. | |
| For input tensors that are only read: the kernel reads from the heap | |
| copy, so the caller doesn't need the returned tensor back. | |
| Do NOT use this for output tensors — see _validate_output_symmetric. | |
| """ | |
| if ctx.is_symmetric(tensor): | |
| return tensor | |
| return ctx.as_symmetric(tensor) | |
| """Return ``tensor`` on the symmetric heap, importing if needed. | |
| If ``tensor`` is not already symmetric, this returns the imported heap | |
| copy. The caller must use the returned tensor for the subsequent kernel | |
| launch; otherwise the import has no effect. | |
| For read-only inputs, users may not need to keep the returned tensor after | |
| the kernel launch completes. | |
| Do NOT use this for output tensors — see _validate_output_symmetric. | |
| """ | |
| if ctx.is_symmetric(tensor): | |
| return tensor | |
| try: | |
| return ctx.as_symmetric(tensor) | |
| except Exception as exc: | |
| raise type(exc)(f"Failed to import {name} to the symmetric heap: {exc}") from exc |
Copilot
AI
Apr 30, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The message suggests ctx.as_symmetric() can be used as a pre-step, but it does not make it explicit that as_symmetric returns a new tensor that must be passed to the collective. Consider clarifying to avoid callers doing ctx.as_symmetric(output_tensor) without using the returned tensor (which would still fail validation or would lead to confusion).
| f"Allocate with ctx.zeros() or import with ctx.as_symmetric() before calling." | |
| f"Allocate it with ctx.zeros(), or assign the result of " | |
| f"ctx.as_symmetric({name}) to a new tensor and pass that returned " | |
| f"tensor to the collective." |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -170,3 +170,84 @@ def test_all_gather_partitioned(dtype, M, N, block_size_m, block_size_n): | |
| import gc | ||
|
|
||
| gc.collect() | ||
|
|
||
|
|
||
| def test_all_gather_rejects_non_symmetric_output(): | ||
| """Test that non-symmetric output tensor raises ValueError. | ||
|
|
||
| In all_gather, other ranks write to the output tensor via RMA (iris.store). | ||
| If the output is not on the symmetric heap, those remote writes would | ||
| silently go to the wrong address. We reject early with a clear error. | ||
| """ | ||
| if not dist.is_initialized(): | ||
| pytest.skip("torch.distributed not initialized") | ||
|
|
||
| heap_size = 2**30 | ||
| shmem = iris.iris(heap_size) | ||
| rank = shmem.get_rank() | ||
|
||
| world_size = shmem.get_num_ranks() | ||
| M, N = 128, 64 | ||
|
|
||
| # Input on symmetric heap (fine) | ||
| iris_input = shmem.zeros((M, N), dtype=torch.float32) | ||
| # Output on regular CUDA memory (NOT on symmetric heap) | ||
| bad_output = torch.zeros(world_size * M, N, dtype=torch.float32, device=f"cuda:{rank}") | ||
|
||
| assert not shmem.is_symmetric(bad_output) | ||
|
|
||
| try: | ||
| with pytest.raises(ValueError, match="output_tensor must be on the symmetric heap"): | ||
| config = Config(block_size_m=32, block_size_n=64) | ||
| shmem.ccl.all_gather(bad_output, iris_input, config=config) | ||
| finally: | ||
| shmem.barrier() | ||
| del shmem | ||
| import gc | ||
|
|
||
| gc.collect() | ||
|
|
||
|
|
||
| def test_all_gather_non_symmetric_input_ok(): | ||
| """Test that non-symmetric input tensor works fine (input is local-only). | ||
|
|
||
| In all_gather, each rank only reads its own input locally — it's never | ||
| accessed remotely. So non-symmetric input tensors should work as-is. | ||
| """ | ||
| if not dist.is_initialized(): | ||
| pytest.skip("torch.distributed not initialized") | ||
|
|
||
| heap_size = 2**30 | ||
| shmem = iris.iris(heap_size) | ||
| rank = shmem.get_rank() | ||
| world_size = shmem.get_num_ranks() | ||
| M, N = 128, 64 | ||
|
|
||
| # Input on regular CUDA memory (NOT on symmetric heap — that's fine for all_gather) | ||
| external_input = torch.randn(M, N, dtype=torch.float32, device=f"cuda:{rank}") | ||
|
||
| external_input.fill_(float(rank + 1)) | ||
| assert not shmem.is_symmetric(external_input) | ||
|
|
||
| # Output on symmetric heap (required — remote writes) | ||
| iris_output = shmem.zeros((world_size * M, N), dtype=torch.float32) | ||
|
|
||
| # Reference via PyTorch | ||
| pytorch_output = torch.zeros(world_size * M, N, dtype=torch.float32, device=f"cuda:{rank}") | ||
|
||
| shmem.barrier() | ||
| dist.all_gather_into_tensor(pytorch_output, external_input) | ||
| torch.cuda.synchronize() | ||
|
|
||
| # Should work — input doesn't need symmetric heap | ||
| shmem.barrier() | ||
| config = Config(block_size_m=32, block_size_n=64) | ||
| shmem.ccl.all_gather(iris_output, external_input, config=config) | ||
| torch.cuda.synchronize() | ||
|
|
||
| try: | ||
| assert torch.allclose(iris_output, pytorch_output, atol=1e-5), ( | ||
| f"Non-symmetric input: max diff {torch.abs(iris_output - pytorch_output).max().item()}" | ||
| ) | ||
| finally: | ||
| shmem.barrier() | ||
| del shmem | ||
| import gc | ||
|
|
||
| gc.collect() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New variant-dependent behavior was introduced (auto-import for
one_shot/two_shotinputs; strict validation foratomic/spinlock/two_shotoutputs). Consider adding focused tests for at least one variant in each category to prevent regressions and to verify the intended access-pattern policy.