Skip to content

Commit

Permalink
start moving towards interprocessing in jvacpp
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Dec 14, 2023
1 parent efdf162 commit 7ea020e
Showing 1 changed file with 10 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
import io.bioimage.modelrunner.utils.CommonUtils;
import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;
import net.imglib2.view.Views;

Expand All @@ -54,19 +57,9 @@ public class JavaCPPTensorBuilder {
* @return The {@link org.bytedeco.pytorch.Tensor} built from the {@link Tensor}.
* @throws IllegalArgumentException if the tensor type is not supported
*/
public static org.bytedeco.pytorch.Tensor build(Tensor tensor) throws IllegalArgumentException
public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch.Tensor build(Tensor<T> tensor) throws IllegalArgumentException
{
if (Util.getTypeFromInterval(tensor.getData()) instanceof ByteType) {
return buildFromTensorByte( tensor.getData());
} else if (Util.getTypeFromInterval(tensor.getData()) instanceof IntType) {
return buildFromTensorInt( tensor.getData());
} else if (Util.getTypeFromInterval(tensor.getData()) instanceof FloatType) {
return buildFromTensorFloat( tensor.getData());
} else if (Util.getTypeFromInterval(tensor.getData()) instanceof DoubleType) {
return buildFromTensorDouble( tensor.getData());
} else {
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.getDataType());
}
return build(tensor.getData());
}

/**
Expand All @@ -79,16 +72,16 @@ public static org.bytedeco.pytorch.Tensor build(Tensor tensor) throws IllegalArg
* @return The {@link org.bytedeco.pytorch.Tensor} built from the {@link RandomAccessibleInterval}.
* @throws IllegalArgumentException if the {@link RandomAccessibleInterval} is not supported
*/
public static <T extends Type<T>> org.bytedeco.pytorch.Tensor build(RandomAccessibleInterval<T> tensor) throws IllegalArgumentException
public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch.Tensor build(RandomAccessibleInterval<T> tensor) throws IllegalArgumentException
{
if (Util.getTypeFromInterval(tensor) instanceof ByteType) {
return buildFromTensorByte( (RandomAccessibleInterval<ByteType>) tensor);
return buildFromTensorByte(Cast.unchecked(tensor));
} else if (Util.getTypeFromInterval(tensor) instanceof IntType) {
return buildFromTensorInt( (RandomAccessibleInterval<IntType>) tensor);
return buildFromTensorInt(Cast.unchecked(tensor));
} else if (Util.getTypeFromInterval(tensor) instanceof ByteType) {
return buildFromTensorFloat( (RandomAccessibleInterval<FloatType>) tensor);
return buildFromTensorFloat(Cast.unchecked(tensor));
} else if (Util.getTypeFromInterval(tensor) instanceof DoubleType) {
return buildFromTensorDouble( (RandomAccessibleInterval<DoubleType>) tensor);
return buildFromTensorDouble(Cast.unchecked(tensor));
} else {
throw new IllegalArgumentException("Unsupported tensor type: " + Util.getTypeFromInterval(tensor).getClass().toString());
}
Expand Down

0 comments on commit 7ea020e

Please sign in to comment.