|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# coding: utf-8 |
| 3 | + |
| 4 | +from __future__ import absolute_import |
| 5 | +from __future__ import division |
| 6 | +from __future__ import print_function |
| 7 | + |
| 8 | +import os |
| 9 | +import pickle |
| 10 | +import random |
| 11 | +import time |
| 12 | +from pathlib import Path |
| 13 | + |
| 14 | +import cv2 |
| 15 | +import numpy as np |
| 16 | +from keras import backend as K |
| 17 | +from keras.layers import Input |
| 18 | +from keras.models import Model |
| 19 | +from matplotlib import pyplot as plt |
| 20 | +from sklearn.metrics import average_precision_score |
| 21 | +from sklearn.metrics.pairwise import cosine_similarity |
| 22 | +from tqdm import tqdm |
| 23 | + |
| 24 | +from faster_rcnn import FasterRCNN |
| 25 | +from faster_rcnn import Config |
| 26 | +from faster_rcnn import iou |
| 27 | +from imgs_to_roi_features import ( |
| 28 | + format_img_channels, |
| 29 | + format_img_size, |
| 30 | + imgs_to_roi_features, |
| 31 | +) |
| 32 | +from create_retrieval_db import best_bbox |
| 33 | + |
| 34 | + |
| 35 | +def get_map(pred, gt): |
| 36 | + T = {} |
| 37 | + P = {} |
| 38 | + |
| 39 | + for bbox in gt: |
| 40 | + bbox["bbox_matched"] = False |
| 41 | + |
| 42 | + pred_probs = np.array([s["prob"] for s in pred]) |
| 43 | + box_idx_sorted_by_prob = np.argsort(pred_probs)[::-1] |
| 44 | + |
| 45 | + for box_idx in box_idx_sorted_by_prob: |
| 46 | + pred_box = pred[box_idx] |
| 47 | + pred_class = pred_box["class"] |
| 48 | + pred_x1 = pred_box["x1"] |
| 49 | + pred_x2 = pred_box["x2"] |
| 50 | + pred_y1 = pred_box["y1"] |
| 51 | + pred_y2 = pred_box["y2"] |
| 52 | + pred_prob = pred_box["prob"] |
| 53 | + if pred_class not in P: |
| 54 | + P[pred_class] = [] |
| 55 | + T[pred_class] = [] |
| 56 | + P[pred_class].append(pred_prob) |
| 57 | + found_match = False |
| 58 | + |
| 59 | + for gt_box in gt: |
| 60 | + gt_class = gt_box["class"] |
| 61 | + gt_x1 = gt_box["x1"] |
| 62 | + gt_x2 = gt_box["x2"] |
| 63 | + gt_y1 = gt_box["y1"] |
| 64 | + gt_y2 = gt_box["y2"] |
| 65 | + gt_seen = gt_box["bbox_matched"] |
| 66 | + if gt_class != pred_class: |
| 67 | + continue |
| 68 | + if gt_seen: |
| 69 | + continue |
| 70 | + iou_map = iou( |
| 71 | + (pred_x1, pred_y1, pred_x2, pred_y2), (gt_x1, gt_y1, gt_x2, gt_y2) |
| 72 | + ) |
| 73 | + if iou_map >= 0.5: |
| 74 | + found_match = True |
| 75 | + gt_box["bbox_matched"] = True |
| 76 | + break |
| 77 | + else: |
| 78 | + continue |
| 79 | + |
| 80 | + T[pred_class].append(int(found_match)) |
| 81 | + |
| 82 | + for gt_box in gt: |
| 83 | + if not gt_box["bbox_matched"]: # and not gt_box['difficult']: |
| 84 | + if gt_box["class"] not in P: |
| 85 | + P[gt_box["class"]] = [] |
| 86 | + T[gt_box["class"]] = [] |
| 87 | + print(f'Some gt box has not been associated to {gt_box["path"]}') |
| 88 | + T[gt_box["class"]].append(1) |
| 89 | + P[gt_box["class"]].append(0) |
| 90 | + return T, P |
| 91 | + |
| 92 | + |
| 93 | +def format_img_map(img, C): |
| 94 | + """Format image for mAP. Resize original image to C.im_size (300 in here) |
| 95 | +
|
| 96 | + Args: |
| 97 | + img: cv2 image |
| 98 | + C: config |
| 99 | +
|
| 100 | + Returns: |
| 101 | + img: Scaled and normalized image with expanding dimension |
| 102 | + fx: ratio for width scaling |
| 103 | + fy: ratio for height scaling |
| 104 | + """ |
| 105 | + img, ratio, fx, fy = format_img_size(img, C) |
| 106 | + img = format_img_channels(img, C) |
| 107 | + return img, fx, fy |
| 108 | + |
| 109 | + |
| 110 | +def data_to_dict(l): |
| 111 | + l = l.strip().split(",") |
| 112 | + return { |
| 113 | + "path": l[0], |
| 114 | + "x1": int(l[1]), |
| 115 | + "y1": int(l[2]), |
| 116 | + "x2": int(l[3]), |
| 117 | + "y2": int(l[4]), |
| 118 | + "class": l[5], |
| 119 | + } |
| 120 | + |
| 121 | + |
| 122 | +if __name__ == "__main__": |
| 123 | + |
| 124 | + config_output_filename = "data/instre_monuments/model_vgg_config.pickle" |
| 125 | + |
| 126 | + with open(config_output_filename, "rb") as f_in: |
| 127 | + C = pickle.load(f_in) |
| 128 | + |
| 129 | + test_path = ( |
| 130 | + "data/instre_monuments/annotations_test.txt" |
| 131 | + ) # Test data (annotation file) |
| 132 | + |
| 133 | + with open(test_path) as f: |
| 134 | + test_imgs = map(data_to_dict, f.readlines()) |
| 135 | + |
| 136 | + T = {} |
| 137 | + P = {} |
| 138 | + mAPs = [] |
| 139 | + |
| 140 | + imgs_paths = list(map(lambda img_data: img_data["path"], test_imgs)) |
| 141 | + with tqdm(total=len(imgs_paths)) as pbar: |
| 142 | + feats = imgs_to_roi_features(imgs_paths, C, 0.7, on_each_iter=pbar.update) |
| 143 | + |
| 144 | + for idx, img_data in enumerate(test_imgs): |
| 145 | + # img_data = (path, (x1,y1,x2,y2), class) |
| 146 | + |
| 147 | + t, p = {}, {} |
| 148 | + |
| 149 | + result = None |
| 150 | + if img_data["path"] in feats: |
| 151 | + result = feats[img_data["path"]] |
| 152 | + |
| 153 | + jk = best_bbox(result) |
| 154 | + |
| 155 | + x1, y1, x2, y2 = result[0][jk] |
| 156 | + prob = result[1][jk][0] |
| 157 | + key = result[1][jk][1] |
| 158 | + |
| 159 | + det = {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "class": key, "prob": prob} |
| 160 | + t, p = get_map([det], [img_data]) |
| 161 | + |
| 162 | + else: |
| 163 | + t, p = get_map([], [img_data]) |
| 164 | + |
| 165 | + for key in t.keys(): |
| 166 | + if key not in T: |
| 167 | + T[key] = [] |
| 168 | + P[key] = [] |
| 169 | + T[key].extend(t[key]) |
| 170 | + P[key].extend(p[key]) |
| 171 | + all_aps = [] |
| 172 | + for key in T.keys(): |
| 173 | + ap = average_precision_score(T[key], P[key]) |
| 174 | + print("{} AP: {}".format(key, ap)) |
| 175 | + all_aps.append(ap) |
| 176 | + print("mAP = {}".format(np.mean(np.array(all_aps)))) |
| 177 | + mAPs.append(np.mean(np.array(all_aps))) |
| 178 | + # print(T) |
| 179 | + # print(P) |
| 180 | + |
| 181 | + print() |
| 182 | + print("mean average precision:", np.mean(np.array(mAPs))) |
| 183 | + |
| 184 | + mAP = [mAP for mAP in mAPs if str(mAP) != "nan"] |
| 185 | + mean_average_prec = round(np.mean(np.array(mAP)), 3) |
| 186 | + print(f"The mean average precision is {mean_average_prec}") |
0 commit comments