diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a4cdadd..3133ee9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,4 +31,5 @@ jobs: python -m unittest tests/test_detector2cvat.py python -m unittest tests/test_miniscene2behavior.py python -m unittest tests/test_player.py - python -m unittest tests/test_tracks_extractor.py \ No newline at end of file + python -m unittest tests/test_tracks_extractor.py + python -m unittest tests/utils/test_yolo.py \ No newline at end of file diff --git a/README.md b/README.md index 6228bec..b9279f6 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ You may use [YOLO](https://docs.ultralytics.com/) to automatically perform detec Detect objects with Ultralytics YOLO detections, apply SORT tracking and convert tracks to CVAT format. ``` -detector2cvat --video path_to_videos --save path_to_save [--imshow] +detector2cvat --video path_to_videos --save path_to_save [--target_labels path_to_labels] [--label_map path_to_map] [--imshow] ``` diff --git a/ethogram/classes.json b/data/classes.json similarity index 100% rename from ethogram/classes.json rename to data/classes.json diff --git a/ethogram/label2index.json b/data/label2index.json similarity index 100% rename from ethogram/label2index.json rename to data/label2index.json diff --git a/ethogram/old2new.json b/data/old2new.json similarity index 100% rename from ethogram/old2new.json rename to data/old2new.json diff --git a/data/yolo_equiv.json b/data/yolo_equiv.json new file mode 100644 index 0000000..00501c6 --- /dev/null +++ b/data/yolo_equiv.json @@ -0,0 +1 @@ +{"horse": "zebra"} \ No newline at end of file diff --git a/data/yolo_labels.json b/data/yolo_labels.json new file mode 100644 index 0000000..8327914 --- /dev/null +++ b/data/yolo_labels.json @@ -0,0 +1 @@ +["zebra", "horse", "giraffe"] \ No newline at end of file diff --git a/src/kabr_tools/detector2cvat.py b/src/kabr_tools/detector2cvat.py index 206da44..f675278 100644 --- a/src/kabr_tools/detector2cvat.py +++ b/src/kabr_tools/detector2cvat.py @@ -1,4 +1,5 @@ import os +import json import argparse import cv2 from tqdm import tqdm @@ -8,13 +9,16 @@ from kabr_tools.utils.draw import Draw -def detector2cvat(path_to_videos: str, path_to_save: str, show: bool) -> None: +def detector2cvat(path_to_videos: str, path_to_save: str, + target_labels: list, label_map: dict, show: bool) -> None: """ Detect objects with Ultralytics YOLO detections, apply SORT tracking and convert tracks to CVAT format. Parameters: path_to_videos - str. Path to the folder containing videos. path_to_save - str. Path to the folder to save output xml & mp4 files. + target_labels - list. List of target labels to detect. + label_map - dict. Dictionary to rename labels. show - bool. Flag to display detector's visualization. """ videos = [] @@ -29,7 +33,7 @@ def detector2cvat(path_to_videos: str, path_to_save: str, show: bool) -> None: videos.append(f"{root}/{file}") - yolo = YOLOv8(weights="yolov8x.pt", imgsz=3840, conf=0.5) + yolo = YOLOv8(weights="yolov8x.pt", imgsz=3840, conf=0.5, target_labels=target_labels, label_map=label_map) for i, video in enumerate(videos): try: @@ -120,6 +124,16 @@ def parse_args() -> argparse.Namespace: help="path to save output xml & mp4 files", required=True ) + local_parser.add_argument( + "--target_labels", + type=str, + help="path to target labels json" + ) + local_parser.add_argument( + "--label_map", + type=str, + help="path to label map json" + ) local_parser.add_argument( "--imshow", action="store_true", @@ -127,10 +141,17 @@ def parse_args() -> argparse.Namespace: ) return local_parser.parse_args() +def load_json(file: str) -> dict: + if file: + with open(file, mode="r", encoding="utf-8") as file: + return json.load(file) + return None def main() -> None: args = parse_args() - detector2cvat(args.video, args.save, args.imshow) + target_labels = load_json(args.target_labels) + label_map = load_json(args.label_map) + detector2cvat(args.video, args.save, target_labels, label_map, args.imshow) if __name__ == "__main__": diff --git a/src/kabr_tools/utils/yolo.py b/src/kabr_tools/utils/yolo.py index 7383318..f5f0cff 100644 --- a/src/kabr_tools/utils/yolo.py +++ b/src/kabr_tools/utils/yolo.py @@ -3,11 +3,23 @@ class YOLOv8: - def __init__(self, weights="yolov8x.pt", imgsz=640, conf=0.5): + def __init__(self, weights="yolov8x.pt", + imgsz=640, conf=0.5, + target_labels=None, label_map=None): self.conf = conf self.imgsz = imgsz self.model = YOLO(weights) - self.names = self.model.names + self.names: dict = self.model.names + + if target_labels: + self.target_labels = target_labels + else: + self.target_labels = ["zebra", "horse", "giraffe"] + + if label_map: + self.label_map = label_map + else: + self.label_map = {"horse" : "zebra"} def forward(self, image): width = image.shape[1] @@ -18,7 +30,7 @@ def forward(self, image): for box, label, confidence in zip(boxes.xyxyn.numpy(), boxes.cls.numpy(), boxes.conf.numpy()): if confidence > self.conf: - if self.names[label] in ["zebra", "horse", "giraffe"]: + if self.names[label] in self.target_labels: box[0] = int(box[0] * width) box[1] = int(box[1] * height) box[2] = int(box[2] * width) @@ -26,10 +38,10 @@ def forward(self, image): box = box.astype(np.int32) confidence = float(f"{confidence:.2f}") - if self.names[label] == "horse": - label = "Zebra" - else: - label = self.names[label].capitalize() + label = self.names[label] + if label in self.label_map: + label = self.label_map[label] + label = label.capitalize() filtered.append(([box[0], box[1], box[2], box[3]], confidence, label)) diff --git a/tests/utils.py b/tests/helpers.py similarity index 100% rename from tests/utils.py rename to tests/helpers.py diff --git a/tests/test_cvat2slowfast.py b/tests/test_cvat2slowfast.py index 674837c..68cbe0e 100644 --- a/tests/test_cvat2slowfast.py +++ b/tests/test_cvat2slowfast.py @@ -2,7 +2,7 @@ import sys import os from kabr_tools import cvat2slowfast -from tests.utils import ( +from tests.helpers import ( get_behavior, del_dir, del_file @@ -35,8 +35,8 @@ def setUp(self): self.tool = "cvat2slowfast.py" self.miniscene = TestCvat2Slowfast.dir self.dataset = "tests/slowfast" - self.classes = "ethogram/classes.json" - self.old2new = "ethogram/old2new.json" + self.classes = "data/classes.json" + self.old2new = "data/old2new.json" def tearDown(self): # delete outputs diff --git a/tests/test_cvat2ultralytics.py b/tests/test_cvat2ultralytics.py index 881245a..e32669d 100644 --- a/tests/test_cvat2ultralytics.py +++ b/tests/test_cvat2ultralytics.py @@ -2,7 +2,7 @@ import sys import os from kabr_tools import cvat2ultralytics -from tests.utils import ( +from tests.helpers import ( del_dir, del_file, get_detection @@ -33,7 +33,7 @@ def setUp(self): self.annotation = TestCvat2Ultralytics.dir self.dataset = "tests/ultralytics" self.skip = "5" - self.label2index = "ethogram/label2index.json" + self.label2index = "data/label2index.json" def tearDown(self): # delete outputs diff --git a/tests/test_detector2cvat.py b/tests/test_detector2cvat.py index 675e7bc..3637b94 100644 --- a/tests/test_detector2cvat.py +++ b/tests/test_detector2cvat.py @@ -2,7 +2,7 @@ import sys import os from kabr_tools import detector2cvat -from tests.utils import ( +from tests.helpers import ( del_dir, del_file, get_detection @@ -33,6 +33,8 @@ def setUp(self): self.tool = "detector2cvat.py" self.video = TestDetector2Cvat.dir self.save = "tests/detector2cvat" + self.target_labels = "data/yolo_labels.json" + self.label_map = "data/yolo_equiv.json" def tearDown(self): # delete outputs @@ -55,6 +57,10 @@ def test_parse_arg_min(self): # check parsed argument values self.assertEqual(args.video, self.video) self.assertEqual(args.save, self.save) + + # check default argument values + self.assertEqual(args.target_labels, None) + self.assertEqual(args.label_map, None) self.assertEqual(args.imshow, False) def test_parse_arg_full(self): @@ -62,10 +68,14 @@ def test_parse_arg_full(self): sys.argv = [self.tool, "--video", self.video, "--save", self.save, + "--target_labels", self.target_labels, + "--label_map", self.label_map, "--imshow"] args = detector2cvat.parse_args() # check parsed argument values self.assertEqual(args.video, self.video) self.assertEqual(args.save, self.save) + self.assertEqual(args.target_labels, self.target_labels) + self.assertEqual(args.label_map, self.label_map) self.assertEqual(args.imshow, True) diff --git a/tests/test_miniscene2behavior.py b/tests/test_miniscene2behavior.py index 7875e2d..95f6d1d 100644 --- a/tests/test_miniscene2behavior.py +++ b/tests/test_miniscene2behavior.py @@ -12,7 +12,7 @@ tracks_extractor ) from kabr_tools.miniscene2behavior import annotate_miniscene -from tests.utils import ( +from tests.helpers import ( del_file, del_dir, get_detection diff --git a/tests/test_player.py b/tests/test_player.py index 0a2d3fb..f15eecb 100644 --- a/tests/test_player.py +++ b/tests/test_player.py @@ -3,7 +3,7 @@ import os from unittest.mock import patch from kabr_tools import player -from tests.utils import ( +from tests.helpers import ( del_file, del_dir, get_behavior diff --git a/tests/test_tracks_extractor.py b/tests/test_tracks_extractor.py index 8f61a1e..fde26e0 100644 --- a/tests/test_tracks_extractor.py +++ b/tests/test_tracks_extractor.py @@ -2,7 +2,7 @@ import sys from unittest.mock import patch from kabr_tools import tracks_extractor -from tests.utils import ( +from tests.helpers import ( get_detection, del_dir, del_file diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/test_yolo.py b/tests/utils/test_yolo.py new file mode 100644 index 0000000..8ed9e53 --- /dev/null +++ b/tests/utils/test_yolo.py @@ -0,0 +1,139 @@ +import unittest +from unittest.mock import MagicMock, patch +from collections import OrderedDict +import numpy as np +import torch +from ultralytics import YOLO +from kabr_tools.utils.yolo import YOLOv8 + +# from yolov8x.pt +LABELS = {"zebra": 22, "horse": 17, "giraffe": 23, "bear": 21} + +def rescale(box, width, height): + return [box[0] * width, box[1] * height, box[2] * width, box[3] * height] + + +class MockBox: + def __init__(self, box=[[0, 0, 0, 0]], cls=["zebra"], conf=[0.95]): + self.xyxyn = None + self.cls = None + self.conf = None + + def mock(self, boxes, classes, confs): + self.xyxyn = torch.Tensor(boxes) + self.cls = torch.Tensor([LABELS[cls] for cls in classes]) + self.conf = torch.Tensor(confs) + return self + + +class TestYolo(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.im = np.zeros((100, 101, 3), dtype=np.uint8) + cls.box = OrderedDict([("x1", 10), ("y1", 20), ("x2", 30), ("y2", 40)]) + cls.box_values = list(cls.box.values()) + + @patch("kabr_tools.utils.yolo.YOLO") + def test_forward(self, yolo_mock): + im = TestYolo.im + yolo_model = MagicMock() + yolo_model.predict.return_value.__getitem__ = lambda x, _: x + yolo_model.names = YOLO("yolov8x.pt").names + yolo_mock.return_value = yolo_model + + # horse -> zebra + points = [[0] * 4] * 3 + labels = ["zebra", "horse", "giraffe"] + expect_labels = ["Zebra", "Zebra", "Giraffe"] + probs = [0.7, 0.8, 0.9] + yolo_boxes = MockBox().mock(points, labels, probs) + yolo_model.predict.return_value.boxes.cpu.return_value = yolo_boxes + + yolo = YOLOv8() + preds = yolo.forward(im) + + self.assertEqual(len(preds), 3) + for i, pred in enumerate(preds): + self.assertEqual(preds[i][0], points[i]) + self.assertEqual(preds[i][1], probs[i]) + self.assertEqual(preds[i][2], expect_labels[i]) + + # bear -> filtered + points = [[0] * 4] * 3 + labels = ["bear", "horse", "giraffe"] + expect_labels = [None, "Zebra", "Giraffe"] + probs = [0.9, 0.8, 0.9] + yolo_boxes = MockBox().mock(points, labels, probs) + yolo_model.predict.return_value.boxes.cpu.return_value = yolo_boxes + + yolo = YOLOv8() + preds = yolo.forward(im) + + self.assertEqual(len(preds), 2) + index = 0 + for pred in preds: + while expect_labels[index] is None: + index += 1 + self.assertEqual(pred[0], rescale(points[index], im.shape[1], im.shape[0])) + self.assertEqual(pred[1], probs[index]) + self.assertEqual(pred[2], expect_labels[index]) + index += 1 + + # low prob -> filtered + points = [[i] * 4 for i in range(8)] + labels = ["bear", "horse", "zebra", "giraffe", "bear", "horse", "zebra", "giraffe"] + expect_labels = [None, "Zebra", None, "Giraffe", None, "Zebra", None, None] + probs = [0.5, 0.9, 0.4, 0.8, 0.7, 0.6, 0.3, 0.5] + yolo_boxes = MockBox().mock(points, labels, probs) + yolo_model.predict.return_value.boxes.cpu.return_value = yolo_boxes + + yolo = YOLOv8() + preds = yolo.forward(im) + + self.assertEqual(len(preds), 3) + index = 0 + for pred in preds: + while expect_labels[index] is None: + index += 1 + self.assertEqual(pred[0], rescale(points[index], im.shape[1], im.shape[0])) + self.assertEqual(pred[1], probs[index]) + self.assertEqual(pred[2], expect_labels[index]) + index += 1 + + @patch("kabr_tools.utils.yolo.YOLO") + def test_yolo_with_params(self, yolo_mock): + im = TestYolo.im + yolo_model = MagicMock() + yolo_model.predict.return_value.__getitem__ = lambda x, _: x + yolo_model.names = YOLO("yolov8x.pt").names + yolo_mock.return_value = yolo_model + + points = [[i] * 4 for i in range(8)] + labels = ["bear", "horse", "zebra", "giraffe", "bear", "horse", "zebra", "giraffe"] + expect_labels = ["Panda", "Fish", None, None, None, None, None, "Giraffe"] + probs = [0.91, 0.99, 0.92, 0.55, 0.9, 0.89, 0.85, 0.93] + yolo_boxes = MockBox().mock(points, labels, probs) + yolo_model.predict.return_value.boxes.cpu.return_value = yolo_boxes + + yolo = YOLOv8(weights="yolov8x.pt", + imgsz=640, conf=0.9, + target_labels=["bear", "horse", "giraffe"], + label_map={"bear": "panda", "horse": "fish"}) + preds = yolo.forward(im) + + self.assertEqual(len(preds), 3) + index = 0 + for pred in preds: + while expect_labels[index] is None: + index += 1 + self.assertEqual(pred[0], rescale(points[index], im.shape[1], im.shape[0])) + self.assertEqual(pred[1], probs[index]) + self.assertEqual(pred[2], expect_labels[index]) + index += 1 + + def test_get_centroid(self): + box = TestYolo.box + box_values = TestYolo.box_values + x, y = YOLOv8.get_centroid(box_values) + self.assertEqual(x, (box["x1"] + box["x2"]) // 2) + self.assertEqual(y, (box["y1"] + box["y2"]) // 2)