From d7febcb64c8bc9a10324b10bdd300f2dc0f528d4 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Tue, 28 Nov 2023 14:23:42 +0100 Subject: [PATCH] increase robustness creating pytorch tensors --- .../javacpp/tensor/JavaCPPTensorBuilder.java | 60 +++++++++++++++---- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java index d964d85..e515f68 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java @@ -20,16 +20,22 @@ */ package io.bioimage.modelrunner.pytorch.javacpp.tensor; +import java.nio.ByteBuffer; +import java.util.Arrays; + import io.bioimage.modelrunner.tensor.Tensor; import io.bioimage.modelrunner.tensor.Utils; +import net.imglib2.Cursor; import net.imglib2.RandomAccessibleInterval; import net.imglib2.blocks.PrimitiveBlocks; import net.imglib2.type.Type; import net.imglib2.type.numeric.integer.ByteType; import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.integer.UnsignedByteType; import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; import net.imglib2.util.Util; +import net.imglib2.view.Views; /** * Class that manages the creation of JAvaCPP Pytorch tensors @@ -101,8 +107,10 @@ public static > org.bytedeco.pytorch.Tensor build(RandomAccess private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleInterval tensor) { long[] ogShape = tensor.dimensionsAsLongArray(); + if (CommonUtils.int32Overflows(ogShape)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); tensor = Utils.transpose(tensor); - PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; for (long ll : tensorShape) size *= ll; @@ -110,7 +118,13 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); + + Cursor cursor = Views.flatIterable(tensor).cursor(); + int i = 0; + while (cursor.hasNext()) { + cursor.fwd(); + flatArr[i ++] = cursor.get().get(); + } org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; } @@ -126,8 +140,10 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleInterval tensor) { long[] ogShape = tensor.dimensionsAsLongArray(); + if (CommonUtils.int32Overflows(ogShape)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); tensor = Utils.transpose(tensor); - PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; for (long ll : tensorShape) size *= ll; @@ -135,8 +151,14 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); - org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); + + Cursor cursor = Views.flatIterable(tensor).cursor(); + int i = 0; + while (cursor.hasNext()) { + cursor.fwd(); + flatArr[i ++] = cursor.get().get(); + } + org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; } @@ -151,8 +173,10 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessibleInterval tensor) { long[] ogShape = tensor.dimensionsAsLongArray(); + if (CommonUtils.int32Overflows(ogShape)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); tensor = Utils.transpose(tensor); - PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; for (long ll : tensorShape) size *= ll; @@ -160,8 +184,14 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); - org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); + + Cursor cursor = Views.flatIterable(tensor).cursor(); + int i = 0; + while (cursor.hasNext()) { + cursor.fwd(); + flatArr[i ++] = cursor.get().get(); + } + org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; } @@ -176,8 +206,10 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibleInterval tensor) { long[] ogShape = tensor.dimensionsAsLongArray(); + if (CommonUtils.int32Overflows(ogShape)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); tensor = Utils.transpose(tensor); - PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; for (long ll : tensorShape) size *= ll; @@ -185,8 +217,14 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibl int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); - org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); + + Cursor cursor = Views.flatIterable(tensor).cursor(); + int i = 0; + while (cursor.hasNext()) { + cursor.fwd(); + flatArr[i ++] = cursor.get().get(); + } + org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; } }