diff --git a/machine-learning/app/models/facial_recognition/recognition.py b/machine-learning/app/models/facial_recognition/recognition.py index d9ceb12b6d590..c060bdd61634f 100644 --- a/machine-learning/app/models/facial_recognition/recognition.py +++ b/machine-learning/app/models/facial_recognition/recognition.py @@ -13,7 +13,6 @@ from app.models.base import InferenceModel from app.models.transforms import decode_cv2 from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType -from app.sessions import has_batch_axis class FaceRecognizer(InferenceModel): @@ -27,7 +26,7 @@ def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) def _load(self) -> ModelSession: session = self._make_session(self.model_path) - if self.batch and not has_batch_axis(session): + if self.batch and str(session.get_inputs()[0].shape[0]) != "batch": self._add_batch_axis(self.model_path) session = self._make_session(self.model_path) self.model = ArcFaceONNX( diff --git a/machine-learning/app/sessions/__init__.py b/machine-learning/app/sessions/__init__.py index e0c00ea4a0472..e69de29bb2d1d 100644 --- a/machine-learning/app/sessions/__init__.py +++ b/machine-learning/app/sessions/__init__.py @@ -1,5 +0,0 @@ -from app.schemas import ModelSession - - -def has_batch_axis(session: ModelSession) -> bool: - return not isinstance(session.get_inputs()[0].shape[0], int) or session.get_inputs()[0].shape[0] < 0