Skip to content

Commit

Permalink
improve robustness
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Mar 25, 2024
1 parent 57ec6bf commit f0c2e91
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


import io.bioimage.modelrunner.tensor.Utils;

import io.bioimage.modelrunner.utils.CommonUtils;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.Type;
Expand All @@ -32,6 +32,8 @@
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;

import java.util.Arrays;

import org.tensorflow.Tensor;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
Expand Down Expand Up @@ -95,6 +97,9 @@ public static <T extends Type<T>> RandomAccessibleInterval<T> build(Tensor<? ext
private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(Tensor<TUint8> tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 1))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per double ubyte tensor supported: " + Integer.MAX_VALUE / 1);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
Expand All @@ -115,6 +120,9 @@ private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(T
private static RandomAccessibleInterval<IntType> buildFromTensorInt(Tensor<TInt32> tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
Expand All @@ -135,6 +143,9 @@ private static RandomAccessibleInterval<IntType> buildFromTensorInt(Tensor<TInt3
private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(Tensor<TFloat32> tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
Expand All @@ -155,6 +166,9 @@ private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(Tensor<T
private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(Tensor<TFloat64> tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
Expand All @@ -175,6 +189,9 @@ private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(Tensor
private static RandomAccessibleInterval<LongType> buildFromTensorLong(Tensor<TInt64> tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ private static Tensor<TUint8> buildUByte(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 1))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down Expand Up @@ -172,9 +172,9 @@ private static Tensor<TInt32> buildInt(
RandomAccessibleInterval<IntType> tensor) throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 4))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per int tensor supported: " + Integer.MAX_VALUE / 4);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down Expand Up @@ -210,9 +210,9 @@ private static Tensor<TInt64> buildLong(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 8))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per long tensor supported: " + Integer.MAX_VALUE / 8);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down Expand Up @@ -248,9 +248,9 @@ private static Tensor<TFloat32> buildFloat(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 4))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per float tensor supported: " + Integer.MAX_VALUE / 4);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down Expand Up @@ -286,9 +286,9 @@ private static Tensor<TFloat64> buildDouble(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 8))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per double tensor supported: " + Integer.MAX_VALUE / 8);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down

0 comments on commit f0c2e91

Please sign in to comment.