Skip to content

Commit

Permalink
finish the scafold for directly running models from the bioimage.io
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 25, 2023
1 parent f080bb1 commit 2a023e2
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,11 @@ public void runModel( List< Tensor < ? > > inTensors, List< Tensor < ? > > outTe
* @param inputTensors
* @return
* @throws ValidationException
* @throws RunModelException
* @throws Exception
*/
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
List<Tensor<T>> runBioimageioModelOnImgLib2WithTiling(List<Tensor<R>> inputTensors) throws ValidationException {
List<Tensor<T>> runBioimageioModelOnImgLib2WithTiling(List<Tensor<R>> inputTensors) throws ValidationException, RunModelException {
if (descriptor == null && modelFolder == null)
throw new IllegalArgumentException("");
else if (descriptor == null && !(new File(modelFolder, Constants.RDF_FNAME).isFile()))
Expand All @@ -521,7 +522,7 @@ else if (descriptor == null)

@SuppressWarnings("unchecked")
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void runTiling(List<Tensor<R>> inputTensors, PatchGridCalculator<R> tileGrid) {
void runTiling(List<Tensor<R>> inputTensors, PatchGridCalculator<R> tileGrid) throws RunModelException {
LinkedHashMap<String, PatchSpec> inTileSpecs = tileGrid.getInputTensorsTileSpecs();
LinkedHashMap<String, PatchSpec> outTileSpecs = tileGrid.getOutputTensorsTileSpecs();
List<Tensor<T>> outputTensors = new ArrayList<Tensor<T>>();
Expand All @@ -538,7 +539,7 @@ void runTiling(List<Tensor<R>> inputTensors, PatchGridCalculator<R> tileGrid) {
}

private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void doTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors, PatchGridCalculator<R> tileGrid) {
void doTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors, PatchGridCalculator<R> tileGrid) throws RunModelException {
LinkedHashMap<String, PatchSpec> inTileSpecs = tileGrid.getInputTensorsTileSpecs();
LinkedHashMap<String, PatchSpec> outTileSpecs = tileGrid.getOutputTensorsTileSpecs();
Map<Object, TileGrid> inTileGrids = inTileSpecs.entrySet().stream()
Expand All @@ -551,7 +552,7 @@ void doTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors, Patch

for (int j = 0; j < nTiles; j ++) {
int tileCount = j + 0;
IntStream.range(0, inputTensors.size()).mapToObj(i -> {
List<Tensor<?>> inputTileList = IntStream.range(0, inputTensors.size()).mapToObj(i -> {
if (!inputTensors.get(i).isImage())
return inputTensors.get(i);
RandomAccessibleInterval<R> tileRai = Views.interval(
Expand All @@ -564,7 +565,19 @@ void doTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors, Patch
Intervals.expand(inputTensors.get(i).getData(), 50));
*/
return Tensor.build(inputTensors.get(i).getName(), inputTensors.get(i).getAxesOrderString(), tileRai);
});
}).collect(Collectors.toList());

List<Tensor<?>> outputTileList = IntStream.range(0, outputTensors.size()).mapToObj(i -> {
if (!outputTensors.get(i).isImage())
return outputTensors.get(i);
RandomAccessibleInterval<T> tileRai = Views.interval(
Views.extendBorder(outputTensors.get(i).getData()),
outTileGrids.get(outputTensors.get(i).getName()).getTilePostionsInImage().get(tileCount),
(long[]) outTileGrids.get(outputTensors.get(i).getName()).getTileSize());
return Tensor.build(outputTensors.get(i).getName(), outputTensors.get(i).getAxesOrderString(), tileRai);
}).collect(Collectors.toList());

this.runModel(inputTileList, outputTileList);
}

}
Expand Down

0 comments on commit 2a023e2

Please sign in to comment.