@@ -325,9 +325,15 @@ void Solver<Dtype>::Solve(const char* resume_file) {
325325template <typename Dtype>
326326void Solver<Dtype>::TestAll() {
327327 for (int test_net_id = 0 ;
328- test_net_id < test_nets_.size () && !requested_early_exit_;
329- ++test_net_id) {
330- Test (test_net_id);
328+ test_net_id < test_nets_.size () && !requested_early_exit_;
329+ ++test_net_id) {
330+ if (param_.eval_type () == " classification" ) {
331+ Test (test_net_id);
332+ } else if (param_.eval_type () == " segmentation" ) {
333+ TestSegmentation (test_net_id);
334+ } else {
335+ LOG (FATAL) << " Unknown evaluation type: " << param_.eval_type ();
336+ }
331337 }
332338}
333339
@@ -406,6 +412,94 @@ void Solver<Dtype>::Test(const int test_net_id) {
406412 }
407413}
408414
415+ template <typename Dtype>
416+ void Solver<Dtype>::TestSegmentation(const int test_net_id) {
417+ LOG (INFO) << " Iteration " << iter_
418+ << " , Testing net (#" << test_net_id << " )" ;
419+ CHECK_NOTNULL (test_nets_[test_net_id].get ())->
420+ ShareTrainedLayersWith (net_.get ());
421+ vector<shared_ptr<Blob<Dtype> > > label_stats;
422+ vector<Blob<Dtype>*> bottom_vec;
423+ const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
424+ Dtype loss = 0 ;
425+ for (int i = 0 ; i < param_.test_iter (test_net_id); ++i) {
426+ Dtype iter_loss;
427+ const vector<Blob<Dtype>*>& result =
428+ test_net->Forward (bottom_vec, &iter_loss);
429+ if (param_.test_compute_loss ()) {
430+ loss += iter_loss;
431+ }
432+ if (result.size () == 0 ) {
433+ continue ;
434+ }
435+ if (i == 0 ) {
436+ for (int j = 0 ; j < result.size (); ++j) {
437+ shared_ptr<Blob<Dtype> > label_stat (new Blob<Dtype>());
438+ label_stats.push_back (label_stat);
439+ label_stat->Reshape (1 , result[j]->channels (),
440+ result[j]->height (), result[j]->width ());
441+ // copy the result
442+ caffe_copy (result[j]->count (), result[j]->cpu_data (),
443+ label_stat->mutable_cpu_data ());
444+ }
445+ } else {
446+ // add the result
447+ for (int j = 0 ; j < result.size (); ++j) {
448+ caffe_axpy (result[j]->count (), Dtype (1 ), result[j]->cpu_data (),
449+ label_stats[j]->mutable_cpu_data ());
450+ }
451+ }
452+ }
453+ if (param_.test_compute_loss ()) {
454+ loss /= param_.test_iter (test_net_id);
455+ LOG (INFO) << " Test loss: " << loss;
456+ }
457+ for (int i = 0 ; i < label_stats.size (); ++i) {
458+ const int output_blob_index = test_net->output_blob_indices ()[i];
459+ const string& output_name = test_net->blob_names ()[output_blob_index];
460+ const Dtype* label_stat_data = label_stats[i]->cpu_data ();
461+ const int channels = label_stats[i]->channels ();
462+ // get sum infomation
463+ Dtype sum_gtpred = 0 ;
464+ Dtype sum_gt = 0 ;
465+ for (int c = 0 ; c < channels; ++c) {
466+ sum_gtpred += label_stat_data[c*3 ];
467+ sum_gt += label_stat_data[c*3 +1 ];
468+ }
469+ if (sum_gt > 0 ) {
470+ // compute accuracy for segmentation
471+ Dtype per_pixel_acc = sum_gtpred / sum_gt;
472+ Dtype per_label_acc = 0 ;
473+ Dtype iou, iou_acc = 0 , weighted_iou_acc = 0 ;
474+ int num_valid_labels = 0 ;
475+ for (int c = 0 ; c < channels; ++c) {
476+ if (label_stat_data[1 ] != 0 ) {
477+ per_label_acc += label_stat_data[0 ] / label_stat_data[1 ];
478+ ++num_valid_labels;
479+ }
480+ if (label_stat_data[1 ] + label_stat_data[2 ] != 0 ) {
481+ iou = label_stat_data[0 ] / (label_stat_data[1 ] + label_stat_data[2 ]
482+ - label_stat_data[0 ]);
483+ iou_acc += iou;
484+ weighted_iou_acc += iou * label_stat_data[1 ] / sum_gt;
485+ }
486+ label_stat_data += label_stats[i]->offset (0 , 1 );
487+ }
488+ LOG (INFO) << " Test net output #" << i << " " << output_name
489+ << " : per_pixel_acc = " << per_pixel_acc;
490+ LOG (INFO) << " Test net output #" << i << " " << output_name
491+ << " : per_label_acc = " << per_label_acc / num_valid_labels;
492+ LOG (INFO) << " Test net output #" << i << " " << output_name
493+ << " : iou_acc = " << iou_acc / num_valid_labels;
494+ LOG (INFO) << " Test net output #" << i << " " << output_name
495+ << " : weighted_iou_acc = " << weighted_iou_acc;
496+ } else {
497+ LOG (INFO) << " Test net output #" << i << " " << output_name
498+ << " : no valid labels!" ;
499+ }
500+ }
501+ }
502+
409503template <typename Dtype>
410504void Solver<Dtype>::Snapshot() {
411505 CHECK (Caffe::root_solver ());
0 commit comments