diff --git a/src/main/java/qupath/ext/wsinfer/ui/PytorchController.java b/src/main/java/qupath/ext/wsinfer/ui/PytorchController.java deleted file mode 100644 index 699b085..0000000 --- a/src/main/java/qupath/ext/wsinfer/ui/PytorchController.java +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2023 University of Edinburgh - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package qupath.ext.wsinfer.ui; - -import javafx.event.ActionEvent; - -/** - * Controller for the PyTorch Error interface, used with WSInfer extension - */ -public class PytorchController { - public void pytorchDownload(ActionEvent actionEvent) { - } -} diff --git a/src/main/java/qupath/ext/wsinfer/ui/PytorchManager.java b/src/main/java/qupath/ext/wsinfer/ui/PytorchManager.java index acb4458..7fb5e76 100644 --- a/src/main/java/qupath/ext/wsinfer/ui/PytorchManager.java +++ b/src/main/java/qupath/ext/wsinfer/ui/PytorchManager.java @@ -17,12 +17,15 @@ package qupath.ext.wsinfer.ui; import ai.djl.engine.Engine; +import ai.djl.engine.EngineException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import qupath.lib.common.GeneralTools; +import java.io.IOException; import java.util.Collection; import java.util.LinkedHashSet; +import java.util.List; import java.util.Set; import java.util.concurrent.Callable; @@ -39,26 +42,28 @@ class PytorchManager { */ static Collection getAvailableDevices() { Set availableDevices = new LinkedHashSet<>(); - boolean includesMPS = false; // Don't add MPS twice - var engine = getEngineOffline(); - if (engine != null) { - // This is expected to return GPUs if available, or CPU otherwise - for (var device : engine.getDevices()) { - String name = device.getDeviceType(); - availableDevices.add(name); - if (name.toLowerCase().startsWith("mps")) - includesMPS = true; + try { + + var engine = getEngineOffline(); + if (engine != null) { + // This is expected to return GPUs if available, or CPU otherwise + for (var device : engine.getDevices()) { + String name = device.getDeviceType(); + availableDevices.add(name); + } } - } - // CPU should always be available - if (!availableDevices.contains("cpu")) + // CPU should always be available availableDevices.add("cpu"); - // If we could use MPS, but don't have it already, add it - if (!includesMPS && GeneralTools.isMac() && "aarch64".equals(System.getProperty("os.arch"))) { - availableDevices.add("mps"); + // If we could use MPS, but don't have it already, add it + if (GeneralTools.isMac() && "aarch64".equals(System.getProperty("os.arch"))) { + availableDevices.add("mps"); + } + return availableDevices; + } catch (Exception e) { + logger.info("Unable to load engine", e); + return List.of(); } - return availableDevices; } /** @@ -76,6 +81,9 @@ static boolean hasPyTorchEngine() { static Engine getEngineOffline() { try { return callOffline(() -> Engine.getEngine("PyTorch")); + } catch (EngineException | IOException e) { + logger.info("Unable to load PyTorch", e); + return null; } catch (Exception e) { logger.error(e.getMessage(), e); return null; @@ -97,13 +105,13 @@ static Engine getEngineOnline() { /** * Call a function with the "offline" property set to true (to block automatic downloads). - * @param callable + * @param callable Function that'll be called offline. * @return * @param - * @throws Exception + * @throws EngineException If the engine can't be loaded (probably because it hasn't been downloaded) */ private static T callOffline(Callable callable) throws Exception { - return callWithTempProperty("offline", "true", callable); + return callWithTempProperty("ai.djl.offline", "true", callable); } /** @@ -114,7 +122,7 @@ private static T callOffline(Callable callable) throws Exception { * @throws Exception */ private static T callOnline(Callable callable) throws Exception { - return callWithTempProperty("offline", "false", callable); + return callWithTempProperty("ai.djl.offline", "false", callable); }