diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java index 9e46072..8fbcdb6 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java @@ -141,7 +141,7 @@ private static Tensor buildUByte( int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); ByteDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); Tensor ndarray = Tensor.of(TUint8.DTYPE, Shape.of(ogShape), dataBuffer); return ndarray; @@ -170,7 +170,7 @@ private static Tensor buildInt( int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); IntDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); Tensor ndarray = TInt32.tensorOf(Shape.of(ogShape), dataBuffer); return ndarray; @@ -200,7 +200,7 @@ private static Tensor buildLong( int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); LongDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); Tensor ndarray = TInt64.tensorOf(Shape.of(ogShape), dataBuffer); return ndarray; @@ -230,7 +230,7 @@ private static Tensor buildFloat( int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); FloatDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); Tensor ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer); return ndarray; @@ -260,7 +260,7 @@ private static Tensor buildDouble( int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); Tensor ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer); return ndarray;