From 76cd73f520766ce42dd5ad69ca75c9ee7bc1d9de Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Mon, 20 Jan 2025 22:10:34 -0500 Subject: [PATCH] Updating curand api to use device ptr to accept views --- src/curand/safe.rs | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/src/curand/safe.rs b/src/curand/safe.rs index db680288..3af362c9 100644 --- a/src/curand/safe.rs +++ b/src/curand/safe.rs @@ -1,7 +1,7 @@ //! Safe abstractions around [crate::curand::result] with [CudaRng]. use super::{result, sys}; -use crate::driver::{CudaDevice, CudaSlice, DeviceSlice}; +use crate::driver::{CudaDevice, DevicePtrMut}; use std::sync::Arc; /// Host side RNG that can fill [CudaSlice] with random values. @@ -51,30 +51,41 @@ impl CudaRng { } /// Fill the [CudaSlice] with data from a `Uniform` distribution - pub fn fill_with_uniform(&self, t: &mut CudaSlice) -> Result<(), result::CurandError> + pub fn fill_with_uniform>( + &self, + dst: &mut Dst, + ) -> Result<(), result::CurandError> where sys::curandGenerator_t: result::UniformFill, { - unsafe { result::UniformFill::fill(self.gen, t.cu_device_ptr as *mut T, t.len()) } + unsafe { result::UniformFill::fill(self.gen, *dst.device_ptr_mut() as *mut T, dst.len()) } } /// Fill the [CudaSlice] with data from a `Normal(mean, std)` distribution. - pub fn fill_with_normal( + pub fn fill_with_normal>( &self, - t: &mut CudaSlice, + dst: &mut Dst, mean: T, std: T, ) -> Result<(), result::CurandError> where sys::curandGenerator_t: result::NormalFill, { - unsafe { result::NormalFill::fill(self.gen, t.cu_device_ptr as *mut T, t.len(), mean, std) } + unsafe { + result::NormalFill::fill( + self.gen, + *dst.device_ptr_mut() as *mut T, + dst.len(), + mean, + std, + ) + } } /// Fill the `CudaRc` with data from a `LogNormal(mean, std)` distribution. - pub fn fill_with_log_normal( + pub fn fill_with_log_normal>( &self, - t: &mut CudaSlice, + dst: &mut Dst, mean: T, std: T, ) -> Result<(), result::CurandError> @@ -82,7 +93,13 @@ impl CudaRng { sys::curandGenerator_t: result::LogNormalFill, { unsafe { - result::LogNormalFill::fill(self.gen, t.cu_device_ptr as *mut T, t.len(), mean, std) + result::LogNormalFill::fill( + self.gen, + *dst.device_ptr_mut() as *mut T, + dst.len(), + mean, + std, + ) } } }