From 77156fab1608083c7911c76dfdb3eb56caa24639 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Thu, 26 Oct 2023 18:52:40 +0200 Subject: [PATCH] improve efficiency of creating imglib2 rais from tensors --- .../javacpp/tensor/ImgLib2Builder.java | 176 +++++++----------- 1 file changed, 67 insertions(+), 109 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/ImgLib2Builder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/ImgLib2Builder.java index e7a51f3..8118ec7 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/ImgLib2Builder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/ImgLib2Builder.java @@ -21,10 +21,13 @@ package io.bioimage.modelrunner.pytorch.javacpp.tensor; +import io.bioimage.modelrunner.tensor.Utils; import io.bioimage.modelrunner.utils.IndexingUtils; import net.imglib2.Cursor; +import net.imglib2.RandomAccessibleInterval; import net.imglib2.img.Img; import net.imglib2.img.array.ArrayImgFactory; +import net.imglib2.img.array.ArrayImgs; import net.imglib2.type.Type; import net.imglib2.type.numeric.integer.ByteType; import net.imglib2.type.numeric.integer.IntType; @@ -34,7 +37,7 @@ import net.imglib2.type.numeric.real.FloatType; /** -* A {@link Img} builder for JAvaCPP Pytorch {@link org.bytedeco.pytorch.Tensor} objects. +* A {@link RandomAccessibleInterval} builder for JAvaCPP Pytorch {@link org.bytedeco.pytorch.Tensor} objects. * Build ImgLib2 objects (backend of {@link io.bioimage.modelrunner.tensor.Tensor}) * from Pytorch {@link org.bytedeco.pytorch.Tensor} * @@ -43,197 +46,152 @@ public class ImgLib2Builder { /** - * Creates a {@link Img} from a given {@link org.bytedeco.pytorch.Tensor} + * Creates a {@link RandomAccessibleInterval} from a given {@link org.bytedeco.pytorch.Tensor} * * @param - * the ImgLib2 data type that the {@link Img} can have + * the ImgLib2 data type that the {@link RandomAccessibleInterval} can have * @param tensor * the {@link org.bytedeco.pytorch.Tensor} that wants to be converted - * @return the {@link Img} that resulted from the {@link org.bytedeco.pytorch.Tensor} + * @return the {@link RandomAccessibleInterval} that resulted from the {@link org.bytedeco.pytorch.Tensor} * @throws IllegalArgumentException if the dataype of the {@link org.bytedeco.pytorch.Tensor} * is not supported */ - public static > Img build(org.bytedeco.pytorch.Tensor tensor) throws IllegalArgumentException + @SuppressWarnings("unchecked") + public static > RandomAccessibleInterval build(org.bytedeco.pytorch.Tensor tensor) throws IllegalArgumentException { if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Byte) || tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Char)) { - return (Img) buildFromTensorByte(tensor); + return (RandomAccessibleInterval) buildFromTensorByte(tensor); } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Int)) { - return (Img) buildFromTensorInt(tensor); + return (RandomAccessibleInterval) buildFromTensorInt(tensor); } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Float)) { - return (Img) buildFromTensorFloat(tensor); + return (RandomAccessibleInterval) buildFromTensorFloat(tensor); } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Double)) { - return (Img) buildFromTensorDouble(tensor); + return (RandomAccessibleInterval) buildFromTensorDouble(tensor); } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Long)) { - return (Img) buildFromTensorLong(tensor); + return (RandomAccessibleInterval) buildFromTensorLong(tensor); } else { throw new IllegalArgumentException("Unsupported tensor type: " + tensor.scalar_type()); } } /** - * Builds a {@link Img} from a unsigned byte-typed {@link org.bytedeco.pytorch.Tensor}. + * Builds a {@link RandomAccessibleInterval} from a unsigned byte-typed {@link org.bytedeco.pytorch.Tensor}. * * @param tensor * The {@link org.bytedeco.pytorch.Tensor} data is read from. - * @return The {@link Img} built from the tensor of type {@link UnsignedByteType}. + * @return The {@link RandomAccessibleInterval} built from the tensor of type {@link UnsignedByteType}. */ - public static Img buildFromTensorUByte(org.bytedeco.pytorch.Tensor tensor) { - long[] tensorShape = tensor.shape(); - final ArrayImgFactory factory = new ArrayImgFactory<>(new UnsignedByteType()); - final Img outputImg = factory.create(tensorShape); - Cursor tensorCursor = outputImg.cursor(); + public static RandomAccessibleInterval buildFromTensorUByte(org.bytedeco.pytorch.Tensor tensor) { + long[] arrayShape = tensor.shape(); + long[] tensorShape = new long[arrayShape.length]; + for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; long flatSize = 1; for (long l : tensorShape) {flatSize *= l;} byte[] flatArr = new byte[(int) flatSize]; tensor.data_ptr_byte().get(flatArr); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - long[] cursorPos = tensorCursor.positionAsLongArray(); - int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, - tensorShape); - byte val = flatArr[flatPos]; - if (val < 0) - tensorCursor.get().set(256 + (int) val); - else - tensorCursor.get().set(val); - } - return outputImg; + RandomAccessibleInterval rai = ArrayImgs.unsignedBytes(flatArr, tensorShape); + return Utils.transpose(rai); } /** - * Builds a {@link Img} from a signed byte-typed {@link org.bytedeco.pytorch.Tensor}. + * Builds a {@link RandomAccessibleInterval} from a signed byte-typed {@link org.bytedeco.pytorch.Tensor}. * * @param tensor * The {@link org.bytedeco.pytorch.Tensor} data is read from. - * @return The {@link Img} built from the tensor of type {@link ByteType}. + * @return The {@link RandomAccessibleInterval} built from the tensor of type {@link ByteType}. */ - public static Img buildFromTensorByte(org.bytedeco.pytorch.Tensor tensor) + public static RandomAccessibleInterval buildFromTensorByte(org.bytedeco.pytorch.Tensor tensor) { - long[] tensorShape = tensor.shape(); - final ArrayImgFactory< ByteType > factory = new ArrayImgFactory<>( new ByteType() ); - final Img< ByteType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - long flatSize = 1; + long[] arrayShape = tensor.shape(); + long[] tensorShape = new long[arrayShape.length]; + for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; + long flatSize = 1; for (long l : tensorShape) {flatSize *= l;} byte[] flatArr = new byte[(int) flatSize]; tensor.data_ptr_byte().get(flatArr); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - long[] cursorPos = tensorCursor.positionAsLongArray(); - int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); - byte val = flatArr[flatPos]; - tensorCursor.get().set(val); - } - return outputImg; + RandomAccessibleInterval rai = ArrayImgs.bytes(flatArr, tensorShape); + return Utils.transpose(rai); } /** - * Builds a {@link Img} from a signed integer-typed {@link org.bytedeco.pytorch.Tensor}. + * Builds a {@link RandomAccessibleInterval} from a signed integer-typed {@link org.bytedeco.pytorch.Tensor}. * * @param tensor * The {@link org.bytedeco.pytorch.Tensor} data is read from. - * @return The {@link Img} built from the tensor of type {@link IntType}. + * @return The {@link RandomAccessibleInterval} built from the tensor of type {@link IntType}. */ - public static Img buildFromTensorInt(org.bytedeco.pytorch.Tensor tensor) + public static RandomAccessibleInterval buildFromTensorInt(org.bytedeco.pytorch.Tensor tensor) { - long[] tensorShape = tensor.shape(); - final ArrayImgFactory< IntType > factory = new ArrayImgFactory<>( new IntType() ); - final Img< IntType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - long flatSize = 1; + long[] arrayShape = tensor.shape(); + long[] tensorShape = new long[arrayShape.length]; + for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; + long flatSize = 1; for (long l : tensorShape) {flatSize *= l;} int[] flatArr = new int[(int) flatSize]; tensor.data_ptr_int().get(flatArr); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - long[] cursorPos = tensorCursor.positionAsLongArray(); - int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); - int val = flatArr[flatPos]; - tensorCursor.get().set(val); - } - return outputImg; + RandomAccessibleInterval rai = ArrayImgs.ints(flatArr, tensorShape); + return Utils.transpose(rai); } /** - * Builds a {@link Img} from a signed float-typed {@link org.bytedeco.pytorch.Tensor}. + * Builds a {@link RandomAccessibleInterval} from a signed float-typed {@link org.bytedeco.pytorch.Tensor}. * * @param tensor * The {@link org.bytedeco.pytorch.Tensor} data is read from. - * @return The {@link Img} built from the tensor of type {@link FloatType}. + * @return The {@link RandomAccessibleInterval} built from the tensor of type {@link FloatType}. */ - public static Img buildFromTensorFloat(org.bytedeco.pytorch.Tensor tensor) + public static RandomAccessibleInterval buildFromTensorFloat(org.bytedeco.pytorch.Tensor tensor) { - long[] tensorShape = tensor.shape(); - final ArrayImgFactory< FloatType > factory = new ArrayImgFactory<>( new FloatType() ); - final Img< FloatType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - long flatSize = 1; + long[] arrayShape = tensor.shape(); + long[] tensorShape = new long[arrayShape.length]; + for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; + long flatSize = 1; for (long l : tensorShape) {flatSize *= l;} float[] flatArr = new float[(int) flatSize]; tensor.data_ptr_float().get(flatArr); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - long[] cursorPos = tensorCursor.positionAsLongArray(); - int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); - float val = flatArr[flatPos]; - tensorCursor.get().set(val); - } - return outputImg; + RandomAccessibleInterval rai = ArrayImgs.floats(flatArr, tensorShape); + return Utils.transpose(rai); } /** - * Builds a {@link Img} from a signed double-typed {@link org.bytedeco.pytorch.Tensor}. + * Builds a {@link RandomAccessibleInterval} from a signed double-typed {@link org.bytedeco.pytorch.Tensor}. * * @param tensor * The {@link org.bytedeco.pytorch.Tensor} data is read from. - * @return The {@link Img} built from the tensor of type {@link DoubleType}. + * @return The {@link RandomAccessibleInterval} built from the tensor of type {@link DoubleType}. */ - public static Img buildFromTensorDouble(org.bytedeco.pytorch.Tensor tensor) + public static RandomAccessibleInterval buildFromTensorDouble(org.bytedeco.pytorch.Tensor tensor) { - long[] tensorShape = tensor.shape(); - final ArrayImgFactory< DoubleType > factory = new ArrayImgFactory<>( new DoubleType() ); - final Img< DoubleType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - long flatSize = 1; + long[] arrayShape = tensor.shape(); + long[] tensorShape = new long[arrayShape.length]; + for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; + long flatSize = 1; for (long l : tensorShape) {flatSize *= l;} double[] flatArr = new double[(int) flatSize]; tensor.data_ptr_double().get(flatArr); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - long[] cursorPos = tensorCursor.positionAsLongArray(); - int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); - double val = flatArr[flatPos]; - tensorCursor.get().set(val); - } - return outputImg; + RandomAccessibleInterval rai = ArrayImgs.doubles(flatArr, tensorShape); + return Utils.transpose(rai); } /** - * Builds a {@link Img} from a signed long-typed {@link org.bytedeco.pytorch.Tensor}. + * Builds a {@link RandomAccessibleInterval} from a signed long-typed {@link org.bytedeco.pytorch.Tensor}. * * @param tensor * The {@link org.bytedeco.pytorch.Tensor} data is read from. - * @return The {@link Img} built from the tensor of type {@link LongType}. + * @return The {@link RandomAccessibleInterval} built from the tensor of type {@link LongType}. */ - public static Img buildFromTensorLong(org.bytedeco.pytorch.Tensor tensor) + public static RandomAccessibleInterval buildFromTensorLong(org.bytedeco.pytorch.Tensor tensor) { - long[] tensorShape = tensor.shape(); - final ArrayImgFactory< LongType > factory = new ArrayImgFactory<>( new LongType() ); - final Img< LongType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - long flatSize = 1; + long[] arrayShape = tensor.shape(); + long[] tensorShape = new long[arrayShape.length]; + for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; + long flatSize = 1; for (long l : tensorShape) {flatSize *= l;} long[] flatArr = new long[(int) flatSize]; tensor.data_ptr_long().get(flatArr); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - long[] cursorPos = tensorCursor.positionAsLongArray(); - int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); - long val = flatArr[flatPos]; - tensorCursor.get().set(val); - } - return outputImg; + RandomAccessibleInterval rai = ArrayImgs.longs(flatArr, tensorShape); + return Utils.transpose(rai); } }