Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions iris/ccl/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Routes to triton/ or gluon/ based on config.use_gluon.
"""

from iris.ccl.utils import extract_group_info
from iris.ccl.utils import extract_group_info, _validate_output_symmetric


def all_gather(output_tensor, input_tensor, ctx, group=None, async_op=False, config=None):
Expand All @@ -16,8 +16,12 @@ def all_gather(output_tensor, input_tensor, ctx, group=None, async_op=False, con

Output is (world_size * M, N) — inputs concatenated along dim 0.

The output tensor must be on the symmetric heap (other ranks write to it
via RMA). The input tensor is only read locally, so it does not need to
be symmetric.

Args:
output_tensor: Shape (world_size * M, N)
output_tensor: Shape (world_size * M, N) — must be on symmetric heap
input_tensor: Shape (M, N)
ctx: Iris instance
group: ProcessGroup or None
Expand All @@ -29,6 +33,9 @@ def all_gather(output_tensor, input_tensor, ctx, group=None, async_op=False, con
if config is None:
config = Config(block_size_m=32, block_size_n=64)

# Output is written remotely by other ranks — must be pre-allocated on symmetric heap
_validate_output_symmetric(ctx, output_tensor, "output_tensor")

rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx)

M, N = input_tensor.shape[:2]
Expand Down
20 changes: 19 additions & 1 deletion iris/ccl/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Triton only (no gluon support).
"""

from iris.ccl.utils import extract_group_info
from iris.ccl.utils import extract_group_info, _ensure_symmetric, _validate_output_symmetric


