Skip to content

Commit

Permalink
fix(ml): batch axis not being added for recognition model (#12588)
Browse files Browse the repository at this point in the history
* fix has_batch_axis

* fix typing
  • Loading branch information
mertalev committed Sep 12, 2024
1 parent fa095c3 commit 22dc9bc
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from app.models.base import InferenceModel
from app.models.transforms import decode_cv2
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType
from app.sessions import has_batch_axis


class FaceRecognizer(InferenceModel):
Expand All @@ -27,7 +26,7 @@ def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any)

def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
if self.batch and not has_batch_axis(session):
if self.batch and str(session.get_inputs()[0].shape[0]) != "batch":
self._add_batch_axis(self.model_path)
session = self._make_session(self.model_path)
self.model = ArcFaceONNX(
Expand Down
5 changes: 0 additions & 5 deletions machine-learning/app/sessions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
from app.schemas import ModelSession


def has_batch_axis(session: ModelSession) -> bool:
return not isinstance(session.get_inputs()[0].shape[0], int) or session.get_inputs()[0].shape[0] < 0

0 comments on commit 22dc9bc

Please sign in to comment.