Skip to content

Commit ac8ba42

Browse files
authored
Add configs to detect target pipeline (#242)
* Create ultralytics config class * Add configs to target detection pipeline * Organize YAML and improve parsing * Reorder args * Update integration test
1 parent 335518a commit ac8ba42

File tree

8 files changed

+184
-85
lines changed

8 files changed

+184
-85
lines changed

config.yaml

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,29 @@ video_input:
2222

2323
detect_target:
2424
worker_count: 1
25-
option: 0 # 0 is for Ultralytics (from detect_target_factory.py)
26-
device: 0
27-
model_path: "tests/model_example/yolov8s_ultralytics_pretrained_default.pt" # See autonomy OneDrive for latest model
25+
option: 0 # 0 is for Ultralytics and 1 is for brightspot (from detect_target_factory.py)
2826
save_prefix: "log_comp"
29-
30-
detect_brightspot:
31-
brightspot_percentile_threshold: 99.9
32-
filter_by_color: True
33-
blob_color: 255
34-
filter_by_circularity: False
35-
min_circularity: 0.01
36-
max_circularity: 1
37-
filter_by_inertia: True
38-
min_inertia_ratio: 0.2
39-
max_inertia_ratio: 1
40-
filter_by_convexity: False
41-
min_convexity: 0.01
42-
max_convexity: 1
43-
filter_by_area: True
44-
min_area_pixels: 50
45-
max_area_pixels: 640
27+
# Ultralytics config (enum 0)
28+
config:
29+
device: 0
30+
model_path: "tests/model_example/yolov8s_ultralytics_pretrained_default.pt" # See autonomy OneDrive for latest model
31+
# Brightspot config (enum 1)
32+
# config:
33+
# brightspot_percentile_threshold: 99.9
34+
# filter_by_color: True
35+
# blob_color: 255
36+
# filter_by_circularity: False
37+
# min_circularity: 0.01
38+
# max_circularity: 1
39+
# filter_by_inertia: True
40+
# min_inertia_ratio: 0.2
41+
# max_inertia_ratio: 1
42+
# filter_by_convexity: False
43+
# min_convexity: 0.01
44+
# max_convexity: 1
45+
# filter_by_area: True
46+
# min_area_pixels: 50
47+
# max_area_pixels: 640
4648

4749
flight_interface:
4850
# Port 5762 connects directly to the simulated auto pilot, which is more realistic

main_2024.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
from modules.common.modules.camera import camera_opencv
1717
from modules.common.modules.camera import camera_picamera2
1818
from modules.communications import communications_worker
19+
from modules.detect_target import detect_target_brightspot
1920
from modules.detect_target import detect_target_factory
2021
from modules.detect_target import detect_target_worker
22+
from modules.detect_target import detect_target_ultralytics
2123
from modules.flight_interface import flight_interface_worker
2224
from modules.video_input import video_input_worker
2325
from modules.data_merge import data_merge_worker
@@ -109,13 +111,26 @@ def main() -> int:
109111
DETECT_TARGET_OPTION = detect_target_factory.DetectTargetOption(
110112
config["detect_target"]["option"]
111113
)
112-
DETECT_TARGET_DEVICE = "cpu" if args.cpu else config["detect_target"]["device"]
113-
DETECT_TARGET_MODEL_PATH = config["detect_target"]["model_path"]
114-
DETECT_TARGET_OVERRIDE_FULL_PRECISION = args.full
115114
DETECT_TARGET_SAVE_PREFIX = str(
116115
pathlib.Path(logging_path, config["detect_target"]["save_prefix"])
117116
)
118117
DETECT_TARGET_SHOW_ANNOTATED = args.show_annotated
118+
match DETECT_TARGET_OPTION:
119+
case detect_target_factory.DetectTargetOption.ML_ULTRALYTICS:
120+
DETECT_TARGET_CONFIG = detect_target_ultralytics.DetectTargetUltralyticsConfig(
121+
config["detect_target"]["config"]["device"],
122+
config["detect_target"]["config"]["model_path"],
123+
args.full,
124+
)
125+
case detect_target_factory.DetectTargetOption.CV_BRIGHTSPOT:
126+
DETECT_TARGET_CONFIG = detect_target_brightspot.DetectTargetBrightspotConfig(
127+
**config["detect_target"]["config"]
128+
)
129+
case _:
130+
main.logger.error(
131+
f"Inputted an invalid detect target option: {DETECT_TARGET_OPTION}", True
132+
)
133+
return -1
119134

120135
FLIGHT_INTERFACE_ADDRESS = config["flight_interface"]["address"]
121136
FLIGHT_INTERFACE_TIMEOUT = config["flight_interface"]["timeout"]
@@ -244,12 +259,10 @@ def main() -> int:
244259
count=DETECT_TARGET_WORKER_COUNT,
245260
target=detect_target_worker.detect_target_worker,
246261
work_arguments=(
247-
DETECT_TARGET_OPTION,
248-
DETECT_TARGET_DEVICE,
249-
DETECT_TARGET_MODEL_PATH,
250-
DETECT_TARGET_OVERRIDE_FULL_PRECISION,
251-
DETECT_TARGET_SHOW_ANNOTATED,
252262
DETECT_TARGET_SAVE_PREFIX,
263+
DETECT_TARGET_SHOW_ANNOTATED,
264+
DETECT_TARGET_OPTION,
265+
DETECT_TARGET_CONFIG,
253266
),
254267
input_queues=[video_input_to_detect_target_queue],
255268
output_queues=[detect_target_to_data_merge_queue],

modules/detect_target/detect_target_factory.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,29 @@ class DetectTargetOption(enum.Enum):
2020

2121

2222
def create_detect_target(
23+
save_name: str,
24+
show_annotations: bool,
2325
detect_target_option: DetectTargetOption,
24-
device: "str | int",
25-
model_path: str,
26-
override_full: bool,
26+
config: (
27+
detect_target_brightspot.DetectTargetBrightspotConfig
28+
| detect_target_ultralytics.DetectTargetUltralyticsConfig
29+
),
2730
local_logger: logger.Logger,
28-
show_annotations: bool,
29-
save_name: str,
3031
) -> tuple[bool, base_detect_target.BaseDetectTarget | None]:
3132
"""
3233
Construct detect target class at runtime.
3334
"""
3435
match detect_target_option:
3536
case DetectTargetOption.ML_ULTRALYTICS:
3637
return True, detect_target_ultralytics.DetectTargetUltralytics(
37-
device,
38-
model_path,
39-
override_full,
38+
config,
4039
local_logger,
4140
show_annotations,
4241
save_name,
4342
)
4443
case DetectTargetOption.CV_BRIGHTSPOT:
4544
return True, detect_target_brightspot.DetectTargetBrightspot(
45+
config,
4646
local_logger,
4747
show_annotations,
4848
save_name,

modules/detect_target/detect_target_ultralytics.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,37 @@
1313
from ..common.modules.logger import logger
1414

1515

16-
class DetectTargetUltralytics(base_detect_target.BaseDetectTarget):
16+
class DetectTargetUltralyticsConfig:
1717
"""
18-
Contains the YOLOv8 model for prediction.
18+
Configuration for DetectTargetUltralytics.
1919
"""
2020

2121
def __init__(
2222
self,
2323
device: "str | int",
2424
model_path: str,
2525
override_full: bool,
26+
) -> None:
27+
"""
28+
Initializes the configuration for DetectTargetUltralytics.
29+
30+
device: name of target device to run inference on (i.e. "cpu" or cuda device 0, 1, 2, 3).
31+
model_path: path to the YOLOv8 model.
32+
override_full: Force full precision floating point calculations.
33+
"""
34+
self.device = device
35+
self.model_path = model_path
36+
self.override_full = override_full
37+
38+
39+
class DetectTargetUltralytics(base_detect_target.BaseDetectTarget):
40+
"""
41+
Contains the YOLOv8 model for prediction.
42+
"""
43+
44+
def __init__(
45+
self,
46+
config: DetectTargetUltralyticsConfig,
2647
local_logger: logger.Logger,
2748
show_annotations: bool = False,
2849
save_name: str = "",
@@ -34,14 +55,14 @@ def __init__(
3455
show_annotations: Display annotated images.
3556
save_name: filename prefix for logging detections and annotated images.
3657
"""
37-
self.__device = device
38-
self.__model = ultralytics.YOLO(model_path)
39-
self.__counter = 0
58+
self.__device = config.device
4059
self.__enable_half_precision = not self.__device == "cpu"
60+
self.__model = ultralytics.YOLO(config.model_path)
61+
if config.override_full:
62+
self.__enable_half_precision = False
63+
self.__counter = 0
4164
self.__local_logger = local_logger
4265
self.__show_annotations = show_annotations
43-
if override_full:
44-
self.__enable_half_precision = False
4566
self.__filename_prefix = ""
4667
if save_name != "":
4768
self.__filename_prefix = save_name + "_" + str(int(time.time())) + "_"

modules/detect_target/detect_target_worker.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,20 @@
88
from modules import image_and_time
99
from utilities.workers import queue_proxy_wrapper
1010
from utilities.workers import worker_controller
11+
from . import detect_target_brightspot
1112
from . import detect_target_factory
13+
from . import detect_target_ultralytics
1214
from ..common.modules.logger import logger
1315

1416

1517
def detect_target_worker(
16-
detect_target_option: detect_target_factory.DetectTargetOption,
17-
device: "str | int",
18-
model_path: str,
19-
override_full: bool,
20-
show_annotations: bool,
2118
save_name: str,
19+
show_annotations: bool,
20+
detect_target_option: detect_target_factory.DetectTargetOption,
21+
config: (
22+
detect_target_brightspot.DetectTargetBrightspotConfig
23+
| detect_target_ultralytics.DetectTargetUltralyticsConfig
24+
),
2225
input_queue: queue_proxy_wrapper.QueueProxyWrapper,
2326
output_queue: queue_proxy_wrapper.QueueProxyWrapper,
2427
controller: worker_controller.WorkerController,
@@ -44,13 +47,11 @@ def detect_target_worker(
4447
local_logger.info("Logger initialized", True)
4548

4649
result, detector = detect_target_factory.create_detect_target(
50+
save_name,
51+
show_annotations,
4752
detect_target_option,
48-
device,
49-
model_path,
50-
override_full,
53+
config,
5154
local_logger,
52-
show_annotations,
53-
save_name,
5455
)
5556

5657
if not result:

0 commit comments

Comments
 (0)