diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..dd7e703 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third-party/nccl"] + path = third_party/nccl + url = https://github.com/NVIDIA/nccl.git diff --git a/Cargo.toml b/Cargo.toml index e8d6b8a..aafc5b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ tonic = "0.12.2" [build-dependencies] tonic-build = "0.12.2" +cmake = "0.1" [lib] name = "torchft" diff --git a/build.rs b/build.rs index f10cb3b..333da43 100644 --- a/build.rs +++ b/build.rs @@ -8,5 +8,8 @@ fn main() -> Result<(), Box> { tonic_build::configure() .protoc_arg("--experimental_allow_proto3_optional") .compile_protos(&["proto/torchft.proto"], &["proto"])?; + + let dst = cmake::build("csrc"); + Ok(()) } diff --git a/cmake/External/nccl.cmake b/cmake/External/nccl.cmake new file mode 100644 index 0000000..f849e76 --- /dev/null +++ b/cmake/External/nccl.cmake @@ -0,0 +1,75 @@ +if(NOT __NCCL_INCLUDED) + set(__NCCL_INCLUDED TRUE) + + if(USE_SYSTEM_NCCL) + # NCCL_ROOT, NCCL_LIB_DIR, NCCL_INCLUDE_DIR will be accounted in the following line. + find_package(NCCL REQUIRED) + if(NCCL_FOUND) + add_library(__caffe2_nccl INTERFACE) + target_link_libraries(__caffe2_nccl INTERFACE ${NCCL_LIBRARIES}) + target_include_directories(__caffe2_nccl INTERFACE ${NCCL_INCLUDE_DIRS}) + endif() + else() + cuda_select_nvcc_arch_flags(NVCC_GENCODE ${TORCH_CUDA_ARCH_LIST}) + + string(REPLACE "-gencode;" "-gencode=" NVCC_GENCODE "${NVCC_GENCODE}") + # this second replacement is needed when there are multiple archs + string(REPLACE ";-gencode" " -gencode" NVCC_GENCODE "${NVCC_GENCODE}") + + if(DEFINED ENV{MAX_JOBS}) + set(MAX_JOBS "$ENV{MAX_JOBS}") + else() + include(ProcessorCount) + ProcessorCount(NUM_HARDWARE_THREADS) + # Assume 2 hardware threads per cpu core + math(EXPR MAX_JOBS "${NUM_HARDWARE_THREADS} / 2") + # ProcessorCount might return 0, set to a positive number + if(MAX_JOBS LESS 2) + set(MAX_JOBS 2) + endif() + endif() + + if("${CMAKE_GENERATOR}" MATCHES "Make") + # Recursive make with jobserver for parallelism, and also put a load limit + # here to avoid flaky OOM, https://www.gnu.org/software/make/manual/html_node/Parallel.html + set(MAKE_COMMAND "$(MAKE)" "-l${MAX_JOBS}") + else() + # Parallel build with CPU load limit to avoid oversubscription + set(MAKE_COMMAND "make" "-j${MAX_JOBS}" "-l${MAX_JOBS}") + endif() + + set(__NCCL_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/nccl") + ExternalProject_Add(nccl_external + SOURCE_DIR ${PROJECT_SOURCE_DIR}/../third_party/nccl + BUILD_IN_SOURCE 1 + CONFIGURE_COMMAND "" + BUILD_COMMAND + ${MAKE_COMMAND} + "CXX=${CMAKE_CXX_COMPILER}" + "CUDA_HOME=${CUDA_TOOLKIT_ROOT_DIR}" + "NVCC=${CUDA_NVCC_EXECUTABLE}" + "NVCC_GENCODE=${NVCC_GENCODE}" + "BUILDDIR=${__NCCL_BUILD_DIR}" + "VERBOSE=0" + "DEBUG=0" + BUILD_BYPRODUCTS "${__NCCL_BUILD_DIR}/lib/libnccl_static.a" + INSTALL_COMMAND "" + ) + + set(__NCCL_LIBRARY_DEP nccl_external) + set(NCCL_LIBRARIES ${__NCCL_BUILD_DIR}/lib/libnccl_static.a) + + set(NCCL_FOUND TRUE) + add_library(__caffe2_nccl INTERFACE) + # The following old-style variables are set so that other libs, such as Gloo, + # can still use it. + set(NCCL_INCLUDE_DIRS ${__NCCL_BUILD_DIR}/include) + add_dependencies(__caffe2_nccl ${__NCCL_LIBRARY_DEP}) + target_link_libraries(__caffe2_nccl INTERFACE ${NCCL_LIBRARIES}) + target_include_directories(__caffe2_nccl INTERFACE ${NCCL_INCLUDE_DIRS}) + # nccl includes calls to shm_open/shm_close and therefore must depend on librt on Linux + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + target_link_libraries(__caffe2_nccl INTERFACE rt) + endif() + endif() +endif() diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt new file mode 100644 index 0000000..ce000f7 --- /dev/null +++ b/csrc/CMakeLists.txt @@ -0,0 +1,54 @@ +cmake_minimum_required(VERSION 3.4...3.18) +project(_torchft_cpp) + +# Python +find_package (Python3 COMPONENTS Interpreter Development) +set(PYTORCH_ROOT "${Python3_SITELIB}") +include_directories(BEFORE "${Python3_INCLUDE_DIRS}") + +# CUDA +find_package(CUDA REQUIRED) +include_directories(BEFORE "${CUDA_INCLUDE_DIRS}") +message(STATUS "CUDA_VERSION: ${CUDA_VERSION}") + +set(TORCH_CUDA_ARCH_LIST "5.0;6.0;7.0;7.5;8.0;8.6;9.0") + +# NCCL +include(ExternalProject) +include(${CMAKE_CURRENT_LIST_DIR}/../cmake/External/nccl.cmake) + +# torch +# if pytorch was installed in develop mode we need to resolve the egg-link +set(PYTORCH_EGG_LINK "${PYTORCH_ROOT}/torch.egg-link") +if (EXISTS "${PYTORCH_EGG_LINK}") + file (STRINGS "${PYTORCH_EGG_LINK}" PYTORCH_ROOT LIMIT_COUNT 1) +endif() + +message(STATUS "PYTORCH_ROOT: ${PYTORCH_ROOT}" ) + +include_directories(BEFORE "${PYTORCH_ROOT}/torch/include") +include_directories(BEFORE "${PYTORCH_ROOT}/torch/include/torch/csrc/api/include/") +LINK_DIRECTORIES("${PYTORCH_ROOT}/torch/lib") + +#include_directories(BEFORE "${Python3_SITELIB}/triton/backends/nvidia/include/") +#include_directories(BEFORE "${Python3_SITELIB}/nvidia/cuda_runtime/include/") +#include_directories(BEFORE "${Python3_SITELIB}/nvidia/cusparse/include/") + +add_definitions(-DUSE_C10D_NCCL) + +add_library(${PROJECT_NAME} SHARED ProcessGroupNCCL.cpp init.cpp NCCLUtils.cpp cuda_utils.cpp) + +target_link_libraries(${PROJECT_NAME} ${CUDA_LIBRARIES}) +target_link_libraries(${PROJECT_NAME} __caffe2_nccl) +target_link_libraries(${PROJECT_NAME} torch_cpu torch_cuda c10_cuda) + +set_target_properties(${PROJECT_NAME} PROPERTIES + PREFIX "" + SUFFIX ".${Python3_SOABI}.so" +) + + +install( + TARGETS ${PROJECT_NAME} + DESTINATION "${PROJECT_SOURCE_DIR}/../torchft" +) diff --git a/csrc/NCCLUtils.cpp b/csrc/NCCLUtils.cpp new file mode 100644 index 0000000..dff8a5f --- /dev/null +++ b/csrc/NCCLUtils.cpp @@ -0,0 +1,571 @@ +#include + +#include + +#ifdef USE_C10D_NCCL +#include +#include + +namespace c10d { + +NCCLComm::NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {} + +NCCLComm::~NCCLComm() noexcept { + // (kwen2501) Making CUDA/NCCL calls in this destructor can hit CUDA driver + // shutdown error if CUDA context has exited first. Thus, we are not + // destroying or aborting NCCL communicators here. We just detect and warn + // about the risk of memory leak. Normally, a user would have called + // `destroy_process_group` or `abort_process_group`, and such risk would be + // avoided. + LockType lock(mutex_); + if (ncclComm_ && initialized_ && !aborted_) { + TORCH_WARN_ONCE( + "WARNING: NCCL communicator hasn't been destroyed. This may cause " + "memory leaks. To avoid the risk, you can call `destroy_process_group` " + "during normal exit or `_abort_process_group` when handling failures.") + } +} + +// NOLINTNEXTLINE(*-noexcept-move-*) +NCCLComm::NCCLComm(NCCLComm&& other) { + // Using other's lock, as it reads other's states + // Can not use this.mutex_, as this object is being constructed. + LockType lock(other.mutex_); + std::swap(ncclComm_, other.ncclComm_); + std::swap(aborted_, other.aborted_); + std::swap(ncclAsyncErr_, other.ncclAsyncErr_); + std::swap(initialized_, other.initialized_); + std::swap(nonBlocking_, other.nonBlocking_); + std::swap(deviceIndex_, other.deviceIndex_); +} + +ncclUniqueId NCCLComm::getNcclId() { + return ncclId_; +} + +std::shared_ptr NCCLComm::create( + int numRanks, + int rank, + ncclUniqueId commId, + at::DeviceIndex deviceIndex) { + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex); + auto comm = std::make_shared(); + C10D_NCCL_CHECK( + ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), + std::nullopt); + comm->ncclId_ = commId; + comm->rank_ = rank; + comm->deviceIndex_ = deviceIndex; + comm->initialized_ = true; + // Old style comm is always blocking. + comm->nonBlocking_ = false; + return comm; +} + +#ifdef NCCL_HAS_CONFIG +std::shared_ptr NCCLComm::create( + int numRanks, + int rank, + ncclUniqueId commId, + at::DeviceIndex deviceIndex, + ncclConfig_t& config) { + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex); + auto comm = std::make_shared(); + comm->nonBlocking_ = config.blocking == 0; + LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: " + << (comm->nonBlocking_ ? "nonblocking" : "blocking"); + C10D_NCCL_CHECK_NONBLOCKING( + ncclCommInitRankConfig( + &(comm->ncclComm_), numRanks, commId, rank, &config), + std::nullopt); + comm->ncclId_ = commId; + comm->rank_ = rank; + comm->deviceIndex_ = deviceIndex; + // Under blocking mode, comm is initialized immediately after NCCL init + // returns; Under nonblocking mode, we check whether comm is initialized the + // *next* time ncclComm_ is accessed. + comm->initialized_ = !comm->nonBlocking_; + return comm; +} +#ifdef NCCL_HAS_INIT_RANK_SCALABLE +std::shared_ptr NCCLComm::create_scalable( + int numRanks, + int rank, + std::vector& commIds, + ncclConfig_t& config) { + auto comm = std::make_shared(); + comm->nonBlocking_ = config.blocking == 0; + LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: " + << (comm->nonBlocking_ ? "nonblocking" : "blocking") + << " with scalable init."; + C10D_NCCL_CHECK_NONBLOCKING( + ncclCommInitRankScalable( + &(comm->ncclComm_), + numRanks, + rank, + commIds.size(), + commIds.data(), + &config), + std::nullopt); + // Only the first ncclUniqueId will be used to create the + // communicator hash id, which is used to identify the communicator + // in the log file and in the replay tool. + comm->ncclId_ = commIds[0]; + comm->rank_ = rank; + comm->initialized_ = !comm->nonBlocking_; + return comm; +} +#endif // NCCL_HAS_INIT_RANK_SCALABLE +#endif // NCCL_HAS_CONFIG + +ncclComm_t NCCLComm::getNcclComm() { + LockType lock(mutex_); + if (aborted_) { + auto commFailureMsg = commFailureReason_ != std::nullopt + ? c10::str(" Original reason for failure was: ", *commFailureReason_) + : ""; + TORCH_CHECK_WITH( + DistBackendError, + false, + c10::str( + "NCCL communicator was aborted on rank ", + rank_, + ". ", + commFailureMsg)); + } + // In non-blocking mode, ensure comm is ready. + if (nonBlocking_) { + // Wait with long interval if communicator is being initialized. + bool longInterval = !initialized_; + waitReady(longInterval); + // ncclComm_ should be initialized by now + } + if (!initialized_) { + // TODO: see if we can consolidate other `initialized_` flipping here. + // Maintaining it elsewhere is some work. + initialized_ = true; + LOG(INFO) << "Rank " << rank_ << ": NCCL communicator " << repr() + << " is initialized."; + } + return ncclComm_; +} + +// Wait for the communicator to be ready. This is a blocking function. +// Arguments: +// longInterval: if true, wait with sleep of an interval; otherwise, wait +// with `sched_yield` which is faster (but acquires CPU more frequently). +void NCCLComm::waitReady(bool longInterval) { + LockType lock(mutex_); + if (aborted_) + return; + // If timeout is reached, throw an exception. + if (longInterval) { + C10D_NCCL_CHECK_TIMEOUT_SLEEP(ncclInProgress, ncclComm_, std::nullopt); + } else { + C10D_NCCL_CHECK_TIMEOUT(ncclInProgress, ncclComm_, std::nullopt); + } +} + +std::optional NCCLComm::getNcclCommFailureReason() const { + LockType lock(mutex_); + return commFailureReason_; +} + +// TODO: why do we have `!defined(FBCODE_CAFFE2)` here? +#if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2) +// last argument to split() API is not used to support +// multiple implementations +std::shared_ptr NCCLComm::split( + NCCLComm* source, + int color_id, + int rank, + ncclConfig_t& config, + std::vector& ranks_ull) { + TORCH_CHECK( + color_id >= NCCL_SPLIT_NOCOLOR, + "Color must be a non-negative value or NCCL_SPLIT_NOCOLOR (-1)" + ", but got ", + color_id); + LOG(INFO) << "Rank " << source->rank_ << ": split from parent comm " + << source->repr() << " with color_id " << color_id << " and rank " + << rank; + at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_); + auto comm = std::make_shared(); + // This call will block until the source communicator is initialized + auto sourceComm = source->getNcclComm(); +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK( + ncclCommSplit(sourceComm, color_id, rank, &(comm->ncclComm_), &config), + std::nullopt); +#else + // After calling ncclCommSplit in non-blocking mode, we should wait for the + // source communicator to be out of ncclInProgress state. + // Reason 1: + // it's unsafe to call new operations on the parent comm while it's in + // ncclInProgress state. + // Reason 2: + // as of NCCL 2.23, the ptr value of child comm will not be filled until the + // state of parent comm is ncclSuccess. This may change in the future. See: + // https://github.com/NVIDIA/nccl/issues/1472 + C10D_NCCL_CHECK_TIMEOUT_SLEEP( + ncclCommSplit(sourceComm, color_id, rank, &(comm->ncclComm_), &config), + sourceComm, // wait on parent comm + std::nullopt); + if (color_id >= 0) { + // Waiting for parent comm above still does not seem to guarantee the child + // comm ptr is valid. Therefore we add a manual wait here for safety. + // TODO: remove this wait after NCCL fix the semantics. + auto startTime = std::chrono::steady_clock::now(); + auto timeout = nccl_nonblocking_timeout(); + while (!comm->ncclComm_) { + C10D_CHECK_TIMEOUT(startTime, timeout); + C10D_SCHED_SLEEP(); + } + } + // comm->ncclComm_ should have valid ptr by now, but not necessarily + // initialized. Rely on getNcclComm() to wait for its initialization. +#endif + ++source->ncclCommSplitCounter_; + comm->rank_ = rank; + // Child comm should be on the same device as parent comm + comm->deviceIndex_ = source->deviceIndex_; + comm->nonBlocking_ = config.blocking == 0; + LOG(INFO) << "Rank " << source->rank_ << ": created child comm " + << comm->repr() << " with color_id " << color_id; + return comm; +} +#endif + +void NCCLComm::finalize() { + LockType lock(mutex_); + if (aborted_) { + LOG(INFO) << "Rank " << rank_ + << ": NCCL communicator already Invalidated. Skip finalize."; + return; + } + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_); + auto comm = getNcclComm(); + C10D_NCCL_CHECK_NONBLOCKING(ncclCommFinalize(comm), std::nullopt); +} + +void NCCLComm::destroy() { + LockType lock(mutex_); + if (aborted_) { + LOG(INFO) << "Rank " << rank_ + << ": NCCL communicator already Invalidated. Skip destroy."; + return; + } + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_); + auto comm = getNcclComm(); + C10D_NCCL_CHECK(ncclCommDestroy(comm), std::nullopt); + // Poison future getNcclComm + aborted_ = true; +} + +void NCCLComm::abort(std::optional commFailureReason) { + LockType lock(mutex_); + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_); +#ifdef ENABLE_NCCL_ERROR_CHECKING + if (aborted_ && !initialized_) { + // Should not abort twice. + return; + } + +#ifdef NCCL_HAS_COMM_REGISTER + // Deregister all registered segments before aborting. + for (auto& it : registeredSegmentHandles_) { + void* handle = it.second; + C10D_NCCL_CHECK( + ::ncclCommDeregister(ncclComm_, handle), + c10::str( + "Failed to deregister segment handle ", + handle, + " on ncclComm_ ", + ncclComm_)); + } + registeredSegmentHandles_.clear(); +#endif + + // Set true failure reason if provided by ProcessGroupNCCL (e.g. work + // timeout) + commFailureReason_ = commFailureReason; + LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: " + << (commFailureReason ? *commFailureReason + : "No abort reason provided."); +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_); +#else + C10D_NCCL_CHECK_TIMEOUT( + ::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_); +#endif + aborted_ = true; + ncclComm_ = nullptr; + + // Set an appropriate error so that we avoid using the communicator. + if (ncclAsyncErr_ == ncclSuccess) { + ncclAsyncErr_ = ncclSystemError; + } +#else + // This is a NOOP, if error checks are disabled. + return; +#endif +} + +bool NCCLComm::isInitialized() const { + LockType lock(mutex_); + return initialized_; +} + +bool NCCLComm::isAborted() const { + LockType lock(mutex_); + return aborted_; +} + +uint64_t NCCLComm::getCommSplitCounter() const { + return ncclCommSplitCounter_; +} + +ncclResult_t NCCLComm::checkForNcclError() { + LockType lock(mutex_); +#ifdef ENABLE_NCCL_ERROR_CHECKING + if (ncclAsyncErr_ != ncclSuccess) { + return ncclAsyncErr_; + } + C10D_NCCL_CHECK( + ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_); + return ncclAsyncErr_; +#else + // Always return success, if error checks are disabled. + return ncclSuccess; +#endif +} + +ncclResult_t NCCLComm::registerSegment( + void* ptr, + size_t size, + bool errorOnRereg /*=true*/) { + LockType lock(mutex_); +#ifdef NCCL_HAS_COMM_REGISTER + // We register only segments from cache allocator + // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always + // maps to a unique handle and should not be registered before the current + // ptr is deregistered and freed. + if (registeredSegmentHandles_.count(ptr) > 0) { + TORCH_CHECK( + !errorOnRereg, + "Segment with ptr ", + ptr, + " has already been registered on ncclComm_ ", + ncclComm_); + // Skip below + return ncclSuccess; + } + + void* handle = nullptr; + // Use getNcclComm to make sure comm is ready before calling nccl APIs + auto comm = getNcclComm(); + C10D_NCCL_CHECK( + ncclCommRegister(comm, ptr, size, &handle), + c10::str( + "Failed to register segment with ptr ", + ptr, + ", size ", + size, + " on ncclComm_ ", + comm)); + registeredSegmentHandles_[ptr] = handle; + return ncclSuccess; +#else + return ncclInvalidUsage; +#endif +} + +ncclResult_t NCCLComm::deregisterSegment(void* ptr) { + LockType lock(mutex_); +#ifdef NCCL_HAS_COMM_REGISTER + TORCH_CHECK( + registeredSegmentHandles_.count(ptr) == 1, + "Segment with ptr ", + ptr, + " is not registered on ncclComm_ ", + ncclComm_); + + void* handle = registeredSegmentHandles_[ptr]; + // Use getNcclComm to make sure comm is ready before calling nccl APIs + auto comm = getNcclComm(); + C10D_NCCL_CHECK( + ncclCommDeregister(comm, handle), + c10::str( + "Failed to deregister segment handle ", + handle, + ", with ptr ", + ptr, + " on ncclComm_ ", + comm)); + registeredSegmentHandles_.erase(ptr); + return ncclSuccess; +#else + return ncclInvalidUsage; +#endif +} + +std::string NCCLComm::repr() const { + return c10::str((void*)ncclComm_); +} + +#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) +std::unordered_map NCCLComm::ncclCommDump() { + std::unordered_map dump; + if (isAborted()) { + LOG(INFO) << "Communicator was aborted before trying to dump its state."; + return dump; + } + C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), std::nullopt); + return dump; +} +#endif + +std::string getNcclVersion() { + static std::string versionString = []() { + int version = 0; + std::string versionString; + ncclResult_t status = ncclGetVersion(&version); + // can't compute the version if call did not return successfully or version + // code < 100 (corresponding to 0.1.0) + if (status != ncclSuccess || version < 100) { + versionString = "Unknown NCCL version"; + } else { + // NCCL changed version coding starting 2.9 + const int majorBase = version < 2900 ? 1000 : 10000; + const int minorBase = 100; + auto ncclMajor = version / majorBase; + auto ncclMinor = (version % majorBase) / minorBase; + auto ncclPatch = + version % (ncclMajor * majorBase + ncclMinor * minorBase); + versionString = std::to_string(ncclMajor) + "." + + std::to_string(ncclMinor) + "." + std::to_string(ncclPatch); +#ifdef NCCL_SUFFIX + const auto ncclSuffix = std::string(NCCL_SUFFIX); + if (!ncclSuffix.empty()) { + versionString += "." + ncclSuffix; + } +#endif + } + return versionString; + }(); + + return versionString; +} + +size_t hashTensors(const std::vector& tensors) { + size_t hash = 0; + for (auto& tensor : tensors) { + if (tensor.numel() > 0 && tensor.storage()) { + size_t data_size = tensor.storage().nbytes(); + if (data_size > 0 && tensor.storage().data_ptr()) { + auto src = static_cast(tensor.storage().data_ptr().get()); + std::vector dst(data_size); + // This is needed so that we trigger a device synchronization so we can + // get the collective finished if launched on GPU and hash its output. + AT_CUDA_CHECK( + cudaMemcpy(dst.data(), src, data_size, cudaMemcpyDeviceToHost)); + for (size_t i = 0; i < data_size; ++i) { + // Update the hash for each byte in the tensor + hash = c10::hash_combine(hash, c10::get_hash(dst[i], data_size)); + } + } + } + } + return hash; +} + +// Default value: 30 minutes +int nccl_nonblocking_timeout() { + static int timeout = -2; // -2 means not initialized + if (timeout == -2) { + const auto val = c10::utils::get_env("TORCH_NCCL_NONBLOCKING_TIMEOUT"); + if (val.has_value() && !val.value().empty()) { + timeout = stoi(val.value()); + } else { + // Default value consistent with kBackendDefaultTimeout + timeout = 30 * 60; + } + } + return timeout; +} + +std::string ncclGetErrorWithVersion(ncclResult_t error) { + return std::string(ncclGetErrorString(error)) + ", NCCL version " + + getNcclVersion(); +} + +// Provides additional detail into NCCL error codes based on when these are +// thrown in the NCCL codebase. +std::string getNcclErrorDetailStr( + ncclResult_t error, + std::optional processGroupFailureReason /* = std::nullopt */ +) { + // Prioritize failure reason provided by PG NCCL first, as it can abort + // communicators when it encounters collective timeouts, etc. + if (processGroupFailureReason != std::nullopt) { + return *processGroupFailureReason; + } + std::string interpret; + std::string err; +#ifdef ENABLE_NCCL_GET_LAST_ERROR + auto ret = ncclGetLastError(nullptr); + if (ret) { + err = "\nLast error:\n" + std::string(ret); + } else { + err = "\nLast error: Unknown NCCL Error\n"; + } +#endif + switch (error) { + case ncclUnhandledCudaError: + interpret = "ncclUnhandledCudaError: Call to CUDA function failed."; + break; + case ncclSystemError: + interpret = + "ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. "; +#ifndef NCCL_REMOTE_ERROR + // Before ncclRemoteError was created, unexpected remote disconnect was + // categorized as ncclSystemError + interpret += "It can be also caused by unexpected exit of a remote peer."; +#endif + break; + case ncclInternalError: + interpret = "ncclInternalError: Internal check failed."; + break; + case ncclInvalidArgument: + interpret = "ncclInvalidArgument: Invalid value for an argument."; + break; + case ncclInvalidUsage: + interpret = + "ncclInvalidUsage: This usually reflects invalid usage of NCCL library."; + break; +#ifdef NCCL_REMOTE_ERROR + case ncclRemoteError: + interpret = + "ncclRemoteError: A call failed possibly due to a network error or a remote process exiting prematurely."; + break; +#endif + default: + interpret = "Unknown NCCL error!"; + } + return interpret + err; +} + +// Dump proxyTrace log to stdout +void printNcclCommProxyTrace( + const std::string& dumpReason, + const std::unordered_map& dumpMap) { + LOG(INFO) << "Dumping nccl comm trace, reason: " << dumpReason; + for (auto& [key, value] : dumpMap) { + LOG(INFO) << "key: " << key << ", value: " << value; + } + LOG(INFO) << "----------------------"; +} + +} // namespace c10d + +#endif // USE_C10D_NCCL diff --git a/csrc/ProcessGroupNCCL.cpp b/csrc/ProcessGroupNCCL.cpp new file mode 100644 index 0000000..c5e2414 --- /dev/null +++ b/csrc/ProcessGroupNCCL.cpp @@ -0,0 +1,5237 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ProcessGroupNCCL.hpp" + +using namespace c10d; + +namespace torchft { + +constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM"; + +namespace { + +#if defined(NCCL_MAJOR) && \ + ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10)) +#define NCCL_HAS_AVG 1 +#endif // NCCL version >= 2.10 + +// NCCL op mapping +const std::map ncclOp = { + {ReduceOp::MIN, ncclMin}, + {ReduceOp::MAX, ncclMax}, + {ReduceOp::SUM, ncclSum}, + {ReduceOp::PRODUCT, ncclProd}, +#ifdef NCCL_HAS_AVG + {ReduceOp::AVG, ncclAvg}, +#endif // NCCL_HAS_AVG +}; + +// NCCL type typing +std::map ncclDataType = { + {at::kChar, ncclInt8}, + {at::kByte, ncclUint8}, + {at::kFloat, ncclFloat}, + {at::kDouble, ncclDouble}, + {at::kInt, ncclInt32}, + {at::kLong, ncclInt64}, + {at::kHalf, ncclHalf}, + {at::kBool, ncclUint8}, + {at::kFloat8_e5m2, ncclUint8}, + {at::kFloat8_e4m3fn, ncclUint8}, + {at::kFloat8_e4m3fnuz, ncclUint8}, + {at::kFloat8_e5m2fnuz, ncclUint8}, +#if HAS_NCCL_BF16_DATATYPE + {at::kBFloat16, ncclBfloat16}, +#endif // HAS_NCCL_BF16_DATATYPE +}; + +// Helper function that gets the data type and issues error if not supported +ncclDataType_t getNcclDataType(at::ScalarType type) { + auto it = ncclDataType.find(type); + TORCH_CHECK_WITH( + TypeError, + it != ncclDataType.end(), + "Input tensor data type is not supported for NCCL process group: ", + type); + return it->second; +} + +bool complexViewAsRealAllowed(const ReduceOp& reduceOp) { + switch (reduceOp) { + // NOLINTNEXTLINE(bugprone-branch-clone) + case ReduceOp::SUM: + return true; + case ReduceOp::AVG: + return true; + case ReduceOp::PREMUL_SUM: + return true; + case ReduceOp::UNUSED: + return true; + default: + return false; + } + return false; +} + +#ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT +template +ncclRedOpRAII unpackPreMulSum( + const ReduceOp& reduceOp, + const ncclComm_t& comm) { + const auto* preMulSupplement = + reinterpret_cast(reduceOp.supplement_.get()); + ncclRedOp_t preMulSum{}; + bool has_tensor = preMulSupplement->tensor_factor.defined(); + auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate; + const T* ptr_factor = has_tensor + ? preMulSupplement->tensor_factor.const_data_ptr() + : nullptr; + T scalar_factor = T(preMulSupplement->double_factor); + ncclRedOpCreatePreMulSum( + &preMulSum, + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/ops.html#ncclredopcreatepremulsum + // tells us that the scalar input is strictly a multiplier. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + /*scalar=*/has_tensor ? const_cast(ptr_factor) : &scalar_factor, + dataType, + residence, + comm); + return ncclRedOpRAII(preMulSum, comm); +} +#endif // ENABLE_NCCL_PREMUL_SUM_SUPPORT + +ncclRedOpRAII getNcclReduceOp( + const ReduceOp& reduceOp, + at::Tensor& input, + const ncclDataType_t& dataType, + const ncclComm_t& comm) { + try { + if (input.scalar_type() == at::kBool) { + if (reduceOp == ReduceOp::SUM) { + // For bool tensors, map sum to max, which both represent a bitwise or. + // This is to prevent overflow issues with sum, since we use uint8 to + // represent a bool (see ncclDataType mapping). + return ncclMax; + } +#ifdef NCCL_HAS_AVG + if (reduceOp == ReduceOp::AVG) { + C10_THROW_ERROR( + TypeError, "Cannot use ReduceOp.AVG with boolean inputs"); + } +#endif // NCCL_HAS_AVG + } + if (reduceOp == ReduceOp::PREMUL_SUM) { +#ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT + switch (dataType) { + case ncclHalf: + return unpackPreMulSum(reduceOp, comm); + case ncclFloat: + return unpackPreMulSum(reduceOp, comm); + case ncclDouble: + return unpackPreMulSum(reduceOp, comm); + default: + C10_THROW_ERROR( + TypeError, "PreMulSum Data type must be half, float, or double"); + return ncclRedOp_t{}; + } +#else + C10_THROW_ERROR(ValueError, "PreMulSum requires NCCL>=2.11.1"); +#endif // ENABLE_NCCL_PREMUL_SUM_SUPPORT + } + return ncclOp.at(reduceOp); + } catch (const std::out_of_range&) { + switch (reduceOp) { + case ReduceOp::AVG: + C10_THROW_ERROR( + ValueError, + c10::str( + "AVG requires NCCL 2.10+. The current version is ", + NCCL_MAJOR, + ".", + NCCL_MINOR)); + break; + case ReduceOp::BAND: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with NCCL"); + break; + case ReduceOp::BOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with NCCL"); + break; + case ReduceOp::BXOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL"); + break; + default: + C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); + break; + } + } +} + +// Get a key string from device +inline std::string getKeyFromDevice(at::Device& device) { + return std::to_string(device.index()); +} + +inline at::DeviceIndex getIndexFromDeviceKey(const std::string& deviceKey) { + // initialize the device index to -1, which is an invalid value. + int index = -1; + try { + index = std::stoi(deviceKey); + } catch (const std::invalid_argument& e) { + LOG(ERROR) << c10::str( + "Invalid deviceKey: ", deviceKey, ",", e.what(), "."); + } catch (const std::out_of_range& e) { + LOG(ERROR) << "Out of range: " << e.what(); + } + return static_cast(index); +} + +std::string getKeySendRecv(int myRank, int peer) { + int lowRank = myRank < peer ? myRank : peer; + int highRank = myRank < peer ? peer : myRank; + std::string sendRecvPair = + std::to_string(lowRank) + ":" + std::to_string(highRank); + return sendRecvPair; +} + +// Get device from tensor +inline at::Device getDevice(at::Tensor& tensor) { + return tensor.device(); +} + +// [Sync Streams] Helper that lets the input ncclStreams to wait for the current +// stream. NCCL communications run on ncclStreams, but input tensors are +// allocated on different streams (i.e., current streams). Communications on +// ncclStreams cannot start before pending input tensor ops on current streams +// finish. Otherwise, ops on two streams might read/write same tensors +// concurrently. +// +// The synchronization above alone is not enough. We also need to make sure +// input tensors are not freed before their usages on ncclStreams finish. This +// can be achieved by calling c10::cuda::CUDACachingAllocator::recordStream, +// which remembers the usage stream (ncclStream), creates an event on the usage +// stream when GC attempts to free the input tensor, and delays GC until that +// event is done. +void syncStream( + at::Device& device, + at::cuda::CUDAEvent& ncclEvent, + at::cuda::CUDAStream& ncclStream) { + ncclEvent.record(at::cuda::getCurrentCUDAStream(device.index())); + ncclEvent.block(ncclStream); +} + +// Given a ncclUniqueId, convert it to a string representation that can be put +// in the store. +std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) { + const uint8_t* bytes = reinterpret_cast(&ncclID); + std::ostringstream oss; + for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) { + oss << std::hex << static_cast(bytes[i]); + } + return oss.str(); +} + +std::string getNcclAbortedCommStoreKey(const std::string& ncclIdStr) { + return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr; +} + +// Returns exception's what() given an exception_ptr instance. +std::string getExceptionMsgFromExceptionPtr( + const std::exception_ptr& exceptionPtr) { + TORCH_CHECK(exceptionPtr != nullptr); + try { + std::rethrow_exception(exceptionPtr); + } catch (const std::exception& e) { + return e.what(); + } catch (...) { + return "Unknown exception type"; + } +} + +inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { + // parentheses avoid some compiler warnings + static const uint64_t min_version = + (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6); + static const uint64_t cur_version = torch::cuda::nccl::version(); + if (cur_version < min_version) { + TORCH_CHECK_WITH( + NotImplementedError, + status == c10::cuda::CaptureStatus::None, + "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6"); + } +} + +} // namespace + +// Map from each communicator to its device index. +// This map is used when register/deregister cache segments from cache +// allocator. See design notes below: +// - Each segment should be registered only to the communicator on the +// same device. +// - We cannot reuse devNCCLCommMap_ in each ProcessGroup because the key may be +// ranks rather than device in point-to-point case. +// - This map has also to be maintained as global variable since the register +// hooks are called outside the scope of any PG, thus we need traverse +// communicators in all PGs. +static std::unordered_map, int> ncclCommDevIdxMap; +static std::mutex ncclCommDevIdxMapMutex; +static bool allocatorHooksAttached = false; + +std::atomic ProcessGroupNCCL::shouldDump_(false); + +static void cacheAllocatorRegisterHook( + const c10::cuda::CUDACachingAllocator::TraceEntry& te) { + // Register after SEGMENT_ALLOC + if (te.action_ != + c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_ALLOC) { + return; + } + + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& it : ncclCommDevIdxMap) { + auto& ncclComm = it.first; + auto& devIdx = it.second; + if (te.device_ == devIdx) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + ncclComm->registerSegment(reinterpret_cast(te.addr_), te.size_); + } + } +} + +static void cacheAllocatorDeregisterHook( + const c10::cuda::CUDACachingAllocator::TraceEntry& te) { + // deregister before SEGMENT_FREE + if (te.action_ != + c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_FREE) { + return; + } + + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& it : ncclCommDevIdxMap) { + auto& ncclComm = it.first; + auto& devIdx = it.second; + if (te.device_ == devIdx) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + ncclComm->deregisterSegment(reinterpret_cast(te.addr_)); + } + } +} + +static std:: + unordered_map> + getNCCLCommDumpMap() { + return std::unordered_map< + std::string, + std::unordered_map>(); +} + +std::optional)>>& +get_cpp_trace_dumper() { + static std::optional< + std::function)>> + dumper(std::nullopt); + return dumper; +} + +gil_checker_t& get_gil_checker() { + static gil_checker_t gil_checker = nullptr; + return gil_checker; +} + +static std::future launchAsyncGilCheck() { + std::promise resultPromise; + std::future resultFuture = resultPromise.get_future(); + TORCH_CHECK(get_gil_checker(), "Can't check GIL with null GIL checker"); + std::thread workerThread([promise = std::move(resultPromise)]() mutable { + c10::setThreadName("pt_nccl_gil_chk"); + + try { + auto& gil_checker = get_gil_checker(); + promise.set_value((*gil_checker)()); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); + + // Detach the thread to allow it to run independently + workerThread.detach(); + + return resultFuture; +} + +const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100; +constexpr int64_t kSynchronizeBusyWaitMillis = 1; +thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; + +std::ostream& operator<<( + std::ostream& output, + const ProcessGroupNCCL::WorkNCCL& workNCCL) { + std::string workInfo; + workInfo = c10::str( + "WorkNCCL(", + "SeqNum=", + workNCCL.seq_, + ", OpType=", + opTypeToString(workNCCL.opType_), + ", NumelIn=", + workNCCL.numelIn_, + ", NumelOut=", + workNCCL.numelOut_, + ", Timeout(ms)=", + workNCCL.opTimeout_.count(), + ")"); + return output << workInfo; +} + +ProcessGroupNCCL::WorkNCCL::WorkNCCL( + std::string pgUID, + std::string pgDesc, + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + bool isP2P, + const char* profilingTitle, + const std::optional>& inputs, + bool desyncDebug, + bool enableTiming, + bool cudaEventCacheEnabled, + DebugLevel distDebugLevel) + : Work(rank, opType, profilingTitle, inputs), + pgUID_(std::move(pgUID)), + pgDesc_(std::move(pgDesc)), + device_(device), + workStartTime_(std::chrono::steady_clock::now()), + seq_(seq), + isP2P_(isP2P), + timingEnabled_(enableTiming), + distDebugLevel_(distDebugLevel) { + // Creates the CUDA event wrappers + // Note: The actual events are lazily created when first recorded to with + // DEFAULT_FLAGS = cudaEventDisableTiming. + if (cudaEventCacheEnabled) { + ncclStartEvent_ = enableTiming + ? ProcessGroupNCCL::CUDAEventCache::get(device.index()) + ->create(enableTiming) + : nullptr; + ncclEndEvent_ = ProcessGroupNCCL::CUDAEventCache::get(device.index()) + ->create(enableTiming); + } else { + ncclStartEvent_ = enableTiming + ? std::make_shared(cudaEventDefault) + : nullptr; + ncclEndEvent_ = std::make_shared( + enableTiming ? cudaEventDefault : cudaEventDisableTiming); + } + futureWorkResult_ = + c10::make_intrusive(c10::AnyEnumType::get()); +} + +ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) + : Work(w.rank_, w.opType_), + std::enable_shared_from_this(w), + pgUID_(w.pgUID_), + pgDesc_(w.pgDesc_), + device_(w.device_), + ncclStartEvent_(w.ncclStartEvent_), + ncclEndEvent_(w.ncclEndEvent_), + ncclComm_(w.ncclComm_), + blockingWait_(w.blockingWait_), + opTimeout_(w.opTimeout_), + ownedEphermeralTimeout_(w.ownedEphermeralTimeout_), + workStartTime_(w.workStartTime_), + seq_(w.seq_), + isP2P_(w.isP2P_), + startTraceUpdated_(w.startTraceUpdated_), + numelIn_(w.numelIn_), + numelOut_(w.numelOut_), + store_(w.store_), + futureWorkResult_(w.futureWorkResult_), + timingEnabled_(w.timingEnabled_), + trace_id_(w.trace_id_), + distDebugLevel_(w.distDebugLevel_) { + exception_ = w.exception_; +} + +bool ProcessGroupNCCL::WorkNCCL::isCompleted() { + if (!ncclComm_->isAborted()) { + checkAndSetException(); + } + return exception() || finishedGPUExecutionInternal(); +} + +bool ProcessGroupNCCL::WorkNCCL::isStarted() { + if (!ncclComm_->isAborted()) { + checkAndSetException(); + } + return exception() || startedGPUExecutionInternal(); +} + +bool ProcessGroupNCCL::WorkNCCL::isSuccess() const { + C10_THROW_ERROR(NotImplementedError, "WorkNCCL::isSuccess() is deprecated"); +} + +void ProcessGroupNCCL::WorkNCCL::checkAndSetException() { + if (exception()) { + // We already have an exception. + return; + } + + auto exception_ptr = checkForNCCLErrors(); + std::unique_lock lock(mutex_); + exception_ = exception_ptr; + if (exception_) { + LOG(ERROR) << logPrefix() << "Collective " << *this + << " raised the following async exception: " + << getExceptionMsgFromExceptionPtr(exception_); + + // Mark future result as ERROR + if (futureWorkResult_ && !futureWorkResult_->completed()) { + futureWorkResult_->markCompleted( + at::IValue(static_cast(WorkResult::COMM_ERROR))); + } + } +} + +const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const { + static std::string prefix = c10::str("[Rank ", rank_, "] "); + return prefix; +} + +void ProcessGroupNCCL::WorkNCCL::setException( + std::exception_ptr exception_ptr) { + std::unique_lock lock(mutex_); + exception_ = std::move(exception_ptr); +} + +// Helper that checks if the NCCL kernels are completed on the GPUs +bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() { + checkAndSetException(); + return finishedGPUExecutionInternal(); +} + +bool ProcessGroupNCCL::WorkNCCL::startedGPUExecutionInternal() const { + // if timing is disabled we won't have allocated start events + if (!timingEnabled_) { + return false; + } + // Checking the work's corresponding CUDA event's status + if (!ncclStartEvent_->query()) { + return false; + } + return true; +} + +bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const { + // Checking the work's corresponding CUDA event's status + // It calls `cudaEventQuery` eventually. Although this seems to be a + // non-blocking call, but we did notice hangs in the past. It can + // hang if another thread is holding the CUDA global context lock. For + // example, when doing a `cudaDeviceSynchronize` or even + // `cudaStreamSynchronize`. + if (!ncclEndEvent_->query()) { + return false; + } + return true; +} + +bool ProcessGroupNCCL::WorkNCCL::checkTimeout( + std::optional timeout) { + STATIC_SCOPED_WAIT_COUNTER( + pytorch.wait_counter.ProcessGroupNCCL__checkTimeout); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + auto workTimeout = timeout ? *timeout : opTimeout_; + + if (timeElapsed < workTimeout) { + return false; + } + + // Timed out + + std::string exceptionMsg = c10::str( + logPrefix(), + "Watchdog caught collective operation timeout: ", + *this, + " ran for ", + timeElapsed.count(), + " milliseconds before timing out."); + + LOG(ERROR) << exceptionMsg; + + std::exception_ptr exception_ptr = + std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exceptionMsg)); + if (!exception()) { + // if there is already an error, we don't override it + setException(exception_ptr); + } + + // Mark future result as TIMEOUT + if (futureWorkResult_ && !futureWorkResult_->completed()) { + futureWorkResult_->markCompleted( + at::IValue(static_cast(WorkResult::TIMEOUT))); + } + return true; +} + +// Print the traceback of the collective at call time +void ProcessGroupNCCL::WorkNCCL::printTraceback() const { + // First step we get the corresponding record entry from FR, based on work's + // trace_id_ +} + +void ProcessGroupNCCL::WorkNCCL::handleException( + ErrorHandlingMode errorHandling) { + if (exception_) { + auto exceptionMsg = c10::str( + "Some NCCL operations have failed or timed out. Due to the ", + "asynchronous nature of CUDA kernels, subsequent GPU operations ", + "might run on corrupted/incomplete data."); + LOG(ERROR) << logPrefix() << exceptionMsg; + C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleException"); + + auto logger = c10d::C10dLogger::getLogger(); + if (logger) { + ::c10d::C10dLoggingData data; + data.strings["work_nccl_exception"] = + getExceptionMsgFromExceptionPtr(exception_); + logger->log(data); + } + + if (SHOULD_TEAR_DOWN(errorHandling)) { + auto tearDownMsg = c10::str( + "To avoid data inconsistency, we are taking the entire process down."); + LOG(ERROR) << logPrefix() << tearDownMsg; + std::rethrow_exception(exception_); + } + } +} + +void ProcessGroupNCCL::WorkNCCL::synchronize() { + synchronizeStream(); + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::unregister_work( + c10::intrusive_ptr< + ProcessGroupNCCL::WorkNCCL>::unsafe_reclaim_from_nonowning(this)); + } +} + +void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { + auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); + // Block the current stream on the NCCL stream + ncclEndEvent_->block(currentStream); + + if (avoidRecordStreams_) { + stashed_for_allocator_safety_->clear(); + } +} + +// Same as calling synchronize() when blockingWait_ is false +bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { + RECORD_PARAM_COMMS( + std::make_tuple(static_cast(this->seq_), this->isP2P_), // seq + std::make_tuple(pgUID_, pgDesc_), // PG name tuple + rank_, // rank + "wait", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, + -1, + static_cast(1)); // number of device? + + // synchronize() will block the current stream on the NCCL stream + synchronize(); + + // In case of blockingWait or a timeout value is specified by the user, we + // block the CPU thread until the work is completed or timed out. + if (blockingWait_ || timeout != kNoTimeout) { + while (!isCompleted()) { + bool timedOut = checkTimeout( + timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout)); + // Explicitly abort ncclComms here before throwing this timed out + // exception to users. + // If throwing timed out excepiton without aborting nccl communicators + // here, it was observed that CUDA GPU will have 100% utilization and + // can not run new events successfully. + if (timedOut) { + std::string exceptionMsg = c10::str( + logPrefix(), "Work ", (*this), " timed out in blocking wait."); + LOG(ERROR) << exceptionMsg; + break; + } + // Yield + std::this_thread::sleep_for( + std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); + } + } else if (isBarrierOp_ && !isCompleted()) { + // For barrier wait when timeout is unspecified, we block the CPU thread on + // current stream. This is to minimize the CPU barrier wait time in healthy + // path + auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); + // CUDAStream wrapper will correctly use a DeviceGuard here + currentStream.synchronize(); + } + + // If exception is detected, throw it from the main CPU thread + if (exception()) { + // Abort NCCL communicators + abort(); + // Throw exception (from main thread here) + handleException(TearDown); + } + + // TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL + // upgrade. Once a NCCL version is qualified, this code should not be needed + // at runtime. +#ifdef PGNCCL_ENABLE_HASH + if (distDebugLevel_ >= DebugLevel::Detail) { + auto numel = getTensorsNumel(*outputs_); + auto hashValue = hashTensors(*outputs_); + PRINT_COLLECTIVE_HASH_SIGNATURE( + "output", opTypeToString(opType_), numel, hashValue); + } +#endif // PGNCCL_ENABLE_HASH + // Always return true, because abort API is not implemented. + return true; +} + +void ProcessGroupNCCL::WorkNCCL::abort() { + // dump before aborting for rcclexp +#if defined(USE_ROCM) && defined(NCCL_COMM_DUMP) + auto dumpMap = ncclComm_->ncclCommDump(); + printNcclCommProxyTrace("WorkNCCL::abort", dumpMap); +#endif // USE_ROCM && NCCL_COMM_DUMP + + // Abort all communicators of this work + ncclComm_->abort(); + + ncclCommDevIdxMapMutex.lock(); + ncclCommDevIdxMap.erase(ncclComm_); + ncclCommDevIdxMapMutex.unlock(); +} + +ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default; + +// CUDA event is used to record the start/end of one Work. +// Instead of let the CUDA event gets destroyed, we now reuse it after the Work +// has been erased from workMetaList_. +// This is to avoid the potential deadlock caused by CudaEventDestroy. +std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( + bool timing) { + // Register the deleter as a callback when the WorkNCCL object is destroyed. + // Each deleter keeps a ref count to the cache object, so that even when + // the thread that creates the cache is gone, the cache object won't be + // destroyed until all the events in the cache are destroyed (ref number drops + // to zero). + auto deleter = [cache = shared_from_this(), + timing](at::cuda::CUDAEvent* event) { + std::lock_guard lock(cache->cacheMutex_); + // We put the event back to the cache deque once the WorkNCCL object is + // destroyed. + cache->eventsArray_[timing ? 1 : 0].push_back(event); + }; + at::cuda::CUDAEvent* event = nullptr; + { + std::lock_guard lock(cacheMutex_); + auto& events = eventsArray_[timing ? 1 : 0]; + // If we still have events in the cache, we reuse it. Otherwise, we create a + // new one. + if (!events.empty()) { + event = events.front(); + events.pop_front(); + } else { + event = new at::cuda::CUDAEvent( + timing ? cudaEventDefault : cudaEventDisableTiming); + } + } + return std::shared_ptr(event, std::move(deleter)); +} + +std::shared_ptr ProcessGroupNCCL:: + CUDAEventCache::get(at::DeviceIndex device) { + // A per-thread singleton of device-to-CUDAEventCache map. + // Map is needed because events cannot be reused across devices. + // Per-thread ownership is needed to support multi-threaded case (instead of + // multi-process case). + static thread_local std:: + map> + cacheDeviceMap; + // Check if device has already been in the map, if not, add a new entry + auto it = cacheDeviceMap.find(device); + if (it == cacheDeviceMap.end()) { + cacheDeviceMap.emplace( + device, std::make_shared()); + } + return cacheDeviceMap[device]; +} + +static std::atomic process_group_id = 0; + +constexpr const char* MULTI_DEVICE_ERROR_MSG = + "Expecting one tensor only but got multiple. You are probably using multiple " + "devices under one thread. The support for such usage has been deprecated. " + "For details, please refer to " + "https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions. " + "ProcessGroupNCCL continues supporting multi-process and multi-thread modes."; + +ProcessGroupNCCL::ProcessGroupNCCL( + c10::intrusive_ptr store, + int rank, + int size, + c10::intrusive_ptr options) + : Backend(rank, size), + store_(std::move(store)), + options_(std::move(options)), + terminateProcessGroup_(false), + terminateHeartbeatMonitorThread_(false), + local_id_(process_group_id++), + intraNodeComm_(initIntraNodeComm()) { + TORCH_CHECK_WITH( + ValueError, + at::cuda::getNumGPUs() != 0, + "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); + + // getNcclVersion needs to get called before launching threads which can + // potentially call getenv. getNcclVersion internally calls setenv to set some + // environment variables from config file, which can race with getenv from + // other threads and cause segfaults. + const auto ncclVersion = getNcclVersion(); + this->setGroupUid(options_->group_name); + this->localDeviceCount_ = static_cast(at::cuda::getNumGPUs()); + logPrefix_ = createLogPrefix(); + blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false); + asyncErrorHandling_ = static_cast( + getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/)); + desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) || + (dist_debug_level_ >= DebugLevel::Detail); + rethrowCUDAErrors_ = false; + // TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT + // or change its name to reflect that dump happens on exception including + // both timeout and other errors. + dumpOnTimeoutOrEx_ = false; + propagatePgError_ = false; + // logging C++ stack isn't safe. Introduce a variable to control it. + logCppStackOnUncleanShutdown_ = + getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true); + enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false); + heartbeat_ = 1ULL; + monitorThreadEnabled_.store(false); + cudaEventCacheEnabled_.store(getCvarBool(TORCH_NCCL_CUDA_EVENT_CACHE, true)); + heartbeatTimeoutInSec_ = + getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /*8 Mins*/); + waitTimeoutDumpInMilSec_ = + getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 60 * 1000 /*60 Sec*/); + coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000); + traceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 2000); + enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail); + // store_ usually is wrapped with PrefixStore and the prefix is different + // across different ProcessGroupNCCL(PG) instances. We need to get the + // underlying non-PrefixStore for sharing global information shared across + // different PGs. + PrefixStore* prefixStore = dynamic_cast(store_.get()); + globalStore_ = + prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_; +#ifdef ENABLE_NCCL_ERROR_CHECKING + enableTiming_.store( + getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_); +#endif // ENABLE_NCCL_ERROR_CHECKING + avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false); +#ifdef NCCL_HAS_COMM_REGISTER + useTensorRegisterAllocatorHook_ = + getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); + if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + expandable_segments()) { + useTensorRegisterAllocatorHook_ = false; + LOG(INFO) + << logPrefix() + << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; + } +#endif // NCCL_HAS_COMM_REGISTER + + if (blockingWait_) { + LOG(INFO) + << logPrefix() + << "TORCH_NCCL_BLOCKING_WAIT is enabled, NO watchdog thread is created."; + } else { + if (desyncDebug_ && asyncErrorHandling_ == NoHandling) { + LOG(INFO) + << logPrefix() + << "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING " + << "must both be enabled. " + << "Enabling TORCH_NCCL_ASYNC_ERROR_HANDLING."; + asyncErrorHandling_ = SkipCleanUp; + } + } + +#ifdef ENABLE_NCCL_ERROR_CHECKING + // in blockingWait mode, we don't need to enable the watchdog thread to check + // the timeout or nccl error because the main thread would throw an exception + // and it is the user's responsibility to handle the exception. + if (!blockingWait_) { + ncclCommWatchdogThread_ = + std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); + } +#endif // ENABLE_NCCL_ERROR_CHECKING + + init(); + const std::string OFF = "OFF"; + std::string torch_distributed_debug = + getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str()); + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL initialization options: " << "size: " << size + << ", global rank: " << globalRank() + << ", TIMEOUT(ms): " << options_->timeout.count() + << ", USE_HIGH_PRIORITY_STREAM: " + << options_->is_high_priority_stream + << ", SPLIT_FROM: " << options_->split_from + << ", SPLIT_COLOR: " << options_->split_color + << ", PG Name: " << options_->group_name; + + LOG(INFO) << logPrefix() << "ProcessGroupNCCL environments: " + << "NCCL version: " << ncclVersion + << ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ + << ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeoutOrEx_ + << ", TORCH_NCCL_PROPAGATE_ERROR: " << propagatePgError_ + << ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: " + << waitTimeoutDumpInMilSec_ + << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_ + << ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load() + << ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_ + << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug +#ifdef NCCL_HAS_COMM_REGISTER + << ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: " + << useTensorRegisterAllocatorHook_ +#endif // NCCL_HAS_COMM_REGISTER + << ", TORCH_NCCL_ENABLE_MONITORING: " + << monitorThreadEnabled_.load() + << ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_ + << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << traceBufferSize_ + << ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_ + << ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_ + << ", TORCH_NCCL_CUDA_EVENT_CACHE: " << cudaEventCacheEnabled_ + << ", TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: " + << logCppStackOnUncleanShutdown_; + + getGlobalRankStartAndStride( + options_->global_ranks_in_group, + this->globalRankStart, + this->globalRankStride); + + // Attach hooks to cache allocator to trigger the hooks whenever a traced + // action is called. In the following hooks, we register a newly allocated + // segment when SEGMENT_ALLOC action occurs, and deregister a segment when + // SEGMENT_FREE action occurs. + // We attach hooks only once at the first PG creation. + // Attaching hooks fails if CUDACachingAllocator is not initialized, so + // Init for CUDA is called (and is a no-op if CUDA is already + // initialized). + if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) { + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorRegisterHook); + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorDeregisterHook); + allocatorHooksAttached = true; + } + + // Enable Desync Debugger per user setting + if (desyncDebug_) { + desyncDebugger_.init(rank, size, store_); + } +} + +void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) { + const auto key = getKeyFromDevice(device); + LOG(INFO) << logPrefix() << "Eagerly connecting nccl backend with device " + << device; + initNCCLComm(key, device, OpType::ALLREDUCE); +} + +bool ProcessGroupNCCL::useNonblocking() { +#ifndef NCCL_HAS_COMM_NONBLOCKING + return false; +#endif // NCCL_HAS_COMM_NONBLOCKING + // Already parsed, return the cached value + if (useNonblocking_.has_value()) { + return useNonblocking_.value(); + } + // Get environment variable. + auto nbEnv = c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING"); + + // 1st priority: Respect the user's setting + if (options_->config.blocking != NCCL_CONFIG_UNDEF_INT) { + useNonblocking_ = options_->config.blocking == 0; + } + // 2nd priority: Respect the environment variable + else if (nbEnv.has_value()) { + useNonblocking_ = nbEnv; + } + // 3rd priority: automatically use nonblocking if we are in eager init mode + else if (getBoundDeviceId()) { + useNonblocking_ = true; + } + // 4th priority: otherwise, nonblocking = false to preserve old behavior + else { + useNonblocking_ = false; + } + + LOG(INFO) << logPrefix() + << "Using non-blocking mode: " << useNonblocking_.value(); + return useNonblocking_.value(); +} + +void ProcessGroupNCCL::performNocolorSplit(at::Device device) { + // If our backend doesn't support splitting, this is a no-op for + // ranks not in the new subgroup (and ranks that would be in it will + // just use a new communicator rather than split). +#ifdef NCCL_HAS_COMM_SPLIT + const auto key = getKeyFromDevice(device); + LOG(INFO) << logPrefix() << "Performing nocolor split on backend device " + << device << ", key " << key << ", i am " << this; + bool useNb = useNonblocking(); + options_->config.blocking = useNb ? 0 : 1; + auto comm = getNCCLComm(key); + if (comm == nullptr) { + LOG(ERROR) << logPrefix() + << "No parent communicator exists for nocolor split"; + } + NCCLComm::split( + comm.get(), + NCCL_SPLIT_NOCOLOR, + rank_, + options_->config, + options_->global_ranks_in_group); +#endif // NCCL_HAS_COMM_SPLIT +} + +bool ProcessGroupNCCL::isInitialized() { + if (devNCCLCommMap_.empty()) { + return false; + } + std::lock_guard lock(mutex_); + bool initialized = true; + for (const auto& [_, comm] : devNCCLCommMap_) { + if (!comm->isInitialized()) { + initialized = false; + break; + } + } + return initialized; +} + +ErrorType ProcessGroupNCCL::getError() { + std::lock_guard lock(errorMutex_); + return error_; +} + +void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { + const auto key = std::to_string(pool->device()); + auto device = at::Device(at::DeviceType::CUDA, pool->device()); + LOG(INFO) << logPrefix() + << "Performing NCCL user buffer registration for all buffers in " + << "MemPool: " << pool->id() << ", device index: " << key + << ", i am " << this; + auto ncclComm = getNCCLComm(key); + if (ncclComm == nullptr) { + // HACK: currently we are using this function for NVLS + // reductions, and that's why using OpType::ALLREDUCE. + // If we end up using this API for zero-copy P2P, we might + // need to refactor and account for different OpType. + ncclComm = initNCCLComm(key, device, OpType::ALLREDUCE); + } + TORCH_INTERNAL_ASSERT(ncclComm != nullptr); + auto ctx = c10::cuda::MemPoolContext(pool); + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); + for (const auto& segmentInfo : snapshot.segments) { + TORCH_INTERNAL_ASSERT( + segmentInfo.device == pool->device(), + "Mismatch between CUDA memory segment device and pool's device"); + ncclComm->registerSegment( + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(segmentInfo.address), + segmentInfo.total_size, + /*errorOnRereg=*/false); // ignores reregistration error + } +} + +void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) { + const auto key = std::to_string(pool->device()); + auto device = at::Device(at::DeviceType::CUDA, pool->device()); + LOG(INFO) << logPrefix() + << "Performing NCCL user buffer deregistration for all buffers in " + << "MemPool: " << pool->id() << ", device index: " << key + << ", i am " << this; + auto ncclComm = getNCCLComm(key); + if (ncclComm == nullptr) { + // HACK: currently we are using this function for NVLS + // reductions, and that's why using OpType::ALLREDUCE. + // If we end up using this API for zero-copy P2P, we might + // need to refactor and account for different OpType. + ncclComm = initNCCLComm(key, device, OpType::ALLREDUCE); + } + TORCH_INTERNAL_ASSERT(ncclComm != nullptr); + auto ctx = c10::cuda::MemPoolContext(pool); + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); + for (const auto& segmentInfo : snapshot.segments) { + TORCH_INTERNAL_ASSERT( + segmentInfo.device == pool->device(), + "Mismatch between CUDA memory segment device and pool's device"); + // NOLINTNEXTLINE(performance-no-int-to-ptr) + ncclComm->deregisterSegment(reinterpret_cast(segmentInfo.address)); + } +} + +c10::intrusive_ptr ProcessGroupNCCL:: + initIntraNodeComm() { + using IntraNodeComm = intra_node_comm::IntraNodeComm; + if (!IntraNodeComm::isEnabled()) { + return nullptr; + } + auto prefixStore = c10::make_intrusive("IntraNodeComm", store_); + auto comm = c10::make_intrusive(prefixStore, rank_, size_); + if (comm->rendezvous()) { + return comm; + } else { + return nullptr; + } +} + +void ProcessGroupNCCL::setSequenceNumberForGroup() { +} // NCCL just starts sequence numbers at 0. + +uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() { + return seqCollective_; +} + +void ProcessGroupNCCL::registerOnCompletionHook( + std::function)>&& hook) { + TORCH_WARN_ONCE( + "ProcessGroupNCCL OnCompletion hook will be deprecated in favor of Flight Recorder. " + "Please check out FlightRecorder.hpp for information that is recorded at work completion. " + "You can file an issue if you want additional information to be recorded. " + "You can also file an RFC if you want Flight Recorder to accept plugins that customize the recording.") + + TORCH_CHECK_WITH( + DistBackendError, + onCompletionHook_ == nullptr, + "ProcessGroupNCCL OnCompletion hook already registered"); + + TORCH_CHECK_WITH( + ValueError, + enableTiming_.load(), + "ProcessGroupNCCL OnCompletion hook requires recording start and end " + "events which require setting TORCH_NCCL_ENABLE_TIMING environment variable. " + "This is only available for NCCL version >= 2.4."); + onCompletionHook_ = std::move(hook); + onCompletionHookThread_ = std::thread(&ProcessGroupNCCL::runHookLoop, this); +} + +// must release GIL when calling this method +void ProcessGroupNCCL::waitForPendingWorks() { + // Reasoning about hook completion: + // 1. waitForPendingWorks should be called after user code has finished + // calling + // all collectives. This means, when we got here, all of the collectives + // are either in workMetaList_ or has been erased from workMetaList_. + // 2. The watchdog thread grabs both locks to move Work object from the + // workMetaList_ to the completedWorkList_, and the hook thread only erases + // a Work object after the hook is returned. Therefore, after user code + // calls a collective, its Work object is either in workMetaList_ or in + // completedWorkList_ before it finishes. + // 3. We have three threads and two locks. + // a. main thread (this function) grabs two locks atomically + // b. watchdog thread (watchdogHandler function) always grabs + // workMetaListMutex_ + // first and then grabs completedWorkListMutex_. + // c. hook thread (runHookLoop function) only grabs + // completedWorkListMutex_. Therefore, locks are always acquired in the + // same order and hence no deadlocks. + while (true) { + { + std::lock(workMetaListMutex_, completedWorkListMutex_); + std::lock_guard lockWork(workMetaListMutex_, std::adopt_lock); + std::lock_guard lockHook( + completedWorkListMutex_, std::adopt_lock); + + if (workMetaList_.empty() && completedWorkList_.empty()) { + return; + } + } + + std::this_thread::sleep_for( + std::chrono::milliseconds(kWatchdogThreadSleepMillis)); + } +} + +void ProcessGroupNCCL::enableCollectivesTiming() { + enableTiming_.store(true); +} + +bool ProcessGroupNCCL::waitForFutureOrTimeout( + std::future& fut, + const std::chrono::milliseconds& timeOutMilSec, + const std::string& futDescription, + ::c10d::C10dLoggingData& debugLog, + bool throwException) { + std::string errorMsg; + bool complete = false; + + TORCH_CHECK(fut.valid(), "Expected a valid future"); + std::future_status status = fut.wait_for(timeOutMilSec); + if (status == std::future_status::ready) { + // Calling .get() will re-raise any exception from the future, and we don't + // care about the retval + try { + bool result = fut.get(); + if (result) { + VLOG(2) << logPrefix() + << "future successfully executed for: " << futDescription; + debugLog.strings["status"] = "SUCCESS"; + complete = true; + } + } catch (const std::exception& e) { + errorMsg = c10::str( + logPrefix(), + "Exception thrown when waiting for future ", + futDescription, + ": ", + e.what()); + + debugLog.strings["status"] = "EXCEPTION"; + debugLog.strings["exception"] = e.what(); + LOG(ERROR) << errorMsg; + } catch (...) { + errorMsg = c10::str( + logPrefix(), + "Unknown exception thrown when waiting for future ", + futDescription); + debugLog.strings["status"] = "EXCEPTION"; + debugLog.strings["exception"] = "Unknown exception"; + LOG(ERROR) << errorMsg; + } + } else { + errorMsg = c10::str( + logPrefix(), + "Future for ", + futDescription, + " timed out after ", + timeOutMilSec.count(), + " ms"); + debugLog.strings["status"] = "TIMEOUT"; + LOG(ERROR) << errorMsg; + } + if (throwException && !errorMsg.empty()) { + C10_THROW_ERROR(DistBackendError, errorMsg); + } + return complete; +} + +void ProcessGroupNCCL::abortCommsFromMap( + std::unordered_map>& ncclCommsMap, + const std::optional& abortReason) { + // The process may control multiple devices, loop through the communicators on + // each device + for (auto& it : ncclCommsMap) { + auto& devName = it.first; + auto& ncclComm = it.second; + VLOG(2) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " + << ncclComm->repr() << " on CUDA device: " << devName; + // abort() call now has GPU guard inside + ncclComm->abort(abortReason); + // Note that we don't remove the aborted communicators from the + // cache. The reason is that if we do remove the communicator + // from the cache, it is possible that a new collective operation + // calls `ncclCommInitRank` to create a new communicator whereas + // other ranks might have failed/timed out and didn't enter + // `ncclCommInitRank`. As a result, when there is a failure on + // a communicator the application receives an exception and its + // their responsibility to destroy the process group and recreate + // it to recover from errors. + + VLOG(2) << logPrefix() << "ProcessGroupNCCL destroyed " + << " communicator on CUDA device: " << devName; + } +} + +// Abort all communicators on this rank +// Note: original name of this method is `abort`. It was renamed to +// `abortComms` to distinguish from the `abort` method below. The `abort` +// method calls `abortComms` but does more destruction than the latter. +bool ProcessGroupNCCL::abortComms( + const std::optional& abortReason) { + // Remove record from global ncclCommDevIdxMapMutex before aboarting, + // so that a new cache segment would not register to already aborded + // communicators. Note that ncclCommDevIdxMap is a global container which may + // contain other PG's communicators, thus we need to only erase communicators + // for the current PG. + ncclCommDevIdxMapMutex.lock(); + for (auto& it : devNCCLCommMap_) { + auto& ncclComm = it.second; + ncclCommDevIdxMap.erase(ncclComm); + } + ncclCommDevIdxMapMutex.unlock(); + + std::lock_guard lock(mutex_); + abortCommsFromMap(devNCCLCommMap_, abortReason); + abortCommsFromMap(inInitializationCommMap_, abortReason); + return true; +} + +// Abort this backend. +void ProcessGroupNCCL::abort() { + // This will log counter for how long the abort actually takes. + STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort); + + // Don't join threads here since the purpose of this method is to abort all + // communicators and signal the threads to exit. Joining on the threads could + // potentially block and hence avoid it in this method. + terminateProcessGroup_.store(true); + workMetaListCV_.notify_one(); + + // lauch abort asynchrounously and wait for it to complete or timeout + LOG(INFO) << logPrefix() + << "Launching ProcessGroupNCCL abort asynchrounously."; + std::future fut = + std::async(std::launch::async, [this]() { return this->abortComms(); }); + + ::c10d::C10dLoggingData debugLog; + waitForFutureOrTimeout( + fut, options_->timeout, "ProcessGroup abort", debugLog, true); + LOG(INFO) << logPrefix() << "ProcessGroupNCCL aborts successfully."; + + // We need to wait for abort to finish before we can safely shut down + // heartbeat monitoring thread. + terminateHeartbeatMonitorThread_.store(true); + monitorWakeUpCV_.notify_one(); +} + +// Difference between `abort()` and `shutdown()`: +// 1. `abort()` will signal communicators to terminate all NCCL kernels +// immediately. +// 2. `shutdown()` will wait for all NCCL kernels to finish before destroying +// communicators. + +// Destroy (shutdown) this backend -- normal exit. +void ProcessGroupNCCL::shutdown() { + LOG(INFO) << logPrefix() + << "Starting to destroy process group, flushing operations."; + // Flush all collectives + { + std::lock_guard lock(mutex_); + for (auto& it : devNCCLCommMap_) { + auto& ncclComm = it.second; + ncclComm->finalize(); + } + } + // Wait for all operations to complete. If NCCL comm is non-blocking and + // timeout is reach, this will throw an exception. + for (auto& it : devNCCLCommMap_) { + auto& ncclComm = it.second; + // Use long interval to avoid acquiring CPU too frequently + ncclComm->waitReady(true); + } + // Deregister memory pool after finalizing all collectives + if (memPool_) { + try { + deregisterMemPool(memPool_.get()); + } catch (...) { + LOG(ERROR) << logPrefix() << "Failed to deregister memory pool, ignoring"; + } + } + // Tell watchdog to (1) flush its queue and (2) do not use comm objects + // anymore because I am going to destroy them now + LOG(INFO) << logPrefix() << "Operations flushed, joining watchdog thread."; + terminateProcessGroup_.store(true); + workMetaListCV_.notify_one(); + if (ncclCommWatchdogThread_.joinable()) { + ncclCommWatchdogThread_.join(); + } + if (onCompletionHookThread_.joinable()) { + onCompletionHookThread_.join(); + } + // Watchdog thread exiting, retire heartbeat monitoring thread now to avoid + // false alarm + terminateHeartbeatMonitorThread_.store(true); + monitorWakeUpCV_.notify_one(); + // Destroy the communicator, reclaim resources + LOG(INFO) << logPrefix() << "Watchdog joined, destroying NCCL communicators."; + { + std::lock_guard lock(mutex_); + for (auto& it : devNCCLCommMap_) { + auto& ncclComm = it.second; + ncclComm->destroy(); + } + } + LOG(INFO) << logPrefix() << "Destroy complete."; +} + +// NOLINTNEXTLINE(bugprone-exception-escape) +ProcessGroupNCCL::~ProcessGroupNCCL() { + LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered."; + + // `shutdown()` or `abort` already called. Skip the favor of disposing + // communicators. + if (!terminateProcessGroup_.load()) { + // If user haven't explicitly destroy/shutdown process group, destructor + // needs to do so + // First print warning on first rank of each node + if (rank_ % localDeviceCount_ == 0) { + TORCH_WARN_ONCE( + "WARNING: destroy_process_group() was not called before program exit, " + "which can leak resources. For more info, please see " + "https://pytorch.org/docs/stable/distributed.html#shutdown"); + } + + // Note 1: in distributed_c10d.py, a reference to PG is held by the global + // context. Therefore, we are here only when the global context is tearing + // down, which means the entire program is exiting. At this point, user + // will no longer care about the result of any collective, thus we can use + // abort instead of destroy to make the destruction non-blocking. + + // TODO: Note 1 is not true in case of a C++ program using libtorch, which + // does not have the global context mentioned. In that case, calling + // `abort()` here could lead to corrupted result. We should consider not + // doing anything and just let things leak. Adversarial example: + /* + Work routine(Tensor& t) { + pg = ProcessGroupNCCL(…); + w = pg.allReduce(t); + return w; + } + */ + abort(); + } + + // Make sure we've told threads to stop; doesn't hurt if we'd done so before. + // Tell watchdog and onCompletionHook: + terminateProcessGroup_.store(true); + workMetaListCV_.notify_one(); + // Tell heartbeat thread: + terminateHeartbeatMonitorThread_.store(true); + monitorWakeUpCV_.notify_one(); + + // Wait for all threads to finish before returning + if (ncclCommWatchdogThread_.joinable()) { + ncclCommWatchdogThread_.join(); + LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined."; + } + if (ncclHeartbeatMonitorThread_.joinable()) { + ncclHeartbeatMonitorThread_.join(); + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL heart beat monitor thread joined."; + } + if (onCompletionHookThread_.joinable()) { + onCompletionHookThread_.join(); + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL onCompletionHookThread thread joined."; + } +} + +void ProcessGroupNCCL::terminateProcess(const std::string& errMsg) { + // Logging with `FATAL`, after errMsg printed, it calls `std::abort()` + // to terminate the program execution. + LOG(FATAL) << logPrefix() << errMsg; +} + +static long computeDeltaMS( + std::chrono::time_point start, + std::chrono::time_point end) { + return std::chrono::duration_cast(end - start) + .count(); +} + +std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutErrorMsg( + const std::string& extraMsg) { + return c10::str( + logPrefix(), + "Received a dump signal due to a collective timeout from ", + extraMsg, + " and we will try our best to dump the debug info. ", + "Last enqueued NCCL work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pgStatus_->lastCompletedSeq, + ".", + "This is most likely caused by incorrect usages of collectives, e.g., wrong ", + "sizes used across ranks, the order of collectives is not same for all ranks ", + "or the scheduled collective, for some reason, didn't run. Additionally, ", + "this can be caused by GIL deadlock or other reasons such as network errors or ", + "bugs in the communications library (e.g. NCCL), etc. "); +} + +std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutExitMsg( + const std::string& exitReason) { + return c10::str( + logPrefix(), + "Terminating the process after attempting to dump debug info, due to ", + exitReason, + "."); +} + +void ProcessGroupNCCL::heartbeatMonitor() { + c10::setThreadName("pt_nccl_heartbt"); + + uint64_t heartBeatCounter = 0ULL; + std::string errorMsg; + std::string exitReason; + bool checkDumpSignal = (dumpOnTimeoutOrEx_ && local_id_ == 0); + int monitorPollInterval = checkDumpSignal || propagatePgError_ + ? coordCheckIntervalMilSec_ + : heartbeatTimeoutInSec_ * 1000; + auto lastTimePollStore = std::chrono::steady_clock::now(); + auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now(); + std::optional dumpPipe = std::nullopt; + if (local_id_ == 0) { + // DumpPipe is one per-trainer process, and its convenient to name them + // after 'global' ranks in the system, So we assume processgroup (uid)==0 is + // the global PG and has globally unique rank ids across trainers. + dumpPipe.emplace(rank_); + } + while (true) { + // This won't have any lock since this lock is only used here. + // Please be aware that mutex `monitorMutex_` should not be used + // somewhere else to avoid the deadlock. + std::unique_lock lock(monitorMutex_); + if (monitorWakeUpCV_.wait_for( + lock, std::chrono::milliseconds(monitorPollInterval), [&] { + return terminateHeartbeatMonitorThread_.load(); + })) { + // For the normal complete or user interception, monitorWakeUpCV_ + // will get notified, we early return and exit heartbeatMonitor. + return; + } + auto currentTime = std::chrono::steady_clock::now(); + + if (propagatePgError_) { + // Check and set remote error if it has not been set before + checkAndSetRemoteError(); + } + + // We put extra functionality in the thread for the default PG (aka, + // local_id_=0) because the signal is same across different PGs. We only + // need to run once per process to avoid duplicate things performed in too + // many separate threads. For example, we check a global flag on the + // TCPStore periodically to see if any PG on any rank observed a timeout and + // signaled peers to dump debugging info, and we avoid hammering the + // TCPStore from all PGs on the same rank. + if (checkDumpSignal) { + // There are two scenarios where monitor thread will dump on timeout: + // 1. The current rank is the first to observe a timeout in watchdog. + // (shouldDump_ was set to true by the watchdog thread). + // 2. Other ranks detected the timeout and signal the current rank to + // dump. In addtion, monitor threads will dump if watchdog threads has no + // heartbeat or dumpPipe is not empty. + if (shouldDump_.load()) { + errorMsg = getNCCLWatchdogTimeoutErrorMsg("this local rank"); + exitReason = "collective timeout or exception"; + break; + } + // We poll store to see if some ranks have flagged a timeout when + // we haven't polled for `heartbeat_timeout` seconds and there haven't + // any work added or removed for `watchdog_timeout` seconds. + if (computeDeltaMS(lastWorkListUpdateTime_, currentTime) >= + kWatchdogThreadSleepMillis && + computeDeltaMS(lastTimePollStore, currentTime) >= + coordCheckIntervalMilSec_) { + lastTimePollStore = currentTime; + auto handleError = [&](const std::string& errorMessage) { + LOG(WARNING) + << logPrefix() + << "Failed to check the \"should dump\" flag on TCPStore, " + << "(maybe TCPStore server has shut down too early), with error: " + << errorMessage; + // We give up for now assuming TCPStore has been torn down. + return; + }; + // Wrap globalStore_->check() in a try-catch block to avoid crashing if + // the store is not available. + bool checkExceptionDump = false; + try { + checkExceptionDump = + globalStore_->check({std::string(kStoreDumpKey)}); + } catch (const c10::DistNetworkError& e) { + handleError(e.msg()); + } catch (const std::exception& e) { + handleError(e.what()); + } + + if (checkExceptionDump) { + int timeOutRank = -1; + if (!shouldDump_.load()) { + LOG(ERROR) + << logPrefix() + << "Observed flight recorder dump signal from another rank via TCPStore."; + } + shouldDump_.store(true); + try { + auto vec = globalStore_->get(std::string(kStoreDumpKey)); + TORCH_CHECK_WITH( + DistBackendError, + vec.size() == sizeof(int), + "Invalid size for the timeout rank ID"); + std::memcpy(&timeOutRank, vec.data(), vec.size()); + } catch (const std::exception& e) { + LOG(ERROR) << logPrefix() + << "Failed to get timeout rank ID from TCPStore." + << e.what(); + } + errorMsg = + getNCCLWatchdogTimeoutErrorMsg(c10::str(" rank ", timeOutRank)); + exitReason = "collective timeout or exception"; + break; + } + } + } + + if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >= + heartbeatTimeoutInSec_ * 1000l) { + // Check the heart beat of watchdog thread. + lastTimeHeartBeatCheck = currentTime; + auto heartbeat = heartbeat_.load(); + if (heartbeat != heartBeatCounter) { + heartBeatCounter = heartbeat; + } else { + shouldDump_.store(true); + // Watchdog heartbeat timeout. + errorMsg = c10::str( + logPrefix(), + "ProcessGroupNCCL's watchdog got stuck for ", + heartbeatTimeoutInSec_, + " seconds without making progress in monitoring enqueued collectives. ", + "This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, ", + "and could be triggered by another thread holding the GIL inside a ", + "CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.", + "If you suspect the watchdog is not actually stuck and a longer timeout would help, ", + "you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value " + "or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0)." + "If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout " + "or false positive abort; otherwise, please attempt to debug the hang. "); + exitReason = "ProcessGroupNCCL watchdog hang"; + break; + } + } + // process a request to dump the trace. only PG uid 0 will respond to dump + // requests, but this is fine since all PG's feed into the same flight + // recorder and dump. After dump, the training should continue. + if (dumpPipe.has_value() && dumpPipe->shouldDump()) { + // best effort dump, not waiting for the dump here + std::future fut = std::async( + std::launch::async, [this]() { return this->dumpDebuggingInfo(); }); + } + } + LOG(ERROR) << errorMsg; + + // We perform some checks to help users debug the timeout/hang issue: + // 1. Dump the nccl trace (flight recorder) to help debug the issue + // (timeout after waitTimeoutDumpInMilSec_, which is one minute). + // 2. Check if there is a GIL deadlock (timeout after 300ms). + // 3. Try to dump the c++ stacktraces (blocking and would hang, + // users can turn this off by set + // TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN=0). + + // Dump the nccl trace (flight recorder). + if (checkDumpSignal && shouldDump_.load()) { + // Store debug info to storage if no other thread does it. (By default to + // local disk) + bool dumpStackTrace = true; + ::c10d::C10dLoggingData debugLog; + debugLog.integers["pg_id"] = static_cast(local_id_); + debugLog.integers["rank"] = rank_; + debugLog.integers["global_rank"] = globalRank(); + debugLog.integers["world_size"] = getSize(); + for (int i = 0; i < 2; i++) { + std::future asyncDebugDump = + std::async(std::launch::async, [this, dumpStackTrace]() { + return this->dumpDebuggingInfo(dumpStackTrace); + }); + + // wait for the dump until timeout - log data + auto complete = waitForFutureOrTimeout( + asyncDebugDump, + std::chrono::milliseconds(waitTimeoutDumpInMilSec_), + "Flight recorder dump in heartbeatMonitor", + debugLog, + false); + + if (complete) { + LOG(INFO) + << logPrefix() + << "Finished flight recorder successfully. Output can be analyzed using the fr_trace script."; + break; + } + // If we failed to dump, try dumping without stack trace in the 2nd + // iteration. + dumpStackTrace = false; + } + debugLog.integers["trace_enabled"] = int64_t(dumpStackTrace); + auto logger = c10d::C10dLogger::getLogger(); + if (logger) { + logger->log(debugLog); + } + // Indicate to watchdog thread that we have finished dumping. + promiseFlightRecorderDump_.set_value(); + } + + // GIL deadlock check. + if (get_gil_checker() != nullptr) { + auto fut = launchAsyncGilCheck(); + auto kGilCheckTimeout = std::chrono::milliseconds(300); + auto futStatus = fut.wait_for(kGilCheckTimeout); + if (futStatus != std::future_status::ready) { + TORCH_CHECK( + futStatus != std::future_status::deferred, + "Expected the future to have been launched eagerly."); + LOG(ERROR) + << logPrefix() + << "Could not acquire GIL within 300 ms on exit, possible GIL induced hang"; + } + } else { + VLOG(2) + << logPrefix() + << "GIL checker was not registered, perhaps this is a no-python build?"; + } + + // Dump the c++ stacktraces. + auto& cpp_dumper = get_cpp_trace_dumper(); + if (logCppStackOnUncleanShutdown_ && cpp_dumper.has_value()) { + LOG(INFO) << logPrefix() << "Dumping c++ stacktraces:"; + cpp_dumper.value()( + [&](const std::string& line) { LOG(INFO) << logPrefix() << line; }); + LOG(INFO) << logPrefix() << "Finished c++ stacktraces dump."; + } + + // There are two possible cases for the watchdog thread exit: + // Case one: desync report runs quickly, and it follows the step: + // collective timeout -> desync -> exception handling -> destructors + // -> set terminateHeartbeatMonitorThread_ -> notify monitorWakeUpCV_. + // So the code either early returns above or will skip the sleep below. + // Case two: desync might be slow or get stuck. Or we get stuck in + // destructors, we will sleep for some time before calling std::abort() to + // kill the whole process. + if ((terminateProcessGroup_.load() || desyncDebug_ || shouldDump_.load()) && + !terminateHeartbeatMonitorThread_.load()) { + // Leave another two mins for desync report generation or process group + // destroy. + std::this_thread::sleep_for(std::chrono::seconds(heartbeatTimeoutInSec_)); + LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ + << " waiting for desync report or process group destroy."; + } + + // At this point, we either already sleep for another `heartbeatTimeoutInSec_` + // or the thread has finished. Because we don't want to block the monitor + // thread, so We mark the thread detach and the dump of debug info becomes + // "best effort". If the process exit normally, marking it detach also makes + // sense because we don't really care about dumping the debug info. + + // We already log completion inside the thread, so it may not be necessary to + // check the return value here. We mainly use a future so we can exit early + // if done. + + if (!terminateHeartbeatMonitorThread_.load()) { + // Create a error message reported from MonitorThread, so + // we throw exception and make the whole process to be killed. + // TODO(fduwjj): After having a hang debug wiki, we need to update the wiki + // url here. + if (monitorThreadEnabled_.load()) { + terminateProcess(getNCCLWatchdogTimeoutExitMsg(exitReason)); + } else { + // Ideally we want to merge this one with the above one, but we are going + // to remove the kill switch for monitor thread soon, so we keep this one + // for now. + LOG(ERROR) + << logPrefix() + << "ProcessGroupNCCL monitor thread is disabled, but would have terminated the process" + << "after attempting to dump debug info, due to " << exitReason + << "."; + } + } +} + +void ProcessGroupNCCL::ncclCommWatchdog() { + c10::setThreadName("pt_nccl_watchdg"); + + try { + VLOG(2) << logPrefix() << "Process group watchdog thread started!"; + ncclHeartbeatMonitorThread_ = + std::thread(&ProcessGroupNCCL::heartbeatMonitor, this); + watchdogHandler(); + VLOG(2) << logPrefix() + << "Process group watchdog thread terminated normally"; + } catch (std::exception& e) { + if (std::string(e.what()).find("driver shutting down") != + std::string::npos) { + VLOG(2) + << logPrefix() + << "main process destroyed cuda before watchdog loop exited, terminating watchdog." + << " (Watchdog caught exception: " << e.what(); + + } else { + // Append error message reported from watchdogHandler + const auto exitMsg = c10::str( + logPrefix(), + "Process group watchdog thread terminated with exception: ", + e.what()); + LOG(ERROR) << exitMsg; + if (C10_LIKELY(rethrowCUDAErrors_) || + !(std::string(e.what()).find("CUDA Error"))) { + // TODO(whc) clean up the rethrow - why is it stored in a class var and + // rethrown? + watchDogException_ = + std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg)); + std::rethrow_exception(watchDogException_); + } + } + } catch (...) { + const auto exitMsg = c10::str( + logPrefix(), + "Process group watchdog thread terminated with exception: unknown"); + LOG(ERROR) << exitMsg; + watchDogException_ = + std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg)); + std::rethrow_exception(watchDogException_); + } +} + +// Initialize and enable DesyncDebugger +void ProcessGroupNCCL::DesyncDebugger::init( + int rank, + int size, + c10::intrusive_ptr store) { + rank_ = rank; + size_ = size; + store_ = std::move(store); + enabled_ = true; + traceKeyStart_ = getTraceStartKey("NCCL", rank); + traceKeyEnd_ = getTraceEndKey("NCCL", rank); +} + +// Run desync debug. This function is called by watchdog at time of timeout. +void ProcessGroupNCCL::DesyncDebugger::run() { + if (!enabled_) + return; + auto logPrefix = c10::str("Rank ", rank_); + try { + std::string desyncMsg = retrieveDesyncReport(store_, "NCCL", rank_, size_); + LOG(ERROR) << logPrefix << desyncMsg; + } catch (const std::exception& e) { + enabled_ = false; + LOG(ERROR) << logPrefix + << " Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " + << " Please file an issue. Error: " << e.what(); + } catch (...) { + enabled_ = false; + LOG(ERROR) + << logPrefix + << " Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." + << " Please file an issue."; + } +} + +// Log work start to store. +void ProcessGroupNCCL::DesyncDebugger::logWorkStart(WorkNCCL& work) { + if (!enabled_) + return; + if (work.startTraceUpdated_) + return; + + work.startTraceUpdated_ = true; + // If not successful, disable the debugger + enabled_ = c10d::traceUpdate( + store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_)); +} + +// Log work end to store. +void ProcessGroupNCCL::DesyncDebugger::logWorkEnd(WorkNCCL& work) { + if (!enabled_) + return; + + // In case the start of the work hasn't been logged + if (!work.startTraceUpdated_) { + logWorkStart(work); + } + + // If not successful, disable the debugger + enabled_ = c10d::traceUpdate( + store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_)); +} + +// We want to have both PG ID and global unique ID (guid) for the logging +// prefix. PG ID records how many ProcessGroupNCCL objects were created on a +// specific rank and is a stable index across ranks, which lets users reason +// about, for example, the second PG we initialized on this rank is for FSDP, +// and corresponds with PG ID = 1 on other ranks as well. Unlike PG ID, guid (or +// group name) is a global unique ID across ranks. The guid is either a hash of +// all the ranks in the group or a counter of how many times +// `_process_group_name` is called, essentially it means how many times we +// have PGs users have created. Before using split_group, even if +// we are creating a new sub-PG, all ranks have to call the API at the same +// time, and this makes `group_name` a unique identifier for a group (PG). +std::string ProcessGroupNCCL::createLogPrefix() const { + if (!pg_desc_.empty() && pg_desc_ != "undefined") { + return c10::str( + "[PG ID ", + local_id_, + " PG GUID ", + pg_uid_, + "(", + pg_desc_, + ") Rank ", + rank_, + "] "); + } + return c10::str( + "[PG ID ", local_id_, " PG GUID ", pg_uid_, " Rank ", rank_, "] "); +} + +const std::string& ProcessGroupNCCL::logPrefix() const { + return logPrefix_; +} + +const int& ProcessGroupNCCL::globalRank() const { + static int globalRank = rank_; + return globalRank; +} + +const std::vector& ProcessGroupNCCL::groupRanks() const { + if (options_->global_ranks_in_group.empty() && local_id_ == 0) { + static std::vector globalRanks(size_); + std::iota(globalRanks.begin(), globalRanks.end(), 0); + return globalRanks; + } + return options_->global_ranks_in_group; +} + +void ProcessGroupNCCL::addEphemeralTimeout( + const std::chrono::milliseconds& timeout) { + std::lock_guard timeoutLock(mtxTimeoutExtension_); + ephemeralTimeoutActive_ += timeout; +} + +bool ProcessGroupNCCL::verifyWorkTimeoutForTest( + const c10::intrusive_ptr& work, + const std::chrono::milliseconds& timeout) { + // Since collective returns a c10d::Work, we need to cast it to WorkNCCL. + if (auto workNCCL = c10::dynamic_intrusive_pointer_cast(work)) { + // workNCCL is now a c10::intrusive_ptr + return workNCCL->opTimeout_ == timeout; + } + C10_THROW_ERROR( + DistBackendError, "Non c10d::WorkNCCL object returned from collective"); +} + +void ProcessGroupNCCL::broadcastSignal( + c10::intrusive_ptr& store, + const std::string& signal, + int srcRank) { + try { + auto vec = std::vector( + reinterpret_cast(&srcRank), + reinterpret_cast(&srcRank) + sizeof(srcRank)); + store->set(signal, vec); + LOG(INFO) << logPrefix() << "Broadcasting signal " << signal + << " to other ranks via TCPStore."; + } catch (const std::exception& e) { + LOG(ERROR) << logPrefix() << "Failed to broadcast signal " << signal + << " through TCPStore. Error: " << e.what(); + } +} + +int ProcessGroupNCCL::getSignalSrcRank( + c10::intrusive_ptr& store, + const std::string& signal) { + // This function is 'non blocking'. We first 'check' if the key exists in the + // store, then read/get the value only if the key exists. + int srcRank = -1; + bool signalExists = false; + try { + signalExists = store->check({signal}); + } catch (const std::exception& e) { + LOG(WARNING) << logPrefix() << "Failed to check the signal " << signal + << " on TCPStore, " << e.what(); + } + if (!signalExists) { + return srcRank; + } + + // key exists, now read and parse the value (source rank) + std::vector vec; + try { + vec = store->get(std::string(signal)); + } catch (const std::exception& e) { + LOG(ERROR) << logPrefix() << "Failed to get source rank of the signal " + << signal << " from TCPStore." << e.what(); + } + TORCH_CHECK_WITH( + DistBackendError, + vec.size() == sizeof(int), + "Invalid size for the timeout rank ID"); + std::memcpy(&srcRank, vec.data(), vec.size()); + return srcRank; +} + +void ProcessGroupNCCL::broadcastDumpSignal() { + // broadcast dump signal to all other global ranks. + broadcastSignal(globalStore_, std::string(kStoreDumpKey), globalRank()); + // signal the local rank to start dumping + if (shouldDump_.load()) { + // already signaled dump, skipping signal again and wait for the dump + // future. + return; + } + LOG(ERROR) << logPrefix() << "First PG on this rank to signal dumping."; + // signal the monitor thread on PG0 to start dumping + shouldDump_.store(true); + // Give time for dumping before throwing exception + auto start = std::chrono::steady_clock::now(); + auto status = promiseFlightRecorderDump_.get_future().wait_for( + std::chrono::milliseconds(waitTimeoutDumpInMilSec_)); + if (status == std::future_status::timeout) { + LOG(WARNING) << logPrefix() << "timed out after waiting for " + << waitTimeoutDumpInMilSec_ << "ms" + << " flight recorder dumps to finish."; + } else if (status == std::future_status::ready) { + auto end = std::chrono::steady_clock::now(); + LOG(INFO) << logPrefix() << "slept for " << computeDeltaMS(start, end) + << "ms" << " giving time for flight recorder dumps to finish."; + } +} + +void ProcessGroupNCCL::checkAndSetRemoteError() { + // if the error is already set, no need to check again + if (getError() != ErrorType::SUCCESS) { + return; + } + // key/signal to read from the tcpstore is a string and pg specific: + // format is: remote_error:pg_uid + int remoteErrorRank = getSignalSrcRank( + store_, std::string(kStoreErrorSignalKey) + ':' + pg_uid_); + if (remoteErrorRank != -1) { + std::lock_guard lock(errorMutex_); + error_ = ErrorType::REMOTE_ERROR; + LOG(ERROR) << c10::str( + logPrefix(), " remote error detected from rank: ", remoteErrorRank); + } +} + +// NCCL recommends to evenly distribute ncclUniqueIds across the ranks +// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#init-rank-config +// Let’s consider an example where: +// nRanks = 10 (total ranks), +// nIds = 3 (roots), +// rmr = 10 % 3 = 1 (1 larger group), +// rpr = 10 / 3 = 3 (base number of ranks per group). +// rlim = 4 +// Output root: +// For ranks [0, 1, 2, 3], root rank is 0 and index is 0. +// For ranks [4, 5, 6], root rank is 4 and index is 1. +// For ranks [7, 8, 9], root rank is 7 and index is 2. +static int getRootIndex(const int rank, const int nRanks, const int nIds) { + const int rmr = nRanks % nIds; + const int rpr = nRanks / nIds; + // For the first rmr roots, we assign one more rank to the root. + const int rlim = rmr * (rpr + 1); + if (rank < rlim) { + // Root with `rpr + 1` ranks, (0, 1, 2, ..., rmr - 1). + return rank % (rpr + 1) ? -1 : rank / (rpr + 1); + } else { + // Root with `rpr` ranks, (rmr, rmr + 1, ..., nIds - 1). + return (rank - rlim) % rpr ? -1 : ((rank - rlim) / rpr) + rmr; + } +} + +void ProcessGroupNCCL::watchdogHandler() { + bool done = false; + lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); + auto lastStatusUpdateTime = std::chrono::steady_clock::now(); + std::list completedWorkList; + + while (!done || !terminateProcessGroup_.load()) { + std::unique_lock lock(workMetaListMutex_); + // We busy-poll the work vector every kWatchdogThreadSleepMillis + // milliseconds as long as the atomic is True. + workMetaListCV_.wait_for( + lock, + std::chrono::milliseconds(kWatchdogThreadSleepMillis), + [&]() -> bool { return terminateProcessGroup_.load(); }); + // Bump up heart beat by one. + heartbeat_++; + +// Some versions of GLOG support less-spammy version of LOG_EVERY_MS +// in which case we don't want to spam the logs. +#ifdef LOG_EVERY_MS + // Log the progress of this PG periodically + C10_LOG_EVERY_MS(INFO, kWorkStatusUpdatePeriodMs) << c10::str( + logPrefix(), + "NCCL Work update periodically: ", + "last enqueued NCCL work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pgStatus_->lastCompletedSeq, + "."); +#endif // LOG_EVERY_MS + auto logger = ::c10d::C10dLogger::getLogger(); + if (logger && + computeDeltaMS( + lastStatusUpdateTime, std::chrono::steady_clock::now()) >= + kWorkStatusUpdatePeriodMs) { + ::c10d::C10dLoggingData data; + // logging integers + data.integers["pg_id"] = static_cast(local_id_); + data.integers["rank"] = rank_; + data.integers["global_rank"] = globalRank(); + data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq; + data.integers["last_started_work"] = pgStatus_->lastStartedSeq; + data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq; + data.integers["last_enqueued_numel_in"] = + static_cast(pgStatus_->lastEnqueuedNumelIn); + data.integers["last_enqueued_numel_out"] = + static_cast(pgStatus_->lastEnqueuedNumelOut); + data.integers["last_completed_numel_in"] = + static_cast(pgStatus_->lastCompletedNumelIn); + data.integers["last_completed_numel_out"] = + static_cast(pgStatus_->lastCompletedNumelOut); + data.integers["last_started_numel_in"] = + static_cast(pgStatus_->lastStartedNumelIn); + data.integers["last_started_numel_out"] = + static_cast(pgStatus_->lastStartedNumelOut); + // logging strings + data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName; + data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName; + data.strings["last_completed_work_name"] = + pgStatus_->lastCompletedWorkName; + data.strings["pg_name"] = pg_uid_; + data.strings["pg_desc"] = pg_desc_; + logger->log(data); + lastStatusUpdateTime = std::chrono::steady_clock::now(); + } + + for (auto it = workMetaList_.begin(); it != workMetaList_.end(); + /* no increment */) { + auto& work = *it; + // When terminateProcessGroup_ is true, communicators have already been + // aborted, So cannot check exception based on them. But watchdog needs to + // finish the check for the works that have already been enqueued to + // workMetaList_ + + // check NCCL errors first + if (!terminateProcessGroup_.load()) { + work.checkAndSetException(); + } + + if (work.exception()) { + // set the error to the first error found + std::lock_guard lock(errorMutex_); + if (error_ == ErrorType::SUCCESS) { + error_ = ErrorType::COMM_ERROR; + } + } + + // Then check if work has timed out + // Skip if work has encountered an error + bool timedout = !work.exception() && work.checkTimeout(); + + // Report desync state in case of timeout (if TORCH_NCCL_DESYNC_DEBUG is + // turned on; otherwise, run() is no-op) + if (timedout) { + std::lock_guard lock(errorMutex_); + if (error_ == ErrorType::SUCCESS) { + error_ = ErrorType::TIMEOUT; + } + desyncDebugger_.run(); + } + + // If work hits an exception (either an error or timeout) + if (work.exception()) { + LOG(ERROR) << c10::str( + logPrefix(), + " failure detected by watchdog at work sequence id: ", + work.seq_, + " PG status: last enqueued work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed work: ", + pgStatus_->lastCompletedSeq); + + // Print the traceback of the collective at call time + work.printTraceback(); + + // broadcast remote error signal to all other ranks in this specific PG. + // key/signal to write in the tcpstore is a string and pg specific: + // format is: remote_error:pg_uid + if (propagatePgError_) { + broadcastSignal( + store_, std::string(kStoreErrorSignalKey) + ':' + pg_uid_, rank_); + } + + // try to notify other ranks via global TCPStore to dump the flight + // recorder when a collective timeout or exception happens. Flight + // recorder behavior is independent of desync Debug. + if (dumpOnTimeoutOrEx_) { + broadcastDumpSignal(); + } + + if (SHOULD_CLEAN_UP(asyncErrorHandling_)) { + // Abort work and corresponding communicators + work.abort(); + // PG level abort, which would abort all other communicators on this + // rank + abortComms(); + } + // Throw exception + work.handleException(asyncErrorHandling_); + } + + // Work status logging for desync debug + desyncDebugger_.logWorkStart(work); + + // a work could be started but not completed, so we should not update + // lastStartedSeq and lastStartedOpName if the work state is checked + // multiple times after the start + if (pgStatus_->lastStartedSeq < static_cast(work.seq_) && + work.isStarted()) { + pgStatus_->lastStartedSeq = static_cast(work.seq_); + pgStatus_->lastStartedWorkName = opTypeToString(work.opType_); + pgStatus_->lastStartedNumelIn = work.numelIn_; + pgStatus_->lastStartedNumelOut = work.numelOut_; + } + + // allow watchdog to do an event query on a side thread + at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); + at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; + + // Clean up completed work + if (work.isCompleted()) { + // Work status logging for desync debug + desyncDebugger_.logWorkEnd(work); + + if (work.futureWorkResult_ && work.finishedGPUExecutionInternal() && + !work.futureWorkResult_->completed()) { + work.futureWorkResult_->markCompleted( + at::IValue(static_cast(WorkResult::SUCCESS))); + } + { + // Reset the timeout and first work if the work is completed. + std::lock_guard timeoutLock(mtxTimeoutExtension_); + if (work.ownedEphermeralTimeout_.count() > 0) { + ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; + ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; + } + } + pgStatus_->lastCompletedSeq = static_cast(work.seq_); + pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); + pgStatus_->lastCompletedNumelIn = work.numelIn_; + pgStatus_->lastCompletedNumelOut = work.numelOut_; + if (onCompletionHook_) { + // Move Work object to completedWorkList_ to be consumed by the hook + // thread + { + const std::lock_guard lock(completedWorkListMutex_); + completedWorkList_.splice( + completedWorkList_.end(), workMetaList_, it++); + } + completedWorkListCV_.notify_one(); + } else { + it = workMetaList_.erase(it); + lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); + } + } else { + // Increment the iterator if the current WorkNCCL object is not + // completed. + ++it; + } + // Increment heartbeat after each work processed, + // in case processing is slowed down (but not hung) by cuda api contention + heartbeat_++; + } + done = workMetaList_.empty(); + } +} + +void ProcessGroupNCCL::runHookLoop() { + c10::setThreadName("pt_nccl_runhook"); + + bool done = false; + while (!done || !terminateProcessGroup_.load()) { + std::unique_lock lock(completedWorkListMutex_); + // We busy-poll the work vector every kWatchdogThreadSleepMillis + // milliseconds as long as the atomic is True. + completedWorkListCV_.wait_for( + lock, + std::chrono::milliseconds(kWatchdogThreadSleepMillis), + [&]() -> bool { + return !completedWorkList_.empty() || terminateProcessGroup_.load(); + }); + + try { + for (auto it = completedWorkList_.begin(); it != completedWorkList_.end(); + /* no increment */) { + const WorkNCCL& work = *it; + // Hook might grab GIL, unlock first to prevent deadlock + lock.unlock(); + + auto timeStarted = + std::chrono::system_clock::now() + + std::chrono::duration_cast( + work.workStartTime_ - std::chrono::steady_clock::now()); + onCompletionHook_(std::make_shared( + work.retrieveOpType(), // OpType + work.getSequencenumber(), // seq + timeStarted, // timeStarted + std::chrono::system_clock::now(), // timeFinished + std::chrono::duration( + work.getDuration()) // activeDuration + )); + + lock.lock(); + it = completedWorkList_.erase(it); + } + } catch (std::exception& e) { + if (std::string(e.what()).find("driver shutting down") != + std::string::npos) { + LOG(INFO) + << logPrefix() + << "main process destroyed cuda before runHookLoop exited, terminating runHookLoop." + << " (runHookLoop caught exception: " << e.what(); + + } else { + // PythonOnCompletionHook has already extracted Python exception message + // and wrapped it with a cpp one. So we no longer need to acquire GIL + // here. + const auto errorStr = c10::str( + "Caught exception on rank ", + rank_, + " while running onCompletion hook for ProcessGroupNCCL: ", + e.what(), + ". Aborting all communicators."); + + // No need to call abort() on WorkNCCL here as that collective has + // already finished successfully at this point. We just need to abort + // the process Abort all NCCL Communicators on this ProcessGroupNCCL + // instance. + abortComms(errorStr); + } + } + + // Lock is still acquired at this point + done = completedWorkList_.empty(); + } +} + +std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors() { + return checkForNCCLErrorsInternal(ncclComm_); +} + +std::exception_ptr ProcessGroupNCCL::checkForNCCLErrors( + std::shared_ptr& ncclComm) { + return checkForNCCLErrorsInternal(ncclComm); +} + +std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( + std::shared_ptr& ncclComm) { + // Prioritize commFailureReason over checkForNcclError() result if + // commFailureReason is set. + auto commFailureReason = ncclComm->getNcclCommFailureReason(); + if (commFailureReason != std::nullopt) { + return std::make_exception_ptr(C10_BUILD_ERROR( + DistBackendError, + c10::str( + "NCCL communicator encountered error set by ProcessGroupNCCL: ", + *commFailureReason))); + } + ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError(); + // When nonblocking mode is enabled by TORCH_NCCL_USE_COMM_NONBLOCKING, + // ncclInProgress could be returned when there are pending NCCL calls. + // In this case, no exception should be thrown +#ifdef NCCL_HAS_COMM_NONBLOCKING + // ncclInProgress is defined only if NCCL_HAS_COMM_NONBLOCKING is defined + if (ncclAsyncErr != ncclSuccess && ncclAsyncErr != ncclInProgress) { +#else + if (ncclAsyncErr != ncclSuccess) { +#endif // NCCL_HAS_COMM_NONBLOCKING + return std::make_exception_ptr(C10_BUILD_ERROR( + DistBackendError, + "NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr) + "\n" + + getNcclErrorDetailStr(ncclAsyncErr))); + } + + return nullptr; +} + +void ProcessGroupNCCL::broadcastUniqueNCCLID( + ncclUniqueId* ncclID, + bool isSingleP2POp, + const std::string& p2pKey, + int p2pRank) { + // For collective operations: + // For every NCCL communicator that we create we need to broadcast + // a unique ID from rank 0 to all other ranks. This broadcast is + // done by rank 0 setting a key in the store and all other ranks + // retrieving the contents of that key. A single process group + // may create multiple NCCL communicators, so we use a sequence + // number to differentiate between them. + // For single point-to-point operations: + // The sequence number will only be increased on 2 out of all the + // processes in a Process Group. So all following collective + // operations will see different sequence numbers which will cause + // runtime errors. To avoid that, use the src:target pair instead + // of sequence number for p2p communications. + + std::string storeKey; + if (!isSingleP2POp) { + storeKey = std::to_string(ncclCommCounter_++); + } else { + storeKey = p2pKey; + } + if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) { + auto vec = std::vector( + reinterpret_cast(ncclID), + reinterpret_cast(ncclID) + NCCL_UNIQUE_ID_BYTES); + store_->set(storeKey, vec); + } else { + try { + auto vec = store_->get(storeKey); + TORCH_CHECK_WITH( + DistBackendError, + vec.size() == NCCL_UNIQUE_ID_BYTES, + "Invalid size for ncclUniqueId"); + std::memcpy(ncclID, vec.data(), vec.size()); + } catch (const std::exception& e) { + std::string exceptionMsg = c10::str( + "[", + rank_, + "] is setting up NCCL communicator and " + "retrieving ncclUniqueId from [0] via c10d key-value store by key '", + storeKey, + "', but store->get('", + storeKey, + "') got error: "); + C10_THROW_ERROR( + DistBackendError, + exceptionMsg + e.what() + + ". This may indicate a possible application crash on rank 0 or a network set up issue."); + } catch (...) { + C10_THROW_ERROR( + DistBackendError, + c10::str( + "Unknown exception while [", + rank_, + "] is setting up NCCL communicator and " + "retrieving ncclUniqueId from [0] via c10d key-value store by key '", + storeKey, + "'", + ". This may indicate a possible application crash on rank 0 or a network set up issue.")); + } + } +} + +// We want to all-gather unique NCCL IDs from all roots using TCPStore. +// This is first done by setting the ID by each root and then `multiGet` by all +// ranks. +void ProcessGroupNCCL::allgatherUniqueNCCLIDs( + int rootIdx, + ncclUniqueId* ncclID, + std::vector& ncclIDs) { + std::vector storeKeys; + std::vector> results; + for (size_t r = 0; r < ncclIDs.size(); r++) { + storeKeys.emplace_back("UniqueNCCLID:" + std::to_string(r)); + } + // For non-root rank, rootIdx is set to -1. + if (rootIdx >= 0) { + auto vec = std::vector( + reinterpret_cast(ncclID), + reinterpret_cast(ncclID) + NCCL_UNIQUE_ID_BYTES); + store_->set(storeKeys[rootIdx], vec); + } + try { + results = store_->multiGet(storeKeys); + } catch (const std::exception& e) { + std::string exceptionMsg = c10::str( + "[", + rank_, + "] is setting up NCCL communicators and " + "retrieving ncclUniqueId from roots via TCPStore by key '", + "', but got error: "); + C10_THROW_ERROR( + DistBackendError, + exceptionMsg + e.what() + + ". This may indicate a possible application crash on rank 0 or a network set up issue."); + } catch (...) { + C10_THROW_ERROR( + DistBackendError, + c10::str( + "Unknown exception while [", + rank_, + "] is setting up NCCL communicators and " + "retrieving ncclUniqueIds from roots via TCPStore by key '", + "'", + ". This may indicate a possible application crash on rank 0 or a network set up issue.")); + } + + for (size_t r = 0; r < ncclIDs.size(); r++) { + TORCH_CHECK_WITH( + DistBackendError, + results[r].size() == NCCL_UNIQUE_ID_BYTES, + "Invalid size for ncclUniqueId"); + std::memcpy(&ncclIDs[r], results[r].data(), results[r].size()); + } +} + +void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { + std::lock_guard lock(mutex_); + if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { + TORCH_INTERNAL_ASSERT( + false, + "Expected to find key ", + devNCCLCommMapKey, + " in NCCL communicator map."); + } + std::shared_ptr& ncclComm = devNCCLCommMap_[devNCCLCommMapKey]; + // ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being + // destroyed, so using ncclCommAbort here. + ncclComm->abort(); + // Remove communicators from the cache. + devNCCLCommMap_.erase(devNCCLCommMapKey); + // Clear used device indices. + usedDeviceIdxs_.clear(); + + ncclCommDevIdxMapMutex.lock(); + ncclCommDevIdxMap.erase(ncclComm); + ncclCommDevIdxMapMutex.unlock(); +} + +std::shared_ptr ProcessGroupNCCL::initNCCLComm( + const std::string& deviceKey, + at::Device& device, + OpType opType, + int p2pRank, + bool isSendRecvSelf) { + // Sanity check + if (deviceKey.empty()) { + C10_THROW_ERROR( + DistBackendError, + "Not able to create/get the NCCL Communicator since " + "the GPU devices are not known"); + } + if (bound_device_id_) { + if (*bound_device_id_ != device) { + LOG(ERROR) << logPrefix() << "Tensor found on device " << device + << " but backend constrained to " << *bound_device_id_; + C10_THROW_ERROR( + DistBackendError, + "Attempt to perform collective on tensor not on device passed to init_process_group"); + } + } + + usedDeviceIdxs_.insert(device.index()); + + // NCCL communicator not cached, create a new entry + std::shared_ptr ncclComm; + + // Create the unique NCCL ID and broadcast it + ncclUniqueId ncclID; + + // reset log prefix to include group_desc + logPrefix_ = createLogPrefix(); + +#ifdef NCCL_COMM_DESCRIPTION + // Pass process group name and description to NCCL communicator + std::string commDesc = pg_desc_ + ':' + pg_uid_; + options_->config.commDesc = strdup(commDesc.c_str()); +#endif // NCCL_COMM_DESCRIPTION + + // For batch_isend_irecv, ncclGroupStart() would be called upfront + bool batchP2P = ncclActiveGroupCounter_ > 0; + bool singleP2POp = isP2POp(opType, batchP2P); + + // Get the device index + auto deviceIndex = device.index(); + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // [Group Start/End Note] This is used to ensure that nccl communicator will + // be created before communication primitives are called. Let's look at this + // example: Using the batch_isend_irecv to send a tensor to a target process. + // On the sender side, the corresponding underlying NCCL calls will look like + // ncclGroupStart() // This is in batch_isend_irecv + // ncclCommInitRank() // Inside NCCLComm::create + // ncclSend() + // ncclGroupEnd() // This is in batch_isend_irecv + // With this pattern, the nccl communicator will be created in the last + // ncclGroupEnd which means when ncclSend is processed, the passed + // communicator argument is NULL which will lead to runtime error. So we need + // to "close" all active nccl groups to ensure nccl communicator is actually + // created before encountering any communication calls. This is why we need + // the following for loop. + for (const auto i : c10::irange(ncclActiveGroupCounter_)) { + (void)i; + // comms have not been initiated yet, so can only check in blocking-way + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + } + + // GPU world size and GPU rank + int numRanks = -1, rank = -1; + + if (!singleP2POp) { + // Collective, all-to-all, or batch P2P + numRanks = getSize(); + rank = getRank(); + } else if (isSendRecvSelf) { + // Same process send and recv. + numRanks = 1; + rank = 0; + } else { + // For single point-to-point operation, there are only 2 processes + // involved so the GPU rank is either 0 or 1. + numRanks = 2; + rank = p2pRank; + } + + RECORD_PARAM_COMMS( + std::make_tuple(0, false), // seq + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank, // rank + "init", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + size_); // worldSize + +#ifdef NCCL_HAS_COMM_NONBLOCKING + bool useNb = useNonblocking(); + options_->config.blocking = useNb ? 0 : 1; +#endif // NCCL_HAS_COMM_NONBLOCKING + +#ifdef NCCL_HAS_COMM_SPLIT + // Use split to create a new communicator only if: + // 1. The parent comm is known; AND + // 2. The new comm is not for a point-to-point operation. + // ncclCommSplit() is a collective call, so it does not work for P2P + // operations. + if (options_->split_from && !singleP2POp) { + // Find a valid, healthy communicator to split from if possible. + std::lock_guard lock(options_->split_from->mutex_); + auto& other_comms = options_->split_from->devNCCLCommMap_; + auto dit = other_comms.find(getKeyFromDevice(device)); + if (dit != other_comms.end()) { + auto& parentComm = dit->second; + if (parentComm != nullptr && !parentComm->isAborted()) { + LOG(INFO) << logPrefix() << "Splitting NCCL communicator from " + << parentComm->repr(); + ncclComm = NCCLComm::split( + parentComm.get(), + options_->split_color, + rank, + options_->config, + options_->global_ranks_in_group); + } + } + } +#endif // NCCL_HAS_COMM_SPLIT + + bool useScalableInit = false; + // (nranks / nroots) == 128 was the default NCCL recommended + // accoring to + // https://github.com/pytorch/pytorch/pull/136789#discussion_r1779171615. + auto ranksPerRoot = getCvarInt(TORCH_NCCL_RANKS_PER_ROOT, 128); +#if defined(NCCL_HAS_INIT_RANK_SCALABLE) && defined(NCCL_HAS_CONFIG) + useScalableInit = !singleP2POp && (getSize() > ranksPerRoot); +#endif // NCCL_HAS_INIT_RANK_SCALABLE && NCCL_HAS_CONFIG + + if (useScalableInit) { + auto numRoots = (getSize() + ranksPerRoot - 1) / ranksPerRoot; + std::vector ncclIDs(numRoots); + + if (!ncclComm) { + auto rootIdx = getRootIndex(rank_, getSize(), numRoots); + // We only need to get unique IDs for roots. For non-root rank, index is + // set to -1. + if (rootIdx >= 0) { + C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt); + } + // We only need to all-gather the ncclID if the rank is root. + auto timeStarted = std::chrono::steady_clock::now(); + allgatherUniqueNCCLIDs(rootIdx, &ncclID, ncclIDs); + auto timerDeltaMs = + std::chrono::duration_cast>( + std::chrono::steady_clock::now() - timeStarted) + .count() * + 1000; + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL all-gather unique IDs through store took " + << timerDeltaMs << " ms"; +#if defined(NCCL_HAS_INIT_RANK_SCALABLE) && defined(NCCL_HAS_CONFIG) + ncclComm = + NCCLComm::create_scalable(numRanks, rank, ncclIDs, options_->config); +#else + C10_THROW_ERROR( + DistBackendError, + c10::str( + logPrefix(), + "create_scalable is called when useScalableInit is enabled but ", + "neither NCCL_HAS_INIT_RANK_SCALABLE nor NCCL_HAS_CONFIG is not defined, this should not happen ")); +#endif // NCCL_HAS_INIT_RANK_SCALABLE + } + } else { + // To simplify conditional nesting, just create the ncclComms[i] + // entry if it hasn't been yet rather than untangling the + // conditions that might have resulted in a split above. + if (!ncclComm) { + if (getCvarBool(TORCH_NCCL_BCAST_UNIQUEID, true) && !isSendRecvSelf) { + // For point-to-point communication, lower rank of the two will get + // unique id. + if (rank_ == 0 || (singleP2POp && p2pRank == 0)) { + C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt); + } + + // Broadcast so that each process can have a unique NCCL ID + auto timeStarted = std::chrono::steady_clock::now(); + broadcastUniqueNCCLID(&ncclID, singleP2POp, deviceKey, p2pRank); + auto timerDeltaMs = + std::chrono::duration_cast>( + std::chrono::steady_clock::now() - timeStarted) + .count() * + 1000; + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL broadcast unique ID through store took " + << timerDeltaMs << " ms"; + } + +#ifdef NCCL_HAS_CONFIG + ncclComm = NCCLComm::create( + numRanks, rank, ncclID, deviceIndex, options_->config); +#else + ncclComm = NCCLComm::create(numRanks, rank, ncclID, deviceIndex); +#endif // NCCL_HAS_CONFIG + } + } + + // Creates the NCCL streams + bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false); + auto streamVal = at::cuda::getStreamFromPool( + options_->is_high_priority_stream || force_high); + + { + std::lock_guard lock(mutex_); + inInitializationCommMap_.emplace(deviceKey, ncclComm); + } + + VLOG(2) << logPrefix() << "ProcessGroupNCCL created ncclComm_ " + << ncclComm->repr() + << " on CUDA device: " << static_cast(deviceIndex); + + // At this point NCCL should have been initialized, hence we can accurately + // get the env value even if NCCL sets it by reading from nccl.conf file + LOG(INFO) << logPrefix() + << "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A"); + + // See [Group Start/End Note] + for (const auto i : c10::irange(ncclActiveGroupCounter_)) { + (void)i; + C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); + } + + ncclStreams_.emplace(deviceKey, streamVal); + + // Note: these events are created with the (default) cudaEventDisableTiming + // flag This flag provides the best performance when used with + // cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't measure the + // performance using cudaEvent, this should be set. + // TODO(kwen2501): is ncclEvents_ used anywhere else? + ncclEvents_.emplace(deviceKey, at::cuda::CUDAEvent(cudaEventDisableTiming)); + + // Move the NCCL resource to cache + auto it = inInitializationCommMap_.find(deviceKey); + // A previous thread could've already removed devicesKey from + // inInitializationCommMap_ and added it to devNCCLCommMap_ + if (it != inInitializationCommMap_.end()) { + devNCCLCommMap_.emplace(deviceKey, std::move(it->second)); + inInitializationCommMap_.erase(deviceKey); + + // Now ncclComms are fully initialized. + // Register all active CUDA memory segments in cache allocator to + // the new NCCL communicators + if (useTensorRegisterAllocatorHook_) { + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); + // Register the segment to a new NCCL communicator if on the same device + for (const auto& segmentInfo : snapshot.segments) { + TORCH_INTERNAL_ASSERT( + segmentInfo.device == device.index(), + "Mismatch between CUDA memory segment device and current device"); + ncclComm->registerSegment( + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(segmentInfo.address), + segmentInfo.total_size); + } + } + // Record the mapping between ncclComm and device index so that later + // register hook can register a newly allocated segment to communicators + // on the same device. + // NOTE: we need remove the communicator from this map when it is + // destroyed, otherwise may register onto an invalid communicator. + ncclCommDevIdxMapMutex.lock(); + ncclCommDevIdxMap.emplace(ncclComm, device.index()); + ncclCommDevIdxMapMutex.unlock(); + } + + it = devNCCLCommMap_.find(deviceKey); + TORCH_INTERNAL_ASSERT( + it != devNCCLCommMap_.end(), "Communicators not populated in cache!"); + return it->second; +} + +std::shared_ptr ProcessGroupNCCL::getNCCLComm( + const std::string& deviceKey) { + std::lock_guard lock(mutex_); + if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) { + // Reuse the cached communicator if there is one. + return devNCCLCommMap_[deviceKey]; + } + return nullptr; +} + +uint64_t ProcessGroupNCCL::getCommSplitCounter() const { + uint64_t ret = 0; + for (const auto& i : devNCCLCommMap_) { + auto& ncclComm = i.second; + ret += ncclComm->getCommSplitCounter(); + } + return ret; +} + +namespace { + +// Check validity of tensor +void check_gpu_single_tensor( + const at::Tensor& tensor, + const bool p2p = false // whether operation is a P2P operation +) { + if (!tensor.is_cuda() || tensor.is_sparse()) { + C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); + } + // Skip the following requirements for P2P operations + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + if (p2p) { + TORCH_WARN_ONCE( + "Detected non-contiguous tensor in P2P operations. It is user " + "responsibility to guarantee that source and destination tensors have " + "the same contiguity format."); + } else { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } + } +} + +// Checks that all `tensors' have the same type and shape and reside on the same +// GPU. +// TODO: test_c10d_nccl.py should consider adding tests for the error conditions +// here, ie, that deliberately pass invalid tensors and check the right +// exception is thrown. The "Expected list of tensors on the same device" +// condition may be a challenge because the test would need to pass tensors on +// different devices in the same process. +int64_t check_gpu_tensors_same_device(const std::vector& tensors) { + if (tensors.empty()) { + C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); + } + + const auto& first = tensors.front(); + + int64_t total_numel = 0; + for (const auto& t : tensors) { + if (!t.is_cuda() || t.is_sparse()) { + C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); + } + if (t.scalar_type() != first.scalar_type()) { + C10_THROW_ERROR(TypeError, "Tensors must have identical type"); + } + if (!t.is_non_overlapping_and_dense()) { + C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); + } + // If we're in this function, the user called a _coalesced collective + // on a set of tensors with potentially different sizes and strides. + // Therefore, we don't check for matching sizes and strides, + // but we do double-check tensors are on the same device. + TORCH_CHECK_WITH( + ValueError, + t.get_device() == tensors[0].get_device(), + "Expected list of tensors on the same device"); + total_numel += t.numel(); + } + + return total_numel; +} + +bool check_same_size(const std::vector& input_tensors) { + for (const auto& input_tensor : input_tensors) { + if (!input_tensors[0].is_same_size(input_tensor)) { + return false; + } + } + return true; +} + +} // namespace + +c10::intrusive_ptr ProcessGroupNCCL::initWork( + at::Device& device, + int rank, + OpType opType, + bool isP2P, + const char* profilingTitle, + const std::vector& inputs, + const std::vector& outputs, // TODO(kwen2501): necessary? + bool record) { + auto r = c10::make_intrusive( + pg_uid_, + pg_desc_, + device, + rank, + opType, + isP2P ? seqP2P_ : seqCollective_, + isP2P, + profilingTitle, + profilingTitle != nullptr ? std::optional>(inputs) + : std::nullopt, + desyncDebug_, + enableTiming_.load(), + cudaEventCacheEnabled_.load(), + dist_debug_level_); + if (record) { + bool isP2P = isP2POp(opType); + // Ideally record every work that we enqueue, rather than every work we + // create. + // - at the time of this PR we do not currently enqueue every created work + // - but it is unsafe to steal refs to start/end cuda events from Works that + // may go out of scope before flight recorder has retired them, + // so we must ensure that any work that is initialized via initWork will + // be enqueued + // - initially, moved record() into workEnqueue(), but found that makes it + // hard to get access to profilingTitle, + // inputs, and outputs for metadata recording, and we don't want to attach + // these objects to the Work becuase it has implications for keeping those + // tensors alive longer and adds overhead when copying Work objects + // between threads + } + return r; +} + +// TODO(kwen2501): deprecate +std::vector ProcessGroupNCCL::WorkNCCL::result() { + return *outputs_; +} + +c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: + getFuture() { + return future_; +} + +c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: + getFutureResult() { + return futureWorkResult_; +} + +float ProcessGroupNCCL::WorkNCCL::getDuration() const { + TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled"); + TORCH_CHECK( + ncclStartEvent_, + "getDuration only works if ncclStartEvents_ is populated, true if timing enabled"); + TORCH_CHECK( + ncclEndEvent_, + "getDuration only works if ncclEndEvents_ is populated, which should always be true"); + return ncclStartEvent_->elapsed_time(*ncclEndEvent_); +} + +uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const { + return seq_; +} + +void ProcessGroupNCCL::assignTimeoutToWork( + const c10::intrusive_ptr& work, + const c10::intrusive_ptr& option) { + std::chrono::milliseconds timeout = option->timeout; + std::lock_guard timeoutLock(mtxTimeoutExtension_); + if (ephemeralTimeoutActive_.count() > 0) { + timeout += ephemeralTimeoutActive_; + } + work->opTimeout_ = timeout; + work->ownedEphermeralTimeout_ = + ephemeralTimeoutActive_ - ephemeralTimeoutInflight_; + ephemeralTimeoutInflight_ = ephemeralTimeoutActive_; +} + +void ProcessGroupNCCL::workEnqueue( + const c10::intrusive_ptr& work) { + // in blockingWait_ mode, we don't need watchdog thread, so no need to enqueue + // the work + if (!terminateProcessGroup_.load() && !blockingWait_) { + std::lock_guard lock(workMetaListMutex_); + // Avoid view tensors to be processed in cleanup thread. + // View tensors' destruction invokes autograd_meta, which + // needs to be destructed in user thread. Otherwise will + // get deadlock. Here we enqueue work without outputs_. + workMetaList_.emplace_back(*work); + // update the PG status related to the last enqueued work + pgStatus_->lastEnqueuedSeq = static_cast(work->seq_); + pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_); + pgStatus_->lastEnqueuedNumelIn = work->numelIn_; + pgStatus_->lastEnqueuedNumelOut = work->numelOut_; + lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); + } +} + +ProcessGroupNCCL::Options::Options(bool is_high_priority_stream) + : Backend::Options(NCCL_BACKEND_NAME, kProcessGroupNCCLDefaultTimeout), + is_high_priority_stream(is_high_priority_stream) {} + +static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; + +void ProcessGroupNCCL::startCoalescing() { + // Other collective ops bump seq_ before creating a work. Thus, if coalesced + // ops bump seq_ only after initing a work they will collide with (reuse) the + // seq_ of the last non-coalesced collective. Previously, seq_ was bumped + // inside endCoalescing, but before initWork. Since we now record individual + // ops from a coalesce group into the flight recorder, we want to have the + // same seq_ for those ops and its 'endCoalescing' op. Hence we bump during + // start, which has one minor downside- we burn a seq_ if someone ever does a + // 'start' and 'end' coalescing region without doing an operation inbetween. + + coalescedDevice_.set_index(-1); + coalescedComm_ = nullptr; + coalescing_state_ |= CoalActive; + groupStart(); +} + +// `optype` is for specifying a composite optype, such as ALLGATHER and +// REDUCE_SCATTER +c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { + if (coalescedComm_ == nullptr) { + // There is no actual work being coalesced, return here + groupEnd(); + coalescing_state_ = 0; + return nullptr; + } + TORCH_CHECK( + coalescedDevice_.index() >= 0, + "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); + + // `coalescedComm_` should have same set of comms across collectives + auto comm = coalescedComm_; + // `coalescedDevice_` should have same set of devices across collectives + auto device = coalescedDevice_; + + // `getKeyFromDevice` is how we get keys for both collectives and batch P2P + const auto key = getKeyFromDevice(device); + auto ncclStream = ncclStreams_.at(key); + + // Create Work object + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + bool enqueue = + (coalescing_state_) && capture_status == c10::cuda::CaptureStatus::None; + auto work = initWork( + device, + rank_, + optype, + coalescing_state_ & CoalP2P, + "nccl:coalesced", + {}, + {}, + enqueue); + work->ncclComm_ = comm; + work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams_; + work->store_ = store_; + assignTimeoutToWork(work, options_); + + // Record start before ncclGroupEnd + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + if (useNonblocking()) { + groupEndNonblocking(comm); + } else { + groupEnd(); + } + + // Record end after ncclGroupEnd + // TODO(eqy): is this still necessary if avoidRecordStreams_ is set? + work->ncclEndEvent_->record(ncclStream); + + if (avoidRecordStreams_) { + // other functions expect an initialized ptr if avoidRecordStreams_ is set + work->stashed_for_allocator_safety_ = + std::make_shared>(); + } + + if (enqueue) { + workEnqueue(work); + } + + coalescing_state_ = 0; + coalescedComm_ = nullptr; + return work; +} + +c10::intrusive_ptr ProcessGroupNCCL::endCoalescing() { + // Default OpType to COALESCED if not specified + return endCoalescing(OpType::COALESCED); +} + +template +c10::intrusive_ptr ProcessGroupNCCL::collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams, + bool nanCheck) { + // Environment setting by the user may add onto collective call's option + avoidRecordStreams |= avoidRecordStreams_; + nanCheck &= enableNanCheck_; + + auto device = getDevice(inputs[0]); + // Guard must be created before `currentStreamCaptureStatusMayInitCtx`; + // otherwise, extra CUDA context could be created on device 0. + at::cuda::OptionalCUDAGuard gpuGuard(device); + + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + errorIfCapturingNonCapturableNCCL(capture_status); + + // Bump collective counter + if (!coalescing_state_) { + seqCollective_++; + } + op_id_++; + + const auto key = getKeyFromDevice(device); + std::shared_ptr ncclComm = getNCCLComm(key); + if (ncclComm == nullptr) { + ncclComm = initNCCLComm(key, device, opType); + } + + if (coalescing_state_ & CoalActive) { + if ((coalescing_state_ & CoalColl) == 0) { + // First op in coalesced operations + seqCollective_++; + } + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = ncclComm; + } else { + TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); + } + } + + // Used many times below, so we stash the unordered_map lookup + auto ncclStream = ncclStreams_.at(key); + + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + + bool enqueue = + !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None; + auto work = initWork( + device, rank_, opType, false, profilingTitle, inputs, outputs, enqueue); + + // Store references to outputs to be used by WorkNCCL::result and operator<<. + work->outputs_ = std::make_shared>(outputs); + + if (avoidRecordStreams) { + work->stashed_for_allocator_safety_ = + std::make_shared>(inputs); + } + + // Start event should only be recorded before the ncclGroupStart() + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + pre(ncclStream, work); + + ncclComm_t comm = ncclComm->getNcclComm(); + + // Both `inputs' and `outputs' are created on a worker stream and used in + // different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // We only record `inputs' here, and leave recording `outputs' to `fn' for + // operations where `inputs' and `outputs' are not the same. + // + // See [Sync Streams]. + if (!avoidRecordStreams) { + for (const auto& input : inputs) { + if (!input.is_sparse()) { + c10::cuda::CUDACachingAllocator::recordStream( + input.storage().data_ptr(), ncclStream); + } else { + // for sparse input case record streams on both index and value + // tensors + c10::cuda::CUDACachingAllocator::recordStream( + input.values().storage().data_ptr(), ncclStream); + c10::cuda::CUDACachingAllocator::recordStream( + input.indices().storage().data_ptr(), ncclStream); + } + } + } + +// Not all collectives have the same signature, e.g, all-reduce take in a Tensor +// as the input and output while all-to-all take in a vector of Tensors as input +// and output. Because we define the signature of the fn to take only single +// tensor as input and output, we need to do a hack to get the first element in +// the vector and pass it to fn. +// TODO: we should clean up this in future (by either entirely removing lambda's +// or removing input and output from lambda's signature). +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK( + fn(inputs[0], outputs[0], comm, ncclStream), + ncclComm->getNcclCommFailureReason()); +#else + C10D_NCCL_CHECK_TIMEOUT( + fn(inputs[0], outputs[0], comm, ncclStream), + comm, + ncclComm->getNcclCommFailureReason()); +#endif // NCCL_HAS_COMM_NONBLOCKING + + post(ncclStream, work); + + // End event should only be recorded after the ncclGroupEnd() + if (!coalescing_state_) { + work->ncclEndEvent_->record(ncclStream); + } + work->ncclComm_ = ncclComm; + + { + c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + + // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA + // future blocks the stream this callback runs on the corresponding + // ncclEndEvents_ ensuring appropriate synchronization. + if (work->recordFunctionEndCallback_) { + work->future_->addCallback( + [work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }, + // uses_future = false allows us to skip synchronization in + // ivalue::Future, but is only valid as long as the lambda doesn't use + // the "Future" argument. + /*uses_future=*/false); + } + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + + // Set appropriate work parameters. + work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams; + work->store_ = store_; + assignTimeoutToWork(work, options_); + // Record size info for debug. We only record the size on the first device as + // multi-device per process is deprecated + work->numelIn_ = 0; + work->numelOut_ = 0; + for (const auto& input : inputs) { + work->numelIn_ += input.numel(); + } + for (const auto& output : outputs) { + work->numelOut_ += output.numel(); + } + + if (enqueue) { + workEnqueue(work); + } + + return work; +} + +template +c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( + std::vector& inputs, + std::vector& outputs, + Fn fn, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams) { + // Environment setting by the user may add onto collective call's option + avoidRecordStreams |= avoidRecordStreams_; + + // Currently, the API permits one scenario where inputs.size() and + // outputs.size() are > 0. + // 1. If the call was a _coalesced call, all inputs must be on the same + // device. + // The group of nccl calls applies the collective separately to each input, + // but the group as a whole should be efficient, and might even execute as + // a single fused kernel. + auto device = getDevice(inputs[0]); + // Guard must be created before `currentStreamCaptureStatusMayInitCtx`; + // otherwise, extra CUDA context could be created on device 0. + at::cuda::OptionalCUDAGuard gpuGuard(device); + + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + errorIfCapturingNonCapturableNCCL(capture_status); + + // Bump collective counter + seqCollective_++; + + // For coalescingManager collectives, there is no individual c++ call per + // collective so there is no flight record and we increment seqCollective_ and + // op_id_ together. Compare this to startCoalescing/endCoalescing flow where + // we increment either seqP2P_ or seqCollective_ once per group and increment + // op_id_ once per indvidual operation within the group + op_id_++; + + const auto key = getKeyFromDevice(device); + std::shared_ptr ncclComm = getNCCLComm(key); + if (ncclComm == nullptr) { + ncclComm = initNCCLComm(key, device, opType); + } + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = ncclComm; + } else { + TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); + } + } + + // Used many times below, so we stash the unordered_map lookup + auto ncclStream = ncclStreams_.at(key); + + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + + auto work = initWork( + device, + rank_, + opType, + false, + profilingTitle, + inputs, + outputs, + /*record=*/true); + + // Store references to outputs to be used by WorkNCCL::result and operator<<. + work->outputs_ = std::make_shared>(outputs); + + if (avoidRecordStreams) { + work->stashed_for_allocator_safety_ = + std::make_shared>(inputs); + } + + // Start event should only be recorded before the ncclGroupStart() (which + // happens inside AutoNcclGroup guard below) + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + ncclComm_t comm = ncclComm->getNcclComm(); + +// TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL +// upgrade. Once a NCCL version is qualified, this code should not be needed at +// runtime. +#ifdef PGNCCL_ENABLE_HASH + if (enableCollecticeHashDebug_.load()) { + auto numel = getTensorsNumel(inputs); + auto hashValue = hashTensors(inputs); + PRINT_COLLECTIVE_HASH_SIGNATURE( + "input", opTypeToString(opType), numel, hashValue); + } +#endif // PGNCCL_ENABLE_HASH + + { + torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking()); + for (const auto i : c10::irange(inputs.size())) { + // Both `inputs' and `outputs' are created on a worker stream and used in + // different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // We only record `inputs' here, and leave recording `outputs' to `fn' for + // operations where `inputs' and `outputs' are not the same. + // + // See [Sync Streams]. + if (!avoidRecordStreams) { + if (!inputs[i].is_sparse()) { + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].storage().data_ptr(), ncclStream); + } else { + // for sparse input case record streams on both index and value + // tensors + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].values().storage().data_ptr(), ncclStream); + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].indices().storage().data_ptr(), ncclStream); + } + } +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK( + fn(inputs[i], outputs[i], comm, ncclStream), + ncclComm->getNcclCommFailureReason()); +#else + C10D_NCCL_CHECK_TIMEOUT( + fn(inputs[i], outputs[i], comm, ncclStream), + comm, + ncclComm->getNcclCommFailureReason()); +#endif // NCCL_HAS_COMM_NONBLOCKING + } + } + + work->ncclEndEvent_->record(ncclStream); + work->ncclComm_ = ncclComm; + + { + c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + + // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA + // future blocks the stream this callback runs on the corresponding + // ncclEndEvents_ ensuring appropriate synchronization. + if (work->recordFunctionEndCallback_) { + work->future_->addCallback( + [work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }, + // uses_future = false allows us to skip synchronization in + // ivalue::Future, but is only valid as long as the lambda doesn't use + // the "Future" argument. + /*uses_future=*/false); + } + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + + // Set appropriate work parameters. + work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams; + work->store_ = store_; + assignTimeoutToWork(work, options_); + // Record size info for debug. We only record the size on the first device as + // multi-device per process is deprecated + work->numelIn_ = inputs[0].numel(); + work->numelOut_ = outputs[0].numel(); + + /* Note [cuda graph capture and workEnqueue] + + Normal behavior of the C10D watchdog is to query cuda events on work objects. + We disable this event query behavior during graph capture as it is disallowed + during capture under the strictest capture mode setting. + Note that previously recorded events (e.g., before the capture) can be queried + as the watchdog capture mode has been changed to thread-local, but user-side + event queries (from the main thread) via .is_completed() are still disallowed. + TODO(eqy): Is there a path to allowing workEnqueue during graph capture for + watchdog-thread usage only? + + TODO: + - Is our design for flight recorder safe in this context? are we recording + any FR events during cudagraph capture? if so, they won't be safe to poll for + completion status. + */ + if (capture_status == c10::cuda::CaptureStatus::None) { + workEnqueue(work); + } + // TODO(whc) if the work isn't enqueued, I don't feel great about returning + // it, since interactions with it by usercode won't behave normally - they + // won't observe work completion, for instance. Will this lead to silent + // problems during capture? + return work; +} + +template +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + PreProcess pre, + PostProcess post, + const char* profilingTitle) { + // avoidRecordStreams_ note: + // send, recv, and irecv should be ok with avoidRecordStreams, + // However, for isend, I don't think the API requires the user + // to wait() on the returned handle, so ProcessGroupNCCL can't know + // when it's safe to release the input back to the allocator, + // and the present call has no way to know it's not an isend. + // Therefore, we warn and fall back to the typical recordStream logic: + if (avoidRecordStreams_) { + TORCH_WARN_ONCE( + "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point " + "collectives."); + } + + auto device = getDevice(tensor); + at::cuda::OptionalCUDAGuard gpuGuard(device); + + std::string key; + int p2pRank = 0, p2pTargetRank = 0; + bool isSendRecvSelf = false; + // For batch_isend_irecv, ncclGroupStart() would be called upfront + bool batchP2P = ncclActiveGroupCounter_ > 0; + if (batchP2P) { + // For batch P2P, we need to treat it like a collective when selecting + // communicator, because other ranks can call into this batch other than my + // rank and my peer + key = getKeyFromDevice(device); + p2pRank = rank_; + p2pTargetRank = peer; + } else { + // For single P2P, preserve the old two-rank behavior (to avoid perf diff) + key = getKeySendRecv(rank_, peer); + p2pRank = rank_ <= peer ? 0 : 1; + isSendRecvSelf = rank_ == peer; + p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; + + if (!coalescing_state_) { + // Bump P2P sequence number. + seqP2P_++; + } + } + + // Bump the logical operation counter regardless of whether this op is + // coalesced or individual + op_id_++; + + std::shared_ptr ncclComm = getNCCLComm(key); + if (ncclComm == nullptr) { + ncclComm = initNCCLComm(key, device, opType, p2pRank, isSendRecvSelf); + } + + if (coalescing_state_ & CoalActive) { + // Bump seqP2P_ once per coalesced group, not once per individual op. + if ((coalescing_state_ & CoalP2P) == 0) { + seqP2P_++; + } + coalescing_state_ |= CoalP2P; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = ncclComm; + } else { + TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); + } + } + + // Used many times below, so we stash the unordered_map lookup + auto ncclStream = ncclStreams_.at(key); + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + + // Work itself will create the CUDA events on all GPUs of tensors + c10::intrusive_ptr work; + if (coalescing_state_) { + // When coalescing, we record events per op that lack timing/state + // information becuase there is no 'work' associated with them, and then + // later in endCoalescing we record a 'coalesced' Work which has + // timing/state updates via watchdog thread, but lacks op metadata such as + // input/output sizes and profilingTitle per-op in the group. + // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get + // their timings/states updated by proxy when the Work obj representing the + // coalesce group gets its update, we could accumulate these trace_ids + // together and ask FlightRecorder to take the update from one Work and + // apply it to multiple entries + } else { + // Store references to outputs to be used by WorkNCCL::result and + // operator<<. Note that these outputs are only valid for recv(), as send() + // does not modify the inputs but we still create these outputs for use + // cases such as profiling. + + work = initWork( + device, + rank_, + opType, + true, + profilingTitle, + {tensor}, + {}, + /*record=*/false); + // This bypasses something in Work() that crashes if {tensor} is given as + // output, not sure what + work->outputs_ = std::make_shared>(); + work->outputs_->push_back(tensor); + // TODO(whc) because we don't pass output {tensor} to initWork, we tell + // initWork to not record, and then we manually call record passing all the + // information it wants. + } + + if (!coalescing_state_) { + // Start event should only be recorded before the ncclGroupStart() + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + pre(ncclStream, work); + } + + // Both send tensor and recv tensor are created on a worker stream and used + // in different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + tensor.storage().data_ptr(), ncclStream); + + // This part seems common to both p2p and coalesced-p2p usage? + ncclComm_t comm_ = ncclComm->getNcclComm(); + +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK( + fn(tensor, comm_, ncclStream, p2pTargetRank), + ncclComm->getNcclCommFailureReason()); +#else + // In non-blocking mode, we need to use ncclGroup semantics to ensure that the + // kernel is enqueued for single-P2P ops. Otherwise, the event record below + // may not capture the kernel, leading to data corruption. + ncclGroupStart(); + C10D_NCCL_CHECK_NONBLOCKING( + fn(tensor, comm_, ncclStream, p2pTargetRank), std::nullopt); + C10D_NCCL_CHECK_TIMEOUT_GROUPEND( + ncclGroupEnd(), ncclComm, ncclComm->getNcclCommFailureReason()); +#endif // NCCL_HAS_COMM_NONBLOCKING + + if (!coalescing_state_) { + post(ncclStream); + + // End event should only be recorded after the ncclGroupEnd() + work->ncclEndEvent_->record(ncclStream); + work->ncclComm_ = ncclComm; + work->blockingWait_ = blockingWait_; + work->store_ = store_; + assignTimeoutToWork(work, options_); + // Record size info for debug. We only record the size on the first device + // as multi-device per process is deprecated + work->numelIn_ = work->numelOut_ = tensor.numel(); + + // Future only needs to be created and marked completed with outputs for + // recv(), but still create future for use cases such as profiling even for + // send(). + { + c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + + // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA + // future blocks the stream this callback runs on the corresponding + // ncclEndEvents_ ensuring appropriate synchronization. + if (work->recordFunctionEndCallback_) { + work->future_->addCallback( + [work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }, + // uses_future = false allows us to skip synchronization in + // ivalue::Future, but is only valid as long as the lambda doesn't use + // the "Future" argument. + /*uses_future=*/false); + } + } + + // Enqueue P2P op so that it can be cancelled by NCCL watchdog + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + + if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) { + workEnqueue(work); + } + return work; +} + +template +c10::intrusive_ptr ProcessGroupNCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams, + bool nanCheck) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + pre, + post, + opType, + profilingTitle, + avoidRecordStreams, + nanCheck); +} + +template +c10::intrusive_ptr ProcessGroupNCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams, + bool nanCheck) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + opType, + profilingTitle, + avoidRecordStreams, + nanCheck); +} + +template +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + const char* profilingTitle) { + return pointToPoint( + tensor, + fn, + peer, + opType, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&) {}, + profilingTitle); +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + // If the nccl branch is not "exp" then we just error + C10_THROW_ERROR( + Error, + "NCCL does not support all_reduce with sparse tensors. Please use dense tensors instead."); +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce_impl( + at::Tensor& tensor, + const char* profilingTitle, + const AllreduceOptions& opts) { + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclAllReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::ALLREDUCE, + profilingTitle); +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + if (tensor.is_complex()) { + TORCH_CHECK( + complexViewAsRealAllowed(opts.reduceOp), + "all_reduce does not support", + opts.reduceOp, + "on complex tensors"); + tensor = at::view_as_real(tensor); + } + check_gpu_single_tensor(tensor); + + if (intraNodeComm_ != nullptr && opts.reduceOp == ReduceOp::SUM) { + using namespace intra_node_comm; + auto algo = intraNodeComm_->selectAllReduceAlgo(tensor); + if (algo != intra_node_comm::AllReduceAlgo::NONE) { + intraNodeComm_->allReduce(tensor, algo); + return c10::make_intrusive(); + } + } + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash tensors. + return allreduce_impl(tensor, "nccl:all_reduce", opts); +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts) { + auto total_numel = check_gpu_tensors_same_device(tensors); + TORCH_CHECK( + !isFloat8Type(tensors.back().scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective and assume only one collective + // in coalesed range + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce_coalesced", // collective name + total_numel, // inNelems + total_numel, // outNelems + tensors[0].scalar_type(), // dType + // I'm not sure what in,outSplitSizes mean here. + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash tensors. + return collectiveCoalesced( + tensors, + tensors, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclAllReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::COALESCED, + "nccl:allreduce_coalesced"); +} + +c10::intrusive_ptr ProcessGroupNCCL::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + if (tensor.is_complex()) { + tensor = at::view_as_real(tensor); + } + check_gpu_single_tensor(tensor); + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "broadcast", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash tensors. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + return ncclBcast( + input.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + static_cast(root), + comm, + stream.stream()); + }, + OpType::BROADCAST, + "nccl:broadcast", + avoidRecordStreams, + nanCheck); +} + +// _broadcast_oop adds an out-of-place broadcast in PGNCCL +// Custom collectives may be implemented by coalescing broadcast operations +// One use-case is implementing a vector all_gather (all_gather_v) +// where unevenly sized inputs are gathered among participating ranks +// Since all_gather provides an out-of-place API, an all_gather_v +// semantic implemented inside pg_nccl.all_gather also needs to support +// out-of-place, for which an out-of-place broadcast is required to be added +c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _broadcast_oop must have the same number of elements "); + } + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + return ncclBroadcast( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + static_cast(root), + comm, + stream.stream()); + }, + OpType::BROADCAST, + "nccl:_broadcast_oop", + /*avoidRecordStreams=*/false, + nanCheck); +} + +c10::intrusive_ptr ProcessGroupNCCL::reduce( + std::vector& tensors, + const ReduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + if (tensor.is_complex()) { + TORCH_CHECK( + complexViewAsRealAllowed(opts.reduceOp), + "reduce does not support", + opts.reduceOp, + "on complex tensors"); + tensor = at::view_as_real(tensor); + } + check_gpu_single_tensor(tensor); + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "reduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash tensors. + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank + opts.rootTensor; + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + static_cast(root), + comm, + stream.stream()); + }, + OpType::REDUCE, + "nccl:reduce"); +} + +// _reduce_oop exposes an out-of-place reduce from PGNCCL +// Custom collectives may be implemented by coalescing reduce operations +// One use-case is implementing a vector reduce_scatter (reduce_scatter_v) +// where inputs are reduced and scattered unevenly among participating ranks +// Since reduce_scatter provides an out-of-place API, a reduce_scatter_v +// semantic implemented inside pg_nccl.reduce_scatter also needs to support +// out-of-place, for which an out-of-place reduce is required to be added +c10::intrusive_ptr ProcessGroupNCCL::_reduce_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _reduce_oop must have the same number of elements "); + } + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank + opts.rootTensor; + const auto ncclDataType = getNcclDataType(input.scalar_type()); + const auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + (int)root, + comm, + stream.stream()); + }, + OpType::REDUCE, + "nccl:_reduce_oop"); +} + +c10::intrusive_ptr ProcessGroupNCCL::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts) { + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto inputTensor = inputTensors.back(); + check_gpu_single_tensor(inputTensor); + auto outputTensors_ = outputTensors.back(); + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * // outNelems + this->getSize(), + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + bool same_size = check_same_size(outputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor outputFlattened = newLikeFlat(outputTensors_); + + return collective( + inputTensor, + outputFlattened, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + return ncclAllGather( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + comm, + stream.stream()); + }, + [](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) { + // avoidRecordStreams_ note: We actually don't need to stash anything + // here. + // - inputTensors is stashed onto work->stashed_for_allocator_safety_ + // in collective(). + // - outputFlattened is stashed onto work->outputs_ in collective(). + // - User-facing outputTensors should be held by the user until after + // waiting on work_, or the call makes no sense. + // So all participating tensors are accounted for, and won't be + // released back to their allocation streams until after work_ is + // waited on. + }, + [&](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) { + // Copy the flattened output tensors to the outputs. + at::cuda::CUDAStreamGuard guard(ncclStream); + for (const auto j : c10::irange(outputTensors_.size())) { + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), ncclStream); + } + outputTensors_[j].copy_( + outputFlattened[static_cast(j)], true); + } + }, + OpType::ALLGATHER, + "nccl:all_gather"); + } else { + const auto num_reduces = outputTensors_.size(); + startCoalescing(); + for (const int64_t i : c10::irange(static_cast(num_reduces))) { + auto& output = outputTensors_[i]; + auto& input = (i == rank_) ? inputTensor : output; + auto broadcastOpts = BroadcastOptions{i, int64_t(0), opts.timeout}; + _broadcast_oop(output, input, broadcastOpts); + } + auto work = endCoalescing(OpType::ALLGATHER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( + std::vector>& /* unused */, + std::vector& /* unused */, + const AllgatherOptions& /* unused */) { + C10_THROW_ERROR( + NotImplementedError, + "ProcessGroupNCCL does not support allgather_coalesced"); +} + +c10::intrusive_ptr ProcessGroupNCCL::allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts) { + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective and assume only one collective + // in coalesed range + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputs, // inputTensors + outputs, // outputTensors + rank_, // rank + "allgather_into_tensor_coalesced", // collective name + getTensorsNumel(inputs), // inNelems + getTensorsNumel(outputs), // outNelems + inputs[0].scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + return ncclAllGather( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + comm, + stream.stream()); + }, + OpType::COALESCED, + "nccl:all_gather_into_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts) { + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto outputTensor = outputTensors.back(); + check_gpu_single_tensor(outputTensor); + auto inputTensors_ = inputTensors.back(); + TORCH_CHECK( + !isFloat8Type(outputTensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "reduce_scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + bool same_size = check_same_size(inputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor inputFlattened = newLikeFlat(inputTensors_); + + return collective( + inputFlattened, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + const auto ncclDataType = getNcclDataType(input.scalar_type()); + const auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduceScatter( + input.data_ptr(), + output.data_ptr(), + output.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + [&](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) { + if (avoidRecordStreams_) { + // We only need to stash inputTensors. + // - inputFlattened is stashed onto + // work->stashed_for_allocator_safety_ + // in collective(). + // - User-facing outputTensors is stashed onto work->outputs_ in + // collective(), + // and should also be held by the user until after waiting on + // work_. + auto& v = work->stashed_for_allocator_safety_; + v->insert(v->end(), inputTensors_.begin(), inputTensors_.end()); + } + + // Copy the input tensors to the flattened inputs. + at::cuda::CUDAStreamGuard guard(ncclStream); + for (const auto j : c10::irange(inputTensors_.size())) { + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + inputTensors_[j].storage().data_ptr(), ncclStream); + } + inputFlattened[static_cast(j)].copy_( + inputTensors_[j], true); + } + }, + [&](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::REDUCE_SCATTER, + "nccl:reduce_scatter"); + } else { + const auto num_reduces = inputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(static_cast(num_reduces))) { + auto& input = inputTensors_[i]; + auto& output = (i == rank_) ? outputTensor : input; + auto reduceOpts = ReduceOptions{ + opts.reduceOp, + static_cast(i), + static_cast(0), + opts.timeout}; + _reduce_oop(output, input, reduceOpts); + } + auto work = endCoalescing(OpType::REDUCE_SCATTER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts) { + if (inputTensor.dtype() != outputTensor.dtype()) { + C10_THROW_ERROR( + TypeError, "input tensor must be the same type as the output tensor."); + } + + if (inputTensor.numel() != outputTensor.numel() * size_) { + C10_THROW_ERROR( + ValueError, + "input tensor must be the same size as output size times world size"); + } + + const auto& tensor = outputTensor; + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "_reduce_scatter_base", // collective name + inputTensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dtype + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputs and outputs. + // Note 2: for asyncOp = false, we don't want to record streams because we + // know that the NCCL stream will join back to the "current" stream right + // after this op. So we might just as well keep the stream ownership of the + // input/output tensors unchanged. The benefit would be that the + // allocation/free of the tensors would look deterministic to the "current" + // stream so that the caching allocator can reuse memory pool for this stream + // in a clever way. This setting is added for libraries like FSDP which uses + // `reduce_scatter_tensor`. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduceScatter( + input.data_ptr(), + output.data_ptr(), + output.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::_REDUCE_SCATTER_BASE, + "nccl:_reduce_scatter_base", + avoidRecordStreams); +} + +c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts) { + TORCH_CHECK( + !isFloat8Type(inputs.back().scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective and assume only one collective + // in coalesed range + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputs, // inputTensors + outputs, // outputTensors + rank_, // rank + "reduce_scatter_tensor_coalesced", // collective name + getTensorsNumel(inputs), // inNelems + getTensorsNumel(outputs), // outNelems + inputs[0].scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduceScatter( + input.data_ptr(), + output.data_ptr(), + output.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::COALESCED, + "nccl:reduce_scatter_tensor_coalesced"); +} + +c10::DeviceIndex ProcessGroupNCCL::guessDeviceId() const { + // 1st choice: don't use this function if your API can take a device_id + // argument. + if (getBoundDeviceId().has_value()) { + // 2nd choice: Use the bound GPU device id if available. + // Bounded device id can be passed to `init_process_group`. + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return getBoundDeviceId().value().index(); + } else if (!usedDeviceIdxs_.empty()) { + // 3rd choice: infer the device id from the used device ids. + return *usedDeviceIdxs_.begin(); + } + // This means there is not yet a NCCL collective being called + // Here we have to use the best guesses and will use a single GPU to call + // allreduce to achieve barrier. + // In case the multiple processes fall into the same node, we use rank to + // ensure that each process is on a different GPU + // Note: it is better to use global rank because the group-local rank can be + // offset wrt the device id if intra-node GPUs are sharded into multiple + // dimensions. + int devIdx = globalRank() % localDeviceCount_; + LOG(WARNING) + << logPrefix() + << c10::str( + " using GPU ", + devIdx, + " as device used by this process is currently unknown. ", + "This can potentially cause a hang if this rank to GPU mapping is incorrect. ", + "You can pecify device_id in init_process_group() to force use of a particular device."); + return static_cast(devIdx); +} + +c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { + RECORD_PARAM_COMMS( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank_, // rank + "barrier", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // Device to use for barrier + c10::DeviceIndex barDevIdx = -1; + + // Select device to use for barrier + // 1st choice: Use user defined GPU device ids if provided + if (!opts.device_ids.empty()) { + // Use the first device id because PG NCCL is single-device now + barDevIdx = static_cast(opts.device_ids[0]); + } else { + // 2nd choice: Use the bound or used GPU device id if available. + barDevIdx = guessDeviceId(); + } + + TORCH_CHECK_WITH( + ValueError, + barDevIdx >= 0, + "Failed to infer a GPU device id to perform barrier. "); + auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx); + + // Create a dummy tensor on the device + // Note: we use zeros() instead of empty() to prevent barrier from triggering + // alarm when NaN checker is enabled. + at::Tensor barrierTensor = + at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); + + // All reduce to achieve the barrier + auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier"); + + // Work will take over barrierTensors + auto ncclWork = dynamic_cast(work.get()); + TORCH_CHECK(ncclWork); + ncclWork->isBarrierOp_ = true; + return work; +} + +c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + check_gpu_single_tensor(outputTensor); + check_gpu_single_tensor(inputTensor); + if (outputSplitSizes.empty() && inputSplitSizes.empty()) { + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_all", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputTensors and + // outputTensors. + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + torch::cuda::nccl::all2all_single_equal_split( + input, output, this->getSize(), comm, stream); + return ncclSuccess; + }, + OpType::ALLTOALL_BASE, + "nccl:all_to_all"); + } else { + c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); + c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_allv", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + inputSplitSizes, // inSplitSizes + outputSplitSizes, // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputTensors and + // outputTensors. + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + std::vector send_lengths(size_); + std::vector recv_lengths(size_); + std::vector send_offsets(size_); + std::vector recv_offsets(size_); + c10d::computeLengthsAndOffsets( + inputSplitSizes, input, &send_lengths, &send_offsets); + c10d::computeLengthsAndOffsets( + outputSplitSizes, output, &recv_lengths, &recv_offsets); + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + torch::cuda::nccl::all2all_single_unequal_split( + input.data_ptr(), + send_lengths.data(), + send_offsets.data(), + output.data_ptr(), + recv_lengths.data(), + recv_offsets.data(), + input.element_size(), + input.scalar_type(), + comm, + stream); + return ncclSuccess; + }, + OpType::ALLTOALL_BASE, + "nccl:all_to_all"); + } +} + +c10::intrusive_ptr ProcessGroupNCCL::alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& /* unused */) { + std::vector inSplitSizes; + std::vector outSplitSizes; + int64_t total_numel = 0; + + auto device = outputTensors[0].device(); + for (const auto r : c10::irange(outputTensors.size())) { + check_gpu_single_tensor(outputTensors[r]); + check_gpu_single_tensor(inputTensors[r]); + TORCH_CHECK( + device == outputTensors[r].device() && + device == inputTensors[r].device(), + "Tensors must be on the same device") + inSplitSizes.push_back(inputTensors[r].numel()); + outSplitSizes.push_back(outputTensors[r].numel()); + total_numel += inputTensors[r].numel(); + } + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_to_all", // collective name + total_numel, // inNelems + total_numel, // outNelems + inputTensors.front().scalar_type(), // dType + inSplitSizes, // inSplitSizes + outSplitSizes, // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensors, + outputTensors, + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + torch::cuda::nccl::all2all(outputTensors, inputTensors, comm, stream); + return ncclSuccess; + }, + [&](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) { + if (avoidRecordStreams_) { + // inputTensor0 and outputTensor0 are stashed redundantly by + // collective(), but that's ok. + auto& v = work->stashed_for_allocator_safety_; + v->insert(v->end(), inputTensors.begin(), inputTensors.end()); + v->insert(v->end(), outputTensors.begin(), outputTensors.end()); + } + }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::ALLTOALL, + "nccl:all_to_all"); +} + +c10::intrusive_ptr ProcessGroupNCCL::send( + std::vector& tensors, + int dstRank, + int /* unused */) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_gpu_single_tensor(tensor, true); + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqP2P_) + (coalescing_state_ & CoalP2P ? 0 : 1), + true), // the 1st p2p in coalesced range sets coalescing_state_ and + // bumps seqP2P_ + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + dstRank, // dst rank + "send", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + auto ret = pointToPoint( + tensor, + [&](at::Tensor& input, + ncclComm_t comm, + at::cuda::CUDAStream& stream, + int dst) { + auto ncclDataType = getNcclDataType(input.scalar_type()); + return ncclSend( + input.data_ptr(), + input.numel(), + ncclDataType, + dst, + comm, + stream.stream()); + }, + dstRank, + OpType::SEND, + c10::str("nccl:send ", rank_, "->", dstRank).c_str()); + return ret; +} + +c10::intrusive_ptr ProcessGroupNCCL::recv( + std::vector& tensors, + int srcRank, + int /* unused */) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_gpu_single_tensor(tensor, true); + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqP2P_) + (coalescing_state_ & CoalP2P ? 0 : 1), + true), // the 1st p2p in coalesced range sets coalescing_state_ and + // bumps seqP2P_ + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + srcRank, // src rank + "recv", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + auto ret = pointToPoint( + tensor, + [&](at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream, + int src) { + auto ncclDataType = getNcclDataType(output.scalar_type()); + return ncclRecv( + output.data_ptr(), + output.numel(), + ncclDataType, + src, + comm, + stream.stream()); + }, + srcRank, + OpType::RECV, + c10::str("nccl:recv ", rank_, "<-", srcRank).c_str()); + return ret; +} + +void ProcessGroupNCCL::groupStart() { + C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); + ++ncclActiveGroupCounter_; +} + +void ProcessGroupNCCL::groupEnd() { + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + --ncclActiveGroupCounter_; +} + +void ProcessGroupNCCL::groupEndNonblocking( + const std::shared_ptr& comm) { +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); +#else + if (!useNonblocking()) { + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + } else { + C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt); + } +#endif // NCCL_HAS_COMM_NONBLOCKING + --ncclActiveGroupCounter_; +} + +c10::intrusive_ptr ProcessGroupNCCL::gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts) { + static auto invalidArgument = [](const std::string& msg) { + C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::gather: " + msg); + }; + + assertRootRank(invalidArgument, opts.rootRank, size_); + + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto inputTensor = inputTensors.back(); + + std::vector outputs; + + if (getRank() == opts.rootRank) { + if (outputTensors.size() != 1) { + std::stringstream ss; + ss << "requires a single-element output list containing a list with " + << getSize() << " tensors."; + invalidArgument(ss.str()); + } else if (outputTensors[0].size() != static_cast(getSize())) { + std::stringstream ss; + ss << "Incorrect output list size " << outputTensors[0].size() + << ". Output list size should be " << getSize() + << ", same as size of the process group."; + invalidArgument(ss.str()); + } + + const auto& options = inputTensor.options(); + const auto& sizes = inputTensor.sizes(); + assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes); + outputs = outputTensors[0]; + } else { + // if not in the root rank, initialize outputs as empty list + if (!outputTensors.empty()) { + invalidArgument("requires empty output on non-root"); + } + outputs = {}; + // append a empty tensor to the list, we don't use it but the + // `collective` template function requires it to invoke its function + outputs.emplace_back(); + } + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + opts.rootRank, // root rank + "gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * this->getSize(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputTensors and + // outputs, which == outputTensors[0] on the root rank where it matters. + + auto inputs = std::vector{inputTensor}; + return collective( + inputs, + outputs, // just to fit the collective interface + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank; + if (getRank() == root) { + if (!avoidRecordStreams_) { + for (auto const& output : outputs) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + } + } + torch::cuda::nccl::gather( + inputTensor, outputs, comm, stream, static_cast(root)); + return ncclSuccess; + }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::GATHER, + "nccl:gather"); +} + +c10::intrusive_ptr ProcessGroupNCCL::scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts) { + static auto invalidArgument = [](const std::string& msg) { + C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::scatter: " + msg); + }; + + assertRootRank(invalidArgument, opts.rootRank, size_); + + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto outputTensor = outputTensors.back(); + + std::vector inputs; + + if (getRank() == opts.rootRank) { + if (inputTensors.size() != 1) { + std::stringstream ss; + ss << "requires a single-element input list containing a list with " + << getSize() << " tensors."; + invalidArgument(ss.str()); + } else if (inputTensors[0].size() != static_cast(getSize())) { + std::stringstream ss; + ss << "Incorrect input list size " << inputTensors[0].size() + << ". Input list size should be " << getSize() + << ", same as size of the process group."; + invalidArgument(ss.str()); + } + + const auto& options = outputTensor.options(); + const auto& sizes = outputTensor.sizes(); + assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes); + inputs = inputTensors[0]; + } else { + // if not in the root rank, initialize inputTensors as empty place holder + // with an empty list + if (!inputTensors.empty()) { + invalidArgument("requires empty input on non-root"); + } + inputs = {}; + // append a empty tensor to the list, we don't use it but the + // `collective` template function requires it to invoke its function + inputs.emplace_back(); + } + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + opts.rootRank, // root rank + "scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash outputTensors and + // inputs, which == inputTensors[0] on the root rank where it matters. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + + const auto root = opts.rootRank; + bool nanCheck = (rank_ == root); + + auto outputs = std::vector{outputTensor}; + return collective( + outputs, + inputs, // just to fit the collective interface + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (getRank() == root) { + if (!avoidRecordStreams) { + for (auto const& input : inputs) { + c10::cuda::CUDACachingAllocator::recordStream( + input.storage().data_ptr(), stream); + } + } + } + torch::cuda::nccl::scatter( + inputs, outputTensor, comm, stream, static_cast(root)); + return ncclSuccess; + }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::SCATTER, + "nccl:scatter", + avoidRecordStreams, + nanCheck); +} + +c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( + std::vector& /* unused */, + int /* unused */) { + C10_THROW_ERROR( + NotImplementedError, "ProcessGroupNCCL does not support recvAnysource"); +} + +c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + check_gpu_single_tensor(input_tensor); + check_gpu_single_tensor(output_tensor); + + if (input_tensor.dtype() != output_tensor.dtype()) { + C10_THROW_ERROR( + TypeError, "output tensor must have the same type as input tensor"); + } + + if (input_tensor.numel() * size_ != output_tensor.numel()) { + C10_THROW_ERROR( + ValueError, + "output tensor size must be equal to world_size times input tensor size"); + } + + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + input_tensor, // inputTensors + output_tensor, // outputTensors + rank_, // rank + "_allgather_base", // collective name + input_tensor.numel(), // inNelems + output_tensor.numel(), // outNelems + output_tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputs and outputs. + // Note 2: for asyncOp = false, we don't want to record streams because we + // know that the NCCL stream will join back to the "current" stream right + // after this op. So we might just as well keep the stream ownership of the + // input/output tensors unchanged. The benefit would be that the + // allocation/free of the tensors would look deterministic to the "current" + // stream so that the caching allocator can reuse memory pool for this stream + // in a clever way. This setting is added for libraries like FSDP which uses + // `all_gather_into_tensor`. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + + return collective( + input_tensor, + output_tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + return ncclAllGather( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + comm, + stream.stream()); + }, + OpType::_ALLGATHER_BASE, + "nccl:_all_gather_base", + avoidRecordStreams); +} + +// Create a memory allocator for NCCL. This allocator is used to allocate memory +// that supports NVLink Sharp functionality. This allocator is later pybinded to +// python, so that users can use it to create MemPool. For example: +// >>> pool = torch.cuda.MemPool(backend.mem_allocator) + +// Allocate function +static void* _ncclMemAlloc(size_t size, int device, void* stream) { +#ifndef NCCL_HAS_MEM_ALLOC + TORCH_CHECK( + false, "NCCL mem allocator is not supported in this NCCL version"); +#else + LOG(INFO) << "NCCL mem allocator: allocating " << size << " bytes"; + at::cuda::OptionalCUDAGuard gpuGuard(device); + void* ptr = nullptr; + TORCH_CHECK(ncclMemAlloc(&ptr, size) == ncclSuccess, "ncclMemAlloc failed"); + return ptr; +#endif // NCCL_HAS_MEM_ALLOC +} + +// Free function +static void _ncclMemFree(void* ptr, size_t size, int device, void* stream) { +#ifndef NCCL_HAS_MEM_ALLOC + TORCH_CHECK( + false, "NCCL mem allocator is not supported in this NCCL version"); +#else + LOG(INFO) << "NCCL mem allocator: freeing " << size << " bytes"; + at::cuda::OptionalCUDAGuard gpuGuard(device); + TORCH_CHECK(ncclMemFree(ptr) == ncclSuccess, "ncclMemFree failed"); +#endif // NCCL_HAS_MEM_ALLOC +} + +// Create a `CUDAPluggableAllocator` that uses the above functions. +std::shared_ptr ProcessGroupNCCL::getMemAllocator() { + C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.getMemAllocator"); + c10::DeviceIndex deviceIdx = guessDeviceId(); + if (!supportsTensorAlloc(deviceIdx)) { + TORCH_CHECK( + false, "NCCL mem allocator is not supported in this NCCL version"); + } + static std::shared_ptr + ncclMemAllocator = + torch::cuda::CUDAPluggableAllocator::createCustomAllocator( + _ncclMemAlloc, _ncclMemFree); + return ncclMemAllocator; +} + +bool ProcessGroupNCCL::supportsTensorAlloc(c10::DeviceIndex deviceIdx) { + // Check if NCCL has `ncclMemAlloc` and `ncclMemFree` functions + int version = 0; + // Rely on link-time versioning + ncclGetVersion(&version); + if (version < NCCL_VERSION(2, 19, 0)) { + return false; + } + + // We do an extra check to see if CUDA driver supports multicast. If not, we + // will return false. Although `ncclMemAlloc` will fall back to regular + // `cudaMalloc` and hence not error out, we may still want to avoid creating a + // separate memory pool for NCCL. + return c10d::cuda::deviceSupportsMulticast(deviceIdx); +} + +at::Tensor ProcessGroupNCCL::allocateTensor( + long size, + at::TensorOptions options) { + // Some checks + TORCH_CHECK_VALUE(options.has_device(), "Tensor options must include device"); + auto device = options.device(); + TORCH_CHECK_VALUE( + device.is_cuda(), + "NCCL tensor allocator expects cuda type but got " + c10::str(device)) + + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // Create memory pool + if (!memPool_) { + // Needs a CUDAAllocator + auto allocator = + reinterpret_cast( + getMemAllocator().get()); + // Pool is created + memPool_ = std::make_unique(allocator); + LOG(INFO) << logPrefix() << "Created memory pool"; + } + + // Allocate tensor under this MemPool's context + auto ctx = c10::cuda::MemPoolContext(memPool_.get()); + c10::cuda::CUDACachingAllocator::beginAllocateToPool( + memPool_->device(), memPool_->id(), [](cudaStream_t) { return true; }); + at::Tensor tensor = at::empty({size}, options); + // Also need to ncclCommRegister the pool in case new segments are created; + // reregistration of old segments will be ignored + registerMemPool(memPool_.get()); + c10::cuda::CUDACachingAllocator::endAllocateToPool( + memPool_->device(), memPool_->id()); + c10::cuda::CUDACachingAllocator::releasePool( + memPool_->device(), memPool_->id()); + LOG(INFO) << logPrefix() << "Allocated tensor of size " << size + << " from memory pool"; + return tensor; +} + +bool ProcessGroupNCCL::dumpDebuggingInfo(bool includeStackTrace) { + return true; +} + +} // namespace torchft diff --git a/csrc/ProcessGroupNCCL.hpp b/csrc/ProcessGroupNCCL.hpp new file mode 100644 index 0000000..48ac78a --- /dev/null +++ b/csrc/ProcessGroupNCCL.hpp @@ -0,0 +1,1348 @@ +#pragma once + +#ifdef USE_C10D_NCCL + +#if defined(__linux__) +#include +#include +#include +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace c10d; + +namespace torchft { + +// Control broadcasting of NCCL uniqueId +static std::vector TORCH_NCCL_BCAST_UNIQUEID = { + "TORCH_NCCL_BCAST_UNIQUEID"}; + +// Control whether to always use high priority streams +static std::vector TORCH_NCCL_HIGH_PRIORITY = { + "TORCH_NCCL_HIGH_PRIORITY"}; + +// Control whether or not wait() is blocking or non-blocking. +static std::vector TORCH_NCCL_BLOCKING_WAIT = { + "TORCH_NCCL_BLOCKING_WAIT", + "NCCL_BLOCKING_WAIT"}; + +// TODO: We want to eventually remove this variable and make users to use +// the default value (3 - SkipCleanUp). +// Control whether or not we perform Async Error Handling with NCCL. +static std::vector TORCH_NCCL_ASYNC_ERROR_HANDLING = { + "TORCH_NCCL_ASYNC_ERROR_HANDLING", + "NCCL_ASYNC_ERROR_HANDLING"}; + +// Control whether dumping debug info on watchdog +// timeout is enabled. This variable must be set together with +// TORCH_NCCL_ENABLE_MONITORING=1 and TORCH_NCCL_TRACE_BUFFER_SIZE > 0. +static std::vector TORCH_NCCL_DUMP_ON_TIMEOUT = { + "TORCH_NCCL_DUMP_ON_TIMEOUT"}; + +// Control whether to propagate NCCL errors to all ranks through TCPStore. +static std::vector TORCH_NCCL_PROPAGATE_ERROR = { + "TORCH_NCCL_PROPAGATE_ERROR"}; + +// Control whether Desync Debug is enabled. This variable must be set +// together with TORCH_NCCL_ASYNC_ERROR_HANDLING. +static std::vector TORCH_NCCL_DESYNC_DEBUG = { + "TORCH_NCCL_DESYNC_DEBUG", + "NCCL_DESYNC_DEBUG"}; + +// Enable recording start-events for all ProcessGroupNCCL collectives, and +// compute accurate collective timing per-collective. (Note: end-events are +// recorded by default. Turn on this flag can increase chances of a watchdog +// hang due to performing a CUDA event query which eventually calls +// cudaEventElapsedTime() API. +static std::vector TORCH_NCCL_ENABLE_TIMING = { + "TORCH_NCCL_ENABLE_TIMING", + "NCCL_ENABLE_TIMING"}; + +// Enable monitoring thread which aborts the process when the ProcessGroupNCCL +// Watchdog thread gets stuck and no heartbeat is detected after +// TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL +// APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged +// time than necessary tying up cluster resources. +static std::vector TORCH_NCCL_ENABLE_MONITORING = { + "TORCH_NCCL_ENABLE_MONITORING"}; + +// Control the watchdog heartbeat timeout period after which the monitoring +// thread will abort the process. +static std::vector TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC = { + "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"}; + +// Whether to rethrow CUDA Errors in the watchdog (default true) +static std::vector TORCH_NCCL_RETHROW_CUDA_ERRORS = { + "TORCH_NCCL_RETHROW_CUDA_ERRORS"}; + +// The maximum number of events we store in the flight recorder's ring buffer. +// (One event could be the start or end of a collective, for example). +static std::vector TORCH_NCCL_TRACE_BUFFER_SIZE = { + "TORCH_NCCL_TRACE_BUFFER_SIZE"}; + +// Control how much extra time we will wait for dumping the debugging info +// before we exit and throws timeout exception. +static std::vector TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = { + "TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"}; + +// Control the interval inside the monitoring thread to check the coordinated +// signal from other ranks, e.g. to dump the debugging information. +static std::vector TORCH_NCCL_COORD_CHECK_MILSEC = { + "TORCH_NCCL_COORD_CHECK_MILSEC"}; + +// Whether to log C++ stack traces on unclean shutdown (default true) +static std::vector TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN = { + "TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN"}; + +// Control whether to use CudaEventCache for the collective in watchdog thread. +// We noticed in the past when cuda global lock is held, destroying CudaEvent +// can cause a hang. +static std::vector TORCH_NCCL_CUDA_EVENT_CACHE = { + "TORCH_NCCL_CUDA_EVENT_CACHE"}; + +// Control the number of ranks each root can cover during NCCL comm init. +static std::vector TORCH_NCCL_RANKS_PER_ROOT = { + "TORCH_NCCL_RANKS_PER_ROOT"}; + +static std::vector TORCH_NCCL_NAN_CHECK = {"TORCH_NCCL_NAN_CHECK"}; + +constexpr const char* NCCL_BACKEND_NAME = "nccl"; + +constexpr const char* kStoreDumpKey = "exception_dump"; + +constexpr const char* kStoreErrorSignalKey = "remote_error"; + +constexpr const int kWorkStatusUpdatePeriodMs = 30 * 1000; // 30 seconds + +constexpr auto kProcessGroupNCCLDefaultTimeout = + std::chrono::milliseconds(10 * 60 * 1000); + +// NoHandling: do not handle asynchronous NCCL errors +// TearDown: tear down process upon error, see `WorkNCCL::handleException` +// CleanUpOnly: just clean up collectives and abort communicators without +// tearing down process SkipCleanUp: (this is a temporary option and can be +// removed in future) tear down process without cleaning up NCCL communicators. +// This should be used as a last resort in case `ncclCommAbort` itself is +// hanging +enum ErrorHandlingMode { + NoHandling = 0, + TearDown = 1, + CleanUpOnly = 2, + SkipCleanUp = 3 +}; + +#define SHOULD_CLEAN_UP(a) (a != NoHandling && a != SkipCleanUp) + +#define SHOULD_TEAR_DOWN(a) (a != NoHandling && a != CleanUpOnly) + +#define PRINT_COLLECTIVE_HASH_SIGNATURE(phase, opType, numel, hashValue) \ + LOG(WARNING) << logPrefix() << "Hash of " << phase << " to NCCL " << opType \ + << " with size " << numel << " is " << hashValue; + +// If set, ProcessGroupNCCL doesn't use recordStream calls to ensure +// caching allocator safety for tensors used on both user-facing and +// internal comm streams. +// Instead, it stashes live references to those tensors until after +// user-facing streams are synced with comm streams. +// See stashed_for_allocator_safety_ below. +static std::vector TORCH_NCCL_AVOID_RECORD_STREAMS = { + "TORCH_NCCL_AVOID_RECORD_STREAMS"}; + +// If set, ProcessGroupNCCL registers postAlloc and preFree hooks to cuda cache +// allocator so that whenever a tensor is allocated or freed, ProcessGroupNCCL +// can register/deregister the tensor on all available NCCL communicators. +static std::vector TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK = + {"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK", + "NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"}; + +#if defined(__linux__) +struct DumpPipe { + DumpPipe(int rank) { + std::string fileStem = + getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, ""); + if (fileStem.empty() || + getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) { + return; + } + TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_PIPE_FILE is empty"); + std::string filename = c10::str(fileStem, rank, ".pipe"); + TORCH_CHECK( + unlink(filename.c_str()) != -1 || errno == ENOENT, + "Error removing existing named pipe ", + filename); + TORCH_CHECK( + mkfifo(filename.c_str(), 0666) != -1, + "Error creating named pipe ", + filename); + fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK); + LOG(INFO) << "Pipe file " << filename + << " has been opened, write to it to trigger NCCL Debug Dump."; + TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename); + } + bool shouldDump() { + if (fd_ == -1) { + return false; + } + // NOLINTNEXTLINE(*array*) + char buf[128]{}; + // non-blocking from O_NONBLOCK above. + // Ignore EINTR because we already will poll this + // again later. + ssize_t bytesRead = read(fd_, &buf, 128); + return bytesRead > 0; + } + ~DumpPipe() { + if (fd_ != -1) { + close(fd_); + } + } + + private: + int fd_ = -1; +}; +#else +struct DumpPipe { + DumpPipe(int rank) {} + bool shouldDump() { + return false; + } +}; +#endif + +// ProcessGroupNCCL implements NCCL bindings for c10d. +// +// All functions of the class are expected to be called in the same order +// across all processes in the process group. This is the only way that we +// can guarantee to match up the same calls among all processes. +// +// All NCCL functions provided by this class are asynchronous functions. More +// specifically, each NCCL call is scheduled on a separate CUDA stream that is +// different from the current CUDA stream. This is for the purpose of +// achieving potentially concurrency and better performance. As a result, +// it is the callers' responsibility to make sure that the CUDA stream their +// code works on needs to wait for the NCCL operation from +// this class. +// +// This can be done by calling: +// +// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same +// functionality and are synonyms. +// +// Also note that WorkNCCL::finishedGPUExecution() is a helper function only +// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has +// finished execution on the GPU (not just scheduled). +// +// Example on using the NCCL process group +// +// ProcessGroupNCCL pg(store, rank, size); +// std::shared_ptr work = pg.allreduce(tensors); +// +// // At this point, NCCL kernel has already by queued successfully +// // Now, let current stream wait for the NCCL to finish, this function is +// // async operation as well +// +// work->wait() +// +// // Now continue on other work in the current stream. +class TORCH_API ProcessGroupNCCL : public Backend { + public: + class WorkNCCL : public Work, public std::enable_shared_from_this { + public: + friend struct WorkInfo; + + // Constructor takes a list of CUDA devices + WorkNCCL( + std::string pgUID, + std::string pgDesc, + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + bool isP2P = false, + const char* profilingTitle = nullptr, + const std::optional>& inputs = std::nullopt, + bool desyncDebug = false, + bool enableTiming = false, + bool cudaEventCacheEnabled = false, + DebugLevel distDebugLevel = DebugLevel::Off); + // Copy constructor doing partial copy without outputs_. Cleanup thread + // monitors and removes finished works. However it will deadlock when + // destructs outputs_ tensors who are view tensors in autograd graph. + WorkNCCL(const WorkNCCL& w); + + ~WorkNCCL() override = default; + + // Checks if the NCCL kernel has started to execute. + bool isStarted(); + + // Checks if request has completed. In this specific case of NCCL, it checks + // if the NCCL operation has completed on the GPU in its own NCCL stream. + // Non-blocking operation. + bool isCompleted() override; + + bool isSuccess() const override; + + // Same as calling synchronize() for NCCL work if timeout is not set. + // Otherwise, it will block the CPU thread until the NCCL work is completed + // or timed out. If timeout, exception will be thrown. + bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; + + void abort() override; + + // Let current stream wait on the completion of the NCCL work + // Throws on exceptions. + void synchronize() override; + + // Synchronize streams by blocking each on the NCCL stream + void synchronizeStream(); + + // Helper function to handle exception (throw if needed). + void handleException(ErrorHandlingMode asyncErrorHandling); + + // Helper function that checks if the NCCL kernels have finished + // execution on the GPUs + bool finishedGPUExecution(); + + // Get a Future object that will be marked as completed internally. + c10::intrusive_ptr getFuture() override; + + // Get a Future result of each work (e.g. success, different error types). + // instead of the tensor output. + c10::intrusive_ptr getFutureResult() override; + + float getDuration() const override; + + uint64_t getSequencenumber() const override; + + const std::string& logPrefix() const; + + // Helper function that sets an exception_ptr on the WorkNCCL object. + void setException(std::exception_ptr exception_ptr); + + // Helper function that returns True if the WorkNCCL object has timed out + // and False otherwise. + // In case of timeout, set exception on the WorkNCCL object. + bool checkTimeout( + std::optional timeout = std::nullopt); + + // Print the traceback of the collective at call time + void printTraceback() const; + + std::vector result() override; + + protected: + // The process group unique id + std::string pgUID_; + + // The process group description + std::string pgDesc_; + + // The cached list of CUDA devices to operate on + at::Device device_; + + // The start CUDA event of NCCL operator tracking this work item. These + // start CUDA events are needed by desync debugging if enabled. + std::shared_ptr ncclStartEvent_; + + // The end CUDA event of NCCL operator tracking this work item. + std::shared_ptr ncclEndEvent_; + + // The NCCL communicator used for this work item. + std::shared_ptr ncclComm_; + + // whether this work is a barrier op + bool isBarrierOp_{false}; + + // Clone of blockingWait_ from ProcessGroupNCCL. + bool blockingWait_{false}; + + // Clone of avoidRecordStreams_ from ProcessGroupNCCL. + bool avoidRecordStreams_{false}; + + // Clone of opTimeout_ from ProcessGroupNCCL. + std::chrono::milliseconds opTimeout_{}; + + // Ephemeral timeouts are owned by exactly one work, + // and reset after that work completes. + // There may be more than one ephemeral timeout active at the same time, + // and this variable is used to track the ownership of ephemeral timeout. + std::chrono::milliseconds ownedEphermeralTimeout_ = + std::chrono::milliseconds(0); + + // Time point representing when the work started. + std::chrono::time_point workStartTime_; + + // Record the sequential number of collective or p2p. + uint64_t seq_; + bool isP2P_; + + // Indicates if the nccl start event has been updated to the store trace. + // This will be used by desync debug. + bool startTraceUpdated_{false}; + + // Record collective sizes for debug. We only record the size on the first + // device as multi-device per process is deprecated + size_t numelIn_ = -1; + size_t numelOut_ = -1; + + // Wrapper method for the static checkForNCCLErrors which can be overridden + // for tests. + virtual std::exception_ptr checkForNCCLErrors(); + + friend std::ostream& operator<<( + std::ostream& output, + const WorkNCCL& workNCCL); + + private: + // Checks for NCCL errors and sets an appropriate exception_ptr. + void checkAndSetException(); + + // Just checks whether GPU execution has started, without modifying + // exception_ptr. + bool startedGPUExecutionInternal() const; + + // Just checks whether GPU execution has completed, without modifying + // exception_ptr. + bool finishedGPUExecutionInternal() const; + + // Reference to the store so that we can write aborted communicators + // to the store. + c10::intrusive_ptr store_; + + // Store a reference to NCCL collective's outputs, used by result and to + // give a more descriptive message when representing the Work as a string. + std::shared_ptr> outputs_; + + // TORCH_NCCL_AVOID_RECORD_STREAMS implementation helper. + // Stores references to participating non-output tensors (ie inputs, + // flattened intermediates). + // We'll clear this list in synchronizeStream, just after user-facing + // stream(s) are synced with the nccl work stream(s). + // By keeping these refs (as well as outputs_) alive until after the + // collective's work rejoins the user-facing streams, we achieve + // caching allocator safety without any recordStream calls. + // For in-place collectives, some refs stashed here may alias outputs_, + // but that doesn't do any harm. + std::shared_ptr> stashed_for_allocator_safety_; + + // The future returned by getFuture. + c10::intrusive_ptr future_; + + // the future result (e.g., success or failure) of the work + c10::intrusive_ptr futureWorkResult_; + + bool timingEnabled_; + // unique id used to tell the trace buffer that this + // work has completed + std::optional trace_id_; + DebugLevel distDebugLevel_; + friend class ProcessGroupNCCL; + }; + + class CUDAEventCache + : public std::enable_shared_from_this { + public: + CUDAEventCache(); + std::shared_ptr create(bool timing); + static std::shared_ptr get( + at::DeviceIndex device); + + private: + std::mutex cacheMutex_; + // NOTE: We intentionally store raw pointers so that + // we do not attempt to destroy the event objects on process exit, + // because cuda may be gone. + std::array, 2> + eventsArray_; // 0 for timing=false, 1 for timing=true + }; + + struct Options : Backend::Options { + // NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for + // operations. This is only used when blockingWait_ is enabled. + explicit Options(bool is_high_priority_stream = false); + + // return intrusive_ptr of the object + static c10::intrusive_ptr create( + bool is_high_priority_stream = false) { + return c10::make_intrusive(is_high_priority_stream); + } + + // Schedule NCCL operations on high priority CUDA streams + bool is_high_priority_stream; + +#ifdef NCCL_HAS_CONFIG + // Configure ranks + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; +#endif + + // Optional "parent" backend and color to create communicators from + // via `ncclCommSplit` + std::shared_ptr split_from; + // Color to use for `ncclCommSplit`, values: + // * Non-negative value: in group; + // * NCCL_SPLIT_NOCOLOR (-1): not in group; + // * NCCL_SPLIT_NOCOLOR - 1: uninitialized. + // [Note 1]: the type must be `int` instead of `int64_t` because NCCL API + // accepts int. Otherwise, an implicit conversion may happen at the API call + // and the value may become negative. + // [Note 2]: this member is pybinded to Python, the value passed from Python + // must be within the numerical range of C++ int. Otherwise, Python will + // raise a RuntimeError saying type is incompatible. See also + // `_process_group_color` in `distributed_c10d.py`. +#ifdef NCCL_HAS_COMM_SPLIT + int split_color{NCCL_SPLIT_NOCOLOR - 1}; +#else + // [Note 3]: for older NCCL versions, NCCL_SPLIT_NOCOLOR is not defined. But + // `split_color` is pybinded to Python, so we need to define it. So we use + // the int value of `NCCL_SPLIT_NOCOLOR` (-1) instead. + int split_color{-2}; +#endif + std::vector global_ranks_in_group; + std::string group_name; + }; + + // Helper class related to TORCH_NCCL_DESYNC_DEBUG + class DesyncDebugger { + public: + // Initialize and enable DesyncDebugger + void init(int rank, int size, c10::intrusive_ptr store); + + // Run desync debug. This function is called by watchdog at time of timeout. + void run(); + + // Log work start to store. + void logWorkStart(WorkNCCL& work); + + // Log work end to store. + void logWorkEnd(WorkNCCL& work); + + private: + // Whether desync debug is enabled. + // If false, all functions are no-op. + bool enabled_{false}; + + // From ProcessGroupNCCL + int rank_; + int size_; + + // Reference to the store so that we can log start/end event. + c10::intrusive_ptr store_; + + // The store keys to trace the last NCCL collective kernel CUDA events - + // start event and end event respectively. These are used to do desync root + // cause analysis. + std::string traceKeyStart_; + std::string traceKeyEnd_; + }; + + // If you wish to create multiple process groups, each with a potentially + // different rank and size, you can do so by passing a new store instance + // to each one. If you have only a single store object, you can + // use the `c10d::PrefixStore` to derive scoped instances. + // This is also what the Python API in torch.distributed does. + // + // The process group instance keeps a reference to the store because + // it may be used long after the constructor runs. In fact, the constructor + // doesn't create any NCCL communicators. A single NCCL communicator can + // only be used on a specific set of devices, and are therefore created + // on-demand when a collective runs. If another collective is executed later, + // against a different set of devices, the process group creates another NCCL + // communicator. These NCCL communicators are cached and reused if possible. + // + ProcessGroupNCCL( + c10::intrusive_ptr store, + int rank, + int size, + c10::intrusive_ptr options = Options::create()); + + // This constructor includes the deprecated `groupName` argument. + // If you have existing code that uses the `groupName`, you can replace + // it by specifying a `c10d::PrefixStore(groupName, store)` for store. + C10_DEPRECATED ProcessGroupNCCL( + const c10::intrusive_ptr& store, + int rank, + int size, + const std::string& groupName, + c10::intrusive_ptr options = Options::create()) + : ProcessGroupNCCL(store, rank, size, std::move(options)) {} + + ~ProcessGroupNCCL() override; + + // This function returns a local uid for ProcessGroupNCCL. + uint64_t getUid() { + return static_cast(local_id_); + } + + c10::intrusive_ptr getOptions() { + return options_; + } + + const std::string getBackendName() const override { + return std::string(NCCL_BACKEND_NAME); + } + + bool supportsSplitting() const override { + return true; + } + + bool supportsCoalescing() const override { + return true; + } + + void startCoalescing() override; + + c10::intrusive_ptr endCoalescing() override; + + // For specifying a composite optype, such as ALLGATHER and REDUCE_SCATTER + c10::intrusive_ptr endCoalescing(OpType optype); + + c10::intrusive_ptr broadcast( + std::vector& tensors, + const BroadcastOptions& opts = BroadcastOptions()) override; + + c10::intrusive_ptr _broadcast_oop( + at::Tensor& outputTensors, + at::Tensor& inputTensors, + const BroadcastOptions& opts = BroadcastOptions()); + + c10::intrusive_ptr allreduce_sparse( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; + + c10::intrusive_ptr allreduce( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; + + c10::intrusive_ptr allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts = + AllreduceCoalescedOptions()) override; + + c10::intrusive_ptr reduce( + std::vector& tensors, + const ReduceOptions& opts = ReduceOptions()) override; + + c10::intrusive_ptr _reduce_oop( + at::Tensor& outputTensors, + at::Tensor& inputTensors, + const ReduceOptions& opts = ReduceOptions()); + + c10::intrusive_ptr allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr _allgather_base( + at::Tensor& outputbuffer, + at::Tensor& inputbuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr allgather_coalesced( + std::vector>& outputTensorLists, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr barrier( + const BarrierOptions& opts = BarrierOptions()) override; + + c10::intrusive_ptr alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr send( + std::vector& tensors, + int dstRank, + int tag) override; + + c10::intrusive_ptr recv( + std::vector& tensors, + int srcRank, + int tag) override; + + void groupStart(); + + void groupEnd(); + + void groupEndNonblocking(const std::shared_ptr& comm); + + c10::intrusive_ptr gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts = GatherOptions()) override; + + c10::intrusive_ptr scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts = ScatterOptions()) override; + + // Unsupported Ops + c10::intrusive_ptr recvAnysource( + std::vector& tensors, + int tag) override; + + // Agrees on an initial sequence number for the whole group by having rank 0 + // create it and broadcast it to other ranks using the store. + void setSequenceNumberForGroup() override; + + // Retrieves the current sequence number for the whole group, which should be + // in sync. If the returned number is not consistent across the group, it + // may indicate that there is some sort of collective desynchronization. + uint64_t getSequenceNumberForGroup() override; + + // Return the total number of splits the communicators held by this process + // group have performed. Counts ncclCommCreateFromRanks() for ncclx v2.21.5+ + uint64_t getCommSplitCounter() const; + + void registerOnCompletionHook( + std::function)>&& hook) override; + void waitForPendingWorks() override; + + void enableCollectivesTiming() override; + + // Helper function for iteratively aborting communicators in the provided map + void abortCommsFromMap( + std::unordered_map>& ncclCommsMap, + const std::optional& abortReason); + + c10::intrusive_ptr initIntraNodeComm(); + + // Destroy (shutdown) this backend -- normal exit. + void shutdown() override; + + // Provides an API to abort the ProcessGroup (similar to ncclCommAbort) + // instead of relying on ProcessGroupNCCL destructor. + void abort() override; + + void eagerConnectSingleDevice(at::Device device) override; + + void performNocolorSplit(at::Device device); + + // If all comms on this PG are fully initialized, return true. + bool isInitialized(); + + ErrorType getError() override; + + std::shared_ptr getMemAllocator() override; + + // Allocate tensor from communication-optimized memory pool + at::Tensor allocateTensor(long size, at::TensorOptions options = {}) override; + + // Whether tensor allocation from NCCL memory pool is supported + bool supportsTensorAlloc(c10::DeviceIndex deviceIdx) override; + + // Performs NCCL user buffer registration for all buffers in + // the given MemPool + void registerMemPool(c10::cuda::MemPool* pool); + + // Performs NCCL user buffer de-registration for all buffers in + // the given MemPool + void deregisterMemPool(c10::cuda::MemPool* pool); + + // This method adds a temporary extension for the timeout period, + // applying to all collectives between the calling of this API and + // the completion of the first collective on the GPU. While this feature + // provides flexibility in specific scenarios, it introduces statefulness + // to timeout setting. Therefore, it is advisable to use this API sparingly + // and consider alternative approaches, such as directly setting the timeout + // or utilizing a barrier collective (one can set any timeout to the barrier), + // whenever feasible. + void addEphemeralTimeout(const std::chrono::milliseconds& timeout); + + // This function is only intended for testing purposes because we don't + // want to expose the `WorkNCCL` via pybind. It verifies whether the + // `opTimeout_` of the provided WorkNCCL instance is the same as the specified + // timeout. + bool verifyWorkTimeoutForTest( + const c10::intrusive_ptr& work, + const std::chrono::milliseconds& timeout); + + protected: + // Helper that broadcasts nccl unique ID to all ranks through the store + void broadcastUniqueNCCLID( + ncclUniqueId* ncclID, + bool isSingleP2POp, + const std::string& devicesKey, + int p2pRank); + + // Helper that allgathers nccl unique IDs to all ranks through the store + void allgatherUniqueNCCLIDs( + int rootIdx, + ncclUniqueId* ncclID, + std::vector& ncclIDs); + + // Helper that looks up the cached NCCL communicators only + std::shared_ptr getNCCLComm(const std::string& deviceKey); + + std::shared_ptr initNCCLComm( + const std::string& deviceKey, + at::Device& device, + OpType opType, + int p2pRank = 0, + bool isSendRecvSelf = false); + + // Wrapper method which can be overridden for tests. + virtual std::exception_ptr checkForNCCLErrors( + std::shared_ptr& ncclComm); + + // Ensure thaht if record is True, the work obj will be enqueued via + // workEnqueue + virtual c10::intrusive_ptr initWork( + at::Device& device, + int rank, + OpType opType, + bool isP2P, + const char* profilingTitle = nullptr, + const std::vector& inputs = {}, + const std::vector& outputs = {}, + bool record = false); + + // In the timeout case and we will dump debug info such as the NCCL flight + // recorder to storage. Down the road, if we have more complicated or blocking + // operations, we might need to use a side thread to do it. + bool dumpDebuggingInfo(bool includeStackTrace = true); + + // Abort all communicators on this rank. + bool abortComms(const std::optional& abortReason = std::nullopt); + + // A helper function to check if nonblocking API mode should be used. + // Use this helper instead of directly checking `useNonblocking_` variable. + bool useNonblocking(); + + private: + int globalRankStart; + int globalRankStride; + + // Helper that encapsulates work shared across all collective communication + // primitives. The callbacks have the following signatures: + // + // ncclResult_t fn(at::Tensor& input, at::Tensor& output, + // ncclComm_t, at::cuda::CUDAStream&); + // void {pre,post}(std::vector); + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false, + bool nanCheck = true); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false, + bool nanCheck = true); + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false, + bool nanCheck = true); + + template + c10::intrusive_ptr collectiveCoalesced( + std::vector& input, + std::vector& output, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false); + + // Helper that encapsulates work shared across point-to-point communication + // primitives. It is the same structure as the helper used for collective + // communication primitives. + template + c10::intrusive_ptr pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + const char* profilingTitle = nullptr); + + template + c10::intrusive_ptr pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + PreProcess pre, + PostProcess post, + const char* profilingTitle); + + c10::intrusive_ptr allreduce_impl( + at::Tensor& tensor, + const char* profilingTitle = "nccl:all_reduce", + const AllreduceOptions& opts = AllreduceOptions()); + + // Checks for NCCL errors on each of the communicators and returns an + // appropriate exception_ptr (nullptr if no errors). + static std::exception_ptr checkForNCCLErrorsInternal( + std::shared_ptr& ncclComm); + + // Function that runs as part of a separate thread and checks for errors on + // NCCL communicators. We need a separate thread to check for NCCL errors + // since we can't rely on the user calling certain methods like wait(), + // isCompleted() etc. to detect and remediate errors. In addition to this, we + // need a mechanism to safely abort and remove NCCL communicators from our + // cache. This can be done cleanly by having a thread for the ProcessGroupNCCL + // class. Attempting to modify the communicator cache from the WorkNCCL class + // might run into issues with object lifetime since the ProcessGroupNCCL + // object might get destroyed before the WorkNCCL object. + void ncclCommWatchdog(); + + // Return the CUDA device most likely associated with this backend. + // If we aren't bound to a specific device, there is no strict + // guarantee that this heuristic is the correct assignment of ranks + // to GPUs that Python layers use, but in practice it tends to be. + // Fortunately we don't rely on this for correctness of any tensor + // operations, just for ancillary uses like barriers. + at::Device guessDeviceForRank() const; + + // Destroys initialized NCCL communicators in devNCCLComMap_ given by input + // key. Throws if there are no communicators to destroy. Also removes + // communicators from the cache and clears used device indices. + void destroyNCCLComms(const std::string& devNCCLCommMapKey); + + // Watchdog's inside loop. + // Takes care of cleaning up completed work, and aborting upon failure or + // timeout. + void watchdogHandler(); + + void runHookLoop(); + + // Generates a prefix that is unique to this process group and rank, for + // disambiguating logs + std::string createLogPrefix() const; + + // Returns the unique prefix created in createLogPrefix + const std::string& logPrefix() const; + + // Returns the global rank of the device. This function assumes that users + // always create a default global process group(PG) which includes all + // devices. It is called in the constructor of ProcessGroupNCCL, so it always + // return the rank_ of the the very first PG created, aka, default global PG. + const int& globalRank() const; + + // Returns the global ranks of a PG. + const std::vector& groupRanks() const; + + // Util function to assign timeout to each work. + void assignTimeoutToWork( + const c10::intrusive_ptr& work, + const c10::intrusive_ptr& option); + + // Broadcast flight-recorder dump signal + void broadcastDumpSignal(); + + // A helper function to broadcast a signal (key) from a src rank to all other + // ranks using the specified store. + void broadcastSignal( + c10::intrusive_ptr& store, + const std::string& signal, + int srcRank); + + // A helper function to get the src rank of a signal from the Store. This is + // nonblocking function returning -1 if the signal is not available yet. + int getSignalSrcRank( + c10::intrusive_ptr& store, + const std::string& signal); + + protected: + // Function that runs as part of a separate thread aside from watchdog + // thread because we need to check the heartbeat from watchdog thread + // so that when we get stuck in some NCCL/CUDA calls, + // we can dump the debugging information and abort the process. + virtual void heartbeatMonitor(); + + // Function that directly trigger std::abort so that the whole process + // gets terminated. + virtual void terminateProcess(const std::string& errMsg); + + // A helper function to wait for a future to complete or timeout. + // Returns true if the future completes before timeout, false otherwise. + bool waitForFutureOrTimeout( + std::future& fut, + const std::chrono::milliseconds& timeOutMilSec, + const std::string& futDescription, + ::c10d::C10dLoggingData& debugLog, + bool throwException = false); + + std::string getNCCLWatchdogTimeoutErrorMsg(const std::string& extraMsg); + + std::string getNCCLWatchdogTimeoutExitMsg(const std::string& exitReason); + + void checkAndSetRemoteError(); + + // A helper function to guess the device id of the current rank, based on + // bounded device or used device. Do not use this function if you already know + // the device id to operate on. + c10::DeviceIndex guessDeviceId() const; + + static const int64_t kWatchdogThreadSleepMillis; + + // The store is used to broadcast the NCCL unique ID of rank 0. This store + // comes with prefix and it is different across ProcessGroup NCCL instances + // (aka, different ProcessGroups). + c10::intrusive_ptr store_; + + // Reference to the store without prefix so that keys are same across all + // ProcessGroup NCCL instances and (key, value) pairs written to the store are + // global. + c10::intrusive_ptr globalStore_; + + // The lock which protects the write/read of + // ephemeralTimeoutActive_/ephemeralTimeoutInflight_. + // TODO(fduwjj): We need to have an audit on all mutexes we are adding here. + // And consolidate them if possible. + std::mutex mtxTimeoutExtension_; + + // The ephemeral timeout added on top of existing timeout for works issued + // before first work finishes. + std::chrono::milliseconds ephemeralTimeoutActive_ = + std::chrono::milliseconds(0); + + // The ephemeral timeout addition which has been already applied to work. + std::chrono::milliseconds ephemeralTimeoutInflight_ = + std::chrono::milliseconds(0); + + const c10::intrusive_ptr options_; + + // The number of NCCL communicators that have been created during + // the lifetime of this process group. This sequence number is + // used to scope keys used in the store. + uint64_t ncclCommCounter_{0}; + + // The NCCL communicator that the process group has cached. + // + // For collective operations: + // The key is a list of GPU devices that an operation is operating on + // The GPU devices are stored in a device sequence and the cache NCCL + // communicator is associated with this GPU device sequence + // + // e.g. If the process group op only uses device 0, then the value of + // the used device string stored (value of the hashmap) would be "0". + // + // If the process group op uses device 0 - 7 and the each tensor of the + // input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately, + // then the value of the used device string (key) stored would be + // "0,1,2,3,4,5,6,7" + // + // If the process group op uses device 0 - 7 and the each tensor of the + // input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately, + // then the value of the used device string stored would be + // "0,4,5,6,7,1,2,3" + // + // Note that the order of the device for the tensor list matters. + // + // For point-to-point operations: + // The key is a string of my current rank and the peer process rank. + // e.g. If process 1 and process 2 are involved in a point-to-point + // communication, the key will be "1:2" on both processes. Note: this is for + // the scenario where there is only 1 GPU per process. When it comes to + // multiple GPUs per process, this part may need to redesigned. + // TODO: we probably need a separte map for P2P comms + std::unordered_map> devNCCLCommMap_; + + // The NCCL communicators currently in process of being initialized. + std::unordered_map> + inInitializationCommMap_; + + // Mutex to guard maps like devNCCLCommMap_. + std::mutex mutex_; + + // Heartbeat of watchdog thread. + std::atomic_uint64_t heartbeat_{}; + + // The time interval used for deciding whether there is no watchdog heartbeat. + int heartbeatTimeoutInSec_; + + // timeout for the dump to finish. + int waitTimeoutDumpInMilSec_; + + // promise to coordinate flight recorder dump. + std::promise promiseFlightRecorderDump_; + + // Interval of check coordinated signals in ProcessGroupNCCL from other ranks + // e.g., trigger the dump of the debugging info for timeout when notified. + int coordCheckIntervalMilSec_; + + // Size of ring buffer where we store NCCL Traces for debugging. + int traceBufferSize_; + + // We gate the heartbeat monitor thread so that we can roll it out gradually. + std::atomic monitorThreadEnabled_{}; + + // We gate the cudaEventCache so that we can roll it out gradually. + std::atomic cudaEventCacheEnabled_{}; + + // Monitor thread which checks the heartbeat of Watchdog thread. + // If the monitor thread finds there is no heartbeat, it will dump debug info + // and then kill the watchdog thread to avoid hang. + std::thread ncclHeartbeatMonitorThread_; + + // Watchdog thread which looks for errors on the cached NCCL communicators. + std::thread ncclCommWatchdogThread_; + + std::thread onCompletionHookThread_; + + // Whether or not we should terminate the watchdog and workCleanup threads. + std::atomic terminateProcessGroup_; + + // Whether or not we should terminate the heartbeat monitoring threads. + std::atomic terminateHeartbeatMonitorThread_; + + // Whether there are hooks pending to be fired + std::atomic hasPendingHooks_{}; + + // This is the signal from watchdog threads to indicate whether the monitor + // thread should dump. Making it static so that it is accessiable from all the + // PGs. With this flag, monitor thread would dump debug info under any one of + // the three conditions: + // + // 1: watchdog thread of any PG detects a collective timeout. + // 2: timeout signal is received from other ranks through tcpstore. + // 3: current PG's watchdog heartbeat timeout occurs. + // + // Note that only the monitor thread from PG0 will dump the debug info for + // case one and two so that the debug info is only dumped once. + static std::atomic shouldDump_; + + // Mutex to Guard workMetaList_ + std::mutex workMetaListMutex_; + + // Mutex to Guard monitorWakeUpCV_ + std::mutex monitorMutex_; + + bool writeDebugInfo_ = false; + + // Condition Variable for watchdog thread sleep + std::condition_variable workMetaListCV_; + + // Condition Variable for monitor thread to wake up early + std::condition_variable monitorWakeUpCV_; + + // Vector to Store WorkNCCL pointers + std::list workMetaList_; + + std::chrono::time_point lastWorkListUpdateTime_; + + // Mutex to Guard workMetaList_ + std::mutex completedWorkListMutex_; + + // Condition Variable for watchdog thread sleep + std::condition_variable completedWorkListCV_; + + std::list completedWorkList_; + + // Add Work Pointer to workVector + void workEnqueue(const c10::intrusive_ptr&); + + // The CUDA streams used by NCCL kernels + std::unordered_map ncclStreams_; + + // The CUDA events used to sync NCCL streams + std::unordered_map ncclEvents_; + + // Device Indexes used for all collectives in this group + std::set usedDeviceIdxs_; + + // Flag to denote if a coalescing groupStart/groupEnd block is active + int coalescing_state_ = 0; + + // Stores device indexes for all collectives run inside a coalescing block + at::Device coalescedDevice_ = at::Device("cuda"); + + // Stores communicators for all collectives run inside a coalescing block + std::shared_ptr coalescedComm_ = nullptr; + + // Whether or not wait() and synchronize() are blocking operations that wait + // for the operation to complete. + bool blockingWait_ = false; + + // Whether or not to hook the cache allocator to register all allocated + // tensors + bool useTensorRegisterAllocatorHook_ = false; + + // Whether or not the workCleanupThread is used to perform async error + // handling. + ErrorHandlingMode asyncErrorHandling_ = NoHandling; + + ErrorType error_ = ErrorType::SUCCESS; + + std::mutex errorMutex_; + + // Whether or not to enable timeout root cause analysis. + bool desyncDebug_; + DesyncDebugger desyncDebugger_; + + // Whether or not to dump debug info on exception including both watchdog + // timeout and nccl errors. + bool dumpOnTimeoutOrEx_; + + // Whether or not to propagate detected errors to all ranks in the same PG + // through TCPStore. + bool propagatePgError_; + + // Whether or not to sleep after an exception is thrown in the watchdog. + bool sleepAfterException_{}; + + // Whether or not to enable nan check for input tensors to collectives. + bool enableNanCheck_; + + // Whether or not to print C++ stack traces to logs on unclean shutdown. + bool logCppStackOnUncleanShutdown_; + + // Whether or not to create start CUDAEvent and enable timing for start + // and end events. Note that enableTiming_ is always true if desyncDebug_ + // is set to true. + std::atomic enableTiming_{}; + + // Flag to enable the print of hash value of input/output of collectives for + // verification. + std::atomic enableCollecticeHashDebug_{}; + + // Whether or not TORCH_NCCL_AVOID_RECORD_STREAMS was set + bool avoidRecordStreams_ = false; + + // Whether the NCCL watchdog should rethrow CUDA errors. + bool rethrowCUDAErrors_ = false; + + // The number of active ncclGroupStart() calls. This counter will be increased + // by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd() + // is called. + static thread_local uint64_t ncclActiveGroupCounter_; + + // Counting for the sequential number of NCCL collective call. + // (specifically, how many actual kernels we launched, which differs from + // op_id_ when coalescing is enabled) + uint64_t seqCollective_{0}; + + // Counting for the sequential number of NCCL P2P calls. + uint64_t seqP2P_{0}; + + // Incrementing counter for logical operations (collective or p2p) issued on + // the ProcessGroup + uint64_t op_id_{0}; + + std::exception_ptr watchDogException_ = nullptr; + + // The number of ProcessGroupNCCL created on the current rank. + size_t local_id_; + + std::string logPrefix_; + + c10::intrusive_ptr intraNodeComm_; + + // Number of devices on this node. + int localDeviceCount_{0}; + + std::shared_ptr pgStatus_ = + std::make_shared(); + + // Internal cached value: use NCCL non-blocking API mode or not. + // Use `useNonblocking()` method instead of accessing this variable directly. + std::optional useNonblocking_{std::nullopt}; + + // Communication-optimized memory pool associated with this PG + std::unique_ptr memPool_ = nullptr; +}; + +// Dumps the NCCL comm traces and additional information about the Process +// Group. +TORCH_API std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive); + +// Dumps the NCCL comm traces and additional information about the Process +// Group in JSON formatted string. +// We don't include stack traces in JSON format as it is far too much data. +TORCH_API std::string dump_nccl_trace_json( + bool includeCollectives, + bool onlyActive); + +// Gets a mutable reference to a global optional function.Heartbeat Monitor +// will use this function to dump traces, if available. Inside fbcode, we +// store a function here that uses an internal tool for process tracing +TORCH_API std::optional< + std::function)>>& +get_cpp_trace_dumper(); + +// Similar to get_cpp_trace_dumper, this stores a function defined in +// torch-python layer that lets us check whether the GIL can be acquired, +// helpful for instrumenting in cases where a hang was observed. +typedef bool (*gil_checker_t)(); + +TORCH_API gil_checker_t& get_gil_checker(); +} // namespace torchft + +#endif // USE_C10D_NCCL diff --git a/csrc/cuda_utils.cpp b/csrc/cuda_utils.cpp new file mode 100644 index 0000000..7884be5 --- /dev/null +++ b/csrc/cuda_utils.cpp @@ -0,0 +1,34 @@ +#include + +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include +#endif + +#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030 +#define CUDART_SUPPORTS_MULTICAST +#endif + +namespace c10d::cuda { + +bool deviceSupportsMulticast(int device_idx) { +#if defined(CUDART_SUPPORTS_MULTICAST) + // Multicast support requirements: + // - CUDA Runtime version >= 12030: Checked at compile time using + // CUDART_VERSION. + // - Driver version >= 535: Checked at runtime by verifying the existence of + // cuMulticastCreate_. + // - Device support: Determined by querying + // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED at runtime. + auto driver_api = c10::cuda::DriverAPI::get(); + int multicast_supported; + C10_CUDA_DRIVER_CHECK(driver_api->cuDeviceGetAttribute_( + &multicast_supported, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, + device_idx)); + return driver_api->cuMulticastCreate_ != nullptr && multicast_supported; +#else + return false; +#endif +} + +} // namespace c10d::cuda diff --git a/csrc/init.cpp b/csrc/init.cpp new file mode 100644 index 0000000..56abb14 --- /dev/null +++ b/csrc/init.cpp @@ -0,0 +1,171 @@ +#include +#include + +#include +#include "ProcessGroupNCCL.hpp" + +namespace py = pybind11; + +namespace { +// Wrapper to ensure GIL is released before destructing ProcessGroupGloo +// TODO: move this somewhere more generally useful +template +class IntrusivePtrNoGilDestructor { + c10::intrusive_ptr impl_{}; + + public: + IntrusivePtrNoGilDestructor() = default; + IntrusivePtrNoGilDestructor(const IntrusivePtrNoGilDestructor&) = default; + IntrusivePtrNoGilDestructor(IntrusivePtrNoGilDestructor&&) noexcept = default; + IntrusivePtrNoGilDestructor& operator=(const IntrusivePtrNoGilDestructor&) = + default; + IntrusivePtrNoGilDestructor& operator=( + IntrusivePtrNoGilDestructor&&) noexcept = default; + /* implicit */ IntrusivePtrNoGilDestructor(c10::intrusive_ptr impl) + : impl_(std::move(impl)) {} + // This ctor is very important; see + // https://github.com/pybind/pybind11/issues/2957 + explicit IntrusivePtrNoGilDestructor(T* impl) + // NOLINTNEXTLINE(bugprone-exception-escape) + : impl_(c10::intrusive_ptr::unsafe_steal_from_new(impl)) {} + // NOLINTNEXTLINE(bugprone-exception-escape) + ~IntrusivePtrNoGilDestructor() { + if (impl_) { + if (PyGILState_Check()) { + pybind11::gil_scoped_release release; + impl_.reset(); + } else { + impl_.reset(); + } + } + } + T& operator*() const noexcept { + return *impl_; + } + T* operator->() const noexcept { + return impl_.get(); + } + [[nodiscard]] T* get() const noexcept { + return impl_.get(); + } + void reset() noexcept { + impl_.reset(); + } + operator bool() const noexcept { + return impl_; + } +}; + +} // anonymous namespace + +PYBIND11_DECLARE_HOLDER_TYPE(T, IntrusivePtrNoGilDestructor, true) + +template +using intrusive_ptr_no_gil_destructor_class_ = + py::class_>; + +PYBIND11_MODULE(_torchft_cpp, module) { + py::object backend = + (py::object)py::module_::import("torch._C._distributed_c10d") + .attr("Backend"); + + auto processGroupNCCL = + intrusive_ptr_no_gil_destructor_class_<::torchft::ProcessGroupNCCL>( + module, "ProcessGroupNCCL", backend) + .def( + py::init( + [](const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size, + c10::intrusive_ptr<::torchft::ProcessGroupNCCL::Options> + options) { + // gil_scoped_release is not safe as a call_guard in init. + // https://github.com/pybind/pybind11/issues/5473 + py::gil_scoped_release nogil{}; + + return c10::make_intrusive<::torchft::ProcessGroupNCCL>( + store, rank, size, std::move(options)); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::arg("options"), + R"(Create a new ProcessGroupNCCL instance.)") + .def( + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size, + const std::chrono::milliseconds& timeout) { + // gil_scoped_release is not safe as a call_guard in init. + // https://github.com/pybind/pybind11/issues/5473 + py::gil_scoped_release nogil{}; + + auto options = ::torchft::ProcessGroupNCCL::Options::create(); + options->is_high_priority_stream = false; + options->timeout = timeout; + return c10::make_intrusive<::torchft::ProcessGroupNCCL>( + store, rank, size, options); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::arg("timeout") = ::torchft::kProcessGroupNCCLDefaultTimeout, + R"(Create a new ProcessGroupNCCL instance.)") + .def("_group_start", &::torchft::ProcessGroupNCCL::groupStart) + .def("_group_end", &::torchft::ProcessGroupNCCL::groupEnd) + .def( + "comm_split_count", + &::torchft::ProcessGroupNCCL::getCommSplitCounter) + .def( + "_set_default_timeout", + [](const c10::intrusive_ptr<::torchft::ProcessGroupNCCL>& self, + std::chrono::milliseconds timeout) { + self->getOptions()->timeout = timeout; + }, + py::arg("timeout"), + py::call_guard()) + .def( + "_add_ephemeral_timeout", + [](const c10::intrusive_ptr<::torchft::ProcessGroupNCCL>& self, + const std::chrono::milliseconds& timeout) { + self->addEphemeralTimeout(timeout); + }, + py::arg("timeout")) + .def( + "_verify_work_timeout", + [](const c10::intrusive_ptr<::torchft::ProcessGroupNCCL>& self, + const c10::intrusive_ptr<::c10d::Work>& work, + const std::chrono::milliseconds& timeout) { + return self->verifyWorkTimeoutForTest(work, timeout); + }, + py::arg("work"), + py::arg("timeout")) + .def_property_readonly( + "options", + &::torchft::ProcessGroupNCCL::getOptions, + R"(Return the options used to create this ProcessGroupNCCL instance.)") + .def_property_readonly( + "uid", &::torchft::ProcessGroupNCCL::getUid, R"(Return the uid.)") + .def_property( + "bound_device_id", + &::torchft::ProcessGroupNCCL::getBoundDeviceId, + &::torchft::ProcessGroupNCCL::setBoundDeviceId, + R"(Return the bound device id.)") + .def( + "perform_nocolor_split", + &::torchft::ProcessGroupNCCL::performNocolorSplit) + .def( + "register_mem_pool", + &::torchft::ProcessGroupNCCL::registerMemPool) + .def( + "deregister_mem_pool", + &::torchft::ProcessGroupNCCL::deregisterMemPool) + .def( + "_is_initialized", + &::torchft::ProcessGroupNCCL::isInitialized, + py::call_guard()) + .def( + "get_error", + &::torchft::ProcessGroupNCCL::getError, + py::call_guard()); +} diff --git a/third_party/nccl b/third_party/nccl new file mode 160000 index 0000000..f44ac75 --- /dev/null +++ b/third_party/nccl @@ -0,0 +1 @@ +Subproject commit f44ac759fee12ecb3cc6891e9e739a000f66fd70 diff --git a/torchft/process_group.py b/torchft/process_group.py index 0b7507d..1197f62 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -23,35 +23,34 @@ from datetime import timedelta from multiprocessing.connection import Connection from typing import ( - TYPE_CHECKING, Any, Callable, + cast, Dict, Generator, List, Optional, Tuple, + TYPE_CHECKING, TypeVar, Union, - cast, ) import torch import torch.distributed as dist import torch.multiprocessing as mp +from torch._C._distributed_c10d import PrefixStore # pyre-fixme[21]: no attribute ProcessGroupNCCL # pyre-fixme[21]: no attribute ProcessGroupGloo from torch.distributed import ( DeviceMesh, - PrefixStore, + get_rank, + init_device_mesh, ProcessGroup as BaseProcessGroup, ProcessGroupGloo as BaseProcessGroupGloo, - ProcessGroupNCCL as BaseProcessGroupNCCL, Store, TCPStore, - get_rank, - init_device_mesh, ) from torch.distributed.distributed_c10d import ( AllgatherOptions, @@ -67,6 +66,7 @@ from torch.futures import Future from torch.utils._pytree import tree_any +from torchft._torchft_cpp import ProcessGroupNCCL as BaseProcessGroupNCCL from torchft.multiprocessing import _MonitoredPipe if TYPE_CHECKING: