Skip to content

Commit

Permalink
Merge pull request #11 from ksugar/dev
Browse files Browse the repository at this point in the history
0.4.0 -> 0.4.1
  • Loading branch information
ksugar authored Oct 17, 2023
2 parents 51b4513 + 882a426 commit 9997b67
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 1 deletion.
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,45 @@ The weights file is downloaded from the URL and registered on the server. After
| Name | The SAM weights name to register. It needs to be unique in the same SAM type. |
| URL | The URL to the SAM weights file. |

#### SAM weights catalog

Here is a list of SAM weights that you can register from the URL.

<table>
<thead>
<tr>
<th>Type</th>
<th>Name (customizable)</th>
<th>URL</th>
<th>Citation</th>
</tr>
</thead>
<tbody>
<tr>
<td>vit_h</td>
<td>vit_h_lm</td>
<td><a href="https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1">https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1</a></td>
<td rowspan="4">Archit, A. et al. <a href="https://doi.org/10.1101/2023.08.21.554208">Segment Anything for
Microscopy.</a> bioRxiv 2023. doi:10.1101/2023.08.21.554208<br><br><a href="https://github.com/computational-cell-analytics/micro-sam">https://github.com/computational-cell-analytics/micro-sam</a></td>
</tr>
<tr>
<td>vit_b</td>
<td>vit_b_lm</td>
<td><a href="https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1">https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1</a></td>
</tr>
<tr>
<td>vit_h</td>
<td>vit_h_em</td>
<td><a href="https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1">https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1</a></td>
</tr>
<tr>
<td>vit_b</td>
<td>vit_b_em</td>
<td><a href="https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1">https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1</a></td>
</tr>
</tbody>
</table>

### Tips

