diff --git a/include/aluminum/base.hpp b/include/aluminum/base.hpp index b072abc3..7bfd144c 100644 --- a/include/aluminum/base.hpp +++ b/include/aluminum/base.hpp @@ -70,7 +70,7 @@ class al_exception : public std::exception { /** Predefined reduction operations. */ enum class ReductionOperator { - sum, prod, min, max, lor, land, lxor, bor, band, bxor + sum, prod, min, max, lor, land, lxor, bor, band, bxor, avg }; } // namespace Al diff --git a/include/aluminum/mpi_comm_and_stream_wrapper.hpp b/include/aluminum/mpi_comm_and_stream_wrapper.hpp index 2b60b3bd..3bc8e948 100644 --- a/include/aluminum/mpi_comm_and_stream_wrapper.hpp +++ b/include/aluminum/mpi_comm_and_stream_wrapper.hpp @@ -56,6 +56,10 @@ class MPICommAndStreamWrapper { MPI_Comm_rank(local_comm, &rank_in_local_comm); MPI_Comm_size(local_comm, &size_of_local_comm); } + /** disable mpi */ + MPICommAndStreamWrapper(int rank_, int size_, Stream stream_) : + stream(stream_), rank_in_comm(rank_), size_of_comm(size_), mpi_disabled(true) { + } /** Cannot copy this. */ MPICommAndStreamWrapper(const MPICommAndStreamWrapper& other) = delete; /** Default move constructor. */ @@ -67,11 +71,13 @@ class MPICommAndStreamWrapper { /** Destroy the underlying MPI_Comm. */ ~MPICommAndStreamWrapper() { - int finalized; - MPI_Finalized(&finalized); - if (!finalized) { - MPI_Comm_free(&comm); - MPI_Comm_free(&local_comm); + if (!mpi_disabled) { + int finalized; + MPI_Finalized(&finalized); + if (!finalized) { + MPI_Comm_free(&comm); + MPI_Comm_free(&local_comm); + } } } @@ -113,6 +119,8 @@ class MPICommAndStreamWrapper { int rank_in_local_comm; /** Size of the local communicator. */ int size_of_local_comm; + /** disable mpi. */ + bool mpi_disabled = false; }; } // namespace internal diff --git a/include/aluminum/nccl_impl.hpp b/include/aluminum/nccl_impl.hpp index 4dad3e51..450ff371 100644 --- a/include/aluminum/nccl_impl.hpp +++ b/include/aluminum/nccl_impl.hpp @@ -93,6 +93,8 @@ class NCCLCommunicator : public internal::MPICommAndStreamWrapper NCCLCommunicator() : NCCLCommunicator(MPI_COMM_WORLD, 0) {} /** Use a particular MPI communicator and stream. */ NCCLCommunicator(MPI_Comm comm_, cudaStream_t stream_ = 0); + /** disable mpi. */ + NCCLCommunicator(int rank_, int size_, ncclUniqueId nccl_id_, cudaStream_t stream_ = 0); /** Cannot copy this. */ NCCLCommunicator(const NCCLCommunicator& other) = delete; /** Default move constructor. */ @@ -108,6 +110,11 @@ class NCCLCommunicator : public internal::MPICommAndStreamWrapper return NCCLCommunicator(get_comm(), stream); } + /** gracefully abort uncompleted nccl operations */ + void abort() { + AL_CHECK_NCCL(ncclCommAbort(m_nccl_comm)); + } + private: /** Raw NCCL communicator. */ ncclComm_t m_nccl_comm; @@ -132,6 +139,8 @@ inline ncclRedOp_t ReductionOperator2ncclRedOp(ReductionOperator op) { return ncclMin; case ReductionOperator::max: return ncclMax; + case ReductionOperator::avg: + return ncclAvg; default: throw_al_exception("Reduction operator not supported"); } diff --git a/src/nccl_impl.cpp b/src/nccl_impl.cpp index 7895377a..5dbf4d5c 100644 --- a/src/nccl_impl.cpp +++ b/src/nccl_impl.cpp @@ -44,6 +44,12 @@ NCCLCommunicator::NCCLCommunicator(MPI_Comm comm_, cudaStream_t stream_) : AL_CHECK_NCCL(ncclCommInitRank(&m_nccl_comm, size(), nccl_id, rank())); } +NCCLCommunicator::NCCLCommunicator(int rank_, int size_, ncclUniqueId nccl_id_, cudaStream_t stream_) : + MPICommAndStreamWrapper(rank_, size_, stream_) { + // This uses the current CUDA device. + AL_CHECK_NCCL(ncclCommInitRank(&m_nccl_comm, size_, nccl_id_, rank_)); +} + NCCLCommunicator::~NCCLCommunicator() { int d; // Only destroy resources if the driver is still loaded.