66import  torch 
77from  torch .hub  import  load_state_dict_from_url 
88from  torch .utils .data  import  DataLoader 
9- import  torchvision .transforms  as  transforms    
9+ import  torchvision .transforms  as  transforms 
1010
1111import  numpy  as  np 
1212from  PIL  import  Image 
1313from  tqdm  import  tqdm 
1414import  supervision  as  sv 
1515import  os 
1616import  wget 
17+ import  cv2 
18+   
19+ class  ResizeIfSmaller :  
20+     def  __init__ (self , min_size , interpolation = Image .BILINEAR ):  
21+         self .min_size  =  min_size   
22+         self .interpolation  =  interpolation   
23+   
24+     def  __call__ (self , img ):
25+         if  isinstance (img , np .ndarray ):
26+             img  =  Image .fromarray (cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB ))
27+         assert  isinstance (img , Image .Image ), "Image should be a PIL Image"   
28+         width , height  =  img .size 
29+         if  height  <  self .min_size  or  width  <  self .min_size :  
30+             ratio  =  max (self .min_size  /  height , self .min_size  /  width )  
31+             new_height  =  int (height  *  ratio )  
32+             new_width  =  int (width  *  ratio )  
33+             img  =  img .resize ((new_width , new_height ), self .interpolation )  
34+         return  img  
1735
1836class  HerdNet (BaseDetector ):
1937    """ 
@@ -60,6 +78,7 @@ def __init__(self, weights=None, device="cpu", version='general' ,url="https://z
6078
6179        if  not  transform :
6280            self .transforms  =  transforms .Compose ([
81+                 ResizeIfSmaller (512 ),
6382                transforms .ToTensor (),
6483                transforms .Normalize (mean = self .img_mean , std = self .img_std )  
6584                ]) 
@@ -90,7 +109,6 @@ def _load_model(self, weights=None, device="cpu", url=None):
90109            else :
91110                weights  =  os .path .join (torch .hub .get_dir (), "checkpoints" , filename )
92111            checkpoint  =  torch .load (weights , map_location = torch .device (device ))
93-             #checkpoint = load_state_dict_from_url(url, map_location=torch.device(self.device)) # NOTE: This function is not used in the current implementation 
94112        else :
95113            raise  Exception ("Need weights for inference." )
96114
@@ -112,13 +130,15 @@ def _load_model(self, weights=None, device="cpu", url=None):
112130
113131        print (f"Model loaded from { weights }  )
114132
115-     def  results_generation (self , preds , img_id , id_strip = None ):
133+     def  results_generation (self , preds , img = None ,  img_id = None , id_strip = None ):
116134        """ 
117135        Generate results for detection based on model predictions. 
118136         
119137        Args: 
120138            preds (numpy.ndarray):  
121139                Model predictions. 
140+             img (numpy.ndarray, optional): 
141+                 Image for inference. Defaults to None. 
122142            img_id (str):  
123143                Image identifier. 
124144            id_strip (str, optional):  
@@ -127,7 +147,13 @@ def results_generation(self, preds, img_id, id_strip=None):
127147        Returns: 
128148            dict: Dictionary containing image ID, detections, and labels. 
129149        """ 
130-         results  =  {"img_id" : str (img_id ).strip (id_strip ) if  id_strip  else  str (img_id )}
150+         assert  img  is  not None  or  img_id  is  not None , "Either img or img_id should be provided." 
151+         if  img_id  is  not None :
152+             img_id  =  str (img_id ).strip (id_strip ) if  id_strip  else  str (img_id )
153+             results  =  {"img_id" : img_id }
154+         elif  img  is  not None :
155+             results  =  {"img" : img }
156+ 
131157        results ["detections" ] =  sv .Detections (
132158            xyxy = preds [:, :4 ],
133159            confidence = preds [:, 4 ],
@@ -157,7 +183,7 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, clf_con
157183
158184        Returns: 
159185            dict: Detection results for the image. 
160-         """    
186+         """ 
161187        if  isinstance (img , str ):  
162188            img_path  =  img_path  or  img   
163189            img  =  np .array (Image .open (img_path ).convert ("RGB" ))  
@@ -168,8 +194,11 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, clf_con
168194        heatmap , clsmap  =  preds [:,:1 ,:,:], preds [:,1 :,:,:]  
169195        counts , locs , labels , scores , dscores  =  self .lmds ((heatmap , clsmap ))
170196        preds_array  =  self .process_lmds_results (counts , locs , labels , scores , dscores , det_conf_thres , clf_conf_thres )
171-         return  self .results_generation (preds_array , img_path , id_strip = id_strip )  
172- 
197+         if  img_path :
198+             results_dict  =  self .results_generation (preds_array , img_id = img_path , id_strip = id_strip )
199+         else :
200+             results_dict  =  self .results_generation (preds_array , img = img )
201+         return  results_dict 
173202
174203    def  batch_image_detection (self , data_path , det_conf_thres = 0.2 , clf_conf_thres = 0.2 , batch_size = 1 , id_strip = None ):
175204        """ 
@@ -207,7 +236,7 @@ def batch_image_detection(self, data_path, det_conf_thres=0.2, clf_conf_thres=0.
207236                heatmap , clsmap  =  predictions [:,:1 ,:,:], predictions [:,1 :,:,:]
208237                counts , locs , labels , scores , dscores  =  self .lmds ((heatmap , clsmap ))
209238                preds_array  =  self .process_lmds_results (counts , locs , labels , scores , dscores , det_conf_thres , clf_conf_thres ) 
210-                 results_dict  =  self .results_generation (preds_array , paths [0 ], id_strip = id_strip )
239+                 results_dict  =  self .results_generation (preds_array , img_id = paths [0 ], id_strip = id_strip )
211240                pbar .update (1 )
212241                sizes  =  sizes .numpy ()
213242                normalized_coords  =  [[x1  /  sizes [0 ][0 ], y1  /  sizes [0 ][1 ], x2  /  sizes [0 ][0 ], y2  /  sizes [0 ][1 ]] for  x1 , y1 , x2 , y2  in  preds_array [:, :4 ]] # TODO: Check if this is correct due to xy swapping  
0 commit comments