Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

main_2024.py detect target #156

Merged
merged 36 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2dbe3c9
added annotated image option
Ethan118 Nov 13, 2023
10c489e
integrated in main_2024
Ethan118 Nov 13, 2023
bfa1978
added cli option
Ethan118 Nov 27, 2023
ad9ac08
added window closing
Ethan118 Nov 27, 2023
f23f62d
fixed unit tests
Ethan118 Nov 28, 2023
dd6e5a1
fixed image annotation and modified test cases
Ethan118 Nov 28, 2023
3ee9905
converted unit tests to use bounding boxes
Ethan118 Nov 30, 2023
d7ca078
added device to predict arg
Ethan118 Nov 30, 2023
8819c3c
added annotated image option
Ethan118 Nov 13, 2023
d08e556
integrated in main_2024
Ethan118 Nov 13, 2023
bd2a59d
added cli option
Ethan118 Nov 27, 2023
a49c660
added window closing
Ethan118 Nov 27, 2023
20e4018
fixed unit tests
Ethan118 Nov 28, 2023
0da5e01
fixed image annotation and modified test cases
Ethan118 Nov 28, 2023
37a37bc
converted unit tests to use bounding boxes
Ethan118 Nov 30, 2023
cb186f4
added device to predict arg
Ethan118 Nov 30, 2023
724de57
remove annotated image from function return
Ethan118 Dec 4, 2023
5e6c8e7
added hardcoded values for testing target detection
Ethan118 Jan 16, 2024
0133d97
fixed conflicts
Ethan118 Jan 16, 2024
479a361
Changed tests to use absolute error tolerance, Added checks for numbe…
Ethan118 Jan 27, 2024
864bcaa
Merge branch 'main' into detect_target-integration
Ethan118 Jan 27, 2024
0b2f4b7
edited paths to use pathlib
Ethan118 Feb 5, 2024
ba05393
Fixed tests to use detections and time object
Ethan118 Feb 5, 2024
108c7c6
merged with main
Ethan118 Feb 5, 2024
ca549ba
fixed path errors in build
Ethan118 Feb 5, 2024
0ab48c9
Offloaded fixture setup to a separate function and added comments
Ethan118 Feb 6, 2024
435992c
Cleaned up test_detect_target and generate_expected
Ethan118 Feb 7, 2024
73a28a8
Fixed formatting
Ethan118 Feb 7, 2024
69be32a
Fixed renamed variables
Ethan118 Feb 7, 2024
5117781
Changed tolerance
Ethan118 Feb 8, 2024
dea615a
Merge branch 'main' into detect_target-integration
Ethan118 Feb 8, 2024
3080db0
fixed create detections
Ethan118 Feb 8, 2024
6b00c7e
Set tolerance to 0 and fixed formatting
Ethan118 Feb 9, 2024
ef80b03
Changed formatting mostly multiline
Ethan118 Feb 14, 2024
68fab82
Final formatting changes, Updated test_detect_target_worker to use de…
Ethan118 Feb 17, 2024
f8359ec
Some refactoring of main_2024 and final formatting changes
Ethan118 Feb 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions main_2024.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove line 12 import (no longer being used).

Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main() -> int:
DETECT_TARGET_MODEL_PATH = config["detect_target"]["model_path"]
DETECT_TARGET_OVERRIDE_FULL_PRECISION = args.full
DETECT_TARGET_SAVE_PREFIX = config["detect_target"]["save_prefix"]
DETECT_TARGET_ANNOTATE = args.show_annotated
DETECT_TARGET_SHOW_ANNOTATED = args.show_annotated

FLIGHT_INTERFACE_ADDRESS = config["flight_interface"]["address"]
FLIGHT_INTERFACE_TIMEOUT = config["flight_interface"]["timeout"]
Expand Down Expand Up @@ -108,7 +108,7 @@ def main() -> int:
DETECT_TARGET_MODEL_PATH,
DETECT_TARGET_OVERRIDE_FULL_PRECISION,
DETECT_TARGET_SAVE_PREFIX,
DETECT_TARGET_ANNOTATE,
DETECT_TARGET_SHOW_ANNOTATED,
video_input_to_detect_target_queue,
detect_target_to_main_queue,
controller,
Expand Down Expand Up @@ -139,6 +139,14 @@ def main() -> int:
except queue.Empty:
detections = None

if detections is not None:
print("timestamp: " + str(detections.timestamp))
print("detections: " + str(len(detections.detections)))
for detection in detections.detections:
print(" label: " + str(detection.label))
print(" confidence: " + str(detection.confidence))
print("")

odometry_and_time = flight_interface_to_main_queue.queue.get()

if odometry_and_time is not None:
Expand All @@ -151,7 +159,8 @@ def main() -> int:
print("pitch: " + str(odometry_and_time.odometry_data.orientation.pitch))
print("")

