From f34d7323f9480e09cd190cb21b2efd5ebb8d97e9 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Sat, 2 May 2026 08:14:40 -0700 Subject: [PATCH 1/3] Fix hardcoded lock value in fused GEMM+CCL operations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace hardcoded lock signal value `1` with a monotonically increasing call_counter on FusedWorkspace. Each call to matmul_all_reduce or matmul_reduce_scatter increments the counter and passes it as the signal value to both producer (atomic_xchg) and consumer (spin loop) sides. This eliminates the need to zero locks between calls for one_shot, two_shot, and reduce_scatter variants, since each call uses a unique signal value that won't collide with previous calls. The spinlock variant still uses CAS(0→1)/release(0) mutex semantics and continues to require zeroed locks. The signal_value parameter defaults to 1 for backward compatibility with existing test kernels and examples that zero locks manually. Closes #465 Co-Authored-By: Claude Opus 4.6 --- iris/mem/triton/context.py | 21 +++++++++++++++------ iris/ops/matmul_all_reduce.py | 17 +++++++++++++---- iris/ops/matmul_reduce_scatter.py | 12 ++++++++---- iris/ops/workspace.py | 6 ++++++ 4 files changed, 42 insertions(+), 14 deletions(-) diff --git a/iris/mem/triton/context.py b/iris/mem/triton/context.py index a797dbc43..6d3adfb03 100644 --- a/iris/mem/triton/context.py +++ b/iris/mem/triton/context.py @@ -731,7 +731,7 @@ def all_reduce_spinlock(self, tile: Tile, dst_view: TensorView, locks): self.atomic_xchg(locks + tile_id, 0, to_rank=dest_rank, sem="release", scope="sys") @triton.jit - def all_reduce_one_shot(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks): + def all_reduce_one_shot(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks, signal_value=1): """ Tile-level all-reduce using one-shot algorithm. @@ -743,6 +743,9 @@ def all_reduce_one_shot(self, tile: Tile, src_view: TensorView, dst_view: Tensor src_view: TensorView for source tensor (to load remote data). dst_view: TensorView for output tensor. locks: Pointer to lock array used as ready flags. + signal_value: Expected lock value indicating tile is ready. + Use a monotonically increasing counter to avoid zeroing + locks between calls. """ src_tile_ptr, mask = src_view.tile_ptr(tile) dst_tile_ptr, _ = dst_view.tile_ptr(tile) @@ -755,7 +758,7 @@ def all_reduce_one_shot(self, tile: Tile, src_view: TensorView, dst_view: Tensor for remote_rank in range(self.world_size): if remote_rank != self.rank: lock_ptr = locks + tile_id - while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="sys") != 1: + while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="sys") != signal_value: pass partial = self.load(src_tile_ptr, from_rank=remote_rank, mask=mask) acc += partial.to(acc_dtype) @@ -799,7 +802,7 @@ def all_reduce_ring(self, tile: Tile, src_view: TensorView, dst_view: TensorView tl.store(dst_tile_ptr, remote_result, mask=mask) @triton.jit - def all_reduce_two_shot(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks): + def all_reduce_two_shot(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks, signal_value=1): """ Tile-level all-reduce using two-shot algorithm with work distribution. @@ -811,6 +814,9 @@ def all_reduce_two_shot(self, tile: Tile, src_view: TensorView, dst_view: Tensor src_view: TensorView for source tensor. dst_view: TensorView for output tensor. locks: Pointer to lock array used as ready flags. + signal_value: Expected lock value indicating tile is ready. + Use a monotonically increasing counter to avoid zeroing + locks between calls. """ num_tiles_n = tl.cdiv(dst_view.N, tile.block_n) tile_id = tile.pid_m * num_tiles_n + tile.pid_n @@ -826,7 +832,7 @@ def all_reduce_two_shot(self, tile: Tile, src_view: TensorView, dst_view: Tensor for remote_rank in range(self.world_size): if remote_rank != self.rank: lock_ptr = locks + tile_id - while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="sys") != 1: + while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="sys") != signal_value: pass partial = self.load(src_tile_ptr, from_rank=remote_rank, mask=mask) acc += partial.to(acc_dtype) @@ -951,7 +957,7 @@ def all_to_all(self, tile: TileView, src_view: TensorView, dst_view: TensorView, tl.store(dst_view.ptr + dst_offsets, data, mask=combined_mask) @triton.jit - def reduce_scatter(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks): + def reduce_scatter(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks, signal_value=1): """ Tile-level reduce-scatter using contiguous work distribution. @@ -963,6 +969,9 @@ def reduce_scatter(self, tile: Tile, src_view: TensorView, dst_view: TensorView, src_view: TensorView for source tensor. dst_view: TensorView for output tensor. locks: Pointer to lock array used as ready flags. + signal_value: Expected lock value indicating tile is ready. + Use a monotonically increasing counter to avoid zeroing + locks between calls. """ num_tiles_n = tl.cdiv(dst_view.N, tile.block_n) num_tiles_m = tl.cdiv(dst_view.M, tile.block_m) @@ -988,7 +997,7 @@ def reduce_scatter(self, tile: Tile, src_view: TensorView, dst_view: TensorView, for remote_rank in range(self.world_size): if remote_rank != self.rank: lock_ptr = locks + tile_id - while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="gpu") != 1: + while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="gpu") != signal_value: pass partial = self.load(src_tile_ptr, from_rank=remote_rank, mask=mask) acc += partial.to(acc_dtype) diff --git a/iris/ops/matmul_all_reduce.py b/iris/ops/matmul_all_reduce.py index 29b84b6e3..494c8e61e 100644 --- a/iris/ops/matmul_all_reduce.py +++ b/iris/ops/matmul_all_reduce.py @@ -46,6 +46,7 @@ def _fused_matmul_all_reduce_kernel( BLOCK_SIZE_K: tl.constexpr, EVEN_K: tl.constexpr, VARIANT: tl.constexpr, + SIGNAL_VALUE = 1, ): """ Fused GEMM + All-Reduce kernel with configurable all-reduce variant. @@ -133,15 +134,15 @@ def _fused_matmul_all_reduce_kernel( # Use atomic_xchg with release semantics to ensure memory ordering tile_id = pid_m * num_tiles_n + pid_n lock_ptr = locks + tile_id - tl.atomic_xchg(lock_ptr, 1, sem="release", scope="sys") # Release ensures prior stores visible to remote GPUs + tl.atomic_xchg(lock_ptr, SIGNAL_VALUE, sem="release", scope="sys") # Release ensures prior stores visible to remote GPUs # Create source view only when needed (aux_buffer is not None) src_view = iris.make_tensor_view(aux_buffer, M, N, stride_cm, stride_cn) if VARIANT == "one_shot": - ctx.all_reduce_one_shot(tile_obj, src_view, dst_view, locks) + ctx.all_reduce_one_shot(tile_obj, src_view, dst_view, locks, SIGNAL_VALUE) elif VARIANT == "two_shot": - ctx.all_reduce_two_shot(tile_obj, src_view, dst_view, locks) + ctx.all_reduce_two_shot(tile_obj, src_view, dst_view, locks, SIGNAL_VALUE) def matmul_all_reduce_preamble( @@ -196,7 +197,9 @@ def matmul_all_reduce_preamble( if config.all_reduce_variant in ["spinlock", "one_shot", "two_shot"]: if workspace.locks is None or workspace.locks.numel() != total_tiles: workspace.locks = shmem.zeros((total_tiles,), dtype=torch.int32) - else: + elif config.all_reduce_variant == "spinlock": + # Spinlock uses CAS(0→1) and releases back to 0, so needs zeroing. + # one_shot/two_shot use monotonic signal_value, no zeroing needed. workspace.locks.zero_() else: workspace.locks = None @@ -331,6 +334,11 @@ def matmul_all_reduce( even_k = K % config.block_size_k == 0 + # Increment call counter for producer-consumer signal value. + # Each call uses a unique value so consumers don't see stale signals. + workspace.call_counter += 1 + signal_value = workspace.call_counter + iris_launch( _fused_matmul_all_reduce_kernel, grid, @@ -356,6 +364,7 @@ def matmul_all_reduce( config.block_size_k, even_k, config.all_reduce_variant, + signal_value, algorithm="matmul_all_reduce", rank=rank, dtype=A.dtype, diff --git a/iris/ops/matmul_reduce_scatter.py b/iris/ops/matmul_reduce_scatter.py index 47454d630..8f2f7a05c 100644 --- a/iris/ops/matmul_reduce_scatter.py +++ b/iris/ops/matmul_reduce_scatter.py @@ -45,6 +45,7 @@ def _fused_matmul_reduce_scatter_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, EVEN_K: tl.constexpr, + SIGNAL_VALUE = 1, ): """ Fused GEMM + Reduce-Scatter kernel. @@ -105,7 +106,7 @@ def _fused_matmul_reduce_scatter_kernel( # Signal tile is ready tile_id = pid_m * num_tiles_n + pid_n lock_ptr = locks + tile_id - tl.atomic_xchg(lock_ptr, 1, sem="release", scope="gpu") + tl.atomic_xchg(lock_ptr, SIGNAL_VALUE, sem="release", scope="gpu") # Create tile object and context tile_obj = iris.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) @@ -115,7 +116,7 @@ def _fused_matmul_reduce_scatter_kernel( src_view = iris.make_tensor_view(aux_buffer, M, N, stride_cm, stride_cn) dst_view = iris.make_tensor_view(C, M, N, stride_cm, stride_cn) - ctx.reduce_scatter(tile_obj, src_view, dst_view, locks) + ctx.reduce_scatter(tile_obj, src_view, dst_view, locks, SIGNAL_VALUE) def matmul_reduce_scatter_preamble( @@ -173,8 +174,6 @@ def matmul_reduce_scatter_preamble( if workspace.locks is None or workspace.locks.numel() != total_tiles: workspace.locks = shmem.zeros((total_tiles,), dtype=torch.int32) - else: - workspace.locks.zero_() if workspace.aux_buffer is None or workspace.aux_buffer.shape != (M, N): workspace.aux_buffer = shmem.zeros((M, N), dtype=dtype) @@ -258,6 +257,10 @@ def matmul_reduce_scatter( even_k = K % config.block_size_k == 0 + # Increment call counter for producer-consumer signal value. + workspace.call_counter += 1 + signal_value = workspace.call_counter + iris_launch( _fused_matmul_reduce_scatter_kernel, grid, @@ -282,6 +285,7 @@ def matmul_reduce_scatter( config.block_size_n, config.block_size_k, even_k, + signal_value, algorithm="matmul_reduce_scatter", rank=rank, dtype=A.dtype, diff --git a/iris/ops/workspace.py b/iris/ops/workspace.py index a9c7cb616..a50f06bf4 100644 --- a/iris/ops/workspace.py +++ b/iris/ops/workspace.py @@ -42,6 +42,11 @@ class FusedWorkspace: aux_buffer: Optional[torch.Tensor] = None # Generic buffer for intermediate results locks: Optional[torch.Tensor] = None # Synchronization primitives + # Monotonic call counter for producer-consumer lock signaling. + # Each call uses a unique signal value, eliminating the need to zero locks + # between calls for one_shot/two_shot/reduce_scatter variants. + call_counter: int = 0 + prepared: bool = False def matches( @@ -82,4 +87,5 @@ def clear(self): """Free all allocated buffers.""" self.aux_buffer = None self.locks = None + self.call_counter = 0 self.prepared = False From 54da9646cdd2f85975de4a8fc2a3cdc04cbfdeee Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 2 May 2026 15:15:12 +0000 Subject: [PATCH 2/3] Apply Ruff auto-fixes --- iris/ops/matmul_all_reduce.py | 6 ++++-- iris/ops/matmul_reduce_scatter.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/iris/ops/matmul_all_reduce.py b/iris/ops/matmul_all_reduce.py index 494c8e61e..6d5ec6f6a 100644 --- a/iris/ops/matmul_all_reduce.py +++ b/iris/ops/matmul_all_reduce.py @@ -46,7 +46,7 @@ def _fused_matmul_all_reduce_kernel( BLOCK_SIZE_K: tl.constexpr, EVEN_K: tl.constexpr, VARIANT: tl.constexpr, - SIGNAL_VALUE = 1, + SIGNAL_VALUE=1, ): """ Fused GEMM + All-Reduce kernel with configurable all-reduce variant. @@ -134,7 +134,9 @@ def _fused_matmul_all_reduce_kernel( # Use atomic_xchg with release semantics to ensure memory ordering tile_id = pid_m * num_tiles_n + pid_n lock_ptr = locks + tile_id - tl.atomic_xchg(lock_ptr, SIGNAL_VALUE, sem="release", scope="sys") # Release ensures prior stores visible to remote GPUs + tl.atomic_xchg( + lock_ptr, SIGNAL_VALUE, sem="release", scope="sys" + ) # Release ensures prior stores visible to remote GPUs # Create source view only when needed (aux_buffer is not None) src_view = iris.make_tensor_view(aux_buffer, M, N, stride_cm, stride_cn) diff --git a/iris/ops/matmul_reduce_scatter.py b/iris/ops/matmul_reduce_scatter.py index 8f2f7a05c..0b172d1fa 100644 --- a/iris/ops/matmul_reduce_scatter.py +++ b/iris/ops/matmul_reduce_scatter.py @@ -45,7 +45,7 @@ def _fused_matmul_reduce_scatter_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, EVEN_K: tl.constexpr, - SIGNAL_VALUE = 1, + SIGNAL_VALUE=1, ): """ Fused GEMM + Reduce-Scatter kernel. From 10441496ba596187a6faeea021895059b4a1636f Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Sat, 2 May 2026 09:03:47 -0700 Subject: [PATCH 3/3] Address review feedback: rename SIGNAL_VALUE to signal_value, wrap at INT32_MAX - Rename kernel parameter from SIGNAL_VALUE (constexpr style) to signal_value (runtime parameter style) to avoid confusion with compile-time constants - Wrap call_counter at INT32_MAX since lock tensors are int32 Co-Authored-By: Claude Opus 4.6 --- iris/ops/matmul_all_reduce.py | 11 ++++++----- iris/ops/matmul_reduce_scatter.py | 9 +++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/iris/ops/matmul_all_reduce.py b/iris/ops/matmul_all_reduce.py index 6d5ec6f6a..a27983d2d 100644 --- a/iris/ops/matmul_all_reduce.py +++ b/iris/ops/matmul_all_reduce.py @@ -46,7 +46,7 @@ def _fused_matmul_all_reduce_kernel( BLOCK_SIZE_K: tl.constexpr, EVEN_K: tl.constexpr, VARIANT: tl.constexpr, - SIGNAL_VALUE=1, + signal_value=1, ): """ Fused GEMM + All-Reduce kernel with configurable all-reduce variant. @@ -135,16 +135,16 @@ def _fused_matmul_all_reduce_kernel( tile_id = pid_m * num_tiles_n + pid_n lock_ptr = locks + tile_id tl.atomic_xchg( - lock_ptr, SIGNAL_VALUE, sem="release", scope="sys" + lock_ptr, signal_value, sem="release", scope="sys" ) # Release ensures prior stores visible to remote GPUs # Create source view only when needed (aux_buffer is not None) src_view = iris.make_tensor_view(aux_buffer, M, N, stride_cm, stride_cn) if VARIANT == "one_shot": - ctx.all_reduce_one_shot(tile_obj, src_view, dst_view, locks, SIGNAL_VALUE) + ctx.all_reduce_one_shot(tile_obj, src_view, dst_view, locks, signal_value) elif VARIANT == "two_shot": - ctx.all_reduce_two_shot(tile_obj, src_view, dst_view, locks, SIGNAL_VALUE) + ctx.all_reduce_two_shot(tile_obj, src_view, dst_view, locks, signal_value) def matmul_all_reduce_preamble( @@ -338,7 +338,8 @@ def matmul_all_reduce( # Increment call counter for producer-consumer signal value. # Each call uses a unique value so consumers don't see stale signals. - workspace.call_counter += 1 + # Wrap at INT32_MAX since locks are int32 tensors. + workspace.call_counter = (workspace.call_counter % 0x7FFFFFFF) + 1 signal_value = workspace.call_counter iris_launch( diff --git a/iris/ops/matmul_reduce_scatter.py b/iris/ops/matmul_reduce_scatter.py index 0b172d1fa..6edee27b0 100644 --- a/iris/ops/matmul_reduce_scatter.py +++ b/iris/ops/matmul_reduce_scatter.py @@ -45,7 +45,7 @@ def _fused_matmul_reduce_scatter_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, EVEN_K: tl.constexpr, - SIGNAL_VALUE=1, + signal_value=1, ): """ Fused GEMM + Reduce-Scatter kernel. @@ -106,7 +106,7 @@ def _fused_matmul_reduce_scatter_kernel( # Signal tile is ready tile_id = pid_m * num_tiles_n + pid_n lock_ptr = locks + tile_id - tl.atomic_xchg(lock_ptr, SIGNAL_VALUE, sem="release", scope="gpu") + tl.atomic_xchg(lock_ptr, signal_value, sem="release", scope="gpu") # Create tile object and context tile_obj = iris.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) @@ -116,7 +116,7 @@ def _fused_matmul_reduce_scatter_kernel( src_view = iris.make_tensor_view(aux_buffer, M, N, stride_cm, stride_cn) dst_view = iris.make_tensor_view(C, M, N, stride_cm, stride_cn) - ctx.reduce_scatter(tile_obj, src_view, dst_view, locks, SIGNAL_VALUE) + ctx.reduce_scatter(tile_obj, src_view, dst_view, locks, signal_value) def matmul_reduce_scatter_preamble( @@ -258,7 +258,8 @@ def matmul_reduce_scatter( even_k = K % config.block_size_k == 0 # Increment call counter for producer-consumer signal value. - workspace.call_counter += 1 + # Wrap at INT32_MAX since locks are int32 tensors. + workspace.call_counter = (workspace.call_counter % 0x7FFFFFFF) + 1 signal_value = workspace.call_counter iris_launch(