Skip to content

Commit b9b47e3

Browse files
committed
Added eval_type to solver.
1 parent 0ab7239 commit b9b47e3

File tree

3 files changed

+103
-5
lines changed

3 files changed

+103
-5
lines changed

include/caffe/solver.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class Solver {
103103
// The test routine
104104
void TestAll();
105105
void Test(const int test_net_id = 0);
106+
void TestSegmentation(const int test_net_id = 0);
106107
virtual void SnapshotSolverState(const string& model_filename) = 0;
107108
virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
108109
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;

src/caffe/proto/caffe.proto

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ message NetParameter {
9898
// NOTE
9999
// Update the next available ID when you add a new SolverParameter field.
100100
//
101-
// SolverParameter next available ID: 41 (last added: type)
101+
// SolverParameter next available ID: 42 (last added: eval_type)
102102
message SolverParameter {
103103
//////////////////////////////////////////////////////////////////////////////
104104
// Specifying the train and test networks
@@ -228,6 +228,9 @@ message SolverParameter {
228228
// If false, don't save a snapshot after training finishes.
229229
optional bool snapshot_after_train = 28 [default = true];
230230

231+
// Evaluation type
232+
optional string eval_type = 41 [default = "classification"];
233+
231234
// DEPRECATED: old solver enum types, use string instead
232235
enum SolverType {
233236
SGD = 0;
@@ -675,7 +678,7 @@ message EltwiseParameter {
675678
// Message that stores parameters used by ELULayer
676679
message ELUParameter {
677680
// Described in:
678-
// Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate
681+
// Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate
679682
// Deep Network Learning by Exponential Linear Units (ELUs). arXiv
680683
optional float alpha = 1 [default = 1];
681684
}

src/caffe/solver.cpp

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,15 @@ void Solver<Dtype>::Solve(const char* resume_file) {
325325
template <typename Dtype>
326326
void Solver<Dtype>::TestAll() {
327327
for (int test_net_id = 0;
328-
test_net_id < test_nets_.size() && !requested_early_exit_;
329-
++test_net_id) {
330-
Test(test_net_id);
328+
test_net_id < test_nets_.size() && !requested_early_exit_;
329+
++test_net_id) {
330+
if (param_.eval_type() == "classification") {
331+
Test(test_net_id);
332+
} else if (param_.eval_type() == "segmentation") {
333+
TestSegmentation(test_net_id);
334+
} else {
335+
LOG(FATAL) << "Unknown evaluation type: " << param_.eval_type();
336+
}
331337
}
332338
}
333339

@@ -406,6 +412,94 @@ void Solver<Dtype>::Test(const int test_net_id) {
406412
}
407413
}
408414

415+
template <typename Dtype>
416+
void Solver<Dtype>::TestSegmentation(const int test_net_id) {
417+
LOG(INFO) << "Iteration " << iter_
418+
<< ", Testing net (#" << test_net_id << ")";
419+
CHECK_NOTNULL(test_nets_[test_net_id].get())->
420+
ShareTrainedLayersWith(net_.get());
421+
vector<shared_ptr<Blob<Dtype> > > label_stats;
422+
vector<Blob<Dtype>*> bottom_vec;
423+
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
424+
Dtype loss = 0;
425+
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
426+
Dtype iter_loss;
427+
const vector<Blob<Dtype>*>& result =
428+
test_net->Forward(bottom_vec, &iter_loss);
429+
if (param_.test_compute_loss()) {
430+
loss += iter_loss;
431+
}
432+
if (result.size() == 0) {
433+
continue;
434+
}
435+
if (i == 0) {
436+
for (int j = 0; j < result.size(); ++j) {
437+
shared_ptr<Blob<Dtype> > label_stat(new Blob<Dtype>());
438+
label_stats.push_back(label_stat);
439+
label_stat->Reshape(1, result[j]->channels(),
440+
result[j]->height(), result[j]->width());
441+
// copy the result
442+
caffe_copy(result[j]->count(), result[j]->cpu_data(),
443+
label_stat->mutable_cpu_data());
444+
}
445+
} else {
446+
// add the result
447+
for (int j = 0; j < result.size(); ++j) {
448+
caffe_axpy(result[j]->count(), Dtype(1), result[j]->cpu_data(),
449+
label_stats[j]->mutable_cpu_data());
450+
}
451+
}
452+
}
453+
if (param_.test_compute_loss()) {
454+
loss /= param_.test_iter(test_net_id);
455+
LOG(INFO) << "Test loss: " << loss;
456+
}
457+
for (int i = 0; i < label_stats.size(); ++i) {
458+
const int output_blob_index = test_net->output_blob_indices()[i];
459+
const string& output_name = test_net->blob_names()[output_blob_index];
460+
const Dtype* label_stat_data = label_stats[i]->cpu_data();
461+
const int channels = label_stats[i]->channels();
462+
// get sum infomation
463+
Dtype sum_gtpred = 0;
464+
Dtype sum_gt = 0;
465+
for (int c = 0; c < channels; ++c) {
466+
sum_gtpred += label_stat_data[c*3];
467+
sum_gt += label_stat_data[c*3+1];
468+
}
469+
if (sum_gt > 0) {
470+
// compute accuracy for segmentation
471+
Dtype per_pixel_acc = sum_gtpred / sum_gt;
472+
Dtype per_label_acc = 0;
473+
Dtype iou, iou_acc = 0, weighted_iou_acc = 0;
474+
int num_valid_labels = 0;
475+
for (int c = 0; c < channels; ++c) {
476+
if (label_stat_data[1] != 0) {
477+
per_label_acc += label_stat_data[0] / label_stat_data[1];
478+
++num_valid_labels;
479+
}
480+
if (label_stat_data[1] + label_stat_data[2] != 0) {
481+
iou = label_stat_data[0] / (label_stat_data[1] + label_stat_data[2]
482+
- label_stat_data[0]);
483+
iou_acc += iou;
484+
weighted_iou_acc += iou * label_stat_data[1] / sum_gt;
485+
}
486+
label_stat_data += label_stats[i]->offset(0, 1);
487+
}
488+
LOG(INFO) << " Test net output #" << i << " " << output_name
489+
<< ": per_pixel_acc = " << per_pixel_acc;
490+
LOG(INFO) << " Test net output #" << i << " " << output_name
491+
<< ": per_label_acc = " << per_label_acc / num_valid_labels;
492+
LOG(INFO) << " Test net output #" << i << " " << output_name
493+
<< ": iou_acc = " << iou_acc / num_valid_labels;
494+
LOG(INFO) << " Test net output #" << i << " " << output_name
495+
<< ": weighted_iou_acc = " << weighted_iou_acc;
496+
} else {
497+
LOG(INFO) << " Test net output #" << i << " " << output_name
498+
<< ": no valid labels!";
499+
}
500+
}
501+
}
502+
409503
template <typename Dtype>
410504
void Solver<Dtype>::Snapshot() {
411505
CHECK(Caffe::root_solver());

0 commit comments

Comments
 (0)