Fused all-gather+GEMM HBM-buffer kernel for iris.ops#346
Conversation
…yaswann/iris_xops_perf
…yaswann/iris_xops_perf
…yaswann/iris_xops_perf
…yaswann/iris_xops_perf
…yaswann/iris_xops_perf
… use randn in benchmark Agent-Logs-Url: https://github.com/ROCm/iris/sessions/6b5a926e-b45d-43ac-abac-7b8805d2d367 Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
…dn inputs Agent-Logs-Url: https://github.com/ROCm/iris/sessions/6b5a926e-b45d-43ac-abac-7b8805d2d367 Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
…iris tensors Agent-Logs-Url: https://github.com/ROCm/iris/sessions/6b5a926e-b45d-43ac-abac-7b8805d2d367 Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Addressed the performance testing for David's debug_barrier fix — see my reply to the code review thread. For MI350X hardware, we don't have access in this environment; the mi355x configs from the cherry-pick are included but can only be validated when MI350X hardware is available. |
…00X) Agent-Logs-Url: https://github.com/ROCm/iris/sessions/306af9bb-b705-4c03-9d49-a31600e6df43 Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
|
@copilot remove the charts you uploaded and get this PR ready for final review. |
…at; add bar chart ws8 Agent-Logs-Url: https://github.com/ROCm/iris/sessions/84eadb5a-429c-47b4-9636-dc3a70726a1a Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
|
@copilot resolve the merge conflicts in this pull request |
…chart images Co-authored-by: ryanswann-amd <109695074+ryanswann-amd@users.noreply.github.com>
Done in commit 801adb7. All 4 chart images removed from
Merge conflicts resolved in commit 801adb7. Changes made:
|
|
@copilot can you expand test coverage for the kernels added in this PR? |
Agent-Logs-Url: https://github.com/ROCm/iris/sessions/fc86711f-8150-48e4-85d6-c4c72c9f26b9 Co-authored-by: ryanswann-amd <109695074+ryanswann-amd@users.noreply.github.com>
Expanded in commit 63d84c4. Added 16 new tests across two categories: Distributed kernel tests:
Pure-Python
All 31 tests pass. |
|
@copilot I identified that the linked diff fixes the CI changes. Can you cherry pick minimal changes to fix the CI error: ryanswann-amd/iris@39c213d...5f30678 |
…uction Cherry-picked from ryanswann-amd/iris@5f30678. Two bugs caused deadlocks when iris.iris() is constructed/destroyed repeatedly (parametrized tests): 1. NCCL collective ordering: replaced _allgather_paths_tensor (which still uses dist.all_gather NCCL collective) with _allgather_paths_store using the dist.Store key-value API (pure TCP, zero NCCL). 2. Stale socket reuse: added per-process monotonic instance_id to socket paths and eager socket cleanup via SymmetricHeap.close_fd_conns(). Agent-Logs-Url: https://github.com/ROCm/iris/sessions/4824970a-ba41-4384-ab0e-650090f96f71 Co-authored-by: ryanswann-amd <109695074+ryanswann-amd@users.noreply.github.com>
Cherry-picked in commit 677cb6f. Applied the minimal changes from
|
Adds
all_gather_matmul_hbm_buffer: a fused kernel that pipelines all-gather and GEMM by splitting workgroups into dedicated fetchers and GEMM workers. Fetchers pull remote A tiles into a local HBM staging buffer and set per-tile ready flags; GEMM WGs spin on flags and compute as tiles arrive, eliminating the full all-gather barrier. Delivers 2.7–3.4× lower latency vs the barrier-based baseline on 8× MI325X.New kernel
iris/ops/all_gather_matmul_hbm_buffer.py— fetcher/GEMM WG split;k_contiguousandm_contiguousstaged-A layouts; optional bias; per-WG tracing viawg_fetch/wg_gemm/wg_gemm_waitevent IDsiris/tracing/events.py— trace event IDs for per-workgroup profilingAPI / config changes
iris/x/gather.py—hintvectorization parameter forwarded to_translate()iris/ops/__init__.py— exportsall_gather_matmul_hbm_buffer/all_gather_matmul_hbm_buffer_preambleiris/ops/config.py— removed unusedall_gather_matmul_variantfield and dead "push" workspace allocation fromall_gather_matmul_preambleBenchmark & tests
benchmark/ops/bench_all_gather_matmul.py— merged baseline and HBM-buffer variants under@bench.axis("algorithm", ["baseline", "hbm_buffer"]);bench_all_gather_matmul_hbm_buffer.pydeletedtests/ops/test_all_gather_matmul.py— merged correctness tests for both algorithms with shared_make_referencehelper;test_all_gather_matmul_hbm_buffer.pydeletedResults (8× AMD MI325X, float16, N=3584, K=8192)