Skip to content

Commit

Permalink
improve efficiency of creating imglib2 rais from tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 26, 2023
1 parent ecfbda2 commit 77156fa
Showing 1 changed file with 67 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
package io.bioimage.modelrunner.pytorch.javacpp.tensor;


import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.utils.IndexingUtils;
import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
Expand All @@ -34,7 +37,7 @@
import net.imglib2.type.numeric.real.FloatType;

/**
* A {@link Img} builder for JAvaCPP Pytorch {@link org.bytedeco.pytorch.Tensor} objects.
* A {@link RandomAccessibleInterval} builder for JAvaCPP Pytorch {@link org.bytedeco.pytorch.Tensor} objects.
* Build ImgLib2 objects (backend of {@link io.bioimage.modelrunner.tensor.Tensor})
* from Pytorch {@link org.bytedeco.pytorch.Tensor}
*
Expand All @@ -43,197 +46,152 @@
public class ImgLib2Builder {

/**
* Creates a {@link Img} from a given {@link org.bytedeco.pytorch.Tensor}
* Creates a {@link RandomAccessibleInterval} from a given {@link org.bytedeco.pytorch.Tensor}
*
* @param <T>
* the ImgLib2 data type that the {@link Img} can have
* the ImgLib2 data type that the {@link RandomAccessibleInterval} can have
* @param tensor
* the {@link org.bytedeco.pytorch.Tensor} that wants to be converted
* @return the {@link Img} that resulted from the {@link org.bytedeco.pytorch.Tensor}
* @return the {@link RandomAccessibleInterval} that resulted from the {@link org.bytedeco.pytorch.Tensor}
* @throws IllegalArgumentException if the dataype of the {@link org.bytedeco.pytorch.Tensor}
* is not supported
*/
public static <T extends Type<T>> Img<T> build(org.bytedeco.pytorch.Tensor tensor) throws IllegalArgumentException
@SuppressWarnings("unchecked")
public static <T extends Type<T>> RandomAccessibleInterval<T> build(org.bytedeco.pytorch.Tensor tensor) throws IllegalArgumentException
{
if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Byte)
|| tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Char)) {
return (Img<T>) buildFromTensorByte(tensor);
return (RandomAccessibleInterval<T>) buildFromTensorByte(tensor);
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Int)) {
return (Img<T>) buildFromTensorInt(tensor);
return (RandomAccessibleInterval<T>) buildFromTensorInt(tensor);
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Float)) {
return (Img<T>) buildFromTensorFloat(tensor);
return (RandomAccessibleInterval<T>) buildFromTensorFloat(tensor);
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Double)) {
return (Img<T>) buildFromTensorDouble(tensor);
return (RandomAccessibleInterval<T>) buildFromTensorDouble(tensor);
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Long)) {
return (Img<T>) buildFromTensorLong(tensor);
return (RandomAccessibleInterval<T>) buildFromTensorLong(tensor);
} else {
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.scalar_type());
}
}

