Skip to content

Commit

Permalink
Add configs to target detection pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
siddhp1 committed Jan 22, 2025
1 parent 9f21e4b commit 67862df
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 21 deletions.
6 changes: 4 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ video_input:

detect_target:
worker_count: 1
option: 0 # 0 is for Ultralytics (from detect_target_factory.py)
option: 0 # 0 is for Ultralytics and 1 is for brightspot (from detect_target_factory.py)
save_prefix: "log_comp"

detect_ultralytics:
device: 0
model_path: "tests/model_example/yolov8s_ultralytics_pretrained_default.pt" # See autonomy OneDrive for latest model
save_prefix: "log_comp"

detect_brightspot:
brightspot_percentile_threshold: 99.9
Expand Down
75 changes: 68 additions & 7 deletions main_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from modules.common.modules.camera import camera_opencv
from modules.common.modules.camera import camera_picamera2
from modules.communications import communications_worker
from modules.detect_target import detect_target_brightspot
from modules.detect_target import detect_target_factory
from modules.detect_target import detect_target_worker
from modules.detect_target import detect_target_ultralytics
from modules.flight_interface import flight_interface_worker
from modules.video_input import video_input_worker
from modules.data_merge import data_merge_worker
Expand Down Expand Up @@ -105,17 +107,48 @@ def main() -> int:
config["video_input"]["image_name"] if config["video_input"]["log_images"] else None
)

DETECT_TARGET_WORKER_COUNT = config["detect_target"]["worker_count"]
DETECT_TARGET_OPTION = detect_target_factory.DetectTargetOption(
config["detect_target"]["option"]
)
DETECT_TARGET_DEVICE = "cpu" if args.cpu else config["detect_target"]["device"]
DETECT_TARGET_MODEL_PATH = config["detect_target"]["model_path"]
DETECT_TARGET_OVERRIDE_FULL_PRECISION = args.full
DETECT_TARGET_SAVE_PREFIX = str(
pathlib.Path(logging_path, config["detect_target"]["save_prefix"])
)
DETECT_TARGET_SHOW_ANNOTATED = args.show_annotated
DETECT_TARGET_WORKER_COUNT = config["detect_target"]["worker_count"]

DETECT_TARGET_ULTRALYTICS_DEVICE = (
"cpu" if args.cpu else config["detect_ultralytics"]["device"]
)
DETECT_TARGET_ULTRALYTICS_MODEL_PATH = config["detect_ultralytics"]["model_path"]
DETECT_TARGET_ULTRALYTICS_OVERRIDE_FULL_PRECISION = args.full

DETECT_TARGET_BRIGHTSPOT_PERCENTILE_THRESHOLD = config["detect_brightspot"][
"brightspot_percentile_threshold"
]
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_COLOR = config["detect_brightspot"]["filter_by_color"]
DETECT_TARGET_BRIGHTSPOT_BLOB_COLOR = config["detect_brightspot"]["blob_color"]
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_CIRCULARITY = config["detect_brightspot"][
"filter_by_circularity"
]
DETECT_TARGET_BRIGHTSPOT_MIN_CIRCULARITY = config["detect_brightspot"]["min_circularity"]
DETECT_TARGET_BRIGHTSPOT_MAX_CIRCULARITY = config["detect_brightspot"]["max_circularity"]
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_INERTIA = config["detect_brightspot"][
"filter_by_inertia"
]
DETECT_TARGET_BRIGHTSPOT_MIN_INERTIA_RATIO = config["detect_brightspot"][
"min_inertia_ratio"
]
DETECT_TARGET_BRIGHTSPOT_MAX_INERTIA_RATIO = config["detect_brightspot"][
"max_inertia_ratio"
]
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_CONVEXITY = config["detect_brightspot"][
"filter_by_convexity"
]
DETECT_TARGET_BRIGHTSPOT_MIN_CONVEXITY = config["detect_brightspot"]["min_convexity"]
DETECT_TARGET_BRIGHTSPOT_MAX_CONVEXITY = config["detect_brightspot"]["max_convexity"]
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_AREA = config["detect_brightspot"]["filter_by_area"]
DETECT_TARGET_BRIGHTSPOT_MIN_AREA_PIXELS = config["detect_brightspot"]["min_area_pixels"]
DETECT_TARGET_BRIGHTSPOT_MAX_AREA_PIXELS = config["detect_brightspot"]["max_area_pixels"]

FLIGHT_INTERFACE_ADDRESS = config["flight_interface"]["address"]
FLIGHT_INTERFACE_TIMEOUT = config["flight_interface"]["timeout"]
Expand Down Expand Up @@ -240,14 +273,42 @@ def main() -> int:
# Get Pylance to stop complaining
assert video_input_worker_properties is not None

