-
Notifications
You must be signed in to change notification settings - Fork 1
/
compute_anchors_main.py
120 lines (98 loc) · 4.06 KB
/
compute_anchors_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Compute anchors for YOLOv7.
Code is adapted from
https://github.com/jinfagang/yolov7_d2/blob/main/tools/compute_anchors.py.
Original warning: this anchor only useful when your YOLO input is not force
resized, which means your input image is padding at bottom and right without any
distortion. Otherwise this anchor is WRONG because we don't using forced resize
as input such as 608 or 512, we just using original image size!
"""
from __future__ import annotations
from typing import Any
import detectron2
import numpy as np
import torch
from detectron2.data import build_detection_train_loader
from tqdm import tqdm
import adv_patch_bench.dataloaders.detectron.util as data_util
from adv_patch_bench.dataloaders.detectron import mtsd_dataset_mapper
from adv_patch_bench.utils.argparse import reap_args_parser, setup_detectron_cfg
NUM_CLUSTERS = 9
def compute_iou(box, clusters):
"""Compute IOU between a box and the clusters."""
x = np.minimum(clusters[:, 0], box[0])
y = np.minimum(clusters[:, 1], box[1])
if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0:
print("Box has no area")
return 0
intersection = x * y
box_area = box[0] * box[1]
cluster_area = clusters[:, 0] * clusters[:, 1]
iou_ = intersection / (box_area + cluster_area - intersection)
return iou_
def avg_iou(boxes, clusters):
"""Compute the average IOU of all boxes with respect to all clusters."""
return np.mean(
[np.max(compute_iou(boxes[i], clusters)) for i in range(boxes.shape[0])]
)
def run_kmeans_ious(boxes, k, dist=np.median):
"""Compute the k-means clustering of boxes."""
rows = boxes.shape[0]
distances = np.empty((rows, k))
last_clusters = np.zeros((rows,))
np.random.seed()
clusters = boxes[np.random.choice(rows, k, replace=False)]
while True:
for row in range(rows):
distances[row] = 1 - compute_iou(boxes[row], clusters)
nearest_clusters = np.argmin(distances, axis=1)
if (last_clusters == nearest_clusters).all():
break
for cluster in range(k):
clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0)
last_clusters = nearest_clusters
return clusters
def main():
"""Main function."""
config: dict[str, dict[str, Any]] = reap_args_parser(
True, is_gen_patch=False, is_train=True
)
config_base = config["base"]
cfg = setup_detectron_cfg(config, is_train=True)
# Register data. This has to be called by every process for some reason.
data_util.register_dataset(config["base"])
data_dicts = detectron2.data.DatasetCatalog.get(
config_base["dataset"] + "_train"
)
# pylint: disable=missing-kwoa,too-many-function-args
data_loader = build_detection_train_loader(
cfg,
mapper=mtsd_dataset_mapper.MtsdDatasetMapper(
cfg,
is_train=True,
config_base=config_base,
img_size=config_base["img_size"],
),
sampler=detectron2.data.samplers.InferenceSampler(len(data_dicts)),
)
bbox_sizes = []
for _, data in enumerate(tqdm(data_loader)):
for sample in data:
bbox = sample["instances"].gt_boxes
width = bbox.tensor[:, 2] - bbox.tensor[:, 0]
height = bbox.tensor[:, 3] - bbox.tensor[:, 1]
assert (width > 0).all() and (height > 0).all()
sizes = torch.stack([height, width], dim=0).T.tolist()
bbox_sizes.extend(sizes)
bbox_sizes = np.array(bbox_sizes)
print(f"Total number of bounding boxes: {len(bbox_sizes)}")
anchors = run_kmeans_ious(bbox_sizes, k=NUM_CLUSTERS)
print(f"Boxes:\n {anchors}")
print(f"Accuracy: {avg_iou(bbox_sizes, anchors) * 100:.2f}%")
areas = np.around(anchors[:, 0] * anchors[:, 1], decimals=2).tolist()
print(f"Area before sorted:\n {areas}")
idx = np.argsort(areas)[::-1]
print(f"Area after sorted:\n {sorted(areas)}")
final_anchors = anchors[idx].round().astype(np.int32).tolist()
print(f"Final anchor: {final_anchors}")
if __name__ == "__main__":
main()