Skip to content

Commit

Permalink
Fix unsafe usage in cudnn & clippy warnings/errors (#310)
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman authored Jan 8, 2025
1 parent 89d3ce8 commit 594feeb
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 57 deletions.
73 changes: 41 additions & 32 deletions src/cudnn/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,9 @@ pub unsafe fn convolution_forward(
.result()
}

/// # Safety
/// Make sure the handle is valid, all data are associated with the handle, and no pointers are null
/// unless explicitly accepted by the underlying apis.
#[allow(clippy::too_many_arguments)]
pub unsafe fn convolution_bias_activation_forward(
handle: sys::cudnnHandle_t,
Expand Down Expand Up @@ -846,7 +849,10 @@ pub fn create_pooling_descriptor() -> Result<sys::cudnnPoolingDescriptor_t, Cudn
}
}

pub fn set_pooling_descriptor(
/// # Safety
/// Make sure the handle is valid, all data are associated with the handle, and no pointers are null
/// unless explicitly accepted by the underlying apis.
pub unsafe fn set_pooling_descriptor(
desc: sys::cudnnPoolingDescriptor_t,
mode: sys::cudnnPoolingMode_t,
nan_propagation: sys::cudnnNanPropagation_t,
Expand All @@ -855,22 +861,24 @@ pub fn set_pooling_descriptor(
pads: &[std::ffi::c_int],
strides: &[std::ffi::c_int],
) -> Result<(), CudnnError> {
unsafe {
lib()
.cudnnSetPoolingNdDescriptor(
desc,
mode,
nan_propagation,
nb_dims,
window_dims.as_ptr(),
pads.as_ptr(),
strides.as_ptr(),
)
.result()
}
lib()
.cudnnSetPoolingNdDescriptor(
desc,
mode,
nan_propagation,
nb_dims,
window_dims.as_ptr(),
pads.as_ptr(),
strides.as_ptr(),
)
.result()
}

pub fn pooling_forward(
/// # Safety
/// Make sure the handle is valid, all data are associated with the handle, and no pointers are null
/// unless explicitly accepted by the underlying apis.
#[allow(clippy::too_many_arguments)]
pub unsafe fn pooling_forward(
handle: sys::cudnnHandle_t,
pooling_desc: sys::cudnnPoolingDescriptor_t,
alpha: *const ::core::ffi::c_void,
Expand All @@ -880,11 +888,9 @@ pub fn pooling_forward(
y_desc: sys::cudnnTensorDescriptor_t,
y: *mut ::core::ffi::c_void,
) -> Result<(), CudnnError> {
unsafe {
lib()
.cudnnPoolingForward(handle, pooling_desc, alpha, x_desc, x, beta, y_desc, y)
.result()
}
lib()
.cudnnPoolingForward(handle, pooling_desc, alpha, x_desc, x, beta, y_desc, y)
.result()
}

pub fn create_activation_descriptor() -> Result<sys::cudnnActivationDescriptor_t, CudnnError> {
Expand All @@ -897,20 +903,25 @@ pub fn create_activation_descriptor() -> Result<sys::cudnnActivationDescriptor_t
}
}

pub fn set_activation_descriptor(
/// # Safety
/// Make sure the handle is valid, all data are associated with the handle, and no pointers are null
/// unless explicitly accepted by the underlying apis.
pub unsafe fn set_activation_descriptor(
desc: sys::cudnnActivationDescriptor_t,
mode: sys::cudnnActivationMode_t,
nan_propagation: sys::cudnnNanPropagation_t,
coef: f64,
) -> Result<(), CudnnError> {
unsafe {
lib()
.cudnnSetActivationDescriptor(desc, mode, nan_propagation, coef)
.result()
}
lib()
.cudnnSetActivationDescriptor(desc, mode, nan_propagation, coef)
.result()
}

pub fn activation_forward(
/// # Safety
/// Make sure the handle is valid, all data are associated with the handle, and no pointers are null
/// unless explicitly accepted by the underlying apis.
#[allow(clippy::too_many_arguments)]
pub unsafe fn activation_forward(
handle: sys::cudnnHandle_t,
activation_desc: sys::cudnnActivationDescriptor_t,
alpha: *const ::core::ffi::c_void,
Expand All @@ -920,9 +931,7 @@ pub fn activation_forward(
y_desc: sys::cudnnTensorDescriptor_t,
y: *mut ::core::ffi::c_void,
) -> Result<(), CudnnError> {
unsafe {
lib()
.cudnnActivationForward(handle, activation_desc, alpha, x_desc, x, beta, y_desc, y)
.result()
}
lib()
.cudnnActivationForward(handle, activation_desc, alpha, x_desc, x, beta, y_desc, y)
.result()
}
4 changes: 2 additions & 2 deletions src/cudnn/safe/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl Cudnn {
handle: self.clone(),
marker: PhantomData,
};
result::set_activation_descriptor(desc.desc, mode, nan_propagation, coef)?;
unsafe { result::set_activation_descriptor(desc.desc, mode, nan_propagation, coef) }?;
Ok(desc)
}
}
Expand All @@ -39,7 +39,7 @@ pub struct ActivationForward<'a, A: CudnnDataType, X: CudnnDataType, Y: CudnnDat
pub y: &'a TensorDescriptor<Y>,
}

impl<'a, A, X, Y> ActivationForward<'a, A, X, Y>
impl<A, X, Y> ActivationForward<'_, A, X, Y>
where
A: CudnnDataType,
X: CudnnDataType,
Expand Down
9 changes: 5 additions & 4 deletions src/cudnn/safe/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ pub struct ConvForward<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType>
#[deprecated(note = "use ConvForward instead. This will be removed in future versions")]
pub type Conv2dForward<'a, X, C, Y> = ConvForward<'a, X, C, Y>;

impl<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvForward<'a, X, C, Y> {
impl<X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvForward<'_, X, C, Y> {
/// Picks the fastest algorithm from all available cuDNN algorithms based on cudnn heuristics.
pub fn pick_algorithm(&self) -> Result<sys::cudnnConvolutionFwdAlgo_t, CudnnError> {
const NUM_ALGOS: usize = 8;
Expand Down Expand Up @@ -322,7 +322,7 @@ pub struct ConvBackwardData<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnData
#[deprecated(note = "use ConvBackwardData instead. This will be removed in future versions")]
pub type Conv2dBackwardData<'a, X, C, Y> = ConvBackwardData<'a, X, C, Y>;

impl<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvBackwardData<'a, X, C, Y> {
impl<X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvBackwardData<'_, X, C, Y> {
/// Picks the fastest algorithm from all available cuDNN algorithms based on cudnn heuristics.
pub fn pick_algorithm(&self) -> Result<sys::cudnnConvolutionBwdDataAlgo_t, CudnnError> {
const NUM_ALGOS: usize = 6;
Expand Down Expand Up @@ -436,7 +436,7 @@ pub struct ConvBackwardFilter<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnDa
#[deprecated(note = "use ConvBackwardFilter instead. This will be removed in future versions")]
pub type Conv2dBackwardFilter<'a, X, C, Y> = ConvBackwardFilter<'a, X, C, Y>;

impl<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvBackwardFilter<'a, X, C, Y> {
impl<X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvBackwardFilter<'_, X, C, Y> {
/// Picks the fastest algorithm from all available cuDNN algorithms based on cudnn heuristics.
pub fn pick_algorithm(&self) -> Result<sys::cudnnConvolutionBwdFilterAlgo_t, CudnnError> {
const NUM_ALGOS: usize = 7;
Expand Down Expand Up @@ -557,7 +557,7 @@ pub struct ConvBiasActivationForward<
pub y: &'a TensorDescriptor<Y>,
}

impl<'a, X, C, A, Y> ConvBiasActivationForward<'a, X, C, A, Y>
impl<X, C, A, Y> ConvBiasActivationForward<'_, X, C, A, Y>
where
X: CudnnDataType,
C: CudnnDataType,
Expand Down Expand Up @@ -598,6 +598,7 @@ where
/// # Safety
/// The src/filter/y arguments must match the data type/layout specified in the
/// descriptors in `self.
#[allow(clippy::too_many_arguments)]
pub unsafe fn launch<Workspace, Src, Filter, Dst>(
&self,
algo: sys::cudnnConvolutionFwdAlgo_t,
Expand Down
22 changes: 12 additions & 10 deletions src/cudnn/safe/pooling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ impl Cudnn {
marker: PhantomData,
};

result::set_pooling_descriptor(
desc.desc,
mode,
nan_propagation,
window.len() as std::ffi::c_int,
window,
pads,
strides,
)?;
unsafe {
result::set_pooling_descriptor(
desc.desc,
mode,
nan_propagation,
window.len() as std::ffi::c_int,
window,
pads,
strides,
)
}?;

Ok(desc)
}
Expand All @@ -53,7 +55,7 @@ pub struct PoolingForward<'a, P, X, Y> {
pub y: &'a TensorDescriptor<Y>,
}

impl<'a, P, X, Y> PoolingForward<'a, P, X, Y>
impl<P, X, Y> PoolingForward<'_, P, X, Y>
where
P: CudnnDataType,
X: CudnnDataType,
Expand Down
8 changes: 4 additions & 4 deletions src/cudnn/safe/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ pub struct ReduceTensor<'a, T: CudnnDataType, Idx> {
pub c: &'a TensorDescriptor<T>,
}

impl<'a, T: CudnnDataType> ReduceTensor<'a, T, FlatIndices> {
impl<T: CudnnDataType> ReduceTensor<'_, T, FlatIndices> {
/// Get's the size of the indices tensor required for this operation.
///
/// See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnGetReductionIndicesSize).
Expand All @@ -119,7 +119,7 @@ impl<'a, T: CudnnDataType> ReduceTensor<'a, T, FlatIndices> {
}
}

impl<'a, T: CudnnDataType, Idx> ReduceTensor<'a, T, Idx> {
impl<T: CudnnDataType, Idx> ReduceTensor<'_, T, Idx> {
/// Gets the size of the workspace for this operation.
///
/// See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnGetReductionWorkspaceSize)
Expand All @@ -135,7 +135,7 @@ impl<'a, T: CudnnDataType, Idx> ReduceTensor<'a, T, Idx> {
}
}

impl<'a, T: CudnnDataType> ReduceTensor<'a, T, FlatIndices> {
impl<T: CudnnDataType> ReduceTensor<'_, T, FlatIndices> {
/// Launches the operation with indices.
///
/// # Safety
Expand Down Expand Up @@ -172,7 +172,7 @@ impl<'a, T: CudnnDataType> ReduceTensor<'a, T, FlatIndices> {
}
}

impl<'a, T: CudnnDataType> ReduceTensor<'a, T, NoIndices> {
impl<T: CudnnDataType> ReduceTensor<'_, T, NoIndices> {
/// Launches the operation with no indices.
///
/// # Safety
Expand Down
3 changes: 0 additions & 3 deletions src/nccl/safe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ impl Comm {
/// let mut slice_receive = dev.alloc_zeros::<f32>(n).unwrap();
/// comm.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum)
/// .unwrap();
/// });
/// group_start().unwrap();
/// ```
Expand Down Expand Up @@ -170,9 +169,7 @@ impl Comm {
/// let mut slice_receive = dev.alloc_zeros::<f32>(n).unwrap();
/// comm.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum)
/// .unwrap();
/// let out = dev.dtoh_sync_copy(&slice_receive).unwrap();
/// assert_eq!(out, vec![(n_devices * (n_devices + 1)) as f32 / 2.0; n]);
/// ```
pub fn from_rank(
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -994,8 +994,8 @@ pub mod external_memory {
size,
..Default::default()
};
lib()
.cudaImportExternalMemory(external_memory.as_mut_ptr(), &handle_description)
lib()
.cudaImportExternalMemory(external_memory.as_mut_ptr(), &handle_description)
.result()?;
Ok(external_memory.assume_init())
}
Expand Down

0 comments on commit 594feeb

Please sign in to comment.