Skip to content

Commit dfaf20d

Browse files
committed
issue #553 and #554 solved. now images can be numpy arrays and images with dimension less than 512 are scaled to use default herdnet values
1 parent 6d9bc9f commit dfaf20d

File tree

2 files changed

+57
-16
lines changed

2 files changed

+57
-16
lines changed

PytorchWildlife/models/detection/herdnet/herdnet.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,32 @@
66
import torch
77
from torch.hub import load_state_dict_from_url
88
from torch.utils.data import DataLoader
9-
import torchvision.transforms as transforms
9+
import torchvision.transforms as transforms
1010

1111
import numpy as np
1212
from PIL import Image
1313
from tqdm import tqdm
1414
import supervision as sv
1515
import os
1616
import wget
17+
import cv2
18+
19+
class ResizeIfSmaller:
20+
def __init__(self, min_size, interpolation=Image.BILINEAR):
21+
self.min_size = min_size
22+
self.interpolation = interpolation
23+
24+
def __call__(self, img):
25+
if isinstance(img, np.ndarray):
26+
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
27+
assert isinstance(img, Image.Image), "Image should be a PIL Image"
28+
width, height = img.size
29+
if height < self.min_size or width < self.min_size:
30+
ratio = max(self.min_size / height, self.min_size / width)
31+
new_height = int(height * ratio)
32+
new_width = int(width * ratio)
33+
img = img.resize((new_width, new_height), self.interpolation)
34+
return img
1735

