Skip to content

Commit

Permalink
Updating curand api to use device ptr to accept views (#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman authored Jan 21, 2025
1 parent 8d7e45c commit 2470079
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions src/curand/safe.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Safe abstractions around [crate::curand::result] with [CudaRng].
use super::{result, sys};
use crate::driver::{CudaDevice, CudaSlice, DeviceSlice};
use crate::driver::{CudaDevice, DevicePtrMut};
use std::sync::Arc;

/// Host side RNG that can fill [CudaSlice] with random values.
Expand Down Expand Up @@ -51,38 +51,55 @@ impl CudaRng {
}

/// Fill the [CudaSlice] with data from a `Uniform` distribution
pub fn fill_with_uniform<T>(&self, t: &mut CudaSlice<T>) -> Result<(), result::CurandError>
pub fn fill_with_uniform<T, Dst: DevicePtrMut<T>>(
&self,
dst: &mut Dst,
) -> Result<(), result::CurandError>
where
sys::curandGenerator_t: result::UniformFill<T>,
{
unsafe { result::UniformFill::fill(self.gen, t.cu_device_ptr as *mut T, t.len()) }
unsafe { result::UniformFill::fill(self.gen, *dst.device_ptr_mut() as *mut T, dst.len()) }
}

/// Fill the [CudaSlice] with data from a `Normal(mean, std)` distribution.
pub fn fill_with_normal<T>(
pub fn fill_with_normal<T, Dst: DevicePtrMut<T>>(
&self,
t: &mut CudaSlice<T>,
dst: &mut Dst,
mean: T,
std: T,
) -> Result<(), result::CurandError>
where
sys::curandGenerator_t: result::NormalFill<T>,
{
unsafe { result::NormalFill::fill(self.gen, t.cu_device_ptr as *mut T, t.len(), mean, std) }
unsafe {
result::NormalFill::fill(
self.gen,
*dst.device_ptr_mut() as *mut T,
dst.len(),
mean,
std,
)
}
}

/// Fill the `CudaRc` with data from a `LogNormal(mean, std)` distribution.
pub fn fill_with_log_normal<T>(
pub fn fill_with_log_normal<T, Dst: DevicePtrMut<T>>(
&self,
t: &mut CudaSlice<T>,
dst: &mut Dst,
mean: T,
std: T,
) -> Result<(), result::CurandError>
where
sys::curandGenerator_t: result::LogNormalFill<T>,
{
unsafe {
result::LogNormalFill::fill(self.gen, t.cu_device_ptr as *mut T, t.len(), mean, std)
result::LogNormalFill::fill(
self.gen,
*dst.device_ptr_mut() as *mut T,
dst.len(),
mean,
std,
)
}
}
}
Expand Down

0 comments on commit 2470079

Please sign in to comment.