From 9d5906bc81af77ae4f1d820eaa82ed62d53eb647 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Fri, 6 Sep 2024 14:23:54 -0400 Subject: [PATCH 1/2] #290 Fixing Option usage in nccl::Comm::broadcast/reduce --- src/nccl/safe.rs | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/nccl/safe.rs b/src/nccl/safe.rs index 596c075..0c46bdf 100644 --- a/src/nccl/safe.rs +++ b/src/nccl/safe.rs @@ -1,5 +1,6 @@ use super::{result, sys}; use crate::driver::{CudaDevice, DevicePtr, DevicePtrMut}; +use std::io::BufRead; use std::mem::MaybeUninit; use std::ptr; use std::{sync::Arc, vec, vec::Vec}; @@ -239,17 +240,23 @@ impl Comm { } } + /// Broadcasts a value from `root` rank to every other ranks `recvbuff`. + /// sendbuff is ignored on ranks other than `root`, so you can pass `None` + /// on non-root ranks. + /// + /// sendbuff must be Some on root rank! pub fn broadcast, R: DevicePtrMut, T: NcclType>( &self, - sendbuff: &Option, + sendbuff: Option<&S>, recvbuff: &mut R, root: i32, ) -> Result { + debug_assert!(sendbuff.is_some() || self.rank != root as usize); + let send_ptr = match sendbuff { + Some(buffer) => *buffer.device_ptr() as *mut _, + None => ptr::null(), + }; unsafe { - let send_ptr = match sendbuff { - Some(buffer) => *buffer.device_ptr() as *mut _, - None => ptr::null(), - }; result::broadcast( send_ptr, *recvbuff.device_ptr_mut() as *mut _, @@ -298,17 +305,26 @@ impl Comm { } } + /// Reduces the sendbuff from all ranks into the recvbuff on the + /// `root` rank. + /// + /// recvbuff must be Some on the root rank! pub fn reduce, R: DevicePtrMut, T: NcclType>( &self, sendbuff: &S, - recvbuff: &mut R, + recvbuff: Option<&mut R>, reduce_op: &ReduceOp, root: i32, ) -> Result { + debug_assert!(recvbuff.is_some() || self.rank != root as usize); + let recv_ptr = match recvbuff { + Some(buffer) => *buffer.device_ptr_mut() as *mut _, + None => ptr::null_mut(), + }; unsafe { result::reduce( *sendbuff.device_ptr() as *mut _, - *recvbuff.device_ptr_mut() as *mut _, + recv_ptr, sendbuff.len(), T::as_nccl_type(), convert_to_nccl_reduce_op(reduce_op), From e4f22dc6a50617f2f4e946926c431f43a3845d21 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Fri, 6 Sep 2024 14:25:20 -0400 Subject: [PATCH 2/2] Remove extra add --- src/nccl/safe.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nccl/safe.rs b/src/nccl/safe.rs index 0c46bdf..831611d 100644 --- a/src/nccl/safe.rs +++ b/src/nccl/safe.rs @@ -1,6 +1,5 @@ use super::{result, sys}; use crate::driver::{CudaDevice, DevicePtr, DevicePtrMut}; -use std::io::BufRead; use std::mem::MaybeUninit; use std::ptr; use std::{sync::Arc, vec, vec::Vec};