Skip to content

Commit

Permalink
keep iterating
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jan 15, 2025
1 parent 2af780d commit c54a950
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.util.Map;

import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;

public class AbstractStardist {

Expand All @@ -15,6 +16,8 @@ public class AbstractStardist {

private float nmsThres;

private int[][] tileOverlap;

private void predict_instances(RandomAccessibleInterval img, Map<String, Object> kwargs) {
Map<String, Object> predictKwargs = new HashMap<String, Object>();
if (kwargs.get("predictKwargs") != null && kwargs.get("predictKwargs") instanceof Map)
Expand Down Expand Up @@ -72,7 +75,23 @@ private void predictSparseGenerator(RandomAccessibleInterval img, String axes, N
Float probThresh) {
if (probThresh == null) probThresh = this.probThres;

Map<String, Object> args = predictSetup(img, axes, normalizer, nTiles);
Map<String, Object> returns = predictSetup(img, axes, normalizer, nTiles);

RandomAccessibleInterval x = (RandomAccessibleInterval) returns.get("x");
List<Integer> nTiles = (List<Integer>) returns.get("nTiles");
axes = (String) returns.get("axes");
String axesNet = (String) returns.get("axesNet");
int[] axesNetDivBy = (int[]) returns.get("axesNetDivBy");
int[] grid = (int[]) returns.get("grid");
Map<String, Integer> gridDict = (Map<String, Integer>) returns.get("gridDict");
Resizer resizer = (Resizer) returns.get("resizer");
int channel = (int) returns.get("channel");
int product = nTiles.stream().reduce(1, (a, b) -> a * b);
if (product > 1) {
// TODO
} else {

}

}

Expand All @@ -88,7 +107,7 @@ private Map<String, Object> predictSetup(RandomAccessibleInterval img, String ax
String axesNet = this.config.axes;
// TODO permuteAxes
RandomAccessibleInterval x = null; // TODO
int channel = axesDict(axesNet).get("C");
channel = axesDict(axesNet).get("C");
if (this.config.nChannelIn != x.dimensionsAsLongArray()[channel])
throw new IllegalArgumentException("The number of channels of the image ("
+ x.dimensionsAsLongArray()[channel] + ") should be the same as the model config ("
Expand All @@ -106,6 +125,76 @@ private Map<String, Object> predictSetup(RandomAccessibleInterval img, String ax
}

Resizer resizer = new Resizer(gridDict);
x = resizer.before(x, axesNet, axesNetDivBy);
Map<String, Object> returns = new HashMap<String, Object>();
returns.put("x", x);
returns.put("axes", axes);
returns.put("axesNet", axesNet);
returns.put("axesNetDivBy", axesNetDivBy);
returns.put("permuteAxes", permuteAxes);
returns.put("resizer", resizer);
returns.put("nTiles", nTiles);
returns.put("grid", grid);
returns.put("gridDict", gridDict);
returns.put("channel", channel);
return returns;
}

private int channel;
private Integer[] sh;

private void tilingSetup(RandomAccessibleInterval x, List<Integer> nTiles, int[] axesNetDivBy,
String axesNet, Map<String, Integer> gridDict) {
String tilingAxes = axesNet.replace("C", "");
int[] xTilingAxes = new int[tilingAxes.length()];
int c = 0;
for (String a : tilingAxes.split("")) {
xTilingAxes[c ++] = axesDict(axesNet).get(a);
}
int[] axesNetTileOverlaps = axesTileOverlap(axesNet);
// TODO nTiles = permuteAxes();
sh = new Integer[axesNet.length()];
for (int i = 0; i < axesNet.length(); i ++) {
String a = axesNet.split("")[i];
sh[i] = (int) Math.floorDiv(x.dimensionsAsLongArray()[i],
gridDict.keySet().contains(a) ? gridDict.get(a) : 1);
}
sh[channel] = null;
int[] nBlockOverlaps = new int[axesNetTileOverlaps.length];
for (int i = 0; i < axesNetTileOverlaps.length; i ++) {
nBlockOverlaps[i] = (int) Math.ceil(axesNetTileOverlaps[i] / (double) axesNetDivBy[i]);
}
}

private RandomAccessibleInterval createEmptyOutput(int nChannel) {
sh[channel] = nChannel;
long[] dims = new long[sh.length];
for (int i = 0; i < sh.length; i ++)
dims[i] = sh[i].longValue();
return ArrayImgs.floats(dims);
}

private int[] axesTileOverlap(String queryAxes) {
String[] strs = Utils.axesCheckAndNormalize(queryAxes, null, null);
queryAxes = strs[0];
if (this.tileOverlap != null) {
tileOverlap = computeReceptiveField();
}
int i = 0;
Map<String, Integer> overlap = new HashMap<String, Integer>();
for (String ax : config.axes.split("")) {
if (ax.equals("C"))
continue;
overlap.put(ax, Math.max(tileOverlap[i][0], tileOverlap[i][1]));
}
int[] arr = new int[queryAxes.length()];
i = 0;
for (String ax : queryAxes.split(""))
arr[i ++] = overlap.keySet().contains(ax) ? overlap.get(ax) : 0;
return arr;
}

private int[][] computeReceptiveField() {
return null;
}

Expand Down
35 changes: 30 additions & 5 deletions src/main/java/io/bioimage/modelrunner/model/stardist/Resizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class Resizer {

private final Map<String, Integer> grid;

private Map<String, Integer> pad;
private Map<String, int[]> pad;

private Map<String, Integer> paddedShape;

Expand All @@ -35,17 +35,17 @@ protected Resizer(Map<String, Integer> grid) {

String[] strs = Utils.axesCheckAndNormalize(axes, x.numDimensions(), null);
axes = strs[0];
pad = new HashMap<String, Integer>();
pad = new HashMap<String, int[]>();
for (int i = 0; i < axes.length(); i ++) {
long val = (axesDivBy[i] - x.dimensionsAsLongArray()[i] % axesDivBy[i]) % axesDivBy[i];
String a = axes.split("")[i];
pad.put(a, (int) val);
pad.put(a, new int[] {0, (int) val});
}
long[] minLim = x.minAsLongArray();
long[] maxLim = x.maxAsLongArray();
int i = 0;
for (String a : axes.split("")) {
maxLim[ i ++] += pad.get(a);
maxLim[ i ++] += pad.get(a)[1];
}
RandomAccessibleInterval<T> xPad = Views.interval(
Views.extendMirrorDouble(x), new FinalInterval( minLim, maxLim ));
Expand All @@ -63,7 +63,32 @@ protected Resizer(Map<String, Integer> grid) {
after(RandomAccessibleInterval<T> x, String axes) {
String[] strs = Utils.axesCheckAndNormalize(axes, x.numDimensions(), null);
axes = strs[0];
RandomAccessibleInterval<T> crop = null;
for (int i = 0; i < axes.length(); i ++) {
String ax = axes.split("")[i];
long s = x.dimensionsAsLongArray()[i];
int g = this.grid.keySet().contains(ax) ? grid.get(ax) : 1;
long sPad = this.paddedShape.keySet().contains(ax) ? paddedShape.get(ax) : s;
if (sPad != s * g)
throw new IllegalArgumentException();
}

int[] end = new int[axes.length()];
for (int i = 0; i < axes.length(); i ++) {
String ax = axes.split("")[i];
int p = pad.keySet().contains(ax) ? pad.get(ax)[1] : 0;
int g = grid.keySet().contains(ax) ? grid.get(ax) : 1;
if (p < g)
end[i] = 0;
else
end[i] = -Math.floorDiv(p, g);
}
long[] minLim = x.minAsLongArray();
long[] maxLim = x.maxAsLongArray();
for (int i = 0; i < maxLim.length; i ++) {
maxLim[i] += end[i];
}
RandomAccessibleInterval<T> crop = Views.interval(
x, new FinalInterval( minLim, maxLim ));
return crop;
}

Expand Down

0 comments on commit c54a950

Please sign in to comment.