From ab0c7d192415b3be366efbf15125c9c2d94b8972 Mon Sep 17 00:00:00 2001 From: Miguel Date: Tue, 16 Jul 2024 16:00:23 -0400 Subject: [PATCH 01/11] Add pooling safe api --- src/cudnn/result.rs | 53 +++++++++++++++++++++++++ src/cudnn/safe/mod.rs | 2 + src/cudnn/safe/pooling.rs | 81 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+) create mode 100644 src/cudnn/safe/pooling.rs diff --git a/src/cudnn/result.rs b/src/cudnn/result.rs index 0269d361..6276040e 100644 --- a/src/cudnn/result.rs +++ b/src/cudnn/result.rs @@ -779,3 +779,56 @@ pub unsafe fn reduce_tensor( ) .result() } + +pub fn create_pooling_descriptor() -> Result { + let mut desc = MaybeUninit::uninit(); + unsafe { lib().cudnnCreatePoolingDescriptor(desc.as_mut_ptr()).result()?; + Ok(desc.assume_init()) + } +} + +pub fn set_pooling_descriptor( + desc: sys::cudnnPoolingDescriptor_t, + mode: sys::cudnnPoolingMode_t, + nan_propagation: sys::cudnnNanPropagation_t, + dims: std::ffi::c_int, + input: &[std::ffi::c_int], + pads: &[std::ffi::c_int], + strides: &[std::ffi::c_int], +) -> Result<(), CudnnError> { + unsafe { + lib().cudnnSetPoolingNdDescriptor( + desc, + mode, + nan_propagation, + dims, + input.as_ptr(), + pads.as_ptr(), + strides.as_ptr() + ).result() + } +} + +pub fn pooling_forward( + handle: sys::cudnnHandle_t, + pooling_desc: sys::cudnnPoolingDescriptor_t, + alpha: *const ::core::ffi::c_void, + x_desc: sys::cudnnTensorDescriptor_t, + x: *const ::core::ffi::c_void, + beta: *const ::core::ffi::c_void, + y_desc: sys::cudnnTensorDescriptor_t, + y: *mut ::core::ffi::c_void, +) -> Result<(), CudnnError> { + unsafe { + lib().cudnnPoolingForward( + handle, + pooling_desc, + alpha, + x_desc, + x, + beta, + y_desc, + y, + ).result() + } +} \ No newline at end of file diff --git a/src/cudnn/safe/mod.rs b/src/cudnn/safe/mod.rs index ff7fc828..e30ee5aa 100644 --- a/src/cudnn/safe/mod.rs +++ b/src/cudnn/safe/mod.rs @@ -19,6 +19,7 @@ mod conv; mod core; mod reduce; +mod pooling; #[allow(deprecated)] pub use self::conv::{ @@ -34,6 +35,7 @@ pub use self::conv::{ ConvForward, FilterDescriptor, }; +pub use self::pooling::{PoolingForward, PoolingDescriptor}; pub use self::core::{Cudnn, CudnnDataType, TensorDescriptor}; pub use self::reduce::{FlatIndices, NoIndices, ReduceTensor, ReductionDescriptor}; pub use super::result::CudnnError; diff --git a/src/cudnn/safe/pooling.rs b/src/cudnn/safe/pooling.rs new file mode 100644 index 00000000..a603af8e --- /dev/null +++ b/src/cudnn/safe/pooling.rs @@ -0,0 +1,81 @@ +use crate::{ + cudnn::{result::CudnnError, sys}, + driver::{DevicePtr, DevicePtrMut}, +}; + +use std::{marker::PhantomData, sync::Arc}; +use crate::cudnn::{result, Cudnn, CudnnDataType, TensorDescriptor}; + +pub struct PoolingDescriptor { + desc: sys::cudnnPoolingDescriptor_t, + #[allow(unused)] + handle: Arc, + marker: PhantomData, +} + +impl Cudnn { + pub fn create_poolingnd( + self: &Arc, + input: &[std::ffi::c_int], + pads: &[std::ffi::c_int], + strides: &[std::ffi::c_int], + mode: sys::cudnnPoolingMode_t, + nan_propagation: sys::cudnnNanPropagation_t, + ) -> Result, CudnnError> { + let desc = result::create_pooling_descriptor()?; + let desc = PoolingDescriptor { + desc, + handle: self.clone(), + marker: PhantomData, + }; + + result::set_pooling_descriptor( + desc.desc, + mode, + nan_propagation, + input.len() as std::ffi::c_int, + input, + pads, + strides + )?; + + Ok(desc) + } +} + +pub struct PoolingForward<'a, P, X, Y> { + pooling: &'a PoolingDescriptor

