|
| 1 | +#include <algorithm> |
| 2 | +#include <vector> |
| 3 | + |
| 4 | +#include "caffe/common_layers.hpp" |
| 5 | +#include "caffe/layer.hpp" |
| 6 | +#include "caffe/util/math_functions.hpp" |
| 7 | + |
| 8 | +namespace caffe { |
| 9 | + |
| 10 | +template <typename Dtype> |
| 11 | +void BatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, |
| 12 | + const vector<Blob<Dtype>*>& top) { |
| 13 | + BatchNormParameter param = this->layer_param_.batch_norm_param(); |
| 14 | + moving_average_fraction_ = param.moving_average_fraction(); |
| 15 | + use_global_stats_ = this->phase_ == TEST; |
| 16 | + if (param.has_use_global_stats()) |
| 17 | + use_global_stats_ = param.use_global_stats(); |
| 18 | + if (bottom[0]->num_axes() == 1) |
| 19 | + channels_ = 1; |
| 20 | + else |
| 21 | + channels_ = bottom[0]->shape(1); |
| 22 | + eps_ = param.eps(); |
| 23 | + if (this->blobs_.size() > 0) { |
| 24 | + LOG(INFO) << "Skipping parameter initialization"; |
| 25 | + } else { |
| 26 | + this->blobs_.resize(3); |
| 27 | + vector<int> sz; |
| 28 | + sz.push_back(channels_); |
| 29 | + this->blobs_[0].reset(new Blob<Dtype>(sz)); |
| 30 | + this->blobs_[1].reset(new Blob<Dtype>(sz)); |
| 31 | + sz[0]=1; |
| 32 | + this->blobs_[2].reset(new Blob<Dtype>(sz)); |
| 33 | + for (int i = 0; i < 3; ++i) { |
| 34 | + caffe_set(this->blobs_[i]->count(), Dtype(0), |
| 35 | + this->blobs_[i]->mutable_cpu_data()); |
| 36 | + } |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | +template <typename Dtype> |
| 41 | +void BatchNormLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, |
| 42 | + const vector<Blob<Dtype>*>& top) { |
| 43 | + if (bottom[0]->num_axes() >= 1) |
| 44 | + CHECK_EQ(bottom[0]->shape(1), channels_); |
| 45 | + top[0]->ReshapeLike(*bottom[0]); |
| 46 | + |
| 47 | + vector<int> sz; |
| 48 | + sz.push_back(channels_); |
| 49 | + mean_.Reshape(sz); |
| 50 | + variance_.Reshape(sz); |
| 51 | + temp_.ReshapeLike(*bottom[0]); |
| 52 | + x_norm_.ReshapeLike(*bottom[0]); |
| 53 | + sz[0]=bottom[0]->shape(0); |
| 54 | + batch_sum_multiplier_.Reshape(sz); |
| 55 | + |
| 56 | + int spatial_dim = bottom[0]->count()/(channels_*bottom[0]->shape(0)); |
| 57 | + if (spatial_sum_multiplier_.num_axes() == 0 || |
| 58 | + spatial_sum_multiplier_.shape(0) != spatial_dim) { |
| 59 | + sz[0] = spatial_dim; |
| 60 | + spatial_sum_multiplier_.Reshape(sz); |
| 61 | + Dtype* multiplier_data = spatial_sum_multiplier_.mutable_cpu_data(); |
| 62 | + caffe_set(spatial_sum_multiplier_.count(), Dtype(1), multiplier_data); |
| 63 | + } |
| 64 | + |
| 65 | + int numbychans = channels_*bottom[0]->shape(0); |
| 66 | + if (num_by_chans_.num_axes() == 0 || |
| 67 | + num_by_chans_.shape(0) != numbychans) { |
| 68 | + sz[0] = numbychans; |
| 69 | + num_by_chans_.Reshape(sz); |
| 70 | + caffe_set(batch_sum_multiplier_.count(), Dtype(1), |
| 71 | + batch_sum_multiplier_.mutable_cpu_data()); |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +template <typename Dtype> |
| 76 | +void BatchNormLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, |
| 77 | + const vector<Blob<Dtype>*>& top) { |
| 78 | + const Dtype* bottom_data = bottom[0]->cpu_data(); |
| 79 | + Dtype* top_data = top[0]->mutable_cpu_data(); |
| 80 | + int num = bottom[0]->shape(0); |
| 81 | + int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_); |
| 82 | + |
| 83 | + // elementwise square |
| 84 | + caffe_powx(bottom[0]->count(), bottom_data, Dtype(2), |
| 85 | + temp_.mutable_cpu_data()); |
| 86 | + |
| 87 | + if (use_global_stats_) { |
| 88 | + // use the stored mean/variance estimates. TODO(cdoersch): allow an option |
| 89 | + // to use an unbiased variance estimate, like the paper does. |
| 90 | + const Dtype scale_factor = 1 / this->blobs_[2]->cpu_data()[0]; |
| 91 | + caffe_cpu_scale(variance_.count(), scale_factor, |
| 92 | + this->blobs_[0]->cpu_data(), mean_.mutable_cpu_data()); |
| 93 | + caffe_cpu_scale(variance_.count(), scale_factor, |
| 94 | + this->blobs_[1]->cpu_data(), variance_.mutable_cpu_data()); |
| 95 | + } else { |
| 96 | + // computes variance using var(X) = E(X^2) - (EX)^2 |
| 97 | + caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, |
| 98 | + 1. / (num * spatial_dim), bottom_data, |
| 99 | + spatial_sum_multiplier_.cpu_data(), 0., |
| 100 | + num_by_chans_.mutable_cpu_data()); |
| 101 | + caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1., |
| 102 | + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., |
| 103 | + mean_.mutable_cpu_data()); |
| 104 | + caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, |
| 105 | + 1. / (num * spatial_dim), temp_.cpu_data(), |
| 106 | + spatial_sum_multiplier_.cpu_data(), 0., |
| 107 | + num_by_chans_.mutable_cpu_data()); |
| 108 | + caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1., |
| 109 | + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., |
| 110 | + variance_.mutable_cpu_data()); |
| 111 | + this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_; |
| 112 | + this->blobs_[2]->mutable_cpu_data()[0] += 1; |
| 113 | + caffe_cpu_axpby(mean_.count(), Dtype(1), mean_.cpu_data(), |
| 114 | + moving_average_fraction_, this->blobs_[0]->mutable_cpu_data()); |
| 115 | + Dtype m = Dtype(bottom[0]->count()/channels_); |
| 116 | + caffe_cpu_axpby(variance_.count(), m/(m-1), variance_.cpu_data(), |
| 117 | + moving_average_fraction_, this->blobs_[1]->mutable_cpu_data()); |
| 118 | + } |
| 119 | + // elementwise square of mean |
| 120 | + caffe_powx(mean_.count(), mean_.cpu_data(), Dtype(2), |
| 121 | + temp_.mutable_cpu_data()); |
| 122 | + |
| 123 | + caffe_sub(mean_.count(), variance_.cpu_data(), temp_.cpu_data(), |
| 124 | + variance_.mutable_cpu_data()); // variance |
| 125 | + |
| 126 | + // normalize variance |
| 127 | + caffe_add_scalar(variance_.count(), eps_, variance_.mutable_cpu_data()); |
| 128 | + caffe_powx(variance_.count(), variance_.cpu_data(), Dtype(0.5), |
| 129 | + variance_.mutable_cpu_data()); |
| 130 | + |
| 131 | + // do mean and variance normalization |
| 132 | + if (bottom[0] != top[0]) { |
| 133 | + caffe_copy(bottom[0]->count(), bottom_data, top_data); |
| 134 | + } |
| 135 | + // subtract mean |
| 136 | + caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, |
| 137 | + batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0., |
| 138 | + num_by_chans_.mutable_cpu_data()); |
| 139 | + caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num, |
| 140 | + spatial_dim, 1, -1, num_by_chans_.cpu_data(), |
| 141 | + spatial_sum_multiplier_.cpu_data(), 1., top_data); |
| 142 | + // replicate variance to input size |
| 143 | + caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, |
| 144 | + batch_sum_multiplier_.cpu_data(), variance_.cpu_data(), 0., |
| 145 | + num_by_chans_.mutable_cpu_data()); |
| 146 | + caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num, |
| 147 | + spatial_dim, 1, 1., num_by_chans_.cpu_data(), |
| 148 | + spatial_sum_multiplier_.cpu_data(), 0., temp_.mutable_cpu_data()); |
| 149 | + caffe_div(temp_.count(), top_data, temp_.cpu_data(), top_data); |
| 150 | + // TODO(cdoersch): The caching is only needed because later in-place layers |
| 151 | + // might clobber the data. Can we skip this if they won't? |
| 152 | + caffe_copy(x_norm_.count(), top_data, |
| 153 | + x_norm_.mutable_cpu_data()); |
| 154 | +} |
| 155 | + |
| 156 | +template <typename Dtype> |
| 157 | +void BatchNormLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, |
| 158 | + const vector<bool>& propagate_down, |
| 159 | + const vector<Blob<Dtype>*>& bottom) { |
| 160 | + CHECK(!use_global_stats_); |
| 161 | + const Dtype* top_diff; |
| 162 | + if (bottom[0] != top[0]) { |
| 163 | + top_diff = top[0]->cpu_diff(); |
| 164 | + } else { |
| 165 | + caffe_copy(x_norm_.count(), top[0]->cpu_diff(), x_norm_.mutable_cpu_diff()); |
| 166 | + top_diff = x_norm_.cpu_diff(); |
| 167 | + } |
| 168 | + const Dtype* top_data = x_norm_.cpu_data(); |
| 169 | + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); |
| 170 | + int num = bottom[0]->shape()[0]; |
| 171 | + int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_); |
| 172 | + // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then |
| 173 | + // |
| 174 | + // dE(Y)/dX = |
| 175 | + // (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y) |
| 176 | + // ./ sqrt(var(X) + eps) |
| 177 | + // |
| 178 | + // where \cdot and ./ are hadamard product and elementwise division, |
| 179 | + // respectively, dE/dY is the top diff, and mean/var/sum are all computed |
| 180 | + // along all dimensions except the channels dimension. In the above |
| 181 | + // equation, the operations allow for expansion (i.e. broadcast) along all |
| 182 | + // dimensions except the channels dimension where required. |
| 183 | + |
| 184 | + // sum(dE/dY \cdot Y) |
| 185 | + caffe_mul(temp_.count(), top_data, top_diff, bottom_diff); |
| 186 | + caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1., |
| 187 | + bottom_diff, spatial_sum_multiplier_.cpu_data(), 0., |
| 188 | + num_by_chans_.mutable_cpu_data()); |
| 189 | + caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1., |
| 190 | + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., |
| 191 | + mean_.mutable_cpu_data()); |
| 192 | + |
| 193 | + // reshape (broadcast) the above |
| 194 | + caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, |
| 195 | + batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0., |
| 196 | + num_by_chans_.mutable_cpu_data()); |
| 197 | + caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num, |
| 198 | + spatial_dim, 1, 1., num_by_chans_.cpu_data(), |
| 199 | + spatial_sum_multiplier_.cpu_data(), 0., bottom_diff); |
| 200 | + |
| 201 | + // sum(dE/dY \cdot Y) \cdot Y |
| 202 | + caffe_mul(temp_.count(), top_data, bottom_diff, bottom_diff); |
| 203 | + |
| 204 | + // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y |
| 205 | + caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1., |
| 206 | + top_diff, spatial_sum_multiplier_.cpu_data(), 0., |
| 207 | + num_by_chans_.mutable_cpu_data()); |
| 208 | + caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1., |
| 209 | + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., |
| 210 | + mean_.mutable_cpu_data()); |
| 211 | + // reshape (broadcast) the above to make |
| 212 | + // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y |
| 213 | + caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, |
| 214 | + batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0., |
| 215 | + num_by_chans_.mutable_cpu_data()); |
| 216 | + caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num * channels_, |
| 217 | + spatial_dim, 1, 1., num_by_chans_.cpu_data(), |
| 218 | + spatial_sum_multiplier_.cpu_data(), 1., bottom_diff); |
| 219 | + |
| 220 | + // dE/dY - mean(dE/dY)-mean(dE/dY \cdot Y) \cdot Y |
| 221 | + caffe_cpu_axpby(temp_.count(), Dtype(1), top_diff, |
| 222 | + Dtype(-1. / (num * spatial_dim)), bottom_diff); |
| 223 | + |
| 224 | + // note: temp_ still contains sqrt(var(X)+eps), computed during the forward |
| 225 | + // pass. |
| 226 | + caffe_div(temp_.count(), bottom_diff, temp_.cpu_data(), bottom_diff); |
| 227 | +} |
| 228 | + |
| 229 | + |
| 230 | +#ifdef CPU_ONLY |
| 231 | +STUB_GPU(BatchNormLayer); |
| 232 | +#endif |
| 233 | + |
| 234 | +INSTANTIATE_CLASS(BatchNormLayer); |
| 235 | +REGISTER_LAYER_CLASS(BatchNorm); |
| 236 | +} // namespace caffe |
0 commit comments