Skip to content

Commit 63c1c23

Browse files
committed
batched nms rewrite working
1 parent 57543f5 commit 63c1c23

File tree

2 files changed

+116
-21
lines changed

2 files changed

+116
-21
lines changed

nms-perf-issue/maskrcnn_test.py

Lines changed: 116 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import tvm
22
from tvm import relay
33
from tvm.contrib.download import download
4+
from tvm.runtime.vm import VirtualMachine
45
from tvm.relay.dataflow_pattern import *
56

67
import numpy as np
@@ -23,7 +24,12 @@ def do_trace(model, inp):
2324

2425
def dict_to_tuple(out_dict):
2526
if "masks" in out_dict.keys():
26-
return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
27+
return (
28+
out_dict["boxes"],
29+
out_dict["scores"],
30+
out_dict["labels"],
31+
out_dict["masks"],
32+
)
2733
return out_dict["boxes"], out_dict["scores"], out_dict["labels"]
2834

2935

@@ -40,7 +46,8 @@ def forward(self, inp):
4046
def get_input():
4147
img_path = "test_street_small.jpg"
4248
img_url = (
43-
"https://raw.githubusercontent.com/dmlc/web-data/" "master/gluoncv/detection/street_small.jpg"
49+
"https://raw.githubusercontent.com/dmlc/web-data/"
50+
"master/gluoncv/detection/street_small.jpg"
4451
)
4552
download(img_url, img_path)
4653

@@ -52,12 +59,7 @@ def get_input():
5259
return img
5360

5461

55-
def batched_nms_pattern():
56-
# exprs I want to extract
57-
boxes = wildcard()
58-
scores = wildcard()
59-
idxs = wildcard()
60-
62+
def batched_nms_pattern(boxes, scores, idxs, iou_threshold):
6163
one = is_constant()
6264
zero = is_constant()
6365

@@ -90,24 +92,117 @@ def batched_nms_pattern():
9092
expand_dims = is_op("expand_dims")(concat)
9193

9294
# %1842 = vision.get_valid_counts(%1841, -1f, meta[relay.attrs.GetValidCountsAttrs][1]);
93-
return is_op("vision.get_valid_counts")(expand_dims, is_constant(), wildcard())
95+
get_valid_counts_out = is_op("vision.get_valid_counts")(expand_dims, is_constant())
96+
data = is_tuple_get_item(get_valid_counts_out, 1)
97+
valid_counts = is_tuple_get_item(get_valid_counts_out, 0)
98+
indices = is_tuple_get_item(get_valid_counts_out, 2)
99+
return is_op("vision.non_max_suppression")(
100+
data, valid_counts, indices, is_constant(), iou_threshold
101+
)
94102

95103

96-
model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
97-
model = TraceWrapper(model_func(pretrained=True))
98-
model.eval()
104+
def convert_batched_nms(boxes, scores, idxs, iou_thres):
105+
scores = relay.expand_dims(scores, axis=-1, num_newaxis=1)
106+
idxs = relay.expand_dims(idxs, axis=-1, num_newaxis=1)
107+
idxs = relay.cast(idxs, "float32")
108+
data = relay.concatenate([idxs, scores, boxes], -1)
109+
data = relay.expand_dims(data, 0, 1)
110+
ct, data, indices = relay.op.vision.get_valid_counts(
111+
data, score_threshold=-1.0, id_index=0, score_index=1
112+
)
113+
top_k = max_out_size = -1
114+
out = relay.op.vision.non_max_suppression(
115+
data=data,
116+
valid_count=ct,
117+
indices=indices,
118+
max_output_size=max_out_size,
119+
iou_threshold=iou_thres,
120+
force_suppress=True,
121+
top_k=top_k,
122+
coord_start=1,
123+
score_index=1,
124+
id_index=0,
125+
return_indices=True,
126+
invalid_to_bottom=False,
127+
)
128+
return out.tuple_value
99129

100-
img = get_input()
101-
inp = torch.from_numpy(img)
102130

103-
with torch.no_grad():
104-
out = model(inp)
105-
script_module = do_trace(model, inp)
131+
class NMSRewrite(DFPatternCallback):
132+
def __init__(self):
133+
super().__init__()
134+
# exprs I want to extract
135+
self.boxes = wildcard()
136+
self.scores = wildcard()
137+
self.idxs = wildcard()
138+
self.iou_threshold = wildcard()
139+
140+
self.pattern = batched_nms_pattern(
141+
self.boxes, self.scores, self.idxs, self.iou_threshold
142+
)
143+
144+
def callback(self, pre, post, node_map):
145+
print("matched")
146+
boxes = node_map[self.boxes][0]
147+
scores = node_map[self.scores][0]
148+
idxs = node_map[self.idxs][0]
149+
iou_thres = node_map[self.iou_threshold][0]
150+
return convert_batched_nms(boxes, scores, idxs, iou_thres)
151+
106152

107153
input_name = "input0"
108-
shape_list = [(input_name, input_shape)]
109-
mod, params = relay.frontend.from_pytorch(script_module, shape_list)
154+
img = get_input()
155+
156+
# model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
157+
# model = TraceWrapper(model_func(pretrained=True))
158+
# model.eval()
159+
160+
# inp = torch.from_numpy(img)
161+
162+
# with torch.no_grad():
163+
# out = model(inp)
164+
# script_module = do_trace(model, inp)
165+
166+
167+
# shape_list = [(input_name, input_shape)]
168+
# mod, params = relay.frontend.from_pytorch(script_module, shape_list)
169+
170+
# with open("maskrcnn_mod.json", "w") as fo:
171+
# fo.write(tvm.ir.save_json(mod))
172+
# with open("maskrcnn.params", "wb") as fo:
173+
# fo.write(relay.save_param_dict(params))
174+
# print(mod["main"])
175+
with open("maskrcnn_mod.json", "r") as fi:
176+
mod = tvm.ir.load_json(fi.read())
177+
with open("maskrcnn.params", "rb") as fi:
178+
params = relay.load_param_dict(fi.read())
179+
180+
mod["main"] = rewrite(NMSRewrite(), mod["main"])
110181
# print(mod["main"])
111182

112-
pat = batched_nms_pattern()
113-
print(pat.match(mod["main"].body))
183+
target = "llvm"
184+
185+
with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
186+
vm_exec = relay.vm.compile(mod, target=target, params=params)
187+
188+
######################################################################
189+
# Inference with Relay VM
190+
# -----------------------
191+
ctx = tvm.cpu()
192+
vm = VirtualMachine(vm_exec, ctx)
193+
vm.set_input("main", **{input_name: img})
194+
tvm_res = vm.run()
195+
196+
######################################################################
197+
# Get boxes with score larger than 0.9
198+
# ------------------------------------
199+
score_threshold = 0.9
200+
boxes = tvm_res[0].asnumpy().tolist()
201+
valid_boxes = []
202+
for i, score in enumerate(tvm_res[1].asnumpy().tolist()):
203+
if score > score_threshold:
204+
valid_boxes.append(boxes[i])
205+
else:
206+
break
207+
208+
print("Get {} valid boxes".format(len(valid_boxes)))

nms-perf-issue/test_street_small.jpg

116 KB
Loading

0 commit comments

Comments
 (0)