Skip to content

Commit 2c04835

Browse files
committed
correct more errors related to pre and post processing
1 parent 49ff319 commit 2c04835

File tree

3 files changed

+48
-8
lines changed

3 files changed

+48
-8
lines changed

src/main/java/io/bioimage/modelrunner/model/processing/Processing.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ private void buildPreprocessing() throws IllegalArgumentException, RuntimeExcept
7575
for (TransformSpec transformation : preprocessing) {
7676
list.add(TransformationInstance.create(transformation));
7777
}
78+
preMap.put(tt.getName(), list);
7879
}
7980
}
8081

@@ -86,6 +87,7 @@ private void buildPostprocessing() throws IllegalArgumentException, RuntimeExcep
8687
for (TransformSpec transformation : preprocessing) {
8788
list.add(TransformationInstance.create(transformation));
8889
}
90+
postMap.put(tt.getName(), list);
8991
}
9092
}
9193

@@ -139,6 +141,8 @@ List<Tensor<R>> preprocess(List<Tensor<T>> tensorList, boolean inplace) {
139141
Tensor<T> tt = tensorList.stream().filter(t -> t.getName().equals(ee.getKey())).findFirst().orElse(null);
140142
if (tt == null)
141143
continue;
144+
if (ee.getValue().size() == 0)
145+
outputs.add(Cast.unchecked(tt));
142146
for (TransformationInstance trans : ee.getValue()) {
143147
List<Tensor<R>> outList = trans.run(tt, inplace);
144148
outputs.addAll(outList);
@@ -184,6 +188,8 @@ List<Tensor<R>> postprocess(List<Tensor<T>> tensorList, boolean inplace) {
184188
Tensor<T> tt = tensorList.stream().filter(t -> t.getName().equals(ee.getKey())).findFirst().orElse(null);
185189
if (tt == null)
186190
continue;
191+
if (ee.getValue().size() == 0)
192+
outputs.add(Cast.unchecked(tt));
187193
for (TransformationInstance trans : ee.getValue()) {
188194
List<Tensor<R>> outList = trans.run(tt, inplace);
189195
outputs.addAll(outList);

src/main/java/io/bioimage/modelrunner/model/processing/TransformationInstance.java

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.lang.reflect.InvocationTargetException;
2626
import java.lang.reflect.Method;
2727
import java.lang.reflect.Parameter;
28+
import java.util.ArrayList;
29+
import java.util.Collections;
2830
import java.util.List;
2931
import java.util.Map;
3032

@@ -34,6 +36,7 @@
3436
import io.bioimage.modelrunner.transformations.BinarizeTransformation;
3537
import net.imglib2.type.NativeType;
3638
import net.imglib2.type.numeric.RealType;
39+
import net.imglib2.util.Cast;
3740

3841
/**
3942
* Class that creates an instance able to run the corresponding Bioimage.io processing routine
@@ -116,13 +119,43 @@ List<Tensor<R>> run(Tensor<T> tensor) throws RuntimeException {
116119
List<Tensor<R>> run(Tensor<T> tensor, boolean inplace) throws RuntimeException {
117120
Method m;
118121
try {
119-
if (inplace)
120-
m = cls.getMethod(RUN_INPLACE_NAME, Tensor.class);
121-
else
122-
m = cls.getMethod(RUN_NAME, Tensor.class);
123-
124-
m.invoke(this.instance, tensor);
125-
return null;
122+
if (inplace) {
123+
m = cls.getMethod(RUN_INPLACE_NAME, Tensor.class);
124+
m.invoke(this.instance, tensor);
125+
return Collections.singletonList(Cast.unchecked(tensor));
126+
} else {
127+
m = cls.getMethod(RUN_NAME, Tensor.class);
128+
Object result = m.invoke(this.instance, tensor);
129+
130+
// Handle different possible return types
131+
if (result == null) {
132+
return null;
133+
} else if (result instanceof List<?>) {
134+
// Cast and verify each element is a Tensor<R>
135+
List<?> resultList = (List<?>) result;
136+
List<Tensor<R>> outputList = new ArrayList<>();
137+
138+
for (Object item : resultList) {
139+
if (item instanceof Tensor<?>) {
140+
@SuppressWarnings("unchecked")
141+
Tensor<R> tensorItem = (Tensor<R>) item;
142+
outputList.add(tensorItem);
143+
} else {
144+
throw new RuntimeException("Invalid return type: Expected Tensor but got " +
145+
(item != null ? item.getClass().getName() : "null"));
146+
}
147+
}
148+
return outputList;
149+
} else if (result instanceof Tensor<?>) {
150+
// Single Tensor result
151+
@SuppressWarnings("unchecked")
152+
Tensor<R> tensorResult = (Tensor<R>) result;
153+
return Collections.singletonList(tensorResult);
154+
} else {
155+
throw new RuntimeException("Unexpected return type: " +
156+
(result != null ? result.getClass().getName() : "null"));
157+
}
158+
}
126159
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException
127160
| NoSuchMethodException | SecurityException e) {
128161
throw new RuntimeException(Types.stackTrace(e));

src/main/java/io/bioimage/modelrunner/transformations/AbstractTensorTransformation.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ protected < R extends RealType< R > & NativeType< R > > Tensor< FloatType > make
6767
final ImgFactory< FloatType > factory = Util.getArrayOrCellImgFactory( input.getData(), new FloatType() );
6868
final Img< FloatType > outputImg = factory.create( input.getData() );
6969
RealTypeConverters.copyFromTo(input.getData(), outputImg);
70-
final Tensor< FloatType > output = Tensor.build( getName() + '_' + input.getName(), input.getAxesOrderString(), outputImg );
70+
// TODO what name final Tensor< FloatType > output = Tensor.build( getName() + '_' + input.getName(), input.getAxesOrderString(), outputImg );
71+
final Tensor< FloatType > output = Tensor.build( input.getName(), input.getAxesOrderString(), outputImg );
7172
return output;
7273
}
7374
}

0 commit comments

Comments
 (0)