Skip to content

Commit

Permalink
improve robustness creating tf tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 28, 2023
1 parent 23a84b4 commit e8faabd
Showing 1 changed file with 56 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
package io.bioimage.modelrunner.tensorflow.v2.api020.tensor;

import io.bioimage.modelrunner.tensor.Utils;

import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.blocks.PrimitiveBlocks;
import net.imglib2.img.Img;
Expand All @@ -33,6 +33,10 @@
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.Views;

import java.nio.ByteBuffer;
import java.util.Arrays;

import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
Expand All @@ -47,6 +51,7 @@
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;
import org.tensorflow.types.UInt8;
import org.tensorflow.types.family.TType;

/**
Expand Down Expand Up @@ -132,16 +137,24 @@ private static Tensor<TUint8> buildUByte(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< UnsignedByteType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final byte[] flatArr = new byte[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<UnsignedByteType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().getByte();
}
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false);
Tensor<TUint8> ndarray = Tensor.of(TUint8.DTYPE, Shape.of(ogShape), dataBuffer);
return ndarray;
Expand All @@ -161,16 +174,24 @@ private static Tensor<TInt32> buildInt(
RandomAccessibleInterval<IntType> tensor) throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final int[] flatArr = new int[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<IntType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
IntDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false);
Tensor<TInt32> ndarray = TInt32.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
Expand All @@ -191,16 +212,24 @@ private static Tensor<TInt64> buildLong(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< LongType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final long[] flatArr = new long[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<LongType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
LongDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false);
Tensor<TInt64> ndarray = TInt64.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
Expand All @@ -221,16 +250,24 @@ private static Tensor<TFloat32> buildFloat(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final float[] flatArr = new float[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<FloatType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false);
Tensor<TFloat32> ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
Expand All @@ -251,16 +288,24 @@ private static Tensor<TFloat64> buildDouble(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final double[] flatArr = new double[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<DoubleType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false);
Tensor<TFloat64> ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
Expand Down

0 comments on commit e8faabd

Please sign in to comment.