def all_reduce_preamble(output_tensor, input_tensor, ctx, config=None, workspace=None):
Expand All @@ -21,6 +21,16 @@ def all_reduce(output_tensor, input_tensor, ctx, op=None, group=None, async_op=F
"""
All-reduce: sum inputs across all ranks, result on every rank.

Which tensors need the symmetric heap depends on the variant:
- atomic/spinlock: output is remote-accessed (input is local)
- one_shot: input is remote-read (output is local)
- two_shot: both are remote-accessed
- ring: neither (workspace ring_buffer/flags are, but those are internal)

We ensure symmetric on the tensors that each variant actually accesses
remotely. Workspace tensors (ring_buffer, flags, locks) are allocated
internally on the heap already.

Args:
output_tensor: Shape (M, N)
input_tensor: Shape (M, N)
Expand Down Expand Up @@ -55,6 +65,14 @@ def all_reduce(output_tensor, input_tensor, ctx, op=None, group=None, async_op=F
if variant not in valid_variants:
raise ValueError(f"Invalid all_reduce_variant: {variant}. Must be one of: {', '.join(valid_variants)}")

# Ensure/validate symmetric only for tensors that are remotely accessed per variant
if variant in ("one_shot", "two_shot"):
# Input is remote-read — auto-import if needed
input_tensor = _ensure_symmetric(ctx, input_tensor, "input_tensor")
if variant in ("atomic", "spinlock", "two_shot"):
# Output is remote-written — must be pre-allocated on heap
_validate_output_symmetric(ctx, output_tensor, "output_tensor")
Comment on lines +68 to +74
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New variant-dependent behavior was introduced (auto-import for one_shot/two_shot inputs; strict validation for atomic/spinlock/two_shot outputs). Consider adding focused tests for at least one variant in each category to prevent regressions and to verify the intended access-pattern policy.

Copilot uses AI. Check for mistakes.

rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx)

from iris.ccl.triton.all_reduce import launch
Expand Down
14 changes: 11 additions & 3 deletions iris/ccl/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Routes to triton/ or gluon/ based on config.use_gluon.
"""

from iris.ccl.utils import extract_group_info
from iris.ccl.utils import extract_group_info, _ensure_symmetric, _validate_output_symmetric


def all_to_all(output_tensor, input_tensor, ctx, group=None, async_op=False, config=None):
Expand All @@ -16,9 +16,12 @@ def all_to_all(output_tensor, input_tensor, ctx, group=None, async_op=False, con

Input/output shape: (M, N * world_size).

Both tensors are accessed remotely (other ranks read input slices and
write to output slices via RMA), so both must be on the symmetric heap.

Args:
output_tensor: Shape (M, N * world_size)
input_tensor: Shape (M, N * world_size)
output_tensor: Shape (M, N * world_size) — must be on symmetric heap
input_tensor: Shape (M, N * world_size) — must be on symmetric heap
Comment on lines +19 to +24
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring currently states both tensors 'must be on the symmetric heap', but input_tensor is auto-imported if needed. To avoid confusing API semantics, update the docstring to distinguish requirements: input 'will be imported if needed', while output 'must already be allocated on the symmetric heap' (strict).

Copilot uses AI. Check for mistakes.
ctx: Iris instance
group: ProcessGroup or None
async_op: If True, skip trailing barrier
Expand All @@ -29,6 +32,11 @@ def all_to_all(output_tensor, input_tensor, ctx, group=None, async_op=False, con
if config is None:
config = Config(block_size_m=32, block_size_n=128)

# Input is remote-read by other ranks — auto-import if needed
input_tensor = _ensure_symmetric(ctx, input_tensor, "input_tensor")
Comment on lines +35 to +36
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring currently states both tensors 'must be on the symmetric heap', but input_tensor is auto-imported if needed. To avoid confusing API semantics, update the docstring to distinguish requirements: input 'will be imported if needed', while output 'must already be allocated on the symmetric heap' (strict).

Copilot uses AI. Check for mistakes.
# Output is remote-written by other ranks — must be pre-allocated on heap
_validate_output_symmetric(ctx, output_tensor, "output_tensor")

rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx)

if config.use_gluon:
Expand Down
11 changes: 9 additions & 2 deletions iris/ccl/reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@
Triton only (no gluon support).
"""

from iris.ccl.utils import extract_group_info
from iris.ccl.utils import extract_group_info, _ensure_symmetric


def reduce_scatter(output_tensor, input_tensor, ctx, op=None, group=None, async_op=False, config=None):
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function docstring says input_tensor ... must be on symmetric heap, but the implementation now auto-imports via _ensure_symmetric. Update the docstring/comments to reflect the actual behavior (e.g., 'will be imported to the symmetric heap if needed') so callers understand non-symmetric inputs are accepted and will incur a copy.

Copilot uses AI. Check for mistakes.
"""
Reduce-scatter: each rank reduces its assigned tiles, stores locally.

The input tensor is read remotely by all ranks (via iris.load).
The output tensor is only written locally — it does not need to be
on the symmetric heap.

Args:
output_tensor: Shape (M, N)
input_tensor: Shape (M, N)
input_tensor: Shape (M, N) — must be on symmetric heap
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function docstring says input_tensor ... must be on symmetric heap, but the implementation now auto-imports via _ensure_symmetric. Update the docstring/comments to reflect the actual behavior (e.g., 'will be imported to the symmetric heap if needed') so callers understand non-symmetric inputs are accepted and will incur a copy.

Copilot uses AI. Check for mistakes.
ctx: Iris instance
op: ReduceOp (only SUM supported)
group: ProcessGroup or None
Expand Down Expand Up @@ -46,6 +50,9 @@ def reduce_scatter(output_tensor, input_tensor, ctx, op=None, group=None, async_
if variant != "two_shot":
raise ValueError(f"reduce_scatter only supports variant='two_shot', got '{variant}'.")

# Input is remote-read by all ranks — must be on symmetric heap
input_tensor = _ensure_symmetric(ctx, input_tensor, "input_tensor")
Comment on lines +53 to +54
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function docstring says input_tensor ... must be on symmetric heap, but the implementation now auto-imports via _ensure_symmetric. Update the docstring/comments to reflect the actual behavior (e.g., 'will be imported to the symmetric heap if needed') so callers understand non-symmetric inputs are accepted and will incur a copy.

Copilot uses AI. Check for mistakes.

rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx)
M, N = input_tensor.shape[:2]

Expand Down
27 changes: 27 additions & 0 deletions iris/ccl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,30 @@ def extract_group_info(group, ctx) -> Tuple[int, int, int, int, int]:
"""

return _extract_group_info(group, ctx.get_rank(), ctx.get_num_ranks())


def _ensure_symmetric(ctx, tensor, name="tensor"):
"""Return tensor on symmetric heap, importing if needed.

For input tensors that are only read: the kernel reads from the heap
copy, so the caller doesn't need the returned tensor back.

Do NOT use this for output tensors — see _validate_output_symmetric.
"""
if ctx.is_symmetric(tensor):
return tensor
return ctx.as_symmetric(tensor)
Comment on lines +82 to +91
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name is currently unused in _ensure_symmetric, which can trigger lint warnings and is misleading. Either remove the parameter, or use it (e.g., in a debug/log message or in an error if as_symmetric can fail). Also consider rewording the docstring: while users may not need the returned tensor, the caller kernel launch must use the returned value—otherwise the import has no effect.

Suggested change
"""Return tensor on symmetric heap, importing if needed.
For input tensors that are only read: the kernel reads from the heap
copy, so the caller doesn't need the returned tensor back.
Do NOT use this for output tensorssee _validate_output_symmetric.
"""
if ctx.is_symmetric(tensor):
return tensor
return ctx.as_symmetric(tensor)
"""Return ``tensor`` on the symmetric heap, importing if needed.
If ``tensor`` is not already symmetric, this returns the imported heap
copy. The caller must use the returned tensor for the subsequent kernel
launch; otherwise the import has no effect.
For read-only inputs, users may not need to keep the returned tensor after
the kernel launch completes.
Do NOT use this for output tensorssee _validate_output_symmetric.
"""
if ctx.is_symmetric(tensor):
return tensor
try:
return ctx.as_symmetric(tensor)
except Exception as exc:
raise type(exc)(f"Failed to import {name} to the symmetric heap: {exc}") from exc

Copilot uses AI. Check for mistakes.


def _validate_output_symmetric(ctx, tensor, name="output_tensor"):
"""Raise if output tensor is not on symmetric heap.

Output tensors cannot be auto-imported because the kernel writes results
into the heap copy, but the caller's original tensor would never see
those results (torch allocator makes an independent copy).
"""
if not ctx.is_symmetric(tensor):
raise ValueError(
f"{name} must be on the symmetric heap. "
f"Allocate with ctx.zeros() or import with ctx.as_symmetric() before calling."
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The message suggests ctx.as_symmetric() can be used as a pre-step, but it does not make it explicit that as_symmetric returns a new tensor that must be passed to the collective. Consider clarifying to avoid callers doing ctx.as_symmetric(output_tensor) without using the returned tensor (which would still fail validation or would lead to confusion).

Suggested change
f"Allocate with ctx.zeros() or import with ctx.as_symmetric() before calling."
f"Allocate it with ctx.zeros(), or assign the result of "
f"ctx.as_symmetric({name}) to a new tensor and pass that returned "
f"tensor to the collective."

Copilot uses AI. Check for mistakes.
)
81 changes: 81 additions & 0 deletions tests/ccl/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,84 @@ def test_all_gather_partitioned(dtype, M, N, block_size_m, block_size_n):
import gc

gc.collect()


def test_all_gather_rejects_non_symmetric_output():
"""Test that non-symmetric output tensor raises ValueError.

In all_gather, other ranks write to the output tensor via RMA (iris.store).
If the output is not on the symmetric heap, those remote writes would
silently go to the wrong address. We reject early with a clear error.
"""
if not dist.is_initialized():
pytest.skip("torch.distributed not initialized")

heap_size = 2**30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests select the CUDA device using the global rank (cuda:{rank}), which will fail on multi-node runs (global rank can exceed the local GPU count on a node). Prefer using the local device mapping (e.g., LOCAL_RANK / torch.cuda.current_device() or a helper/fixture already used elsewhere in the test suite) to make the tests robust across single-node and multi-node distributed environments.

Copilot uses AI. Check for mistakes.
world_size = shmem.get_num_ranks()
M, N = 128, 64

# Input on symmetric heap (fine)
iris_input = shmem.zeros((M, N), dtype=torch.float32)
# Output on regular CUDA memory (NOT on symmetric heap)
bad_output = torch.zeros(world_size * M, N, dtype=torch.float32, device=f"cuda:{rank}")
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests select the CUDA device using the global rank (cuda:{rank}), which will fail on multi-node runs (global rank can exceed the local GPU count on a node). Prefer using the local device mapping (e.g., LOCAL_RANK / torch.cuda.current_device() or a helper/fixture already used elsewhere in the test suite) to make the tests robust across single-node and multi-node distributed environments.

Copilot uses AI. Check for mistakes.
assert not shmem.is_symmetric(bad_output)

try:
with pytest.raises(ValueError, match="output_tensor must be on the symmetric heap"):
config = Config(block_size_m=32, block_size_n=64)
shmem.ccl.all_gather(bad_output, iris_input, config=config)
finally:
shmem.barrier()
del shmem
import gc

gc.collect()


def test_all_gather_non_symmetric_input_ok():
"""Test that non-symmetric input tensor works fine (input is local-only).

In all_gather, each rank only reads its own input locally — it's never
accessed remotely. So non-symmetric input tensors should work as-is.
"""
if not dist.is_initialized():
pytest.skip("torch.distributed not initialized")

heap_size = 2**30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
M, N = 128, 64

# Input on regular CUDA memory (NOT on symmetric heap — that's fine for all_gather)
external_input = torch.randn(M, N, dtype=torch.float32, device=f"cuda:{rank}")
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests select the CUDA device using the global rank (cuda:{rank}), which will fail on multi-node runs (global rank can exceed the local GPU count on a node). Prefer using the local device mapping (e.g., LOCAL_RANK / torch.cuda.current_device() or a helper/fixture already used elsewhere in the test suite) to make the tests robust across single-node and multi-node distributed environments.

Copilot uses AI. Check for mistakes.
external_input.fill_(float(rank + 1))
assert not shmem.is_symmetric(external_input)

# Output on symmetric heap (required — remote writes)
iris_output = shmem.zeros((world_size * M, N), dtype=torch.float32)

# Reference via PyTorch
pytorch_output = torch.zeros(world_size * M, N, dtype=torch.float32, device=f"cuda:{rank}")
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests select the CUDA device using the global rank (cuda:{rank}), which will fail on multi-node runs (global rank can exceed the local GPU count on a node). Prefer using the local device mapping (e.g., LOCAL_RANK / torch.cuda.current_device() or a helper/fixture already used elsewhere in the test suite) to make the tests robust across single-node and multi-node distributed environments.

Copilot uses AI. Check for mistakes.
shmem.barrier()
dist.all_gather_into_tensor(pytorch_output, external_input)
torch.cuda.synchronize()

# Should work — input doesn't need symmetric heap
shmem.barrier()
config = Config(block_size_m=32, block_size_n=64)
shmem.ccl.all_gather(iris_output, external_input, config=config)
torch.cuda.synchronize()

try:
assert torch.allclose(iris_output, pytorch_output, atol=1e-5), (
f"Non-symmetric input: max diff {torch.abs(iris_output - pytorch_output).max().item()}"
)
finally:
shmem.barrier()
del shmem
import gc

gc.collect()
Loading