Skip to content

Commit 0ab7239

Browse files
committed
Added ParseEvaluate Layer.
1 parent 940f923 commit 0ab7239

File tree

3 files changed

+153
-1
lines changed

3 files changed

+153
-1
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#ifndef CAFFE_PARSE_EVALUATE_LAYER_HPP_
2+
#define CAFFE_PARSE_EVALUATE_LAYER_HPP_
3+
4+
#include <set>
5+
#include <vector>
6+
7+
#include "caffe/blob.hpp"
8+
#include "caffe/layer.hpp"
9+
#include "caffe/proto/caffe.pb.h"
10+
11+
namespace caffe {
12+
13+
/**
14+
* @brief Count the prediction and ground truth statistics for each datum.
15+
*
16+
* NOTE: This does not implement Backwards operation.
17+
*/
18+
template <typename Dtype>
19+
class ParseEvaluateLayer : public Layer<Dtype> {
20+
public:
21+
/**
22+
* @param param provides ParseEvaluateParameter parse_evaluate_param,
23+
* with ParseEvaluateLayer options:
24+
* - num_labels (\b optional int32.).
25+
* number of labels. must provide!!
26+
* - ignore_label (\b repeated int32).
27+
* If any, ignore evaluating the corresponding label for each
28+
* image.
29+
*/
30+
explicit ParseEvaluateLayer(const LayerParameter& param)
31+
: Layer<Dtype>(param) {}
32+
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
33+
const vector<Blob<Dtype>*>& top);
34+
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
35+
const vector<Blob<Dtype>*>& top);
36+
37+
virtual inline const char* type() const { return "ParseEvaluate"; }
38+
virtual inline int ExactNumBottomBlobs() const { return 2; }
39+
virtual inline int ExactNumTopBlobs() const { return 1; }
40+
41+
protected:
42+
/**
43+
* @param bottom input Blob vector (length 2)
44+
* -# @f$ (N \times 1 \times H \times W) @f$
45+
* the prediction label @f$ x @f$
46+
* -# @f$ (N \times 1 \times H \times W) @f$
47+
* the ground truth label @f$ x @f$
48+
* @param top output Blob vector (length 1)
49+
* -# @f$ (N \times C \times 1 \times 3) @f$
50+
* the counts for different class @f$
51+
*/
52+
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
53+
const vector<Blob<Dtype>*>& top);
54+
/// @brief Not implemented (non-differentiable function)
55+
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
56+
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
57+
NOT_IMPLEMENTED;
58+
}
59+
60+
// number of total labels
61+
int num_labels_;
62+
// store ignored labels
63+
std::set<Dtype> ignore_labels_;
64+
};
65+
66+
} // namespace caffe
67+
68+
#endif // CAFFE_PARSE_EVALUATE_LAYER_HPP_
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#include <algorithm>
2+
#include <functional>
3+
#include <utility>
4+
#include <vector>
5+
6+
#include "caffe/layer.hpp"
7+
#include "caffe/layers/parse_evaluate_layer.hpp"
8+
#include "caffe/util/io.hpp"
9+
#include "caffe/util/math_functions.hpp"
10+
11+
namespace caffe {
12+
13+
template <typename Dtype>
14+
void ParseEvaluateLayer<Dtype>::LayerSetUp(
15+
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
16+
const ParseEvaluateParameter& parse_evaluate_param =
17+
this->layer_param_.parse_evaluate_param();
18+
CHECK(parse_evaluate_param.has_num_labels()) << "Must have num_labels!!";
19+
num_labels_ = parse_evaluate_param.num_labels();
20+
ignore_labels_.clear();
21+
int num_ignore_label = parse_evaluate_param.ignore_label().size();
22+
for (int i = 0; i < num_ignore_label; ++i) {
23+
ignore_labels_.insert(parse_evaluate_param.ignore_label(i));
24+
}
25+
}
26+
27+
template <typename Dtype>
28+
void ParseEvaluateLayer<Dtype>::Reshape(
29+
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
30+
CHECK_EQ(bottom[0]->num(), bottom[1]->num())
31+
<< "The data and label should have the same number.";
32+
CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
33+
CHECK_EQ(bottom[0]->channels(), 1);
34+
CHECK_EQ(bottom[0]->height(), bottom[1]->height());
35+
CHECK_GE(bottom[0]->width(), bottom[1]->width());
36+
top[0]->Reshape(1, num_labels_, 1, 3);
37+
}
38+
39+
template <typename Dtype>
40+
void ParseEvaluateLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
41+
const vector<Blob<Dtype>*>& top) {
42+
CHECK_EQ(bottom[0]->num(), bottom[1]->num());
43+
CHECK_EQ(bottom[0]->count(), bottom[1]->count());
44+
const Dtype* bottom_pred = bottom[0]->cpu_data();
45+
const Dtype* bottom_gt = bottom[1]->cpu_data();
46+
Dtype* top_data = top[0]->mutable_cpu_data();
47+
caffe_set(top[0]->count(), Dtype(0), top_data);
48+
int num = bottom[0]->num();
49+
int spatial_dim = bottom[0]->height() * bottom[0]->width();
50+
for (int i = 0; i < num; ++i) {
51+
// count the number of ground truth labels, the predicted labels, and
52+
// predicted labels happens to be ground truth labels
53+
for (int j = 0; j < spatial_dim; ++j) {
54+
int gt_label = bottom_gt[j];
55+
int pred_label = bottom_pred[j];
56+
CHECK_LT(pred_label, num_labels_);
57+
if (ignore_labels_.find(gt_label) != ignore_labels_.end()) {
58+
continue;
59+
}
60+
if (gt_label == pred_label) {
61+
top_data[gt_label * 3]++;
62+
}
63+
top_data[gt_label * 3 + 1]++;
64+
top_data[pred_label * 3 + 2]++;
65+
}
66+
bottom_pred += bottom[0]->offset(1);
67+
bottom_gt += bottom[1]->offset(1);
68+
}
69+
// ParseEvaluate layer should not be used as a loss function.
70+
}
71+
72+
INSTANTIATE_CLASS(ParseEvaluateLayer);
73+
REGISTER_LAYER_CLASS(ParseEvaluate);
74+
75+
} // namespace caffe

