Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions iris/ccl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@
"""

from dataclasses import dataclass
import functools
import iris


@functools.lru_cache(maxsize=1)
def _cached_num_xcc():
"""Cache the XCC count since it never changes during a process."""
Comment on lines +9 to +15
return iris.hip.get_num_xcc()


@dataclass
class Config:
"""
Expand All @@ -32,9 +39,10 @@ class Config:
use_gluon: If True, use Gluon-based implementation (default: False)
Gluon provides better control over warp-level traffic shaping
all_gather_variant: Variant for all-gather operation (default: "persistent")
Options: "persistent", "partitioned"
- "persistent": Each PID handles multiple tiles and sends to all ranks
- "partitioned": PIDs partitioned across ranks, eliminates inner loop
Options: "persistent", "partitioned", "pull"
- "persistent": Each PID handles multiple tiles and sends to all ranks (PUSH)
- "partitioned": PIDs partitioned across ranks, eliminates inner loop (PUSH)
- "pull": Each rank reads from all ranks into its own output (PULL)
all_reduce_variant: Variant for all-reduce operation (default: "atomic")
Options: "atomic", "ring", "two_shot", "one_shot", "spinlock"
all_reduce_distribution: Distribution for two-shot all-reduce (default: 0)
Expand Down Expand Up @@ -84,7 +92,7 @@ class Config:
num_xcds: int | None = None
chunk_size: int | None = None
use_gluon: bool = False
all_gather_variant: str = "persistent"
all_gather_variant: str = "pull"
all_reduce_variant: str = "two_shot"
all_reduce_distribution: int = 1
all_reduce_num_rings: int = 1
Expand All @@ -98,7 +106,7 @@ class Config:
def __post_init__(self):
"""Validate and auto-detect num_xcds if not set."""
if self.num_xcds is None:
self.num_xcds = iris.hip.get_num_xcc()
self.num_xcds = _cached_num_xcc()
Comment on lines 106 to +109

if self.chunk_size is None:
self.chunk_size = self.swizzle_size * self.swizzle_size
Expand All @@ -114,9 +122,9 @@ def __post_init__(self):
raise ValueError(f"comm_sms must be positive, got {self.comm_sms}")
if self.num_xcds <= 0:
raise ValueError(f"num_xcds must be positive, got {self.num_xcds}")
if self.all_gather_variant not in ["persistent", "partitioned"]:
if self.all_gather_variant not in ["persistent", "partitioned", "pull"]:
raise ValueError(
f"all_gather_variant must be one of: 'persistent', 'partitioned', got {self.all_gather_variant}"
f"all_gather_variant must be one of: 'persistent', 'partitioned', 'pull', got {self.all_gather_variant}"
)
if self.all_reduce_variant not in ["atomic", "ring", "two_shot", "one_shot", "spinlock"]:
raise ValueError(
Expand Down
65 changes: 65 additions & 0 deletions iris/ccl/triton/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,69 @@ def persistent_all_gather_partitioned(
)


@triton.jit()
def persistent_all_gather_pull(
input_ptr,
output_ptr,
M,
N,
stride_in_m,
stride_in_n,
stride_out_m,
stride_out_n,
heap_bases: tl.tensor,
group_rank: tl.constexpr,
iris_rank: tl.constexpr,
world_size: tl.constexpr,
rank_start: tl.constexpr,
rank_stride: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
COMM_SMS: tl.constexpr,
NUM_XCDS: tl.constexpr,
CHUNK_SIZE: tl.constexpr,
):
Comment on lines +296 to +302
"""
Pull-model all-gather: each rank reads input from all ranks into its own output.

Uses simple linear tile indexing (no swizzle/chiplet transform) for clean
codegen on MI300X. GROUP_SIZE_M/NUM_XCDS/CHUNK_SIZE accepted but unused.
Comment on lines +306 to +307
"""
pid = tl.program_id(0)

num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n

for tile_id in range(pid, total_tiles, COMM_SMS):
pid_m = tile_id // num_pid_n
pid_n = tile_id % num_pid_n

tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)

rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)

mask = (rm[:, None] < M) & (rn[None, :] < N)
input_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n

for i in tl.static_range(world_size):
source_rank = rank_start + i * rank_stride
rm_output = rm + i * M
output_offset = rm_output[:, None] * stride_out_m + rn[None, :] * stride_out_n

if i == group_rank:
data = tl.load(input_ptr + input_offset, mask=mask, other=0.0)
tl.store(output_ptr + output_offset, data, mask=mask, cache_modifier=".wt")
else:
data = iris.load(input_ptr + input_offset, iris_rank, source_rank, heap_bases, mask=mask)
tl.store(output_ptr + output_offset, data, mask=mask, cache_modifier=".wt")


