diff --git a/tagger/interrogator.py b/tagger/interrogator.py index b644a5f..12e3869 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -241,10 +241,15 @@ def load(self) -> None: # https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime # TODO: remove old package when the environment changes? from launch import is_installed, run_pip + from platform import system + package_name = "onnxruntime-gpu" + if system() == "Darwin": + package_name = "onnxruntime-silicon" + if not is_installed('onnxruntime'): package = os.environ.get( 'ONNXRUNTIME_PACKAGE', - 'onnxruntime-gpu' + package_name ) run_pip(f'install {package}', 'onnxruntime')