6
6
import torch
7
7
from torch .hub import load_state_dict_from_url
8
8
from torch .utils .data import DataLoader
9
- import torchvision .transforms as transforms
9
+ import torchvision .transforms as transforms
10
10
11
11
import numpy as np
12
12
from PIL import Image
13
13
from tqdm import tqdm
14
14
import supervision as sv
15
15
import os
16
16
import wget
17
+ import cv2
18
+
19
+ class ResizeIfSmaller :
20
+ def __init__ (self , min_size , interpolation = Image .BILINEAR ):
21
+ self .min_size = min_size
22
+ self .interpolation = interpolation
23
+
24
+ def __call__ (self , img ):
25
+ if isinstance (img , np .ndarray ):
26
+ img = Image .fromarray (cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB ))
27
+ assert isinstance (img , Image .Image ), "Image should be a PIL Image"
28
+ width , height = img .size
29
+ if height < self .min_size or width < self .min_size :
30
+ ratio = max (self .min_size / height , self .min_size / width )
31
+ new_height = int (height * ratio )
32
+ new_width = int (width * ratio )
33
+ img = img .resize ((new_width , new_height ), self .interpolation )
34
+ return img
17
35
18
36
class HerdNet (BaseDetector ):
19
37
"""
@@ -60,6 +78,7 @@ def __init__(self, weights=None, device="cpu", version='general' ,url="https://z
60
78
61
79
if not transform :
62
80
self .transforms = transforms .Compose ([
81
+ ResizeIfSmaller (512 ),
63
82
transforms .ToTensor (),
64
83
transforms .Normalize (mean = self .img_mean , std = self .img_std )
65
84
])
@@ -90,7 +109,6 @@ def _load_model(self, weights=None, device="cpu", url=None):
90
109
else :
91
110
weights = os .path .join (torch .hub .get_dir (), "checkpoints" , filename )
92
111
checkpoint = torch .load (weights , map_location = torch .device (device ))
93
- #checkpoint = load_state_dict_from_url(url, map_location=torch.device(self.device)) # NOTE: This function is not used in the current implementation
94
112
else :
95
113
raise Exception ("Need weights for inference." )
96
114
@@ -112,13 +130,15 @@ def _load_model(self, weights=None, device="cpu", url=None):
112
130
113
131
print (f"Model loaded from { weights } " )
114
132
115
- def results_generation (self , preds , img_id , id_strip = None ):
133
+ def results_generation (self , preds , img = None , img_id = None , id_strip = None ):
116
134
"""
117
135
Generate results for detection based on model predictions.
118
136
119
137
Args:
120
138
preds (numpy.ndarray):
121
139
Model predictions.
140
+ img (numpy.ndarray, optional):
141
+ Image for inference. Defaults to None.
122
142
img_id (str):
123
143
Image identifier.
124
144
id_strip (str, optional):
@@ -127,7 +147,13 @@ def results_generation(self, preds, img_id, id_strip=None):
127
147
Returns:
128
148
dict: Dictionary containing image ID, detections, and labels.
129
149
"""
130
- results = {"img_id" : str (img_id ).strip (id_strip ) if id_strip else str (img_id )}
150
+ assert img is not None or img_id is not None , "Either img or img_id should be provided."
151
+ if img_id is not None :
152
+ img_id = str (img_id ).strip (id_strip ) if id_strip else str (img_id )
153
+ results = {"img_id" : img_id }
154
+ elif img is not None :
155
+ results = {"img" : img }
156
+
131
157
results ["detections" ] = sv .Detections (
132
158
xyxy = preds [:, :4 ],
133
159
confidence = preds [:, 4 ],
@@ -157,7 +183,7 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, clf_con
157
183
158
184
Returns:
159
185
dict: Detection results for the image.
160
- """
186
+ """
161
187
if isinstance (img , str ):
162
188
img_path = img_path or img
163
189
img = np .array (Image .open (img_path ).convert ("RGB" ))
@@ -168,8 +194,11 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, clf_con
168
194
heatmap , clsmap = preds [:,:1 ,:,:], preds [:,1 :,:,:]
169
195
counts , locs , labels , scores , dscores = self .lmds ((heatmap , clsmap ))
170
196
preds_array = self .process_lmds_results (counts , locs , labels , scores , dscores , det_conf_thres , clf_conf_thres )
171
- return self .results_generation (preds_array , img_path , id_strip = id_strip )
172
-
197
+ if img_path :
198
+ results_dict = self .results_generation (preds_array , img_id = img_path , id_strip = id_strip )
199
+ else :
200
+ results_dict = self .results_generation (preds_array , img = img )
201
+ return results_dict
173
202
174
203
def batch_image_detection (self , data_path , det_conf_thres = 0.2 , clf_conf_thres = 0.2 , batch_size = 1 , id_strip = None ):
175
204
"""
@@ -207,7 +236,7 @@ def batch_image_detection(self, data_path, det_conf_thres=0.2, clf_conf_thres=0.
207
236
heatmap , clsmap = predictions [:,:1 ,:,:], predictions [:,1 :,:,:]
208
237
counts , locs , labels , scores , dscores = self .lmds ((heatmap , clsmap ))
209
238
preds_array = self .process_lmds_results (counts , locs , labels , scores , dscores , det_conf_thres , clf_conf_thres )
210
- results_dict = self .results_generation (preds_array , paths [0 ], id_strip = id_strip )
239
+ results_dict = self .results_generation (preds_array , img_id = paths [0 ], id_strip = id_strip )
211
240
pbar .update (1 )
212
241
sizes = sizes .numpy ()
213
242
normalized_coords = [[x1 / sizes [0 ][0 ], y1 / sizes [0 ][1 ], x2 / sizes [0 ][0 ], y2 / sizes [0 ][1 ]] for x1 , y1 , x2 , y2 in preds_array [:, :4 ]] # TODO: Check if this is correct due to xy swapping
0 commit comments