Skip to content

Commit

Permalink
issue #553 and #554 solved. now images can be numpy arrays and images…
Browse files Browse the repository at this point in the history
… with dimension less than 512 are scaled to use default herdnet values
  • Loading branch information
idchacon28 committed Jan 25, 2025
1 parent 6d9bc9f commit dfaf20d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
45 changes: 37 additions & 8 deletions PytorchWildlife/models/detection/herdnet/herdnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,32 @@
import torch
from torch.hub import load_state_dict_from_url
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.transforms as transforms

import numpy as np
from PIL import Image
from tqdm import tqdm
import supervision as sv
import os
import wget
import cv2

class ResizeIfSmaller:
def __init__(self, min_size, interpolation=Image.BILINEAR):
self.min_size = min_size
self.interpolation = interpolation

def __call__(self, img):
if isinstance(img, np.ndarray):
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
assert isinstance(img, Image.Image), "Image should be a PIL Image"
width, height = img.size
if height < self.min_size or width < self.min_size:
ratio = max(self.min_size / height, self.min_size / width)
new_height = int(height * ratio)
new_width = int(width * ratio)
img = img.resize((new_width, new_height), self.interpolation)
return img

class HerdNet(BaseDetector):
"""
Expand Down Expand Up @@ -60,6 +78,7 @@ def __init__(self, weights=None, device="cpu", version='general' ,url="https://z

if not transform:
self.transforms = transforms.Compose([
ResizeIfSmaller(512),
transforms.ToTensor(),
transforms.Normalize(mean=self.img_mean, std=self.img_std)
])
Expand Down Expand Up @@ -90,7 +109,6 @@ def _load_model(self, weights=None, device="cpu", url=None):
else:
weights = os.path.join(torch.hub.get_dir(), "checkpoints", filename)
checkpoint = torch.load(weights, map_location=torch.device(device))
#checkpoint = load_state_dict_from_url(url, map_location=torch.device(self.device)) # NOTE: This function is not used in the current implementation
else:
raise Exception("Need weights for inference.")

Expand All @@ -112,13 +130,15 @@ def _load_model(self, weights=None, device="cpu", url=None):

print(f"Model loaded from {weights}")

def results_generation(self, preds, img_id, id_strip=None):
def results_generation(self, preds, img=None, img_id=None, id_strip=None):
"""
Generate results for detection based on model predictions.
Args:
preds (numpy.ndarray):
Model predictions.
img (numpy.ndarray, optional):
Image for inference. Defaults to None.
img_id (str):
Image identifier.
id_strip (str, optional):
Expand All @@ -127,7 +147,13 @@ def results_generation(self, preds, img_id, id_strip=None):
Returns:
dict: Dictionary containing image ID, detections, and labels.
"""
results = {"img_id": str(img_id).strip(id_strip) if id_strip else str(img_id)}
assert img is not None or img_id is not None, "Either img or img_id should be provided."
if img_id is not None:
img_id = str(img_id).strip(id_strip) if id_strip else str(img_id)
results = {"img_id": img_id}
elif img is not None:
results = {"img": img}

results["detections"] = sv.Detections(
xyxy=preds[:, :4],
confidence=preds[:, 4],
Expand Down Expand Up @@ -157,7 +183,7 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, clf_con
Returns:
dict: Detection results for the image.
"""
"""
if isinstance(img, str):
img_path = img_path or img
img = np.array(Image.open(img_path).convert("RGB"))
Expand All @@ -168,8 +194,11 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, clf_con
heatmap, clsmap = preds[:,:1,:,:], preds[:,1:,:,:]
counts, locs, labels, scores, dscores = self.lmds((heatmap, clsmap))
preds_array = self.process_lmds_results(counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres)
return self.results_generation(preds_array, img_path, id_strip=id_strip)

if img_path:
results_dict = self.results_generation(preds_array, img_id=img_path, id_strip=id_strip)
else:
results_dict = self.results_generation(preds_array, img=img)
return results_dict

def batch_image_detection(self, data_path, det_conf_thres=0.2, clf_conf_thres=0.2, batch_size=1, id_strip=None):
"""
Expand Down Expand Up @@ -207,7 +236,7 @@ def batch_image_detection(self, data_path, det_conf_thres=0.2, clf_conf_thres=0.
heatmap, clsmap = predictions[:,:1,:,:], predictions[:,1:,:,:]
counts, locs, labels, scores, dscores = self.lmds((heatmap, clsmap))
preds_array = self.process_lmds_results(counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres)
results_dict = self.results_generation(preds_array, paths[0], id_strip=id_strip)
results_dict = self.results_generation(preds_array, img_id=paths[0], id_strip=id_strip)
pbar.update(1)
sizes = sizes.numpy()
normalized_coords = [[x1 / sizes[0][0], y1 / sizes[0][1], x2 / sizes[0][0], y2 / sizes[0][1]] for x1, y1, x2, y2 in preds_array[:, :4]] # TODO: Check if this is correct due to xy swapping
Expand Down
28 changes: 20 additions & 8 deletions PytorchWildlife/utils/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,17 @@ def save_detection_images_dots(results, output_dir, input_dir = None, overwrite=

with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink:
if isinstance(results, list):
for entry in results:
for i, entry in enumerate(results):
if "img_id" in entry:
scene = np.array(Image.open(entry["img_id"]).convert("RGB"))
image_name = os.path.basename(entry["img_id"])
else:
scene = entry["img"]
image_name = f"output_image_{i}.jpg" # default name if no image id is provided

annotated_img = lab_annotator.annotate(
scene=dot_annotator.annotate(
scene=np.array(Image.open(entry["img_id"]).convert("RGB")),
scene=scene,
detections=entry["detections"],
),
detections=entry["detections"],
Expand All @@ -111,23 +118,28 @@ def save_detection_images_dots(results, output_dir, input_dir = None, overwrite=
relative_path = os.path.relpath(entry["img_id"], input_dir)
save_path = os.path.join(output_dir, relative_path)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
image_name = relative_path
else:
image_name = os.path.basename(entry["img_id"])
image_name = relative_path
sink.save_image(
image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=image_name
)
else:
if "img_id" in results:
scene = np.array(Image.open(results["img_id"]).convert("RGB"))
image_name = os.path.basename(results["img_id"])
else:
scene = results["img"]
image_name = "output_image.jpg" # default name if no image id is provided

annotated_img = lab_annotator.annotate(
scene=dot_annotator.annotate(
scene=np.array(Image.open(results["img_id"]).convert("RGB")),
scene=scene,
detections=results["detections"],
),
detections=results["detections"],
labels=results["labels"],
)
)
sink.save_image(
image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=os.path.basename(results["img_id"])
image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=image_name
)


Expand Down

0 comments on commit dfaf20d

Please sign in to comment.