Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configs to detect target pipeline #242

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,29 @@ video_input:

detect_target:
worker_count: 1
option: 0 # 0 is for Ultralytics (from detect_target_factory.py)
device: 0
model_path: "tests/model_example/yolov8s_ultralytics_pretrained_default.pt" # See autonomy OneDrive for latest model
option: 0 # 0 is for Ultralytics and 1 is for brightspot (from detect_target_factory.py)
save_prefix: "log_comp"

detect_brightspot:
brightspot_percentile_threshold: 99.9
filter_by_color: True
blob_color: 255
filter_by_circularity: False
min_circularity: 0.01
max_circularity: 1
filter_by_inertia: True
min_inertia_ratio: 0.2
max_inertia_ratio: 1
filter_by_convexity: False
min_convexity: 0.01
max_convexity: 1
filter_by_area: True
min_area_pixels: 50
max_area_pixels: 640
# Ultralytics config (enum 0)
config:
device: 0
model_path: "tests/model_example/yolov8s_ultralytics_pretrained_default.pt" # See autonomy OneDrive for latest model
# Brightspot config (enum 1)
# config:
# brightspot_percentile_threshold: 99.9
# filter_by_color: True
# blob_color: 255
# filter_by_circularity: False
# min_circularity: 0.01
# max_circularity: 1
# filter_by_inertia: True
# min_inertia_ratio: 0.2
# max_inertia_ratio: 1
# filter_by_convexity: False
# min_convexity: 0.01
# max_convexity: 1
# filter_by_area: True
# min_area_pixels: 50
# max_area_pixels: 640

flight_interface:
# Port 5762 connects directly to the simulated auto pilot, which is more realistic
Expand Down
29 changes: 21 additions & 8 deletions main_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from modules.common.modules.camera import camera_opencv
from modules.common.modules.camera import camera_picamera2
from modules.communications import communications_worker
from modules.detect_target import detect_target_brightspot
from modules.detect_target import detect_target_factory
from modules.detect_target import detect_target_worker
from modules.detect_target import detect_target_ultralytics
from modules.flight_interface import flight_interface_worker
from modules.video_input import video_input_worker
from modules.data_merge import data_merge_worker
Expand Down Expand Up @@ -109,13 +111,26 @@ def main() -> int:
DETECT_TARGET_OPTION = detect_target_factory.DetectTargetOption(
config["detect_target"]["option"]
)
DETECT_TARGET_DEVICE = "cpu" if args.cpu else config["detect_target"]["device"]
DETECT_TARGET_MODEL_PATH = config["detect_target"]["model_path"]
DETECT_TARGET_OVERRIDE_FULL_PRECISION = args.full
DETECT_TARGET_SAVE_PREFIX = str(
pathlib.Path(logging_path, config["detect_target"]["save_prefix"])
)
DETECT_TARGET_SHOW_ANNOTATED = args.show_annotated
match DETECT_TARGET_OPTION:
case detect_target_factory.DetectTargetOption.ML_ULTRALYTICS:
DETECT_TARGET_CONFIG = detect_target_ultralytics.DetectTargetUltralyticsConfig(
config["detect_target"]["config"]["device"],
config["detect_target"]["config"]["model_path"],
args.full,
)
case detect_target_factory.DetectTargetOption.CV_BRIGHTSPOT:
DETECT_TARGET_CONFIG = detect_target_brightspot.DetectTargetBrightspotConfig(
**config["detect_target"]["config"]
)
case _:
main.logger.error(
f"Inputted an invalid detect target option: {DETECT_TARGET_OPTION}", True
)
return -1

FLIGHT_INTERFACE_ADDRESS = config["flight_interface"]["address"]
FLIGHT_INTERFACE_TIMEOUT = config["flight_interface"]["timeout"]
Expand Down Expand Up @@ -244,12 +259,10 @@ def main() -> int:
count=DETECT_TARGET_WORKER_COUNT,
target=detect_target_worker.detect_target_worker,
work_arguments=(
DETECT_TARGET_OPTION,
DETECT_TARGET_DEVICE,
DETECT_TARGET_MODEL_PATH,
DETECT_TARGET_OVERRIDE_FULL_PRECISION,
DETECT_TARGET_SHOW_ANNOTATED,
DETECT_TARGET_SAVE_PREFIX,
DETECT_TARGET_SHOW_ANNOTATED,
DETECT_TARGET_OPTION,
DETECT_TARGET_CONFIG,
),
input_queues=[video_input_to_detect_target_queue],
output_queues=[detect_target_to_data_merge_queue],
Expand Down
16 changes: 8 additions & 8 deletions modules/detect_target/detect_target_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,29 @@ class DetectTargetOption(enum.Enum):


