From 993cc5c89db1b2fb0684b6c566afa1c1f549659b Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Thu, 26 Oct 2023 19:15:41 +0200 Subject: [PATCH] improve creation of tensors from imgglib2 --- .../javacpp/tensor/JavaCPPTensorBuilder.java | 122 ++++++------------ 1 file changed, 42 insertions(+), 80 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java index 321f289..e6153a2 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java @@ -21,17 +21,15 @@ package io.bioimage.modelrunner.pytorch.javacpp.tensor; import io.bioimage.modelrunner.tensor.Tensor; -import io.bioimage.modelrunner.utils.IndexingUtils; -import net.imglib2.Cursor; +import io.bioimage.modelrunner.tensor.Utils; import net.imglib2.RandomAccessibleInterval; -import net.imglib2.img.Img; +import net.imglib2.blocks.PrimitiveBlocks; import net.imglib2.type.Type; import net.imglib2.type.numeric.integer.ByteType; import net.imglib2.type.numeric.integer.IntType; import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; import net.imglib2.util.Util; -import net.imglib2.view.IntervalView; /** * Class that manages the creation of JAvaCPP Pytorch tensors @@ -102,25 +100,16 @@ public static > org.bytedeco.pytorch.Tensor build(RandomAccess */ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleInterval tensor) { - long[] tensorShape = tensor.dimensionsAsLongArray(); - Cursor tensorCursor; - if (tensor instanceof IntervalView) - tensorCursor = ((IntervalView) tensor).cursor(); - else if (tensor instanceof Img) - tensorCursor = ((Img) tensor).cursor(); - else - throw new IllegalArgumentException("The data of the " + Tensor.class + " has " - + "to be an instance of " + Img.class + " or " + IntervalView.class); - long flatSize = 1; - for (long dd : tensor.dimensionsAsLongArray()) { flatSize *= dd;} - byte[] flatArr = new byte[(int) flatSize]; - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - long[] cursorPos = tensorCursor.positionAsLongArray(); - int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); - byte val = tensorCursor.get().getByte(); - flatArr[flatPos] = val; - } + tensor = Utils.transpose(tensor); + PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor ); + long[] tensorShape = tensor.dimensionsAsLongArray(); + int size = 1; + for (long ll : tensorShape) size *= ll; + final byte[] flatArr = new byte[size]; + int[] sArr = new int[tensorShape.length]; + for (int i = 0; i < sArr.length; i ++) + sArr[i] = (int) tensorShape[i]; + blocks.copy( new long[tensorShape.length], flatArr, sArr ); org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape); return ndarray; } @@ -135,25 +124,16 @@ else if (tensor instanceof Img) */ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleInterval tensor) { - long[] tensorShape = tensor.dimensionsAsLongArray(); - Cursor tensorCursor; - if (tensor instanceof IntervalView) - tensorCursor = ((IntervalView) tensor).cursor(); - else if (tensor instanceof Img) - tensorCursor = ((Img) tensor).cursor(); - else - throw new IllegalArgumentException("The data of the " + Tensor.class + " has " - + "to be an instance of " + Img.class + " or " + IntervalView.class); - long flatSize = 1; - for (long dd : tensor.dimensionsAsLongArray()) { flatSize *= dd;} - int[] flatArr = new int[(int) flatSize]; - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - long[] cursorPos = tensorCursor.positionAsLongArray(); - int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); - int val = tensorCursor.get().getInteger(); - flatArr[flatPos] = val; - } + tensor = Utils.transpose(tensor); + PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor ); + long[] tensorShape = tensor.dimensionsAsLongArray(); + int size = 1; + for (long ll : tensorShape) size *= ll; + final int[] flatArr = new int[size]; + int[] sArr = new int[tensorShape.length]; + for (int i = 0; i < sArr.length; i ++) + sArr[i] = (int) tensorShape[i]; + blocks.copy( new long[tensorShape.length], flatArr, sArr ); org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape); return ndarray; } @@ -168,25 +148,16 @@ else if (tensor instanceof Img) */ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessibleInterval tensor) { - long[] tensorShape = tensor.dimensionsAsLongArray(); - Cursor tensorCursor; - if (tensor instanceof IntervalView) - tensorCursor = ((IntervalView) tensor).cursor(); - else if (tensor instanceof Img) - tensorCursor = ((Img) tensor).cursor(); - else - throw new IllegalArgumentException("The data of the " + Tensor.class + " has " - + "to be an instance of " + Img.class + " or " + IntervalView.class); - long flatSize = 1; - for (long dd : tensor.dimensionsAsLongArray()) { flatSize *= dd;} - float[] flatArr = new float[(int) flatSize]; - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - long[] cursorPos = tensorCursor.positionAsLongArray(); - int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); - float val = tensorCursor.get().getRealFloat(); - flatArr[flatPos] = val; - } + tensor = Utils.transpose(tensor); + PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor ); + long[] tensorShape = tensor.dimensionsAsLongArray(); + int size = 1; + for (long ll : tensorShape) size *= ll; + final float[] flatArr = new float[size]; + int[] sArr = new int[tensorShape.length]; + for (int i = 0; i < sArr.length; i ++) + sArr[i] = (int) tensorShape[i]; + blocks.copy( new long[tensorShape.length], flatArr, sArr ); org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape); return ndarray; } @@ -201,25 +172,16 @@ else if (tensor instanceof Img) */ private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibleInterval tensor) { - long[] tensorShape = tensor.dimensionsAsLongArray(); - Cursor tensorCursor; - if (tensor instanceof IntervalView) - tensorCursor = ((IntervalView) tensor).cursor(); - else if (tensor instanceof Img) - tensorCursor = ((Img) tensor).cursor(); - else - throw new IllegalArgumentException("The data of the " + Tensor.class + " has " - + "to be an instance of " + Img.class + " or " + IntervalView.class); - long flatSize = 1; - for (long dd : tensor.dimensionsAsLongArray()) { flatSize *= dd;} - double[] flatArr = new double[(int) flatSize]; - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - long[] cursorPos = tensorCursor.positionAsLongArray(); - int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); - double val = tensorCursor.get().getRealDouble(); - flatArr[flatPos] = val; - } + tensor = Utils.transpose(tensor); + PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor ); + long[] tensorShape = tensor.dimensionsAsLongArray(); + int size = 1; + for (long ll : tensorShape) size *= ll; + final double[] flatArr = new double[size]; + int[] sArr = new int[tensorShape.length]; + for (int i = 0; i < sArr.length; i ++) + sArr[i] = (int) tensorShape[i]; + blocks.copy( new long[tensorShape.length], flatArr, sArr ); org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape); return ndarray; }