diff --git a/configs/v3det/README.md b/configs/v3det/README.md index 36879316f4f..cfac5d1dc63 100644 --- a/configs/v3det/README.md +++ b/configs/v3det/README.md @@ -61,17 +61,17 @@ data/ ## Results and Models | Backbone | Model | Lr schd | box AP | Config | Download | -| :------: | :-------------: | :-----: | :----: | :----------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | -| R-50 | Faster R-CNN | 2x | 25.4 | [config](./faster_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//faster_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x) | -| R-50 | Cascade R-CNN | 2x | 31.6 | [config](./cascade_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//cascade_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x) | -| R-50 | FCOS | 2x | 9.4 | [config](./fcos_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//fcos_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x) | -| R-50 | Deformable-DETR | 50e | 34.4 | [config](./deformable-detr-refine-twostage_r50_8xb4_sample1e-3_v3det_50e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/Deformable_DETR_V3Det_R50) | -| R-50 | DINO | 36e | 33.5 | [config](./dino-4scale_r50_8xb2_sample1e-3_v3det_36e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/DINO_V3Det_R50) | -| Swin-B | Faster R-CNN | 2x | 37.6 | [config](./faster_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//faster_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x) | -| Swin-B | Cascade R-CNN | 2x | 42.5 | [config](./cascade_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//cascade_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x) | -| Swin-B | FCOS | 2x | 21.0 | [config](./fcos_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//fcos_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x) | -| Swin-B | Deformable-DETR | 50e | 42.5 | [config](./deformable-detr-refine-twostage_swin_16xb2_sample1e-3_v3det_50e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/Deformable_DETR_V3Det_SwinB) | -| Swin-B | DINO | 36e | 42.0 | [config](./dino-4scale_swin_16xb1_sample1e-3_v3det_36e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/DINO_V3Det_SwinB) | +| :------: | :-------------: | :-----: |:------:| :----------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | +| R-50 | Faster R-CNN | 2x | 25.8 | [config](./faster_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//faster_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x) | +| R-50 | Cascade R-CNN | 2x | 32.1 | [config](./cascade_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//cascade_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x) | +| R-50 | FCOS | 2x | 9.6 | [config](./fcos_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//fcos_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x) | +| R-50 | Deformable-DETR | 50e | 35.0 | [config](./deformable-detr-refine-twostage_r50_8xb4_sample1e-3_v3det_50e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/Deformable_DETR_V3Det_R50) | +| R-50 | DINO | 36e | 34.0 | [config](./dino-4scale_r50_8xb2_sample1e-3_v3det_36e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/DINO_V3Det_R50) | +| Swin-B | Faster R-CNN | 2x | 38.2 | [config](./faster_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//faster_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x) | +| Swin-B | Cascade R-CNN | 2x | 43.2 | [config](./cascade_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//cascade_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x) | +| Swin-B | FCOS | 2x | 21.5 | [config](./fcos_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//fcos_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x) | +| Swin-B | Deformable-DETR | 50e | 43.1 | [config](./deformable-detr-refine-twostage_swin_16xb2_sample1e-3_v3det_50e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/Deformable_DETR_V3Det_SwinB) | +| Swin-B | DINO | 36e | 42.6 | [config](./dino-4scale_swin_16xb1_sample1e-3_v3det_36e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/DINO_V3Det_SwinB) | ## Citation diff --git a/mmdet/datasets/api_wrappers/cocoeval_mp.py b/mmdet/datasets/api_wrappers/cocoeval_mp.py index b3673ea7a7e..592e61bc3a3 100644 --- a/mmdet/datasets/api_wrappers/cocoeval_mp.py +++ b/mmdet/datasets/api_wrappers/cocoeval_mp.py @@ -1,18 +1,64 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import itertools +import multiprocessing as mp import time from collections import defaultdict +import mmengine import numpy as np -import torch.multiprocessing as mp from mmengine.logging import MMLogger -from pycocotools.cocoeval import COCOeval +from pycocotools.cocoeval import COCOeval, Params from tqdm import tqdm class COCOevalMP(COCOeval): + def __init__(self, + cocoGt=None, + cocoDt=None, + iouType='bbox', + num_proc=8, + tree_ann_path='data/V3Det/annotations/v3det_' + '2023_v1_category_tree.json', + ignore_parent_child_gts=True): + """Initialize CocoEval using coco APIs for gt and dt. + + :param cocoGt: coco object with ground truth annotations + :param cocoDt: coco object with detection results + :return: None + """ + if not iouType: + print('iouType not specified. use default iouType segm') + self.cocoGt = cocoGt # ground truth COCO API + self.cocoDt = cocoDt # detections COCO API + self.evalImgs = defaultdict( + list) # per-image per-category evaluation results [KxAxI] elements + self.eval = {} # accumulated evaluation results + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + self.params = Params(iouType=iouType) # parameters + self._paramsEval = {} # parameters for evaluation + self.stats = [] # result summarization + self.ious = {} # ious between all gts and dts + self.num_proc = num_proc # num of process + self.tree_ann_path = tree_ann_path + self.ignore_parent_child_gts = ignore_parent_child_gts + if not mmengine.exists(tree_ann_path): + print(f'{tree_ann_path} not exist') + raise FileNotFoundError + if cocoGt is not None: + self.params.imgIds = sorted(cocoGt.getImgIds()) + self.params.catIds = sorted(cocoGt.getCatIds()) + # split base novel cat ids + self.base_inds = [] + self.novel_inds = [] + for i, c in enumerate(self.cocoGt.dataset['categories']): + if c['novel']: + self.novel_inds.append(i) + else: + self.base_inds.append(i) + def _prepare(self): ''' Prepare ._gts and ._dts for evaluation based on params @@ -39,10 +85,6 @@ def _toMask(anns, coco): if (dt['category_id'] in cat_ids) and (dt['image_id'] in img_ids): dts.append(dt) - # gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) # noqa - # dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) # noqa - # gts=self.cocoGt.dataset['annotations'] - # dts=self.cocoDt.dataset['annotations'] else: gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) @@ -63,6 +105,44 @@ def _toMask(anns, coco): self._gts[gt['image_id'], gt['category_id']].append(gt) for dt in dts: self._dts[dt['image_id'], dt['category_id']].append(dt) + + if self.ignore_parent_child_gts: + # for each category, maintain its child categories + cat_tree = mmengine.load(self.tree_ann_path) + catid2treeid = cat_tree['categoryid2treeid'] + treeid2catid = {v: k for k, v in catid2treeid.items()} + ori_ancestor2descendant = cat_tree['ancestor2descendant'] + ancestor2descendant = dict() + for k, v in ori_ancestor2descendant.items(): + if k in treeid2catid: + ancestor2descendant[k] = v + ancestor2descendant_catid = defaultdict(set) + for tree_id in ancestor2descendant: + cat_id = treeid2catid[tree_id] + descendant_ids = ancestor2descendant[tree_id] + for descendant_id in descendant_ids: + if descendant_id not in treeid2catid: + continue + descendant_catid = treeid2catid[descendant_id] + ancestor2descendant_catid[int(cat_id)].add( + int(descendant_catid)) + self.ancestor2descendant_catid = ancestor2descendant_catid + # If a gt has child category cat_A, and dts of this image + # has this category, add this gt to gt + for gt in gts: + ignore_cats = [] + for child_cat_id in self.ancestor2descendant_catid[ + gt['category_id']]: + if len(self._dts[gt['image_id'], child_cat_id]) > 0: + ignore_cats.append(child_cat_id) + if len(ignore_cats) == 0: + continue + ignore_gt = copy.deepcopy(gt) + ignore_gt['category_id'] = ignore_cats + ignore_gt['ignore'] = 1 + for child_cat_id in ignore_cats: + self._gts[gt['image_id'], child_cat_id].append(ignore_gt) + self.evalImgs = defaultdict( list) # per-image per-category evaluation results self.eval = {} # accumulated evaluation results @@ -91,7 +171,7 @@ def evaluate(self): # loop through images, area range, max detection number catIds = p.catIds if p.useCats else [-1] - nproc = 8 + nproc = self.num_proc split_size = len(catIds) // nproc mp_params = [] for i in range(nproc): @@ -102,7 +182,7 @@ def evaluate(self): mp_params.append((catIds[begin:end], )) MMLogger.get_current_instance().info( - 'start multi processing evaluation ...') + f'start multi processing evaluation with nproc: {nproc}...') with mp.Pool(nproc) as pool: self.evalImgs = pool.starmap(self._evaluateImg, mp_params) @@ -116,14 +196,12 @@ def _evaluateImg(self, catids_chunk): self._prepare() p = self.params maxDet = max(p.maxDets) - all_params = [] - for catId in catids_chunk: - for areaRng in p.areaRng: - for imgId in p.imgIds: - all_params.append((catId, areaRng, imgId)) + all_params = itertools.product(catids_chunk, p.areaRng, p.imgIds) + all_params_len = len(catids_chunk) * len(p.areaRng) * len(p.imgIds) evalImgs = [ self.evaluateImg(imgId, catId, areaRng, maxDet) - for catId, areaRng, imgId in tqdm(all_params) + for catId, areaRng, imgId in tqdm( + all_params, total=all_params_len) ] return evalImgs @@ -209,7 +287,7 @@ def evaluateImg(self, imgId, catId, aRng, maxDet): 'dtIgnore': dtIg, } - def summarize(self): + def summarize(self, is_ovd=False): """Compute and display summary metrics for evaluation results. Note this function can *only* be applied on the default parameter @@ -272,6 +350,102 @@ def _summarizeDets(): stats = np.array(stats) return stats + def _summarizeOVDs(): + + def _summarize(ap=1, + iouThr=None, + areaRng='all', + maxDets=100, + cat_kind=None): + assert cat_kind in ('Base', 'Novel') + if cat_kind == 'Novel': + cat_inds = self.novel_inds + else: + cat_inds = self.base_inds + p = self.params + iStr = (' {:<18} {} @[ IoU={:<9} | area={:>6s} | ' + 'maxDets={:>3d} ] = {:0.3f}') # noqa + titleStr = f'{cat_kind} Average Precision' if ap == 1 \ + else f'{cat_kind} Average Recall' + typeStr = '(AP)' if ap == 1 else '(AR)' + iouStr = '{:0.2f}:{:0.2f}'.format( + p.iouThrs[0], p.iouThrs[-1]) if ( + iouThr is None) else '{:0.2f}'.format(iouThr) + + aind = [ + i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng + ] + mind = [ + i for i, mDet in enumerate(p.maxDets) if mDet == maxDets + ] + if ap == 1: + # dimension of precision: [TxRxKxAxM] + s = self.eval['precision'] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, cat_inds, aind, mind] + else: + # dimension of recall: [TxKxAxM] + s = self.eval['recall'] + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, cat_inds, aind, mind] + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + print( + iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, + mean_s)) + return mean_s + + stats = [] + for cat_kind in ('Base', 'Novel'): + print(f'\nSummarize {cat_kind} Classes:') + stats.append( + _summarize( + 1, maxDets=self.params.maxDets[-1], cat_kind=cat_kind)) + stats.append( + _summarize( + 1, + iouThr=.5, + maxDets=self.params.maxDets[-1], + cat_kind=cat_kind)) + stats.append( + _summarize( + 1, + iouThr=.75, + maxDets=self.params.maxDets[-1], + cat_kind=cat_kind)) + for area_rng in ('small', 'medium', 'large'): + stats.append( + _summarize( + 1, + areaRng=area_rng, + maxDets=self.params.maxDets[-1], + cat_kind=cat_kind)) + for max_det in self.params.maxDets: + stats.append( + _summarize(0, maxDets=max_det, cat_kind=cat_kind)) + for area_rng in ('small', 'medium', 'large'): + stats.append( + _summarize( + 0, + areaRng=area_rng, + maxDets=self.params.maxDets[-1], + cat_kind=cat_kind)) + stats = np.array(stats) + + print() + print('-' * 45) + print(f'Compute OVD AP: (bAP + 3 * nAP) / 4 ' + f'= {(stats[0] + 3 * stats[10]) / 4.:.4f}') + print('-' * 45) + return stats + def _summarizeKps(): stats = np.zeros((10, )) stats[0] = _summarize(1, maxDets=20) @@ -293,4 +467,7 @@ def _summarizeKps(): summarize = _summarizeDets elif iouType == 'keypoints': summarize = _summarizeKps + if is_ovd: + assert iouType == 'bbox' + summarize = _summarizeOVDs self.stats = summarize()