Skip to content

Commit

Permalink
stable working version
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 23, 2024
1 parent e6f5bd5 commit f50a2ad
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,15 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
IValue output = model.forward(inputsVector);
TensorVector outputTensorVector = null;
if (output.isTensorList()) {
System.out.println("entered 1");
outputTensorVector = output.toTensorVector();
} else {
System.out.println("entered 2");
outputTensorVector = new TensorVector();
outputTensorVector.put(output.toTensor());
}

// Fill the agnostic output tensors list with data from the inference result
int c = 0;
for (String ee : outputs) {
System.out.println(ee);
Map<String, Object> decoded = Types.decode(ee);
ShmBuilder.build(outputTensorVector.get(c ++), (String) decoded.get(MEM_NAME_KEY));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@
*/
package io.bioimage.modelrunner.pytorch.javacpp.shm;

import io.bioimage.modelrunner.pytorch.javacpp.tensor.ImgLib2Builder;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;

import org.bytedeco.pytorch.Tensor;

import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
Expand Down Expand Up @@ -88,7 +89,14 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new ByteType(), false, true);
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
byte[] flat = new byte[(int) flatSize];
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize));
tensor.data_ptr_byte().get(flat);
byteBuffer.put(flat);
byteBuffer.rewind();
shma.getDataBufferNoHeader().put(byteBuffer);
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -99,8 +107,15 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
RandomAccessibleInterval<?> rai = shma.getSharedRAI();
rai = ImgLib2Builder.build(tensor);
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
int[] flat = new int[(int) flatSize];
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Integer.BYTES));
IntBuffer floatBuffer = byteBuffer.asIntBuffer();
tensor.data_ptr_int().get(flat);
floatBuffer.put(flat);
byteBuffer.rewind();
shma.getDataBufferNoHeader().put(byteBuffer);
if (PlatformDetection.isWindows()) shma.close();
}

Expand Down Expand Up @@ -130,7 +145,15 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
double[] flat = new double[(int) flatSize];
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Double.BYTES));
DoubleBuffer floatBuffer = byteBuffer.asDoubleBuffer();
tensor.data_ptr_double().get(flat);
floatBuffer.put(flat);
byteBuffer.rewind();
shma.getDataBufferNoHeader().put(byteBuffer);
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -141,7 +164,15 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
long[] flat = new long[(int) flatSize];
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Long.BYTES));
LongBuffer floatBuffer = byteBuffer.asLongBuffer();
tensor.data_ptr_long().get(flat);
floatBuffer.put(flat);
byteBuffer.rewind();
shma.getDataBufferNoHeader().put(byteBuffer);
if (PlatformDetection.isWindows()) shma.close();
}
}

0 comments on commit f50a2ad

Please sign in to comment.