Skip to content

Commit

Permalink
fix has_batch_axis
Browse files Browse the repository at this point in the history
  • Loading branch information
mertalev committed Sep 12, 2024
1 parent ad58d7e commit 034e8f4
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from app.models.base import InferenceModel
from app.models.transforms import decode_cv2
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType
from app.sessions import has_batch_axis


class FaceRecognizer(InferenceModel):
Expand All @@ -27,7 +26,7 @@ def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any)

def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
if self.batch and not has_batch_axis(session):
if self.batch and session.get_inputs()[0].shape[0] != "batch":
self._add_batch_axis(self.model_path)
session = self._make_session(self.model_path)
self.model = ArcFaceONNX(
Expand Down
5 changes: 0 additions & 5 deletions machine-learning/app/sessions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
from app.schemas import ModelSession


def has_batch_axis(session: ModelSession) -> bool:
return not isinstance(session.get_inputs()[0].shape[0], int) or session.get_inputs()[0].shape[0] < 0

0 comments on commit 034e8f4

Please sign in to comment.