@@ -686,6 +686,10 @@ def post_process(self,
686
686
scale_factor ,
687
687
infer_shape = [640 , 640 ],
688
688
rescale = True ):
689
+
690
+ if getattr (self , "export_mode" , False ):
691
+ return self .export_post_process (head_outs , im_shape , scale_factor , infer_shape , rescale )
692
+
689
693
pred_scores , pred_dist , pred_mask_coeffs , mask_feat , anchor_points , stride_tensor = head_outs
690
694
691
695
pred_bboxes = batch_distance2bbox (anchor_points , pred_dist )
@@ -740,4 +744,55 @@ def post_process(self,
740
744
bbox_pred = paddle .zeros ([bbox_num , 6 ])
741
745
mask_pred = paddle .zeros ([bbox_num , int (ori_h ), int (ori_w )])
742
746
747
+ return bbox_pred , bbox_num , mask_pred , keep_idxs
748
+
749
+ def export_post_process (self ,
750
+ head_outs ,
751
+ im_shape ,
752
+ scale_factor ,
753
+ infer_shape = [640 , 640 ],
754
+ rescale = True ):
755
+ pred_scores , pred_dist , pred_mask_coeffs , mask_feat , anchor_points , stride_tensor = head_outs
756
+
757
+ pred_bboxes = batch_distance2bbox (anchor_points , pred_dist )
758
+ pred_bboxes *= stride_tensor
759
+
760
+ if self .exclude_post_process :
761
+ return paddle .concat ([
762
+ pred_bboxes ,
763
+ pred_scores .transpose ([0 , 2 , 1 ]),
764
+ pred_mask_coeffs .transpose ([0 , 2 , 1 ])
765
+ ],
766
+ axis = - 1 ), mask_feat , None
767
+ # [1, 8400, 4+80+32], [1, 32, 160, 160]
768
+
769
+ bbox_pred , bbox_num , keep_idxs = self .nms (pred_bboxes , pred_scores )
770
+
771
+ pred_mask_coeffs = pred_mask_coeffs .transpose ([0 , 2 , 1 ])
772
+ mask_coeffs = paddle .gather (
773
+ pred_mask_coeffs .reshape ([- 1 , self .num_masks ]), keep_idxs )
774
+
775
+ mask_logits = process_mask_upsample (mask_feat [0 ], mask_coeffs ,
776
+ bbox_pred [:, 2 :6 ], infer_shape )
777
+ if rescale :
778
+ ori_h , ori_w = im_shape [0 ] / scale_factor [0 ]
779
+ mask_logits = F .interpolate (
780
+ mask_logits .unsqueeze (0 ),
781
+ size = [
782
+ int (paddle .round (mask_logits .shape [- 2 ] /
783
+ scale_factor [0 ][0 ])),
784
+ int (paddle .round (mask_logits .shape [- 1 ] /
785
+ scale_factor [0 ][1 ]))
786
+ ],
787
+ mode = 'bilinear' ,
788
+ align_corners = False )
789
+ mask_logits = mask_logits [..., :int (ori_h ), :int (ori_w )]
790
+
791
+ masks = mask_logits .squeeze (0 )
792
+ mask_pred = masks > self .mask_thr_binary
793
+
794
+ # scale bbox to origin
795
+ scale_factor = scale_factor .flip (- 1 ).tile ([1 , 2 ])
796
+ bbox_pred [:, 2 :6 ] /= scale_factor
797
+
743
798
return bbox_pred , bbox_num , mask_pred , keep_idxs
0 commit comments