|
| 1 | +#include <algorithm> |
| 2 | +#include <functional> |
| 3 | +#include <utility> |
| 4 | +#include <vector> |
| 5 | + |
| 6 | +#include "caffe/layer.hpp" |
| 7 | +#include "caffe/layers/parse_evaluate_layer.hpp" |
| 8 | +#include "caffe/util/io.hpp" |
| 9 | +#include "caffe/util/math_functions.hpp" |
| 10 | + |
| 11 | +namespace caffe { |
| 12 | + |
| 13 | +template <typename Dtype> |
| 14 | +void ParseEvaluateLayer<Dtype>::LayerSetUp( |
| 15 | + const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { |
| 16 | + const ParseEvaluateParameter& parse_evaluate_param = |
| 17 | + this->layer_param_.parse_evaluate_param(); |
| 18 | + CHECK(parse_evaluate_param.has_num_labels()) << "Must have num_labels!!"; |
| 19 | + num_labels_ = parse_evaluate_param.num_labels(); |
| 20 | + ignore_labels_.clear(); |
| 21 | + int num_ignore_label = parse_evaluate_param.ignore_label().size(); |
| 22 | + for (int i = 0; i < num_ignore_label; ++i) { |
| 23 | + ignore_labels_.insert(parse_evaluate_param.ignore_label(i)); |
| 24 | + } |
| 25 | +} |
| 26 | + |
| 27 | +template <typename Dtype> |
| 28 | +void ParseEvaluateLayer<Dtype>::Reshape( |
| 29 | + const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { |
| 30 | + CHECK_EQ(bottom[0]->num(), bottom[1]->num()) |
| 31 | + << "The data and label should have the same number."; |
| 32 | + CHECK_EQ(bottom[0]->channels(), bottom[1]->channels()); |
| 33 | + CHECK_EQ(bottom[0]->channels(), 1); |
| 34 | + CHECK_EQ(bottom[0]->height(), bottom[1]->height()); |
| 35 | + CHECK_GE(bottom[0]->width(), bottom[1]->width()); |
| 36 | + top[0]->Reshape(1, num_labels_, 1, 3); |
| 37 | +} |
| 38 | + |
| 39 | +template <typename Dtype> |
| 40 | +void ParseEvaluateLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, |
| 41 | + const vector<Blob<Dtype>*>& top) { |
| 42 | + CHECK_EQ(bottom[0]->num(), bottom[1]->num()); |
| 43 | + CHECK_EQ(bottom[0]->count(), bottom[1]->count()); |
| 44 | + const Dtype* bottom_pred = bottom[0]->cpu_data(); |
| 45 | + const Dtype* bottom_gt = bottom[1]->cpu_data(); |
| 46 | + Dtype* top_data = top[0]->mutable_cpu_data(); |
| 47 | + caffe_set(top[0]->count(), Dtype(0), top_data); |
| 48 | + int num = bottom[0]->num(); |
| 49 | + int spatial_dim = bottom[0]->height() * bottom[0]->width(); |
| 50 | + for (int i = 0; i < num; ++i) { |
| 51 | + // count the number of ground truth labels, the predicted labels, and |
| 52 | + // predicted labels happens to be ground truth labels |
| 53 | + for (int j = 0; j < spatial_dim; ++j) { |
| 54 | + int gt_label = bottom_gt[j]; |
| 55 | + int pred_label = bottom_pred[j]; |
| 56 | + CHECK_LT(pred_label, num_labels_); |
| 57 | + if (ignore_labels_.find(gt_label) != ignore_labels_.end()) { |
| 58 | + continue; |
| 59 | + } |
| 60 | + if (gt_label == pred_label) { |
| 61 | + top_data[gt_label * 3]++; |
| 62 | + } |
| 63 | + top_data[gt_label * 3 + 1]++; |
| 64 | + top_data[pred_label * 3 + 2]++; |
| 65 | + } |
| 66 | + bottom_pred += bottom[0]->offset(1); |
| 67 | + bottom_gt += bottom[1]->offset(1); |
| 68 | + } |
| 69 | + // ParseEvaluate layer should not be used as a loss function. |
| 70 | +} |
| 71 | + |
| 72 | +INSTANTIATE_CLASS(ParseEvaluateLayer); |
| 73 | +REGISTER_LAYER_CLASS(ParseEvaluate); |
| 74 | + |
| 75 | +} // namespace caffe |
0 commit comments