diff --git a/src/cudnn/result.rs b/src/cudnn/result.rs index 0269d361..b9d37369 100644 --- a/src/cudnn/result.rs +++ b/src/cudnn/result.rs @@ -447,6 +447,51 @@ pub unsafe fn convolution_forward( .result() } +#[allow(clippy::too_many_arguments)] +pub unsafe fn convolution_bias_activation_forward( + handle: sys::cudnnHandle_t, + alpha1: *const ::core::ffi::c_void, + x_desc: sys::cudnnTensorDescriptor_t, + x: *const ::core::ffi::c_void, + w_desc: sys::cudnnFilterDescriptor_t, + w: *const ::core::ffi::c_void, + conv_desc: sys::cudnnConvolutionDescriptor_t, + algo: sys::cudnnConvolutionFwdAlgo_t, + work_space: *mut ::core::ffi::c_void, + work_space_size_in_bytes: usize, + alpha2: *const ::core::ffi::c_void, + z_desc: sys::cudnnTensorDescriptor_t, + z: *const ::core::ffi::c_void, + bias_desc: sys::cudnnTensorDescriptor_t, + bias: *const ::core::ffi::c_void, + activation_desc: sys::cudnnActivationDescriptor_t, + y_desc: sys::cudnnTensorDescriptor_t, + y: *mut ::core::ffi::c_void, +) -> Result<(), CudnnError> { + lib() + .cudnnConvolutionBiasActivationForward( + handle, + alpha1, + x_desc, + x, + w_desc, + w, + conv_desc, + algo, + work_space, + work_space_size_in_bytes, + alpha2, + z_desc, + z, + bias_desc, + bias, + activation_desc, + y_desc, + y, + ) + .result() +} + /// See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnGetConvolutionBackwardDataAlgorithm_v7) /// /// # Safety @@ -779,3 +824,94 @@ pub unsafe fn reduce_tensor( ) .result() } + +pub fn create_pooling_descriptor() -> Result { + let mut desc = MaybeUninit::uninit(); + unsafe { + lib() + .cudnnCreatePoolingDescriptor(desc.as_mut_ptr()) + .result()?; + Ok(desc.assume_init()) + } +} + +pub fn set_pooling_descriptor( + desc: sys::cudnnPoolingDescriptor_t, + mode: sys::cudnnPoolingMode_t, + nan_propagation: sys::cudnnNanPropagation_t, + nb_dims: std::ffi::c_int, + window_dims: &[std::ffi::c_int], + pads: &[std::ffi::c_int], + strides: &[std::ffi::c_int], +) -> Result<(), CudnnError> { + unsafe { + lib() + .cudnnSetPoolingNdDescriptor( + desc, + mode, + nan_propagation, + nb_dims, + window_dims.as_ptr(), + pads.as_ptr(), + strides.as_ptr(), + ) + .result() + } +} + +pub fn pooling_forward( + handle: sys::cudnnHandle_t, + pooling_desc: sys::cudnnPoolingDescriptor_t, + alpha: *const ::core::ffi::c_void, + x_desc: sys::cudnnTensorDescriptor_t, + x: *const ::core::ffi::c_void, + beta: *const ::core::ffi::c_void, + y_desc: sys::cudnnTensorDescriptor_t, + y: *mut ::core::ffi::c_void, +) -> Result<(), CudnnError> { + unsafe { + lib() + .cudnnPoolingForward(handle, pooling_desc, alpha, x_desc, x, beta, y_desc, y) + .result() + } +} + +pub fn create_activation_descriptor() -> Result { + let mut desc = MaybeUninit::uninit(); + unsafe { + lib() + .cudnnCreateActivationDescriptor(desc.as_mut_ptr()) + .result()?; + Ok(desc.assume_init()) + } +} + +pub fn set_activation_descriptor( + desc: sys::cudnnActivationDescriptor_t, + mode: sys::cudnnActivationMode_t, + nan_propagation: sys::cudnnNanPropagation_t, + coef: f64, +) -> Result<(), CudnnError> { + unsafe { + lib() + .cudnnSetActivationDescriptor(desc, mode, nan_propagation, coef) + .result() + } +} + +pub fn activation_forward( + handle: sys::cudnnHandle_t, + activation_desc: sys::cudnnActivationDescriptor_t, + alpha: *const ::core::ffi::c_void, + x_desc: sys::cudnnTensorDescriptor_t, + x: *const ::core::ffi::c_void, + beta: *const ::core::ffi::c_void, + y_desc: sys::cudnnTensorDescriptor_t, + y: *mut ::core::ffi::c_void, +) -> Result<(), CudnnError> { + unsafe { + lib() + .cudnnActivationForward(handle, activation_desc, alpha, x_desc, x, beta, y_desc, y) + .result() + } +} diff --git a/src/cudnn/safe/activation.rs b/src/cudnn/safe/activation.rs new file mode 100644 index 00000000..a1c6af23 --- /dev/null +++ b/src/cudnn/safe/activation.rs @@ -0,0 +1,79 @@ +use crate::cudnn::{result, sys, Cudnn, CudnnDataType, CudnnError, TensorDescriptor}; +use crate::driver::{DevicePtr, DevicePtrMut}; +use core::marker::PhantomData; +use std::sync::Arc; + +/// A descriptor of the activation operation. Create with [`Cudnn::create_activation()`] +#[derive(Debug)] +pub struct ActivationDescriptor { + pub(crate) desc: sys::cudnnActivationDescriptor_t, + #[allow(unused)] + pub(crate) handle: Arc, + pub(crate) marker: PhantomData, +} + +impl Cudnn { + pub fn create_activation( + self: &Arc, + mode: sys::cudnnActivationMode_t, + nan_propagation: sys::cudnnNanPropagation_t, + coef: f64, + ) -> Result, CudnnError> { + let desc = result::create_activation_descriptor()?; + let desc = ActivationDescriptor { + desc, + handle: self.clone(), + marker: PhantomData, + }; + result::set_activation_descriptor(desc.desc, mode, nan_propagation, coef)?; + Ok(desc) + } +} + +/// The activation forward operation. Pass in references to descriptors +/// directly, and then call [`ConvForward::launch()`] . +pub struct ActivationForward<'a, A: CudnnDataType, X: CudnnDataType, Y: CudnnDataType> { + /// Activation function. + pub act: &'a ActivationDescriptor, + pub x: &'a TensorDescriptor, + pub y: &'a TensorDescriptor, +} + +impl<'a, A, X, Y> ActivationForward<'a, A, X, Y> +where + A: CudnnDataType, + X: CudnnDataType, + Y: CudnnDataType, +{ + /// Launches the operation. + /// + /// - `src` is the input tensor + /// - `y` is the output + /// + /// # Safety + /// The arguments must match the data type/layout specified in the + /// descriptors in `self. + pub unsafe fn launch( + &self, + (alpha, beta): (Y, Y), + x: &Src, + y: &mut Dst, + ) -> Result<(), CudnnError> + where + Src: DevicePtr, + Dst: DevicePtrMut, + { + let alpha = alpha.into_scaling_parameter(); + let beta = beta.into_scaling_parameter(); + result::activation_forward( + self.act.handle.handle, + self.act.desc, + (&alpha) as *const Y::Scalar as *const std::ffi::c_void, + self.x.desc, + *x.device_ptr() as *const X as *const std::ffi::c_void, + (&beta) as *const Y::Scalar as *const std::ffi::c_void, + self.y.desc, + *y.device_ptr_mut() as *mut Y as *mut std::ffi::c_void, + ) + } +} diff --git a/src/cudnn/safe/conv.rs b/src/cudnn/safe/conv.rs index a6d2aba5..1bab263a 100644 --- a/src/cudnn/safe/conv.rs +++ b/src/cudnn/safe/conv.rs @@ -4,6 +4,7 @@ use crate::{ driver::{DevicePtr, DevicePtrMut}, }; +use crate::cudnn::safe::activation::ActivationDescriptor; use std::{marker::PhantomData, sync::Arc}; /// A descriptor of the filters for conv operation. Create with [`Cudnn::create_4d_filter()`] @@ -525,3 +526,123 @@ impl<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvBackwardFilte ) } } + +/// The bias + convolution + activation forward operation. +/// The full computation follows the equation `y = act (alpha1 * conv(x) + alpha2 * z + bias)`. +/// Pass in references to descriptors directly, and then call: +/// 1. [`ConvForward::pick_algorithm()`] to use cudnn heuristics to select the algorithm +/// 2. [`ConvForward::get_workspace_size()`] to get required workspace size. +/// 3. [`ConvForward::launch()`] to execute it +#[derive(Debug)] +pub struct ConvBiasActivationForward< + 'a, + X: CudnnDataType, + C: CudnnDataType, + A: CudnnDataType, + Y: CudnnDataType, +> { + /// Conv parameters. + pub conv: &'a ConvDescriptor, + /// Activation function. + pub act: &'a ActivationDescriptor, + /// Input tensor descriptor + pub x: &'a TensorDescriptor, + /// Filter descriptor + pub w: &'a FilterDescriptor, + /// Z descriptor + pub z: &'a TensorDescriptor, + /// Bias descriptor + pub bias: &'a TensorDescriptor, + /// Output tensor descriptor + pub y: &'a TensorDescriptor, +} + +impl<'a, X, C, A, Y> ConvBiasActivationForward<'a, X, C, A, Y> +where + X: CudnnDataType, + C: CudnnDataType, + A: CudnnDataType, + Y: CudnnDataType, +{ + /// Picks the fastest algorithm from all available cuDNN algorithms based on cudnn heuristics. + pub fn pick_algorithm(&self) -> Result { + let conv = ConvForward { + conv: self.conv, + x: self.x, + w: self.w, + y: self.y, + }; + conv.pick_algorithm() + } + + /// Returns size in **bytes** to execute the selected algorithm. + pub fn get_workspace_size( + &self, + algo: sys::cudnnConvolutionFwdAlgo_t, + ) -> Result { + let conv = ConvForward { + conv: self.conv, + x: self.x, + w: self.w, + y: self.y, + }; + conv.get_workspace_size(algo) + } + + /// Launches the operation. + /// + /// - `src` is the input tensor + /// - `filter` is the convolution kernels + /// - `y` is the output + /// + /// # Safety + /// The src/filter/y arguments must match the data type/layout specified in the + /// descriptors in `self. + pub unsafe fn launch( + &self, + algo: sys::cudnnConvolutionFwdAlgo_t, + workspace: Option<&mut Workspace>, + (alpha1, alpha2): (Y, Y), + src: &Src, + filter: &Filter, + z: &Src, + bias: &Src, + y: &mut Dst, + ) -> Result<(), CudnnError> + where + Workspace: DevicePtrMut, + Src: DevicePtr, + Filter: DevicePtr, + Dst: DevicePtrMut, + { + let (num_bytes, workspace_ptr) = match workspace { + Some(w) => ( + w.num_bytes(), + *w.device_ptr_mut() as *mut u8 as *mut std::ffi::c_void, + ), + None => (0, std::ptr::null_mut()), + }; + let alpha1 = alpha1.into_scaling_parameter(); + let alpha2 = alpha2.into_scaling_parameter(); + result::convolution_bias_activation_forward( + self.conv.handle.handle, + (&alpha1) as *const Y::Scalar as *const std::ffi::c_void, + self.x.desc, + *src.device_ptr() as *const X as *const std::ffi::c_void, + self.w.desc, + *filter.device_ptr() as *const X as *const std::ffi::c_void, + self.conv.desc, + algo, + workspace_ptr, + num_bytes, + (&alpha2) as *const Y::Scalar as *const std::ffi::c_void, + self.z.desc, + *z.device_ptr() as *const X as *const std::ffi::c_void, + self.bias.desc, + *bias.device_ptr() as *const X as *const std::ffi::c_void, + self.act.desc, + self.y.desc, + *y.device_ptr_mut() as *mut Y as *mut std::ffi::c_void, + ) + } +} diff --git a/src/cudnn/safe/mod.rs b/src/cudnn/safe/mod.rs index ff7fc828..66e2ee69 100644 --- a/src/cudnn/safe/mod.rs +++ b/src/cudnn/safe/mod.rs @@ -16,8 +16,10 @@ //! //! # Reductions +mod activation; mod conv; mod core; +mod pooling; mod reduce; #[allow(deprecated)] @@ -30,13 +32,16 @@ pub use self::conv::{ // Current APIs ConvBackwardData, ConvBackwardFilter, + ConvBiasActivationForward, ConvDescriptor, ConvForward, FilterDescriptor, }; pub use self::core::{Cudnn, CudnnDataType, TensorDescriptor}; +pub use self::pooling::{PoolingDescriptor, PoolingForward}; pub use self::reduce::{FlatIndices, NoIndices, ReduceTensor, ReductionDescriptor}; pub use super::result::CudnnError; +pub use activation::{ActivationDescriptor, ActivationForward}; #[cfg(test)] mod tests { @@ -288,4 +293,168 @@ mod tests { assert_eq!(c_host.len(), 1); assert_eq!(c_host[0], 21.0); } + + #[test] + fn test_conv_bias_activation() -> Result<(), CudnnError> { + let dev = CudaDevice::new(0).unwrap(); + let cudnn = Cudnn::new(dev.clone())?; + + let conv = cudnn.create_convnd::( + &[0; 3], + &[1; 3], + &[1; 3], + cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION, + )?; + + // Create input, filter and output tensors + let x = dev.htod_copy(vec![1.0f32; 32 * 3 * 64 * 64 * 64]).unwrap(); + let x_desc = cudnn.create_nd_tensor::( + &[32, 3, 64, 64, 64], + &[3 * 64 * 64 * 64, 64 * 64 * 64, 64 * 64, 64, 1], + )?; + let filter = dev.htod_copy(vec![1.0f32; 32 * 3 * 4 * 4 * 4]).unwrap(); + let filter_desc = cudnn.create_nd_filter::( + cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, + &[32, 3, 4, 4, 4], + )?; + let bias = dev.htod_copy(vec![1.0f32; 32]).unwrap(); + let bias_desc = cudnn.create_nd_tensor::(&[1, 32, 1, 1, 1], &[32, 1, 1, 1, 1])?; + let activation_desc = cudnn.create_activation::( + cudnn::sys::cudnnActivationMode_t::CUDNN_ACTIVATION_RELU, + cudnn::sys::cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, + f64::MAX, + )?; + let z = dev.htod_copy(vec![0.0f32; 32 * 32 * 61 * 61 * 61]).unwrap(); + let z_desc = cudnn.create_nd_tensor::( + &[32, 32, 61, 61, 61], + &[32 * 61 * 61 * 61, 61 * 61 * 61, 61 * 61, 61, 1], + )?; + let mut y = dev.alloc_zeros::(32 * 32 * 61 * 61 * 61).unwrap(); + let y_desc = cudnn.create_nd_tensor::( + &[32, 32, 61, 61, 61], + &[32 * 61 * 61 * 61, 61 * 61 * 61, 61 * 61, 61, 1], + )?; + + { + let op = ConvBiasActivationForward { + conv: &conv, + act: &activation_desc, + x: &x_desc, + w: &filter_desc, + y: &y_desc, + z: &z_desc, + bias: &bias_desc, + }; + + // Pick algorithm + let algo = op.pick_algorithm()?; + + // Get workspace size + let workspace_size = op.get_workspace_size(algo)?; + let mut workspace = dev.alloc_zeros::(workspace_size).unwrap(); + + // Launch conv operation + unsafe { + op.launch( + algo, + Some(&mut workspace), + (1.0, 0.0), + &x, + &filter, + &z, + &bias, + &mut y, + )?; + } + + let y_host = dev.sync_reclaim(y).unwrap(); + assert_eq!(y_host.len(), 32 * 32 * 61 * 61 * 61); + assert_eq!(y_host[0], 3.0 * 4.0 * 4.0 * 4.0 + 1.0); + } + + Ok(()) + } + + #[test] + fn test_pooling() -> Result<(), CudnnError> { + let dev = CudaDevice::new(0).unwrap(); + let cudnn = Cudnn::new(dev.clone())?; + + let pooling = cudnn.create_poolingnd::( + &[2, 2], + &[0, 0], + &[2, 2], + cudnn::sys::cudnnPoolingMode_t::CUDNN_POOLING_MAX, + cudnn::sys::cudnnNanPropagation_t::CUDNN_PROPAGATE_NAN, + )?; + + // Create input, filter and output tensors + let x = dev + .htod_copy(vec![ + 1.0, 1.0, 2.0, 4.0, 5.0, 6.0, 7.0, 8.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0, + ]) + .unwrap(); + let x_desc = cudnn.create_nd_tensor::(&[32, 3, 4, 4], &[32 * 3 * 4, 3 * 4, 4, 1])?; + let mut y = dev.alloc_zeros::(32 * 3 * 2 * 2).unwrap(); + let y_desc = cudnn.create_nd_tensor::(&[32, 3, 2, 2], &[3 * 2 * 2, 2 * 2, 2, 1])?; + + { + let op = PoolingForward { + pooling: &pooling, + x: &x_desc, + y: &y_desc, + }; + + // Launch conv operation + unsafe { + op.launch((1.0, 0.0), &x, &mut y)?; + } + + let y_host = dev.sync_reclaim(y).unwrap(); + assert_eq!(y_host.len(), 32 * 3 * 2 * 2); + assert_eq!(y_host[0], 6.0); + } + + Ok(()) + } + + #[test] + fn test_activation() -> Result<(), CudnnError> { + let dev = CudaDevice::new(0).unwrap(); + let cudnn = Cudnn::new(dev.clone())?; + + let act = cudnn.create_activation::( + cudnn::sys::cudnnActivationMode_t::CUDNN_ACTIVATION_RELU, + cudnn::sys::cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, + f64::MAX, + )?; + + // Create input, filter and output tensors + let x = dev.htod_copy(vec![-1.0, 2.0, -3.0, 100.0]).unwrap(); + let x_desc = cudnn.create_nd_tensor::(&[1, 1, 2, 2], &[2 * 2, 2 * 2, 2, 1])?; + let mut y = dev.alloc_zeros::(4).unwrap(); + let y_desc = cudnn.create_nd_tensor::(&[1, 1, 2, 2], &[2 * 2, 2 * 2, 2, 1])?; + + { + let op = ActivationForward { + act: &act, + x: &x_desc, + y: &y_desc, + }; + + // Launch conv operation + unsafe { + op.launch((1.0, 0.0), &x, &mut y)?; + } + + let y_host = dev.sync_reclaim(y).unwrap(); + assert_eq!(y_host.len(), 2 * 2); + assert_eq!(y_host[0], 0.0); + assert_eq!(y_host[1], 2.0); + assert_eq!(y_host[2], 0.0); + assert_eq!(y_host[3], 100.0); + } + + Ok(()) + } } diff --git a/src/cudnn/safe/pooling.rs b/src/cudnn/safe/pooling.rs new file mode 100644 index 00000000..089c6e36 --- /dev/null +++ b/src/cudnn/safe/pooling.rs @@ -0,0 +1,93 @@ +use crate::{ + cudnn::{result::CudnnError, sys}, + driver::{DevicePtr, DevicePtrMut}, +}; + +use crate::cudnn::{result, Cudnn, CudnnDataType, TensorDescriptor}; +use std::{marker::PhantomData, sync::Arc}; + +/// A descriptor of the window for pooling operation. Create with [`Cudnn::create_poolingnd()`] +pub struct PoolingDescriptor { + desc: sys::cudnnPoolingDescriptor_t, + #[allow(unused)] + handle: Arc, + marker: PhantomData, +} + +impl Cudnn { + /// Create a window nd descriptor. + pub fn create_poolingnd( + self: &Arc, + window: &[std::ffi::c_int], + pads: &[std::ffi::c_int], + strides: &[std::ffi::c_int], + mode: sys::cudnnPoolingMode_t, + nan_propagation: sys::cudnnNanPropagation_t, + ) -> Result, CudnnError> { + let desc = result::create_pooling_descriptor()?; + let desc = PoolingDescriptor { + desc, + handle: self.clone(), + marker: PhantomData, + }; + + result::set_pooling_descriptor( + desc.desc, + mode, + nan_propagation, + window.len() as std::ffi::c_int, + window, + pads, + strides, + )?; + + Ok(desc) + } +} + +/// The pooling forward operation. Pass in references to descriptors +/// directly, and then call [`PoolingForward::launch()`]. +pub struct PoolingForward<'a, P, X, Y> { + pub pooling: &'a PoolingDescriptor

, + pub x: &'a TensorDescriptor, + pub y: &'a TensorDescriptor, +} + +impl<'a, P, X, Y> PoolingForward<'a, P, X, Y> +where + P: CudnnDataType, + X: CudnnDataType, + Y: CudnnDataType, +{ + /// Launches the operation. + /// + /// - `src` is the input tensor + /// - `y` is the output + /// + /// # Safety + /// The arguments must match the data type/layout specified in the + /// descriptors in `self. + pub unsafe fn launch( + &self, + (alpha, beta): (Y, Y), + src: &Src, + y: &mut Dst, + ) -> Result<(), CudnnError> + where + Src: DevicePtr, + Dst: DevicePtrMut, + { + let alpha = alpha.into_scaling_parameter(); + let beta = beta.into_scaling_parameter(); + result::pooling_forward( + self.pooling.handle.handle, + self.pooling.desc, + (&alpha) as *const Y::Scalar as *const std::ffi::c_void, + self.x.desc, + *src.device_ptr() as *const X as *const std::ffi::c_void, + (&beta) as *const Y::Scalar as *const std::ffi::c_void, + self.y.desc, + *y.device_ptr_mut() as *mut Y as *mut std::ffi::c_void, + ) + } +}