Skip to content

Commit

Permalink
increase robustness creating pytorch tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 28, 2023
1 parent 0b09e00 commit d7febcb
Showing 1 changed file with 49 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,22 @@
*/
package io.bioimage.modelrunner.pytorch.javacpp.tensor;

import java.nio.ByteBuffer;
import java.util.Arrays;

import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.Utils;
import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.blocks.PrimitiveBlocks;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.Views;

/**
* Class that manages the creation of JAvaCPP Pytorch tensors
Expand Down Expand Up @@ -101,16 +107,24 @@ public static <T extends Type<T>> org.bytedeco.pytorch.Tensor build(RandomAccess
private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleInterval<ByteType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final byte[] flatArr = new byte[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<ByteType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
return ndarray;
}
Expand All @@ -126,17 +140,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleInterval<IntType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final int[] flatArr = new int[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);

Cursor<IntType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
return ndarray;
}

Expand All @@ -151,17 +173,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessibleInterval<FloatType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final float[] flatArr = new float[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);

Cursor<FloatType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
return ndarray;
}

Expand All @@ -176,17 +206,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibleInterval<DoubleType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final double[] flatArr = new double[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);

Cursor<DoubleType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
return ndarray;
}
}

0 comments on commit d7febcb

Please sign in to comment.