Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configs to detect target pipeline #242

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ video_input:

detect_target:
worker_count: 1
option: 0 # 0 is for Ultralytics (from detect_target_factory.py)
option: 0 # 0 is for Ultralytics and 1 is for brightspot (from detect_target_factory.py)
save_prefix: "log_comp"

detect_ultralytics:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this under detect_target and call it config, like it is for video_input?

device: 0
model_path: "tests/model_example/yolov8s_ultralytics_pretrained_default.pt" # See autonomy OneDrive for latest model
save_prefix: "log_comp"

detect_brightspot:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just noticed this now, can we move this section under detect_target and call it config, like it is for video_input? In this case, you will comment out this whole section because we are using the ultralytics config

brightspot_percentile_threshold: 99.9
Expand Down
75 changes: 68 additions & 7 deletions main_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from modules.common.modules.camera import camera_opencv
from modules.common.modules.camera import camera_picamera2
from modules.communications import communications_worker
from modules.detect_target import detect_target_brightspot
from modules.detect_target import detect_target_factory
from modules.detect_target import detect_target_worker
from modules.detect_target import detect_target_ultralytics
from modules.flight_interface import flight_interface_worker
from modules.video_input import video_input_worker
from modules.data_merge import data_merge_worker
Expand Down Expand Up @@ -105,17 +107,48 @@ def main() -> int:
config["video_input"]["image_name"] if config["video_input"]["log_images"] else None
)

DETECT_TARGET_WORKER_COUNT = config["detect_target"]["worker_count"]
DETECT_TARGET_OPTION = detect_target_factory.DetectTargetOption(
config["detect_target"]["option"]
)
DETECT_TARGET_DEVICE = "cpu" if args.cpu else config["detect_target"]["device"]
DETECT_TARGET_MODEL_PATH = config["detect_target"]["model_path"]
DETECT_TARGET_OVERRIDE_FULL_PRECISION = args.full
DETECT_TARGET_SAVE_PREFIX = str(
pathlib.Path(logging_path, config["detect_target"]["save_prefix"])
)
DETECT_TARGET_SHOW_ANNOTATED = args.show_annotated
DETECT_TARGET_WORKER_COUNT = config["detect_target"]["worker_count"]

DETECT_TARGET_ULTRALYTICS_DEVICE = (
"cpu" if args.cpu else config["detect_ultralytics"]["device"]
)
DETECT_TARGET_ULTRALYTICS_MODEL_PATH = config["detect_ultralytics"]["model_path"]
DETECT_TARGET_ULTRALYTICS_OVERRIDE_FULL_PRECISION = args.full

