Skip to content

Commit

Permalink
Implement local model support, see #32
Browse files Browse the repository at this point in the history
  • Loading branch information
alanocallaghan committed Oct 30, 2023
1 parent f84c688 commit fb829cb
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 28 deletions.
2 changes: 0 additions & 2 deletions src/main/java/qupath/ext/wsinfer/WSInferExtension.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ public class WSInferExtension implements QuPathExtension, GitHubProject {
private final BooleanProperty enableExtensionProperty = PathPrefs.createPersistentPreference(
"enableExtension", true);



@Override
public void installExtension(QuPathGUI qupath) {
if (isInstalled) {
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/qupath/ext/wsinfer/models/WSInferModel.java
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/**
/**
* Copyright 2023 University of Edinburgh
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -42,10 +42,10 @@ public class WSInferModel {
private WSInferModelConfiguration configuration;

@SerializedName("hf_repo_id")
private String hfRepoId;
String hfRepoId;

@SerializedName("hf_revision")
private String hfRevision;
String hfRevision;

public String getName() {
return hfRepoId;
Expand Down Expand Up @@ -118,7 +118,7 @@ private File getFile(String f) {
return Paths.get(getModelDirectory().toString(), f).toFile();
}

private File getModelDirectory() {
File getModelDirectory() {
return Paths.get(WSInferPrefs.modelDirectoryProperty().get(), hfRepoId, hfRevision).toFile();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ public class WSInferModelCollection {
* @return
*/
public Map<String, WSInferModel> getModels() {
return Collections.unmodifiableMap(models);
return Collections.synchronizedMap(models);
}
}
35 changes: 35 additions & 0 deletions src/main/java/qupath/ext/wsinfer/models/WSInferModelLocal.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package qupath.ext.wsinfer.models;

import qupath.ext.wsinfer.ui.WSInferPrefs;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;

public class WSInferModelLocal extends WSInferModel {

private final File modelDirectory;

public WSInferModelLocal(File modelDirectory) {
this.modelDirectory = modelDirectory;
this.hfRepoId = modelDirectory.getName();
// todo: load any files, populate fields from them.
}

@Override
File getModelDirectory() {
return this.modelDirectory;
}

@Override
public boolean isValid() {
return getTSFile().exists() && getConfiguration() != null;
}

@Override
public synchronized void downloadModel() {}

@Override
public synchronized void removeCache() {}
}
19 changes: 19 additions & 0 deletions src/main/java/qupath/ext/wsinfer/models/WSInferUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;

/**
* Utility class to help with working with WSInfer models.
Expand Down Expand Up @@ -64,9 +65,27 @@ public static WSInferModelCollection getModelCollection() {
cachedModelCollection = downloadModelCollection();
}
}
String localModelDirectory = WSInferPrefs.localDirectoryProperty().get();
if (localModelDirectory != null) {
addLocalModels(cachedModelCollection, localModelDirectory);
}
return cachedModelCollection;
}

private static void addLocalModels(WSInferModelCollection cachedModelCollection, String localModelDirectory) {
File modelDir = new File(localModelDirectory);
if (!modelDir.exists() || !modelDir.isDirectory()) {
return;
}
for (var model: Objects.requireNonNull(modelDir.listFiles())) {
var localModel = new WSInferModelLocal(model);
System.out.println(model);
System.out.println(localModel.getName());
System.out.println(cachedModelCollection.getModels());
cachedModelCollection.getModels().put(localModel.getName(), localModel);
}
}

/**
* Download the model collection from the hugging face repo.
* This replaces any previously cached version.
Expand Down
37 changes: 24 additions & 13 deletions src/main/java/qupath/ext/wsinfer/ui/WSInferController.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import javafx.beans.property.ObjectProperty;
import javafx.beans.property.SimpleIntegerProperty;
import javafx.beans.property.SimpleObjectProperty;
import javafx.beans.value.ObservableBooleanValue;
import javafx.beans.value.ObservableValue;
import javafx.concurrent.Task;
import javafx.concurrent.Worker;
Expand All @@ -51,10 +50,10 @@
import qupath.ext.wsinfer.models.WSInferModel;
import qupath.ext.wsinfer.models.WSInferModelCollection;
import qupath.ext.wsinfer.models.WSInferUtils;
import qupath.fx.dialogs.Dialogs;
import qupath.lib.common.ThreadTools;
import qupath.lib.gui.QuPathGUI;
import qupath.lib.gui.commands.Commands;
import qupath.fx.dialogs.Dialogs;
import qupath.lib.images.ImageData;
import qupath.lib.objects.PathAnnotationObject;
import qupath.lib.objects.PathObject;
Expand Down Expand Up @@ -83,7 +82,7 @@ public class WSInferController {
private static final Logger logger = LoggerFactory.getLogger(WSInferController.class);

public QuPathGUI qupath;
private ObjectProperty<ImageData<BufferedImage>> imageDataProperty = new SimpleObjectProperty<>();
private final ObjectProperty<ImageData<BufferedImage>> imageDataProperty = new SimpleObjectProperty<>();
private MessageTextHelper messageTextHelper;

@FXML
Expand Down Expand Up @@ -112,6 +111,8 @@ public class WSInferController {
private Spinner<Integer> spinnerNumWorkers;
@FXML
private TextField tfModelDirectory;
@FXML
private TextField localModelDirectory;

private final static ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings");

Expand Down Expand Up @@ -242,6 +243,8 @@ private void configureActionToggleButton(Action action, ToggleButton button) {

private void configureModelDirectory() {
tfModelDirectory.textProperty().bindBidirectional(WSInferPrefs.modelDirectoryProperty());
localModelDirectory.textProperty().bindBidirectional(WSInferPrefs.localDirectoryProperty());
localModelDirectory.textProperty().addListener((v, o, n) -> configureModelChoices());
}

private void configureNumWorkers() {
Expand Down Expand Up @@ -391,7 +394,7 @@ private WSInferTask(ImageData<BufferedImage> imageData, WSInferModel model) {
}

@Override
protected Void call() throws Exception {
protected Void call() {
try {
// Ensure PyTorch engine is available
if (!PytorchManager.hasPyTorchEngine()) {
Expand Down Expand Up @@ -428,7 +431,7 @@ protected Void call() throws Exception {
*/
private class MessageTextHelper {

private SelectedObjectCounter selectedObjectCounter;
private final SelectedObjectCounter selectedObjectCounter;

/**
* Text to display a warning (because inference can't be run)
Expand Down Expand Up @@ -516,14 +519,14 @@ private String getWarningText() {
*/
private static class SelectedObjectCounter {

private ObjectProperty<ImageData<?>> imageDataProperty = new SimpleObjectProperty<>();
private final ObjectProperty<ImageData<?>> imageDataProperty = new SimpleObjectProperty<>();

private PathObjectSelectionListener selectionListener = this::selectedPathObjectChanged;
private final PathObjectSelectionListener selectionListener = this::selectedPathObjectChanged;

private ObservableValue<PathObjectHierarchy> hierarchyProperty;
private final ObservableValue<PathObjectHierarchy> hierarchyProperty;

private IntegerProperty numSelectedAnnotations = new SimpleIntegerProperty();
private IntegerProperty numSelectedDetections = new SimpleIntegerProperty();
private final IntegerProperty numSelectedAnnotations = new SimpleIntegerProperty();
private final IntegerProperty numSelectedDetections = new SimpleIntegerProperty();

SelectedObjectCounter(ObservableValue<ImageData<BufferedImage>> imageDataProperty) {
this.imageDataProperty.bind(imageDataProperty);
Expand Down Expand Up @@ -561,16 +564,24 @@ private void updateSelectedObjectCounts() {
numSelectedDetections.set(0);
} else {
var selected = hierarchy.getSelectionModel().getSelectedObjects();
numSelectedAnnotations.set((int)selected.stream().filter(p -> p.isAnnotation()).count());
numSelectedDetections.set((int)selected.stream().filter(p -> p.isDetection()).count());
numSelectedAnnotations.set(
(int)selected
.stream().filter(PathObject::isAnnotation)
.count()
);
numSelectedDetections.set(
(int)selected
.stream().filter(PathObject::isDetection)
.count()
);
}
}

}

private static class ModelStringConverter extends StringConverter<WSInferModel> {

private WSInferModelCollection models;
private final WSInferModelCollection models;

private ModelStringConverter(WSInferModelCollection models) {
Objects.requireNonNull(models, "Models cannot be null");
Expand Down
14 changes: 14 additions & 0 deletions src/main/java/qupath/ext/wsinfer/ui/WSInferPrefs.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package qupath.ext.wsinfer.ui;

import javafx.beans.property.BooleanProperty;
import javafx.beans.property.Property;
import javafx.beans.property.SimpleStringProperty;
import javafx.beans.property.StringProperty;
import qupath.lib.gui.prefs.PathPrefs;
import qupath.lib.gui.UserDirectoryManager;
Expand Down Expand Up @@ -44,6 +46,9 @@ public class WSInferPrefs {
"wsinfer.numWorkers",
1
).asObject();
private static StringProperty localDirectoryProperty = PathPrefs.createPersistentPreference(
"wsinfer.localDirectory",
null);

/**
* String storing the preferred directory to cache models.
Expand All @@ -66,9 +71,18 @@ public static Property<Integer> numWorkersProperty() {
return numWorkersProperty;
}

/**
* String storing the preferred directory for user-supplied model folders.
*/
public static StringProperty localDirectoryProperty() {
return localDirectoryProperty;
}

private static Path getUserDir() {
Path userPath = UserDirectoryManager.getInstance().getUserPath();
Path cachePath = Paths.get(System.getProperty("user.dir"), ".cache", "QuPath");
return userPath == null || userPath.toString().isEmpty() ? cachePath : userPath;
}


}
4 changes: 3 additions & 1 deletion src/main/resources/qupath/ext/wsinfer/ui/strings.properties
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ ui.options.device = Preferred device:
ui.options.device.tooltip = Select the preferred device for model running (choose CPU if other options are not available)
ui.options.directory = Model directory:
ui.options.directory.tooltip = Choose the directory where models should be stored
ui.options.localModelDirectory = User model directory:
ui.options.localModelDirectory.tooltip = Choose the directory where user-created models can be loaded from
ui.options.pworkers = Number of parallel workers:
ui.options.pworkers.label = Choose the desired number of threads used to request tiles for inference
ui.options.pworkers.tooltip = Choose the desired number of threads used to request tiles for inference

##Other Windows
#Processing Window and progress pop-ups
Expand Down
24 changes: 17 additions & 7 deletions src/main/resources/qupath/ext/wsinfer/ui/wsinfer_control.fxml
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,28 @@
</children>
</HBox>
<VBox alignment="CENTER" styleClass="standard-spacing">
<children>
<Label styleClass="regular" text="%ui.options.directory" />
<TextField fx:id="tfModelDirectory">
<tooltip><Tooltip text="%ui.options.directory.tooltip" /></tooltip>
</TextField>
</children>
<VBox alignment="CENTER" styleClass="standard-spacing">
<children>
<Label styleClass="regular" text="%ui.options.directory" />
<TextField fx:id="tfModelDirectory">
<tooltip><Tooltip text="%ui.options.directory.tooltip" /></tooltip>
</TextField>
</children>
</VBox>
<VBox alignment="CENTER" styleClass="standard-spacing">
<children>
<Label styleClass="regular" text="%ui.options.localModelDirectory" />
<TextField fx:id="localModelDirectory">
<tooltip><Tooltip text="%ui.options.localModelDirectory.tooltip" /></tooltip>
</TextField>
</children>
</VBox>
</VBox>
<HBox alignment="CENTER" styleClass="standard-spacing">
<children>
<Label styleClass="regular" text="%ui.options.pworkers" />
<Spinner fx:id="spinnerNumWorkers" prefWidth="75.0">
<tooltip><Tooltip text="%ui.options.pworkers.label" /></tooltip>
<tooltip><Tooltip text="%ui.options.pworkers.tooltip" /></tooltip>
<valueFactory>
<SpinnerValueFactory.IntegerSpinnerValueFactory initialValue="1" max="8" min="1" />
</valueFactory>
Expand Down

0 comments on commit fb829cb

Please sign in to comment.