diff --git a/src/main/java/qupath/ext/wsinfer/WSInferExtension.java b/src/main/java/qupath/ext/wsinfer/WSInferExtension.java index da80687..02a4f16 100644 --- a/src/main/java/qupath/ext/wsinfer/WSInferExtension.java +++ b/src/main/java/qupath/ext/wsinfer/WSInferExtension.java @@ -49,8 +49,6 @@ public class WSInferExtension implements QuPathExtension, GitHubProject { private final BooleanProperty enableExtensionProperty = PathPrefs.createPersistentPreference( "enableExtension", true); - - @Override public void installExtension(QuPathGUI qupath) { if (isInstalled) { diff --git a/src/main/java/qupath/ext/wsinfer/models/WSInferModel.java b/src/main/java/qupath/ext/wsinfer/models/WSInferModel.java index cddaaec..9ca3154 100644 --- a/src/main/java/qupath/ext/wsinfer/models/WSInferModel.java +++ b/src/main/java/qupath/ext/wsinfer/models/WSInferModel.java @@ -1,4 +1,4 @@ -/** + /** * Copyright 2023 University of Edinburgh * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -42,10 +42,10 @@ public class WSInferModel { private WSInferModelConfiguration configuration; @SerializedName("hf_repo_id") - private String hfRepoId; + String hfRepoId; @SerializedName("hf_revision") - private String hfRevision; + String hfRevision; public String getName() { return hfRepoId; @@ -118,7 +118,7 @@ private File getFile(String f) { return Paths.get(getModelDirectory().toString(), f).toFile(); } - private File getModelDirectory() { + File getModelDirectory() { return Paths.get(WSInferPrefs.modelDirectoryProperty().get(), hfRepoId, hfRevision).toFile(); } diff --git a/src/main/java/qupath/ext/wsinfer/models/WSInferModelCollection.java b/src/main/java/qupath/ext/wsinfer/models/WSInferModelCollection.java index 4a991c9..626cd76 100644 --- a/src/main/java/qupath/ext/wsinfer/models/WSInferModelCollection.java +++ b/src/main/java/qupath/ext/wsinfer/models/WSInferModelCollection.java @@ -33,6 +33,6 @@ public class WSInferModelCollection { * @return */ public Map getModels() { - return Collections.unmodifiableMap(models); + return Collections.synchronizedMap(models); } } diff --git a/src/main/java/qupath/ext/wsinfer/models/WSInferModelLocal.java b/src/main/java/qupath/ext/wsinfer/models/WSInferModelLocal.java new file mode 100644 index 0000000..e469af5 --- /dev/null +++ b/src/main/java/qupath/ext/wsinfer/models/WSInferModelLocal.java @@ -0,0 +1,35 @@ +package qupath.ext.wsinfer.models; + +import qupath.ext.wsinfer.ui.WSInferPrefs; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; + +public class WSInferModelLocal extends WSInferModel { + + private final File modelDirectory; + + public WSInferModelLocal(File modelDirectory) { + this.modelDirectory = modelDirectory; + this.hfRepoId = modelDirectory.getName(); + // todo: load any files, populate fields from them. + } + + @Override + File getModelDirectory() { + return this.modelDirectory; + } + + @Override + public boolean isValid() { + return getTSFile().exists() && getConfiguration() != null; + } + + @Override + public synchronized void downloadModel() {} + + @Override + public synchronized void removeCache() {} +} diff --git a/src/main/java/qupath/ext/wsinfer/models/WSInferUtils.java b/src/main/java/qupath/ext/wsinfer/models/WSInferUtils.java index 05f72c0..c24b721 100644 --- a/src/main/java/qupath/ext/wsinfer/models/WSInferUtils.java +++ b/src/main/java/qupath/ext/wsinfer/models/WSInferUtils.java @@ -33,6 +33,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Objects; /** * Utility class to help with working with WSInfer models. @@ -64,9 +65,27 @@ public static WSInferModelCollection getModelCollection() { cachedModelCollection = downloadModelCollection(); } } + String localModelDirectory = WSInferPrefs.localDirectoryProperty().get(); + if (localModelDirectory != null) { + addLocalModels(cachedModelCollection, localModelDirectory); + } return cachedModelCollection; } + private static void addLocalModels(WSInferModelCollection cachedModelCollection, String localModelDirectory) { + File modelDir = new File(localModelDirectory); + if (!modelDir.exists() || !modelDir.isDirectory()) { + return; + } + for (var model: Objects.requireNonNull(modelDir.listFiles())) { + var localModel = new WSInferModelLocal(model); + System.out.println(model); + System.out.println(localModel.getName()); + System.out.println(cachedModelCollection.getModels()); + cachedModelCollection.getModels().put(localModel.getName(), localModel); + } + } + /** * Download the model collection from the hugging face repo. * This replaces any previously cached version. diff --git a/src/main/java/qupath/ext/wsinfer/ui/WSInferController.java b/src/main/java/qupath/ext/wsinfer/ui/WSInferController.java index 70866b1..0dce1bc 100644 --- a/src/main/java/qupath/ext/wsinfer/ui/WSInferController.java +++ b/src/main/java/qupath/ext/wsinfer/ui/WSInferController.java @@ -25,7 +25,6 @@ import javafx.beans.property.ObjectProperty; import javafx.beans.property.SimpleIntegerProperty; import javafx.beans.property.SimpleObjectProperty; -import javafx.beans.value.ObservableBooleanValue; import javafx.beans.value.ObservableValue; import javafx.concurrent.Task; import javafx.concurrent.Worker; @@ -51,10 +50,10 @@ import qupath.ext.wsinfer.models.WSInferModel; import qupath.ext.wsinfer.models.WSInferModelCollection; import qupath.ext.wsinfer.models.WSInferUtils; +import qupath.fx.dialogs.Dialogs; import qupath.lib.common.ThreadTools; import qupath.lib.gui.QuPathGUI; import qupath.lib.gui.commands.Commands; -import qupath.fx.dialogs.Dialogs; import qupath.lib.images.ImageData; import qupath.lib.objects.PathAnnotationObject; import qupath.lib.objects.PathObject; @@ -83,7 +82,7 @@ public class WSInferController { private static final Logger logger = LoggerFactory.getLogger(WSInferController.class); public QuPathGUI qupath; - private ObjectProperty> imageDataProperty = new SimpleObjectProperty<>(); + private final ObjectProperty> imageDataProperty = new SimpleObjectProperty<>(); private MessageTextHelper messageTextHelper; @FXML @@ -112,6 +111,8 @@ public class WSInferController { private Spinner spinnerNumWorkers; @FXML private TextField tfModelDirectory; + @FXML + private TextField localModelDirectory; private final static ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings"); @@ -242,6 +243,8 @@ private void configureActionToggleButton(Action action, ToggleButton button) { private void configureModelDirectory() { tfModelDirectory.textProperty().bindBidirectional(WSInferPrefs.modelDirectoryProperty()); + localModelDirectory.textProperty().bindBidirectional(WSInferPrefs.localDirectoryProperty()); + localModelDirectory.textProperty().addListener((v, o, n) -> configureModelChoices()); } private void configureNumWorkers() { @@ -391,7 +394,7 @@ private WSInferTask(ImageData imageData, WSInferModel model) { } @Override - protected Void call() throws Exception { + protected Void call() { try { // Ensure PyTorch engine is available if (!PytorchManager.hasPyTorchEngine()) { @@ -428,7 +431,7 @@ protected Void call() throws Exception { */ private class MessageTextHelper { - private SelectedObjectCounter selectedObjectCounter; + private final SelectedObjectCounter selectedObjectCounter; /** * Text to display a warning (because inference can't be run) @@ -516,14 +519,14 @@ private String getWarningText() { */ private static class SelectedObjectCounter { - private ObjectProperty> imageDataProperty = new SimpleObjectProperty<>(); + private final ObjectProperty> imageDataProperty = new SimpleObjectProperty<>(); - private PathObjectSelectionListener selectionListener = this::selectedPathObjectChanged; + private final PathObjectSelectionListener selectionListener = this::selectedPathObjectChanged; - private ObservableValue hierarchyProperty; + private final ObservableValue hierarchyProperty; - private IntegerProperty numSelectedAnnotations = new SimpleIntegerProperty(); - private IntegerProperty numSelectedDetections = new SimpleIntegerProperty(); + private final IntegerProperty numSelectedAnnotations = new SimpleIntegerProperty(); + private final IntegerProperty numSelectedDetections = new SimpleIntegerProperty(); SelectedObjectCounter(ObservableValue> imageDataProperty) { this.imageDataProperty.bind(imageDataProperty); @@ -561,8 +564,16 @@ private void updateSelectedObjectCounts() { numSelectedDetections.set(0); } else { var selected = hierarchy.getSelectionModel().getSelectedObjects(); - numSelectedAnnotations.set((int)selected.stream().filter(p -> p.isAnnotation()).count()); - numSelectedDetections.set((int)selected.stream().filter(p -> p.isDetection()).count()); + numSelectedAnnotations.set( + (int)selected + .stream().filter(PathObject::isAnnotation) + .count() + ); + numSelectedDetections.set( + (int)selected + .stream().filter(PathObject::isDetection) + .count() + ); } } @@ -570,7 +581,7 @@ private void updateSelectedObjectCounts() { private static class ModelStringConverter extends StringConverter { - private WSInferModelCollection models; + private final WSInferModelCollection models; private ModelStringConverter(WSInferModelCollection models) { Objects.requireNonNull(models, "Models cannot be null"); diff --git a/src/main/java/qupath/ext/wsinfer/ui/WSInferPrefs.java b/src/main/java/qupath/ext/wsinfer/ui/WSInferPrefs.java index d8a4648..b3c27d8 100644 --- a/src/main/java/qupath/ext/wsinfer/ui/WSInferPrefs.java +++ b/src/main/java/qupath/ext/wsinfer/ui/WSInferPrefs.java @@ -16,7 +16,9 @@ package qupath.ext.wsinfer.ui; +import javafx.beans.property.BooleanProperty; import javafx.beans.property.Property; +import javafx.beans.property.SimpleStringProperty; import javafx.beans.property.StringProperty; import qupath.lib.gui.prefs.PathPrefs; import qupath.lib.gui.UserDirectoryManager; @@ -44,6 +46,9 @@ public class WSInferPrefs { "wsinfer.numWorkers", 1 ).asObject(); + private static StringProperty localDirectoryProperty = PathPrefs.createPersistentPreference( + "wsinfer.localDirectory", + null); /** * String storing the preferred directory to cache models. @@ -66,9 +71,18 @@ public static Property numWorkersProperty() { return numWorkersProperty; } + /** + * String storing the preferred directory for user-supplied model folders. + */ + public static StringProperty localDirectoryProperty() { + return localDirectoryProperty; + } + private static Path getUserDir() { Path userPath = UserDirectoryManager.getInstance().getUserPath(); Path cachePath = Paths.get(System.getProperty("user.dir"), ".cache", "QuPath"); return userPath == null || userPath.toString().isEmpty() ? cachePath : userPath; } + + } diff --git a/src/main/resources/qupath/ext/wsinfer/ui/strings.properties b/src/main/resources/qupath/ext/wsinfer/ui/strings.properties index 29963d8..42ce315 100644 --- a/src/main/resources/qupath/ext/wsinfer/ui/strings.properties +++ b/src/main/resources/qupath/ext/wsinfer/ui/strings.properties @@ -52,8 +52,10 @@ ui.options.device = Preferred device: ui.options.device.tooltip = Select the preferred device for model running (choose CPU if other options are not available) ui.options.directory = Model directory: ui.options.directory.tooltip = Choose the directory where models should be stored +ui.options.localModelDirectory = User model directory: +ui.options.localModelDirectory.tooltip = Choose the directory where user-created models can be loaded from ui.options.pworkers = Number of parallel workers: -ui.options.pworkers.label = Choose the desired number of threads used to request tiles for inference +ui.options.pworkers.tooltip = Choose the desired number of threads used to request tiles for inference ##Other Windows #Processing Window and progress pop-ups diff --git a/src/main/resources/qupath/ext/wsinfer/ui/wsinfer_control.fxml b/src/main/resources/qupath/ext/wsinfer/ui/wsinfer_control.fxml index 770a055..ef1890c 100644 --- a/src/main/resources/qupath/ext/wsinfer/ui/wsinfer_control.fxml +++ b/src/main/resources/qupath/ext/wsinfer/ui/wsinfer_control.fxml @@ -137,18 +137,28 @@ - - + + + + + + + +