diff --git a/src/curand/result.rs b/src/curand/result.rs index 496233b..0716441 100644 --- a/src/curand/result.rs +++ b/src/curand/result.rs @@ -156,6 +156,21 @@ pub mod generate { lib().curandGenerate(gen, out, num).result() } + /// Fills `out` with `num` u64 values with all bits random. + /// + /// See [cuRAND docs](https://docs.nvidia.com/cuda/curand/group__HOST.html#group__HOST_1gd8e23d144e88b8e638139db05ff798b3) + /// + /// # Safety + /// 1. generator must have been allocated and not freed. + /// 2. `out` point to `num` values + pub unsafe fn uniform_u64( + gen: sys::curandGenerator_t, + out: *mut u64, + num: usize, + ) -> Result<(), CurandError> { + lib().curandGenerateLongLong(gen, out, num).result() + } + /// Fills `out` with `num` f32 values from a normal distribution /// parameterized by `mean` and `std`. /// @@ -279,6 +294,12 @@ impl UniformFill for sys::curandGenerator_t { } } +impl UniformFill for sys::curandGenerator_t { + unsafe fn fill(self, out: *mut u64, num: usize) -> Result<(), CurandError> { + generate::uniform_u64(self, out, num) + } +} + /// Fill with normally distributed numbers of type `T`. pub trait NormalFill { /// # Safety