@@ -29,8 +29,8 @@ namespace caffe {
29
29
template <typename Dtype>
30
30
class CuDNNConvolutionLayer : public ConvolutionLayer <Dtype> {
31
31
public:
32
- explicit CuDNNConvolutionLayer (const LayerParameter& param)
33
- : ConvolutionLayer<Dtype>(param), handles_setup_( false ) {}
32
+ explicit CuDNNConvolutionLayer (const LayerParameter& param);
33
+
34
34
virtual void LayerSetUp (const vector<Blob<Dtype>*>& bottom,
35
35
const vector<Blob<Dtype>*>& top);
36
36
virtual void Reshape (const vector<Blob<Dtype>*>& bottom,
@@ -43,49 +43,32 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
43
43
virtual void Backward_gpu (const vector<Blob<Dtype>*>& top,
44
44
const vector<bool >& propagate_down, const vector<Blob<Dtype>*>& bottom);
45
45
46
-
47
46
bool handles_setup_;
48
47
49
48
#ifdef USE_MIOPEN
50
- miopenHandle_t* handle_;
51
- hipStream_t* stream_;
52
49
53
50
// algorithms for forward and backwards convolutions
54
- miopenConvFwdAlgorithm_t* fwd_algo_;
55
- miopenConvBwdWeightsAlgorithm_t* bwd_weight_algo_;
56
- miopenConvBwdDataAlgorithm_t* bwd_data_algo_;
51
+ vector< miopenConvFwdAlgorithm_t> fwd_algo_;
52
+ vector< miopenConvBwdWeightsAlgorithm_t> bwd_weight_algo_;
53
+ vector< miopenConvBwdDataAlgorithm_t> bwd_data_algo_;
57
54
58
55
vector<miopenTensorDescriptor_t> bottom_descs_, top_descs_;
59
56
miopenTensorDescriptor_t bias_desc_;
60
57
miopenTensorDescriptor_t filter_desc_;
61
58
vector<miopenConvolutionDescriptor_t> conv_descs_;
62
59
63
60
int N_, C_, W_, H_;
64
- #endif
65
-
66
- #ifdef USE_CUDNN
67
- cudnnHandle_t* handle_;
68
- cudaStream_t* stream_;
69
-
70
- // algorithms for forward and backwards convolutions
71
- cudnnConvolutionFwdAlgo_t *fwd_algo_;
72
- cudnnConvolutionBwdFilterAlgo_t *bwd_filter_algo_;
73
- cudnnConvolutionBwdDataAlgo_t *bwd_data_algo_;
74
-
75
- vector<cudnnTensorDescriptor_t> bottom_descs_, top_descs_;
76
- cudnnTensorDescriptor_t bias_desc_;
77
- cudnnFilterDescriptor_t filter_desc_;
78
- vector<cudnnConvolutionDescriptor_t> conv_descs_;
61
+ miopenHandle_t handle_;
79
62
#endif
80
63
81
64
int bottom_offset_, top_offset_, bias_offset_;
82
65
83
- size_t * workspace_fwd_sizes_;
84
- size_t *workspace_bwd_data_sizes_ ;
85
- size_t *workspace_bwd_filter_sizes_ ;
66
+ vector< size_t > workspace_fwd_sizes_;
67
+ vector< size_t > workspace_bwd_filter_sizes_ ;
68
+ vector< size_t > workspace_bwd_data_sizes_ ;
86
69
size_t workspaceSizeInBytes; // size of underlying storage
87
70
void *workspaceData; // underlying storage
88
- void ** workspace; // aliases into workspaceData
71
+ vector< void *> workspace; // aliases into workspaceData
89
72
};
90
73
#endif
91
74
0 commit comments