Skip to content

Commit

Permalink
craete the scaffold for tiling tensors and runnning a model
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 25, 2023
1 parent e0e09c4 commit 00ef328
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 10 deletions.
60 changes: 51 additions & 9 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,20 @@
import java.io.IOException;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import javax.xml.bind.ValidationException;

import io.bioimage.modelrunner.bioimageio.bioengine.BioEngineAvailableModels;
import io.bioimage.modelrunner.bioimageio.bioengine.BioengineInterface;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.description.weights.WeightFormat;
import io.bioimage.modelrunner.engine.DeepLearningEngineInterface;
import io.bioimage.modelrunner.engine.EngineInfo;
Expand All @@ -47,6 +50,7 @@
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tiling.PatchGridCalculator;
import io.bioimage.modelrunner.tiling.PatchSpec;
import io.bioimage.modelrunner.tiling.TileGrid;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import net.imglib2.RandomAccessibleInterval;
Expand Down Expand Up @@ -493,32 +497,70 @@ public void runModel( List< Tensor < ? > > inTensors, List< Tensor < ? > > outTe
* @param <T>
* ImgLib2 data type of the output images
* @param <R>
* @param inputImgs
* @param inputTensors
* @return
* @throws ValidationException
* @throws Exception
*/
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
List<Img<T>> runBioimageioModelOnImgLib2WithTiling(List<Tensor<R>> inputImgs) throws ValidationException {
List<Tensor<T>> runBioimageioModelOnImgLib2WithTiling(List<Tensor<R>> inputTensors) throws ValidationException {
if (descriptor == null && modelFolder == null)
throw new IllegalArgumentException("");
else if (descriptor == null && !(new File(modelFolder, Constants.RDF_FNAME).isFile()))
throw new IllegalArgumentException("");
else if (descriptor == null)
descriptor = ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME);
PatchGridCalculator<R> tileGrid = PatchGridCalculator.build(descriptor, inputImgs);
Map<String, PatchSpec> specs = tileGrid.get();
//specs.get("").getPatchInputSize()
PatchGridCalculator<R> tileGrid = PatchGridCalculator.build(descriptor, inputTensors);
runTiling(inputTensors, tileGrid);
return null;
}

@SuppressWarnings("unchecked")
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void runTiling(List<Tensor<R>> inputTensors, PatchGridCalculator<R> tileGrid) {
LinkedHashMap<String, PatchSpec> inTileSpecs = tileGrid.getInputTensorsTileSpecs();
LinkedHashMap<String, PatchSpec> outTileSpecs = tileGrid.getOutputTensorsTileSpecs();
List<Tensor<T>> outputTensors = new ArrayList<Tensor<T>>();
for (TensorSpec tt : descriptor.getOutputTensors()) {
if (outTileSpecs.get(tt.getName()) == null)
outputTensors.add(Tensor.buildEmptyTensor(tt.getName(), tt.getAxesOrder()));
else
outputTensors.add((Tensor<T>) Tensor.buildBlankTensor(tt.getName(),
tt.getAxesOrder(),
outTileSpecs.get(tt.getName()).getTensorDims(),
new FloatType()));
}
doTiling(inputTensors, outputTensors, tileGrid);
}

private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void doTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors, PatchGridCalculator<R> tileGrid) {
LinkedHashMap<String, PatchSpec> inTileSpecs = tileGrid.getInputTensorsTileSpecs();
LinkedHashMap<String, PatchSpec> outTileSpecs = tileGrid.getOutputTensorsTileSpecs();
Map<Object, TileGrid> inTileGrids = inTileSpecs.entrySet().stream()
.collect(Collectors.toMap(entry -> entry.getKey(), entry -> TileGrid.create(entry.getValue())));
Map<Object, TileGrid> outTileGrids = outTileSpecs.entrySet().stream()
.collect(Collectors.toMap(entry -> entry.getKey(), entry -> TileGrid.create(entry.getValue())));
int[] tilesPerAxis = inTileSpecs.values().stream().findFirst().get().getPatchGridSize();
int nTiles = 1;
for (int i : tilesPerAxis) nTiles *= i;

for (int i = 0; i < nTiles; i ++) {

}

}

public static <T extends NativeType<T> & RealType<T>> void main(String[] args) throws IOException {
String mm = "C:\\Users\\angel\\OneDrive\\Documentos\\pasteur\\git\\model-runner-java\\models\\StarDist H&E Nuclei Segmentation_06092023_020924\\";
Img<FloatType> im = ArrayImgs.floats(new long[] {1, 512, 512, 4});
Map<String, Tensor<T>> l = new HashMap<String, Tensor<T>>();
l.put("input", (Tensor<T>) Tensor.build("input", "bxyc", im));
Img<FloatType> im = ArrayImgs.floats(new long[] {1, 511, 512, 3});
List<Tensor<T>> l = new ArrayList<Tensor<T>>();
l.add((Tensor<T>) Tensor.build("input", "bxyc", im));
PatchGridCalculator<T> tileGrid = PatchGridCalculator.build(mm, l);
LinkedHashMap<String, PatchSpec> tileSpecs = tileGrid.get();
LinkedHashMap<String, PatchSpec> inTileSpecs = tileGrid.getInputTensorsTileSpecs();
LinkedHashMap<String, PatchSpec> outTileSpecs = tileGrid.getOutputTensorsTileSpecs();
TileGrid aa = TileGrid.create(inTileSpecs.get("input"));
TileGrid bb = TileGrid.create(outTileSpecs.get("output"));
System.out.println(false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ public static <T extends RealType<T> & NativeType<T>> PatchGridCalculator<T> bui
try {
descriptor =
ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME, false);
descriptor.getInputTensors().get(0).setTileSizeForTensorAndImageSize(new int[]{1, 256, 256, 3}, new int[]{1, 512, 512, 3});
} catch (Exception ex) {
throw new IOException("Unable to process the rf.yaml specifications file.", ex);
}
Expand Down

0 comments on commit 00ef328

Please sign in to comment.