Context
During the CCL split refactor (#523), Copilot code review flagged several pre-existing issues in kernel code. These exist on main today and were not introduced by the split. Filing here to track as follow-up work.
Issues
1. Gluon all-to-all: % M/% N modulo wrapping on boundary tiles
File: iris/ccl/gluon/all_to_all.py (lines 94, 108, 113)
rm = (...) % M and rn = (...) % N causes boundary tiles to wrap back to the start of the tensor. After wrapping, col_mask = rn < N is trivially true, so out-of-bounds elements are never masked. Could produce incorrect results when M or N is not an exact multiple of block size.
Fix: Remove % M/% N, use proper row/col masks like the Triton backend. Needs GPU testing to verify the modulo isn't intentional for gluon's masking model.
2. Gluon chiplet_transform: > vs >= off-by-one
File: iris/ccl/gluon/all_to_all.py (line 25)
if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size):
Should be >= — the first tail PID should not be transformed. Same pattern exists in iris/ccl/utils.py.
3. Triton all-gather: unused NUM_XCDS/CHUNK_SIZE parameters
File: iris/ccl/triton/all_gather.py (lines 35, 61)
Both persistent_all_gather and persistent_all_gather_partitioned accept NUM_XCDS/CHUNK_SIZE but never call chiplet_transform_chunked. Other backends (all_reduce, reduce_scatter, all_to_all) all apply the transform. Either apply it or remove the parameters.
4. Ring all-reduce: SLICE_SIZE_N passed but unused
File: iris/ccl/triton/all_reduce.py (lines 447, 454, 573)
SLICE_SIZE_N is passed to the ring kernel and launch() enforces slice-related constraints, but the kernel body ignores it and always uses BLOCK_SIZE_N. Either implement column slicing or remove the parameter + validation.
Testing
All fixes need multi-GPU testing on MI300X+ hardware:
torchrun --nproc_per_node=8 tests/run_tests_distributed.py tests/ccl/ -v
Context
During the CCL split refactor (#523), Copilot code review flagged several pre-existing issues in kernel code. These exist on
maintoday and were not introduced by the split. Filing here to track as follow-up work.Issues
1. Gluon all-to-all:
% M/% Nmodulo wrapping on boundary tilesFile:
iris/ccl/gluon/all_to_all.py(lines 94, 108, 113)rm = (...) % Mandrn = (...) % Ncauses boundary tiles to wrap back to the start of the tensor. After wrapping,col_mask = rn < Nis trivially true, so out-of-bounds elements are never masked. Could produce incorrect results when M or N is not an exact multiple of block size.Fix: Remove
% M/% N, use proper row/col masks like the Triton backend. Needs GPU testing to verify the modulo isn't intentional for gluon's masking model.2. Gluon chiplet_transform:
>vs>=off-by-oneFile:
iris/ccl/gluon/all_to_all.py(line 25)Should be
>=— the first tail PID should not be transformed. Same pattern exists iniris/ccl/utils.py.3. Triton all-gather: unused
NUM_XCDS/CHUNK_SIZEparametersFile:
iris/ccl/triton/all_gather.py(lines 35, 61)Both
persistent_all_gatherandpersistent_all_gather_partitionedacceptNUM_XCDS/CHUNK_SIZEbut never callchiplet_transform_chunked. Other backends (all_reduce, reduce_scatter, all_to_all) all apply the transform. Either apply it or remove the parameters.4. Ring all-reduce:
SLICE_SIZE_Npassed but unusedFile:
iris/ccl/triton/all_reduce.py(lines 447, 454, 573)SLICE_SIZE_Nis passed to the ring kernel andlaunch()enforces slice-related constraints, but the kernel body ignores it and always usesBLOCK_SIZE_N. Either implement column slicing or remove the parameter + validation.Testing
All fixes need multi-GPU testing on MI300X+ hardware: