Skip to content

Commit

Permalink
Add configs to detect target pipeline (#242)
Browse files Browse the repository at this point in the history
* Create ultralytics config class

* Add configs to target detection pipeline

* Organize YAML and improve parsing

* Reorder args

* Update integration test
  • Loading branch information
siddhp1 authored Jan 24, 2025
1 parent 335518a commit ac8ba42
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 85 deletions.
42 changes: 22 additions & 20 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,29 @@ video_input:

detect_target:
worker_count: 1
option: 0 # 0 is for Ultralytics (from detect_target_factory.py)
device: 0
model_path: "tests/model_example/yolov8s_ultralytics_pretrained_default.pt" # See autonomy OneDrive for latest model
option: 0 # 0 is for Ultralytics and 1 is for brightspot (from detect_target_factory.py)
save_prefix: "log_comp"

detect_brightspot:
brightspot_percentile_threshold: 99.9
filter_by_color: True
blob_color: 255
filter_by_circularity: False
min_circularity: 0.01
max_circularity: 1
filter_by_inertia: True
min_inertia_ratio: 0.2
max_inertia_ratio: 1
filter_by_convexity: False
min_convexity: 0.01
max_convexity: 1
filter_by_area: True
min_area_pixels: 50
max_area_pixels: 640
# Ultralytics config (enum 0)
config:
device: 0
model_path: "tests/model_example/yolov8s_ultralytics_pretrained_default.pt" # See autonomy OneDrive for latest model
# Brightspot config (enum 1)
# config:
# brightspot_percentile_threshold: 99.9
# filter_by_color: True
# blob_color: 255
# filter_by_circularity: False
# min_circularity: 0.01
# max_circularity: 1
# filter_by_inertia: True
# min_inertia_ratio: 0.2
# max_inertia_ratio: 1
# filter_by_convexity: False
# min_convexity: 0.01
# max_convexity: 1
# filter_by_area: True
# min_area_pixels: 50
# max_area_pixels: 640

flight_interface:
# Port 5762 connects directly to the simulated auto pilot, which is more realistic
Expand Down
29 changes: 21 additions & 8 deletions main_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from modules.common.modules.camera import camera_opencv
from modules.common.modules.camera import camera_picamera2
from modules.communications import communications_worker
from modules.detect_target import detect_target_brightspot
from modules.detect_target import detect_target_factory
from modules.detect_target import detect_target_worker
from modules.detect_target import detect_target_ultralytics
from modules.flight_interface import flight_interface_worker
from modules.video_input import video_input_worker
from modules.data_merge import data_merge_worker
Expand Down Expand Up @@ -109,13 +111,26 @@ def main() -> int:
DETECT_TARGET_OPTION = detect_target_factory.DetectTargetOption(
config["detect_target"]["option"]
)
DETECT_TARGET_DEVICE = "cpu" if args.cpu else config["detect_target"]["device"]
DETECT_TARGET_MODEL_PATH = config["detect_target"]["model_path"]
DETECT_TARGET_OVERRIDE_FULL_PRECISION = args.full
DETECT_TARGET_SAVE_PREFIX = str(
pathlib.Path(logging_path, config["detect_target"]["save_prefix"])
)
DETECT_TARGET_SHOW_ANNOTATED = args.show_annotated
match DETECT_TARGET_OPTION:
case detect_target_factory.DetectTargetOption.ML_ULTRALYTICS:
DETECT_TARGET_CONFIG = detect_target_ultralytics.DetectTargetUltralyticsConfig(
config["detect_target"]["config"]["device"],
config["detect_target"]["config"]["model_path"],
args.full,
)
case detect_target_factory.DetectTargetOption.CV_BRIGHTSPOT:
DETECT_TARGET_CONFIG = detect_target_brightspot.DetectTargetBrightspotConfig(
**config["detect_target"]["config"]
)
case _:
main.logger.error(
f"Inputted an invalid detect target option: {DETECT_TARGET_OPTION}", True
)
return -1

FLIGHT_INTERFACE_ADDRESS = config["flight_interface"]["address"]
FLIGHT_INTERFACE_TIMEOUT = config["flight_interface"]["timeout"]
Expand Down Expand Up @@ -244,12 +259,10 @@ def main() -> int:
count=DETECT_TARGET_WORKER_COUNT,
target=detect_target_worker.detect_target_worker,
work_arguments=(
DETECT_TARGET_OPTION,
DETECT_TARGET_DEVICE,
DETECT_TARGET_MODEL_PATH,
DETECT_TARGET_OVERRIDE_FULL_PRECISION,
DETECT_TARGET_SHOW_ANNOTATED,
DETECT_TARGET_SAVE_PREFIX,
DETECT_TARGET_SHOW_ANNOTATED,
DETECT_TARGET_OPTION,
DETECT_TARGET_CONFIG,
),
input_queues=[video_input_to_detect_target_queue],
output_queues=[detect_target_to_data_merge_queue],
Expand Down
16 changes: 8 additions & 8 deletions modules/detect_target/detect_target_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,29 @@ class DetectTargetOption(enum.Enum):


def create_detect_target(
save_name: str,
show_annotations: bool,
detect_target_option: DetectTargetOption,
device: "str | int",
model_path: str,
override_full: bool,
config: (
detect_target_brightspot.DetectTargetBrightspotConfig
| detect_target_ultralytics.DetectTargetUltralyticsConfig
),
local_logger: logger.Logger,
show_annotations: bool,
save_name: str,
) -> tuple[bool, base_detect_target.BaseDetectTarget | None]:
"""
Construct detect target class at runtime.
"""
match detect_target_option:
case DetectTargetOption.ML_ULTRALYTICS:
return True, detect_target_ultralytics.DetectTargetUltralytics(
device,
model_path,
override_full,
config,
local_logger,
show_annotations,
save_name,
)
case DetectTargetOption.CV_BRIGHTSPOT:
return True, detect_target_brightspot.DetectTargetBrightspot(
config,
local_logger,
show_annotations,
save_name,
Expand Down
35 changes: 28 additions & 7 deletions modules/detect_target/detect_target_ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,37 @@
from ..common.modules.logger import logger


class DetectTargetUltralytics(base_detect_target.BaseDetectTarget):
class DetectTargetUltralyticsConfig:
"""
Contains the YOLOv8 model for prediction.
Configuration for DetectTargetUltralytics.
"""

def __init__(
self,
device: "str | int",
model_path: str,
override_full: bool,
) -> None:
"""
Initializes the configuration for DetectTargetUltralytics.
device: name of target device to run inference on (i.e. "cpu" or cuda device 0, 1, 2, 3).
model_path: path to the YOLOv8 model.
override_full: Force full precision floating point calculations.
"""
self.device = device
self.model_path = model_path
self.override_full = override_full


class DetectTargetUltralytics(base_detect_target.BaseDetectTarget):
"""
Contains the YOLOv8 model for prediction.
"""

def __init__(
self,
config: DetectTargetUltralyticsConfig,
local_logger: logger.Logger,
show_annotations: bool = False,
save_name: str = "",
Expand All @@ -34,14 +55,14 @@ def __init__(
show_annotations: Display annotated images.
save_name: filename prefix for logging detections and annotated images.
"""
self.__device = device
self.__model = ultralytics.YOLO(model_path)
self.__counter = 0
self.__device = config.device
self.__enable_half_precision = not self.__device == "cpu"
self.__model = ultralytics.YOLO(config.model_path)
if config.override_full:
self.__enable_half_precision = False
self.__counter = 0
self.__local_logger = local_logger
self.__show_annotations = show_annotations
if override_full:
self.__enable_half_precision = False
self.__filename_prefix = ""
if save_name != "":
self.__filename_prefix = save_name + "_" + str(int(time.time())) + "_"
Expand Down
21 changes: 11 additions & 10 deletions modules/detect_target/detect_target_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@
from modules import image_and_time
from utilities.workers import queue_proxy_wrapper
from utilities.workers import worker_controller
from . import detect_target_brightspot
from . import detect_target_factory
from . import detect_target_ultralytics
from ..common.modules.logger import logger


def detect_target_worker(
detect_target_option: detect_target_factory.DetectTargetOption,
device: "str | int",
model_path: str,
override_full: bool,
show_annotations: bool,
save_name: str,
show_annotations: bool,
detect_target_option: detect_target_factory.DetectTargetOption,
config: (
detect_target_brightspot.DetectTargetBrightspotConfig
| detect_target_ultralytics.DetectTargetUltralyticsConfig
),
input_queue: queue_proxy_wrapper.QueueProxyWrapper,
output_queue: queue_proxy_wrapper.QueueProxyWrapper,
controller: worker_controller.WorkerController,
Expand All @@ -44,13 +47,11 @@ def detect_target_worker(
local_logger.info("Logger initialized", True)

result, detector = detect_target_factory.create_detect_target(
save_name,
show_annotations,
detect_target_option,
device,
model_path,
override_full,
config,
local_logger,
show_annotations,
save_name,
)

if not result:
Expand Down
Loading

0 comments on commit ac8ba42

Please sign in to comment.