def launch(
input_tensor,
output_tensor,
Expand Down Expand Up @@ -307,6 +370,8 @@ def launch(
kernel_fn = persistent_all_gather
elif config.all_gather_variant == "partitioned":
kernel_fn = persistent_all_gather_partitioned
elif config.all_gather_variant == "pull":
kernel_fn = persistent_all_gather_pull
else:
raise ValueError(f"Unknown all_gather_variant: {config.all_gather_variant}")

Expand Down
74 changes: 73 additions & 1 deletion iris/ccl/triton/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,75 @@ def persistent_all_to_all(
)


@triton.jit()
def persistent_all_to_all_pull(
input_ptr,
output_ptr,
M,
N,
stride_in_m,
stride_in_n,
stride_out_m,
stride_out_n,
heap_bases: tl.tensor,
group_rank: tl.constexpr,
iris_rank: tl.constexpr,
world_size: tl.constexpr,
rank_start: tl.constexpr,
rank_stride: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
COMM_SMS: tl.constexpr,
NUM_XCDS: tl.constexpr,
CHUNK_SIZE: tl.constexpr,
):
"""
Pull-model all-to-all: each rank reads its data from all remote ranks.

Instead of pushing local chunks to remote outputs, each rank pulls
the chunks destined for it from all remote inputs.
"""
pid = tl.program_id(0)

num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n

for tile_id in range(pid, total_tiles, COMM_SMS):
pid_m = tile_id // num_pid_n
pid_n = tile_id % num_pid_n

tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)

rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)

mask = (rm[:, None] < M) & (rn[None, :] < N)

# For each source rank i, pull the chunk that rank i sends to us (group_rank)
for i in tl.static_range(world_size):
source_rank = rank_start + i * rank_stride

# Source rank i's input for us is at column offset group_rank * N
input_col_offset = group_rank * N
input_offset = rm[:, None] * stride_in_m + (rn[None, :] + input_col_offset) * stride_in_n

# Output goes to column offset i * N
output_col_offset = i * N
output_offset = rm[:, None] * stride_out_m + (rn[None, :] + output_col_offset) * stride_out_n

if i == group_rank:
data = tl.load(input_ptr + input_offset, mask=mask, other=0.0)
tl.store(output_ptr + output_offset, data, mask=mask, cache_modifier=".wt")
else:
data = iris.load(input_ptr + input_offset, iris_rank, source_rank, heap_bases, mask=mask)
tl.store(output_ptr + output_offset, data, mask=mask, cache_modifier=".wt")


def launch(
input_tensor,
output_tensor,
Expand All @@ -195,8 +264,11 @@ def launch(
stride_in_m, stride_in_n = input_tensor.stride(0), input_tensor.stride(1)
stride_out_m, stride_out_n = output_tensor.stride(0), output_tensor.stride(1)

# Use PULL model by default — PUSH model has poor performance on MI300X
kernel_fn = persistent_all_to_all_pull

Comment on lines +267 to +269
iris_launch(
persistent_all_to_all,
kernel_fn,
(config.comm_sms,),
input_tensor,
output_tensor,
Expand Down
82 changes: 81 additions & 1 deletion iris/ccl/triton/reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,83 @@ def persistent_reduce_scatter_two_shot(
tl.store(out_ptr, reduced, mask=mask, cache_modifier=".wt")


@triton.jit()
def persistent_reduce_scatter_simple(
input_ptr,
output_ptr,
M,
N,
stride_in_m,
stride_in_n,
stride_out_m,
stride_out_n,
heap_bases: tl.tensor,
group_rank: tl.constexpr,
iris_rank: tl.constexpr,
world_size: tl.constexpr,
rank_start: tl.constexpr,
rank_stride: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
COMM_SMS: tl.constexpr,
NUM_XCDS: tl.constexpr,
CHUNK_SIZE: tl.constexpr,
DISTRIBUTION: tl.constexpr,
):
"""
Simplified reduce-scatter using PULL model.

Each rank reduces its assigned tiles by pulling data from all ranks
and stores the result locally. Avoids the dual-path is_full/masked
structure that causes poor codegen on MI300X.
"""
pid = tl.program_id(0)

num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n

acc_dtype = tl.float32 if output_ptr.type.element_ty != tl.int8 else tl.int32

# Block distribution: each rank handles contiguous chunk of tiles
tiles_per_rank = tl.cdiv(total_tiles, world_size)
start_tile = group_rank * tiles_per_rank
remaining = total_tiles - start_tile
remaining = tl.maximum(remaining, 0)
max_tiles = tl.minimum(tiles_per_rank, remaining)

for tile_offset in range(pid, max_tiles, COMM_SMS):
tile_id = start_tile + tile_offset

pid_m = tile_id // num_pid_n
pid_n = tile_id % num_pid_n

Comment on lines +182 to +194
tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)

rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)

mask = (rm[:, None] < M) & (rn[None, :] < N)
input_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n
output_offset = rm[:, None] * stride_out_m + rn[None, :] * stride_out_n

Comment on lines +205 to +206
# Accumulate from all ranks via PULL
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for i in tl.static_range(world_size):
source_rank = rank_start + i * rank_stride
if i == group_rank:
data = tl.load(input_ptr + input_offset, mask=mask, other=0.0)
else:
data = iris.load(input_ptr + input_offset, iris_rank, source_rank, heap_bases, mask=mask)
acc += data.to(acc_dtype)

tl.store(output_ptr + output_offset, acc.to(output_ptr.type.element_ty), mask=mask, cache_modifier=".wt")


def launch(
output_tensor,
input_tensor,
Expand All @@ -159,8 +236,11 @@ def launch(
heap_bases = ctx.get_heap_bases()
distribution = config.all_reduce_distribution

# Use simplified kernel by default — the two_shot variant has poor codegen on MI300X
kernel_fn = persistent_reduce_scatter_simple

Comment on lines +239 to +241
iris_launch(
persistent_reduce_scatter_two_shot,
kernel_fn,
(config.comm_sms,),
input_tensor,
output_tensor,
Expand Down