Skip to content

Commit

Permalink
(pose_estimation) allways use ~/data/pytorch, when just model provided
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthijsBurgh committed Feb 20, 2024
1 parent 9464d2e commit b32788a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion image_recognition_pose_estimation/scripts/detect_poses
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Detect poses in an image", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--pose_model", help="What pose model to use", default="~/data/pytorch_models/yolov8n-pose.pt")
parser.add_argument("--pose_model", help="What pose model to use", default="yolov8n-pose.pt")
parser.add_argument("--conf", help="Minimal confidence level for detection", default=0.25, type=float)
parser.add_argument(
"--verbose",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from pathlib import Path
from typing import List, Tuple

import numpy as np
Expand Down Expand Up @@ -48,14 +49,20 @@
]


YOLO_POSE_PATTERN = re.compile(r"^.*yolov8(?:([nsml])|(x))-pose(?(2)-p6|)?.pt$")
YOLO_POSE_PATTERN = re.compile(r"^yolov8(?:([nsml])|(x))-pose(?(2)-p6|)?.pt$")


class YoloPoseWrapper:
def __init__(self, model_name: str = "yolov8n-pose.pt", verbose: bool = False):
if not YOLO_POSE_PATTERN.match(model_name):
# Validate model name
model_name_path = Path(model_name)
model_basename = model_name_path.name
if not YOLO_POSE_PATTERN.match(model_basename):
raise ValueError(f"Model name '{model_name}' does not match pattern '{YOLO_POSE_PATTERN.pattern}'")

if len(model_name_path.parts) == 1:
model_name = str(Path.home() / "data" / "pytorch_models" / model_name)

self._model = YOLO(model=model_name, task="pose", verbose=verbose)

def detect_poses(self, image: np.ndarray, conf: float = 0.25) -> Tuple[List[Recognition], np.ndarray]:
Expand Down

0 comments on commit b32788a

Please sign in to comment.