Skip to content

Commit faecbf9

Browse files
committed
try to clean the patch creation api
1 parent db1b825 commit faecbf9

File tree

1 file changed

+9
-22
lines changed

1 file changed

+9
-22
lines changed

src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,8 @@
3232
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
3333
import io.bioimage.modelrunner.tensor.Tensor;
3434
import io.bioimage.modelrunner.utils.Constants;
35-
import net.imglib2.IterableInterval;
3635
import net.imglib2.RandomAccessibleInterval;
3736
import net.imglib2.type.NativeType;
38-
import net.imglib2.type.Type;
39-
import net.imglib2.type.numeric.NumericType;
4037
import net.imglib2.type.numeric.RealType;
4138

4239
/**
@@ -45,7 +42,7 @@
4542
*
4643
* @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
4744
*/
48-
public class PatchGridCalculator<T extends RealType<T> & NativeType<T>>
45+
public class PatchGridCalculator <T extends RealType<T> & NativeType<T>>
4946
{
5047

5148
private ModelDescriptor descriptor;
@@ -133,9 +130,9 @@ PatchGridCalculator<T> build(ModelDescriptor model, LinkedHashMap<String, Tensor
133130
* to the first input, second image to second output and so on.
134131
* @return the object that creates a list of patch specs for each tensor
135132
*/
136-
public static <T extends NumericType<T> & RealType<T>>
133+
public static <T extends NativeType<T> & RealType<T>>
137134
PatchGridCalculator<T> build(ModelDescriptor model, List<Tensor<T>> inputImagesList) {
138-
LinkedHashMap<String, Object> map = new LinkedHashMap<String, Object>();
135+
LinkedHashMap<String, Tensor<T>> map = new LinkedHashMap<String, Tensor<T>>();
139136
if (inputImagesList.size() != model.getInputTensors().size())
140137
throw new IllegalArgumentException("The size of the list containing the model input RandomAccessibleIntervals"
141138
+ " was not the same size (" + inputImagesList.size() + ") as the number of "
@@ -159,7 +156,7 @@ public LinkedHashMap<String, PatchSpec> get() throws IllegalArgumentException
159156
if (psMap != null)
160157
return psMap;
161158
List<TensorSpec> inputTensors = findInputImageTensorSpec();
162-
List<Object> inputImages = inputTensors.stream()
159+
List<Tensor<T>> inputImages = inputTensors.stream()
163160
.filter(k -> this.inputValuesMap.get(k.getName()) != null)
164161
.map(k -> this.inputValuesMap.get(k.getName())).collect(Collectors.toList());
165162
if (inputImages.size() == 0)
@@ -235,7 +232,7 @@ private List<TensorSpec> findInputImageTensorSpec()
235232
* @return the LinkedHashMap where the key corresponds to the name of the tensor and the value is its
236233
* patch specifications
237234
*/
238-
private LinkedHashMap<String, PatchSpec> computePatchSpecsForEveryTensor(List<TensorSpec> tensors, List<Object> images){
235+
private LinkedHashMap<String, PatchSpec> computePatchSpecsForEveryTensor(List<TensorSpec> tensors, List<Tensor<T>> images){
239236
LinkedHashMap<String, PatchSpec> patchInfoList = new LinkedHashMap<String, PatchSpec>();
240237
for (int i = 0; i < tensors.size(); i ++)
241238
patchInfoList.put(tensors.get(i).getName(), computePatchSpecs(tensors.get(i), images.get(i)));
@@ -254,19 +251,9 @@ private LinkedHashMap<String, PatchSpec> computePatchSpecsForEveryTensor(List<Te
254251
* @throws IllegalArgumentException if the JAva type of the input object is not among the
255252
* allowed ones ({@link Sequence}, {@link NDArray} or {@link Tensor})
256253
*/
257-
private <T extends Type<T>> PatchSpec computePatchSpecs(TensorSpec inputTensorSpec, Object inputObject)
254+
private <T extends NativeType<T> & RealType<T>> PatchSpec computePatchSpecs(TensorSpec inputTensorSpec, Tensor<T> inputObject)
258255
throws IllegalArgumentException {
259-
if (inputObject instanceof RandomAccessibleInterval) {
260-
return computePatchSpecs(inputTensorSpec, (RandomAccessibleInterval<T>) inputObject);
261-
} else if(inputObject instanceof Tensor ) {
262-
return computePatchSpecs(inputTensorSpec, ((Tensor) inputObject).getData());
263-
} else {
264-
throw new IllegalArgumentException("Input tensor '" + inputTensorSpec.getName()
265-
+ "' is not from a supported a Java class (" + inputObject.getClass().toString()
266-
+ ") that JDLL can pass to a model. JDLL can only handle inputs to the model as:" + System.lineSeparator()
267-
+ "- " + RandomAccessibleInterval.class.toString() + System.lineSeparator()
268-
+ "- " + Tensor.class.toString() + System.lineSeparator());
269-
}
256+
return computePatchSpecs(inputTensorSpec, inputObject.getData());
270257
}
271258

272259
/**
@@ -283,9 +270,9 @@ private <T extends Type<T>> PatchSpec computePatchSpecs(TensorSpec inputTensorSp
283270
* input patch to the model
284271
* @return an object containing the specs needed to perform patching for the particular tensor
285272
*/
286-
private <T extends NumericType<T> & RealType<T>> PatchSpec computePatchSpecs(TensorSpec inputTensorSpec, RandomAccessibleInterval<T> inputSequence)
273+
private PatchSpec computePatchSpecs(TensorSpec inputTensorSpec, RandomAccessibleInterval<T> inputSequence)
287274
{
288-
String processingAxesOrder = "xyczb";
275+
String processingAxesOrder = inputTensorSpec.getAxesOrder();
289276
int[] inputPatchSize = arrayToWantedAxesOrderAddOnes(inputTensorSpec.getProcessingPatch(),
290277
inputTensorSpec.getAxesOrder(),
291278
processingAxesOrder);

0 commit comments

Comments
 (0)