Skip to content

Commit

Permalink
correct small bug and ad setters and getters
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 25, 2023
1 parent 0dde1d1 commit 5d36645
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ public LinkedHashMap<String, PatchSpec> getInputTensorsTileSpecs() throws Illega
if (this.inputTilesSpecs != null)
return inputTilesSpecs;
List<TensorSpec> inputTensors = findInputImageTensorSpec();
List<TensorSpec> outputTensors = findOutputImageTensorSpec();
List<Tensor<T>> inputImages = inputTensors.stream()
.filter(k -> this.inputValuesMap.get(k.getName()) != null)
.map(k -> this.inputValuesMap.get(k.getName())).collect(Collectors.toList());
Expand Down Expand Up @@ -235,17 +234,17 @@ private LinkedHashMap<String, PatchSpec> computePatchSpecsForEveryTensor(List<Te
return patchInfoList;
}

public LinkedHashMap<String, PatchSpec> getOutputTensorsTileSpecs(List<TensorSpec> outTensors,
Map<String, PatchSpec> inSpecs) throws IllegalArgumentException {
public LinkedHashMap<String, PatchSpec> getOutputTensorsTileSpecs() throws IllegalArgumentException {
if (this.inputTilesSpecs == null)
throw new IllegalArgumentException("Please first calculate the tile specs for the input tensors. Call: "
+ "getInputTensorsTileSpecs()");
if (this.outputTilesSpecs != null)
return outputTilesSpecs;
List<TensorSpec> outTensors = findOutputImageTensorSpec();
LinkedHashMap<String, PatchSpec> patchInfoList = new LinkedHashMap<String, PatchSpec>();
for (int i = 0; i < outTensors.size(); i ++) {
String refTensor = outTensors.get(i).getShape().getReferenceInput();
PatchSpec refSpec = refTensor == null ? inSpecs.values().stream().findFirst().get() : inSpecs.get(refTensor);
if (refSpec == null)
throw new IllegalArgumentException("Please first calculate the tile specs for th einput tensors. Call: "
+ "getInputTensorsTileSpecs()");
PatchSpec refSpec = refTensor == null ? inputTilesSpecs.values().stream().findFirst().get() : inputTilesSpecs.get(refTensor);
patchInfoList.put(outTensors.get(i).getName(), computePatchSpecsForOutputTensor(outTensors.get(i), refSpec));
}
outputTilesSpecs = patchInfoList;
Expand Down Expand Up @@ -336,29 +335,29 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval<T>
return PatchSpec.create(spec.getName(), tileSize, patchGridSize, paddingSize, rai.dimensionsAsLongArray());
}

private PatchSpec computePatchSpecsForOutputTensor(TensorSpec spec, PatchSpec refSpec)
private PatchSpec computePatchSpecsForOutputTensor(TensorSpec tensorSpec, PatchSpec refTilesSpec)
{
int[] inputTileGrid = refSpec.getPatchGridSize();
int[] inputTileGrid = refTilesSpec.getPatchGridSize();
// REgard that the input halo represents the output halo + offset
// and must be divisible by 0.5.
int[][] paddingSize = refSpec.getPatchPaddingSize();
int[][] paddingSize = refTilesSpec.getPatchPaddingSize();
int[] tileSize;
long[] shapeLong;
if (spec.getShape().getReferenceInput() == null) {
tileSize = spec.getShape().getPatchRecomendedSize();
shapeLong = LongStream.range(0, spec.getAxesOrder().length())
if (tensorSpec.getShape().getReferenceInput() == null) {
tileSize = tensorSpec.getShape().getPatchRecomendedSize();
shapeLong = LongStream.range(0, tensorSpec.getAxesOrder().length())
.map(i -> (tileSize[(int) i] - paddingSize[0][(int) i] - paddingSize[0][(int) i]) * inputTileGrid[(int) i])
.toArray();
} else {
tileSize = IntStream.range(0, spec.getAxesOrder().length())
.map(i -> (int) (refSpec.getPatchInputSize()[i] * spec.getShape().getScale()[i] + 2 * spec.getShape().getOffset()[i]))
tileSize = IntStream.range(0, tensorSpec.getAxesOrder().length())
.map(i -> (int) (refTilesSpec.getPatchInputSize()[i] * tensorSpec.getShape().getScale()[i] + 2 * tensorSpec.getShape().getOffset()[i]))
.toArray();
shapeLong = LongStream.range(0, spec.getAxesOrder().length())
.map(i -> (int) (refSpec.getPatchInputSize()[(int) i] * spec.getShape().getScale()[(int) i]
+ 2 * spec.getShape().getOffset()[(int) i])).toArray();
shapeLong = LongStream.range(0, tensorSpec.getAxesOrder().length())
.map(i -> (int) (refTilesSpec.getTensorDims()[(int) i] * tensorSpec.getShape().getScale()[(int) i]
+ 2 * tensorSpec.getShape().getOffset()[(int) i])).toArray();
}

return PatchSpec.create(spec.getName(), tileSize, inputTileGrid, paddingSize, shapeLong);
return PatchSpec.create(tensorSpec.getName(), tileSize, inputTileGrid, paddingSize, shapeLong);
}

/**
Expand Down
28 changes: 27 additions & 1 deletion src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ private TileGrid()

/**
*/
public static TileGrid create(PatchSpec tileSpecs, long[] imageDims)
public static TileGrid create(PatchSpec tileSpecs)
{
TileGrid ps = new TileGrid();
ps.tensorName = tileSpecs.getTensorName();
long[] imageDims = tileSpecs.getTensorDims();
int[] gridSize = tileSpecs.getPatchGridSize();
ps.tileSize = tileSpecs.getPatchInputSize();
int tileCount = Arrays.stream(gridSize).reduce(1, (a, b) -> a * b);
Expand All @@ -87,6 +89,30 @@ public static TileGrid create(PatchSpec tileSpecs, long[] imageDims)
}
return ps;
}

public String getTensorName() {
return tensorName;
}

public int[] getTileSize() {
return this.tileSize;
}

public int[] getRoiSize() {
return this.roiSize;
}

public List<long[]> getTilePostionsInImage() {
return this.tilePostionsInImage;
}

public List<long[]> getRoiPositionsInTile() {
return this.roiPositionsInTile;
}

public List<long[]> getRoiPostionsInImage() {
return this.roiPositionsInImage;
}

@Override
public String toString()
Expand Down

0 comments on commit 5d36645

Please sign in to comment.