diff --git a/demo/demo_multi_model.py b/demo/demo_multi_model.py new file mode 100644 index 00000000000..f7935de6f90 --- /dev/null +++ b/demo/demo_multi_model.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Support for multi-model fusion, and currently only the Weighted Box Fusion +(WBF) fusion method is supported. + +References: https://github.com/ZFTurbo/Weighted-Boxes-Fusion + +Example: + + python demo/demo_multi_model.py demo/demo.jpg \ + ./configs/faster_rcnn/faster-rcnn_r50-caffe_fpn_1x_coco.py \ + ./configs/retinanet/retinanet_r50-caffe_fpn_1x_coco.py \ + --checkpoints \ + https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco/faster_rcnn_r50_caffe_fpn_1x_coco_bbox_mAP-0.378_20200504_180032-c5925ee5.pth \ # noqa + https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_caffe_fpn_1x_coco/retinanet_r50_caffe_fpn_1x_coco_20200531-f11027c5.pth \ + --weights 1 2 +""" + +import argparse +import os.path as osp + +import mmcv +import mmengine +from mmengine.fileio import isdir, join_path, list_dir_or_file +from mmengine.logging import print_log +from mmengine.structures import InstanceData + +from mmdet.apis import DetInferencer +from mmdet.models.utils import weighted_boxes_fusion +from mmdet.registry import VISUALIZERS +from mmdet.structures import DetDataSample + +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', + '.tiff', '.webp') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMDetection multi-model inference demo') + parser.add_argument( + 'inputs', type=str, help='Input image file or folder path.') + parser.add_argument( + 'config', + type=str, + nargs='*', + help='Config file(s), support receive multiple files') + parser.add_argument( + '--checkpoints', + type=str, + nargs='*', + help='Checkpoint file(s), support receive multiple files, ' + 'remember to correspond to the above config', + ) + parser.add_argument( + '--weights', + type=float, + nargs='*', + default=None, + help='weights for each model, remember to ' + 'correspond to the above config') + parser.add_argument( + '--fusion-iou-thr', + type=float, + default=0.55, + help='IoU value for boxes to be a match in wbf') + parser.add_argument( + '--skip-box-thr', + type=float, + default=0.0, + help='exclude boxes with score lower than this variable in wbf') + parser.add_argument( + '--conf-type', + type=str, + default='avg', # avg, max, box_and_model_avg, absent_model_aware_avg + help='how to calculate confidence in weighted boxes in wbf') + parser.add_argument( + '--out-dir', + type=str, + default='outputs', + help='Output directory of images or prediction results.') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--pred-score-thr', + type=float, + default=0.3, + help='bbox score threshold') + parser.add_argument( + '--batch-size', type=int, default=1, help='Inference batch size.') + parser.add_argument( + '--show', + action='store_true', + help='Display the image in a popup window.') + parser.add_argument( + '--no-save-vis', + action='store_true', + help='Do not save detection vis results') + parser.add_argument( + '--no-save-pred', + action='store_true', + help='Do not save detection json results') + parser.add_argument( + '--palette', + default='none', + choices=['coco', 'voc', 'citys', 'random', 'none'], + help='Color palette used for visualization') + + args = parser.parse_args() + + if args.no_save_vis and args.no_save_pred: + args.out_dir = '' + + return args + + +def main(): + args = parse_args() + + results = [] + cfg_visualizer = None + dataset_meta = None + + inputs = [] + filename_list = [] + if isdir(args.inputs): + dir = list_dir_or_file( + args.inputs, list_dir=False, suffix=IMG_EXTENSIONS) + for filename in dir: + img = mmcv.imread(join_path(args.inputs, filename)) + inputs.append(img) + filename_list.append(filename) + else: + img = mmcv.imread(args.inputs) + inputs.append(img) + img_name = osp.basename(args.inputs) + filename_list.append(img_name) + + for i, (config, + checkpoint) in enumerate(zip(args.config, args.checkpoints)): + inferencer = DetInferencer( + config, checkpoint, device=args.device, palette=args.palette) + + result_raw = inferencer( + inputs=inputs, + batch_size=args.batch_size, + no_save_vis=True, + pred_score_thr=args.pred_score_thr) + + if i == 0: + cfg_visualizer = inferencer.cfg.visualizer + dataset_meta = inferencer.model.dataset_meta + results = [{ + 'bboxes_list': [], + 'scores_list': [], + 'labels_list': [] + } for _ in range(len(result_raw['predictions']))] + + for res, raw in zip(results, result_raw['predictions']): + res['bboxes_list'].append(raw['bboxes']) + res['scores_list'].append(raw['scores']) + res['labels_list'].append(raw['labels']) + + visualizer = VISUALIZERS.build(cfg_visualizer) + visualizer.dataset_meta = dataset_meta + + for i in range(len(results)): + bboxes, scores, labels = weighted_boxes_fusion( + results[i]['bboxes_list'], + results[i]['scores_list'], + results[i]['labels_list'], + weights=args.weights, + iou_thr=args.fusion_iou_thr, + skip_box_thr=args.skip_box_thr, + conf_type=args.conf_type) + + pred_instances = InstanceData() + pred_instances.bboxes = bboxes + pred_instances.scores = scores + pred_instances.labels = labels + + fusion_result = DetDataSample(pred_instances=pred_instances) + + img_name = filename_list[i] + + if not args.no_save_pred: + out_json_path = ( + args.out_dir + '/preds/' + img_name.split('.')[0] + '.json') + mmengine.dump( + { + 'labels': labels.tolist(), + 'scores': scores.tolist(), + 'bboxes': bboxes.tolist() + }, out_json_path) + + out_file = osp.join(args.out_dir, 'vis', + img_name) if not args.no_save_vis else None + + visualizer.add_datasample( + img_name, + inputs[i][..., ::-1], + data_sample=fusion_result, + show=args.show, + draw_gt=False, + wait_time=0, + pred_score_thr=args.pred_score_thr, + out_file=out_file) + + if not args.no_save_vis: + print_log(f'results have been saved at {args.out_dir}') + + +if __name__ == '__main__': + main() diff --git a/docs/en/user_guides/useful_tools.md b/docs/en/user_guides/useful_tools.md index eb626624f6e..8a79f0c2f1b 100644 --- a/docs/en/user_guides/useful_tools.md +++ b/docs/en/user_guides/useful_tools.md @@ -111,6 +111,80 @@ python tools/analysis_tools/analyze_results.py \ --show-score-thr 0.3 ``` +## Fusing results from multiple models + +`tools/analysis_tools/fusion_results.py` can fusing predictions using Weighted Boxes Fusion(WBF) from different object detection models. (Currently support coco format only) + +**Usage** + +```shell +python tools/analysis_tools/fuse_results.py \ + ${PRED_RESULTS} \ + [--annotation ${ANNOTATION}] \ + [--weights ${WEIGHTS}] \ + [--fusion-iou-thr ${FUSION_IOU_THR}] \ + [--skip-box-thr ${SKIP_BOX_THR}] \ + [--conf-type ${CONF_TYPE}] \ + [--eval-single ${EVAL_SINGLE}] \ + [--save-fusion-results ${SAVE_FUSION_RESULTS}] \ + [--out-dir ${OUT_DIR}] +``` + +Description of all arguments: + +- `pred-results`: Paths of detection results from different models.(Currently support coco format only) +- `--annotation`: Path of ground-truth. +- `--weights`: List of weights for each model. Default: `None`, which means weight == 1 for each model. +- `--fusion-iou-thr`: IoU value for boxes to be a match。Default: `0.55`。 +- `--skip-box-thr`: The confidence threshold that needs to be excluded in the WBF algorithm. bboxes whose confidence is less than this value will be excluded.。Default: `0`。 +- `--conf-type`: How to calculate confidence in weighted boxes. + - `avg`: average value,default. + - `max`: maximum value. + - `box_and_model_avg`: box and model wise hybrid weighted average. + - `absent_model_aware_avg`: weighted average that takes into account the absent model. +- `--eval-single`: Whether evaluate every single model. Default: `False`. +- `--save-fusion-results`: Whether save fusion results. Default: `False`. +- `--out-dir`: Path of fusion results. + +**Examples**: +Assume that you have got 3 result files from corresponding models through `tools/test.py`, which paths are './faster-rcnn_r50-caffe_fpn_1x_coco.json', './retinanet_r50-caffe_fpn_1x_coco.json', './cascade-rcnn_r50-caffe_fpn_1x_coco.json' respectively. The ground-truth file path is './annotation.json'. + +1. Fusion of predictions from three models and evaluation of their effectiveness + +```shell +python tools/analysis_tools/fuse_results.py \ + ./faster-rcnn_r50-caffe_fpn_1x_coco.json \ + ./retinanet_r50-caffe_fpn_1x_coco.json \ + ./cascade-rcnn_r50-caffe_fpn_1x_coco.json \ + --annotation ./annotation.json \ + --weights 1 2 3 \ +``` + +2. Simultaneously evaluate each single model and fusion results + +```shell +python tools/analysis_tools/fuse_results.py \ + ./faster-rcnn_r50-caffe_fpn_1x_coco.json \ + ./retinanet_r50-caffe_fpn_1x_coco.json \ + ./cascade-rcnn_r50-caffe_fpn_1x_coco.json \ + --annotation ./annotation.json \ + --weights 1 2 3 \ + --eval-single +``` + +3. Fusion of prediction results from three models and save + +```shell +python tools/analysis_tools/fuse_results.py \ + ./faster-rcnn_r50-caffe_fpn_1x_coco.json \ + ./retinanet_r50-caffe_fpn_1x_coco.json \ + ./cascade-rcnn_r50-caffe_fpn_1x_coco.json \ + --annotation ./annotation.json \ + --weights 1 2 3 \ + --save-fusion-results \ + --out-dir outputs/fusion +``` + ## Visualization ### Visualize Datasets diff --git a/docs/zh_cn/user_guides/useful_tools.md b/docs/zh_cn/user_guides/useful_tools.md index 00ed06321ef..8416472c90e 100644 --- a/docs/zh_cn/user_guides/useful_tools.md +++ b/docs/zh_cn/user_guides/useful_tools.md @@ -109,6 +109,80 @@ python tools/analysis_tools/analyze_results.py \ --show-score-thr 0.3 ``` +## 多模型检测结果融合 + +`tools/analysis_tools/fuse_results.py` 可使用 Weighted Boxes Fusion(WBF) 方法将多个模型的检测结果进行融合。(当前仅支持 COCO 格式) + +**使用方法** + +```shell +python tools/analysis_tools/fuse_results.py \ + ${PRED_RESULTS} \ + [--annotation ${ANNOTATION}] \ + [--weights ${WEIGHTS}] \ + [--fusion-iou-thr ${FUSION_IOU_THR}] \ + [--skip-box-thr ${SKIP_BOX_THR}] \ + [--conf-type ${CONF_TYPE}] \ + [--eval-single ${EVAL_SINGLE}] \ + [--save-fusion-results ${SAVE_FUSION_RESULTS}] \ + [--out-dir ${OUT_DIR}] +``` + +各个参数选项的作用: + +- `pred-results`: 多模型测试结果的保存路径。(目前仅支持 json 格式) +- `--annotation`: 真实标注框的保存路径。 +- `--weights`: 模型融合权重。默认设置下,每个模型的权重均为1。 +- `--fusion-iou-thr`: 在WBF算法中,匹配成功的 IoU 阈值,默认值为`0.55`。 +- `--skip-box-thr`: WBF算法中需剔除的置信度阈值,置信度小于该值的 bbox 会被剔除,默认值为`0`。 +- `--conf-type`: 如何计算融合后 bbox 的置信度。有以下四种选项: + - `avg`: 取平均值,默认为此选项。 + - `max`: 取最大值。 + - `box_and_model_avg`: box和模型尺度的加权平均值。 + - `absent_model_aware_avg`: 考虑缺失模型的加权平均值。 +- `--eval-single`: 是否评估每个单一模型,默认值为`False`。 +- `--save-fusion-results`: 是否保存融合结果,默认值为`False`。 +- `--out-dir`: 融合结果保存的路径。 + +**样例**: +假设你已经通过 `tools/test.py` 得到了3个模型的 json 格式的结果文件,路径分别为 './faster-rcnn_r50-caffe_fpn_1x_coco.json', './retinanet_r50-caffe_fpn_1x_coco.json', './cascade-rcnn_r50-caffe_fpn_1x_coco.json',真实标注框的文件路径为'./annotation.json'。 + +1. 融合三个模型的预测结果并评估其效果 + +```shell +python tools/analysis_tools/fuse_results.py \ + ./faster-rcnn_r50-caffe_fpn_1x_coco.json \ + ./retinanet_r50-caffe_fpn_1x_coco.json \ + ./cascade-rcnn_r50-caffe_fpn_1x_coco.json \ + --annotation ./annotation.json \ + --weights 1 2 3 \ +``` + +2. 同时评估每个单一模型与融合结果 + +```shell +python tools/analysis_tools/fuse_results.py \ + ./faster-rcnn_r50-caffe_fpn_1x_coco.json \ + ./retinanet_r50-caffe_fpn_1x_coco.json \ + ./cascade-rcnn_r50-caffe_fpn_1x_coco.json \ + --annotation ./annotation.json \ + --weights 1 2 3 \ + --eval-single +``` + +3. 融合三个模型的预测结果并保存 + +```shell +python tools/analysis_tools/fuse_results.py \ + ./faster-rcnn_r50-caffe_fpn_1x_coco.json \ + ./retinanet_r50-caffe_fpn_1x_coco.json \ + ./cascade-rcnn_r50-caffe_fpn_1x_coco.json \ + --annotation ./annotation.json \ + --weights 1 2 3 \ + --save-fusion-results \ + --out-dir outputs/fusion +``` + ## 可视化 ### 可视化数据集 diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py index 81bef2ccf5e..a00d9a37f33 100644 --- a/mmdet/models/utils/__init__.py +++ b/mmdet/models/utils/__init__.py @@ -19,6 +19,7 @@ from .point_sample import (get_uncertain_point_coords_with_randomness, get_uncertainty) from .vlfuse_helper import BertEncoderLayer, VLFuse, permute_and_flatten +from .wbf import weighted_boxes_fusion __all__ = [ 'gaussian_radius', 'gen_gaussian_target', 'make_divisible', @@ -32,5 +33,5 @@ 'samplelist_boxtype2tensor', 'filter_gt_instances', 'rename_loss_dict', 'reweight_loss_dict', 'relative_coordinate_maps', 'aligned_bilinear', 'unfold_wo_center', 'imrenormalize', 'VLFuse', 'permute_and_flatten', - 'BertEncoderLayer', 'align_tensor' + 'BertEncoderLayer', 'align_tensor', 'weighted_boxes_fusion' ] diff --git a/mmdet/models/utils/wbf.py b/mmdet/models/utils/wbf.py new file mode 100644 index 00000000000..b26a2c669a5 --- /dev/null +++ b/mmdet/models/utils/wbf.py @@ -0,0 +1,250 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import warnings +from typing import Tuple + +import numpy as np +import torch +from torch import Tensor + + +# References: https://github.com/ZFTurbo/Weighted-Boxes-Fusion +def weighted_boxes_fusion( + bboxes_list: list, + scores_list: list, + labels_list: list, + weights: list = None, + iou_thr: float = 0.55, + skip_box_thr: float = 0.0, + conf_type: str = 'avg', + allows_overflow: bool = False) -> Tuple[Tensor, Tensor, Tensor]: + """weighted boxes fusion is a method for + fusing predictions from different object detection models, which utilizes + confidence scores of all proposed bounding boxes to construct averaged + boxes. + + Args: + bboxes_list(list): list of boxes predictions from each model, + each box is 4 numbers. + scores_list(list): list of scores for each model + labels_list(list): list of labels for each model + weights: list of weights for each model. + Default: None, which means weight == 1 for each model + iou_thr: IoU value for boxes to be a match + skip_box_thr: exclude boxes with score lower than this variable. + conf_type: how to calculate confidence in weighted boxes. + 'avg': average value, + 'max': maximum value, + 'box_and_model_avg': box and model wise hybrid weighted average, + 'absent_model_aware_avg': weighted average that takes into + account the absent model. + allows_overflow: false if we want confidence score not exceed 1.0. + + Returns: + bboxes(Tensor): boxes coordinates (Order of boxes: x1, y1, x2, y2). + scores(Tensor): confidence scores + labels(Tensor): boxes labels + """ + + if weights is None: + weights = np.ones(len(bboxes_list)) + if len(weights) != len(bboxes_list): + print('Warning: incorrect number of weights {}. Must be: ' + '{}. Set weights equal to 1.'.format( + len(weights), len(bboxes_list))) + weights = np.ones(len(bboxes_list)) + weights = np.array(weights) + + if conf_type not in [ + 'avg', 'max', 'box_and_model_avg', 'absent_model_aware_avg' + ]: + print('Unknown conf_type: {}. Must be "avg", ' + '"max" or "box_and_model_avg", ' + 'or "absent_model_aware_avg"'.format(conf_type)) + exit() + + filtered_boxes = prefilter_boxes(bboxes_list, scores_list, labels_list, + weights, skip_box_thr) + if len(filtered_boxes) == 0: + return torch.Tensor(), torch.Tensor(), torch.Tensor() + + overall_boxes = [] + + for label in filtered_boxes: + boxes = filtered_boxes[label] + new_boxes = [] + weighted_boxes = np.empty((0, 8)) + + # Clusterize boxes + for j in range(0, len(boxes)): + index, best_iou = find_matching_box_fast(weighted_boxes, boxes[j], + iou_thr) + + if index != -1: + new_boxes[index].append(boxes[j]) + weighted_boxes[index] = get_weighted_box( + new_boxes[index], conf_type) + else: + new_boxes.append([boxes[j].copy()]) + weighted_boxes = np.vstack((weighted_boxes, boxes[j].copy())) + + # Rescale confidence based on number of models and boxes + for i in range(len(new_boxes)): + clustered_boxes = new_boxes[i] + if conf_type == 'box_and_model_avg': + clustered_boxes = np.array(clustered_boxes) + # weighted average for boxes + weighted_boxes[i, 1] = weighted_boxes[i, 1] * len( + clustered_boxes) / weighted_boxes[i, 2] + # identify unique model index by model index column + _, idx = np.unique(clustered_boxes[:, 3], return_index=True) + # rescale by unique model weights + weighted_boxes[i, 1] = weighted_boxes[i, 1] * clustered_boxes[ + idx, 2].sum() / weights.sum() + elif conf_type == 'absent_model_aware_avg': + clustered_boxes = np.array(clustered_boxes) + # get unique model index in the cluster + models = np.unique(clustered_boxes[:, 3]).astype(int) + # create a mask to get unused model weights + mask = np.ones(len(weights), dtype=bool) + mask[models] = False + # absent model aware weighted average + weighted_boxes[ + i, 1] = weighted_boxes[i, 1] * len(clustered_boxes) / ( + weighted_boxes[i, 2] + weights[mask].sum()) + elif conf_type == 'max': + weighted_boxes[i, 1] = weighted_boxes[i, 1] / weights.max() + elif not allows_overflow: + weighted_boxes[i, 1] = weighted_boxes[i, 1] * min( + len(weights), len(clustered_boxes)) / weights.sum() + else: + weighted_boxes[i, 1] = weighted_boxes[i, 1] * len( + clustered_boxes) / weights.sum() + overall_boxes.append(weighted_boxes) + overall_boxes = np.concatenate(overall_boxes, axis=0) + overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]] + + bboxes = torch.Tensor(overall_boxes[:, 4:]) + scores = torch.Tensor(overall_boxes[:, 1]) + labels = torch.Tensor(overall_boxes[:, 0]).int() + + return bboxes, scores, labels + + +def prefilter_boxes(boxes, scores, labels, weights, thr): + + new_boxes = dict() + + for t in range(len(boxes)): + + if len(boxes[t]) != len(scores[t]): + print('Error. Length of boxes arrays not equal to ' + 'length of scores array: {} != {}'.format( + len(boxes[t]), len(scores[t]))) + exit() + + if len(boxes[t]) != len(labels[t]): + print('Error. Length of boxes arrays not equal to ' + 'length of labels array: {} != {}'.format( + len(boxes[t]), len(labels[t]))) + exit() + + for j in range(len(boxes[t])): + score = scores[t][j] + if score < thr: + continue + label = int(labels[t][j]) + box_part = boxes[t][j] + x1 = float(box_part[0]) + y1 = float(box_part[1]) + x2 = float(box_part[2]) + y2 = float(box_part[3]) + + # Box data checks + if x2 < x1: + warnings.warn('X2 < X1 value in box. Swap them.') + x1, x2 = x2, x1 + if y2 < y1: + warnings.warn('Y2 < Y1 value in box. Swap them.') + y1, y2 = y2, y1 + if (x2 - x1) * (y2 - y1) == 0.0: + warnings.warn('Zero area box skipped: {}.'.format(box_part)) + continue + + # [label, score, weight, model index, x1, y1, x2, y2] + b = [ + int(label), + float(score) * weights[t], weights[t], t, x1, y1, x2, y2 + ] + + if label not in new_boxes: + new_boxes[label] = [] + new_boxes[label].append(b) + + # Sort each list in dict by score and transform it to numpy array + for k in new_boxes: + current_boxes = np.array(new_boxes[k]) + new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]] + + return new_boxes + + +def get_weighted_box(boxes, conf_type='avg'): + + box = np.zeros(8, dtype=np.float32) + conf = 0 + conf_list = [] + w = 0 + for b in boxes: + box[4:] += (b[1] * b[4:]) + conf += b[1] + conf_list.append(b[1]) + w += b[2] + box[0] = boxes[0][0] + if conf_type in ('avg', 'box_and_model_avg', 'absent_model_aware_avg'): + box[1] = conf / len(boxes) + elif conf_type == 'max': + box[1] = np.array(conf_list).max() + box[2] = w + box[3] = -1 + box[4:] /= conf + + return box + + +def find_matching_box_fast(boxes_list, new_box, match_iou): + + def bb_iou_array(boxes, new_box): + # bb intersection over union + xA = np.maximum(boxes[:, 0], new_box[0]) + yA = np.maximum(boxes[:, 1], new_box[1]) + xB = np.minimum(boxes[:, 2], new_box[2]) + yB = np.minimum(boxes[:, 3], new_box[3]) + + interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0) + + # compute the area of both the prediction and ground-truth rectangles + boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1]) + + iou = interArea / (boxAArea + boxBArea - interArea) + + return iou + + if boxes_list.shape[0] == 0: + return -1, match_iou + + boxes = boxes_list + + ious = bb_iou_array(boxes[:, 4:], new_box[4:]) + + ious[boxes[:, 0] != new_box[0]] = -1 + + best_idx = np.argmax(ious) + best_iou = ious[best_idx] + + if best_iou <= match_iou: + best_iou = match_iou + best_idx = -1 + + return best_idx, best_iou diff --git a/tools/analysis_tools/fuse_results.py b/tools/analysis_tools/fuse_results.py new file mode 100644 index 00000000000..1f35123cbbb --- /dev/null +++ b/tools/analysis_tools/fuse_results.py @@ -0,0 +1,142 @@ +import argparse + +from mmengine.fileio import dump, load +from mmengine.logging import print_log +from mmengine.utils import ProgressBar +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +from mmdet.models.utils import weighted_boxes_fusion + + +def parse_args(): + parser = argparse.ArgumentParser(description='Fusion image \ + prediction results using Weighted \ + Boxes Fusion from multiple models.') + parser.add_argument( + 'pred-results', + type=str, + nargs='+', + help='files of prediction results \ + from multiple models, json format.') + parser.add_argument('--annotation', type=str, help='annotation file path') + parser.add_argument( + '--weights', + type=float, + nargs='*', + default=None, + help='weights for each model, ' + 'remember to correspond to the above prediction path.') + parser.add_argument( + '--fusion-iou-thr', + type=float, + default=0.55, + help='IoU value for boxes to be a match in wbf.') + parser.add_argument( + '--skip-box-thr', + type=float, + default=0.0, + help='exclude boxes with score lower than this variable in wbf.') + parser.add_argument( + '--conf-type', + type=str, + default='avg', + help='how to calculate confidence in weighted boxes in wbf.') + parser.add_argument( + '--eval-single', + action='store_true', + help='whether evaluate each single model result.') + parser.add_argument( + '--save-fusion-results', + action='store_true', + help='whether save fusion result') + parser.add_argument( + '--out-dir', + type=str, + default='outputs', + help='Output directory of images or prediction results.') + + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + + assert len(args.models_name) == len(args.pred_results), \ + 'the quantities of model names and prediction results are not equal' + + cocoGT = COCO(args.annotation) + + predicts_raw = [] + + models_name = ['model_' + str(i) for i in range(len(args.pred_results))] + + for model_name, path in \ + zip(models_name, args.pred_results): + pred = load(path) + predicts_raw.append(pred) + + if args.eval_single: + print_log(f'Evaluate {model_name}...') + cocoDt = cocoGT.loadRes(pred) + coco_eval = COCOeval(cocoGT, cocoDt, iouType='bbox') + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + predict = { + str(image_id): { + 'bboxes_list': [[] for _ in range(len(predicts_raw))], + 'scores_list': [[] for _ in range(len(predicts_raw))], + 'labels_list': [[] for _ in range(len(predicts_raw))] + } + for image_id in cocoGT.getImgIds() + } + + for i, pred_single in enumerate(predicts_raw): + for pred in pred_single: + p = predict[str(pred['image_id'])] + p['bboxes_list'][i].append(pred['bbox']) + p['scores_list'][i].append(pred['score']) + p['labels_list'][i].append(pred['category_id']) + + result = [] + prog_bar = ProgressBar(len(predict)) + for image_id, res in predict.items(): + bboxes, scores, labels = weighted_boxes_fusion( + res['bboxes_list'], + res['scores_list'], + res['labels_list'], + weights=args.weights, + iou_thr=args.fusion_iou_thr, + skip_box_thr=args.skip_box_thr, + conf_type=args.conf_type) + + for bbox, score, label in zip(bboxes, scores, labels): + result.append({ + 'bbox': bbox.numpy().tolist(), + 'category_id': int(label), + 'image_id': int(image_id), + 'score': float(score) + }) + + prog_bar.update() + + if args.save_fusion_results: + out_file = args.out_dir + '/fusion_results.json' + dump(result, file=out_file) + print_log( + f'Fusion results have been saved to {out_file}.', logger='current') + + print_log('Evaluate fusion results using wbf...') + cocoDt = cocoGT.loadRes(result) + coco_eval = COCOeval(cocoGT, cocoDt, iouType='bbox') + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + +if __name__ == '__main__': + main()