Skip to content

Commit

Permalink
Adds u64 support for curand::CudaRng::fill_with_uniform
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jan 21, 2025
1 parent dd389c5 commit e7686bd
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/curand/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,21 @@ pub mod generate {
lib().curandGenerate(gen, out, num).result()
}

/// Fills `out` with `num` u64 values with all bits random.
///
/// See [cuRAND docs](https://docs.nvidia.com/cuda/curand/group__HOST.html#group__HOST_1gd8e23d144e88b8e638139db05ff798b3)
///
/// # Safety
/// 1. generator must have been allocated and not freed.
/// 2. `out` point to `num` values
pub unsafe fn uniform_u64(
gen: sys::curandGenerator_t,
out: *mut u64,
num: usize,
) -> Result<(), CurandError> {
lib().curandGenerateLongLong(gen, out, num).result()
}

/// Fills `out` with `num` f32 values from a normal distribution
/// parameterized by `mean` and `std`.
///
Expand Down Expand Up @@ -279,6 +294,12 @@ impl UniformFill<u32> for sys::curandGenerator_t {
}
}

impl UniformFill<u64> for sys::curandGenerator_t {
unsafe fn fill(self, out: *mut u64, num: usize) -> Result<(), CurandError> {
generate::uniform_u64(self, out, num)
}
}

/// Fill with normally distributed numbers of type `T`.
pub trait NormalFill<T> {
/// # Safety
Expand Down

0 comments on commit e7686bd

Please sign in to comment.