diff --git a/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java b/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java index 49c2d6fa..c6ec50f7 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java @@ -22,14 +22,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Map; import java.util.stream.IntStream; +import java.util.stream.LongStream; import io.bioimage.modelrunner.utils.IndexingUtils; -import net.imglib2.img.Img; -import net.imglib2.img.array.ArrayImgFactory; -import net.imglib2.type.NativeType; -import net.imglib2.util.Util; /** * Calculate all the coordinates for the tiles for a specific tensor in the image that wants to @@ -66,9 +62,10 @@ private TileGrid() /** */ - public static TileGrid create(PatchSpec tileSpecs, int[] gridSize) + public static TileGrid create(PatchSpec tileSpecs, long[] imageDims) { TileGrid ps = new TileGrid(); + int[] gridSize = tileSpecs.getPatchGridSize(); ps.tileSize = tileSpecs.getPatchInputSize(); int tileCount = Arrays.stream(gridSize).reduce(1, (a, b) -> a * b); @@ -80,11 +77,12 @@ public static TileGrid create(PatchSpec tileSpecs, int[] gridSize) .map(i -> patchSize[i] - padSize[0][i] - padSize[1][i]).toArray(); ps.roiSize = roiSize; ps.roiPositionsInTile.add(IntStream.range(0, padSize[0].length).mapToLong(i -> (long) padSize[0][i]).toArray()); - long[] roiStart = IntStream.range(0, patchIndex.length) - .map(i -> roiSize[i] * patchIndex[i]).mapToLong(i -> (long) i).toArray(); + long[] roiStart = LongStream.range(0, patchIndex.length) + .map(i -> Math.min(roiSize[(int) i] * patchIndex[(int) i], imageDims[(int) i] - roiSize[(int) i])).toArray(); ps.roiPositionsInImage.add(roiStart); - long[] patchStart = IntStream.range(0, patchIndex.length) - .map(i -> roiSize[i] * patchIndex[i] - padSize[0][i]).mapToLong(i -> (long) i).toArray(); + long[] patchStart = LongStream.range(0, patchIndex.length) + .map(i -> Math.min(roiSize[(int) i] * patchIndex[(int) i] - padSize[0][(int) i], imageDims[(int) i] - roiSize[(int) i])) + .toArray(); ps.tilePostionsInImage.add(patchStart); } return ps;