diff --git a/README.md b/README.md
index 3562fed..d0a33c9 100644
--- a/README.md
+++ b/README.md
@@ -120,6 +120,45 @@ The weights file is downloaded from the URL and registered on the server. After
| Name | The SAM weights name to register. It needs to be unique in the same SAM type. |
| URL | The URL to the SAM weights file. |
+#### SAM weights catalog
+
+Here is a list of SAM weights that you can register from the URL.
+
+
+
### Tips
If you select a class in `Auto set` in the Annotations tab, it is used for a new annotation generated by SAM.
diff --git a/build.gradle b/build.gradle
index 4e233cd..9e42d32 100644
--- a/build.gradle
+++ b/build.gradle
@@ -11,7 +11,7 @@ plugins {
ext.moduleName = 'org.elephant.sam.qupath'
// TODO: Define the extension version & provide a short description
-version = "0.4.0"
+version = "0.4.1"
description = 'QuPath extension for Segment Anything Model (SAM)'
// TODO: Specify the QuPath version, compatible with the extension.
diff --git a/src/main/java/org/elephant/sam/commands/SAMMainCommand.java b/src/main/java/org/elephant/sam/commands/SAMMainCommand.java
index 03f5f5f..4bde8eb 100644
--- a/src/main/java/org/elephant/sam/commands/SAMMainCommand.java
+++ b/src/main/java/org/elephant/sam/commands/SAMMainCommand.java
@@ -575,6 +575,7 @@ public void runAutoMask() {
.cropNPointsDownscaleFactor(cropNPointsDownscaleFactorProperty.get())
.minMaskRegionArea(minMaskRegionAreaProperty.get())
.includeImageEdge(includeImageEdgeProperty.get())
+ .checkpointUrl(selectedWeightsProperty.get().getUrl())
.build();
submitTask(task);
}
@@ -730,6 +731,7 @@ private void submitDetectionTask(List foregroundObjects, List {
List detected = task.getValue();
diff --git a/src/main/java/org/elephant/sam/parameters/SAMAutoMaskParameters.java b/src/main/java/org/elephant/sam/parameters/SAMAutoMaskParameters.java
index 22e4a51..f4c0056 100644
--- a/src/main/java/org/elephant/sam/parameters/SAMAutoMaskParameters.java
+++ b/src/main/java/org/elephant/sam/parameters/SAMAutoMaskParameters.java
@@ -39,6 +39,8 @@ public class SAMAutoMaskParameters {
private String output_type;
@SuppressWarnings("unused")
private boolean include_image_edge;
+ @SuppressWarnings("unused")
+ private String checkpoint_url;
private SAMAutoMaskParameters(final Builder builder) {
Objects.requireNonNull(builder.type, "Model type must be specified");
@@ -58,6 +60,7 @@ private SAMAutoMaskParameters(final Builder builder) {
this.min_mask_region_area = builder.minMaskRegionArea;
this.output_type = builder.outputType;
this.include_image_edge = builder.includeImageEdge;
+ this.checkpoint_url = builder.checkpointUrl;
}
/**
@@ -87,6 +90,7 @@ public static class Builder {
private int minMaskRegionArea;
private String outputType;
private boolean includeImageEdge;
+ private String checkpointUrl;
private Builder(final SAMType model) {
this.type = model.modelName();
@@ -253,6 +257,17 @@ public Builder includeImageEdge(final boolean includeImageEdge) {
return this;
}
+ /**
+ * If specified, use the specified checkpoint.
+ *
+ * @param checkpointUrl
+ * @return this builder
+ */
+ public Builder checkpointUrl(final String checkpointUrl) {
+ this.checkpointUrl = checkpointUrl;
+ return this;
+ }
+
/**
* Build the parameters.
*
diff --git a/src/main/java/org/elephant/sam/parameters/SAMPromptParameters.java b/src/main/java/org/elephant/sam/parameters/SAMPromptParameters.java
index 0f0fa68..4db4a98 100644
--- a/src/main/java/org/elephant/sam/parameters/SAMPromptParameters.java
+++ b/src/main/java/org/elephant/sam/parameters/SAMPromptParameters.java
@@ -26,6 +26,8 @@ public class SAMPromptParameters {
private int[] point_labels;
@SuppressWarnings("unused")
private boolean multimask_output;
+ @SuppressWarnings("unused")
+ private String checkpoint_url;
private SAMPromptParameters(final Builder builder) {
Objects.requireNonNull(builder.type, "Model type must be specified");
@@ -53,6 +55,7 @@ private SAMPromptParameters(final Builder builder) {
ind++;
}
}
+ this.checkpoint_url = builder.checkpointUrl;
}
/**
@@ -72,6 +75,7 @@ public static class Builder {
private String b64img;
private String b64mask;
private boolean multimask_output = false;
+ private String checkpointUrl;
private Collection foreground = new LinkedHashSet<>();
private Collection background = new LinkedHashSet<>();
@@ -151,6 +155,17 @@ public Builder multimaskOutput(boolean doMultimask) {
return this;
}
+ /**
+ * URL to a checkpoint file (optional).
+ *
+ * @param checkpointUrl
+ * @return this builder
+ */
+ public Builder checkpointUrl(String checkpointUrl) {
+ this.checkpointUrl = checkpointUrl;
+ return this;
+ }
+
/**
* Build the prompt.
*
diff --git a/src/main/java/org/elephant/sam/tasks/SAMAutoMaskTask.java b/src/main/java/org/elephant/sam/tasks/SAMAutoMaskTask.java
index 109e2ea..b66865d 100644
--- a/src/main/java/org/elephant/sam/tasks/SAMAutoMaskTask.java
+++ b/src/main/java/org/elephant/sam/tasks/SAMAutoMaskTask.java
@@ -88,6 +88,8 @@ public class SAMAutoMaskTask extends Task> {
private final boolean includeImageEdge;
+ private final String checkpointUrl;
+
private SAMAutoMaskTask(Builder builder) {
this.serverURL = builder.serverURL;
Objects.requireNonNull(serverURL, "Server must not be null!");
@@ -133,6 +135,7 @@ private SAMAutoMaskTask(Builder builder) {
this.cropNPointsDownscaleFactor = builder.cropNPointsDownscaleFactor;
this.minMaskRegionArea = builder.minMaskRegionArea;
this.includeImageEdge = builder.includeImageEdge;
+ this.checkpointUrl = builder.checkpointUrl;
}
@Override
@@ -179,6 +182,7 @@ private List detectObjects()
.minMaskRegionArea(minMaskRegionArea)
.outputType(outputType.toString())
.includeImageEdge(includeImageEdge)
+ .checkpointUrl(checkpointUrl)
.build();
if (isCancelled())
@@ -270,6 +274,7 @@ public static class Builder {
private int cropNPointsDownscaleFactor = 1;
private int minMaskRegionArea = 0;
private boolean includeImageEdge = false;
+ private String checkpointUrl = null;
private Builder(QuPathViewer viewer) {
this.viewer = viewer;
@@ -503,6 +508,17 @@ public Builder includeImageEdge(final boolean includeImageEdge) {
return this;
}
+ /**
+ * If specified, use the specified checkpoint.
+ *
+ * @param checkpointUrl
+ * @return this builder
+ */
+ public Builder checkpointUrl(final String checkpointUrl) {
+ this.checkpointUrl = checkpointUrl;
+ return this;
+ }
+
/**
* Build the detection task.
*
diff --git a/src/main/java/org/elephant/sam/tasks/SAMDetectionTask.java b/src/main/java/org/elephant/sam/tasks/SAMDetectionTask.java
index 29430b3..d5ccaf9 100644
--- a/src/main/java/org/elephant/sam/tasks/SAMDetectionTask.java
+++ b/src/main/java/org/elephant/sam/tasks/SAMDetectionTask.java
@@ -71,6 +71,8 @@ public class SAMDetectionTask extends Task> {
private final SAMType model;
+ private final String checkpointUrl;
+
private SAMDetectionTask(Builder builder) {
this.serverURL = builder.serverURL;
Objects.requireNonNull(serverURL, "Server must not be null!");
@@ -106,6 +108,7 @@ private SAMDetectionTask(Builder builder) {
this.outputType = builder.outputType;
this.setName = builder.setName;
this.setRandomColor = builder.setRandomColor;
+ this.checkpointUrl = builder.checkpointUrl;
}
@Override
@@ -132,6 +135,7 @@ private List detectObjects(PathObject foregroundObject, List exten
throws InterruptedException, IOException {
SAMPromptParameters.Builder promptBuilder = SAMPromptParameters.builder(model)
+ .checkpointUrl(checkpointUrl)
.multimaskOutput(outputType != SAMOutput.SINGLE_MASK);
// Determine which part of the image we need & set foreground prompts
@@ -254,6 +258,7 @@ public static class Builder {
private SAMOutput outputType = SAMOutput.SINGLE_MASK;
private boolean setRandomColor = true;
private boolean setName = true;
+ private String checkpointUrl;
private Builder(QuPathViewer viewer) {
this.viewer = viewer;
@@ -356,6 +361,17 @@ public Builder setName(final boolean setName) {
return this;
}
+ /**
+ * Specify the checkpoint URL.
+ *
+ * @param checkpointUrl
+ * @return this builder
+ */
+ public Builder checkpointUrl(final String checkpointUrl) {
+ this.checkpointUrl = checkpointUrl;
+ return this;
+ }
+
/**
* Build the detection task.
*
diff --git a/src/main/resources/scripts/bboxSAM.groovy b/src/main/resources/scripts/bboxSAM.groovy
new file mode 100644
index 0000000..d17da40
--- /dev/null
+++ b/src/main/resources/scripts/bboxSAM.groovy
@@ -0,0 +1,42 @@
+import org.elephant.sam.entities.SAMType;
+import org.elephant.sam.entities.SAMOutput;
+import org.elephant.sam.tasks.SAMDetectionTask;
+import qupath.lib.roi.RectangleROI;
+
+// Given that you have the following data in .csv
+def data = [
+ ['id', 'x', 'y', 'w', 'h'],
+ ['0', '411', '84', '31', '33'],
+ ['1', '413', '170', '32', '37']
+]
+def fileOut = File.createTempFile('bbox', '.csv');
+fileOut.text = data*.join(',').join(System.lineSeparator());
+
+// Read the data from the file
+def fileIn = new File(fileOut.path);
+def rows = fileIn.readLines().tail()*.split(',');
+def plane = getCurrentViewer().getImagePlane();
+def bboxes = rows.collect {
+ PathObjects.createAnnotationObject(
+ ROIs.createRectangleROI(
+ it[1] as Double,
+ it[2] as Double,
+ it[3] as Double,
+ it[4] as Double,
+ plane
+ )
+ )
+}
+
+def task = SAMDetectionTask.builder(getCurrentViewer())
+ .serverURL("http://localhost:8000/sam/")
+ .addForegroundPrompts(bboxes)
+ .addBackgroundPrompts(Collections.emptyList())
+ .model(SAMType.VIT_L)
+ .outputType(SAMOutput.MULTI_SMALLEST)
+ .setName(true)
+ .setRandomColor(true)
+ .build();
+Platform.runLater(task);
+def annotations = task.get();
+addObjects(annotations);
\ No newline at end of file
diff --git a/src/main/resources/scripts/runAutomask.groovy b/src/main/resources/scripts/runAutomask.groovy
new file mode 100644
index 0000000..38f81a2
--- /dev/null
+++ b/src/main/resources/scripts/runAutomask.groovy
@@ -0,0 +1,25 @@
+import org.elephant.sam.entities.SAMType;
+import org.elephant.sam.entities.SAMOutput;
+import org.elephant.sam.tasks.SAMAutoMaskTask;
+
+def task = SAMAutoMaskTask.builder(getCurrentViewer())
+ .serverURL("http://localhost:8000/sam/")
+ .model(SAMType.VIT_L)
+ .outputType(SAMOutput.MULTI_SMALLEST)
+ .setName(true)
+ .clearCurrentObjects(true)
+ .setRandomColor(true)
+ .pointsPerSide(16)
+ .pointsPerBatch(64)
+ .predIoUThresh(0.88)
+ .stabilityScoreThresh(0.95)
+ .stabilityScoreOffset(1.0)
+ .boxNmsThresh(0.2)
+ .cropNLayers(0)
+ .cropNmsThresh(0.7)
+ .cropOverlapRatio(512 / 1500)
+ .cropNPointsDownscaleFactor(1)
+ .minMaskRegionArea(0)
+ .includeImageEdge(true)
+ .build();
+Platform.runLater(task);