Skip to content

Commit

Permalink
improve conversion from imglib2 into tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 26, 2023
1 parent e8cc492 commit ef6d4d2
Showing 1 changed file with 52 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
package io.bioimage.modelrunner.tensorflow.v2.api020.tensor;


import io.bioimage.modelrunner.utils.IndexingUtils;
import io.bioimage.modelrunner.tensor.Utils;

import net.imglib2.Cursor;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
Expand All @@ -42,7 +41,7 @@
import org.tensorflow.types.family.TType;

/**
* A {@link Img} builder for TensorFlow {@link Tensor} objects.
* A {@link RandomAccessibleInterval} builder for TensorFlow {@link Tensor} objects.
* Build ImgLib2 objects (backend of {@link io.bioimage.modelrunner.tensor.Tensor})
* from Tensorflow 2 {@link Tensor}
*
Expand All @@ -58,169 +57,131 @@ private ImgLib2Builder()
}

/**
* Creates a {@link Img} from a given {@link Tensor} tensor
* Creates a {@link RandomAccessibleInterval} from a given {@link Tensor} tensor
* @param <T>
* the possible ImgLib2 datatypes of the image
* @param tensor
* The {@link Tensor} tensor data of datatype belonging to {@link TType} is read from.
* @return The {@link Img} built from the {@link TType} tensor.
* @return The {@link RandomAccessibleInterval} built from the {@link TType} tensor.
* @throws IllegalArgumentException If the {@link TType} tensor type is not supported.
*/
@SuppressWarnings("unchecked")
public static <T extends Type<T>> Img<T> build(Tensor<? extends TType> tensor) throws IllegalArgumentException
public static <T extends Type<T>> RandomAccessibleInterval<T> build(Tensor<? extends TType> tensor) throws IllegalArgumentException
{
switch (tensor.dataType().name())
{
case TUint8.NAME:
return (Img<T>) buildFromTensorUByte((Tensor<TUint8>) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorUByte((Tensor<TUint8>) tensor);
case TInt32.NAME:
return (Img<T>) buildFromTensorInt((Tensor<TInt32>) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorInt((Tensor<TInt32>) tensor);
case TFloat32.NAME:
return (Img<T>) buildFromTensorFloat((Tensor<TFloat32>) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorFloat((Tensor<TFloat32>) tensor);
case TFloat64.NAME:
return (Img<T>) buildFromTensorDouble((Tensor<TFloat64>) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorDouble((Tensor<TFloat64>) tensor);
case TInt64.NAME:
return (Img<T>) buildFromTensorLong((Tensor<TInt64>) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorLong((Tensor<TInt64>) tensor);
default:
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType().name());
}
}

/**
* Builds a {@link Img} from a unsigned byte-typed {@link TUint8} tensor.
* Builds a {@link RandomAccessibleInterval} from a unsigned byte-typed {@link TUint8} tensor.
*
* @param tensor
* The {@link TUint8} tensor data is read from.
* @return The {@link Img} built from the tensor, of type {@link UnsignedByteType}.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link UnsignedByteType}.
*/
private static Img<UnsignedByteType> buildFromTensorUByte(Tensor<TUint8> tensor)
private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(Tensor<TUint8> tensor)
{
long[] tensorShape = tensor.shape().asArray();
final ArrayImgFactory< UnsignedByteType > factory = new ArrayImgFactory<>( new UnsignedByteType() );
final Img< UnsignedByteType > outputImg = factory.create(tensorShape);
Cursor<UnsignedByteType> tensorCursor= outputImg.cursor();
long[] arrayShape = tensor.shape().asArray();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
byte[] flatArr = new byte[totalSize];
tensor.rawData().read(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
byte val = flatArr[flatPos];
if (val < 0)
tensorCursor.get().set(256 + (int) val);
else
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<UnsignedByteType> rai = ArrayImgs.unsignedBytes(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a unsigned int32-typed {@link TInt32} tensor.
* Builds a {@link RandomAccessibleInterval} from a unsigned int32-typed {@link TInt32} tensor.
*
* @param tensor
* The {@link TInt32} tensor data is read from.
* @return The {@link Img} built from the tensor, of type {@link IntType}.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link IntType}.
*/
private static Img<IntType> buildFromTensorInt(Tensor<TInt32> tensor)
private static RandomAccessibleInterval<IntType> buildFromTensorInt(Tensor<TInt32> tensor)
{
long[] tensorShape = tensor.shape().asArray();
final ArrayImgFactory< IntType > factory = new ArrayImgFactory<>( new IntType() );
final Img< IntType > outputImg = factory.create(tensorShape);
Cursor<IntType> tensorCursor= outputImg.cursor();
long[] arrayShape = tensor.shape().asArray();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
int[] flatArr = new int[totalSize];
tensor.rawData().asInts().read(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
int val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<IntType> rai = ArrayImgs.ints(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a unsigned float32-typed {@link TFloat32} tensor.
* Builds a {@link RandomAccessibleInterval} from a unsigned float32-typed {@link TFloat32} tensor.
*
* @param tensor
* The {@link TFloat32} tensor data is read from.
* @return The {@link Img} built from the tensor, of type {@link FloatType}.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link FloatType}.
*/
private static Img<FloatType> buildFromTensorFloat(Tensor<TFloat32> tensor)
private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(Tensor<TFloat32> tensor)
{
long[] tensorShape = tensor.shape().asArray();
final ArrayImgFactory< FloatType > factory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > outputImg = factory.create(tensorShape);
Cursor<FloatType> tensorCursor= outputImg.cursor();
long[] arrayShape = tensor.shape().asArray();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
float[] flatArr = new float[totalSize];
tensor.rawData().asFloats().read(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
float val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<FloatType> rai = ArrayImgs.floats(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a unsigned float64-typed {@link TFloat64} tensor.
* Builds a {@link RandomAccessibleInterval} from a unsigned float64-typed {@link TFloat64} tensor.
*
* @param tensor
* The {@link TFloat64} tensor data is read from.
* @return The {@link Img} built from the tensor, of type {@link DoubleType}.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link DoubleType}.
*/
private static Img<DoubleType> buildFromTensorDouble(Tensor<TFloat64> tensor)
private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(Tensor<TFloat64> tensor)
{
long[] tensorShape = tensor.shape().asArray();
final ArrayImgFactory< DoubleType > factory = new ArrayImgFactory<>( new DoubleType() );
final Img< DoubleType > outputImg = factory.create(tensorShape);
Cursor<DoubleType> tensorCursor= outputImg.cursor();
long[] arrayShape = tensor.shape().asArray();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
double[] flatArr = new double[totalSize];
tensor.rawData().asDoubles().read(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
double val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<DoubleType> rai = ArrayImgs.doubles(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a unsigned int64-typed {@link TInt64} tensor.
* Builds a {@link RandomAccessibleInterval} from a unsigned int64-typed {@link TInt64} tensor.
*
* @param tensor
* The {@link TInt64} tensor data is read from.
* @return The {@link Img} built from the tensor, of type {@link LongType}.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link LongType}.
*/
private static Img<LongType> buildFromTensorLong(Tensor<TInt64> tensor)
private static RandomAccessibleInterval<LongType> buildFromTensorLong(Tensor<TInt64> tensor)
{
long[] tensorShape = tensor.shape().asArray();
final ArrayImgFactory< LongType > factory = new ArrayImgFactory<>( new LongType() );
final Img< LongType > outputImg = factory.create(tensorShape);
Cursor<LongType> tensorCursor= outputImg.cursor();
long[] arrayShape = tensor.shape().asArray();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
long[] flatArr = new long[totalSize];
tensor.rawData().asLongs().read(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
long val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<LongType> rai = ArrayImgs.longs(flatArr, tensorShape);
return Utils.transpose(rai);
}
}

0 comments on commit ef6d4d2

Please sign in to comment.