-
Notifications
You must be signed in to change notification settings - Fork 39
Fix hardcoded lock value in fused GEMM+CCL operations #529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,7 @@ def _fused_matmul_reduce_scatter_kernel( | |
| BLOCK_SIZE_N: tl.constexpr, | ||
| BLOCK_SIZE_K: tl.constexpr, | ||
| EVEN_K: tl.constexpr, | ||
| signal_value=1, | ||
| ): | ||
| """ | ||
| Fused GEMM + Reduce-Scatter kernel. | ||
|
|
@@ -105,7 +106,7 @@ def _fused_matmul_reduce_scatter_kernel( | |
| # Signal tile is ready | ||
| tile_id = pid_m * num_tiles_n + pid_n | ||
| lock_ptr = locks + tile_id | ||
| tl.atomic_xchg(lock_ptr, 1, sem="release", scope="gpu") | ||
| tl.atomic_xchg(lock_ptr, signal_value, sem="release", scope="gpu") | ||
|
|
||
| # Create tile object and context | ||
| tile_obj = iris.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) | ||
|
|
@@ -115,7 +116,7 @@ def _fused_matmul_reduce_scatter_kernel( | |
| src_view = iris.make_tensor_view(aux_buffer, M, N, stride_cm, stride_cn) | ||
| dst_view = iris.make_tensor_view(C, M, N, stride_cm, stride_cn) | ||
|
|
||
| ctx.reduce_scatter(tile_obj, src_view, dst_view, locks) | ||
| ctx.reduce_scatter(tile_obj, src_view, dst_view, locks, signal_value) | ||
|
|
||
|
|
||
| def matmul_reduce_scatter_preamble( | ||
|
|
@@ -173,8 +174,6 @@ def matmul_reduce_scatter_preamble( | |
|
|
||
| if workspace.locks is None or workspace.locks.numel() != total_tiles: | ||
| workspace.locks = shmem.zeros((total_tiles,), dtype=torch.int32) | ||
| else: | ||
| workspace.locks.zero_() | ||
|
|
||
| if workspace.aux_buffer is None or workspace.aux_buffer.shape != (M, N): | ||
| workspace.aux_buffer = shmem.zeros((M, N), dtype=dtype) | ||
|
|
@@ -258,6 +257,11 @@ def matmul_reduce_scatter( | |
|
|
||
| even_k = K % config.block_size_k == 0 | ||
|
|
||
| # Increment call counter for producer-consumer signal value. | ||
| # Wrap at INT32_MAX since locks are int32 tensors. | ||
| workspace.call_counter = (workspace.call_counter % 0x7FFFFFFF) + 1 | ||
| signal_value = workspace.call_counter | ||
|
Comment on lines
+260
to
+263
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Workspaces are per-process objects — in distributed training, each rank runs its own Python process with its own workspace instance. The call_counter increments identically on all ranks because fused ops are collective (all ranks must call them together). Divergence would indicate a program bug (one rank skipping a collective call), which would deadlock regardless of the signal value. |
||
|
|
||
| iris_launch( | ||
| _fused_matmul_reduce_scatter_kernel, | ||
| grid, | ||
|
|
@@ -282,6 +286,7 @@ def matmul_reduce_scatter( | |
| config.block_size_n, | ||
| config.block_size_k, | ||
| even_k, | ||
| signal_value, | ||
| algorithm="matmul_reduce_scatter", | ||
| rank=rank, | ||
| dtype=A.dtype, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 1044149 — call_counter now wraps at INT32_MAX (0x7FFFFFFF).