, + x: &'a TensorDescriptor, + y: &'a TensorDescriptor, +} + +impl<'a, P, X, Y> PoolingForward<'a, P, X, Y> +where + P: CudnnDataType, + X: CudnnDataType, + Y: CudnnDataType, +{ + pub fn launch( + &self, + (alpha, beta): (Y, Y), + input: &Input, + output: &mut Output, + ) -> Result<(), CudnnError> + where + Input: DevicePtr, + Output: DevicePtrMut, + { + let alpha = alpha.into_scaling_parameter(); + let beta = beta.into_scaling_parameter(); + result::pooling_forward( + self.pooling.handle.handle, + self.pooling.desc, + (&alpha) as *const Y::Scalar as *const std::ffi::c_void, + self.x.desc, + *input.device_ptr() as *const X as *const std::ffi::c_void, + (&beta) as *const Y::Scalar as *const std::ffi::c_void, + self.y.desc, + *output.device_ptr_mut() as *mut Y as *mut std::ffi::c_void, + ) + } +} From 415ce82f2755d159de240cd9da77a5afe89e602a Mon Sep 17 00:00:00 2001 From: Miguel Date: Tue, 16 Jul 2024 16:05:23 -0400 Subject: [PATCH 02/11] Run fmt --- src/cudnn/result.rs | 40 +++++++++++++++++++-------------------- src/cudnn/safe/mod.rs | 4 ++-- src/cudnn/safe/pooling.rs | 32 +++++++++++++++---------------- 3 files changed, 37 insertions(+), 39 deletions(-) diff --git a/src/cudnn/result.rs b/src/cudnn/result.rs index 6276040e..f7bb696a 100644 --- a/src/cudnn/result.rs +++ b/src/cudnn/result.rs @@ -782,7 +782,10 @@ pub unsafe fn reduce_tensor( pub fn create_pooling_descriptor() -> Result { let mut desc = MaybeUninit::uninit(); - unsafe { lib().cudnnCreatePoolingDescriptor(desc.as_mut_ptr()).result()?; + unsafe { + lib() + .cudnnCreatePoolingDescriptor(desc.as_mut_ptr()) + .result()?; Ok(desc.assume_init()) } } @@ -797,15 +800,17 @@ pub fn set_pooling_descriptor( strides: &[std::ffi::c_int], ) -> Result<(), CudnnError> { unsafe { - lib().cudnnSetPoolingNdDescriptor( - desc, - mode, - nan_propagation, - dims, - input.as_ptr(), - pads.as_ptr(), - strides.as_ptr() - ).result() + lib() + .cudnnSetPoolingNdDescriptor( + desc, + mode, + nan_propagation, + dims, + input.as_ptr(), + pads.as_ptr(), + strides.as_ptr(), + ) + .result() } } @@ -820,15 +825,8 @@ pub fn pooling_forward( y: *mut ::core::ffi::c_void, ) -> Result<(), CudnnError> { unsafe { - lib().cudnnPoolingForward( - handle, - pooling_desc, - alpha, - x_desc, - x, - beta, - y_desc, - y, - ).result() + lib() + .cudnnPoolingForward(handle, pooling_desc, alpha, x_desc, x, beta, y_desc, y) + .result() } -} \ No newline at end of file +} diff --git a/src/cudnn/safe/mod.rs b/src/cudnn/safe/mod.rs index e30ee5aa..872971ea 100644 --- a/src/cudnn/safe/mod.rs +++ b/src/cudnn/safe/mod.rs @@ -18,8 +18,8 @@ mod conv; mod core; -mod reduce; mod pooling; +mod reduce; #[allow(deprecated)] pub use self::conv::{ @@ -35,8 +35,8 @@ pub use self::conv::{ ConvForward, FilterDescriptor, }; -pub use self::pooling::{PoolingForward, PoolingDescriptor}; pub use self::core::{Cudnn, CudnnDataType, TensorDescriptor}; +pub use self::pooling::{PoolingDescriptor, PoolingForward}; pub use self::reduce::{FlatIndices, NoIndices, ReduceTensor, ReductionDescriptor}; pub use super::result::CudnnError; diff --git a/src/cudnn/safe/pooling.rs b/src/cudnn/safe/pooling.rs index a603af8e..fcae0570 100644 --- a/src/cudnn/safe/pooling.rs +++ b/src/cudnn/safe/pooling.rs @@ -3,8 +3,8 @@ use crate::{ driver::{DevicePtr, DevicePtrMut}, }; -use std::{marker::PhantomData, sync::Arc}; use crate::cudnn::{result, Cudnn, CudnnDataType, TensorDescriptor}; +use std::{marker::PhantomData, sync::Arc}; pub struct PoolingDescriptor { desc: sys::cudnnPoolingDescriptor_t, @@ -36,7 +36,7 @@ impl Cudnn { input.len() as std::ffi::c_int, input, pads, - strides + strides, )?; Ok(desc) @@ -64,18 +64,18 @@ where where Input: DevicePtr, Output: DevicePtrMut, - { - let alpha = alpha.into_scaling_parameter(); - let beta = beta.into_scaling_parameter(); - result::pooling_forward( - self.pooling.handle.handle, - self.pooling.desc, - (&alpha) as *const Y::Scalar as *const std::ffi::c_void, - self.x.desc, - *input.device_ptr() as *const X as *const std::ffi::c_void, - (&beta) as *const Y::Scalar as *const std::ffi::c_void, - self.y.desc, - *output.device_ptr_mut() as *mut Y as *mut std::ffi::c_void, - ) - } + { + let alpha = alpha.into_scaling_parameter(); + let beta = beta.into_scaling_parameter(); + result::pooling_forward( + self.pooling.handle.handle, + self.pooling.desc, + (&alpha) as *const Y::Scalar as *const std::ffi::c_void, + self.x.desc, + *input.device_ptr() as *const X as *const std::ffi::c_void, + (&beta) as *const Y::Scalar as *const std::ffi::c_void, + self.y.desc, + *output.device_ptr_mut() as *mut Y as *mut std::ffi::c_void, + ) + } } From 44df13dc930730fe5cf382b7b1ad8821fa8c4acf Mon Sep 17 00:00:00 2001 From: Miguel Date: Wed, 17 Jul 2024 19:31:16 -0400 Subject: [PATCH 03/11] Make fields in PoolingFoward public --- src/cudnn/safe/pooling.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cudnn/safe/pooling.rs b/src/cudnn/safe/pooling.rs index fcae0570..39c4c9f4 100644 --- a/src/cudnn/safe/pooling.rs +++ b/src/cudnn/safe/pooling.rs @@ -44,9 +44,9 @@ impl Cudnn { } pub struct PoolingForward<'a, P, X, Y> { - pooling: &'a PoolingDescriptor

, - x: &'a TensorDescriptor, - y: &'a TensorDescriptor, + pub pooling: &'a PoolingDescriptor

, + pub x: &'a TensorDescriptor, + pub y: &'a TensorDescriptor, } impl<'a, P, X, Y> PoolingForward<'a, P, X, Y> From 2ae4ca7f519b026e51fceb18c839a2e5a898cce4 Mon Sep 17 00:00:00 2001 From: Miguel Date: Thu, 18 Jul 2024 14:58:16 -0400 Subject: [PATCH 04/11] Add conv fused with activation and bias --- src/cudnn/result.rs | 68 +++++++++++++++++++++ src/cudnn/safe/conv.rs | 131 +++++++++++++++++++++++++++++++++++++++++ src/cudnn/safe/mod.rs | 1 + 3 files changed, 200 insertions(+) diff --git a/src/cudnn/result.rs b/src/cudnn/result.rs index f7bb696a..4fb9dbfd 100644 --- a/src/cudnn/result.rs +++ b/src/cudnn/result.rs @@ -447,6 +447,51 @@ pub unsafe fn convolution_forward( .result() } +#[allow(clippy::too_many_arguments)] +pub unsafe fn convolution_bias_activation_forward( + handle: sys::cudnnHandle_t, + alpha1: *const ::core::ffi::c_void, + x_desc: sys::cudnnTensorDescriptor_t, + x: *const ::core::ffi::c_void, + w_desc: sys::cudnnFilterDescriptor_t, + w: *const ::core::ffi::c_void, + conv_desc: sys::cudnnConvolutionDescriptor_t, + algo: sys::cudnnConvolutionFwdAlgo_t, + work_space: *mut ::core::ffi::c_void, + work_space_size_in_bytes: usize, + alpha2: *const ::core::ffi::c_void, + z_desc: sys::cudnnTensorDescriptor_t, + z: *const ::core::ffi::c_void, + bias_desc: sys::cudnnTensorDescriptor_t, + bias: *const ::core::ffi::c_void, + activation_desc: sys::cudnnActivationDescriptor_t, + y_desc: sys::cudnnTensorDescriptor_t, + y: *mut ::core::ffi::c_void, +) -> Result<(), CudnnError> { + lib() + .cudnnConvolutionBiasActivationForward( + handle, + alpha1, + x_desc, + x, + w_desc, + w, + conv_desc, + algo, + work_space, + work_space_size_in_bytes, + alpha2, + z_desc, + z, + bias_desc, + bias, + activation_desc, + y_desc, + y, + ) + .result() +} + /// See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnGetConvolutionBackwardDataAlgorithm_v7) /// /// # Safety @@ -830,3 +875,26 @@ pub fn pooling_forward( .result() } } + +pub fn create_activation_descriptor() -> Result { + let mut desc = MaybeUninit::uninit(); + unsafe { + lib() + .cudnnCreateActivationDescriptor(desc.as_mut_ptr()) + .result()?; + Ok(desc.assume_init()) + } +} + +pub fn set_activation_descriptor( + desc: sys::cudnnActivationDescriptor_t, + mode: sys::cudnnActivationMode_t, + nan_propagation: sys::cudnnNanPropagation_t, + coef: f64, +) -> Result<(), CudnnError> { + unsafe { + lib() + .cudnnSetActivationDescriptor(desc, mode, nan_propagation, coef) + .result() + } +} diff --git a/src/cudnn/safe/conv.rs b/src/cudnn/safe/conv.rs index a6d2aba5..b1f8488d 100644 --- a/src/cudnn/safe/conv.rs +++ b/src/cudnn/safe/conv.rs @@ -525,3 +525,134 @@ impl<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvBackwardFilte ) } } + +#[derive(Debug)] +pub struct ConvBiasActivationForward< + 'a, + X: CudnnDataType, + C: CudnnDataType, + A: CudnnDataType, + Y: CudnnDataType, +> { + /// Conv parameters. + pub conv: &'a ConvDescriptor, + /// Activation function. + pub act: &'a ActivationDescriptor, + /// Input tensor descriptor + pub x: &'a TensorDescriptor, + /// Filter descriptor + pub w: &'a FilterDescriptor, + /// Z descriptor + pub z: &'a TensorDescriptor, + /// Bias descriptor + pub bias: &'a TensorDescriptor, + /// Output tensor descriptor + pub y: &'a TensorDescriptor, +} + +impl<'a, X, C, A, Y> ConvBiasActivationForward<'a, X, C, A, Y> +where + X: CudnnDataType, + C: CudnnDataType, + A: CudnnDataType, + Y: CudnnDataType, +{ + /// Picks the fastest algorithm from all available cuDNN algorithms based on cudnn heuristics. + pub fn pick_algorithm(&self) -> Result { + let conv = ConvForward { + conv: self.conv, + x: self.x, + w: self.w, + y: self.y, + }; + conv.pick_algorithm() + } + + /// Returns size in **bytes** to execute the selected algorithm. + pub fn get_workspace_size( + &self, + algo: sys::cudnnConvolutionFwdAlgo_t, + ) -> Result { + let conv = ConvForward { + conv: self.conv, + x: self.x, + w: self.w, + y: self.y, + }; + conv.get_workspace_size(algo) + } + + pub unsafe fn launch( + &self, + algo: sys::cudnnConvolutionFwdAlgo_t, + workspace: Option<&mut Workspace>, + (alpha1, alpha2): (Y, Y), + src: &Src, + filter: &Filter, + z: &Src, + bias: &Src, + y: &mut Dst, + ) -> Result<(), CudnnError> + where + Workspace: DevicePtrMut, + Src: DevicePtr, + Filter: DevicePtr, + Dst: DevicePtrMut, + { + let (num_bytes, workspace_ptr) = match workspace { + Some(w) => ( + w.num_bytes(), + *w.device_ptr_mut() as *mut u8 as *mut std::ffi::c_void, + ), + None => (0, std::ptr::null_mut()), + }; + let alpha1 = alpha1.into_scaling_parameter(); + let alpha2 = alpha2.into_scaling_parameter(); + result::convolution_bias_activation_forward( + self.conv.handle.handle, + (&alpha1) as *const Y::Scalar as *const std::ffi::c_void, + self.x.desc, + *src.device_ptr() as *const X as *const std::ffi::c_void, + self.w.desc, + *filter.device_ptr() as *const X as *const std::ffi::c_void, + self.conv.desc, + algo, + workspace_ptr, + num_bytes, + (&alpha2) as *const Y::Scalar as *const std::ffi::c_void, + self.z.desc, + *z.device_ptr() as *const X as *const std::ffi::c_void, + self.bias.desc, + *bias.device_ptr() as *const X as *const std::ffi::c_void, + self.act.desc, + self.y.desc, + *y.device_ptr_mut() as *mut Y as *mut std::ffi::c_void, + ) + } +} + +#[derive(Debug)] +pub struct ActivationDescriptor { + pub(crate) desc: sys::cudnnActivationDescriptor_t, + #[allow(unused)] + pub(crate) handle: Arc, + pub(crate) marker: PhantomData, +} + +impl Cudnn { + pub fn create_activation( + self: &Arc, + mode: sys::cudnnActivationMode_t, + nan_propagation: sys::cudnnNanPropagation_t, + coef: f64, + ) -> Result, CudnnError> { + let desc = result::create_activation_descriptor()?; + let desc = ActivationDescriptor { + desc, + handle: self.clone(), + marker: PhantomData, + }; + result::set_activation_descriptor(desc.desc, mode, nan_propagation, coef)?; + Ok(desc) + } +} diff --git a/src/cudnn/safe/mod.rs b/src/cudnn/safe/mod.rs index 872971ea..c1f0f876 100644 --- a/src/cudnn/safe/mod.rs +++ b/src/cudnn/safe/mod.rs @@ -31,6 +31,7 @@ pub use self::conv::{ // Current APIs ConvBackwardData, ConvBackwardFilter, + ConvBiasActivationForward, ConvDescriptor, ConvForward, FilterDescriptor, From 7511aff081f93c5b6f2a00e77f62bff9ae3e8f33 Mon Sep 17 00:00:00 2001 From: Miguel Date: Fri, 19 Jul 2024 17:28:37 -0400 Subject: [PATCH 05/11] Add activation forward --- src/cudnn/result.rs | 17 ++++++++++ src/cudnn/safe/activation.rs | 66 ++++++++++++++++++++++++++++++++++++ src/cudnn/safe/conv.rs | 27 +-------------- src/cudnn/safe/mod.rs | 2 ++ 4 files changed, 86 insertions(+), 26 deletions(-) create mode 100644 src/cudnn/safe/activation.rs diff --git a/src/cudnn/result.rs b/src/cudnn/result.rs index 4fb9dbfd..36b69d8d 100644 --- a/src/cudnn/result.rs +++ b/src/cudnn/result.rs @@ -898,3 +898,20 @@ pub fn set_activation_descriptor( .result() } } + +pub fn activation_forward( + handle: sys::cudnnHandle_t, + activation_desc: sys::cudnnActivationDescriptor_t, + alpha: *const ::core::ffi::c_void, + x_desc: sys::cudnnTensorDescriptor_t, + x: *const ::core::ffi::c_void, + beta: *const ::core::ffi::c_void, + y_desc: sys::cudnnTensorDescriptor_t, + y: *mut ::core::ffi::c_void, +) -> Result<(), CudnnError> { + unsafe { + lib() + .cudnnActivationForward(handle, activation_desc, alpha, x_desc, x, beta, y_desc, y) + .result() + } +} diff --git a/src/cudnn/safe/activation.rs b/src/cudnn/safe/activation.rs new file mode 100644 index 00000000..9234f8f6 --- /dev/null +++ b/src/cudnn/safe/activation.rs @@ -0,0 +1,66 @@ +use crate::cudnn::{result, sys, Cudnn, CudnnDataType, CudnnError}; +use crate::driver::{DevicePtr, DevicePtrMut}; +use core::marker::PhantomData; +use std::sync::Arc; + +pub struct ActivationForward<'a, A: CudnnDataType> { + /// Activation function. + pub act: &'a ActivationDescriptor, +} + +impl<'a, T> ActivationForward<'a, T> +where + T: CudnnDataType, +{ + pub fn launch( + &self, + (alpha, beta): (T, T), + x_desc: sys::cudnnTensorDescriptor_t, + x: &Src, + y_desc: sys::cudnnTensorDescriptor_t, + y: &mut Dst, + ) -> Result<(), CudnnError> + where + Src: DevicePtr, + Dst: DevicePtrMut, + { + let alpha = alpha.into_scaling_parameter(); + let beta = beta.into_scaling_parameter(); + result::activation_forward( + self.act.handle.handle, + self.act.desc, + (&alpha) as *const T::Scalar as *const std::ffi::c_void, + x_desc, + *x.device_ptr() as *const T as *const std::ffi::c_void, + (&beta) as *const T::Scalar as *const std::ffi::c_void, + y_desc, + *y.device_ptr_mut() as *mut T as *mut std::ffi::c_void, + ) + } +} + +#[derive(Debug)] +pub struct ActivationDescriptor { + pub(crate) desc: sys::cudnnActivationDescriptor_t, + #[allow(unused)] + pub(crate) handle: Arc, + pub(crate) marker: PhantomData, +} + +impl Cudnn { + pub fn create_activation( + self: &Arc, + mode: sys::cudnnActivationMode_t, + nan_propagation: sys::cudnnNanPropagation_t, + coef: f64, + ) -> Result, CudnnError> { + let desc = result::create_activation_descriptor()?; + let desc = ActivationDescriptor { + desc, + handle: self.clone(), + marker: PhantomData, + }; + result::set_activation_descriptor(desc.desc, mode, nan_propagation, coef)?; + Ok(desc) + } +} diff --git a/src/cudnn/safe/conv.rs b/src/cudnn/safe/conv.rs index b1f8488d..56073636 100644 --- a/src/cudnn/safe/conv.rs +++ b/src/cudnn/safe/conv.rs @@ -4,6 +4,7 @@ use crate::{ driver::{DevicePtr, DevicePtrMut}, }; +use crate::cudnn::safe::activation::ActivationDescriptor; use std::{marker::PhantomData, sync::Arc}; /// A descriptor of the filters for conv operation. Create with [`Cudnn::create_4d_filter()`] @@ -630,29 +631,3 @@ where ) } } - -#[derive(Debug)] -pub struct ActivationDescriptor { - pub(crate) desc: sys::cudnnActivationDescriptor_t, - #[allow(unused)] - pub(crate) handle: Arc, - pub(crate) marker: PhantomData, -} - -impl Cudnn { - pub fn create_activation( - self: &Arc, - mode: sys::cudnnActivationMode_t, - nan_propagation: sys::cudnnNanPropagation_t, - coef: f64, - ) -> Result, CudnnError> { - let desc = result::create_activation_descriptor()?; - let desc = ActivationDescriptor { - desc, - handle: self.clone(), - marker: PhantomData, - }; - result::set_activation_descriptor(desc.desc, mode, nan_propagation, coef)?; - Ok(desc) - } -} diff --git a/src/cudnn/safe/mod.rs b/src/cudnn/safe/mod.rs index c1f0f876..1d6a9858 100644 --- a/src/cudnn/safe/mod.rs +++ b/src/cudnn/safe/mod.rs @@ -16,6 +16,7 @@ //! //! # Reductions +mod activation; mod conv; mod core; mod pooling; @@ -40,6 +41,7 @@ pub use self::core::{Cudnn, CudnnDataType, TensorDescriptor}; pub use self::pooling::{PoolingDescriptor, PoolingForward}; pub use self::reduce::{FlatIndices, NoIndices, ReduceTensor, ReductionDescriptor}; pub use super::result::CudnnError; +pub use activation::{ActivationDescriptor, ActivationForward}; #[cfg(test)] mod tests { From 2ac8dbf1a809ac4fe1dff2142ae64c3bafbd848f Mon Sep 17 00:00:00 2001 From: Miguel Date: Fri, 19 Jul 2024 18:12:04 -0400 Subject: [PATCH 06/11] Add descriptors to forward object --- src/cudnn/safe/activation.rs | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/cudnn/safe/activation.rs b/src/cudnn/safe/activation.rs index 9234f8f6..bf83be5c 100644 --- a/src/cudnn/safe/activation.rs +++ b/src/cudnn/safe/activation.rs @@ -1,40 +1,42 @@ -use crate::cudnn::{result, sys, Cudnn, CudnnDataType, CudnnError}; +use crate::cudnn::{result, sys, Cudnn, CudnnDataType, CudnnError, TensorDescriptor}; use crate::driver::{DevicePtr, DevicePtrMut}; use core::marker::PhantomData; use std::sync::Arc; -pub struct ActivationForward<'a, A: CudnnDataType> { +pub struct ActivationForward<'a, A: CudnnDataType, X: CudnnDataType, Y: CudnnDataType> { /// Activation function. pub act: &'a ActivationDescriptor, + pub x: &'a TensorDescriptor, + pub y: &'a TensorDescriptor, } -impl<'a, T> ActivationForward<'a, T> +impl<'a, A, X, Y> ActivationForward<'a, A, X, Y> where - T: CudnnDataType, + A: CudnnDataType, + X: CudnnDataType, + Y: CudnnDataType, { pub fn launch( &self, - (alpha, beta): (T, T), - x_desc: sys::cudnnTensorDescriptor_t, + (alpha, beta): (Y, Y), x: &Src, - y_desc: sys::cudnnTensorDescriptor_t, y: &mut Dst, ) -> Result<(), CudnnError> where - Src: DevicePtr, - Dst: DevicePtrMut, + Src: DevicePtr, + Dst: DevicePtrMut, { let alpha = alpha.into_scaling_parameter(); let beta = beta.into_scaling_parameter(); result::activation_forward( self.act.handle.handle, self.act.desc, - (&alpha) as *const T::Scalar as *const std::ffi::c_void, - x_desc, - *x.device_ptr() as *const T as *const std::ffi::c_void, - (&beta) as *const T::Scalar as *const std::ffi::c_void, - y_desc, - *y.device_ptr_mut() as *mut T as *mut std::ffi::c_void, + (&alpha) as *const Y::Scalar as *const std::ffi::c_void, + self.x.desc, + *x.device_ptr() as *const X as *const std::ffi::c_void, + (&beta) as *const Y::Scalar as *const std::ffi::c_void, + self.y.desc, + *y.device_ptr_mut() as *mut Y as *mut std::ffi::c_void, ) } } From 5581b94d744189754f988c22c68f05462954da17 Mon Sep 17 00:00:00 2001 From: Miguel Date: Wed, 7 Aug 2024 16:47:28 -0400 Subject: [PATCH 07/11] Keep names consistent --- src/cudnn/safe/activation.rs | 52 ++++++++++++++++++------------------ src/cudnn/safe/pooling.rs | 20 +++++++------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/cudnn/safe/activation.rs b/src/cudnn/safe/activation.rs index bf83be5c..1e2a6e47 100644 --- a/src/cudnn/safe/activation.rs +++ b/src/cudnn/safe/activation.rs @@ -3,6 +3,32 @@ use crate::driver::{DevicePtr, DevicePtrMut}; use core::marker::PhantomData; use std::sync::Arc; +#[derive(Debug)] +pub struct ActivationDescriptor { + pub(crate) desc: sys::cudnnActivationDescriptor_t, + #[allow(unused)] + pub(crate) handle: Arc, + pub(crate) marker: PhantomData, +} + +impl Cudnn { + pub fn create_activation( + self: &Arc, + mode: sys::cudnnActivationMode_t, + nan_propagation: sys::cudnnNanPropagation_t, + coef: f64, + ) -> Result, CudnnError> { + let desc = result::create_activation_descriptor()?; + let desc = ActivationDescriptor { + desc, + handle: self.clone(), + marker: PhantomData, + }; + result::set_activation_descriptor(desc.desc, mode, nan_propagation, coef)?; + Ok(desc) + } +} + pub struct ActivationForward<'a, A: CudnnDataType, X: CudnnDataType, Y: CudnnDataType> { /// Activation function. pub act: &'a ActivationDescriptor, @@ -40,29 +66,3 @@ where ) } } - -#[derive(Debug)] -pub struct ActivationDescriptor { - pub(crate) desc: sys::cudnnActivationDescriptor_t, - #[allow(unused)] - pub(crate) handle: Arc, - pub(crate) marker: PhantomData, -} - -impl Cudnn { - pub fn create_activation( - self: &Arc, - mode: sys::cudnnActivationMode_t, - nan_propagation: sys::cudnnNanPropagation_t, - coef: f64, - ) -> Result, CudnnError> { - let desc = result::create_activation_descriptor()?; - let desc = ActivationDescriptor { - desc, - handle: self.clone(), - marker: PhantomData, - }; - result::set_activation_descriptor(desc.desc, mode, nan_propagation, coef)?; - Ok(desc) - } -} diff --git a/src/cudnn/safe/pooling.rs b/src/cudnn/safe/pooling.rs index 39c4c9f4..956ffe98 100644 --- a/src/cudnn/safe/pooling.rs +++ b/src/cudnn/safe/pooling.rs @@ -16,7 +16,7 @@ pub struct PoolingDescriptor { impl Cudnn { pub fn create_poolingnd( self: &Arc, - input: &[std::ffi::c_int], + filter: &[std::ffi::c_int], pads: &[std::ffi::c_int], strides: &[std::ffi::c_int], mode: sys::cudnnPoolingMode_t, @@ -33,8 +33,8 @@ impl Cudnn { desc.desc, mode, nan_propagation, - input.len() as std::ffi::c_int, - input, + filter.len() as std::ffi::c_int, + filter, pads, strides, )?; @@ -55,15 +55,15 @@ where X: CudnnDataType, Y: CudnnDataType, { - pub fn launch( + pub fn launch( &self, (alpha, beta): (Y, Y), - input: &Input, - output: &mut Output, + src: &Src, + y: &mut Dst, ) -> Result<(), CudnnError> where - Input: DevicePtr, - Output: DevicePtrMut, + Src: DevicePtr, + Dst: DevicePtrMut, { let alpha = alpha.into_scaling_parameter(); let beta = beta.into_scaling_parameter(); @@ -72,10 +72,10 @@ where self.pooling.desc, (&alpha) as *const Y::Scalar as *const std::ffi::c_void, self.x.desc, - *input.device_ptr() as *const X as *const std::ffi::c_void, + *src.device_ptr() as *const X as *const std::ffi::c_void, (&beta) as *const Y::Scalar as *const std::ffi::c_void, self.y.desc, - *output.device_ptr_mut() as *mut Y as *mut std::ffi::c_void, + *y.device_ptr_mut() as *mut Y as *mut std::ffi::c_void, ) } } From 2007812d9ec37b9b92838cf838812699c193d193 Mon Sep 17 00:00:00 2001 From: Miguel Date: Tue, 3 Dec 2024 16:01:53 -0500 Subject: [PATCH 08/11] Add docs --- src/cudnn/result.rs | 8 ++++---- src/cudnn/safe/activation.rs | 13 ++++++++++++- src/cudnn/safe/conv.rs | 15 +++++++++++++++ src/cudnn/safe/pooling.rs | 20 ++++++++++++++++---- 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/src/cudnn/result.rs b/src/cudnn/result.rs index 36b69d8d..b9d37369 100644 --- a/src/cudnn/result.rs +++ b/src/cudnn/result.rs @@ -839,8 +839,8 @@ pub fn set_pooling_descriptor( desc: sys::cudnnPoolingDescriptor_t, mode: sys::cudnnPoolingMode_t, nan_propagation: sys::cudnnNanPropagation_t, - dims: std::ffi::c_int, - input: &[std::ffi::c_int], + nb_dims: std::ffi::c_int, + window_dims: &[std::ffi::c_int], pads: &[std::ffi::c_int], strides: &[std::ffi::c_int], ) -> Result<(), CudnnError> { @@ -850,8 +850,8 @@ pub fn set_pooling_descriptor( desc, mode, nan_propagation, - dims, - input.as_ptr(), + nb_dims, + window_dims.as_ptr(), pads.as_ptr(), strides.as_ptr(), ) diff --git a/src/cudnn/safe/activation.rs b/src/cudnn/safe/activation.rs index 1e2a6e47..18e47d7e 100644 --- a/src/cudnn/safe/activation.rs +++ b/src/cudnn/safe/activation.rs @@ -1,8 +1,9 @@ -use crate::cudnn::{result, sys, Cudnn, CudnnDataType, CudnnError, TensorDescriptor}; +use crate::cudnn::{result, sys, ConvForward, Cudnn, CudnnDataType, CudnnError, TensorDescriptor}; use crate::driver::{DevicePtr, DevicePtrMut}; use core::marker::PhantomData; use std::sync::Arc; +/// A descriptor of the activation operation. Create with [`Cudnn::create_activation()`] #[derive(Debug)] pub struct ActivationDescriptor { pub(crate) desc: sys::cudnnActivationDescriptor_t, @@ -29,6 +30,8 @@ impl Cudnn { } } +/// The activation forward operation. Pass in references to descriptors +/// directly, and then call [`ConvForward::launch()`] . pub struct ActivationForward<'a, A: CudnnDataType, X: CudnnDataType, Y: CudnnDataType> { /// Activation function. pub act: &'a ActivationDescriptor, @@ -42,6 +45,14 @@ where X: CudnnDataType, Y: CudnnDataType, { + /// Launches the operation. + /// + /// - `src` is the input tensor + /// - `y` is the output + /// + /// # Safety + /// The arguments must match the data type/layout specified in the + /// descriptors in `self. pub fn launch( &self, (alpha, beta): (Y, Y), diff --git a/src/cudnn/safe/conv.rs b/src/cudnn/safe/conv.rs index 56073636..1bab263a 100644 --- a/src/cudnn/safe/conv.rs +++ b/src/cudnn/safe/conv.rs @@ -527,6 +527,12 @@ impl<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvBackwardFilte } } +/// The bias + convolution + activation forward operation. +/// The full computation follows the equation `y = act (alpha1 * conv(x) + alpha2 * z + bias)`. +/// Pass in references to descriptors directly, and then call: +/// 1. [`ConvForward::pick_algorithm()`] to use cudnn heuristics to select the algorithm +/// 2. [`ConvForward::get_workspace_size()`] to get required workspace size. +/// 3. [`ConvForward::launch()`] to execute it #[derive(Debug)] pub struct ConvBiasActivationForward< 'a, @@ -583,6 +589,15 @@ where conv.get_workspace_size(algo) } + /// Launches the operation. + /// + /// - `src` is the input tensor + /// - `filter` is the convolution kernels + /// - `y` is the output + /// + /// # Safety + /// The src/filter/y arguments must match the data type/layout specified in the + /// descriptors in `self. pub unsafe fn launch( &self, algo: sys::cudnnConvolutionFwdAlgo_t, diff --git a/src/cudnn/safe/pooling.rs b/src/cudnn/safe/pooling.rs index 956ffe98..6d858ce6 100644 --- a/src/cudnn/safe/pooling.rs +++ b/src/cudnn/safe/pooling.rs @@ -3,9 +3,10 @@ use crate::{ driver::{DevicePtr, DevicePtrMut}, }; -use crate::cudnn::{result, Cudnn, CudnnDataType, TensorDescriptor}; +use crate::cudnn::{result, ConvForward, Cudnn, CudnnDataType, TensorDescriptor}; use std::{marker::PhantomData, sync::Arc}; +/// A descriptor of the window for pooling operation. Create with [`Cudnn::create_poolingnd()`] pub struct PoolingDescriptor { desc: sys::cudnnPoolingDescriptor_t, #[allow(unused)] @@ -14,9 +15,10 @@ pub struct PoolingDescriptor { } impl Cudnn { + /// Create a window nd descriptor. pub fn create_poolingnd( self: &Arc, - filter: &[std::ffi::c_int], + window: &[std::ffi::c_int], pads: &[std::ffi::c_int], strides: &[std::ffi::c_int], mode: sys::cudnnPoolingMode_t, @@ -33,8 +35,8 @@ impl Cudnn { desc.desc, mode, nan_propagation, - filter.len() as std::ffi::c_int, - filter, + window.len() as std::ffi::c_int, + window, pads, strides, )?; @@ -43,6 +45,8 @@ impl Cudnn { } } +/// The pooling forward operation. Pass in references to descriptors +/// directly, and then call [`PoolingForward::launch()`]. pub struct PoolingForward<'a, P, X, Y> { pub pooling: &'a PoolingDescriptor

