diff --git a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java index 7e8bb5ca..246ae3ac 100644 --- a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java +++ b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java @@ -48,11 +48,14 @@ import io.bioimage.modelrunner.exceptions.LoadEngineException; import io.bioimage.modelrunner.exceptions.LoadModelException; import io.bioimage.modelrunner.exceptions.RunModelException; +import io.bioimage.modelrunner.model.processing.Processing; +import io.bioimage.modelrunner.model.stardist.StardistConfig; import io.bioimage.modelrunner.runmode.RunMode; import io.bioimage.modelrunner.runmode.ops.GenericOp; import io.bioimage.modelrunner.tensor.Tensor; import io.bioimage.modelrunner.tensor.Utils; import io.bioimage.modelrunner.utils.Constants; +import io.bioimage.modelrunner.utils.JSONUtils; import io.bioimage.modelrunner.versionmanagement.InstalledEngines; import net.imglib2.RandomAccessibleInterval; import net.imglib2.img.array.ArrayImgs; @@ -72,13 +75,15 @@ */ public class Stardist2D { + private String modelDir; + private ModelDescriptor descriptor; private final int channels; - private final float nms_threshold; + private Float nms_threshold; - private final float prob_threshold; + private Float prob_threshold; private static final List STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"}); @@ -90,11 +95,47 @@ public class Stardist2D { private static final String STARDIST2D_METHOD_NAME= "stardist_postprocessing"; - private Stardist2D() { + private static final String THRES_FNAME = "thresholds.json"; + + private static final String PROB_THRES_KEY = "thres"; + + private static final String NMS_THRES_KEY = "thres"; + + private static final float DEFAULT_NMS_THRES = (float) 0.4; + + private static final float DEFAULT_PROB_THRES = (float) 0.5; + + private Stardist2D(StardistConfig config, String modelName, String baseDir) { + modelDir = new File(baseDir, modelName).getAbsolutePath(); + findWeights(); + findThresholds(); + this.channels = 1; - // TODO get from config?? - this.nms_threshold = 0; - this.prob_threshold = 0; + } + + private void findWeights() { + + } + + private void findThresholds() { + if (new File(modelDir, THRES_FNAME).isFile()) { + try { + Map json = JSONUtils.load(modelDir + File.separator + THRES_FNAME); + if (json.get(PROB_THRES_KEY) != null && json.get(PROB_THRES_KEY) instanceof Number) + prob_threshold = ((Number) json.get(PROB_THRES_KEY)).floatValue(); + if (json.get(NMS_THRES_KEY) != null && json.get(NMS_THRES_KEY) instanceof Number) + nms_threshold = ((Number) json.get(NMS_THRES_KEY)).floatValue(); + } catch (IOException e) { + } + } + if (nms_threshold == null) { + System.out.println("Nms threshold not defined, using default value: " + DEFAULT_NMS_THRES); + nms_threshold = DEFAULT_NMS_THRES; + } + if (prob_threshold == null) { + System.out.println("Probability threshold not defined, using default value: " + DEFAULT_PROB_THRES); + prob_threshold = DEFAULT_NMS_THRES; + } } private Stardist2D(ModelDescriptor descriptor) { @@ -226,6 +267,8 @@ RandomAccessibleInterval predict(RandomAccessibleInterval image) throws Mo Model model = Model.createBioimageioModel(this.descriptor.getModelPath()); model.loadModel(); + Processing processing = Processing.init(descriptor); + inputList = processing.preprocess(inputList, false); model.runModel(inputList, outputList); return Utils.transpose(Cast.unchecked(postProcessing(outputList.get(0).getData()))); diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/AbstractStardist.java b/src/main/java/io/bioimage/modelrunner/model/stardist/AbstractStardist.java new file mode 100644 index 00000000..1f5aa4b1 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/model/stardist/AbstractStardist.java @@ -0,0 +1,155 @@ +package io.bioimage.modelrunner.model.stardist; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import net.imglib2.RandomAccessibleInterval; + +public class AbstractStardist { + + private StardistConfig config; + + private float probThres; + + private float nmsThres; + + private void predict_instances(RandomAccessibleInterval img, Map kwargs) { + Map predictKwargs = new HashMap(); + if (kwargs.get("predictKwargs") != null && kwargs.get("predictKwargs") instanceof Map) + predictKwargs = (Map) kwargs.get("predictKwargs"); + Map nmsKwargs = new HashMap(); + if (kwargs.get("nmsKwargs") != null && kwargs.get("nmsKwargs") instanceof Map) + nmsKwargs = (Map) kwargs.get("nmsKwargs"); + boolean returnPredict = false; + boolean sparse = true; + if (kwargs.get("returnPredict") != null && kwargs.get("returnPredict") instanceof Boolean) + returnPredict = (boolean) kwargs.get("returnPredict"); + if (kwargs.get("sparse") != null && kwargs.get("sparse") instanceof Boolean) + sparse = (boolean) kwargs.get("sparse"); + if (returnPredict && sparse) + sparse = false; + String axes = ""; + if (kwargs.get("axes") != null && kwargs.get("axes") instanceof String) + axes = (String) kwargs.get("axes"); + + String _axes = normalizeAxes(img, axes); + String axesNet = config.axes; + String permuteAxes = permuteAxes(_axes, axesNet); + int[] shapeInst; + + Number scale = null; + if (kwargs.get("scale") != null && kwargs.get("scale") instanceof Number) + scale = (Number) kwargs.get("scale"); + + if (scale != null) { + // TODO + for (String ax : _axes.split("")) { + } + } + + Normalizer normalizer = null; + if (kwargs.get("normalizer") != null && kwargs.get("normalizer") instanceof Normalizer) + normalizer = (Normalizer) kwargs.get("normalizer"); + List nTiles = null; + if (kwargs.get("nTiles") != null && kwargs.get("nTiles") instanceof List) + nTiles = ((List) kwargs.get("nTiles")); + Float probThresh = null; + if (kwargs.get("probThresh") != null && kwargs.get("probThresh") instanceof Number) + probThresh = ((Number) kwargs.get("probThresh")).floatValue(); + + if (sparse) { + predictSparseGenerator(img, axes, normalizer, nTiles, probThresh); + } else { + + } + + + } + + private void predictSparseGenerator(RandomAccessibleInterval img, String axes, Normalizer normalizer, List nTiles, + Float probThresh) { + if (probThresh == null) probThresh = this.probThres; + + Map args = predictSetup(img, axes, normalizer, nTiles); + + } + + private Map predictSetup(RandomAccessibleInterval img, String axes, Normalizer normalizer, List nTiles) { + if (nTiles == null) { + nTiles = new ArrayList(); + for (int i = 0; i < img.dimensionsAsLongArray().length; i ++) nTiles.add(1); + } + if (nTiles.size() != img.dimensionsAsLongArray().length) + throw new IllegalArgumentException("The number of image dimensions (" + img.dimensionsAsLongArray().length + + ") should be the same as the tile list lenght (" + nTiles.size() + ")."); + axes = normalizeAxes(img, axes); + String axesNet = this.config.axes; + // TODO permuteAxes + RandomAccessibleInterval x = null; // TODO + int channel = axesDict(axesNet).get("C"); + if (this.config.nChannelIn != x.dimensionsAsLongArray()[channel]) + throw new IllegalArgumentException("The number of channels of the image (" + + x.dimensionsAsLongArray()[channel] + ") should be the same as the model config (" + + config.nChannelIn + ")."); + int[] axesNetDivBy = axesDivBy(axesNet); + int[] grid = config.grid; + if (grid.length != axesNet.length() - 1) + throw new IllegalArgumentException(); + Map gridDict = new HashMap(); + int i = 0; + for (String a : axesNet.toUpperCase().split("")) { + if (a == "C") + continue; + gridDict.put(a, grid[i ++]); + } + + Resizer resizer = new Resizer(gridDict); + return null; + } + + private int[] axesDivBy(String queryAxes) { + if (this.config.backbone.equals("unet")) + throw new IllegalArgumentException("Backbone '" + config.backbone + "' not implemented."); + String[] strs = Utils.axesCheckAndNormalize(queryAxes, null, null); + queryAxes = strs[0]; + + if (config.unet_pool.length != config.grid.length) + throw new IllegalArgumentException(); + int i = 0; + Map divBy = new HashMap(); + for (String a : config.axes.split("")) { + if (a.toUpperCase().equals("C")) + continue; + int val = (int) (Math.pow(config.unet_pool[i], config.unet_n_depth) * config.grid[i]); + divBy.put(a.toUpperCase(), val); + i ++; + } + int[] arr = new int[queryAxes.length()]; + i = 0; + for (String a : queryAxes.split("")) { + arr[i ++] = divBy.keySet().contains(a) ? divBy.get(a) : 1; + } + return arr; + } + + private Map axesDict(String axes) { + String[] strs = Utils.axesCheckAndNormalize(axes, null, null); + axes = strs[0]; + String allowed = strs[1]; + Map map = new HashMap(); + for (String a : allowed.split("")) + map.put(a, axes.indexOf(a) == -1 ? null : axes.indexOf(a)); + return map; + } + + private String normalizeAxes(RandomAccessibleInterval img, String axes) { + return null; + } + + private String permuteAxes(String axes, String axesNet) { + return null; + } + +} diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/Normalizer.java b/src/main/java/io/bioimage/modelrunner/model/stardist/Normalizer.java new file mode 100644 index 00000000..1988cd04 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/model/stardist/Normalizer.java @@ -0,0 +1,5 @@ +package io.bioimage.modelrunner.model.stardist; + +public class Normalizer { + +} diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java b/src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java new file mode 100644 index 00000000..fdb5bccb --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java @@ -0,0 +1,73 @@ +package io.bioimage.modelrunner.model.stardist; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import net.imglib2.FinalInterval; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.RealType; +import net.imglib2.view.Views; + +public class Resizer { + + + private final Map grid; + + private Map pad; + + private Map paddedShape; + + protected Resizer(Map grid) { + this.grid = grid; + } + + protected & RealType> RandomAccessibleInterval + before(RandomAccessibleInterval x, String axes, int[] axesDivBy) { + for (int i = 0; i < axes.length(); i ++) { + String ax = axes.split("")[i]; + int g = grid.keySet().contains(ax) ? grid.get(ax) : 1; + int a = axesDivBy[i]; + if (a % g != 0) + throw new IllegalArgumentException(); + } + + String[] strs = Utils.axesCheckAndNormalize(axes, x.numDimensions(), null); + axes = strs[0]; + pad = new HashMap(); + for (int i = 0; i < axes.length(); i ++) { + long val = (axesDivBy[i] - x.dimensionsAsLongArray()[i] % axesDivBy[i]) % axesDivBy[i]; + String a = axes.split("")[i]; + pad.put(a, (int) val); + } + long[] minLim = x.minAsLongArray(); + long[] maxLim = x.maxAsLongArray(); + int i = 0; + for (String a : axes.split("")) { + maxLim[ i ++] += pad.get(a); + } + RandomAccessibleInterval xPad = Views.interval( + Views.extendMirrorDouble(x), new FinalInterval( minLim, maxLim )); + paddedShape = new HashMap(); + for (int j = 0; j < axes.length(); j ++) { + String ax = axes.split("")[j]; + if (ax.toUpperCase().equals("C")) + continue; + paddedShape.put(ax.toUpperCase(), (int) xPad.dimensionsAsLongArray()[j]); + } + return xPad; + } + + protected & RealType> RandomAccessibleInterval + after(RandomAccessibleInterval x, String axes) { + String[] strs = Utils.axesCheckAndNormalize(axes, x.numDimensions(), null); + axes = strs[0]; + RandomAccessibleInterval crop = null; + return crop; + } + + protected void filterPoints() { + + } +} diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/StardistConfig.java b/src/main/java/io/bioimage/modelrunner/model/stardist/StardistConfig.java new file mode 100644 index 00000000..ce5241fb --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/model/stardist/StardistConfig.java @@ -0,0 +1,38 @@ +/*- + * #%L + * Use deep learning frameworks from Java in an agnostic and isolated way. + * %% + * Copyright (C) 2022 - 2024 Institut Pasteur and BioImage.IO developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package io.bioimage.modelrunner.model.stardist; + + +/** + * Implementation of the Stardist Config instance in Java. + * + *@author Carlos Garcia + */ +public class StardistConfig { + + public int[] unet_pool; + public String axes; + public double unet_n_depth; + public int[] grid; + public int nChannelIn; + public String backbone; + + +} diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/Utils.java b/src/main/java/io/bioimage/modelrunner/model/stardist/Utils.java new file mode 100644 index 00000000..0b10a384 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/model/stardist/Utils.java @@ -0,0 +1,21 @@ +package io.bioimage.modelrunner.model.stardist; + +public class Utils { + + protected static String[] axesCheckAndNormalize(String axes, Integer length, String disallowed) { + String allowed = "STCZYX"; + if (axes == null) + throw new IllegalArgumentException("Axes cannot be null"); + axes = axes.toUpperCase(); + for (String a : axes.split("")) { + if (!allowed.contains(a)) + throw new IllegalArgumentException("Invalid axis: " + a + ", it must be one of " + allowed); + if (axes.replace(a, "").length() + 1 != axes.length()) + throw new IllegalArgumentException("Invalid axis: " + a + " can only appear once."); + } + if (length != null && axes.length() != length) + throw new IllegalArgumentException("Axes (" + axes + ") must have length " + length); + return new String[] {axes, allowed}; + } + +}