32
32
import io .bioimage .modelrunner .bioimageio .description .TensorSpec ;
33
33
import io .bioimage .modelrunner .tensor .Tensor ;
34
34
import io .bioimage .modelrunner .utils .Constants ;
35
- import net .imglib2 .IterableInterval ;
36
35
import net .imglib2 .RandomAccessibleInterval ;
37
36
import net .imglib2 .type .NativeType ;
38
- import net .imglib2 .type .Type ;
39
- import net .imglib2 .type .numeric .NumericType ;
40
37
import net .imglib2 .type .numeric .RealType ;
41
38
42
39
/**
45
42
*
46
43
* @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
47
44
*/
48
- public class PatchGridCalculator <T extends RealType <T > & NativeType <T >>
45
+ public class PatchGridCalculator <T extends RealType <T > & NativeType <T >>
49
46
{
50
47
51
48
private ModelDescriptor descriptor ;
@@ -133,9 +130,9 @@ PatchGridCalculator<T> build(ModelDescriptor model, LinkedHashMap<String, Tensor
133
130
* to the first input, second image to second output and so on.
134
131
* @return the object that creates a list of patch specs for each tensor
135
132
*/
136
- public static <T extends NumericType <T > & RealType <T >>
133
+ public static <T extends NativeType <T > & RealType <T >>
137
134
PatchGridCalculator <T > build (ModelDescriptor model , List <Tensor <T >> inputImagesList ) {
138
- LinkedHashMap <String , Object > map = new LinkedHashMap <String , Object >();
135
+ LinkedHashMap <String , Tensor < T >> map = new LinkedHashMap <String , Tensor < T > >();
139
136
if (inputImagesList .size () != model .getInputTensors ().size ())
140
137
throw new IllegalArgumentException ("The size of the list containing the model input RandomAccessibleIntervals"
141
138
+ " was not the same size (" + inputImagesList .size () + ") as the number of "
@@ -159,7 +156,7 @@ public LinkedHashMap<String, PatchSpec> get() throws IllegalArgumentException
159
156
if (psMap != null )
160
157
return psMap ;
161
158
List <TensorSpec > inputTensors = findInputImageTensorSpec ();
162
- List <Object > inputImages = inputTensors .stream ()
159
+ List <Tensor < T > > inputImages = inputTensors .stream ()
163
160
.filter (k -> this .inputValuesMap .get (k .getName ()) != null )
164
161
.map (k -> this .inputValuesMap .get (k .getName ())).collect (Collectors .toList ());
165
162
if (inputImages .size () == 0 )
@@ -235,7 +232,7 @@ private List<TensorSpec> findInputImageTensorSpec()
235
232
* @return the LinkedHashMap where the key corresponds to the name of the tensor and the value is its
236
233
* patch specifications
237
234
*/
238
- private LinkedHashMap <String , PatchSpec > computePatchSpecsForEveryTensor (List <TensorSpec > tensors , List <Object > images ){
235
+ private LinkedHashMap <String , PatchSpec > computePatchSpecsForEveryTensor (List <TensorSpec > tensors , List <Tensor < T > > images ){
239
236
LinkedHashMap <String , PatchSpec > patchInfoList = new LinkedHashMap <String , PatchSpec >();
240
237
for (int i = 0 ; i < tensors .size (); i ++)
241
238
patchInfoList .put (tensors .get (i ).getName (), computePatchSpecs (tensors .get (i ), images .get (i )));
@@ -254,19 +251,9 @@ private LinkedHashMap<String, PatchSpec> computePatchSpecsForEveryTensor(List<Te
254
251
* @throws IllegalArgumentException if the JAva type of the input object is not among the
255
252
* allowed ones ({@link Sequence}, {@link NDArray} or {@link Tensor})
256
253
*/
257
- private <T extends Type <T >> PatchSpec computePatchSpecs (TensorSpec inputTensorSpec , Object inputObject )
254
+ private <T extends NativeType <T > & RealType < T >> PatchSpec computePatchSpecs (TensorSpec inputTensorSpec , Tensor < T > inputObject )
258
255
throws IllegalArgumentException {
259
- if (inputObject instanceof RandomAccessibleInterval ) {
260
- return computePatchSpecs (inputTensorSpec , (RandomAccessibleInterval <T >) inputObject );
261
- } else if (inputObject instanceof Tensor ) {
262
- return computePatchSpecs (inputTensorSpec , ((Tensor ) inputObject ).getData ());
263
- } else {
264
- throw new IllegalArgumentException ("Input tensor '" + inputTensorSpec .getName ()
265
- + "' is not from a supported a Java class (" + inputObject .getClass ().toString ()
266
- + ") that JDLL can pass to a model. JDLL can only handle inputs to the model as:" + System .lineSeparator ()
267
- + "- " + RandomAccessibleInterval .class .toString () + System .lineSeparator ()
268
- + "- " + Tensor .class .toString () + System .lineSeparator ());
269
- }
256
+ return computePatchSpecs (inputTensorSpec , inputObject .getData ());
270
257
}
271
258
272
259
/**
@@ -283,9 +270,9 @@ private <T extends Type<T>> PatchSpec computePatchSpecs(TensorSpec inputTensorSp
283
270
* input patch to the model
284
271
* @return an object containing the specs needed to perform patching for the particular tensor
285
272
*/
286
- private < T extends NumericType < T > & RealType < T >> PatchSpec computePatchSpecs (TensorSpec inputTensorSpec , RandomAccessibleInterval <T > inputSequence )
273
+ private PatchSpec computePatchSpecs (TensorSpec inputTensorSpec , RandomAccessibleInterval <T > inputSequence )
287
274
{
288
- String processingAxesOrder = "xyczb" ;
275
+ String processingAxesOrder = inputTensorSpec . getAxesOrder () ;
289
276
int [] inputPatchSize = arrayToWantedAxesOrderAddOnes (inputTensorSpec .getProcessingPatch (),
290
277
inputTensorSpec .getAxesOrder (),
291
278
processingAxesOrder );
0 commit comments