, pub x: &'a TensorDescriptor, @@ -55,6 +59,14 @@ where X: CudnnDataType, Y: CudnnDataType, { + /// Launches the operation. + /// + /// - `src` is the input tensor + /// - `y` is the output + /// + /// # Safety + /// The arguments must match the data type/layout specified in the + /// descriptors in `self. pub fn launch( &self, (alpha, beta): (Y, Y), From 42787883ba842d601425584f1668d7c5797b3a45 Mon Sep 17 00:00:00 2001 From: Miguel Date: Tue, 3 Dec 2024 17:11:50 -0500 Subject: [PATCH 09/11] Add test for the conv + bias + act operation --- src/cudnn/safe/activation.rs | 2 +- src/cudnn/safe/mod.rs | 81 ++++++++++++++++++++++++++++++++++++ src/cudnn/safe/pooling.rs | 2 +- 3 files changed, 83 insertions(+), 2 deletions(-) diff --git a/src/cudnn/safe/activation.rs b/src/cudnn/safe/activation.rs index 18e47d7e..48bc0c83 100644 --- a/src/cudnn/safe/activation.rs +++ b/src/cudnn/safe/activation.rs @@ -1,4 +1,4 @@ -use crate::cudnn::{result, sys, ConvForward, Cudnn, CudnnDataType, CudnnError, TensorDescriptor}; +use crate::cudnn::{result, sys, Cudnn, CudnnDataType, CudnnError, TensorDescriptor}; use crate::driver::{DevicePtr, DevicePtrMut}; use core::marker::PhantomData; use std::sync::Arc; diff --git a/src/cudnn/safe/mod.rs b/src/cudnn/safe/mod.rs index 1d6a9858..a668168f 100644 --- a/src/cudnn/safe/mod.rs +++ b/src/cudnn/safe/mod.rs @@ -293,4 +293,85 @@ mod tests { assert_eq!(c_host.len(), 1); assert_eq!(c_host[0], 21.0); } + + #[test] + fn test_conv_bias_activation() -> Result<(), CudnnError> { + let dev = CudaDevice::new(0).unwrap(); + let cudnn = Cudnn::new(dev.clone())?; + + let conv = cudnn.create_convnd::( + &[0; 3], + &[1; 3], + &[1; 3], + cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION, + )?; + + // Create input, filter and output tensors + let x = dev.htod_copy(vec![1.0f32; 32 * 3 * 64 * 64 * 64]).unwrap(); + let x_desc = cudnn.create_nd_tensor::( + &[32, 3, 64, 64, 64], + &[3 * 64 * 64 * 64, 64 * 64 * 64, 64 * 64, 64, 1], + )?; + let filter = dev.htod_copy(vec![1.0f32; 32 * 3 * 4 * 4 * 4]).unwrap(); + let filter_desc = cudnn.create_nd_filter::( + cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, + &[32, 3, 4, 4, 4], + )?; + let bias = dev.htod_copy(vec![1.0f32; 32]).unwrap(); + let bias_desc = cudnn.create_nd_tensor::(&[1, 32, 1, 1, 1], &[32, 1, 1, 1, 1])?; + let activation_desc = cudnn.create_activation::( + cudnn::sys::cudnnActivationMode_t::CUDNN_ACTIVATION_RELU, + cudnn::sys::cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, + f64::MAX, + )?; + let z = dev.htod_copy(vec![0.0f32; 32 * 32 * 61 * 61 * 61]).unwrap(); + let z_desc = cudnn.create_nd_tensor::( + &[32, 32, 61, 61, 61], + &[32 * 61 * 61 * 61, 61 * 61 * 61, 61 * 61, 61, 1], + )?; + let mut y = dev.alloc_zeros::(32 * 32 * 61 * 61 * 61).unwrap(); + let y_desc = cudnn.create_nd_tensor::( + &[32, 32, 61, 61, 61], + &[32 * 61 * 61 * 61, 61 * 61 * 61, 61 * 61, 61, 1], + )?; + + { + let op = ConvBiasActivationForward { + conv: &conv, + act: &activation_desc, + x: &x_desc, + w: &filter_desc, + y: &y_desc, + z: &z_desc, + bias: &bias_desc, + }; + + // Pick algorithm + let algo = op.pick_algorithm()?; + + // Get workspace size + let workspace_size = op.get_workspace_size(algo)?; + let mut workspace = dev.alloc_zeros::(workspace_size).unwrap(); + + // Launch conv operation + unsafe { + op.launch( + algo, + Some(&mut workspace), + (1.0, 0.0), + &x, + &filter, + &z, + &bias, + &mut y, + )?; + } + + let y_host = dev.sync_reclaim(y).unwrap(); + assert_eq!(y_host.len(), 32 * 32 * 61 * 61 * 61); + assert_eq!(y_host[0], 3.0 * 4.0 * 4.0 * 4.0 + 1.0); + } + + Ok(()) + } } diff --git a/src/cudnn/safe/pooling.rs b/src/cudnn/safe/pooling.rs index 6d858ce6..07612fe2 100644 --- a/src/cudnn/safe/pooling.rs +++ b/src/cudnn/safe/pooling.rs @@ -3,7 +3,7 @@ use crate::{ driver::{DevicePtr, DevicePtrMut}, }; -use crate::cudnn::{result, ConvForward, Cudnn, CudnnDataType, TensorDescriptor}; +use crate::cudnn::{result, Cudnn, CudnnDataType, TensorDescriptor}; use std::{marker::PhantomData, sync::Arc}; /// A descriptor of the window for pooling operation. Create with [`Cudnn::create_poolingnd()`] From bf86b4c2e51fec59197d66a945fa5703d08d66c6 Mon Sep 17 00:00:00 2001 From: Miguel Date: Tue, 3 Dec 2024 17:49:15 -0500 Subject: [PATCH 10/11] Add test for the pooling operation --- src/cudnn/safe/mod.rs | 43 +++++++++++++++++++++++++++++++++++++++ src/cudnn/safe/pooling.rs | 2 +- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/cudnn/safe/mod.rs b/src/cudnn/safe/mod.rs index a668168f..67a80480 100644 --- a/src/cudnn/safe/mod.rs +++ b/src/cudnn/safe/mod.rs @@ -374,4 +374,47 @@ mod tests { Ok(()) } + + #[test] + fn test_pooling() -> Result<(), CudnnError> { + let dev = CudaDevice::new(0).unwrap(); + let cudnn = Cudnn::new(dev.clone())?; + + let pooling = cudnn.create_poolingnd::( + &[2, 2], + &[0, 0], + &[2, 2], + cudnn::sys::cudnnPoolingMode_t::CUDNN_POOLING_MAX, + cudnn::sys::cudnnNanPropagation_t::CUDNN_PROPAGATE_NAN, + )?; + + // Create input, filter and output tensors + let x = dev + .htod_copy(vec![ + 1.0, 1.0, 2.0, 4.0, 5.0, 6.0, 7.0, 8.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0, + ]) + .unwrap(); + let x_desc = cudnn.create_nd_tensor::(&[32, 3, 4, 4], &[32 * 3 * 4, 3 * 4, 4, 1])?; + let mut y = dev.alloc_zeros::(32 * 3 * 2 * 2).unwrap(); + let y_desc = cudnn.create_nd_tensor::(&[32, 3, 2, 2], &[3 * 2 * 2, 2 * 2, 2, 1])?; + + { + let op = PoolingForward { + pooling: &pooling, + x: &x_desc, + y: &y_desc, + }; + + // Launch conv operation + unsafe { + op.launch((1.0, 0.0), &x, &mut y)?; + } + + let y_host = dev.sync_reclaim(y).unwrap(); + assert_eq!(y_host.len(), 32 * 3 * 2 * 2); + assert_eq!(y_host[0], 6.0); + } + + Ok(()) + } } diff --git a/src/cudnn/safe/pooling.rs b/src/cudnn/safe/pooling.rs index 07612fe2..089c6e36 100644 --- a/src/cudnn/safe/pooling.rs +++ b/src/cudnn/safe/pooling.rs @@ -67,7 +67,7 @@ where /// # Safety /// The arguments must match the data type/layout specified in the /// descriptors in `self. - pub fn launch( + pub unsafe fn launch( &self, (alpha, beta): (Y, Y), src: &Src, From 802a1065b15e490fea219af91f6007bf05dd6ad0 Mon Sep 17 00:00:00 2001 From: Miguel Date: Tue, 3 Dec 2024 18:02:53 -0500 Subject: [PATCH 11/11] Add test for the activation operation --- src/cudnn/safe/activation.rs | 2 +- src/cudnn/safe/mod.rs | 40 ++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/cudnn/safe/activation.rs b/src/cudnn/safe/activation.rs index 48bc0c83..a1c6af23 100644 --- a/src/cudnn/safe/activation.rs +++ b/src/cudnn/safe/activation.rs @@ -53,7 +53,7 @@ where /// # Safety /// The arguments must match the data type/layout specified in the /// descriptors in `self. - pub fn launch( + pub unsafe fn launch( &self, (alpha, beta): (Y, Y), x: &Src, diff --git a/src/cudnn/safe/mod.rs b/src/cudnn/safe/mod.rs index 67a80480..66e2ee69 100644 --- a/src/cudnn/safe/mod.rs +++ b/src/cudnn/safe/mod.rs @@ -417,4 +417,44 @@ mod tests { Ok(()) } + + #[test] + fn test_activation() -> Result<(), CudnnError> { + let dev = CudaDevice::new(0).unwrap(); + let cudnn = Cudnn::new(dev.clone())?; + + let act = cudnn.create_activation::( + cudnn::sys::cudnnActivationMode_t::CUDNN_ACTIVATION_RELU, + cudnn::sys::cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, + f64::MAX, + )?; + + // Create input, filter and output tensors + let x = dev.htod_copy(vec![-1.0, 2.0, -3.0, 100.0]).unwrap(); + let x_desc = cudnn.create_nd_tensor::(&[1, 1, 2, 2], &[2 * 2, 2 * 2, 2, 1])?; + let mut y = dev.alloc_zeros::(4).unwrap(); + let y_desc = cudnn.create_nd_tensor::(&[1, 1, 2, 2], &[2 * 2, 2 * 2, 2, 1])?; + + { + let op = ActivationForward { + act: &act, + x: &x_desc, + y: &y_desc, + }; + + // Launch conv operation + unsafe { + op.launch((1.0, 0.0), &x, &mut y)?; + } + + let y_host = dev.sync_reclaim(y).unwrap(); + assert_eq!(y_host.len(), 2 * 2); + assert_eq!(y_host[0], 0.0); + assert_eq!(y_host[1], 2.0); + assert_eq!(y_host[2], 0.0); + assert_eq!(y_host[3], 100.0); + } + + Ok(()) + } }