Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions iris/mem/triton/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def all_reduce_spinlock(self, tile: Tile, dst_view: TensorView, locks):
self.atomic_xchg(locks + tile_id, 0, to_rank=dest_rank, sem="release", scope="sys")

@triton.jit
def all_reduce_one_shot(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks):
def all_reduce_one_shot(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks, signal_value=1):
"""
Tile-level all-reduce using one-shot algorithm.

Expand All @@ -743,6 +743,9 @@ def all_reduce_one_shot(self, tile: Tile, src_view: TensorView, dst_view: Tensor
src_view: TensorView for source tensor (to load remote data).
dst_view: TensorView for output tensor.
locks: Pointer to lock array used as ready flags.
signal_value: Expected lock value indicating tile is ready.
Use a monotonically increasing counter to avoid zeroing
locks between calls.
"""
src_tile_ptr, mask = src_view.tile_ptr(tile)
dst_tile_ptr, _ = dst_view.tile_ptr(tile)
Expand All @@ -755,7 +758,7 @@ def all_reduce_one_shot(self, tile: Tile, src_view: TensorView, dst_view: Tensor
for remote_rank in range(self.world_size):
if remote_rank != self.rank:
lock_ptr = locks + tile_id
while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="sys") != 1:
while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="sys") != signal_value:
pass
partial = self.load(src_tile_ptr, from_rank=remote_rank, mask=mask)
acc += partial.to(acc_dtype)
Expand Down Expand Up @@ -799,7 +802,7 @@ def all_reduce_ring(self, tile: Tile, src_view: TensorView, dst_view: TensorView
tl.store(dst_tile_ptr, remote_result, mask=mask)

@triton.jit
def all_reduce_two_shot(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks):
def all_reduce_two_shot(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks, signal_value=1):
"""
Tile-level all-reduce using two-shot algorithm with work distribution.

Expand All @@ -811,6 +814,9 @@ def all_reduce_two_shot(self, tile: Tile, src_view: TensorView, dst_view: Tensor
src_view: TensorView for source tensor.
dst_view: TensorView for output tensor.
locks: Pointer to lock array used as ready flags.
signal_value: Expected lock value indicating tile is ready.
Use a monotonically increasing counter to avoid zeroing
locks between calls.
"""
num_tiles_n = tl.cdiv(dst_view.N, tile.block_n)
tile_id = tile.pid_m * num_tiles_n + tile.pid_n
Expand All @@ -826,7 +832,7 @@ def all_reduce_two_shot(self, tile: Tile, src_view: TensorView, dst_view: Tensor
for remote_rank in range(self.world_size):
if remote_rank != self.rank:
lock_ptr = locks + tile_id
while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="sys") != 1:
while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="sys") != signal_value:
pass
partial = self.load(src_tile_ptr, from_rank=remote_rank, mask=mask)
acc += partial.to(acc_dtype)
Expand Down Expand Up @@ -951,7 +957,7 @@ def all_to_all(self, tile: TileView, src_view: TensorView, dst_view: TensorView,
tl.store(dst_view.ptr + dst_offsets, data, mask=combined_mask)

@triton.jit
def reduce_scatter(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks):
def reduce_scatter(self, tile: Tile, src_view: TensorView, dst_view: TensorView, locks, signal_value=1):
"""
Tile-level reduce-scatter using contiguous work distribution.

Expand All @@ -963,6 +969,9 @@ def reduce_scatter(self, tile: Tile, src_view: TensorView, dst_view: TensorView,
src_view: TensorView for source tensor.
dst_view: TensorView for output tensor.
locks: Pointer to lock array used as ready flags.
signal_value: Expected lock value indicating tile is ready.
Use a monotonically increasing counter to avoid zeroing
locks between calls.
"""
num_tiles_n = tl.cdiv(dst_view.N, tile.block_n)
num_tiles_m = tl.cdiv(dst_view.M, tile.block_m)
Expand All @@ -988,7 +997,7 @@ def reduce_scatter(self, tile: Tile, src_view: TensorView, dst_view: TensorView,
for remote_rank in range(self.world_size):
if remote_rank != self.rank:
lock_ptr = locks + tile_id
while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="gpu") != 1:
while self.atomic_add(lock_ptr, 0, to_rank=remote_rank, sem="acquire", scope="gpu") != signal_value:
pass
partial = self.load(src_tile_ptr, from_rank=remote_rank, mask=mask)
acc += partial.to(acc_dtype)
Expand Down
20 changes: 16 additions & 4 deletions iris/ops/matmul_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def _fused_matmul_all_reduce_kernel(
BLOCK_SIZE_K: tl.constexpr,
EVEN_K: tl.constexpr,
VARIANT: tl.constexpr,
signal_value=1,
):
"""
Fused GEMM + All-Reduce kernel with configurable all-reduce variant.
Expand Down Expand Up @@ -133,15 +134,17 @@ def _fused_matmul_all_reduce_kernel(
# Use atomic_xchg with release semantics to ensure memory ordering
tile_id = pid_m * num_tiles_n + pid_n
lock_ptr = locks + tile_id
tl.atomic_xchg(lock_ptr, 1, sem="release", scope="sys") # Release ensures prior stores visible to remote GPUs
tl.atomic_xchg(
lock_ptr, signal_value, sem="release", scope="sys"
) # Release ensures prior stores visible to remote GPUs

# Create source view only when needed (aux_buffer is not None)
src_view = iris.make_tensor_view(aux_buffer, M, N, stride_cm, stride_cn)

if VARIANT == "one_shot":
ctx.all_reduce_one_shot(tile_obj, src_view, dst_view, locks)
ctx.all_reduce_one_shot(tile_obj, src_view, dst_view, locks, signal_value)
elif VARIANT == "two_shot":
ctx.all_reduce_two_shot(tile_obj, src_view, dst_view, locks)
ctx.all_reduce_two_shot(tile_obj, src_view, dst_view, locks, signal_value)


def matmul_all_reduce_preamble(
Expand Down Expand Up @@ -196,7 +199,9 @@ def matmul_all_reduce_preamble(
if config.all_reduce_variant in ["spinlock", "one_shot", "two_shot"]:
if workspace.locks is None or workspace.locks.numel() != total_tiles:
workspace.locks = shmem.zeros((total_tiles,), dtype=torch.int32)
else:
elif config.all_reduce_variant == "spinlock":
# Spinlock uses CAS(0→1) and releases back to 0, so needs zeroing.
# one_shot/two_shot use monotonic signal_value, no zeroing needed.
workspace.locks.zero_()
else:
workspace.locks = None
Expand Down Expand Up @@ -331,6 +336,12 @@ def matmul_all_reduce(

even_k = K % config.block_size_k == 0

# Increment call counter for producer-consumer signal value.
# Each call uses a unique value so consumers don't see stale signals.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 1044149 — call_counter now wraps at INT32_MAX (0x7FFFFFFF).

# Wrap at INT32_MAX since locks are int32 tensors.
workspace.call_counter = (workspace.call_counter % 0x7FFFFFFF) + 1
signal_value = workspace.call_counter

iris_launch(
_fused_matmul_all_reduce_kernel,
grid,
Expand All @@ -356,6 +367,7 @@ def matmul_all_reduce(
config.block_size_k,
even_k,
config.all_reduce_variant,
signal_value,
algorithm="matmul_all_reduce",
rank=rank,
dtype=A.dtype,
Expand Down
13 changes: 9 additions & 4 deletions iris/ops/matmul_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _fused_matmul_reduce_scatter_kernel(
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
EVEN_K: tl.constexpr,
signal_value=1,
):
"""
Fused GEMM + Reduce-Scatter kernel.
Expand Down Expand Up @@ -105,7 +106,7 @@ def _fused_matmul_reduce_scatter_kernel(
# Signal tile is ready
tile_id = pid_m * num_tiles_n + pid_n
lock_ptr = locks + tile_id
tl.atomic_xchg(lock_ptr, 1, sem="release", scope="gpu")
tl.atomic_xchg(lock_ptr, signal_value, sem="release", scope="gpu")

# Create tile object and context
tile_obj = iris.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c)
Expand All @@ -115,7 +116,7 @@ def _fused_matmul_reduce_scatter_kernel(
src_view = iris.make_tensor_view(aux_buffer, M, N, stride_cm, stride_cn)
dst_view = iris.make_tensor_view(C, M, N, stride_cm, stride_cn)

ctx.reduce_scatter(tile_obj, src_view, dst_view, locks)
ctx.reduce_scatter(tile_obj, src_view, dst_view, locks, signal_value)


def matmul_reduce_scatter_preamble(
Expand Down Expand Up @@ -173,8 +174,6 @@ def matmul_reduce_scatter_preamble(

if workspace.locks is None or workspace.locks.numel() != total_tiles:
workspace.locks = shmem.zeros((total_tiles,), dtype=torch.int32)
else:
workspace.locks.zero_()

if workspace.aux_buffer is None or workspace.aux_buffer.shape != (M, N):
workspace.aux_buffer = shmem.zeros((M, N), dtype=dtype)
Expand Down Expand Up @@ -258,6 +257,11 @@ def matmul_reduce_scatter(

even_k = K % config.block_size_k == 0

# Increment call counter for producer-consumer signal value.
# Wrap at INT32_MAX since locks are int32 tensors.
workspace.call_counter = (workspace.call_counter % 0x7FFFFFFF) + 1
signal_value = workspace.call_counter
Comment on lines +260 to +263
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workspaces are per-process objects — in distributed training, each rank runs its own Python process with its own workspace instance. The call_counter increments identically on all ranks because fused ops are collective (all ranks must call them together). Divergence would indicate a program bug (one rank skipping a collective call), which would deadlock regardless of the signal value.


iris_launch(
_fused_matmul_reduce_scatter_kernel,
grid,
Expand All @@ -282,6 +286,7 @@ def matmul_reduce_scatter(
config.block_size_n,
config.block_size_k,
even_k,
signal_value,
algorithm="matmul_reduce_scatter",
rank=rank,
dtype=A.dtype,
Expand Down
6 changes: 6 additions & 0 deletions iris/ops/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class FusedWorkspace:
aux_buffer: Optional[torch.Tensor] = None # Generic buffer for intermediate results
locks: Optional[torch.Tensor] = None # Synchronization primitives

# Monotonic call counter for producer-consumer lock signaling.
# Each call uses a unique signal value, eliminating the need to zero locks
# between calls for one_shot/two_shot/reduce_scatter variants.
call_counter: int = 0

prepared: bool = False

def matches(
Expand Down Expand Up @@ -82,4 +87,5 @@ def clear(self):
"""Free all allocated buffers."""
self.aux_buffer = None
self.locks = None
self.call_counter = 0
self.prepared = False
Loading