Skip to content

Conversation

@neoblizz
Copy link
Member

@neoblizz neoblizz commented Nov 1, 2025

Summary

This pull request introduces a new Collective Communication Library (CCL) for Iris, providing standalone collective primitives such as all-to-all communication, along with supporting infrastructure and documentation updates. The changes focus on enabling high-performance distributed collective operations that match PyTorch's RCCL/NCCL interface, including benchmarking and validation tools.

Collective Communication Library Implementation

  • Added the new iris/ccl package, including the all_to_all collective operation (iris/ccl/all_to_all.py) and a configuration structure for kernel tuning (iris/ccl/config.py). These modules provide a flexible and efficient interface for distributed tensor communication. [1] [2]
  • Introduced the top-level iris/ccl/__init__.py to expose the collective primitives and configuration, matching PyTorch's interface for easy adoption.

Benchmarking and Validation Tools

  • Added a comprehensive benchmark script examples/ccl/benchmark.py to measure bandwidth and validate correctness of the all-to-all operation, supporting multiple datatypes and configurable parameters.

Containerization and Environment Setup

  • Created a new Dockerfile docker/Dockerfile.ccl to provide a ready-to-use environment for CCL development and validation, including installation of dependencies, Triton, ROCm tools, and entrypoint scripts for testing.

Documentation Updates

  • Updated examples/README.md to document the new ccl directory and provide usage instructions for benchmarking the all-to-all collective operation. [1] [2]

Test Infrastructure

  • Added a test package initializer for future CCL tests (tests/ccl/__init__.py).

Submission Checklist

@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Nov 1, 2025
@neoblizz neoblizz marked this pull request as ready for review November 1, 2025 05:17
Copilot AI review requested due to automatic review settings November 1, 2025 05:17
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces the iris-ccl (Collective Communication Library) module with an all-to-all collective operation. The implementation provides a standalone collective primitive matching PyTorch's RCCL/NCCL interface, enabling efficient multi-rank tensor exchange operations.

  • Adds iris.ccl module with all_to_all collective operation and Config dataclass for kernel parameters
  • Implements persistent kernel using Triton with remote PUT operations for cross-rank communication
  • Includes comprehensive test suite, benchmark utilities, and Docker support for validation

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
iris/ccl/init.py Module initialization exposing all_to_all and Config
iris/ccl/config.py Configuration dataclass with auto-detection of XCD count and validation
iris/ccl/all_to_all.py Core all-to-all implementation using persistent Triton kernel with remote PUTs
tests/ccl/init.py Test module initialization
tests/ccl/test_all_to_all.py Parametrized test suite for all-to-all with various dtypes and sizes
examples/ccl/benchmark.py Benchmark script with validation and performance measurement
examples/README.md Documentation update adding ccl examples
docker/Dockerfile.ccl Docker configuration for ccl validation environment
examples/17_gemm_one_shot_all_reduce_pc/gemm_one_shot_all_reduce_pc.py Synchronization bug fix replacing xchg with add and improving memory fences

if target_rank == cur_rank:
# Local path: copy input[cur_rank] chunk to output[cur_rank] chunk
data = tl.load(input_ptr + input_offset_send, mask=mask)
output_offset_local = rm[:, None] * stride_out_m + (rn[None, :] + cur_rank * N) * stride_out_n
Copy link

Copilot AI Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .wt (write-through) cache modifier is used here but not for remote PUTs on line 97-104. Consider adding a comment explaining why write-through caching is only appropriate for local stores and not remote operations.

Suggested change
output_offset_local = rm[:, None] * stride_out_m + (rn[None, :] + cur_rank * N) * stride_out_n
output_offset_local = rm[:, None] * stride_out_m + (rn[None, :] + cur_rank * N) * stride_out_n
# Use write-through cache modifier only for local stores.
# Write-through caching is not used for remote PUTs because remote memory operations
# may not support cache modifiers and could result in undefined behavior or inefficiency.

Copilot uses AI. Check for mistakes.

# Clone and install Triton
WORKDIR $TRITON_PATH
RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH
Copy link

Copilot AI Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The git commit hash should include a comment explaining why this specific Triton commit is being used (e.g., compatibility requirements, specific feature needs, or bug fixes).

Suggested change
RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH
RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH
# Pin Triton to commit dd58234 for ROCm 6.3 compatibility and MI300X/MI350X support.

Copilot uses AI. Check for mistakes.
@neoblizz neoblizz requested a review from Copilot November 3, 2025 16:35
Copy link
Collaborator

@mawad-amd mawad-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good so far. Thanks for adding this.


def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
"""Worker function for PyTorch distributed execution."""
backend = "nccl" if torch.cuda.is_available() else "gloo"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would really like to separate timing/benchmarking code into the benchmarking directory and have this as an example with no timing or anything. We can discuss if this would not be ideal but I could see a wrapper around all_to_all.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how that would work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@neoblizz neoblizz requested review from Copilot and mawad-amd November 9, 2025 17:44
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 15 out of 15 changed files in this pull request and generated 15 comments.

if not dist.is_initialized():
pytest.skip("torch.distributed not initialized")

heap_size = 2**33 # 1GB
Copy link

Copilot AI Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "1GB" but the heap size is 2^33 bytes which equals 8GB. Either the comment should be updated to "8GB" or the heap size should be changed to 2^30 for 1GB.

Suggested change
heap_size = 2**33 # 1GB
heap_size = 2**30 # 1GB

