diff --git a/machine-learning/app/models/facial_recognition/recognition.py b/machine-learning/app/models/facial_recognition/recognition.py index a7a02b0b24cd4..c060bdd61634f 100644 --- a/machine-learning/app/models/facial_recognition/recognition.py +++ b/machine-learning/app/models/facial_recognition/recognition.py @@ -26,7 +26,7 @@ def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) def _load(self) -> ModelSession: session = self._make_session(self.model_path) - if self.batch and session.get_inputs()[0].shape[0] != "batch": + if self.batch and str(session.get_inputs()[0].shape[0]) != "batch": self._add_batch_axis(self.model_path) session = self._make_session(self.model_path) self.model = ArcFaceONNX(