Skip to content

Commit

Permalink
add model loaded check
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 26, 2023
1 parent 2a023e2 commit 6d617fd
Showing 1 changed file with 35 additions and 11 deletions.
46 changes: 35 additions & 11 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

import javax.xml.bind.ValidationException;

import ij.IJ;
import ij.ImagePlus;
import io.bioimage.modelrunner.bioimageio.bioengine.BioEngineAvailableModels;
import io.bioimage.modelrunner.bioimageio.bioengine.BioengineInterface;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
Expand All @@ -59,6 +61,7 @@
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.img.basictypeaccess.array.FloatArray;
import net.imglib2.img.display.imagej.ImageJFunctions;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
Expand All @@ -73,6 +76,10 @@
*/
public class Model
{
/**
* Whether the model is loaded or not
*/
boolean loaded = false;
/**
* ClassLoader containing all the classes needed to use the corresponding
* Deep Learning framework (engine).
Expand Down Expand Up @@ -458,6 +465,7 @@ public void loadModel() throws LoadModelException
if (engineClassLoader.isBioengine())
((BioengineInterface) engineInstance).addServer(engineInfo.getServer());
engineClassLoader.setBaseClassLoader();
loaded = true;
}

/**
Expand All @@ -473,6 +481,7 @@ public void closeModel()
engineInstance = null;
engineClassLoader.setBaseClassLoader();
engineClassLoader = null;
loaded = false;
}

/**
Expand Down Expand Up @@ -509,20 +518,21 @@ public void runModel( List< Tensor < ? > > inTensors, List< Tensor < ? > > outTe
*/
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
List<Tensor<T>> runBioimageioModelOnImgLib2WithTiling(List<Tensor<R>> inputTensors) throws ValidationException, RunModelException {
if (!this.isLoaded())
throw new RunModelException("Please first load the model.");
if (descriptor == null && modelFolder == null)
throw new IllegalArgumentException("");
else if (descriptor == null && !(new File(modelFolder, Constants.RDF_FNAME).isFile()))
throw new IllegalArgumentException("");
else if (descriptor == null)
descriptor = ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME);
PatchGridCalculator<R> tileGrid = PatchGridCalculator.build(descriptor, inputTensors);
runTiling(inputTensors, tileGrid);
return null;
return runTiling(inputTensors, tileGrid);
}

@SuppressWarnings("unchecked")
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void runTiling(List<Tensor<R>> inputTensors, PatchGridCalculator<R> tileGrid) throws RunModelException {
List<Tensor<T>> runTiling(List<Tensor<R>> inputTensors, PatchGridCalculator<R> tileGrid) throws RunModelException {
LinkedHashMap<String, PatchSpec> inTileSpecs = tileGrid.getInputTensorsTileSpecs();
LinkedHashMap<String, PatchSpec> outTileSpecs = tileGrid.getOutputTensorsTileSpecs();
List<Tensor<T>> outputTensors = new ArrayList<Tensor<T>>();
Expand All @@ -536,6 +546,7 @@ void runTiling(List<Tensor<R>> inputTensors, PatchGridCalculator<R> tileGrid) th
new FloatType()));
}
doTiling(inputTensors, outputTensors, tileGrid);
return outputTensors;
}

private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
Expand Down Expand Up @@ -579,19 +590,24 @@ void doTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors, Patch

this.runModel(inputTileList, outputTileList);
}

}

public static <T extends NativeType<T> & RealType<T>> void main(String[] args) throws IOException {
public static <T extends NativeType<T> & RealType<T>> void main(String[] args) throws IOException, ValidationException, LoadEngineException, RunModelException {
String mm = "C:\\Users\\angel\\OneDrive\\Documentos\\pasteur\\git\\model-runner-java\\models\\StarDist H&E Nuclei Segmentation_06092023_020924\\";
Img<FloatType> im = ArrayImgs.floats(new long[] {1, 511, 512, 3});
ImagePlus imp = IJ.openImage(mm + File.separator + "sample_input_0.tif");
imp.show();
RandomAccessibleInterval<FloatType> wrapImg = ImageJFunctions.convertFloat(imp);
wrapImg = (RandomAccessibleInterval<FloatType>) Views.addDimension(wrapImg, 0, 0);
wrapImg = (RandomAccessibleInterval<FloatType>) Views.permute(wrapImg, 2, 3);
wrapImg = (RandomAccessibleInterval<FloatType>) Views.permute(wrapImg, 1, 2);
wrapImg = (RandomAccessibleInterval<FloatType>) Views.permute(wrapImg, 0, 1);
List<Tensor<T>> l = new ArrayList<Tensor<T>>();
l.add((Tensor<T>) Tensor.build("input", "bxyc", im));
PatchGridCalculator<T> tileGrid = PatchGridCalculator.build(mm, l);
LinkedHashMap<String, PatchSpec> inTileSpecs = tileGrid.getInputTensorsTileSpecs();
LinkedHashMap<String, PatchSpec> outTileSpecs = tileGrid.getOutputTensorsTileSpecs();
TileGrid aa = TileGrid.create(inTileSpecs.get("input"));
TileGrid bb = TileGrid.create(outTileSpecs.get("output"));
l.add((Tensor<T>) Tensor.build("input", "bxyc", wrapImg));
Model model = createBioimageioModel(mm);
model.loadModel();
List<Tensor<T>> out = model.runBioimageioModelOnImgLib2WithTiling(l);
ImageJFunctions.show(Views.dropSingletonDimensions(out.get(0).getData()));
System.out.println(false);
}

Expand Down Expand Up @@ -652,4 +668,12 @@ public boolean isBioengine() {
public EngineInfo getEngineInfo() {
return engineInfo;
}

/**
* Whether the model is loaded or not
* @return
*/
public boolean isLoaded() {
return loaded;
}
}

0 comments on commit 6d617fd

Please sign in to comment.