Skip to content

Commit

Permalink
improve creation of tensors from imgglib2
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 26, 2023
1 parent 77156fa commit 993cc5c
Showing 1 changed file with 42 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@
package io.bioimage.modelrunner.pytorch.javacpp.tensor;

import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.utils.IndexingUtils;
import net.imglib2.Cursor;
import io.bioimage.modelrunner.tensor.Utils;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.blocks.PrimitiveBlocks;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;

/**
* Class that manages the creation of JAvaCPP Pytorch tensors
Expand Down Expand Up @@ -102,25 +100,16 @@ public static <T extends Type<T>> org.bytedeco.pytorch.Tensor build(RandomAccess
*/
private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleInterval<ByteType> tensor)
{
long[] tensorShape = tensor.dimensionsAsLongArray();
Cursor<ByteType> tensorCursor;
if (tensor instanceof IntervalView)
tensorCursor = ((IntervalView<ByteType>) tensor).cursor();
else if (tensor instanceof Img)
tensorCursor = ((Img<ByteType>) tensor).cursor();
else
throw new IllegalArgumentException("The data of the " + Tensor.class + " has "
+ "to be an instance of " + Img.class + " or " + IntervalView.class);
long flatSize = 1;
for (long dd : tensor.dimensionsAsLongArray()) { flatSize *= dd;}
byte[] flatArr = new byte[(int) flatSize];
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
byte val = tensorCursor.get().getByte();
flatArr[flatPos] = val;
}
tensor = Utils.transpose(tensor);
PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final byte[] flatArr = new byte[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
return ndarray;
}
Expand All @@ -135,25 +124,16 @@ else if (tensor instanceof Img)
*/
private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleInterval<IntType> tensor)
{
long[] tensorShape = tensor.dimensionsAsLongArray();
Cursor<IntType> tensorCursor;
if (tensor instanceof IntervalView)
tensorCursor = ((IntervalView<IntType>) tensor).cursor();
else if (tensor instanceof Img)
tensorCursor = ((Img<IntType>) tensor).cursor();
else
throw new IllegalArgumentException("The data of the " + Tensor.class + " has "
+ "to be an instance of " + Img.class + " or " + IntervalView.class);
long flatSize = 1;
for (long dd : tensor.dimensionsAsLongArray()) { flatSize *= dd;}
int[] flatArr = new int[(int) flatSize];
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
int val = tensorCursor.get().getInteger();
flatArr[flatPos] = val;
}
tensor = Utils.transpose(tensor);
PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final int[] flatArr = new int[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
return ndarray;
}
Expand All @@ -168,25 +148,16 @@ else if (tensor instanceof Img)
*/
private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessibleInterval<FloatType> tensor)
{
long[] tensorShape = tensor.dimensionsAsLongArray();
Cursor<FloatType> tensorCursor;
if (tensor instanceof IntervalView)
tensorCursor = ((IntervalView<FloatType>) tensor).cursor();
else if (tensor instanceof Img)
tensorCursor = ((Img<FloatType>) tensor).cursor();
else
throw new IllegalArgumentException("The data of the " + Tensor.class + " has "
+ "to be an instance of " + Img.class + " or " + IntervalView.class);
long flatSize = 1;
for (long dd : tensor.dimensionsAsLongArray()) { flatSize *= dd;}
float[] flatArr = new float[(int) flatSize];
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
float val = tensorCursor.get().getRealFloat();
flatArr[flatPos] = val;
}
tensor = Utils.transpose(tensor);
PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final float[] flatArr = new float[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
return ndarray;
}
Expand All @@ -201,25 +172,16 @@ else if (tensor instanceof Img)
*/
private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibleInterval<DoubleType> tensor)
{
long[] tensorShape = tensor.dimensionsAsLongArray();
Cursor<DoubleType> tensorCursor;
if (tensor instanceof IntervalView)
tensorCursor = ((IntervalView<DoubleType>) tensor).cursor();
else if (tensor instanceof Img)
tensorCursor = ((Img<DoubleType>) tensor).cursor();
else
throw new IllegalArgumentException("The data of the " + Tensor.class + " has "
+ "to be an instance of " + Img.class + " or " + IntervalView.class);
long flatSize = 1;
for (long dd : tensor.dimensionsAsLongArray()) { flatSize *= dd;}
double[] flatArr = new double[(int) flatSize];
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
double val = tensorCursor.get().getRealDouble();
flatArr[flatPos] = val;
}
tensor = Utils.transpose(tensor);
PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final double[] flatArr = new double[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
return ndarray;
}
Expand Down

0 comments on commit 993cc5c

Please sign in to comment.