/**
* Builds a {@link Img} from a unsigned byte-typed {@link org.bytedeco.pytorch.Tensor}.
* Builds a {@link RandomAccessibleInterval} from a unsigned byte-typed {@link org.bytedeco.pytorch.Tensor}.
*
* @param tensor
* The {@link org.bytedeco.pytorch.Tensor} data is read from.
* @return The {@link Img} built from the tensor of type {@link UnsignedByteType}.
* @return The {@link RandomAccessibleInterval} built from the tensor of type {@link UnsignedByteType}.
*/
public static Img<UnsignedByteType> buildFromTensorUByte(org.bytedeco.pytorch.Tensor tensor) {
long[] tensorShape = tensor.shape();
final ArrayImgFactory<UnsignedByteType> factory = new ArrayImgFactory<>(new UnsignedByteType());
final Img<UnsignedByteType> outputImg = factory.create(tensorShape);
Cursor<UnsignedByteType> tensorCursor = outputImg.cursor();
public static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(org.bytedeco.pytorch.Tensor tensor) {
long[] arrayShape = tensor.shape();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
for (long l : tensorShape) {flatSize *= l;}
byte[] flatArr = new byte[(int) flatSize];
tensor.data_ptr_byte().get(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos,
tensorShape);
byte val = flatArr[flatPos];
if (val < 0)
tensorCursor.get().set(256 + (int) val);
else
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<UnsignedByteType> rai = ArrayImgs.unsignedBytes(flatArr, tensorShape);
return Utils.transpose(rai);
}


/**
* Builds a {@link Img} from a signed byte-typed {@link org.bytedeco.pytorch.Tensor}.
* Builds a {@link RandomAccessibleInterval} from a signed byte-typed {@link org.bytedeco.pytorch.Tensor}.
*
* @param tensor
* The {@link org.bytedeco.pytorch.Tensor} data is read from.
* @return The {@link Img} built from the tensor of type {@link ByteType}.
* @return The {@link RandomAccessibleInterval} built from the tensor of type {@link ByteType}.
*/
public static Img<ByteType> buildFromTensorByte(org.bytedeco.pytorch.Tensor tensor)
public static RandomAccessibleInterval<ByteType> buildFromTensorByte(org.bytedeco.pytorch.Tensor tensor)
{
long[] tensorShape = tensor.shape();
final ArrayImgFactory< ByteType > factory = new ArrayImgFactory<>( new ByteType() );
final Img< ByteType > outputImg = (Img<ByteType>) factory.create(tensorShape);
Cursor<ByteType> tensorCursor= outputImg.cursor();
long flatSize = 1;
long[] arrayShape = tensor.shape();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
for (long l : tensorShape) {flatSize *= l;}
byte[] flatArr = new byte[(int) flatSize];
tensor.data_ptr_byte().get(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
byte val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<ByteType> rai = ArrayImgs.bytes(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a signed integer-typed {@link org.bytedeco.pytorch.Tensor}.
* Builds a {@link RandomAccessibleInterval} from a signed integer-typed {@link org.bytedeco.pytorch.Tensor}.
*
* @param tensor
* The {@link org.bytedeco.pytorch.Tensor} data is read from.
* @return The {@link Img} built from the tensor of type {@link IntType}.
* @return The {@link RandomAccessibleInterval} built from the tensor of type {@link IntType}.
*/
public static Img<IntType> buildFromTensorInt(org.bytedeco.pytorch.Tensor tensor)
public static RandomAccessibleInterval<IntType> buildFromTensorInt(org.bytedeco.pytorch.Tensor tensor)
{
long[] tensorShape = tensor.shape();
final ArrayImgFactory< IntType > factory = new ArrayImgFactory<>( new IntType() );
final Img< IntType > outputImg = (Img<IntType>) factory.create(tensorShape);
Cursor<IntType> tensorCursor= outputImg.cursor();
long flatSize = 1;
long[] arrayShape = tensor.shape();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
for (long l : tensorShape) {flatSize *= l;}
int[] flatArr = new int[(int) flatSize];
tensor.data_ptr_int().get(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
int val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<IntType> rai = ArrayImgs.ints(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a signed float-typed {@link org.bytedeco.pytorch.Tensor}.
* Builds a {@link RandomAccessibleInterval} from a signed float-typed {@link org.bytedeco.pytorch.Tensor}.
*
* @param tensor
* The {@link org.bytedeco.pytorch.Tensor} data is read from.
* @return The {@link Img} built from the tensor of type {@link FloatType}.
* @return The {@link RandomAccessibleInterval} built from the tensor of type {@link FloatType}.
*/
public static Img<FloatType> buildFromTensorFloat(org.bytedeco.pytorch.Tensor tensor)
public static RandomAccessibleInterval<FloatType> buildFromTensorFloat(org.bytedeco.pytorch.Tensor tensor)
{
long[] tensorShape = tensor.shape();
final ArrayImgFactory< FloatType > factory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > outputImg = (Img<FloatType>) factory.create(tensorShape);
Cursor<FloatType> tensorCursor= outputImg.cursor();
long flatSize = 1;
long[] arrayShape = tensor.shape();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
for (long l : tensorShape) {flatSize *= l;}
float[] flatArr = new float[(int) flatSize];
tensor.data_ptr_float().get(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
float val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<FloatType> rai = ArrayImgs.floats(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a signed double-typed {@link org.bytedeco.pytorch.Tensor}.
* Builds a {@link RandomAccessibleInterval} from a signed double-typed {@link org.bytedeco.pytorch.Tensor}.
*
* @param tensor
* The {@link org.bytedeco.pytorch.Tensor} data is read from.
* @return The {@link Img} built from the tensor of type {@link DoubleType}.
* @return The {@link RandomAccessibleInterval} built from the tensor of type {@link DoubleType}.
*/
public static Img<DoubleType> buildFromTensorDouble(org.bytedeco.pytorch.Tensor tensor)
public static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(org.bytedeco.pytorch.Tensor tensor)
{
long[] tensorShape = tensor.shape();
final ArrayImgFactory< DoubleType > factory = new ArrayImgFactory<>( new DoubleType() );
final Img< DoubleType > outputImg = (Img<DoubleType>) factory.create(tensorShape);
Cursor<DoubleType> tensorCursor= outputImg.cursor();
long flatSize = 1;
long[] arrayShape = tensor.shape();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
for (long l : tensorShape) {flatSize *= l;}
double[] flatArr = new double[(int) flatSize];
tensor.data_ptr_double().get(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
double val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<DoubleType> rai = ArrayImgs.doubles(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a signed long-typed {@link org.bytedeco.pytorch.Tensor}.
* Builds a {@link RandomAccessibleInterval} from a signed long-typed {@link org.bytedeco.pytorch.Tensor}.
*
* @param tensor
* The {@link org.bytedeco.pytorch.Tensor} data is read from.
* @return The {@link Img} built from the tensor of type {@link LongType}.
* @return The {@link RandomAccessibleInterval} built from the tensor of type {@link LongType}.
*/
public static Img<LongType> buildFromTensorLong(org.bytedeco.pytorch.Tensor tensor)
public static RandomAccessibleInterval<LongType> buildFromTensorLong(org.bytedeco.pytorch.Tensor tensor)
{
long[] tensorShape = tensor.shape();
final ArrayImgFactory< LongType > factory = new ArrayImgFactory<>( new LongType() );
final Img< LongType > outputImg = (Img<LongType>) factory.create(tensorShape);
Cursor<LongType> tensorCursor= outputImg.cursor();
long flatSize = 1;
long[] arrayShape = tensor.shape();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
for (long l : tensorShape) {flatSize *= l;}
long[] flatArr = new long[(int) flatSize];
tensor.data_ptr_long().get(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
long val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<LongType> rai = ArrayImgs.longs(flatArr, tensorShape);
return Utils.transpose(rai);
}
}

0 comments on commit 77156fa

Please sign in to comment.