From 034e8f4e40e09e84bbda9e1bb79a39d0ddc7ad5c Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Wed, 11 Sep 2024 20:35:25 -0400 Subject: [PATCH 1/2] fix has_batch_axis --- .../app/models/facial_recognition/recognition.py | 3 +-- machine-learning/app/sessions/__init__.py | 5 ----- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/machine-learning/app/models/facial_recognition/recognition.py b/machine-learning/app/models/facial_recognition/recognition.py index d9ceb12b6d590..a7a02b0b24cd4 100644 --- a/machine-learning/app/models/facial_recognition/recognition.py +++ b/machine-learning/app/models/facial_recognition/recognition.py @@ -13,7 +13,6 @@ from app.models.base import InferenceModel from app.models.transforms import decode_cv2 from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType -from app.sessions import has_batch_axis class FaceRecognizer(InferenceModel): @@ -27,7 +26,7 @@ def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) def _load(self) -> ModelSession: session = self._make_session(self.model_path) - if self.batch and not has_batch_axis(session): + if self.batch and session.get_inputs()[0].shape[0] != "batch": self._add_batch_axis(self.model_path) session = self._make_session(self.model_path) self.model = ArcFaceONNX( diff --git a/machine-learning/app/sessions/__init__.py b/machine-learning/app/sessions/__init__.py index e0c00ea4a0472..e69de29bb2d1d 100644 --- a/machine-learning/app/sessions/__init__.py +++ b/machine-learning/app/sessions/__init__.py @@ -1,5 +0,0 @@ -from app.schemas import ModelSession - - -def has_batch_axis(session: ModelSession) -> bool: - return not isinstance(session.get_inputs()[0].shape[0], int) or session.get_inputs()[0].shape[0] < 0 From 861d394864345370eca9a0e0550d7f6e9597c782 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Wed, 11 Sep 2024 21:30:28 -0400 Subject: [PATCH 2/2] fix typing --- machine-learning/app/models/facial_recognition/recognition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/machine-learning/app/models/facial_recognition/recognition.py b/machine-learning/app/models/facial_recognition/recognition.py index a7a02b0b24cd4..c060bdd61634f 100644 --- a/machine-learning/app/models/facial_recognition/recognition.py +++ b/machine-learning/app/models/facial_recognition/recognition.py @@ -26,7 +26,7 @@ def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) def _load(self) -> ModelSession: session = self._make_session(self.model_path) - if self.batch and session.get_inputs()[0].shape[0] != "batch": + if self.batch and str(session.get_inputs()[0].shape[0]) != "batch": self._add_batch_axis(self.model_path) session = self._make_session(self.model_path) self.model = ArcFaceONNX(