Skip to content

Commit

Permalink
start creating new stardist wrapper in java with appose
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jan 15, 2025
1 parent d2a65f2 commit 2af780d
Show file tree
Hide file tree
Showing 6 changed files with 341 additions and 6 deletions.
55 changes: 49 additions & 6 deletions src/main/java/io/bioimage/modelrunner/model/Stardist2D.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,14 @@
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.processing.Processing;
import io.bioimage.modelrunner.model.stardist.StardistConfig;
import io.bioimage.modelrunner.runmode.RunMode;
import io.bioimage.modelrunner.runmode.ops.GenericOp;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.utils.JSONUtils;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;
Expand All @@ -72,13 +75,15 @@
*/
public class Stardist2D {

private String modelDir;

private ModelDescriptor descriptor;

private final int channels;

private final float nms_threshold;
private Float nms_threshold;

private final float prob_threshold;
private Float prob_threshold;

private static final List<String> STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"});

Expand All @@ -90,11 +95,47 @@ public class Stardist2D {

private static final String STARDIST2D_METHOD_NAME= "stardist_postprocessing";

private Stardist2D() {
private static final String THRES_FNAME = "thresholds.json";

private static final String PROB_THRES_KEY = "thres";

private static final String NMS_THRES_KEY = "thres";

private static final float DEFAULT_NMS_THRES = (float) 0.4;

private static final float DEFAULT_PROB_THRES = (float) 0.5;

private Stardist2D(StardistConfig config, String modelName, String baseDir) {
modelDir = new File(baseDir, modelName).getAbsolutePath();
findWeights();
findThresholds();

this.channels = 1;
// TODO get from config??
this.nms_threshold = 0;
this.prob_threshold = 0;
}

private void findWeights() {

}

private void findThresholds() {
if (new File(modelDir, THRES_FNAME).isFile()) {
try {
Map<String, Object> json = JSONUtils.load(modelDir + File.separator + THRES_FNAME);
if (json.get(PROB_THRES_KEY) != null && json.get(PROB_THRES_KEY) instanceof Number)
prob_threshold = ((Number) json.get(PROB_THRES_KEY)).floatValue();
if (json.get(NMS_THRES_KEY) != null && json.get(NMS_THRES_KEY) instanceof Number)
nms_threshold = ((Number) json.get(NMS_THRES_KEY)).floatValue();
} catch (IOException e) {
}
}
if (nms_threshold == null) {
System.out.println("Nms threshold not defined, using default value: " + DEFAULT_NMS_THRES);
nms_threshold = DEFAULT_NMS_THRES;
}
if (prob_threshold == null) {
System.out.println("Probability threshold not defined, using default value: " + DEFAULT_PROB_THRES);
prob_threshold = DEFAULT_NMS_THRES;
}
}

private Stardist2D(ModelDescriptor descriptor) {
Expand Down Expand Up @@ -226,6 +267,8 @@ RandomAccessibleInterval<T> predict(RandomAccessibleInterval<T> image) throws Mo

Model model = Model.createBioimageioModel(this.descriptor.getModelPath());
model.loadModel();
Processing processing = Processing.init(descriptor);
inputList = processing.preprocess(inputList, false);
model.runModel(inputList, outputList);

return Utils.transpose(Cast.unchecked(postProcessing(outputList.get(0).getData())));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package io.bioimage.modelrunner.model.stardist;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import net.imglib2.RandomAccessibleInterval;

public class AbstractStardist {

private StardistConfig config;

private float probThres;

private float nmsThres;

private void predict_instances(RandomAccessibleInterval img, Map<String, Object> kwargs) {
Map<String, Object> predictKwargs = new HashMap<String, Object>();
if (kwargs.get("predictKwargs") != null && kwargs.get("predictKwargs") instanceof Map)
predictKwargs = (Map<String, Object>) kwargs.get("predictKwargs");
Map<String, Object> nmsKwargs = new HashMap<String, Object>();
if (kwargs.get("nmsKwargs") != null && kwargs.get("nmsKwargs") instanceof Map)
nmsKwargs = (Map<String, Object>) kwargs.get("nmsKwargs");
boolean returnPredict = false;
boolean sparse = true;
if (kwargs.get("returnPredict") != null && kwargs.get("returnPredict") instanceof Boolean)
returnPredict = (boolean) kwargs.get("returnPredict");
if (kwargs.get("sparse") != null && kwargs.get("sparse") instanceof Boolean)
sparse = (boolean) kwargs.get("sparse");
if (returnPredict && sparse)
sparse = false;
String axes = "";
if (kwargs.get("axes") != null && kwargs.get("axes") instanceof String)
axes = (String) kwargs.get("axes");

String _axes = normalizeAxes(img, axes);
String axesNet = config.axes;
String permuteAxes = permuteAxes(_axes, axesNet);
int[] shapeInst;

Number scale = null;
if (kwargs.get("scale") != null && kwargs.get("scale") instanceof Number)
scale = (Number) kwargs.get("scale");

if (scale != null) {
// TODO
for (String ax : _axes.split("")) {
}
}

Normalizer normalizer = null;
if (kwargs.get("normalizer") != null && kwargs.get("normalizer") instanceof Normalizer)
normalizer = (Normalizer) kwargs.get("normalizer");
List<Integer> nTiles = null;
if (kwargs.get("nTiles") != null && kwargs.get("nTiles") instanceof List)
nTiles = ((List<Integer>) kwargs.get("nTiles"));
Float probThresh = null;
if (kwargs.get("probThresh") != null && kwargs.get("probThresh") instanceof Number)
probThresh = ((Number) kwargs.get("probThresh")).floatValue();

if (sparse) {
predictSparseGenerator(img, axes, normalizer, nTiles, probThresh);
} else {

}


}

private void predictSparseGenerator(RandomAccessibleInterval img, String axes, Normalizer normalizer, List<Integer> nTiles,
Float probThresh) {
if (probThresh == null) probThresh = this.probThres;

Map<String, Object> args = predictSetup(img, axes, normalizer, nTiles);

}

private Map<String, Object> predictSetup(RandomAccessibleInterval img, String axes, Normalizer normalizer, List<Integer> nTiles) {
if (nTiles == null) {
nTiles = new ArrayList<Integer>();
for (int i = 0; i < img.dimensionsAsLongArray().length; i ++) nTiles.add(1);
}
if (nTiles.size() != img.dimensionsAsLongArray().length)
throw new IllegalArgumentException("The number of image dimensions (" + img.dimensionsAsLongArray().length
+ ") should be the same as the tile list lenght (" + nTiles.size() + ").");
axes = normalizeAxes(img, axes);
String axesNet = this.config.axes;
// TODO permuteAxes
RandomAccessibleInterval x = null; // TODO
int channel = axesDict(axesNet).get("C");
if (this.config.nChannelIn != x.dimensionsAsLongArray()[channel])
throw new IllegalArgumentException("The number of channels of the image ("
+ x.dimensionsAsLongArray()[channel] + ") should be the same as the model config ("
+ config.nChannelIn + ").");
int[] axesNetDivBy = axesDivBy(axesNet);
int[] grid = config.grid;
if (grid.length != axesNet.length() - 1)
throw new IllegalArgumentException();
Map<String, Integer> gridDict = new HashMap<String, Integer>();
int i = 0;
for (String a : axesNet.toUpperCase().split("")) {
if (a == "C")
continue;
gridDict.put(a, grid[i ++]);
}

Resizer resizer = new Resizer(gridDict);
return null;
}

private int[] axesDivBy(String queryAxes) {
if (this.config.backbone.equals("unet"))
throw new IllegalArgumentException("Backbone '" + config.backbone + "' not implemented.");
String[] strs = Utils.axesCheckAndNormalize(queryAxes, null, null);
queryAxes = strs[0];

if (config.unet_pool.length != config.grid.length)
throw new IllegalArgumentException();
int i = 0;
Map<String, Integer> divBy = new HashMap<String, Integer>();
for (String a : config.axes.split("")) {
if (a.toUpperCase().equals("C"))
continue;
int val = (int) (Math.pow(config.unet_pool[i], config.unet_n_depth) * config.grid[i]);
divBy.put(a.toUpperCase(), val);
i ++;
}
int[] arr = new int[queryAxes.length()];
i = 0;
for (String a : queryAxes.split("")) {
arr[i ++] = divBy.keySet().contains(a) ? divBy.get(a) : 1;
}
return arr;
}

private Map<String, Integer> axesDict(String axes) {
String[] strs = Utils.axesCheckAndNormalize(axes, null, null);
axes = strs[0];
String allowed = strs[1];
Map<String, Integer> map = new HashMap<String, Integer>();
for (String a : allowed.split(""))
map.put(a, axes.indexOf(a) == -1 ? null : axes.indexOf(a));
return map;
}

private String normalizeAxes(RandomAccessibleInterval img, String axes) {
return null;
}

private String permuteAxes(String axes, String axesNet) {
return null;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.bioimage.modelrunner.model.stardist;

public class Normalizer {

}
73 changes: 73 additions & 0 deletions src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package io.bioimage.modelrunner.model.stardist;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import net.imglib2.FinalInterval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.view.Views;

public class Resizer {


private final Map<String, Integer> grid;

private Map<String, Integer> pad;

private Map<String, Integer> paddedShape;

protected Resizer(Map<String, Integer> grid) {
this.grid = grid;
}

protected <T extends NativeType<T> & RealType<T>> RandomAccessibleInterval<T>
before(RandomAccessibleInterval<T> x, String axes, int[] axesDivBy) {
for (int i = 0; i < axes.length(); i ++) {
String ax = axes.split("")[i];
int g = grid.keySet().contains(ax) ? grid.get(ax) : 1;
int a = axesDivBy[i];
if (a % g != 0)
throw new IllegalArgumentException();
}

String[] strs = Utils.axesCheckAndNormalize(axes, x.numDimensions(), null);
axes = strs[0];
pad = new HashMap<String, Integer>();
for (int i = 0; i < axes.length(); i ++) {
long val = (axesDivBy[i] - x.dimensionsAsLongArray()[i] % axesDivBy[i]) % axesDivBy[i];
String a = axes.split("")[i];
pad.put(a, (int) val);
}
long[] minLim = x.minAsLongArray();
long[] maxLim = x.maxAsLongArray();
int i = 0;
for (String a : axes.split("")) {
maxLim[ i ++] += pad.get(a);
}
RandomAccessibleInterval<T> xPad = Views.interval(
Views.extendMirrorDouble(x), new FinalInterval( minLim, maxLim ));
paddedShape = new HashMap<String, Integer>();
for (int j = 0; j < axes.length(); j ++) {
String ax = axes.split("")[j];
if (ax.toUpperCase().equals("C"))
continue;
paddedShape.put(ax.toUpperCase(), (int) xPad.dimensionsAsLongArray()[j]);
}
return xPad;
}

protected <T extends NativeType<T> & RealType<T>> RandomAccessibleInterval<T>
after(RandomAccessibleInterval<T> x, String axes) {
String[] strs = Utils.axesCheckAndNormalize(axes, x.numDimensions(), null);
axes = strs[0];
RandomAccessibleInterval<T> crop = null;
return crop;
}

protected void filterPoints() {

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*-
* #%L
* Use deep learning frameworks from Java in an agnostic and isolated way.
* %%
* Copyright (C) 2022 - 2024 Institut Pasteur and BioImage.IO developers.
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package io.bioimage.modelrunner.model.stardist;


/**
* Implementation of the Stardist Config instance in Java.
*
*@author Carlos Garcia
*/
public class StardistConfig {

public int[] unet_pool;
public String axes;
public double unet_n_depth;
public int[] grid;
public int nChannelIn;
public String backbone;


}
21 changes: 21 additions & 0 deletions src/main/java/io/bioimage/modelrunner/model/stardist/Utils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package io.bioimage.modelrunner.model.stardist;

public class Utils {

protected static String[] axesCheckAndNormalize(String axes, Integer length, String disallowed) {
String allowed = "STCZYX";
if (axes == null)
throw new IllegalArgumentException("Axes cannot be null");
axes = axes.toUpperCase();
for (String a : axes.split("")) {
if (!allowed.contains(a))
throw new IllegalArgumentException("Invalid axis: " + a + ", it must be one of " + allowed);
if (axes.replace(a, "").length() + 1 != axes.length())
throw new IllegalArgumentException("Invalid axis: " + a + " can only appear once.");
}
if (length != null && axes.length() != length)
throw new IllegalArgumentException("Axes (" + axes + ") must have length " + length);
return new String[] {axes, allowed};
}

}

0 comments on commit 2af780d

Please sign in to comment.