Skip to content

Commit

Permalink
improve the robustness of the patching strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 17, 2023
1 parent bc17ca2 commit db1b825
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
17 changes: 16 additions & 1 deletion src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.io.File;
import java.io.IOException;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -49,6 +51,9 @@
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.img.basictypeaccess.array.FloatArray;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
Expand Down Expand Up @@ -503,9 +508,19 @@ else if (descriptor == null)
descriptor = ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME);
PatchGridCalculator tileGrid = PatchGridCalculator.build(descriptor, inputImgs);
Map<String, PatchSpec> specs = tileGrid.get();
specs.get("").getPatchInputSize()
//specs.get("").getPatchInputSize()
return null;
}

public static void main(String[] args) throws IOException {
String mm = "C:\\Users\\angel\\OneDrive\\Documentos\\pasteur\\git\\model-runner-java\\models\\StarDist H&E Nuclei Segmentation_06092023_020924\\";
Img<FloatType> im = ArrayImgs.floats(new long[] {1, 512, 512, 1});
Map<String, Object> l = new HashMap<String, Object>();
l.put("input", im);
PatchGridCalculator tileGrid = PatchGridCalculator.build(mm, l);
tileGrid.get();
System.out.println(false);
}

/**
* Get the EngineClassLoader created by the DeepLearning Model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.bioimage.modelrunner.utils.Constants;
import net.imglib2.IterableInterval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.NumericType;
import net.imglib2.type.numeric.RealType;
Expand All @@ -44,11 +45,11 @@
*
* @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
*/
public class PatchGridCalculator
public class PatchGridCalculator<T extends RealType<T> & NativeType<T>>
{

private ModelDescriptor descriptor;
private Map<String, Object> inputValuesMap;
private Map<String, Tensor<T>> inputValuesMap;
/**
* MAp containing the {@link PatchSpec} for each of the tensors defined in the rdf.yaml specs file
*/
Expand All @@ -63,18 +64,15 @@ public class PatchGridCalculator
* map containing the input images associated to their input tensors
* @throws IllegalArgumentException if the {@link #inputValuesMap}
*/
private PatchGridCalculator(ModelDescriptor descriptor, Map<String, Object> inputValuesMap)
private PatchGridCalculator(ModelDescriptor descriptor, Map<String, Tensor<T>> inputValuesMap)
throws IllegalArgumentException
{
for (TensorSpec tt : descriptor.getInputTensors()) {
if (tt.isImage() && inputValuesMap.get(tt.getName()) == null)
throw new IllegalArgumentException("Model input tensor '" + tt.getName() + "' is specified in the rdf.yaml specs file "
+ "but cannot be found in the model inputs map provided.");
// TODO change isImage() by isTensor()
if (tt.isImage()
&& !(inputValuesMap.get(tt.getName()) instanceof RandomAccessibleInterval)
&& !(inputValuesMap.get(tt.getName()) instanceof IterableInterval)
&& !(inputValuesMap.get(tt.getName()) instanceof Tensor))
if (tt.isImage() && !(inputValuesMap.get(tt.getName()) instanceof Tensor))
throw new IllegalArgumentException("Model input tensor '" + tt.getName() + "' is specified in the rdf.yaml specs file "
+ "as a tensor but. JDLL needs tensor to be specified either as JDLL tensors (io.bioimage.tensor.Tensor) "
+ "or ImgLib2 Imgs (net.imglib2.img.Img), ImgLib2 RandomAccessibleIntervals (net.imglib2.RandomAccessibleInterval) "
Expand All @@ -96,15 +94,15 @@ private PatchGridCalculator(ModelDescriptor descriptor, Map<String, Object> inpu
* @throws IOException if it is not possible to read the rdf.yaml file of the model or it
* does not exist
*/
public static PatchGridCalculator build(String modelFolder, Map<String, Object> inputValuesMap) throws IOException {
public static <T extends RealType<T> & NativeType<T>> PatchGridCalculator<T> build(String modelFolder, Map<String, Tensor<T>> inputValuesMap) throws IOException {
ModelDescriptor descriptor;
try {
descriptor =
ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME, false);
} catch (Exception ex) {
throw new IOException("Unable to process the rf.yaml specifications file.", ex);
}
return new PatchGridCalculator(descriptor, inputValuesMap);
return new PatchGridCalculator<T>(descriptor, inputValuesMap);
}

/**
Expand All @@ -117,9 +115,10 @@ public static PatchGridCalculator build(String modelFolder, Map<String, Object>
* @throws IllegalArgumentException if the inputs provided in the input values map does not correspond
* to the inputs defined in the inputs field of the rdf.yaml specs file.
*/
public static PatchGridCalculator build(ModelDescriptor model, LinkedHashMap<String, Object> inputValuesMap)
public static <T extends RealType<T> & NativeType<T>>
PatchGridCalculator<T> build(ModelDescriptor model, LinkedHashMap<String, Tensor<T>> inputValuesMap)
throws IllegalArgumentException {
return new PatchGridCalculator(model, inputValuesMap);
return new PatchGridCalculator<T>(model, inputValuesMap);
}

/**
Expand All @@ -135,7 +134,7 @@ public static PatchGridCalculator build(ModelDescriptor model, LinkedHashMap<Str
* @return the object that creates a list of patch specs for each tensor
*/
public static <T extends NumericType<T> & RealType<T>>
PatchGridCalculator build(ModelDescriptor model, List<RandomAccessibleInterval<T>> inputImagesList) {
PatchGridCalculator<T> build(ModelDescriptor model, List<Tensor<T>> inputImagesList) {
LinkedHashMap<String, Object> map = new LinkedHashMap<String, Object>();
if (inputImagesList.size() != model.getInputTensors().size())
throw new IllegalArgumentException("The size of the list containing the model input RandomAccessibleIntervals"
Expand All @@ -144,7 +143,7 @@ PatchGridCalculator build(ModelDescriptor model, List<RandomAccessibleInterval<T
int c = 0;
for (TensorSpec tt : model.getInputTensors())
map.put(tt.getName(), inputImagesList.get(c ++));
return new PatchGridCalculator(model, map);
return new PatchGridCalculator<T>(model, map);
}

/**
Expand All @@ -161,7 +160,11 @@ public LinkedHashMap<String, PatchSpec> get() throws IllegalArgumentException
return psMap;
List<TensorSpec> inputTensors = findInputImageTensorSpec();
List<Object> inputImages = inputTensors.stream()
.map(k -> this.inputValuesMap.get(k)).collect(Collectors.toList());
.filter(k -> this.inputValuesMap.get(k.getName()) != null)
.map(k -> this.inputValuesMap.get(k.getName())).collect(Collectors.toList());
if (inputImages.size() == 0)
throw new IllegalArgumentException("No inputs have been provided that match the "
+ "specified input tensors specified in the rdf.yaml file.");
LinkedHashMap<String, PatchSpec> specsMap = computePatchSpecsForEveryTensor(inputTensors, inputImages);
// Check that the obtained patch specs are not going to cause errors
checkPatchSpecs(specsMap);
Expand Down Expand Up @@ -280,7 +283,7 @@ private <T extends Type<T>> PatchSpec computePatchSpecs(TensorSpec inputTensorSp
* input patch to the model
* @return an object containing the specs needed to perform patching for the particular tensor
*/
private <T extends Type<T>> PatchSpec computePatchSpecs(TensorSpec inputTensorSpec, RandomAccessibleInterval<T> inputSequence)
private <T extends NumericType<T> & RealType<T>> PatchSpec computePatchSpecs(TensorSpec inputTensorSpec, RandomAccessibleInterval<T> inputSequence)
{
String processingAxesOrder = "xyczb";
int[] inputPatchSize = arrayToWantedAxesOrderAddOnes(inputTensorSpec.getProcessingPatch(),
Expand Down

0 comments on commit db1b825

Please sign in to comment.