diff --git a/iris/ccl/config.py b/iris/ccl/config.py index 81188c5fb..8b8d90c70 100644 --- a/iris/ccl/config.py +++ b/iris/ccl/config.py @@ -6,9 +6,16 @@ """ from dataclasses import dataclass +import functools import iris +@functools.lru_cache(maxsize=1) +def _cached_num_xcc(): + """Cache the XCC count since it never changes during a process.""" + return iris.hip.get_num_xcc() + + @dataclass class Config: """ @@ -32,9 +39,10 @@ class Config: use_gluon: If True, use Gluon-based implementation (default: False) Gluon provides better control over warp-level traffic shaping all_gather_variant: Variant for all-gather operation (default: "persistent") - Options: "persistent", "partitioned" - - "persistent": Each PID handles multiple tiles and sends to all ranks - - "partitioned": PIDs partitioned across ranks, eliminates inner loop + Options: "persistent", "partitioned", "pull" + - "persistent": Each PID handles multiple tiles and sends to all ranks (PUSH) + - "partitioned": PIDs partitioned across ranks, eliminates inner loop (PUSH) + - "pull": Each rank reads from all ranks into its own output (PULL) all_reduce_variant: Variant for all-reduce operation (default: "atomic") Options: "atomic", "ring", "two_shot", "one_shot", "spinlock" all_reduce_distribution: Distribution for two-shot all-reduce (default: 0) @@ -84,7 +92,7 @@ class Config: num_xcds: int | None = None chunk_size: int | None = None use_gluon: bool = False - all_gather_variant: str = "persistent" + all_gather_variant: str = "pull" all_reduce_variant: str = "two_shot" all_reduce_distribution: int = 1 all_reduce_num_rings: int = 1 @@ -98,7 +106,7 @@ class Config: def __post_init__(self): """Validate and auto-detect num_xcds if not set.""" if self.num_xcds is None: - self.num_xcds = iris.hip.get_num_xcc() + self.num_xcds = _cached_num_xcc() if self.chunk_size is None: self.chunk_size = self.swizzle_size * self.swizzle_size @@ -114,9 +122,9 @@ def __post_init__(self): raise ValueError(f"comm_sms must be positive, got {self.comm_sms}") if self.num_xcds <= 0: raise ValueError(f"num_xcds must be positive, got {self.num_xcds}") - if self.all_gather_variant not in ["persistent", "partitioned"]: + if self.all_gather_variant not in ["persistent", "partitioned", "pull"]: raise ValueError( - f"all_gather_variant must be one of: 'persistent', 'partitioned', got {self.all_gather_variant}" + f"all_gather_variant must be one of: 'persistent', 'partitioned', 'pull', got {self.all_gather_variant}" ) if self.all_reduce_variant not in ["atomic", "ring", "two_shot", "one_shot", "spinlock"]: raise ValueError( diff --git a/iris/ccl/triton/all_gather.py b/iris/ccl/triton/all_gather.py index cd2891ee7..dfa5c5a71 100644 --- a/iris/ccl/triton/all_gather.py +++ b/iris/ccl/triton/all_gather.py @@ -277,6 +277,69 @@ def persistent_all_gather_partitioned( ) +@triton.jit() +def persistent_all_gather_pull( + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases: tl.tensor, + group_rank: tl.constexpr, + iris_rank: tl.constexpr, + world_size: tl.constexpr, + rank_start: tl.constexpr, + rank_stride: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + CHUNK_SIZE: tl.constexpr, +): + """ + Pull-model all-gather: each rank reads input from all ranks into its own output. + + Uses simple linear tile indexing (no swizzle/chiplet transform) for clean + codegen on MI300X. GROUP_SIZE_M/NUM_XCDS/CHUNK_SIZE accepted but unused. + """ + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + for tile_id in range(pid, total_tiles, COMM_SMS): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + mask = (rm[:, None] < M) & (rn[None, :] < N) + input_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n + + for i in tl.static_range(world_size): + source_rank = rank_start + i * rank_stride + rm_output = rm + i * M + output_offset = rm_output[:, None] * stride_out_m + rn[None, :] * stride_out_n + + if i == group_rank: + data = tl.load(input_ptr + input_offset, mask=mask, other=0.0) + tl.store(output_ptr + output_offset, data, mask=mask, cache_modifier=".wt") + else: + data = iris.load(input_ptr + input_offset, iris_rank, source_rank, heap_bases, mask=mask) + tl.store(output_ptr + output_offset, data, mask=mask, cache_modifier=".wt") + + def launch( input_tensor, output_tensor, @@ -307,6 +370,8 @@ def launch( kernel_fn = persistent_all_gather elif config.all_gather_variant == "partitioned": kernel_fn = persistent_all_gather_partitioned + elif config.all_gather_variant == "pull": + kernel_fn = persistent_all_gather_pull else: raise ValueError(f"Unknown all_gather_variant: {config.all_gather_variant}") diff --git a/iris/ccl/triton/all_to_all.py b/iris/ccl/triton/all_to_all.py index a224a93a6..290a2d566 100644 --- a/iris/ccl/triton/all_to_all.py +++ b/iris/ccl/triton/all_to_all.py @@ -177,6 +177,75 @@ def persistent_all_to_all( ) +@triton.jit() +def persistent_all_to_all_pull( + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases: tl.tensor, + group_rank: tl.constexpr, + iris_rank: tl.constexpr, + world_size: tl.constexpr, + rank_start: tl.constexpr, + rank_stride: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + CHUNK_SIZE: tl.constexpr, +): + """ + Pull-model all-to-all: each rank reads its data from all remote ranks. + + Instead of pushing local chunks to remote outputs, each rank pulls + the chunks destined for it from all remote inputs. + """ + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + for tile_id in range(pid, total_tiles, COMM_SMS): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + mask = (rm[:, None] < M) & (rn[None, :] < N) + + # For each source rank i, pull the chunk that rank i sends to us (group_rank) + for i in tl.static_range(world_size): + source_rank = rank_start + i * rank_stride + + # Source rank i's input for us is at column offset group_rank * N + input_col_offset = group_rank * N + input_offset = rm[:, None] * stride_in_m + (rn[None, :] + input_col_offset) * stride_in_n + + # Output goes to column offset i * N + output_col_offset = i * N + output_offset = rm[:, None] * stride_out_m + (rn[None, :] + output_col_offset) * stride_out_n + + if i == group_rank: + data = tl.load(input_ptr + input_offset, mask=mask, other=0.0) + tl.store(output_ptr + output_offset, data, mask=mask, cache_modifier=".wt") + else: + data = iris.load(input_ptr + input_offset, iris_rank, source_rank, heap_bases, mask=mask) + tl.store(output_ptr + output_offset, data, mask=mask, cache_modifier=".wt") + + def launch( input_tensor, output_tensor, @@ -195,8 +264,11 @@ def launch( stride_in_m, stride_in_n = input_tensor.stride(0), input_tensor.stride(1) stride_out_m, stride_out_n = output_tensor.stride(0), output_tensor.stride(1) + # Use PULL model by default — PUSH model has poor performance on MI300X + kernel_fn = persistent_all_to_all_pull + iris_launch( - persistent_all_to_all, + kernel_fn, (config.comm_sms,), input_tensor, output_tensor, diff --git a/iris/ccl/triton/reduce_scatter.py b/iris/ccl/triton/reduce_scatter.py index 31f20ebf6..f023ea641 100644 --- a/iris/ccl/triton/reduce_scatter.py +++ b/iris/ccl/triton/reduce_scatter.py @@ -140,6 +140,83 @@ def persistent_reduce_scatter_two_shot( tl.store(out_ptr, reduced, mask=mask, cache_modifier=".wt") +@triton.jit() +def persistent_reduce_scatter_simple( + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases: tl.tensor, + group_rank: tl.constexpr, + iris_rank: tl.constexpr, + world_size: tl.constexpr, + rank_start: tl.constexpr, + rank_stride: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + CHUNK_SIZE: tl.constexpr, + DISTRIBUTION: tl.constexpr, +): + """ + Simplified reduce-scatter using PULL model. + + Each rank reduces its assigned tiles by pulling data from all ranks + and stores the result locally. Avoids the dual-path is_full/masked + structure that causes poor codegen on MI300X. + """ + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + acc_dtype = tl.float32 if output_ptr.type.element_ty != tl.int8 else tl.int32 + + # Block distribution: each rank handles contiguous chunk of tiles + tiles_per_rank = tl.cdiv(total_tiles, world_size) + start_tile = group_rank * tiles_per_rank + remaining = total_tiles - start_tile + remaining = tl.maximum(remaining, 0) + max_tiles = tl.minimum(tiles_per_rank, remaining) + + for tile_offset in range(pid, max_tiles, COMM_SMS): + tile_id = start_tile + tile_offset + + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + mask = (rm[:, None] < M) & (rn[None, :] < N) + input_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n + output_offset = rm[:, None] * stride_out_m + rn[None, :] * stride_out_n + + # Accumulate from all ranks via PULL + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for i in tl.static_range(world_size): + source_rank = rank_start + i * rank_stride + if i == group_rank: + data = tl.load(input_ptr + input_offset, mask=mask, other=0.0) + else: + data = iris.load(input_ptr + input_offset, iris_rank, source_rank, heap_bases, mask=mask) + acc += data.to(acc_dtype) + + tl.store(output_ptr + output_offset, acc.to(output_ptr.type.element_ty), mask=mask, cache_modifier=".wt") + + def launch( output_tensor, input_tensor, @@ -159,8 +236,11 @@ def launch( heap_bases = ctx.get_heap_bases() distribution = config.all_reduce_distribution + # Use simplified kernel by default — the two_shot variant has poor codegen on MI300X + kernel_fn = persistent_reduce_scatter_simple + iris_launch( - persistent_reduce_scatter_two_shot, + kernel_fn, (config.comm_sms,), input_tensor, output_tensor,