Skip to content

Commit

Permalink
Merge pull request #36 from alanocallaghan/java-17
Browse files Browse the repository at this point in the history
Support for QuPath 0.5 and local models
  • Loading branch information
petebankhead authored Nov 8, 2023
2 parents 281c08b + 6fb6187 commit 5874d13
Show file tree
Hide file tree
Showing 15 changed files with 334 additions and 104 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## v0.3.0

* Add support for QuPath v0.5
* Enable auto-updating (see https://github.com/qupath/qupath-extension-wsinfer/issues/40)
* Add support for local models (see https://github.com/qupath/qupath-extension-wsinfer/issues/32)

## v0.2.1

* Fix CPU/GPU support (https://github.com/qupath/qupath-extension-wsinfer/issues/41)
Expand All @@ -14,4 +20,4 @@

## v0.1.0

* First release!
* First release!
19 changes: 11 additions & 8 deletions build.gradle
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
plugins {
// Main gradle plugin for building a Java library
id 'java-library'
// To create a shadow/fat jar that bundle up all dependencies
id 'com.github.johnrengelman.shadow' version '7.1.2'
// Include this plugin to avoid downloading JavaCPP dependencies for all platforms
id 'org.bytedeco.gradle-javacpp-platform'
// Main gradle plugin for building a Java library
id 'java-library'
// To create a shadow/fat jar that bundle up all dependencies
id 'com.github.johnrengelman.shadow' version '7.1.2'
// Include this plugin to avoid downloading JavaCPP dependencies for all platforms
id 'org.bytedeco.gradle-javacpp-platform'
id 'org.openjfx.javafxplugin' version '0.1.0'
}

ext.moduleName = 'io.github.qupath.extension.wsinfer'

version = "0.2.1"
version = "0.3.0"
description = 'An extension to run WSInfer in QuPath'

// The default 'gradle.ext.qupathVersion' reads this from settings.gradle.
ext.qupathVersion = gradle.ext.qupathVersion

// Generally 11 for QuPath v0.4.3, but will be 17 for QuPath v0.5.0
ext.qupathJavaVersion = 11
ext.qupathJavaVersion = 17

def djlVersion = libs.versions.deepJavaLibrary.get()

Expand All @@ -36,6 +37,8 @@ dependencies {
// Main QuPath user interface jar.
// Automatically includes other QuPath jars as subdependencies.
implementation "io.github.qupath:qupath-gui-fx:${qupathVersion}"
implementation 'io.github.qupath:qupath-fxtras:0.1.0'
implementation 'org.commonmark:commonmark:0.21.0'

// For logging - the version comes from QuPath's version catalog at
// https://github.com/qupath/qupath/blob/main/gradle/libs.versions.toml
Expand Down
4 changes: 2 additions & 2 deletions settings.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pluginManagement {

rootProject.name = 'qupath-extension-wsinfer'

gradle.ext.qupathVersion = "0.4.4"
gradle.ext.qupathVersion = "0.5.0-SNAPSHOT"

dependencyResolutionManagement {

Expand All @@ -31,4 +31,4 @@ dependencyResolutionManagement {
}

}
}
}
4 changes: 2 additions & 2 deletions src/main/java/qupath/ext/wsinfer/WSInfer.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import qupath.ext.wsinfer.models.WSInferTransform;
import qupath.ext.wsinfer.models.WSInferUtils;
import qupath.ext.wsinfer.ui.WSInferPrefs;
import qupath.lib.gui.dialogs.Dialogs;
import qupath.lib.gui.tools.GuiTools;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.ImageServer;
import qupath.lib.images.servers.PixelCalibration;
Expand Down Expand Up @@ -151,7 +151,7 @@ public static void runInference(ImageData<BufferedImage> imageData, WSInferModel
public static void runInference(ImageData<BufferedImage> imageData, WSInferModel wsiModel, ProgressListener progressListener) throws InterruptedException, ModelNotFoundException, MalformedModelException, IOException, TranslateException {
Objects.requireNonNull(wsiModel, "Model cannot be null");
if (imageData == null) {
Dialogs.showNoImageError(resources.getString("title"));
GuiTools.showNoImageError(resources.getString("title"));
}

// Try to get some tiles we can use
Expand Down
9 changes: 6 additions & 3 deletions src/main/java/qupath/ext/wsinfer/WSInferExtension.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import qupath.ext.wsinfer.ui.WSInferCommand;
import qupath.lib.common.Version;
import qupath.lib.gui.QuPathGUI;
import qupath.lib.gui.extensions.GitHubProject;
import qupath.lib.gui.extensions.QuPathExtension;
import qupath.lib.gui.prefs.PathPrefs;

Expand All @@ -32,7 +33,7 @@
* QuPath extension to run patch-based deep learning inference with WSInfer.
* See https://wsinfer.readthedocs.io for more info.
*/
public class WSInferExtension implements QuPathExtension {
public class WSInferExtension implements QuPathExtension, GitHubProject {
private final static ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings");

private final static Logger logger = LoggerFactory.getLogger(WSInferExtension.class);
Expand All @@ -48,8 +49,6 @@ public class WSInferExtension implements QuPathExtension {
private final BooleanProperty enableExtensionProperty = PathPrefs.createPersistentPreference(
"enableExtension", true);



@Override
public void installExtension(QuPathGUI qupath) {
if (isInstalled) {
Expand Down Expand Up @@ -84,4 +83,8 @@ public Version getQuPathVersion() {
return EXTENSION_QUPATH_VERSION;
}

@Override
public GitHubRepo getRepository() {
return GitHubRepo.create(getName(), "qupath", "qupath-extension-wsinfer");
}
}
53 changes: 33 additions & 20 deletions src/main/java/qupath/ext/wsinfer/models/WSInferModel.java
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/**
/**
* Copyright 2023 University of Edinburgh
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -16,22 +16,21 @@

package qupath.ext.wsinfer.models;

import com.google.gson.annotations.SerializedName;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.ext.wsinfer.ui.WSInferPrefs;
import qupath.lib.io.GsonTools;

import java.io.File;
import java.io.IOException;
import java.math.BigInteger;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import com.google.gson.annotations.SerializedName;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.ext.wsinfer.ui.WSInferPrefs;
import qupath.lib.io.GsonTools;

import java.io.File;
import java.io.IOException;
import java.math.BigInteger;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;

// Equivalent to config.json files from hugging face
public class WSInferModel {
Expand All @@ -42,10 +41,10 @@ public class WSInferModel {
private WSInferModelConfiguration configuration;

@SerializedName("hf_repo_id")
private String hfRepoId;
String hfRepoId;

@SerializedName("hf_revision")
private String hfRevision;
String hfRevision;

public String getName() {
return hfRepoId;
Expand Down Expand Up @@ -86,6 +85,15 @@ public File getCFFile() {
return getFile("config.json");
}


/**
* Get the configuration file. Note that it is not guaranteed that the model has been downloaded.
* @return path to model config file in cache dir
*/
public File getREADMEFile() {
return getFile("README.md");
}

/**
* Check if the model files exist and are valid.
* @return true if the files exist and the SHA matches, and the config is valid.
Expand Down Expand Up @@ -118,7 +126,7 @@ private File getFile(String f) {
return Paths.get(getModelDirectory().toString(), f).toFile();
}

private File getModelDirectory() {
File getModelDirectory() {
return Paths.get(WSInferPrefs.modelDirectoryProperty().get(), hfRepoId, hfRevision).toFile();
}

Expand Down Expand Up @@ -177,11 +185,16 @@ public synchronized void downloadModel() throws IOException {
}
downloadFileToCacheDir("torchscript_model.pt");
downloadFileToCacheDir("config.json");
downloadFileToCacheDir("README.md");

// this downloads the LFS pointer, not the actual .pt file
// the LFS pointer contains a SHA256 checksum
URL url = new URL(String.format("https://huggingface.co/%s/raw/%s/torchscript_model.pt", hfRepoId, hfRevision));
WSInferUtils.downloadURLToFile(url, getPointerFile());
if (!isValid() || !checkSHAMatches()) {
throw new IOException("Error downloading model files");
}

}

private void downloadFileToCacheDir(String file) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ public class WSInferModelCollection {
* @return
*/
public Map<String, WSInferModel> getModels() {
return Collections.unmodifiableMap(models);
return Collections.synchronizedMap(models);
}
}
69 changes: 69 additions & 0 deletions src/main/java/qupath/ext/wsinfer/models/WSInferModelLocal.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* Copyright 2023 University of Edinburgh
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package qupath.ext.wsinfer.models;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.ResourceBundle;

public class WSInferModelLocal extends WSInferModel {

private final File modelDirectory;
private static final ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings");

/**
* Try to create a WSInfer model from a user directory.
* @param modelDirectory A user directory containing a
* torchscript.pt file and a config.json file
* @return A {@link WSInferModel} if the directory supplied is valid,
* otherwise nothing.
*/
public static WSInferModelLocal createInstance(File modelDirectory) throws IOException {
return new WSInferModelLocal(modelDirectory);
}

private WSInferModelLocal(File modelDirectory) throws IOException {
this.modelDirectory = modelDirectory;
this.hfRepoId = modelDirectory.getName();
List<File> files = Arrays.asList(Objects.requireNonNull(modelDirectory.listFiles()));
if (!files.contains(getCFFile())) {
throw new IOException(resources.getString("error.localModel") + ": " + getCFFile().toString());
}
if (!files.contains(getTSFile())) {
throw new IOException(resources.getString("error.localModel") + ": " + getTSFile().toString());
}
}

@Override
File getModelDirectory() {
return this.modelDirectory;
}

@Override
public boolean isValid() {
return getTSFile().exists() && getConfiguration() != null;
}

@Override
public synchronized void downloadModel() {}

@Override
public synchronized void removeCache() {}
}
26 changes: 25 additions & 1 deletion src/main/java/qupath/ext/wsinfer/models/WSInferUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

package qupath.ext.wsinfer.models;

import org.apache.commons.compress.utils.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.ext.wsinfer.ui.WSInferPrefs;
import qupath.fx.dialogs.Dialogs;
import qupath.lib.io.GsonTools;

import java.io.File;
Expand All @@ -33,13 +33,16 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;
import java.util.ResourceBundle;

/**
* Utility class to help with working with WSInfer models.
*/
public class WSInferUtils {

private static final Logger logger = LoggerFactory.getLogger(WSInferUtils.class);
private static final ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings");

private static WSInferModelCollection cachedModelCollection;

Expand All @@ -64,9 +67,30 @@ public static WSInferModelCollection getModelCollection() {
cachedModelCollection = downloadModelCollection();
}
}
String localModelDirectory = WSInferPrefs.localDirectoryProperty().get();
if (localModelDirectory != null) {
addLocalModels(cachedModelCollection, localModelDirectory);
}
return cachedModelCollection;
}

private static void addLocalModels(WSInferModelCollection cachedModelCollection, String localModelDirectory) {
File modelDir = new File(localModelDirectory);
if (!modelDir.exists() || !modelDir.isDirectory()) {
return;
}
for (var model: Objects.requireNonNull(modelDir.listFiles())) {
if (model.isDirectory()) {
try {
var localModel = WSInferModelLocal.createInstance(model);
cachedModelCollection.getModels().put(localModel.getName(), localModel);
} catch (IOException e) {
Dialogs.showErrorNotification(resources.getString("title"), e);
}
}
}
}

/**
* Download the model collection from the hugging face repo.
* This replaces any previously cached version.
Expand Down
5 changes: 3 additions & 2 deletions src/main/java/qupath/ext/wsinfer/ui/WSInferCommand.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import org.slf4j.LoggerFactory;
import qupath.lib.gui.ExtensionClassLoader;
import qupath.lib.gui.QuPathGUI;
import qupath.lib.gui.dialogs.Dialogs;
import qupath.fx.dialogs.Dialogs;


import java.io.IOException;
import java.net.URL;
Expand Down Expand Up @@ -72,7 +73,7 @@ private Stage createStage() throws IOException {

// We need to use the ExtensionClassLoader to load the FXML, since it's in a different module
var loader = new FXMLLoader(url, resources);
loader.setClassLoader(QuPathGUI.getExtensionClassLoader());
loader.setClassLoader(this.getClass().getClassLoader());
VBox root = loader.load();

// There's probably a better approach... but wrapping in a border pane
Expand Down
Loading

0 comments on commit 5874d13

Please sign in to comment.