From f50a2ad72d21ba8a6be873410096591ed86cc2c0 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Sat, 23 Nov 2024 14:37:21 +0100 Subject: [PATCH] stable working version --- .../javacpp/PytorchJavaCPPInterface.java | 3 -- .../pytorch/javacpp/shm/ShmBuilder.java | 45 ++++++++++++++++--- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java index 87a283a..8bc16e6 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java @@ -253,10 +253,8 @@ protected void runFromShmas(List inputs, List outputs) throws IO IValue output = model.forward(inputsVector); TensorVector outputTensorVector = null; if (output.isTensorList()) { - System.out.println("entered 1"); outputTensorVector = output.toTensorVector(); } else { - System.out.println("entered 2"); outputTensorVector = new TensorVector(); outputTensorVector.put(output.toTensor()); } @@ -264,7 +262,6 @@ protected void runFromShmas(List inputs, List outputs) throws IO // Fill the agnostic output tensors list with data from the inference result int c = 0; for (String ee : outputs) { - System.out.println(ee); Map decoded = Types.decode(ee); ShmBuilder.build(outputTensorVector.get(c ++), (String) decoded.get(MEM_NAME_KEY)); } diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java index c31281c..8165914 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java @@ -20,21 +20,22 @@ */ package io.bioimage.modelrunner.pytorch.javacpp.shm; -import io.bioimage.modelrunner.pytorch.javacpp.tensor.ImgLib2Builder; import io.bioimage.modelrunner.system.PlatformDetection; import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; import io.bioimage.modelrunner.utils.CommonUtils; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; import java.util.Arrays; import org.bytedeco.pytorch.Tensor; import net.imglib2.type.numeric.integer.IntType; import net.imglib2.type.numeric.integer.LongType; -import net.imglib2.RandomAccessibleInterval; import net.imglib2.type.numeric.integer.ByteType; import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; @@ -88,7 +89,14 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new ByteType(), false, true); - shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); + long flatSize = 1; + for (long l : arrayShape) {flatSize *= l;} + byte[] flat = new byte[(int) flatSize]; + ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize)); + tensor.data_ptr_byte().get(flat); + byteBuffer.put(flat); + byteBuffer.rewind(); + shma.getDataBufferNoHeader().put(byteBuffer); if (PlatformDetection.isWindows()) shma.close(); } @@ -99,8 +107,15 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true); - RandomAccessibleInterval rai = shma.getSharedRAI(); - rai = ImgLib2Builder.build(tensor); + long flatSize = 1; + for (long l : arrayShape) {flatSize *= l;} + int[] flat = new int[(int) flatSize]; + ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Integer.BYTES)); + IntBuffer floatBuffer = byteBuffer.asIntBuffer(); + tensor.data_ptr_int().get(flat); + floatBuffer.put(flat); + byteBuffer.rewind(); + shma.getDataBufferNoHeader().put(byteBuffer); if (PlatformDetection.isWindows()) shma.close(); } @@ -130,7 +145,15 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true); - shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); + long flatSize = 1; + for (long l : arrayShape) {flatSize *= l;} + double[] flat = new double[(int) flatSize]; + ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Double.BYTES)); + DoubleBuffer floatBuffer = byteBuffer.asDoubleBuffer(); + tensor.data_ptr_double().get(flat); + floatBuffer.put(flat); + byteBuffer.rewind(); + shma.getDataBufferNoHeader().put(byteBuffer); if (PlatformDetection.isWindows()) shma.close(); } @@ -141,7 +164,15 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true); - shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); + long flatSize = 1; + for (long l : arrayShape) {flatSize *= l;} + long[] flat = new long[(int) flatSize]; + ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Long.BYTES)); + LongBuffer floatBuffer = byteBuffer.asLongBuffer(); + tensor.data_ptr_long().get(flat); + floatBuffer.put(flat); + byteBuffer.rewind(); + shma.getDataBufferNoHeader().put(byteBuffer); if (PlatformDetection.isWindows()) shma.close(); } }