Skip to content

Commit

Permalink
#290 Fixing Option usage in nccl::Comm::broadcast/reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Sep 6, 2024
1 parent 7956461 commit 9d5906b
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions src/nccl/safe.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{result, sys};
use crate::driver::{CudaDevice, DevicePtr, DevicePtrMut};
use std::io::BufRead;

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-11060)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-11060)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-11050)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-11050)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-11080)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-11080)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-11040)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-11070)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-11070)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-11040)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12010)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12010)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12040)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12040)

unused import: `std::io::BufRead`

Check failure on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / clippy

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12000)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12000)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12050)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12050)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12060)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12060)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12020)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12020)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12030)

unused import: `std::io::BufRead`

Check warning on line 3 in src/nccl/safe.rs

View workflow job for this annotation

GitHub Actions / cargo-check (cuda-12030)

unused import: `std::io::BufRead`
use std::mem::MaybeUninit;
use std::ptr;
use std::{sync::Arc, vec, vec::Vec};
Expand Down Expand Up @@ -239,17 +240,23 @@ impl Comm {
}
}

/// Broadcasts a value from `root` rank to every other ranks `recvbuff`.
/// sendbuff is ignored on ranks other than `root`, so you can pass `None`
/// on non-root ranks.
///
/// sendbuff must be Some on root rank!
pub fn broadcast<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
&self,
sendbuff: &Option<S>,
sendbuff: Option<&S>,
recvbuff: &mut R,
root: i32,
) -> Result<result::NcclStatus, result::NcclError> {
debug_assert!(sendbuff.is_some() || self.rank != root as usize);
let send_ptr = match sendbuff {
Some(buffer) => *buffer.device_ptr() as *mut _,
None => ptr::null(),
};
unsafe {
let send_ptr = match sendbuff {
Some(buffer) => *buffer.device_ptr() as *mut _,
None => ptr::null(),
};
result::broadcast(
send_ptr,
*recvbuff.device_ptr_mut() as *mut _,
Expand Down Expand Up @@ -298,17 +305,26 @@ impl Comm {
}
}

/// Reduces the sendbuff from all ranks into the recvbuff on the
/// `root` rank.
///
/// recvbuff must be Some on the root rank!
pub fn reduce<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
&self,
sendbuff: &S,
recvbuff: &mut R,
recvbuff: Option<&mut R>,
reduce_op: &ReduceOp,
root: i32,
) -> Result<result::NcclStatus, result::NcclError> {
debug_assert!(recvbuff.is_some() || self.rank != root as usize);
let recv_ptr = match recvbuff {
Some(buffer) => *buffer.device_ptr_mut() as *mut _,
None => ptr::null_mut(),
};
unsafe {
result::reduce(
*sendbuff.device_ptr() as *mut _,
*recvbuff.device_ptr_mut() as *mut _,
recv_ptr,
sendbuff.len(),
T::as_nccl_type(),
convert_to_nccl_reduce_op(reduce_op),
Expand Down

0 comments on commit 9d5906b

Please sign in to comment.