Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion comms/torchcomms/TorchCommPy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ Register a tensor buffer with the window for RMA operations.
Raises:
RuntimeError: If tensor is not contiguous or a buffer is already registered.

Note:
In CUDA graph capture mode, the window holds a **non-owned** reference to the
underlying buffer — it does not prevent the tensor from being deallocated.
The caller must ensure the tensor remains alive for the entire lifetime of
the window, and the window must not outlive the CUDA graph it was captured in.

Example:

.. code-block:: python
Expand Down Expand Up @@ -369,7 +375,9 @@ Get the size of the registered window buffer in bytes.
Get the registered tensor buffer, if any.

Returns:
Optional[torch.Tensor]: The registered tensor, or None if no tensor is registered.
Optional[torch.Tensor]: The registered tensor, or ``None`` if no tensor is
registered or if the window was created in CUDA graph capture mode (where
the buffer is non-owned).

)")
.def_property_readonly(
Expand Down
7 changes: 3 additions & 4 deletions comms/torchcomms/TorchCommWindow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class TorchCommWindow {
return win_size_;
}

// Returns the registered tensor buffer, if any.
// Returns the registered tensor buffer, or nullopt in graph capture mode.
std::optional<at::Tensor> get_tensor() const {
return buf_tensor_;
}
Expand All @@ -88,9 +88,8 @@ class TorchCommWindow {
// while the communicator operates on the GPU. However, if both are using the
// GPU, they should reside on the same device.
size_t win_size_{0};
// Store a copy of the user-provided tensor buffer to ensure its storage
// remains valid for the lifetime of the window. This prevents use-after-free
// issues by holding a reference count on the tensor's storage.
// Holds a reference to the registered tensor to keep its storage alive.
// Nullopt in graph capture mode (non-owned buffer); see tensor_register docs.
std::optional<at::Tensor> buf_tensor_;
at::ScalarType buf_dtype_{at::kFloat};
c10::Device buf_device_{c10::kCUDA};
Expand Down
18 changes: 17 additions & 1 deletion comms/torchcomms/ncclx/TorchCommWindowNCCLX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,23 @@ void TorchCommWindowNCCLX<Backend>::tensor_register(const at::Tensor& tensor) {
initNcclOrigWindow(tensor.data_ptr(), win_size_);
#endif

buf_tensor_ = tensor;
// In graph capture mode, we create a non-owned buffer: the window does not
// hold a reference to the tensor. This relies on the caller keeping the
// tensor alive for the lifetime of the window. The NCCL window registration
// (commWindowRegister) independently tracks the underlying physical buffer,
// so the window remains functional without buf_tensor_.
//
// IMPORTANT: In graph capture mode, the window must not outlive the graph
// or the tensor that was registered. The user is responsible for ensuring
// the tensor's storage remains valid for the window's entire lifetime.
if (torch_comm_->getGraphCaptureMode()) {
TC_LOG(WARNING)
<< "[TorchCommWindowNCCLX]: Graph capture mode active — window holds "
<< "a non-owned buffer. The registered tensor must remain alive for "
<< "the lifetime of this window. get_tensor() will return nullopt.";
} else {
buf_tensor_ = tensor;
}
buf_device_ = tensor.device();
}

Expand Down
63 changes: 63 additions & 0 deletions comms/torchcomms/ncclx/tests/unit/cpp/TorchCommWindowNCCLXTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,69 @@ TEST_F(TorchCommWindowNCCLXTest, WindowOperationsAfterFinalizeThrowException) {
testOperation([&]() { comm->new_window(); });
}

// =============================================================================
// Graph Capture Mode Tests
// =============================================================================
//
// These tests verify that tensor_register() skips storing buf_tensor_ in
// graph capture mode to save memory, while still storing it in normal mode.

TEST_F(
TorchCommWindowNCCLXTest,
TensorRegisterSkipsBufTensorInGraphCaptureMode) {
// Verifies: In graph capture mode, tensor_register() does NOT store
// buf_tensor_ so the tensor can be released to save memory.
setupRankAndSize(0, 2);
setupCCAExpectations(1, 2, 1);
auto comm = createMockedTorchComm();

cuda_mock_->setupDefaultBehaviors();
nccl_mock_->setupDefaultBehaviors();

EXPECT_NO_THROW(
comm->init(*device_, "test_graph_buf_tensor", default_options_));

// Simulate graph capture mode: streamIsCapturing returns Active
ON_CALL(*cuda_mock_, streamIsCapturing(_, _))
.WillByDefault(DoAll(
SetArgPointee<1>(cudaStreamCaptureStatusActive),
Return(cudaSuccess)));

auto tensor = createTestTensor({10, 10});
auto win = comm->new_window();
win->tensor_register(tensor);

// In graph capture mode, get_tensor() should return nullopt
EXPECT_FALSE(win->get_tensor().has_value())
<< "buf_tensor_ should not be stored in graph capture mode";

EXPECT_NO_THROW(comm->finalize());
}

TEST_F(TorchCommWindowNCCLXTest, TensorRegisterStoresBufTensorInNormalMode) {
// Verifies: In normal (non-graph-capture) mode, tensor_register() stores
// buf_tensor_ as usual.
setupRankAndSize(0, 2);
setupCCAExpectations(1, 2, 1);
auto comm = createMockedTorchComm();

cuda_mock_->setupDefaultBehaviors();
nccl_mock_->setupDefaultBehaviors();

EXPECT_NO_THROW(
comm->init(*device_, "test_normal_buf_tensor", default_options_));

auto tensor = createTestTensor({10, 10});
auto win = comm->new_window();
win->tensor_register(tensor);

// In normal mode, get_tensor() should return the tensor
EXPECT_TRUE(win->get_tensor().has_value())
<< "buf_tensor_ should be stored in normal mode";

EXPECT_NO_THROW(comm->finalize());
}

// =============================================================================
// Device API Tests for get_device_window()
// =============================================================================
Expand Down
Loading