From c54a95049321de17193838d9b95b6019312bb9e4 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Wed, 15 Jan 2025 14:56:40 +0100 Subject: [PATCH] keep iterating --- .../model/stardist/AbstractStardist.java | 93 ++++++++++++++++++- .../modelrunner/model/stardist/Resizer.java | 35 ++++++- 2 files changed, 121 insertions(+), 7 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/AbstractStardist.java b/src/main/java/io/bioimage/modelrunner/model/stardist/AbstractStardist.java index 1f5aa4b1..700652f9 100644 --- a/src/main/java/io/bioimage/modelrunner/model/stardist/AbstractStardist.java +++ b/src/main/java/io/bioimage/modelrunner/model/stardist/AbstractStardist.java @@ -6,6 +6,7 @@ import java.util.Map; import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.array.ArrayImgs; public class AbstractStardist { @@ -15,6 +16,8 @@ public class AbstractStardist { private float nmsThres; + private int[][] tileOverlap; + private void predict_instances(RandomAccessibleInterval img, Map kwargs) { Map predictKwargs = new HashMap(); if (kwargs.get("predictKwargs") != null && kwargs.get("predictKwargs") instanceof Map) @@ -72,7 +75,23 @@ private void predictSparseGenerator(RandomAccessibleInterval img, String axes, N Float probThresh) { if (probThresh == null) probThresh = this.probThres; - Map args = predictSetup(img, axes, normalizer, nTiles); + Map returns = predictSetup(img, axes, normalizer, nTiles); + + RandomAccessibleInterval x = (RandomAccessibleInterval) returns.get("x"); + List nTiles = (List) returns.get("nTiles"); + axes = (String) returns.get("axes"); + String axesNet = (String) returns.get("axesNet"); + int[] axesNetDivBy = (int[]) returns.get("axesNetDivBy"); + int[] grid = (int[]) returns.get("grid"); + Map gridDict = (Map) returns.get("gridDict"); + Resizer resizer = (Resizer) returns.get("resizer"); + int channel = (int) returns.get("channel"); + int product = nTiles.stream().reduce(1, (a, b) -> a * b); + if (product > 1) { + // TODO + } else { + + } } @@ -88,7 +107,7 @@ private Map predictSetup(RandomAccessibleInterval img, String ax String axesNet = this.config.axes; // TODO permuteAxes RandomAccessibleInterval x = null; // TODO - int channel = axesDict(axesNet).get("C"); + channel = axesDict(axesNet).get("C"); if (this.config.nChannelIn != x.dimensionsAsLongArray()[channel]) throw new IllegalArgumentException("The number of channels of the image (" + x.dimensionsAsLongArray()[channel] + ") should be the same as the model config (" @@ -106,6 +125,76 @@ private Map predictSetup(RandomAccessibleInterval img, String ax } Resizer resizer = new Resizer(gridDict); + x = resizer.before(x, axesNet, axesNetDivBy); + Map returns = new HashMap(); + returns.put("x", x); + returns.put("axes", axes); + returns.put("axesNet", axesNet); + returns.put("axesNetDivBy", axesNetDivBy); + returns.put("permuteAxes", permuteAxes); + returns.put("resizer", resizer); + returns.put("nTiles", nTiles); + returns.put("grid", grid); + returns.put("gridDict", gridDict); + returns.put("channel", channel); + return returns; + } + + private int channel; + private Integer[] sh; + + private void tilingSetup(RandomAccessibleInterval x, List nTiles, int[] axesNetDivBy, + String axesNet, Map gridDict) { + String tilingAxes = axesNet.replace("C", ""); + int[] xTilingAxes = new int[tilingAxes.length()]; + int c = 0; + for (String a : tilingAxes.split("")) { + xTilingAxes[c ++] = axesDict(axesNet).get(a); + } + int[] axesNetTileOverlaps = axesTileOverlap(axesNet); + // TODO nTiles = permuteAxes(); + sh = new Integer[axesNet.length()]; + for (int i = 0; i < axesNet.length(); i ++) { + String a = axesNet.split("")[i]; + sh[i] = (int) Math.floorDiv(x.dimensionsAsLongArray()[i], + gridDict.keySet().contains(a) ? gridDict.get(a) : 1); + } + sh[channel] = null; + int[] nBlockOverlaps = new int[axesNetTileOverlaps.length]; + for (int i = 0; i < axesNetTileOverlaps.length; i ++) { + nBlockOverlaps[i] = (int) Math.ceil(axesNetTileOverlaps[i] / (double) axesNetDivBy[i]); + } + } + + private RandomAccessibleInterval createEmptyOutput(int nChannel) { + sh[channel] = nChannel; + long[] dims = new long[sh.length]; + for (int i = 0; i < sh.length; i ++) + dims[i] = sh[i].longValue(); + return ArrayImgs.floats(dims); + } + + private int[] axesTileOverlap(String queryAxes) { + String[] strs = Utils.axesCheckAndNormalize(queryAxes, null, null); + queryAxes = strs[0]; + if (this.tileOverlap != null) { + tileOverlap = computeReceptiveField(); + } + int i = 0; + Map overlap = new HashMap(); + for (String ax : config.axes.split("")) { + if (ax.equals("C")) + continue; + overlap.put(ax, Math.max(tileOverlap[i][0], tileOverlap[i][1])); + } + int[] arr = new int[queryAxes.length()]; + i = 0; + for (String ax : queryAxes.split("")) + arr[i ++] = overlap.keySet().contains(ax) ? overlap.get(ax) : 0; + return arr; + } + + private int[][] computeReceptiveField() { return null; } diff --git a/src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java b/src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java index fdb5bccb..1ba2611d 100644 --- a/src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java +++ b/src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java @@ -15,7 +15,7 @@ public class Resizer { private final Map grid; - private Map pad; + private Map pad; private Map paddedShape; @@ -35,17 +35,17 @@ protected Resizer(Map grid) { String[] strs = Utils.axesCheckAndNormalize(axes, x.numDimensions(), null); axes = strs[0]; - pad = new HashMap(); + pad = new HashMap(); for (int i = 0; i < axes.length(); i ++) { long val = (axesDivBy[i] - x.dimensionsAsLongArray()[i] % axesDivBy[i]) % axesDivBy[i]; String a = axes.split("")[i]; - pad.put(a, (int) val); + pad.put(a, new int[] {0, (int) val}); } long[] minLim = x.minAsLongArray(); long[] maxLim = x.maxAsLongArray(); int i = 0; for (String a : axes.split("")) { - maxLim[ i ++] += pad.get(a); + maxLim[ i ++] += pad.get(a)[1]; } RandomAccessibleInterval xPad = Views.interval( Views.extendMirrorDouble(x), new FinalInterval( minLim, maxLim )); @@ -63,7 +63,32 @@ protected Resizer(Map grid) { after(RandomAccessibleInterval x, String axes) { String[] strs = Utils.axesCheckAndNormalize(axes, x.numDimensions(), null); axes = strs[0]; - RandomAccessibleInterval crop = null; + for (int i = 0; i < axes.length(); i ++) { + String ax = axes.split("")[i]; + long s = x.dimensionsAsLongArray()[i]; + int g = this.grid.keySet().contains(ax) ? grid.get(ax) : 1; + long sPad = this.paddedShape.keySet().contains(ax) ? paddedShape.get(ax) : s; + if (sPad != s * g) + throw new IllegalArgumentException(); + } + + int[] end = new int[axes.length()]; + for (int i = 0; i < axes.length(); i ++) { + String ax = axes.split("")[i]; + int p = pad.keySet().contains(ax) ? pad.get(ax)[1] : 0; + int g = grid.keySet().contains(ax) ? grid.get(ax) : 1; + if (p < g) + end[i] = 0; + else + end[i] = -Math.floorDiv(p, g); + } + long[] minLim = x.minAsLongArray(); + long[] maxLim = x.maxAsLongArray(); + for (int i = 0; i < maxLim.length; i ++) { + maxLim[i] += end[i]; + } + RandomAccessibleInterval crop = Views.interval( + x, new FinalInterval( minLim, maxLim )); return crop; }