if cv2.waitKey(1) & 0xFF == ord("q"):
if cv2.waitKey(1) == ord("q"):
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
print("Exiting main loop")
break

# Teardown
Expand Down
1 change: 0 additions & 1 deletion modules/detect_target/detect_target.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert to when you were displaying the image in run() . Return only the detections.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amy mentioned above that this causes issues. Is there a work around?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work on you end? If it does then maybe I missed something in the set up?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try adding cv2.waitKey(1) after the imshow() ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try adding cv2.waitKey(1) after the imshow() ?

Oh this works! @Ethan118 you can do this instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove line 7 import (no longer being used).

Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def run(self, data: image_and_time.ImageAndTime) -> "tuple[bool, detections_and_

if self.__show_annotations:
cv2.imshow("Annotated", image_annotated)
cv2.waitKey(1)

return True, detections

Expand Down
24 changes: 15 additions & 9 deletions tests/model_example/generate_expected.py
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@
import cv2
import numpy as np
import ultralytics
import pathlib
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved


# Downloaded from: https://github.com/ultralytics/assets/releases
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
MODEL_PATH = "tests/model_example/yolov8s_ultralytics_pretrained_default.pt"
MODEL_PATH = pathlib.Path("tests/model_example/yolov8s_ultralytics_pretrained_default.pt")

BUS_IMAGE_PATH = pathlib.Path("tests/model_example/bus.jpg")
ZIDANE_IMAGE_PATH = pathlib.Path("tests/model_example/zidane.jpg")

SAVE_PATH = pathlib.Path("tests/model_example")
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
model = ultralytics.YOLO(MODEL_PATH)
image_bus = cv2.imread("tests/model_example/bus.jpg")
image_zidane = cv2.imread("tests/model_example/zidane.jpg")
image_bus = cv2.imread(BUS_IMAGE_PATH)
image_zidane = cv2.imread(ZIDANE_IMAGE_PATH)
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved

# ultralytics saves as .jpg , bad for testing reproducibility
results_bus = model.predict(
Expand All @@ -35,15 +41,15 @@
image_zidane_annotated = results_zidane[0].plot(conf=True)

# Generate expected
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
bus_expected = results_bus[0].boxes.xyxy.detach().cpu().numpy()
zidane_expected = results_zidane[0].boxes.xyxy.detach().cpu().numpy()
bounding_box_bus = results_bus[0].boxes.xyxy.detach().cpu().numpy()
bounding_box_zidane = results_zidane[0].boxes.xyxy.detach().cpu().numpy()

# Save image
cv2.imwrite("tests/model_example/bus_annotated.png", image_bus_annotated)
cv2.imwrite("tests/model_example/zidane_annotated.png", image_zidane_annotated)
cv2.imwrite(f"{SAVE_PATH}/bus_annotated.png", image_bus_annotated)
cv2.imwrite(f"{SAVE_PATH}/zidane_annotated.png", image_zidane_annotated)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unnecessary space characters.

# Save expected to text file
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
np.savetxt("tests/model_example/bus_expected.txt", bus_expected)
np.savetxt("tests/model_example/zidane_expected.txt", zidane_expected)
np.savetxt(f"{SAVE_PATH}/bounding_box_bus.txt", bounding_box_bus)
np.savetxt(f"{SAVE_PATH}/bounding_box_zidane.txt", bounding_box_zidane)
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved

print("Done!")
102 changes: 49 additions & 53 deletions tests/test_detect_target.py
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import pytest
import torch
import ultralytics

from modules.detect_target import detect_target
from modules import image_and_time
Expand All @@ -19,12 +18,10 @@
OVERRIDE_FULL = False # Tests are able to handle both full and half precision.
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
IMAGE_BUS_PATH = "tests/model_example/bus.jpg"
IMAGE_BUS_ANNOTATED_PATH = "tests/model_example/bus_annotated.png"
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
BOUNDING_BOX_BUS_PATH = "tests/model_example/bounding_box_bus.txt"
IMAGE_ZIDANE_PATH = "tests/model_example/zidane.jpg"
IMAGE_ZIDANE_ANNOTATED_PATH = "tests/model_example/zidane_annotated.png"
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved

model = ultralytics.YOLO(MODEL_PATH)
expected_bus = np.loadtxt("tests/model_example/bus_expected.txt")
expected_zidane = np.loadtxt("tests/model_example/zidane_expected.txt")
BOUNDING_BOX_ZIDANE_PATH = "tests/model_example/bounding_box_zidane.txt"

@pytest.fixture()
def detector():
Expand Down Expand Up @@ -58,85 +55,83 @@ def image_zidane():
assert zidane_image is not None
yield zidane_image

@pytest.fixture()
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
def expected_bus():
"""
Load expected bus image.
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
"""
expected_bus = np.loadtxt(BOUNDING_BOX_BUS_PATH)
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this to something that is not the function name (e.g. expected ).

assert expected_bus is not None
yield expected_bus
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved

def rmse(actual: np.ndarray,
expected: np.ndarray) -> float:
"""
Helper function to compute root mean squared error.
"""
mean_squared_error = np.square(actual - expected).mean()

return np.sqrt(mean_squared_error)


def test_rmse():
"""
Root mean squared error.
"""
# Setup
sample_actual = np.array([1, 2, 3, 4, 5])
sample_expected = np.array([1.6, 2.5, 2.9, 3, 4.1])
EXPECTED_ERROR = np.sqrt(0.486)

# Run
actual_error = rmse(sample_actual, sample_expected)

# Test
np.testing.assert_almost_equal(actual_error, EXPECTED_ERROR)

@pytest.fixture()
def expected_zidane():
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

"""
Load expected Zidane image.
"""
expected_zidane = np.loadtxt(BOUNDING_BOX_ZIDANE_PATH)
assert expected_zidane is not None
yield expected_zidane

class TestDetector:
"""
Tests `DetectTarget.run()` .
"""

__IMAGE_DIFFERENCE_TOLERANCE = 1
__BOUNDING_BOX_TOLERANCE = 1e-7

def test_single_bus_image(self,
detector: detect_target.DetectTarget,
image_bus: image_and_time.ImageAndTime):
image_bus: image_and_time.ImageAndTime,
expected_bus: np.ndarray):
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
"""
Bus image.
"""
# Run
result, actual = detector.run(image_bus)
detections = actual.detections