src/caffe/proto/caffe.proto

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ message ParamSpec {
306306
// NOTE
307307
// Update the next available ID when you add a new LayerParameter field.
308308
//
309-
// LayerParameter next available layer-specific ID: 146 (last added: unpooling_param)
309+
// LayerParameter next available layer-specific ID: 147 (last added: parse_evaluate_param)
310310
message LayerParameter {
311311
optional string name = 1; // the layer name
312312
optional string type = 2; // the layer type
@@ -380,6 +380,7 @@ message LayerParameter {
380380
optional MemoryDataParameter memory_data_param = 119;
381381
optional MVNParameter mvn_param = 120;
382382
optional NormalizeParameter norm_param = 144;
383+
optional ParseEvaluateParameter parse_evaluate_param = 146;
383384
optional PoolingParameter pooling_param = 121;
384385
optional PowerParameter power_param = 122;
385386
optional PReLUParameter prelu_param = 131;
@@ -862,6 +863,14 @@ message NormalizeParameter {
862863
optional float eps = 4 [default = 1e-10];
863864
}
864865

866+
// Message that stores parameters used by ParseEvaluateLayer
867+
message ParseEvaluateParameter {
868+
// Number of total labels. Must provide.
869+
optional int32 num_labels = 1;
870+
// Ignore evaluating following labels.
871+
repeated int32 ignore_label = 2;
872+
}
873+
865874
message PoolingParameter {
866875
enum PoolMethod {
867876
MAX = 0;

0 commit comments

Comments
 (0)