Skip to content

Commit

Permalink
Add safe methods set_pointer_mode and get_pointer_mode to `CudaBl…
Browse files Browse the repository at this point in the history
…as`, with test
  • Loading branch information
MathisWellmann committed Sep 5, 2024
1 parent 7956461 commit b1ef34e
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions src/cublas/safe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,37 @@ impl CudaBlas {
None => result::set_stream(self.handle, self.device.stream as *mut _),
}
}

/// Set the handle's pointer mode.
/// ref: <https://docs.nvidia.com/cuda/cublas/#cublassetpointermode>
///
/// Some cublas functions require the pointer mode to be set to `cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE`
/// when passing a device memory result buffer into the function, such as `cublas<t>asum()`.
/// Otherwise the operation will panic with `SIGSEGV: invalid memory reference`,
/// or one has to use a host memory reference, which has performance implications.
pub fn set_pointer_mode(
&self,
pointer_mode: sys::cublasPointerMode_t,
) -> Result<(), CublasError> {
unsafe {
sys::lib()
.cublasSetPointerMode_v2(self.handle, pointer_mode)
.result()?;
}
Ok(())
}

/// Get the handle's current pointer mode.
/// ref: <https://docs.nvidia.com/cuda/cublas/#cublasgetpointermode>
pub fn get_pointer_mode(&self) -> Result<sys::cublasPointerMode_t, CublasError> {
unsafe {
let mut mode = ::core::mem::MaybeUninit::uninit();
sys::lib()
.cublasGetPointerMode_v2(self.handle, mode.as_mut_ptr())
.result()?;
return Ok(mode.assume_init());

Check failure on line 81 in src/cublas/safe.rs

View workflow job for this annotation

GitHub Actions / clippy

unneeded `return` statement
}
}
}

impl Drop for CudaBlas {
Expand Down Expand Up @@ -865,4 +896,23 @@ mod tests {
}
}
}

#[test]
fn cublas_pointer_mode() {
let dev = CudaDevice::new(0).unwrap();
let blas = CudaBlas::new(dev.clone()).unwrap();
assert_eq!(
blas.get_pointer_mode().unwrap(),
sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST,
"The default pointer mode uses host pointers"
);

blas.set_pointer_mode(sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE)
.unwrap();
assert_eq!(
blas.get_pointer_mode().unwrap(),
sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
"We have set the mode to use device pointers"
);
}
}

0 comments on commit b1ef34e

Please sign in to comment.