If you select a class in `Auto set` in the Annotations tab, it is used for a new annotation generated by SAM.
Expand Down
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ plugins {
ext.moduleName = 'org.elephant.sam.qupath'

// TODO: Define the extension version & provide a short description
version = "0.4.0"
version = "0.4.1"
description = 'QuPath extension for Segment Anything Model (SAM)'

// TODO: Specify the QuPath version, compatible with the extension.
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/elephant/sam/commands/SAMMainCommand.java
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ public void runAutoMask() {
.cropNPointsDownscaleFactor(cropNPointsDownscaleFactorProperty.get())
.minMaskRegionArea(minMaskRegionAreaProperty.get())
.includeImageEdge(includeImageEdgeProperty.get())
.checkpointUrl(selectedWeightsProperty.get().getUrl())
.build();
submitTask(task);
}
Expand Down Expand Up @@ -730,6 +731,7 @@ private void submitDetectionTask(List<PathObject> foregroundObjects, List<PathOb
.outputType(outputTypeProperty.get())
.setName(setNamesProperty.get())
.setRandomColor(useRandomColorsProperty.get())
.checkpointUrl(selectedWeightsProperty.get().getUrl())
.build();
task.setOnSucceeded(event -> {
List<PathObject> detected = task.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ public class SAMAutoMaskParameters {
private String output_type;
@SuppressWarnings("unused")
private boolean include_image_edge;
@SuppressWarnings("unused")
private String checkpoint_url;

private SAMAutoMaskParameters(final Builder builder) {
Objects.requireNonNull(builder.type, "Model type must be specified");
Expand All @@ -58,6 +60,7 @@ private SAMAutoMaskParameters(final Builder builder) {
this.min_mask_region_area = builder.minMaskRegionArea;
this.output_type = builder.outputType;
this.include_image_edge = builder.includeImageEdge;
this.checkpoint_url = builder.checkpointUrl;
}

/**
Expand Down Expand Up @@ -87,6 +90,7 @@ public static class Builder {
private int minMaskRegionArea;
private String outputType;
private boolean includeImageEdge;
private String checkpointUrl;

private Builder(final SAMType model) {
this.type = model.modelName();
Expand Down Expand Up @@ -253,6 +257,17 @@ public Builder includeImageEdge(final boolean includeImageEdge) {
return this;
}

/**
* If specified, use the specified checkpoint.
*
* @param checkpointUrl
* @return this builder
*/
public Builder checkpointUrl(final String checkpointUrl) {
this.checkpointUrl = checkpointUrl;
return this;
}

/**
* Build the parameters.
*
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/org/elephant/sam/parameters/SAMPromptParameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public class SAMPromptParameters {
private int[] point_labels;
@SuppressWarnings("unused")
private boolean multimask_output;
@SuppressWarnings("unused")
private String checkpoint_url;

private SAMPromptParameters(final Builder builder) {
Objects.requireNonNull(builder.type, "Model type must be specified");
Expand Down Expand Up @@ -53,6 +55,7 @@ private SAMPromptParameters(final Builder builder) {
ind++;
}
}
this.checkpoint_url = builder.checkpointUrl;
}

/**
Expand All @@ -72,6 +75,7 @@ public static class Builder {
private String b64img;
private String b64mask;
private boolean multimask_output = false;
private String checkpointUrl;

private Collection<Coordinate> foreground = new LinkedHashSet<>();
private Collection<Coordinate> background = new LinkedHashSet<>();
Expand Down Expand Up @@ -151,6 +155,17 @@ public Builder multimaskOutput(boolean doMultimask) {
return this;
}

/**
* URL to a checkpoint file (optional).
*
* @param checkpointUrl
* @return this builder
*/
public Builder checkpointUrl(String checkpointUrl) {
this.checkpointUrl = checkpointUrl;
return this;
}

/**
* Build the prompt.
*
Expand Down
16 changes: 16 additions & 0 deletions src/main/java/org/elephant/sam/tasks/SAMAutoMaskTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ public class SAMAutoMaskTask extends Task<List<PathObject>> {

private final boolean includeImageEdge;

private final String checkpointUrl;

private SAMAutoMaskTask(Builder builder) {
this.serverURL = builder.serverURL;
Objects.requireNonNull(serverURL, "Server must not be null!");
Expand Down Expand Up @@ -133,6 +135,7 @@ private SAMAutoMaskTask(Builder builder) {
this.cropNPointsDownscaleFactor = builder.cropNPointsDownscaleFactor;
this.minMaskRegionArea = builder.minMaskRegionArea;
this.includeImageEdge = builder.includeImageEdge;
this.checkpointUrl = builder.checkpointUrl;
}

@Override
Expand Down Expand Up @@ -179,6 +182,7 @@ private List<PathObject> detectObjects()
.minMaskRegionArea(minMaskRegionArea)
.outputType(outputType.toString())
.includeImageEdge(includeImageEdge)
.checkpointUrl(checkpointUrl)
.build();

if (isCancelled())
Expand Down Expand Up @@ -270,6 +274,7 @@ public static class Builder {
private int cropNPointsDownscaleFactor = 1;
private int minMaskRegionArea = 0;
private boolean includeImageEdge = false;
private String checkpointUrl = null;

private Builder(QuPathViewer viewer) {
this.viewer = viewer;
Expand Down Expand Up @@ -503,6 +508,17 @@ public Builder includeImageEdge(final boolean includeImageEdge) {
return this;
}

/**
* If specified, use the specified checkpoint.
*
* @param checkpointUrl
* @return this builder
*/
public Builder checkpointUrl(final String checkpointUrl) {
this.checkpointUrl = checkpointUrl;
return this;
}

/**
* Build the detection task.
*
Expand Down
16 changes: 16 additions & 0 deletions src/main/java/org/elephant/sam/tasks/SAMDetectionTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ public class SAMDetectionTask extends Task<List<PathObject>> {

private final SAMType model;

private final String checkpointUrl;

private SAMDetectionTask(Builder builder) {
this.serverURL = builder.serverURL;
Objects.requireNonNull(serverURL, "Server must not be null!");
Expand Down Expand Up @@ -106,6 +108,7 @@ private SAMDetectionTask(Builder builder) {
this.outputType = builder.outputType;
this.setName = builder.setName;
this.setRandomColor = builder.setRandomColor;
this.checkpointUrl = builder.checkpointUrl;
}

@Override
Expand All @@ -132,6 +135,7 @@ private List<PathObject> detectObjects(PathObject foregroundObject, List<? exten
throws InterruptedException, IOException {

SAMPromptParameters.Builder promptBuilder = SAMPromptParameters.builder(model)
.checkpointUrl(checkpointUrl)
.multimaskOutput(outputType != SAMOutput.SINGLE_MASK);

// Determine which part of the image we need & set foreground prompts
Expand Down Expand Up @@ -254,6 +258,7 @@ public static class Builder {
private SAMOutput outputType = SAMOutput.SINGLE_MASK;
private boolean setRandomColor = true;
private boolean setName = true;
private String checkpointUrl;

private Builder(QuPathViewer viewer) {
this.viewer = viewer;
Expand Down Expand Up @@ -356,6 +361,17 @@ public Builder setName(final boolean setName) {
return this;
}

/**
* Specify the checkpoint URL.
*
* @param checkpointUrl
* @return this builder
*/
public Builder checkpointUrl(final String checkpointUrl) {
this.checkpointUrl = checkpointUrl;
return this;
}

/**
* Build the detection task.
*
Expand Down
42 changes: 42 additions & 0 deletions src/main/resources/scripts/bboxSAM.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import org.elephant.sam.entities.SAMType;
import org.elephant.sam.entities.SAMOutput;
import org.elephant.sam.tasks.SAMDetectionTask;
import qupath.lib.roi.RectangleROI;

// Given that you have the following data in .csv
def data = [
['id', 'x', 'y', 'w', 'h'],
['0', '411', '84', '31', '33'],
['1', '413', '170', '32', '37']
]
def fileOut = File.createTempFile('bbox', '.csv');
fileOut.text = data*.join(',').join(System.lineSeparator());

// Read the data from the file
def fileIn = new File(fileOut.path);
def rows = fileIn.readLines().tail()*.split(',');
def plane = getCurrentViewer().getImagePlane();
def bboxes = rows.collect {
PathObjects.createAnnotationObject(
ROIs.createRectangleROI(
it[1] as Double,
it[2] as Double,
it[3] as Double,
it[4] as Double,
plane
)
)
}

def task = SAMDetectionTask.builder(getCurrentViewer())
.serverURL("http://localhost:8000/sam/")
.addForegroundPrompts(bboxes)
.addBackgroundPrompts(Collections.emptyList())
.model(SAMType.VIT_L)
.outputType(SAMOutput.MULTI_SMALLEST)
.setName(true)
.setRandomColor(true)
.build();
Platform.runLater(task);
def annotations = task.get();
addObjects(annotations);
25 changes: 25 additions & 0 deletions src/main/resources/scripts/runAutomask.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import org.elephant.sam.entities.SAMType;
import org.elephant.sam.entities.SAMOutput;
import org.elephant.sam.tasks.SAMAutoMaskTask;

def task = SAMAutoMaskTask.builder(getCurrentViewer())
.serverURL("http://localhost:8000/sam/")
.model(SAMType.VIT_L)
.outputType(SAMOutput.MULTI_SMALLEST)
.setName(true)
.clearCurrentObjects(true)
.setRandomColor(true)
.pointsPerSide(16)
.pointsPerBatch(64)
.predIoUThresh(0.88)
.stabilityScoreThresh(0.95)
.stabilityScoreOffset(1.0)
.boxNmsThresh(0.2)
.cropNLayers(0)
.cropNmsThresh(0.7)
.cropOverlapRatio(512 / 1500)
.cropNPointsDownscaleFactor(1)
.minMaskRegionArea(0)
.includeImageEdge(true)
.build();
Platform.runLater(task);

0 comments on commit 9997b67

Please sign in to comment.