detect_target_brightspot_config = detect_target_brightspot.DetectTargetBrightspotConfig(
DETECT_TARGET_BRIGHTSPOT_PERCENTILE_THRESHOLD,
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_COLOR,
DETECT_TARGET_BRIGHTSPOT_BLOB_COLOR,
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_CIRCULARITY,
DETECT_TARGET_BRIGHTSPOT_MIN_CIRCULARITY,
DETECT_TARGET_BRIGHTSPOT_MAX_CIRCULARITY,
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_INERTIA,
DETECT_TARGET_BRIGHTSPOT_MIN_INERTIA_RATIO,
DETECT_TARGET_BRIGHTSPOT_MAX_INERTIA_RATIO,
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_CONVEXITY,
DETECT_TARGET_BRIGHTSPOT_MIN_CONVEXITY,
DETECT_TARGET_BRIGHTSPOT_MAX_CONVEXITY,
DETECT_TARGET_BRIGHTSPOT_FILTER_BY_AREA,
DETECT_TARGET_BRIGHTSPOT_MIN_AREA_PIXELS,
DETECT_TARGET_BRIGHTSPOT_MAX_AREA_PIXELS,
)

detect_target_ultralytics_config = detect_target_ultralytics.DetectTargetUltralyticsConfig(
DETECT_TARGET_ULTRALYTICS_DEVICE,
DETECT_TARGET_ULTRALYTICS_MODEL_PATH,
DETECT_TARGET_ULTRALYTICS_OVERRIDE_FULL_PRECISION,
)

match DETECT_TARGET_OPTION:
case detect_target_factory.DetectTargetOption.ML_ULTRALYTICS:
detect_target_config = detect_target_ultralytics_config
case detect_target_factory.DetectTargetOption.CV_BRIGHTSPOT:
detect_target_config = detect_target_brightspot_config

result, detect_target_worker_properties = worker_manager.WorkerProperties.create(
count=DETECT_TARGET_WORKER_COUNT,
target=detect_target_worker.detect_target_worker,
work_arguments=(
DETECT_TARGET_OPTION,
DETECT_TARGET_DEVICE,
DETECT_TARGET_MODEL_PATH,
DETECT_TARGET_OVERRIDE_FULL_PRECISION,
detect_target_config,
DETECT_TARGET_SHOW_ANNOTATED,
DETECT_TARGET_SAVE_PREFIX,
),
Expand Down
12 changes: 6 additions & 6 deletions modules/detect_target/detect_target_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ class DetectTargetOption(enum.Enum):

def create_detect_target(
detect_target_option: DetectTargetOption,
device: "str | int",
model_path: str,
override_full: bool,
config: (
detect_target_brightspot.DetectTargetBrightspotConfig
| detect_target_ultralytics.DetectTargetUltralyticsConfig
),
local_logger: logger.Logger,
show_annotations: bool,
save_name: str,
Expand All @@ -34,15 +35,14 @@ def create_detect_target(
match detect_target_option:
case DetectTargetOption.ML_ULTRALYTICS:
return True, detect_target_ultralytics.DetectTargetUltralytics(
device,
model_path,
override_full,
config,
local_logger,
show_annotations,
save_name,
)
case DetectTargetOption.CV_BRIGHTSPOT:
return True, detect_target_brightspot.DetectTargetBrightspot(
config,
local_logger,
show_annotations,
save_name,
Expand Down
13 changes: 7 additions & 6 deletions modules/detect_target/detect_target_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@

from utilities.workers import queue_proxy_wrapper
from utilities.workers import worker_controller
from . import detect_target_brightspot
from . import detect_target_factory
from . import detect_target_ultralytics
from ..common.modules.logger import logger


def detect_target_worker(
detect_target_option: detect_target_factory.DetectTargetOption,
device: "str | int",
model_path: str,
override_full: bool,
config: (
detect_target_brightspot.DetectTargetBrightspotConfig
| detect_target_ultralytics.DetectTargetUltralyticsConfig
),
show_annotations: bool,
save_name: str,
input_queue: queue_proxy_wrapper.QueueProxyWrapper,
Expand Down Expand Up @@ -44,9 +47,7 @@ def detect_target_worker(

result, detector = detect_target_factory.create_detect_target(
detect_target_option,
device,
model_path,
override_full,
config,
local_logger,
show_annotations,
save_name,
Expand Down

0 comments on commit 67862df

Please sign in to comment.