|
25 | 25 | import java.lang.reflect.InvocationTargetException;
|
26 | 26 | import java.lang.reflect.Method;
|
27 | 27 | import java.lang.reflect.Parameter;
|
| 28 | +import java.util.ArrayList; |
| 29 | +import java.util.Collections; |
28 | 30 | import java.util.List;
|
29 | 31 | import java.util.Map;
|
30 | 32 |
|
|
34 | 36 | import io.bioimage.modelrunner.transformations.BinarizeTransformation;
|
35 | 37 | import net.imglib2.type.NativeType;
|
36 | 38 | import net.imglib2.type.numeric.RealType;
|
| 39 | +import net.imglib2.util.Cast; |
37 | 40 |
|
38 | 41 | /**
|
39 | 42 | * Class that creates an instance able to run the corresponding Bioimage.io processing routine
|
@@ -116,13 +119,43 @@ List<Tensor<R>> run(Tensor<T> tensor) throws RuntimeException {
|
116 | 119 | List<Tensor<R>> run(Tensor<T> tensor, boolean inplace) throws RuntimeException {
|
117 | 120 | Method m;
|
118 | 121 | try {
|
119 |
| - if (inplace) |
120 |
| - m = cls.getMethod(RUN_INPLACE_NAME, Tensor.class); |
121 |
| - else |
122 |
| - m = cls.getMethod(RUN_NAME, Tensor.class); |
123 |
| - |
124 |
| - m.invoke(this.instance, tensor); |
125 |
| - return null; |
| 122 | + if (inplace) { |
| 123 | + m = cls.getMethod(RUN_INPLACE_NAME, Tensor.class); |
| 124 | + m.invoke(this.instance, tensor); |
| 125 | + return Collections.singletonList(Cast.unchecked(tensor)); |
| 126 | + } else { |
| 127 | + m = cls.getMethod(RUN_NAME, Tensor.class); |
| 128 | + Object result = m.invoke(this.instance, tensor); |
| 129 | + |
| 130 | + // Handle different possible return types |
| 131 | + if (result == null) { |
| 132 | + return null; |
| 133 | + } else if (result instanceof List<?>) { |
| 134 | + // Cast and verify each element is a Tensor<R> |
| 135 | + List<?> resultList = (List<?>) result; |
| 136 | + List<Tensor<R>> outputList = new ArrayList<>(); |
| 137 | + |
| 138 | + for (Object item : resultList) { |
| 139 | + if (item instanceof Tensor<?>) { |
| 140 | + @SuppressWarnings("unchecked") |
| 141 | + Tensor<R> tensorItem = (Tensor<R>) item; |
| 142 | + outputList.add(tensorItem); |
| 143 | + } else { |
| 144 | + throw new RuntimeException("Invalid return type: Expected Tensor but got " + |
| 145 | + (item != null ? item.getClass().getName() : "null")); |
| 146 | + } |
| 147 | + } |
| 148 | + return outputList; |
| 149 | + } else if (result instanceof Tensor<?>) { |
| 150 | + // Single Tensor result |
| 151 | + @SuppressWarnings("unchecked") |
| 152 | + Tensor<R> tensorResult = (Tensor<R>) result; |
| 153 | + return Collections.singletonList(tensorResult); |
| 154 | + } else { |
| 155 | + throw new RuntimeException("Unexpected return type: " + |
| 156 | + (result != null ? result.getClass().getName() : "null")); |
| 157 | + } |
| 158 | + } |
126 | 159 | } catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException
|
127 | 160 | | NoSuchMethodException | SecurityException e) {
|
128 | 161 | throw new RuntimeException(Types.stackTrace(e));
|
|
0 commit comments