From 192f12190c37384d70e60c9c9f84f53aacd01b5c Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Wed, 25 Oct 2023 16:16:41 +0200 Subject: [PATCH] start finding the path specs for the output images --- .../tiling/PatchGridCalculator.java | 81 ++++++++++++++++++- .../modelrunner/tiling/PatchSpec.java | 48 ++++++++++- 2 files changed, 126 insertions(+), 3 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java b/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java index 12a6b780..faed09f6 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java @@ -65,7 +65,7 @@ private PatchGridCalculator(ModelDescriptor descriptor, List> tensorLi throws IllegalArgumentException { for (TensorSpec tt : descriptor.getInputTensors()) { - if (tt.isImage() && Tensor.getTensorByNameFromList(tensorList, tt.getName()) != null) + if (tt.isImage() && Tensor.getTensorByNameFromList(tensorList, tt.getName()) == null) throw new IllegalArgumentException("Model input tensor '" + tt.getName() + "' is specified in the rdf.yaml specs file " + "but cannot be found in the model inputs map provided."); // TODO change isImage() by isTensor() @@ -136,6 +136,7 @@ public LinkedHashMap get() throws IllegalArgumentException if (psMap != null) return psMap; List inputTensors = findInputImageTensorSpec(); + List outputTensors = findOutputImageTensorSpec(); List> inputImages = inputTensors.stream() .filter(k -> this.inputValuesMap.get(k.getName()) != null) .map(k -> this.inputValuesMap.get(k.getName())).collect(Collectors.toList()); @@ -143,6 +144,7 @@ public LinkedHashMap get() throws IllegalArgumentException throw new IllegalArgumentException("No inputs have been provided that match the " + "specified input tensors specified in the rdf.yaml file."); LinkedHashMap specsMap = computePatchSpecsForEveryTensor(inputTensors, inputImages); + LinkedHashMap outSpecsMap = computePatchSpecsForEveryOutputTensor(outputTensors, specsMap); // Check that the obtained patch specs are not going to cause errors checkPatchSpecs(specsMap); psMap = specsMap; @@ -201,6 +203,16 @@ private List findInputImageTensorSpec() return this.descriptor.getInputTensors().stream().filter(tr -> tr.isImage()) .collect(Collectors.toList()); } + + /** + * Get the output tensors that correspond to images + * @return list of tensor specs corresponding to each of the output image tensors + */ + private List findOutputImageTensorSpec() + { + return this.descriptor.getOutputTensors().stream().filter(tr -> tr.isImage()) + .collect(Collectors.toList()); + } /** * Create list of patch specifications for every tensor aking into account the @@ -218,6 +230,27 @@ private LinkedHashMap computePatchSpecsForEveryTensor(List computePatchSpecsForEveryOutputTensor(List outTensors, + Map inSpecs){ + LinkedHashMap patchInfoList = new LinkedHashMap(); + for (int i = 0; i < outTensors.size(); i ++) { + String refTensor = outTensors.get(i).getShape().getReferenceInput(); + PatchSpec refSpec = refTensor == null ? inSpecs.values().stream().findFirst().get() : inSpecs.get(refTensor); + patchInfoList.put(outTensors.get(i).getName(), computePatchSpecsForOutputTensor(outTensors.get(i), refSpec)); + } + return patchInfoList; + } /** * Compute the patch details needed to perform the tiling strategy. The calculations @@ -309,7 +342,51 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval .map(i -> (int) Math.max( paddingSize[1][i], inputPatchSize[i] - inputSequenceSize[i] - paddingSize[0][i])).toArray(); - return PatchSpec.create(spec.getName(), inputPatchSize, patchGridSize, paddingSize); + return PatchSpec.create(spec.getName(), inputPatchSize, patchGridSize, paddingSize, rai.dimensionsAsLongArray()); + } + + private PatchSpec computePatchSpecsForOutputTensor(TensorSpec spec, PatchSpec refSpec) + { + String processingAxesOrder = spec.getAxesOrder(); + int[] inputPatchSize = arrayToWantedAxesOrderAddOnes(tileSize, spec.getAxesOrder(), + processingAxesOrder); + int[][] paddingSize = new int[2][5]; + // REgard that the input halo represents the output halo + offset + // and must be divisible by 0.5. + float[] halo = arrayToWantedAxesOrderAddZeros(spec.getHalo(), + spec.getAxesOrder(), + processingAxesOrder); + if (!descriptor.isPyramidal() && spec.getTiling()) { + // In the case that padding is asymmetrical, the left upper padding has the extra pixel + for (int i = 0; i < halo.length; i ++) {paddingSize[0][i] = (int) Math.ceil(halo[i]);} + // In the case that padding is asymmetrical, the right bottom padding has one pixel less + for (int i = 0; i < halo.length; i ++) {paddingSize[1][i] = (int) Math.floor(halo[i]);} + + } + long[] shapeLong = rai.dimensionsAsLongArray(); + int[] shapeInt = new int[shapeLong.length]; + for (int i = 0; i < shapeInt.length; i ++) {shapeInt[i] = (int) shapeLong[i];} + int[] inputSequenceSize = arrayToWantedAxesOrderAddOnes(shapeInt, + spec.getAxesOrder(), + processingAxesOrder); + int[] patchGridSize = new int[] {1, 1, 1, 1, 1}; + if (descriptor.isTilingAllowed()) { + patchGridSize = IntStream.range(0, inputPatchSize.length) + .map(i -> (int) Math.ceil((double) inputSequenceSize[i] / ((double) inputPatchSize[i] - halo[i] * 2))) + .toArray(); + } + // For the cases when the patch is bigger than the image size, share the + // padding between both sides of the image + paddingSize[0] = IntStream.range(0, inputPatchSize.length) + .map(i -> + (int) Math.max(paddingSize[0][i], + Math.ceil( (double) (inputPatchSize[i] - inputSequenceSize[i]) / 2)) + ).toArray(); + paddingSize[1] = IntStream.range(0, inputPatchSize.length) + .map(i -> (int) Math.max( paddingSize[1][i], + inputPatchSize[i] - inputSequenceSize[i] - paddingSize[0][i])).toArray(); + + return PatchSpec.create(spec.getName(), inputPatchSize, patchGridSize, paddingSize, rai.dimensionsAsLongArray()); } /** diff --git a/src/main/java/io/bioimage/modelrunner/tiling/PatchSpec.java b/src/main/java/io/bioimage/modelrunner/tiling/PatchSpec.java index d1f76999..1658cacd 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/PatchSpec.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/PatchSpec.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; +import java.util.Map; /** * Patch specification providing information about the patch size and patch grid size. @@ -29,6 +30,10 @@ */ public class PatchSpec { + /** + * Size of the tensor that is going to be tiled + */ + private long[] tensorDims; /** * Size of the input patch. Following "xyczb" axes order */ @@ -63,13 +68,14 @@ public class PatchSpec * @return The create patch specification. */ public static PatchSpec create(String tensorName, int[] patchInputSize, int[] patchGridSize, - int[][] patchPaddingSize) + int[][] patchPaddingSize, long[] tensorDims) { PatchSpec ps = new PatchSpec(); ps.patchInputSize = patchInputSize; ps.patchGridSize = patchGridSize; ps.patchPaddingSize = patchPaddingSize; ps.tensorName = tensorName; + ps.tensorDims = tensorDims; return ps; } @@ -78,6 +84,10 @@ private PatchSpec() } /** + * TODO this method should be per image, not in total?? + * TODO this method should be per image, not in total?? + * TODO this method should be per image, not in total?? + * TODO this method should be per image, not in total?? * Obtain the number of patches in each axes for a list of input patch specs. * When tiling is allowed, only one patch grid is permitted. If among the tensors * there are one or more that do not allow tiling, then two patch sizes are allowed, @@ -101,6 +111,34 @@ public static int[] getGridSize(List patches) { return grid; } + /** + * TODO this method should be per image, not in total?? + * TODO this method should be per image, not in total?? + * TODO this method should be per image, not in total?? + * TODO this method should be per image, not in total?? + * Obtain the number of patches in each axes for a list of input patch specs. + * When tiling is allowed, only one patch grid is permitted. If among the tensors + * there are one or more that do not allow tiling, then two patch sizes are allowed, + * the one for the tensors that allow tiling and the one for the ones that not (that will + * just be 1s in every axes). + * In the case there exist tensors that allow tiling, the grid size for those will be the + * one returned + * @param patches + * map containing tiling specs per tensor + * @return the number of patches in each axes + */ + public static int[] getGridSize(Map patches) { + // The minimum possible grid is just one patch in every direction. This is the + // grid if no tiling is allowed + int[] grid = new int[]{1, 1, 1, 1, 1}; + // If there is any different grid, that will be the absolute one + for (PatchSpec pp : patches.values()) { + if (!PatchGridCalculator.compareTwoArrays(grid, pp.getPatchGridSize())) + return pp.getPatchGridSize(); + } + return grid; + } + /** * Return the PatchSpec corresponding to the tensor called by the name defined * @param specs @@ -120,6 +158,14 @@ public static PatchSpec getPatchSpecFromListByName(List specs, String public String getTensorName() { return tensorName; } + + /** + * The dimensions of the tensor + * @return the dimensions of the tensor that is going to be tiled + */ + public long[] getTensorDims() { + return tensorDims; + } /** * @return Input patch size. The patch taken from the input sequence including the halo.