Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating curand safe api to accept DevicePtrMut #318

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions src/curand/safe.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Safe abstractions around [crate::curand::result] with [CudaRng].

use super::{result, sys};
use crate::driver::{CudaDevice, CudaSlice, DeviceSlice};
use crate::driver::{CudaDevice, DevicePtrMut};
use std::sync::Arc;

/// Host side RNG that can fill [CudaSlice] with random values.
Expand Down Expand Up @@ -51,38 +51,55 @@ impl CudaRng {
}

/// Fill the [CudaSlice] with data from a `Uniform` distribution
pub fn fill_with_uniform<T>(&self, t: &mut CudaSlice<T>) -> Result<(), result::CurandError>
pub fn fill_with_uniform<T, Dst: DevicePtrMut<T>>(
&self,
dst: &mut Dst,
) -> Result<(), result::CurandError>
where
sys::curandGenerator_t: result::UniformFill<T>,
{
unsafe { result::UniformFill::fill(self.gen, t.cu_device_ptr as *mut T, t.len()) }
unsafe { result::UniformFill::fill(self.gen, *dst.device_ptr_mut() as *mut T, dst.len()) }
}

/// Fill the [CudaSlice] with data from a `Normal(mean, std)` distribution.
pub fn fill_with_normal<T>(
pub fn fill_with_normal<T, Dst: DevicePtrMut<T>>(
&self,
t: &mut CudaSlice<T>,
dst: &mut Dst,
mean: T,
std: T,
) -> Result<(), result::CurandError>
where
sys::curandGenerator_t: result::NormalFill<T>,
{
unsafe { result::NormalFill::fill(self.gen, t.cu_device_ptr as *mut T, t.len(), mean, std) }
unsafe {
result::NormalFill::fill(
self.gen,
*dst.device_ptr_mut() as *mut T,
dst.len(),
mean,
std,
)
}
}

/// Fill the `CudaRc` with data from a `LogNormal(mean, std)` distribution.
pub fn fill_with_log_normal<T>(
pub fn fill_with_log_normal<T, Dst: DevicePtrMut<T>>(
&self,
t: &mut CudaSlice<T>,
dst: &mut Dst,
mean: T,
std: T,
) -> Result<(), result::CurandError>
where
sys::curandGenerator_t: result::LogNormalFill<T>,
{
unsafe {
result::LogNormalFill::fill(self.gen, t.cu_device_ptr as *mut T, t.len(), mean, std)
result::LogNormalFill::fill(
self.gen,
*dst.device_ptr_mut() as *mut T,
dst.len(),
mean,
std,
)
}
}
}
Expand Down
Loading