1836
class HerdNet(BaseDetector):
1937
"""
@@ -60,6 +78,7 @@ def __init__(self, weights=None, device="cpu", version='general' ,url="https://z
6078

6179
if not transform:
6280
self.transforms = transforms.Compose([
81+
ResizeIfSmaller(512),
6382
transforms.ToTensor(),
6483
transforms.Normalize(mean=self.img_mean, std=self.img_std)
6584
])
@@ -90,7 +109,6 @@ def _load_model(self, weights=None, device="cpu", url=None):
90109
else:
91110
weights = os.path.join(torch.hub.get_dir(), "checkpoints", filename)
92111
checkpoint = torch.load(weights, map_location=torch.device(device))
93-
#checkpoint = load_state_dict_from_url(url, map_location=torch.device(self.device)) # NOTE: This function is not used in the current implementation
94112
else:
95113
raise Exception("Need weights for inference.")
96114

@@ -112,13 +130,15 @@ def _load_model(self, weights=None, device="cpu", url=None):
112130

113131
print(f"Model loaded from {weights}")
114132

115-
def results_generation(self, preds, img_id, id_strip=None):
133+
def results_generation(self, preds, img=None, img_id=None, id_strip=None):
116134
"""
117135
Generate results for detection based on model predictions.
118136
119137
Args:
120138
preds (numpy.ndarray):
121139
Model predictions.
140+
img (numpy.ndarray, optional):
141+
Image for inference. Defaults to None.
122142
img_id (str):
123143
Image identifier.
124144
id_strip (str, optional):
@@ -127,7 +147,13 @@ def results_generation(self, preds, img_id, id_strip=None):
127147
Returns:
128148
dict: Dictionary containing image ID, detections, and labels.
129149
"""
130-
results = {"img_id": str(img_id).strip(id_strip) if id_strip else str(img_id)}
150+
assert img is not None or img_id is not None, "Either img or img_id should be provided."
151+
if img_id is not None:
152+
img_id = str(img_id).strip(id_strip) if id_strip else str(img_id)
153+
results = {"img_id": img_id}
154+
elif img is not None:
155+
results = {"img": img}
156+
131157
results["detections"] = sv.Detections(
132158
xyxy=preds[:, :4],
133159
confidence=preds[:, 4],
@@ -157,7 +183,7 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, clf_con
157183
158184
Returns:
159185
dict: Detection results for the image.
160-
"""
186+
"""
161187
if isinstance(img, str):
162188
img_path = img_path or img
163189
img = np.array(Image.open(img_path).convert("RGB"))
@@ -168,8 +194,11 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, clf_con
168194
heatmap, clsmap = preds[:,:1,:,:], preds[:,1:,:,:]
169195
counts, locs, labels, scores, dscores = self.lmds((heatmap, clsmap))
170196
preds_array = self.process_lmds_results(counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres)
171-
return self.results_generation(preds_array, img_path, id_strip=id_strip)
172-
197+
if img_path:
198+
results_dict = self.results_generation(preds_array, img_id=img_path, id_strip=id_strip)
199+
else:
200+
results_dict = self.results_generation(preds_array, img=img)
201+
return results_dict
173202

174203
def batch_image_detection(self, data_path, det_conf_thres=0.2, clf_conf_thres=0.2, batch_size=1, id_strip=None):
175204
"""
@@ -207,7 +236,7 @@ def batch_image_detection(self, data_path, det_conf_thres=0.2, clf_conf_thres=0.
207236
heatmap, clsmap = predictions[:,:1,:,:], predictions[:,1:,:,:]
208237
counts, locs, labels, scores, dscores = self.lmds((heatmap, clsmap))
209238
preds_array = self.process_lmds_results(counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres)
210-
results_dict = self.results_generation(preds_array, paths[0], id_strip=id_strip)
239+
results_dict = self.results_generation(preds_array, img_id=paths[0], id_strip=id_strip)
211240
pbar.update(1)
212241
sizes = sizes.numpy()
213242
normalized_coords = [[x1 / sizes[0][0], y1 / sizes[0][1], x2 / sizes[0][0], y2 / sizes[0][1]] for x1, y1, x2, y2 in preds_array[:, :4]] # TODO: Check if this is correct due to xy swapping

PytorchWildlife/utils/post_process.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,17 @@ def save_detection_images_dots(results, output_dir, input_dir = None, overwrite=
9898

9999
with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink:
100100
if isinstance(results, list):
101-
for entry in results:
101+
for i, entry in enumerate(results):
102+
if "img_id" in entry:
103+
scene = np.array(Image.open(entry["img_id"]).convert("RGB"))
104+
image_name = os.path.basename(entry["img_id"])
105+
else:
106+
scene = entry["img"]
107+
image_name = f"output_image_{i}.jpg" # default name if no image id is provided
108+
102109
annotated_img = lab_annotator.annotate(
103110
scene=dot_annotator.annotate(
104-
scene=np.array(Image.open(entry["img_id"]).convert("RGB")),
111+
scene=scene,
105112
detections=entry["detections"],
106113
),
107114
detections=entry["detections"],
@@ -111,23 +118,28 @@ def save_detection_images_dots(results, output_dir, input_dir = None, overwrite=
111118
relative_path = os.path.relpath(entry["img_id"], input_dir)
112119
save_path = os.path.join(output_dir, relative_path)
113120
os.makedirs(os.path.dirname(save_path), exist_ok=True)
114-
image_name = relative_path
115-
else:
116-
image_name = os.path.basename(entry["img_id"])
121+
image_name = relative_path
117122
sink.save_image(
118123
image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=image_name
119124
)
120125
else:
126+
if "img_id" in results:
127+
scene = np.array(Image.open(results["img_id"]).convert("RGB"))
128+
image_name = os.path.basename(results["img_id"])
129+
else:
130+
scene = results["img"]
131+
image_name = "output_image.jpg" # default name if no image id is provided
132+
121133
annotated_img = lab_annotator.annotate(
122134
scene=dot_annotator.annotate(
123-
scene=np.array(Image.open(results["img_id"]).convert("RGB")),
135+
scene=scene,
124136
detections=results["detections"],
125137
),
126138
detections=results["detections"],
127139
labels=results["labels"],
128-
)
140+
)
129141
sink.save_image(
130-
image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=os.path.basename(results["img_id"])
142+
image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=image_name
131143
)
132144

133145

0 commit comments

Comments
 (0)