DETECT_TARGET_BRIGHTSPOT_PERCENTILE_THRESHOLD = config["detect_brightspot"][
"brightspot_percentile_threshold"
]
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_COLOR = config["detect_brightspot"]["filter_by_color"]
DETECT_TARGET_BRIGHTSPOT_BLOB_COLOR = config["detect_brightspot"]["blob_color"]
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_CIRCULARITY = config["detect_brightspot"][
"filter_by_circularity"
]
DETECT_TARGET_BRIGHTSPOT_MIN_CIRCULARITY = config["detect_brightspot"]["min_circularity"]
DETECT_TARGET_BRIGHTSPOT_MAX_CIRCULARITY = config["detect_brightspot"]["max_circularity"]
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_INERTIA = config["detect_brightspot"][
"filter_by_inertia"
]
DETECT_TARGET_BRIGHTSPOT_MIN_INERTIA_RATIO = config["detect_brightspot"][
"min_inertia_ratio"
]
DETECT_TARGET_BRIGHTSPOT_MAX_INERTIA_RATIO = config["detect_brightspot"][
"max_inertia_ratio"
]
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_CONVEXITY = config["detect_brightspot"][
"filter_by_convexity"
]
DETECT_TARGET_BRIGHTSPOT_MIN_CONVEXITY = config["detect_brightspot"]["min_convexity"]
DETECT_TARGET_BRIGHTSPOT_MAX_CONVEXITY = config["detect_brightspot"]["max_convexity"]
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_AREA = config["detect_brightspot"]["filter_by_area"]
DETECT_TARGET_BRIGHTSPOT_MIN_AREA_PIXELS = config["detect_brightspot"]["min_area_pixels"]
DETECT_TARGET_BRIGHTSPOT_MAX_AREA_PIXELS = config["detect_brightspot"]["max_area_pixels"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having this many variables, just have one DETECT_TARGET_CONFIG and use match case here (see VIDEO_INPUT_CAMERA_CONFIG above


FLIGHT_INTERFACE_ADDRESS = config["flight_interface"]["address"]
FLIGHT_INTERFACE_TIMEOUT = config["flight_interface"]["timeout"]
Expand Down Expand Up @@ -240,14 +273,42 @@ def main() -> int:
# Get Pylance to stop complaining
assert video_input_worker_properties is not None

detect_target_brightspot_config = detect_target_brightspot.DetectTargetBrightspotConfig(
DETECT_TARGET_BRIGHTSPOT_PERCENTILE_THRESHOLD,
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_COLOR,
DETECT_TARGET_BRIGHTSPOT_BLOB_COLOR,
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_CIRCULARITY,
DETECT_TARGET_BRIGHTSPOT_MIN_CIRCULARITY,
DETECT_TARGET_BRIGHTSPOT_MAX_CIRCULARITY,
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_INERTIA,
DETECT_TARGET_BRIGHTSPOT_MIN_INERTIA_RATIO,
DETECT_TARGET_BRIGHTSPOT_MAX_INERTIA_RATIO,
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_CONVEXITY,
DETECT_TARGET_BRIGHTSPOT_MIN_CONVEXITY,
DETECT_TARGET_BRIGHTSPOT_MAX_CONVEXITY,
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_AREA,
DETECT_TARGET_BRIGHTSPOT_MIN_AREA_PIXELS,
DETECT_TARGET_BRIGHTSPOT_MAX_AREA_PIXELS,
)

detect_target_ultralytics_config = detect_target_ultralytics.DetectTargetUltralyticsConfig(
DETECT_TARGET_ULTRALYTICS_DEVICE,
DETECT_TARGET_ULTRALYTICS_MODEL_PATH,
DETECT_TARGET_ULTRALYTICS_OVERRIDE_FULL_PRECISION,
)

match DETECT_TARGET_OPTION:
case detect_target_factory.DetectTargetOption.ML_ULTRALYTICS:
detect_target_config = detect_target_ultralytics_config
case detect_target_factory.DetectTargetOption.CV_BRIGHTSPOT:
detect_target_config = detect_target_brightspot_config

result, detect_target_worker_properties = worker_manager.WorkerProperties.create(
count=DETECT_TARGET_WORKER_COUNT,
target=detect_target_worker.detect_target_worker,
work_arguments=(
DETECT_TARGET_OPTION,
DETECT_TARGET_DEVICE,
DETECT_TARGET_MODEL_PATH,
DETECT_TARGET_OVERRIDE_FULL_PRECISION,
detect_target_config,
DETECT_TARGET_SHOW_ANNOTATED,
DETECT_TARGET_SAVE_PREFIX,
),
Expand Down
12 changes: 6 additions & 6 deletions modules/detect_target/detect_target_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ class DetectTargetOption(enum.Enum):

def create_detect_target(
detect_target_option: DetectTargetOption,
device: "str | int",
model_path: str,
override_full: bool,
config: (
detect_target_brightspot.DetectTargetBrightspotConfig
| detect_target_ultralytics.DetectTargetUltralyticsConfig
),
local_logger: logger.Logger,
show_annotations: bool,
save_name: str,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we re-order these so that it goes save_name, show_annotations, config, local_logger?

Expand All @@ -34,15 +35,14 @@ def create_detect_target(
match detect_target_option:
case DetectTargetOption.ML_ULTRALYTICS:
return True, detect_target_ultralytics.DetectTargetUltralytics(
device,
model_path,
override_full,
config,
local_logger,
show_annotations,
save_name,
)
case DetectTargetOption.CV_BRIGHTSPOT:
return True, detect_target_brightspot.DetectTargetBrightspot(
config,
local_logger,
show_annotations,
save_name,
Expand Down
35 changes: 28 additions & 7 deletions modules/detect_target/detect_target_ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,37 @@
from ..common.modules.logger import logger


class DetectTargetUltralytics(base_detect_target.BaseDetectTarget):
class DetectTargetUltralyticsConfig:
"""
Contains the YOLOv8 model for prediction.
Configuration for DetectTargetUltralytics.
"""

def __init__(
self,
device: "str | int",
model_path: str,
override_full: bool,
) -> None:
"""
Initializes the configuration for DetectTargetUltralytics.

device: name of target device to run inference on (i.e. "cpu" or cuda device 0, 1, 2, 3).
model_path: path to the YOLOv8 model.
override_full: Force full precision floating point calculations.
"""
self.device = device
self.model_path = model_path
self.override_full = override_full


class DetectTargetUltralytics(base_detect_target.BaseDetectTarget):
"""
Contains the YOLOv8 model for prediction.
"""

def __init__(
self,
config: DetectTargetUltralyticsConfig,
local_logger: logger.Logger,
show_annotations: bool = False,
save_name: str = "",
Expand All @@ -34,14 +55,14 @@ def __init__(
show_annotations: Display annotated images.
save_name: filename prefix for logging detections and annotated images.
"""
self.__device = device
self.__model = ultralytics.YOLO(model_path)
self.__counter = 0
self.__device = config.device
self.__enable_half_precision = not self.__device == "cpu"
self.__model = ultralytics.YOLO(config.model_path)
if config.override_full:
self.__enable_half_precision = False
self.__counter = 0
self.__local_logger = local_logger
self.__show_annotations = show_annotations
if override_full:
self.__enable_half_precision = False
self.__filename_prefix = ""
if save_name != "":
self.__filename_prefix = save_name + "_" + str(int(time.time())) + "_"
Expand Down
13 changes: 7 additions & 6 deletions modules/detect_target/detect_target_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@

from utilities.workers import queue_proxy_wrapper
from utilities.workers import worker_controller
from . import detect_target_brightspot
from . import detect_target_factory
from . import detect_target_ultralytics
from ..common.modules.logger import logger


def detect_target_worker(
detect_target_option: detect_target_factory.DetectTargetOption,
device: "str | int",
model_path: str,
override_full: bool,
config: (
detect_target_brightspot.DetectTargetBrightspotConfig
| detect_target_ultralytics.DetectTargetUltralyticsConfig
),
show_annotations: bool,
save_name: str,
input_queue: queue_proxy_wrapper.QueueProxyWrapper,
Expand Down Expand Up @@ -44,9 +47,7 @@ def detect_target_worker(

result, detector = detect_target_factory.create_detect_target(
detect_target_option,
device,
model_path,
override_full,
config,
local_logger,
show_annotations,
save_name,
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/test_detect_target_ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,13 @@ def detector() -> detect_target_ultralytics.DetectTargetUltralytics: # type: ig
assert result
assert test_logger is not None

detection = detect_target_ultralytics.DetectTargetUltralytics(
DEVICE, str(MODEL_PATH), OVERRIDE_FULL, test_logger
config = detect_target_ultralytics.DetectTargetUltralyticsConfig(
DEVICE,
str(MODEL_PATH),
OVERRIDE_FULL,
)

detection = detect_target_ultralytics.DetectTargetUltralytics(config, test_logger)
yield detection # type: ignore


Expand Down
Loading