1
1
import tvm
2
2
from tvm import relay
3
3
from tvm .contrib .download import download
4
+ from tvm .runtime .vm import VirtualMachine
4
5
from tvm .relay .dataflow_pattern import *
5
6
6
7
import numpy as np
@@ -23,7 +24,12 @@ def do_trace(model, inp):
23
24
24
25
def dict_to_tuple (out_dict ):
25
26
if "masks" in out_dict .keys ():
26
- return out_dict ["boxes" ], out_dict ["scores" ], out_dict ["labels" ], out_dict ["masks" ]
27
+ return (
28
+ out_dict ["boxes" ],
29
+ out_dict ["scores" ],
30
+ out_dict ["labels" ],
31
+ out_dict ["masks" ],
32
+ )
27
33
return out_dict ["boxes" ], out_dict ["scores" ], out_dict ["labels" ]
28
34
29
35
@@ -40,7 +46,8 @@ def forward(self, inp):
40
46
def get_input ():
41
47
img_path = "test_street_small.jpg"
42
48
img_url = (
43
- "https://raw.githubusercontent.com/dmlc/web-data/" "master/gluoncv/detection/street_small.jpg"
49
+ "https://raw.githubusercontent.com/dmlc/web-data/"
50
+ "master/gluoncv/detection/street_small.jpg"
44
51
)
45
52
download (img_url , img_path )
46
53
@@ -52,12 +59,7 @@ def get_input():
52
59
return img
53
60
54
61
55
- def batched_nms_pattern ():
56
- # exprs I want to extract
57
- boxes = wildcard ()
58
- scores = wildcard ()
59
- idxs = wildcard ()
60
-
62
+ def batched_nms_pattern (boxes , scores , idxs , iou_threshold ):
61
63
one = is_constant ()
62
64
zero = is_constant ()
63
65
@@ -90,24 +92,117 @@ def batched_nms_pattern():
90
92
expand_dims = is_op ("expand_dims" )(concat )
91
93
92
94
# %1842 = vision.get_valid_counts(%1841, -1f, meta[relay.attrs.GetValidCountsAttrs][1]);
93
- return is_op ("vision.get_valid_counts" )(expand_dims , is_constant (), wildcard ())
95
+ get_valid_counts_out = is_op ("vision.get_valid_counts" )(expand_dims , is_constant ())
96
+ data = is_tuple_get_item (get_valid_counts_out , 1 )
97
+ valid_counts = is_tuple_get_item (get_valid_counts_out , 0 )
98
+ indices = is_tuple_get_item (get_valid_counts_out , 2 )
99
+ return is_op ("vision.non_max_suppression" )(
100
+ data , valid_counts , indices , is_constant (), iou_threshold
101
+ )
94
102
95
103
96
- model_func = torchvision .models .detection .maskrcnn_resnet50_fpn
97
- model = TraceWrapper (model_func (pretrained = True ))
98
- model .eval ()
104
+ def convert_batched_nms (boxes , scores , idxs , iou_thres ):
105
+ scores = relay .expand_dims (scores , axis = - 1 , num_newaxis = 1 )
106
+ idxs = relay .expand_dims (idxs , axis = - 1 , num_newaxis = 1 )
107
+ idxs = relay .cast (idxs , "float32" )
108
+ data = relay .concatenate ([idxs , scores , boxes ], - 1 )
109
+ data = relay .expand_dims (data , 0 , 1 )
110
+ ct , data , indices = relay .op .vision .get_valid_counts (
111
+ data , score_threshold = - 1.0 , id_index = 0 , score_index = 1
112
+ )
113
+ top_k = max_out_size = - 1
114
+ out = relay .op .vision .non_max_suppression (
115
+ data = data ,
116
+ valid_count = ct ,
117
+ indices = indices ,
118
+ max_output_size = max_out_size ,
119
+ iou_threshold = iou_thres ,
120
+ force_suppress = True ,
121
+ top_k = top_k ,
122
+ coord_start = 1 ,
123
+ score_index = 1 ,
124
+ id_index = 0 ,
125
+ return_indices = True ,
126
+ invalid_to_bottom = False ,
127
+ )
128
+ return out .tuple_value
99
129
100
- img = get_input ()
101
- inp = torch .from_numpy (img )
102
130
103
- with torch .no_grad ():
104
- out = model (inp )
105
- script_module = do_trace (model , inp )
131
+ class NMSRewrite (DFPatternCallback ):
132
+ def __init__ (self ):
133
+ super ().__init__ ()
134
+ # exprs I want to extract
135
+ self .boxes = wildcard ()
136
+ self .scores = wildcard ()
137
+ self .idxs = wildcard ()
138
+ self .iou_threshold = wildcard ()
139
+
140
+ self .pattern = batched_nms_pattern (
141
+ self .boxes , self .scores , self .idxs , self .iou_threshold
142
+ )
143
+
144
+ def callback (self , pre , post , node_map ):
145
+ print ("matched" )
146
+ boxes = node_map [self .boxes ][0 ]
147
+ scores = node_map [self .scores ][0 ]
148
+ idxs = node_map [self .idxs ][0 ]
149
+ iou_thres = node_map [self .iou_threshold ][0 ]
150
+ return convert_batched_nms (boxes , scores , idxs , iou_thres )
151
+
106
152
107
153
input_name = "input0"
108
- shape_list = [(input_name , input_shape )]
109
- mod , params = relay .frontend .from_pytorch (script_module , shape_list )
154
+ img = get_input ()
155
+
156
+ # model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
157
+ # model = TraceWrapper(model_func(pretrained=True))
158
+ # model.eval()
159
+
160
+ # inp = torch.from_numpy(img)
161
+
162
+ # with torch.no_grad():
163
+ # out = model(inp)
164
+ # script_module = do_trace(model, inp)
165
+
166
+
167
+ # shape_list = [(input_name, input_shape)]
168
+ # mod, params = relay.frontend.from_pytorch(script_module, shape_list)
169
+
170
+ # with open("maskrcnn_mod.json", "w") as fo:
171
+ # fo.write(tvm.ir.save_json(mod))
172
+ # with open("maskrcnn.params", "wb") as fo:
173
+ # fo.write(relay.save_param_dict(params))
174
+ # print(mod["main"])
175
+ with open ("maskrcnn_mod.json" , "r" ) as fi :
176
+ mod = tvm .ir .load_json (fi .read ())
177
+ with open ("maskrcnn.params" , "rb" ) as fi :
178
+ params = relay .load_param_dict (fi .read ())
179
+
180
+ mod ["main" ] = rewrite (NMSRewrite (), mod ["main" ])
110
181
# print(mod["main"])
111
182
112
- pat = batched_nms_pattern ()
113
- print (pat .match (mod ["main" ].body ))
183
+ target = "llvm"
184
+
185
+ with tvm .transform .PassContext (opt_level = 3 , disabled_pass = ["FoldScaleAxis" ]):
186
+ vm_exec = relay .vm .compile (mod , target = target , params = params )
187
+
188
+ ######################################################################
189
+ # Inference with Relay VM
190
+ # -----------------------
191
+ ctx = tvm .cpu ()
192
+ vm = VirtualMachine (vm_exec , ctx )
193
+ vm .set_input ("main" , ** {input_name : img })
194
+ tvm_res = vm .run ()
195
+
196
+ ######################################################################
197
+ # Get boxes with score larger than 0.9
198
+ # ------------------------------------
199
+ score_threshold = 0.9
200
+ boxes = tvm_res [0 ].asnumpy ().tolist ()
201
+ valid_boxes = []
202
+ for i , score in enumerate (tvm_res [1 ].asnumpy ().tolist ()):
203
+ if score > score_threshold :
204
+ valid_boxes .append (boxes [i ])
205
+ else :
206
+ break
207
+
208
+ print ("Get {} valid boxes" .format (len (valid_boxes )))
0 commit comments