def create_detect_target(
save_name: str,
show_annotations: bool,
detect_target_option: DetectTargetOption,
device: "str | int",
model_path: str,
override_full: bool,
config: (
detect_target_brightspot.DetectTargetBrightspotConfig
| detect_target_ultralytics.DetectTargetUltralyticsConfig
),
local_logger: logger.Logger,
show_annotations: bool,
save_name: str,
) -> tuple[bool, base_detect_target.BaseDetectTarget | None]:
"""
Construct detect target class at runtime.
"""
match detect_target_option:
case DetectTargetOption.ML_ULTRALYTICS:
return True, detect_target_ultralytics.DetectTargetUltralytics(
device,
model_path,
override_full,
config,
local_logger,
show_annotations,
save_name,
)
case DetectTargetOption.CV_BRIGHTSPOT:
return True, detect_target_brightspot.DetectTargetBrightspot(
config,
local_logger,
show_annotations,
save_name,
Expand Down
35 changes: 28 additions & 7 deletions modules/detect_target/detect_target_ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,37 @@
from ..common.modules.logger import logger


class DetectTargetUltralytics(base_detect_target.BaseDetectTarget):
class DetectTargetUltralyticsConfig:
"""
Contains the YOLOv8 model for prediction.
Configuration for DetectTargetUltralytics.
"""

def __init__(
self,
device: "str | int",
model_path: str,
override_full: bool,
) -> None:
"""
Initializes the configuration for DetectTargetUltralytics.

device: name of target device to run inference on (i.e. "cpu" or cuda device 0, 1, 2, 3).
model_path: path to the YOLOv8 model.
override_full: Force full precision floating point calculations.
"""
self.device = device
self.model_path = model_path
self.override_full = override_full


class DetectTargetUltralytics(base_detect_target.BaseDetectTarget):
"""
Contains the YOLOv8 model for prediction.
"""

def __init__(
self,
config: DetectTargetUltralyticsConfig,
local_logger: logger.Logger,
show_annotations: bool = False,
save_name: str = "",
Expand All @@ -34,14 +55,14 @@ def __init__(
show_annotations: Display annotated images.
save_name: filename prefix for logging detections and annotated images.
"""
self.__device = device
self.__model = ultralytics.YOLO(model_path)
self.__counter = 0
self.__device = config.device
self.__enable_half_precision = not self.__device == "cpu"
self.__model = ultralytics.YOLO(config.model_path)
if config.override_full:
self.__enable_half_precision = False
self.__counter = 0
self.__local_logger = local_logger
self.__show_annotations = show_annotations
if override_full:
self.__enable_half_precision = False
self.__filename_prefix = ""
if save_name != "":
self.__filename_prefix = save_name + "_" + str(int(time.time())) + "_"
Expand Down
21 changes: 11 additions & 10 deletions modules/detect_target/detect_target_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@

from utilities.workers import queue_proxy_wrapper
from utilities.workers import worker_controller
from . import detect_target_brightspot
from . import detect_target_factory
from . import detect_target_ultralytics
from ..common.modules.logger import logger


def detect_target_worker(
detect_target_option: detect_target_factory.DetectTargetOption,
device: "str | int",
model_path: str,
override_full: bool,
show_annotations: bool,
save_name: str,
show_annotations: bool,
detect_target_option: detect_target_factory.DetectTargetOption,
config: (
detect_target_brightspot.DetectTargetBrightspotConfig
| detect_target_ultralytics.DetectTargetUltralyticsConfig
),
input_queue: queue_proxy_wrapper.QueueProxyWrapper,
output_queue: queue_proxy_wrapper.QueueProxyWrapper,
controller: worker_controller.WorkerController,
Expand All @@ -43,13 +46,11 @@ def detect_target_worker(
local_logger.info("Logger initialized", True)

result, detector = detect_target_factory.create_detect_target(
save_name,
show_annotations,
detect_target_option,
device,
model_path,
override_full,
config,
local_logger,
show_annotations,
save_name,
)

if not result:
Expand Down
Loading
Loading