diff --git a/LICENSE b/LICENSE index 3316218f5..4c057d39d 100644 --- a/LICENSE +++ b/LICENSE @@ -51,3 +51,29 @@ CONTRIBUTION AGREEMENT By contributing to the BVLC/caffe repository through pull-request, comment, or otherwise, the contributor releases their content to the license and copyright terms herein. + +************************************************************************ + +Faster R-CNN + +The MIT License (MIT) + +Copyright (c) 2015 Microsoft Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/include/caffe/gen_anchors.hpp b/include/caffe/gen_anchors.hpp new file mode 100644 index 000000000..3b9973274 --- /dev/null +++ b/include/caffe/gen_anchors.hpp @@ -0,0 +1,75 @@ +/* +All modification made by Intel Corporation: © 2017 Intel Corporation + +All contributions by the University of California: +Copyright (c) 2014, 2015, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014, 2015, the respective contributors +All rights reserved. +For the list of contributors go to https://github.com/BVLC/caffe/blob/master/CONTRIBUTORS.md + + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of Intel Corporation nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef GEN_ANCHORS +#define GEN_ANCHORS + +#include + +using namespace std; + +namespace caffe { + +/** + * @brief Type of faster-rcnn anchor + */ +struct anchor { + float start_x; + float start_y; + float end_x; + float end_y; + + anchor() {} + + anchor(float s_x, float s_y, float e_x, float e_y) + { + start_x = s_x; + start_y = s_y; + end_x = e_x; + end_y = e_y; + } +}; + + +/** + * @brief Generates a vector of anchors based on a size, list of ratios and list of scales + */ +void GenerateAnchors(unsigned int base_size, const vector& ratios, const vector scales, // input + anchor *anchors); // output +} + +#endif diff --git a/include/caffe/layers/dropout_layer.hpp b/include/caffe/layers/dropout_layer.hpp index 5e35e13a3..d4b327ed9 100644 --- a/include/caffe/layers/dropout_layer.hpp +++ b/include/caffe/layers/dropout_layer.hpp @@ -103,13 +103,14 @@ class DropoutLayer : public NeuronLayer { virtual void Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom); - /// when divided by UINT_MAX, the randomly generated values @f$u\sim U(0,1)@f$ + /// when divided by uint_MAX, the randomly generated values @f$u\sim U(0,1)@f$ Blob rand_vec_; /// the probability @f$ p @f$ of dropping any input Dtype threshold_; /// the scale for undropped inputs at train time @f$ 1 / (1 - p) @f$ Dtype scale_; unsigned int uint_thres_; + bool scale_train_; }; } // namespace caffe diff --git a/include/caffe/layers/fast_rcnn_layers.hpp b/include/caffe/layers/fast_rcnn_layers.hpp new file mode 100644 index 000000000..8448efbc9 --- /dev/null +++ b/include/caffe/layers/fast_rcnn_layers.hpp @@ -0,0 +1,207 @@ +// ------------------------------------------------------------------ +// Fast R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see LICENSE for details] +// Written by Ross Girshick +// ------------------------------------------------------------------ + +#ifndef CAFFE_FAST_RCNN_LAYERS_HPP_ +#define CAFFE_FAST_RCNN_LAYERS_HPP_ + +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +//#include "caffe/loss_layers.hpp" +#include "caffe/layers/accuracy_layer.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/gen_anchors.hpp" + +namespace caffe { + +/* ROIPoolingLayer - Region of Interest Pooling Layer +*/ +template +class ROIPoolingLayer : public Layer { + public: + explicit ROIPoolingLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "ROIPooling"; } + + virtual inline int MinBottomBlobs() const { return 2; } + virtual inline int MaxBottomBlobs() const { return 2; } + virtual inline int MinTopBlobs() const { return 1; } + virtual inline int MaxTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int channels_; + int height_; + int width_; + int pooled_height_; + int pooled_width_; + Dtype spatial_scale_; + Blob max_idx_; +}; + +template +class SmoothL1LossLayer : public LossLayer { + public: + explicit SmoothL1LossLayer(const LayerParameter& param) + : LossLayer(param), diff_() {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "SmoothL1Loss"; } + + virtual inline int ExactNumBottomBlobs() const { return -1; } + virtual inline int MinBottomBlobs() const { return 2; } + virtual inline int MaxBottomBlobs() const { return 3; } + + /** + * Unlike most loss layers, in the SmoothL1LossLayer we can backpropagate + * to both inputs -- override to return true and always allow force_backward. + */ + virtual inline bool AllowForceBackward(const int bottom_index) const { + return true; + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Blob diff_; + Blob errors_; + bool has_weights_; +}; + +/* SimplerNMSLayer - N Mini-batch Sampling Layer +*/ +template +class SimplerNMSLayer : public Layer { +public: + SimplerNMSLayer(const LayerParameter& param) :Layer(param), + max_proposals_(500), + prob_threshold_(0.5f), + iou_threshold_(0.7f), + min_bbox_size_(16), + feat_stride_(16), + pre_nms_topN_(6000), + post_nms_topN_(300) { + }; + + ~SimplerNMSLayer() { + } + + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual void Reshape(const vector*>& bottom, + const vector*>& top) { + //top[0]->Reshape(std::vector{ 1, 1, max_proposals_, 5 }); + top[0]->Reshape(vector{ (int)post_nms_topN_, 5 }); + } + + virtual inline const char* type() const { return "SimplerNMS"; } + +protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + +private: + int max_proposals_; + float prob_threshold_; + // TODO: add to proto + float iou_threshold_; + int min_bbox_size_; + int feat_stride_; + int pre_nms_topN_; + int post_nms_topN_; + + // relative to center point, + Blob anchors_blob_; + + //TODO: clamp is part of std as of c++17... + constexpr static inline const Dtype clamp_v(const Dtype v, const Dtype v_min, const Dtype v_max) + { + return std::max(v_min, std::min(v, v_max)); + } + struct simpler_nms_roi_t + { + Dtype x0, y0, x1, y1; + + Dtype area() const { return std::max(0, y1 - y0 + 1) * std::max(0, x1 - x0 + 1); } + simpler_nms_roi_t intersect (simpler_nms_roi_t other) const + { + return + { + std::max(x0, other.x0), + std::max(y0, other.y0), + std::min(x1, other.x1), + std::min(y1, other.y1) + }; + } + simpler_nms_roi_t clamp (simpler_nms_roi_t other) const + { + return + { + clamp_v(x0, other.x0, other.x1), + clamp_v(y0, other.y0, other.y1), + clamp_v(x1, other.x0, other.x1), + clamp_v(y1, other.y0, other.y1) + }; + } + }; + + struct simpler_nms_delta_t { Dtype shift_x, shift_y, log_w, log_h; }; + struct simpler_nms_proposal_t { simpler_nms_roi_t roi; Dtype confidence; size_t ord; }; + + static std::vector simpler_nms_perform_nms( + const std::vector& proposals, + float iou_threshold, + size_t top_n); + + static void sort_and_keep_at_most_top_n( + std::vector& proposals, + size_t top_n); + + static simpler_nms_roi_t simpler_nms_gen_bbox( + const anchor& box, + const simpler_nms_delta_t& delta, + int anchor_shift_x, + int anchor_shift_y); +}; + +} // namespace caffe + +#endif // CAFFE_FAST_RCNN_LAYERS_HPP_ diff --git a/src/caffe/layers/dropout_layer.cpp b/src/caffe/layers/dropout_layer.cpp index c23c583de..1a0bd6be2 100644 --- a/src/caffe/layers/dropout_layer.cpp +++ b/src/caffe/layers/dropout_layer.cpp @@ -37,6 +37,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // TODO (sergeyk): effect should not be dependent on phase. wasted memcpy. +#include #include #include "caffe/layers/dropout_layer.hpp" @@ -52,7 +53,11 @@ void DropoutLayer::LayerSetUp(const vector*>& bottom, DCHECK(threshold_ > 0.); DCHECK(threshold_ < 1.); scale_ = 1. / (1. - threshold_); - uint_thres_ = static_cast(UINT_MAX * threshold_); + uint_thres_ = + static_cast(static_cast + (std::numeric_limits::max()) + * static_cast(threshold_)); + scale_train_ = this->layer_param_.dropout_param().scale_train(); } template @@ -74,14 +79,28 @@ void DropoutLayer::Forward_cpu(const vector*>& bottom, if (this->phase_ == TRAIN) { // Create random numbers caffe_rng_bernoulli(count, 1. - threshold_, mask); + if (scale_train_) { #ifdef _OPENMP - #pragma omp parallel for + #pragma omp parallel for +#endif + for (unsigned int i = 0; i < count; ++i) { + top_data[i] = bottom_data[i] * mask[i] * scale_; + } + } + else { +#ifdef _OPENMP + #pragma omp parallel for #endif - for (int i = 0; i < count; ++i) { - top_data[i] = bottom_data[i] * mask[i] * scale_; + for (unsigned int i = 0; i < count; ++i) { + top_data[i] = bottom_data[i] * mask[i]; } - } else { + } + } + else { caffe_copy(bottom[0]->count(), bottom_data, top_data); + if (!scale_train_) { + caffe_scal(count, 1. / scale_, top_data); + } } } @@ -94,19 +113,29 @@ void DropoutLayer::Backward_cpu(const vector*>& top, Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); if (this->phase_ == TRAIN) { const unsigned int* mask = rand_vec_.cpu_data(); - const int count = bottom[0]->count(); + const unsigned int count = bottom[0]->count(); + if (scale_train_) { + for (unsigned int i = 0; i < count; ++i) { + bottom_diff[i] = top_diff[i] * mask[i] * scale_; + } + } + else { #ifdef _OPENMP #pragma omp parallel for #endif - for (int i = 0; i < count; ++i) { - bottom_diff[i] = top_diff[i] * mask[i] * scale_; - } - } else { + for (unsigned int i = 0; i < count; ++i) { + bottom_diff[i] = top_diff[i] * mask[i]; + } + } + } + else { caffe_copy(top[0]->count(), top_diff, bottom_diff); + if (!scale_train_) { + caffe_scal(top[0]->count(), 1. / scale_, bottom_diff); } } } - +} #ifdef CPU_ONLY STUB_GPU(DropoutLayer); diff --git a/src/caffe/layers/dropout_layer.cu b/src/caffe/layers/dropout_layer.cu index 186c10ca4..96fd9f160 100644 --- a/src/caffe/layers/dropout_layer.cu +++ b/src/caffe/layers/dropout_layer.cu @@ -5,62 +5,84 @@ namespace caffe { -template + +#ifdef USE_CUDA +template __global__ void DropoutForward(const int n, const Dtype* in, - const unsigned int* mask, const unsigned int threshold, const float scale, - Dtype* out) { + const unsigned int* mask, + const unsigned int threshold, const float scale, + Dtype* out) { CUDA_KERNEL_LOOP(index, n) { out[index] = in[index] * (mask[index] > threshold) * scale; } } +#endif // USE_CUDA -template +template void DropoutLayer::Forward_gpu(const vector*>& bottom, - const vector*>& top) { + const vector*>& top) { const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); const int count = bottom[0]->count(); - if (this->phase_ == TRAIN) { - unsigned int* mask = - static_cast(rand_vec_.mutable_gpu_data()); - caffe_gpu_rng_uniform(count, mask); - // set thresholds - // NOLINT_NEXT_LINE(whitespace/operators) - DropoutForward<<>>( - count, bottom_data, mask, uint_thres_, scale_, top_data); - CUDA_POST_KERNEL_CHECK; - } else { - caffe_copy(count, bottom_data, top_data); + + if (this->device_->backend() == BACKEND_CUDA) { +#ifdef USE_CUDA + if (this->phase_ == TRAIN) { + unsigned int* mask = + static_cast(rand_vec_.mutable_gpu_data()); + caffe_gpu_rng_uniform(count, (unsigned intc*) (mask)); // NOLINT + // set thresholds + // NOLINT_NEXT_LINE(whitespace/operators) + float scale_val = scale_train_ ? scale_ : 1.f; + DropoutForward CUDA_KERNEL(CAFFE_GET_BLOCKS(count), + CAFFE_CUDA_NUM_THREADS)( + count, bottom_data, mask, uint_thres_, scale_val, top_data); + CUDA_POST_KERNEL_CHECK; + } else { + caffe_copy(count, bottom_data, top_data); + if (! scale_train_) { + caffe_scal(count, 1. / scale_, top_data); + } + } +#endif // USE_CUDA } } -template +#ifdef USE_CUDA +template __global__ void DropoutBackward(const int n, const Dtype* in_diff, - const unsigned int* mask, const unsigned int threshold, const float scale, - Dtype* out_diff) { + const unsigned int* mask, + const unsigned int threshold, const float scale, + Dtype* out_diff) { CUDA_KERNEL_LOOP(index, n) { out_diff[index] = in_diff[index] * scale * (mask[index] > threshold); } } +#endif // USE_CUDA -template +template void DropoutLayer::Backward_gpu(const vector*>& top, - const vector& propagate_down, - const vector*>& bottom) { + const vector& propagate_down, + const vector*>& bottom) { if (propagate_down[0]) { const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); - if (this->phase_ == TRAIN) { - const unsigned int* mask = - static_cast(rand_vec_.gpu_data()); - const int count = bottom[0]->count(); - // NOLINT_NEXT_LINE(whitespace/operators) - DropoutBackward<<>>( - count, top_diff, mask, uint_thres_, scale_, bottom_diff); - CUDA_POST_KERNEL_CHECK; - } else { - caffe_copy(top[0]->count(), top_diff, bottom_diff); + + if (this->device_->backend() == BACKEND_CUDA) { +#ifdef USE_CUDA + if (this->phase_ == TRAIN) { + const unsigned int* mask = static_cast(rand_vec_ + .gpu_data()); + const int count = bottom[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + DropoutBackward CUDA_KERNEL(CAFFE_GET_BLOCKS(count), + CAFFE_CUDA_NUM_THREADS)( + count, top_diff, mask, uint_thres_, scale_, bottom_diff); + CUDA_POST_KERNEL_CHECK; + } else { + caffe_copy(top[0]->count(), top_diff, bottom_diff); + } +#endif // USE_CUDA } } } diff --git a/src/caffe/layers/roi_pooling_layer.cpp b/src/caffe/layers/roi_pooling_layer.cpp new file mode 100755 index 000000000..bfc7a45cc --- /dev/null +++ b/src/caffe/layers/roi_pooling_layer.cpp @@ -0,0 +1,165 @@ +// ------------------------------------------------------------------ +// Fast R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see LICENSE for details] +// Written by Ross Girshick +// ------------------------------------------------------------------ + +#include + +#include "caffe/layers/fast_rcnn_layers.hpp" + +using std::max; +using std::min; +using std::floor; +using std::ceil; + +namespace caffe { + +template +void ROIPoolingLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + ROIPoolingParameter roi_pool_param = this->layer_param_.roi_pooling_param(); + CHECK_GT(roi_pool_param.pooled_h(), 0) + << "pooled_h must be > 0"; + CHECK_GT(roi_pool_param.pooled_w(), 0) + << "pooled_w must be > 0"; + pooled_height_ = roi_pool_param.pooled_h(); + pooled_width_ = roi_pool_param.pooled_w(); + spatial_scale_ = roi_pool_param.spatial_scale(); + LOG(INFO) << "Spatial scale: " << spatial_scale_; +} + +template +void ROIPoolingLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + channels_ = bottom[0]->channels(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); + top[0]->Reshape(bottom[1]->num(), channels_, pooled_height_, + pooled_width_); + max_idx_.Reshape(bottom[1]->num(), channels_, pooled_height_, + pooled_width_); +} + +template +void ROIPoolingLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* bottom_rois = bottom[1]->cpu_data(); + // Number of ROIs + int num_rois = bottom[1]->num(); + int batch_size = bottom[0]->num(); + int top_count = top[0]->count(); + Dtype* top_data = top[0]->mutable_cpu_data(); + caffe_set(top_count, Dtype(-FLT_MAX), top_data); + int* argmax_data = max_idx_.mutable_cpu_data(); + caffe_set(top_count, -1, argmax_data); + + // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R + for (int n = 0; n < num_rois; ++n) { + int roi_batch_ind = bottom_rois[0]; + int roi_start_w = round(bottom_rois[1] * spatial_scale_); + int roi_start_h = round(bottom_rois[2] * spatial_scale_); + int roi_end_w = round(bottom_rois[3] * spatial_scale_); + int roi_end_h = round(bottom_rois[4] * spatial_scale_); + CHECK_GE(roi_batch_ind, 0); + CHECK_LT(roi_batch_ind, batch_size); + + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + + const Dtype* batch_data = bottom_data + bottom[0]->offset(roi_batch_ind); + + for (int c = 0; c < channels_; ++c) { + for (int ph = 0; ph < pooled_height_; ++ph) { + for (int pw = 0; pw < pooled_width_; ++pw) { + // Compute pooling region for this output unit: + // start (included) = floor(ph * roi_height / pooled_height_) + // end (excluded) = ceil((ph + 1) * roi_height / pooled_height_) + + // The following computation of hstart, wstart, hend, wend is + // done with integers due to floating precision errors. + // As the floating point computing on GPU is not identical to CPU, + // integer computing is used as a workaround. + // The following approach also works but requires a rigorous + // analysis: + // int hstart = static_cast(floor((static_cast(ph) + // * static_cast(roi_height)) + // / static_cast(pooled_height_))); + // int wstart = static_cast(floor((static_cast(pw) + // * static_cast(roi_width)) + // / static_cast(pooled_width_))); + // int hend = static_cast(ceil((static_cast(ph + 1) + // * static_cast(roi_height)) + // / static_cast(pooled_height_))); + // int wend = static_cast(ceil((static_cast(pw + 1) + // * static_cast(roi_width)) + // / static_cast(pooled_width_))); + + int hstart = (ph * roi_height) / pooled_height_; + if ( (hstart * pooled_height_) > (ph * roi_height) ) { + --hstart; + } + int wstart = (pw * roi_width) / pooled_width_; + if ( (wstart * pooled_width_) > (pw * roi_width) ) { + --wstart; + } + int hend = ((ph + 1) * roi_height) / pooled_height_; + if ( (hend * pooled_height_) < ((ph + 1) * roi_height) ) { + ++hend; + } + int wend = ((pw + 1) * roi_width) / pooled_width_; + if ( (wend * pooled_width_) < ((pw + 1) * roi_width) ) { + ++wend; + } + + hstart = min(max(hstart + roi_start_h, 0), height_); + hend = min(max(hend + roi_start_h, 0), height_); + wstart = min(max(wstart + roi_start_w, 0), width_); + wend = min(max(wend + roi_start_w, 0), width_); + + bool is_empty = (hend <= hstart) || (wend <= wstart); + + const int pool_index = ph * pooled_width_ + pw; + if (is_empty) { + top_data[pool_index] = 0; + argmax_data[pool_index] = -1; + } + + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int index = h * width_ + w; + if (batch_data[index] > top_data[pool_index]) { + top_data[pool_index] = batch_data[index]; + argmax_data[pool_index] = index; + } + } + } + } + } + // Increment all data pointers by one channel + batch_data += bottom[0]->offset(0, 1); + top_data += top[0]->offset(0, 1); + argmax_data += max_idx_.offset(0, 1); + } + // Increment ROI data pointer + bottom_rois += bottom[1]->offset(1); + } +} + +template +void ROIPoolingLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + //NOT_IMPLEMENTED; +} + + +#ifdef CPU_ONLY +STUB_GPU(ROIPoolingLayer); +#endif + +INSTANTIATE_CLASS(ROIPoolingLayer); +REGISTER_LAYER_CLASS(ROIPooling); + +} // namespace caffe diff --git a/src/caffe/layers/roi_pooling_layer.cu b/src/caffe/layers/roi_pooling_layer.cu new file mode 100644 index 000000000..98666dd4b --- /dev/null +++ b/src/caffe/layers/roi_pooling_layer.cu @@ -0,0 +1,190 @@ +// ------------------------------------------------------------------ +// Fast R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see LICENSE for details] +// Written by Ross Girshick +// ------------------------------------------------------------------ + +#include + +#include "caffe/layers/fast_rcnn_layers.hpp" + +using std::max; +using std::min; + +namespace caffe { + +#ifdef USE_CUDA +template +__global__ void ROIPoolForward(const int nthreads, const Dtype* bottom_data, + const Dtype spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const Dtype* bottom_rois, Dtype* top_data, int* argmax_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + bottom_rois += n * 5; + int roi_batch_ind = bottom_rois[0]; + int roi_start_w = round(bottom_rois[1] * spatial_scale); + int roi_start_h = round(bottom_rois[2] * spatial_scale); + int roi_end_w = round(bottom_rois[3] * spatial_scale); + int roi_end_h = round(bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + Dtype bin_size_h = static_cast(roi_height) + / static_cast(pooled_height); + Dtype bin_size_w = static_cast(roi_width) + / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) + * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) + * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) + * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) + * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + Dtype maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + bottom_data += (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h * width + w; + if (bottom_data[bottom_index] > maxval) { + maxval = bottom_data[bottom_index]; + maxidx = bottom_index; + } + } + } + top_data[index] = maxval; + argmax_data[index] = maxidx; + } +} + +template +void ROIPoolingLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + const Dtype* bottom_rois = bottom[1]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + int* argmax_data = max_idx_.mutable_gpu_data(); + int count = top[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + ROIPoolForward<<>>( + count, bottom_data, spatial_scale_, channels_, height_, width_, + pooled_height_, pooled_width_, bottom_rois, top_data, argmax_data); + CUDA_POST_KERNEL_CHECK; +} + +template +__global__ void ROIPoolBackward(const int nthreads, const Dtype* top_diff, + const int* argmax_data, const int num_rois, const Dtype spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, Dtype* bottom_diff, + const Dtype* bottom_rois) { + CUDA_KERNEL_LOOP(index, nthreads) { + // (n, c, h, w) coords in bottom data + int w = index % width; + int h = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + + Dtype gradient = 0; + // Accumulate gradient over all ROIs that pooled this element + for (int roi_n = 0; roi_n < num_rois; ++roi_n) { + const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + // Skip if ROI's batch index doesn't match n + if (n != roi_batch_ind) { + continue; + } + + int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Skip if ROI doesn't include (h, w) + const bool in_roi = (w >= roi_start_w && w <= roi_end_w && + h >= roi_start_h && h <= roi_end_h); + if (!in_roi) { + continue; + } + + int offset = (roi_n * channels + c) * pooled_height * pooled_width; + const Dtype* offset_top_diff = top_diff + offset; + const int* offset_argmax_data = argmax_data + offset; + + // Compute feasible set of pooled units that could have pooled + // this bottom unit + + // Force malformed ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + + Dtype bin_size_h = static_cast(roi_height) + / static_cast(pooled_height); + Dtype bin_size_w = static_cast(roi_width) + / static_cast(pooled_width); + + int phstart = floor(static_cast(h - roi_start_h) / bin_size_h); + int phend = ceil(static_cast(h - roi_start_h + 1) / bin_size_h); + int pwstart = floor(static_cast(w - roi_start_w) / bin_size_w); + int pwend = ceil(static_cast(w - roi_start_w + 1) / bin_size_w); + + phstart = min(max(phstart, 0), pooled_height); + phend = min(max(phend, 0), pooled_height); + pwstart = min(max(pwstart, 0), pooled_width); + pwend = min(max(pwend, 0), pooled_width); + + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + if (offset_argmax_data[ph * pooled_width + pw] == (h * width + w)) { + gradient += offset_top_diff[ph * pooled_width + pw]; + } + } + } + } + bottom_diff[index] = gradient; + } +} + +template +void ROIPoolingLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { + return; + } + const Dtype* bottom_rois = bottom[1]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int count = bottom[0]->count(); + caffe_gpu_set(count, Dtype(0.), bottom_diff); + const int* argmax_data = max_idx_.gpu_data(); + // NOLINT_NEXT_LINE(whitespace/operators) + ROIPoolBackward<<>>( + count, top_diff, argmax_data, top[0]->num(), spatial_scale_, channels_, + height_, width_, pooled_height_, pooled_width_, bottom_diff, bottom_rois); + CUDA_POST_KERNEL_CHECK; +} + +INSTANTIATE_LAYER_GPU_FUNCS(ROIPoolingLayer); +#else +#endif +} // namespace caffe diff --git a/src/caffe/layers/simpler_nms_layer.cpp b/src/caffe/layers/simpler_nms_layer.cpp new file mode 100644 index 000000000..8b19d3053 --- /dev/null +++ b/src/caffe/layers/simpler_nms_layer.cpp @@ -0,0 +1,275 @@ +/* +All modification made by Intel Corporation: © 2016 Intel Corporation + +All contributions by the University of California: +Copyright (c) 2014, 2015, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014, 2015, the respective contributors +All rights reserved. +For the list of contributors go to https://github.com/BVLC/caffe/blob/master/CONTRIBUTORS.md + + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of Intel Corporation nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#include +#include +#include +#include "caffe/gen_anchors.hpp" +#include "caffe/layers/fast_rcnn_layers.hpp" + + +namespace caffe +{ +template +void SimplerNMSLayer::LayerSetUp(const vector*>& bottom, const vector*>& top) +{ + const SimplerNMSParameter& nms_param = this->layer_param_.simpler_nms_param(); + //if (this->layer_param_.has_simpler_nms_param()) //TODO: Check why the function returns false + //{ + max_proposals_ = nms_param.max_num_proposals(); + prob_threshold_ = nms_param.cls_threshold(); + iou_threshold_ = nms_param.iou_threshold(); + min_bbox_size_ = nms_param.min_bbox_size(); + //TODO: handle feat_stride + CHECK(nms_param.feat_stride() == 16) << this->type() << " layer currently doesn't support other feat_stride value than 16."; + feat_stride_ = nms_param.feat_stride(); + pre_nms_topN_ = nms_param.pre_nms_topn(); + post_nms_topN_ = nms_param.post_nms_topn(); + + vector scales(nms_param.scale_size()); + + for (int i = 0 ; i < nms_param.scale_size() ; i++) { + scales[i] = nms_param.scale(i); + } + + vector default_ratios(3); + default_ratios[0] = 0.5f; + default_ratios[1] = 1.0f; + default_ratios[2] = 2.0f; + + unsigned int default_size = 16; + + anchors_blob_.Reshape(default_ratios.size(), scales.size(), sizeof(anchor) / sizeof(float), 1); + anchor *anchors = (anchor*) anchors_blob_.mutable_cpu_data(); + GenerateAnchors(default_size, default_ratios, scales, anchors); +} + +template +void SimplerNMSLayer::Forward_cpu(const vector*>& bottom, const vector*>& top) +{ + int anchors_num = anchors_blob_.shape(0) * anchors_blob_.shape(1); + const anchor* anchors = (anchor*) anchors_blob_.cpu_data(); + + // feat map sizes + int fm_w = bottom[0]->shape(3); + int fm_h = bottom[0]->shape(2); + int fm_sz = fm_w * fm_h; + + // original input image to the graph (after possible scaling etc.) so that coordinates are valid for it + int img_w = (int)bottom[2]->cpu_data()[1]; + int img_h = (int)bottom[2]->cpu_data()[0]; + + //TODO(ruv): what is it being multipied by, here?? + int scaled_min_bbox_size = min_bbox_size_ * (int)bottom[2]->cpu_data()[2]; + + const Dtype* bottom_cls_scores = bottom[0]->cpu_data(); + const Dtype* bottom_delta_pred = bottom[1]->cpu_data(); + Dtype * top_data = top[0]->mutable_cpu_data(); + + std::vector sorted_proposals_confidence; + for (unsigned y = 0; y < fm_h; ++y) + { + int anchor_shift_y = y * feat_stride_; + + for (unsigned x = 0; x < fm_w; ++x) + { + int anchor_shift_x = x * feat_stride_; + int location_index = y * fm_w + x; + + // we assume proposals are grouped by window location + for (int anchor_index = 0; anchor_index < anchors_num ; anchor_index++) + { + Dtype dx0 = bottom_delta_pred[location_index + fm_sz * (anchor_index * 4 + 0)]; + Dtype dy0 = bottom_delta_pred[location_index + fm_sz * (anchor_index * 4 + 1)]; + Dtype dx1 = bottom_delta_pred[location_index + fm_sz * (anchor_index * 4 + 2)]; + Dtype dy1 = bottom_delta_pred[location_index + fm_sz * (anchor_index * 4 + 3)]; + simpler_nms_delta_t bbox_delta { dx0, dy0, dx1, dy1 }; + + Dtype proposal_confidence = + bottom_cls_scores[location_index + fm_sz * (anchor_index + anchors_num * 1)]; + + simpler_nms_roi_t tmp_roi = simpler_nms_gen_bbox(anchors[anchor_index], bbox_delta, anchor_shift_x, anchor_shift_y); + simpler_nms_roi_t roi = tmp_roi.clamp({ 0, 0, Dtype(img_w - 1), Dtype(img_h - 1) }); + + int bbox_w = roi.x1 - roi.x0 + 1; + int bbox_h = roi.y1 - roi.y0 + 1; + + if (bbox_w >= scaled_min_bbox_size && bbox_h >= scaled_min_bbox_size) + { + simpler_nms_proposal_t proposal { roi, proposal_confidence, sorted_proposals_confidence.size() }; + sorted_proposals_confidence.push_back(proposal); + } + } + } + } + + sort_and_keep_at_most_top_n(sorted_proposals_confidence, pre_nms_topN_); + auto res = simpler_nms_perform_nms(sorted_proposals_confidence, iou_threshold_, post_nms_topN_); + + size_t res_num_rois = res.size(); + for (size_t i = 0; i < res_num_rois; ++i) + { + top_data[5 * i + 0] = 0; // roi_batch_ind, always zero on test time + top_data[5 * i + 1] = res[i].x0; + top_data[5 * i + 2] = res[i].y0; + top_data[5 * i + 3] = res[i].x1; + top_data[5 * i + 4] = res[i].y1; + } + + top[0]->Reshape(vector{ (int)res_num_rois, 5 }); +} + +template +std::vector< typename SimplerNMSLayer::simpler_nms_roi_t > +SimplerNMSLayer::simpler_nms_perform_nms( + const std::vector& proposals, + float iou_threshold, + size_t top_n) +{ +//TODO(ruv): can I mark the 1st arg, proposals as const? ifndef DONT_PRECALC_AREA, i can +//TODO(ruv): is it better to do the precalc or not? since we need to fetch the floats from memory anyway for - +// intersect calc, it's only a question of whether it's faster to do (f-f)*(f-f) or fetch another val +#define DONT_PRECALC_AREA + +#ifndef DONT_PRECALC_AREA + std::vector areas; + areas.reserve(proposals.size()); + std::transform(proposals.begin(), proposals.end(), areas.begin(), [](const simpler_nms_proposals_t>& v) + { + return v.roi.area(); + }); +#endif + + std::vector res; + res.reserve(top_n); +#ifdef DONT_PRECALC_AREA + for (const auto & prop : proposals) + { + const auto bbox = prop.roi; + const Dtype area = bbox.area(); +#else + size_t proposal_count = proposals.size(); + for (size_t proposalIndex = 0; proposalIndex < proposal_count; ++proposalIndex) + { + const auto & bbox = proposals[proposalIndex].roi; +#endif + + // For any realistic WL, this condition is true for all top_n values anyway + if (prop.confidence > 0) + { + bool overlaps = std::any_of(res.begin(), res.end(), [&](const simpler_nms_roi_t& res_bbox) + { + Dtype interArea = bbox.intersect(res_bbox).area(); +#ifdef DONT_PRECALC_AREA + Dtype unionArea = res_bbox.area() + area - interArea; +#else + Dtype unionArea = res_bbox.area() + areas[proposalIndex] - interArea; +#endif + + return interArea > iou_threshold * unionArea; + }); + + if (! overlaps) + { + res.push_back(bbox); + if (res.size() == top_n) break; + } + } + } + + return res; +} + +template +inline void SimplerNMSLayer::sort_and_keep_at_most_top_n( + std::vector& proposals, + size_t top_n) +{ + const auto cmp_fn = [](const simpler_nms_proposal_t& a, + const simpler_nms_proposal_t& b) + { + return a.confidence > b.confidence || (a.confidence == b.confidence && a.ord > b.ord); + }; + + if (proposals.size() > top_n) + { + std::partial_sort(proposals.begin(), proposals.begin() + top_n, proposals.end(), cmp_fn); + proposals.resize(top_n); + } + else + std::sort(proposals.begin(), proposals.end(), cmp_fn); +} + +template +inline typename SimplerNMSLayer::simpler_nms_roi_t +SimplerNMSLayer::simpler_nms_gen_bbox( + const anchor& box, + const simpler_nms_delta_t& delta, + int anchor_shift_x, + int anchor_shift_y) +{ + auto anchor_w = box.end_x - box.start_x + 1; + auto anchor_h = box.end_y - box.start_y + 1; + auto center_x = box.start_x + anchor_w * .5f; + auto center_y = box.start_y + anchor_h *.5f; + + Dtype pred_center_x = delta.shift_x * anchor_w + center_x + anchor_shift_x; + Dtype pred_center_y = delta.shift_y * anchor_h + center_y + anchor_shift_y; + Dtype half_pred_w = std::exp(delta.log_w) * anchor_w * .5f; + Dtype half_pred_h = std::exp(delta.log_h) * anchor_h * .5f; + + return { pred_center_x - half_pred_w, + pred_center_y - half_pred_h, + pred_center_x + half_pred_w, + pred_center_y + half_pred_h }; +} + +template +void SimplerNMSLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) +{ + //NOT_IMPLEMENTED; +}; + +#ifdef CPU_ONLY + STUB_GPU(SimplerNMSLayer); +#endif + +INSTANTIATE_CLASS(SimplerNMSLayer); +REGISTER_LAYER_CLASS(SimplerNMS); + +} diff --git a/src/caffe/layers/simpler_nms_layer.cu b/src/caffe/layers/simpler_nms_layer.cu new file mode 100644 index 000000000..a28be91cf --- /dev/null +++ b/src/caffe/layers/simpler_nms_layer.cu @@ -0,0 +1,44 @@ +/* +All modification made by Intel Corporation: © 2016 Intel Corporation + +All contributions by the University of California: +Copyright (c) 2014, 2015, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014, 2015, the respective contributors +All rights reserved. +For the list of contributors go to https://github.com/BVLC/caffe/blob/master/CONTRIBUTORS.md + + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of Intel Corporation nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#include + +#include "caffe/layer.hpp" +#include "caffe/layers/fast_rcnn_layers.hpp" +#include "caffe/util/math_functions.hpp" + +// Not implemented \ No newline at end of file diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index c4c5228e5..1d79d5624 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -576,6 +576,8 @@ message LayerParameter { optional ReshapeParameter reshape_param = 133; optional ScaleParameter scale_param = 142; optional SigmoidParameter sigmoid_param = 124; + optional ROIPoolingParameter roi_pooling_param = 8266711; + optional SimplerNMSParameter simpler_nms_param = 8266712; optional SoftmaxParameter softmax_param = 125; optional SPPParameter spp_param = 132; optional SplitParameter split_param = 147; @@ -1142,6 +1144,7 @@ message DetectionOutputParameter { message DropoutParameter { optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio + optional bool scale_train = 2 [default = true]; // scale train or test phase } // DummyDataLayer fills any number of arbitrarily shaped blobs with random @@ -1628,6 +1631,17 @@ message ReLUParameter { optional Engine engine = 2 [default = DEFAULT]; } +// Message that stores parameters used by ROIPoolingLayer +message ROIPoolingParameter { + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pooled_h = 1 [default = 0]; // The pooled output height + optional uint32 pooled_w = 2 [default = 0]; // The pooled output width + // Multiplicative spatial scale factor to translate ROI coords from their + // input scale to the scale used when pooling + optional float spatial_scale = 3 [default = 1]; +} + message ReshapeParameter { // Specify the output dimensions. If some of the dimensions are set to 0, // the corresponding dimension from the bottom layer is used (unchanged). @@ -1738,6 +1752,18 @@ message SigmoidParameter { optional Engine engine = 1 [default = DEFAULT]; } +// Message that stores parameters used by SimplerNMSLayer +message SimplerNMSParameter { + optional float cls_threshold = 1 [default = 0.5]; + optional uint32 max_num_proposals = 2 [default = 300]; + optional float iou_threshold = 3 [default = 0.7]; + optional uint32 min_bbox_size = 4 [default = 16]; + optional uint32 feat_stride = 5 [default = 16]; + optional uint32 pre_nms_topn = 6 [default = 6000]; + optional uint32 post_nms_topn = 7 [default = 300]; + repeated float scale = 8; +} + message SliceParameter { // The axis along which to slice -- may be negative to index from the end // (e.g., -1 for the last axis). diff --git a/src/caffe/util/gen_anchors.cpp b/src/caffe/util/gen_anchors.cpp new file mode 100644 index 000000000..71b7de0dd --- /dev/null +++ b/src/caffe/util/gen_anchors.cpp @@ -0,0 +1,135 @@ +/* +All modification made by Intel Corporation: © 2017 Intel Corporation + +All contributions by the University of California: +Copyright (c) 2014, 2015, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014, 2015, the respective contributors +All rights reserved. +For the list of contributors go to https://github.com/BVLC/caffe/blob/master/CONTRIBUTORS.md + + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of Intel Corporation nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#include +#include + +#include "caffe/gen_anchors.hpp" + +namespace caffe { + + +static void CalcBasicParams(const anchor& base_anchor, // input + float& width, float& height, float& x_center, float& y_center) // output +{ + width = base_anchor.end_x - base_anchor.start_x + 1.0f; + height = base_anchor.end_y - base_anchor.start_y + 1.0f; + + x_center = base_anchor.start_x + 0.5f * (width - 1.0f); + y_center = base_anchor.start_y + 0.5f * (height - 1.0f); +} + + +static void MakeAnchors(const vector& ws, const vector& hs, float x_center, float y_center, // input + vector& anchors) // output +{ + int len = ws.size(); + anchors.clear(); + anchors.resize(len); + + for (unsigned int i = 0 ; i < len ; i++) { + // transpose to create the anchor + anchors[i].start_x = x_center - 0.5f * (ws[i] - 1.0f); + anchors[i].start_y = y_center - 0.5f * (hs[i] - 1.0f); + anchors[i].end_x = x_center + 0.5f * (ws[i] - 1.0f); + anchors[i].end_y = y_center + 0.5f * (hs[i] - 1.0f); + } +} + + +static void CalcAnchors(const anchor& base_anchor, const vector& scales, // input + vector& anchors) // output +{ + float width = 0.0f, height = 0.0f, x_center = 0.0f, y_center = 0.0f; + + CalcBasicParams(base_anchor, width, height, x_center, y_center); + + int num_scales = scales.size(); + vector ws(num_scales), hs(num_scales); + + for (unsigned int i = 0 ; i < num_scales ; i++) { + ws[i] = width * scales[i]; + hs[i] = height * scales[i]; + } + + MakeAnchors(ws, hs, x_center, y_center, anchors); +} + + +static void CalcRatioAnchors(const anchor& base_anchor, const vector& ratios, // input + vector& ratio_anchors) // output +{ + float width = 0.0f, height = 0.0f, x_center = 0.0f, y_center = 0.0f; + + CalcBasicParams(base_anchor, width, height, x_center, y_center); + + float size = width * height; + + int num_ratios = ratios.size(); + + vector ws(num_ratios), hs(num_ratios); + + for (unsigned int i = 0 ; i < num_ratios ; i++) { + float new_size = size / ratios[i]; + ws[i] = round(sqrt(new_size)); + hs[i] = round(ws[i] * ratios[i]); + } + + MakeAnchors(ws, hs, x_center, y_center, ratio_anchors); +} + +void GenerateAnchors(unsigned int base_size, const vector& ratios, const vector scales, // input + anchor *anchors) // output +{ + float end = (float)(base_size - 1); // because we start at zero + + anchor base_anchor(0.0f, 0.0f, end, end); + + vector ratio_anchors; + CalcRatioAnchors(base_anchor, ratios, ratio_anchors); + + for (int i = 0, index = 0; i < ratio_anchors.size() ; i++) { + vector temp_anchors; + CalcAnchors(ratio_anchors[i], scales, temp_anchors); + + for (int j = 0 ; j < temp_anchors.size() ; j++) { + anchors[index++] = temp_anchors[j]; + } + } +} + +} // namespace caffe