Skip to content

Conversation

@mawad-amd
Copy link
Collaborator

Motivation

Add RDMA + Proxy thread backend

Technical Details

  1. CPU-GPU queue
  2. Device-side building of CPU-GPU packet
  3. Proxy thread talking to the NIC

Unclear yet how to merge this backend into Iris RMA backend but would like a single backend for both.

Test Plan

Test Result

Submission Checklist

@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Oct 31, 2025
@mawad-amd mawad-amd requested a review from Copilot November 7, 2025 23:24
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds experimental InfiniBand RDMA (Remote Direct Memory Access) support to Iris for multi-node GPU communication. The implementation provides a symmetric heap model with RDMA operations (put/get/atomics) accessible from Triton kernels, using PyTorch distributed for bootstrapping and InfiniBand for high-performance inter-node communication.

Key changes:

  • RDMA backend with InfiniBand support (optional build via CMake)
  • CPU-GPU work queue for asynchronous RDMA operations
  • Triton device APIs for RDMA put/get/atomic operations with symmetric heap addressing

Reviewed Changes

Copilot reviewed 22 out of 22 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
setup.py Adds CMake build system for optional RDMA C++ extension with InfiniBand detection
iris/experimental/iris_rdma.py Main Python API providing RDMA context, symmetric heap, and Triton device APIs
iris/experimental/iris_rdma/python/bindings.cpp PyBind11 bindings exposing C++ RDMA backend to Python
iris/experimental/iris_rdma/src/*.hpp C++ implementation: network backend, queue pairs, work queue, proxy thread, logging
iris/experimental/init.py Exports iris_rdma module with optional import handling
examples/22-24_rdma_* Example programs demonstrating producer-consumer, GET, and atomic operations
docker/* Updated Dockerfile and scripts with InfiniBand device support
run.sh, rebuild.sh Helper scripts for running and rebuilding
Comments suppressed due to low confidence (1)

iris/experimental/iris_rdma/src/iris_manager.hpp:1

  • Corrected spelling of 'its' to 'it's' in comment.
// SPDX-License-Identifier: MIT

Comment on lines +101 to +110
void dump_cq_info() const {
LOG_DEBUG("cq: %p", cq_);
LOG_DEBUG("handle: %u", cq_->channel);
LOG_DEBUG("cq_context: %p", cq_->cq_context);
LOG_DEBUG("context: %p", cq_->context);
LOG_DEBUG("cqe: %u", cq_->cqe);
LOG_DEBUG("comp_events_completed: %u", cq_->comp_events_completed);
LOG_DEBUG("async_events_completed: %u", cq_->async_events_completed);

}
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dump_cq_info() function appears to be a debugging utility that is called in production code (see network_backend.hpp line 489). Consider removing this call from the hot path (poll_cq) or guarding it behind a debug flag to avoid performance overhead in production.

Copilot uses AI. Check for mistakes.
Comment on lines +690 to +691
int sq_length = 64; // Send queue length // TODO: FIX THAT

Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment "FIX THAT" is vague and doesn't explain what needs to be fixed. Consider clarifying what specific issue needs to be addressed (e.g., "TODO: Make queue length configurable" or "TODO: Calculate optimal queue length based on workload").

Suggested change
int sq_length = 64; // Send queue length // TODO: FIX THAT
int sq_length = 64; // Send queue length
// TODO: Make send queue length (sq_length) configurable or calculate based on workload/device capabilities

Copilot uses AI. Check for mistakes.
std::this_thread::sleep_for(std::chrono::microseconds(10));
}
if (n <= 0) {
LOG_DEBUG("Warning: PUT completion not polled (may be OK if async)");
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using LOG_DEBUG for a warning message is inconsistent. Consider using LOG_WARN for warning messages to maintain proper log level semantics.

Suggested change
LOG_DEBUG("Warning: PUT completion not polled (may be OK if async)");
LOG_WARN("Warning: PUT completion not polled (may be OK if async)");

Copilot uses AI. Check for mistakes.
std::this_thread::sleep_for(std::chrono::microseconds(10));
}
if (n <= 0) {
LOG_DEBUG("Warning: GET completion not polled (may be OK if async)");
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to Comment 4, these warning messages use LOG_DEBUG instead of LOG_WARN. For consistency and proper log level semantics, warnings should use LOG_WARN.

Suggested change
LOG_DEBUG("Warning: GET completion not polled (may be OK if async)");
LOG_WARN("Warning: GET completion not polled (may be OK if async)");

Copilot uses AI. Check for mistakes.
std::this_thread::sleep_for(std::chrono::microseconds(10));
}
if (n <= 0) {
LOG_DEBUG("Warning: ATOMIC_EXCH completion not polled (may be OK if async)");
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to Comment 4, these warning messages use LOG_DEBUG instead of LOG_WARN. For consistency and proper log level semantics, warnings should use LOG_WARN.

Copilot uses AI. Check for mistakes.
# Extract source address (min of pointer block where data is stored)
src_ptr_u64 = src_ptr.to(tl.uint64)
src_ptr_val = tl.min(src_ptr_u64, axis=0)
max_src_ptr = tl.max(src_ptr_u64, axis=0)
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable max_src_ptr is not used.

Suggested change
max_src_ptr = tl.max(src_ptr_u64, axis=0)

Copilot uses AI. Check for mistakes.

def build_extension(self, ext):
if not isinstance(ext, CMakeExtension):
return super().build_extension(ext)
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.

Suggested change
return super().build_extension(ext)
super().build_extension(ext)

Copilot uses AI. Check for mistakes.
import triton
import triton.language as tl
import numpy as np
import sys
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'sys' is not used.

Suggested change
import sys

Copilot uses AI. Check for mistakes.
import torch.distributed as dist
import triton
import triton.language as tl
import time
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'time' is not used.

Suggested change
import time

Copilot uses AI. Check for mistakes.
import torch.distributed as dist
import triton
import triton.language as tl
import time
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'time' is not used.

Suggested change
import time

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants