diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/ImgLib2Builder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/ImgLib2Builder.java index 4ecc2b7..11e6148 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/ImgLib2Builder.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/ImgLib2Builder.java @@ -22,7 +22,7 @@ import io.bioimage.modelrunner.tensor.Utils; - +import io.bioimage.modelrunner.utils.CommonUtils; import net.imglib2.RandomAccessibleInterval; import net.imglib2.img.array.ArrayImgs; import net.imglib2.type.Type; @@ -32,6 +32,8 @@ import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; +import java.util.Arrays; + import org.tensorflow.Tensor; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -95,6 +97,9 @@ public static > RandomAccessibleInterval build(Tensor buildFromTensorUByte(Tensor tensor) { long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 1)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per double ubyte tensor supported: " + Integer.MAX_VALUE / 1); long[] tensorShape = new long[arrayShape.length]; for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; int totalSize = 1; @@ -115,6 +120,9 @@ private static RandomAccessibleInterval buildFromTensorUByte(T private static RandomAccessibleInterval buildFromTensorInt(Tensor tensor) { long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 4)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4); long[] tensorShape = new long[arrayShape.length]; for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; int totalSize = 1; @@ -135,6 +143,9 @@ private static RandomAccessibleInterval buildFromTensorInt(Tensor buildFromTensorFloat(Tensor tensor) { long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 4)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4); long[] tensorShape = new long[arrayShape.length]; for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; int totalSize = 1; @@ -155,6 +166,9 @@ private static RandomAccessibleInterval buildFromTensorFloat(Tensor buildFromTensorDouble(Tensor tensor) { long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 8)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8); long[] tensorShape = new long[arrayShape.length]; for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; int totalSize = 1; @@ -175,6 +189,9 @@ private static RandomAccessibleInterval buildFromTensorDouble(Tensor private static RandomAccessibleInterval buildFromTensorLong(Tensor tensor) { long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 8)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8); long[] tensorShape = new long[arrayShape.length]; for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; int totalSize = 1; diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java index a34696d..680b290 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java @@ -135,9 +135,9 @@ private static Tensor buildUByte( throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); - if (CommonUtils.int32Overflows(ogShape)) + if (CommonUtils.int32Overflows(ogShape, 1)) throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) - + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); tensor = Utils.transpose(tensor); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; @@ -172,9 +172,9 @@ private static Tensor buildInt( RandomAccessibleInterval tensor) throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); - if (CommonUtils.int32Overflows(ogShape)) + if (CommonUtils.int32Overflows(ogShape, 4)) throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) - + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); + + " is too big. Max number of elements per int tensor supported: " + Integer.MAX_VALUE / 4); tensor = Utils.transpose(tensor); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; @@ -210,9 +210,9 @@ private static Tensor buildLong( throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); - if (CommonUtils.int32Overflows(ogShape)) + if (CommonUtils.int32Overflows(ogShape, 8)) throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) - + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); + + " is too big. Max number of elements per long tensor supported: " + Integer.MAX_VALUE / 8); tensor = Utils.transpose(tensor); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; @@ -248,9 +248,9 @@ private static Tensor buildFloat( throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); - if (CommonUtils.int32Overflows(ogShape)) + if (CommonUtils.int32Overflows(ogShape, 4)) throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) - + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); + + " is too big. Max number of elements per float tensor supported: " + Integer.MAX_VALUE / 4); tensor = Utils.transpose(tensor); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; @@ -286,9 +286,9 @@ private static Tensor buildDouble( throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); - if (CommonUtils.int32Overflows(ogShape)) + if (CommonUtils.int32Overflows(ogShape, 8)) throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) - + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); + + " is too big. Max number of elements per double tensor supported: " + Integer.MAX_VALUE / 8); tensor = Utils.transpose(tensor); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1;