# Test
assert result
assert actual is not None
assert detections is not None

detections = actual.detections

error = 0
assert len(detections) == expected_bus.shape[0]

for i in range(0, len(detections)):
error += rmse([detections[i].x1, detections[i].y1, detections[i].x2, detections[i].y2], expected_bus[i])
assert (error / len(detections)) < self.__IMAGE_DIFFERENCE_TOLERANCE
errors = np.full((1, expected_bus.shape[1]), 1)

for i in range(0, expected_bus.shape[0]):
errors = np.abs(np.array([detections[i].x1, detections[i].y1, detections[i].x2, detections[i].y2]) - expected_bus[i])
assert all(e < self.__BOUNDING_BOX_TOLERANCE for e in errors)

def test_single_zidane_image(self,
detector: detect_target.DetectTarget,
image_zidane: image_and_time.ImageAndTime):
image_zidane: image_and_time.ImageAndTime,
expected_zidane: np.ndarray):
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
"""
Zidane image.
"""
# Run
result, actual = detector.run(image_zidane)
detections = actual.detections

# Test
assert result
assert actual is not None
assert detections is not None

error = 0
detections = actual.detections

assert len(detections) == expected_zidane.shape[0]

for i in range(0, len(detections)):
error += rmse([detections[i].x1, detections[i].y1, detections[i].x2, detections[i].y2], expected_zidane[i])
assert (error / len(detections)) < self.__IMAGE_DIFFERENCE_TOLERANCE
errors = np.full((1, expected_zidane.shape[0]), 1)

for i in range(0, expected_zidane.shape[0]):
errors = np.abs(np.array([detections[i].x1, detections[i].y1, detections[i].x2, detections[i].y2]) - expected_zidane[i])
assert all(e < self.__BOUNDING_BOX_TOLERANCE for e in errors)

def test_multiple_zidane_image(self,
detector: detect_target.DetectTarget,
image_zidane: image_and_time.ImageAndTime):
image_zidane: image_and_time.ImageAndTime,
expected_zidane: np.ndarray):
Xierumeng marked this conversation as resolved.
Show resolved Hide resolved
"""
Multiple Zidane images.
"""
Expand All @@ -157,15 +152,16 @@ def test_multiple_zidane_image(self,
for i in range(0, IMAGE_COUNT):
output: "tuple[bool, detections_and_time.DetectionsAndTime | None]" = outputs[i]
result, actual = output

detections = actual.detections

assert result
assert actual is not None
assert detections is not None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unnecessary space characters.

detections = actual.detections

assert len(detections) == expected_zidane.shape[0]

error = 0
errors = np.full((1, expected_zidane.shape[0]), 1)

for i in range(0, len(detections)):
error += rmse([detections[i].x1, detections[i].y1, detections[i].x2, detections[i].y2], expected_zidane[i])
assert (error / len(detections)) < self.__IMAGE_DIFFERENCE_TOLERANCE
for i in range(0, expected_zidane.shape[0]):
errors = np.abs(np.array([detections[i].x1, detections[i].y1, detections[i].x2, detections[i].y2]) - expected_zidane[i])
assert all(e < self.__BOUNDING_BOX_TOLERANCE for e in errors)
Loading