Copilot uses AI. Check for mistakes.
all_reduce_variant: str = "atomic"
all_reduce_distribution: int = 0
all_reduce_num_rings: int = 1
all_reduce_ring_slice_n: int | None = None
Copy link

Copilot AI Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint int | None uses Python 3.10+ union syntax. For compatibility with Python 3.9 and earlier, either use Optional[int] from typing (which is already imported at line 10 for other purposes in this file), or add from __future__ import annotations at the top of the file. Consider using Optional[int] to maintain consistency with common Python practices.

Suggested change
all_reduce_ring_slice_n: int | None = None
all_reduce_ring_slice_n: Optional[int] = None

Copilot uses AI. Check for mistakes.

# Clone and install Triton
WORKDIR $TRITON_PATH
RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH
Copy link

Copilot AI Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The git checkout command on line 34 will fail because it's run in the wrong context. The git clone on line 33 already places the repository in $TRITON_PATH, but then you need to be inside that directory to run git checkout. However, since git clone already executed and the WORKDIR is set to $TRITON_PATH on line 32, the checkout command should work. But there's a logical issue: line 33 clones into the current directory (which is already $TRITON_PATH from line 32), effectively cloning into itself. This should be either git clone https://github.com/triton-lang/triton.git . or the WORKDIR should be set differently.

Suggested change
RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH
RUN git clone https://github.com/triton-lang/triton.git .

Copilot uses AI. Check for mistakes.
rn = gl.max_contiguous(gl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)

# Pre-compute base offsets - maximize VGPR usage by keeping all offsets in registers
row_offsets_m = rm * stride_in_m
Copy link

Copilot AI Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable row_offsets_m is not used.

Suggested change
row_offsets_m = rm * stride_in_m

Copilot uses AI. Check for mistakes.

# Pre-compute base offsets - maximize VGPR usage by keeping all offsets in registers
row_offsets_m = rm * stride_in_m
row_offsets_out_m = rm * stride_out_m
Copy link

Copilot AI Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable row_offsets_out_m is not used.

Suggested change
row_offsets_out_m = rm * stride_out_m

Copilot uses AI. Check for mistakes.
@neoblizz neoblizz requested a review from Copilot November 9, 2025 18:04
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.

expected_slice = config.block_size_n // world_size
if slice_n is None or slice_n * world_size != config.block_size_n:
slice_n = expected_slice
config.all_reduce_ring_slice_n = slice_n
Copy link

Copilot AI Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutating the config object passed by the user can lead to unexpected behavior. Consider creating a copy of the config or documenting that the config object may be modified. This could cause issues when the same config is reused across multiple calls with different world_size values.

Copilot uses AI. Check for mistakes.

# Optimization to vectorize the load/store - similar to iris.py
# This enables the compiler to generate dwordx4 or wider loads
# Note: Gluon uses scalar multiples, not 2D tuples like Triton
Copy link

Copilot AI Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Commented-out optimization code should either be removed or have a clear explanation of why it's commented out and under what conditions it should be enabled. If this is a work-in-progress optimization, consider using a TODO comment or feature flag instead.

Suggested change
# Note: Gluon uses scalar multiples, not 2D tuples like Triton
# Note: Gluon uses scalar multiples, not 2D tuples like Triton
# TODO: Enable the following optimization once Gluon supports pointer alignment
# and vectorized memory accesses in the same way as Triton. Currently disabled
# due to potential incompatibility or lack of support in Gluon for these features.

Copilot uses AI. Check for mistakes.
if config.use_gluon and GLUON_AVAILABLE:
# Check if shmem is Iris Gluon (has get_device_context method)
if not hasattr(shmem, 'get_device_context'):
raise ValueError("use_gluon=True requires Iris Gluon context. Use iris.experimental.iris_gluon.iris()")
Copy link

Copilot AI Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message suggests using iris.experimental.iris_gluon.iris() but the correct import path is import iris.experimental.iris_gluon as iris_gluon followed by iris_gluon.iris(). Consider updating the message to: "use_gluon=True requires Iris Gluon context. Use iris_gluon.iris() where iris_gluon is imported from iris.experimental.iris_gluon"

Suggested change
raise ValueError("use_gluon=True requires Iris Gluon context. Use iris.experimental.iris_gluon.iris()")
raise ValueError("use_gluon=True requires Iris Gluon context. Use iris_gluon.iris() where iris_gluon is imported from iris.experimental.iris_gluon")

Copilot uses AI. Check for mistakes.
shmem.barrier()
config = Config(all_reduce_variant=variant)
if variant == "two_shot":
# Test both distribution modes for two_shot
Copy link

Copilot AI Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says "Test both distribution modes for two_shot" but the code only tests striding mode (distribution=0). Either remove the misleading comment or test both modes (0 and 1) for the two_shot variant. There's a separate test function test_all_reduce_two_shot_distribution that tests both modes, so this comment is misleading.

Suggested change
# Test both distribution modes for two_shot

Copilot uses AI. Check for mistakes.
# Remote store offset: write into target's output at columns [cur_rank*N : (cur_rank+1)*N]
# This is constant for all target_rank iterations since it only depends on cur_rank
output_offset_remote = output_base_m + (output_base_n + cur_rank * N * stride_out_n)
output_ptr_remote = tl.multiple_of(output_ptr + output_offset_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N))
Copy link

Copilot AI Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable output_ptr_remote is not used.

Suggested change
output_ptr_remote = tl.multiple_of(output_ptr + output_offset_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N))

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants