From a937f90511704cfb117a05fd1ec47683fb9af3f6 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Sat, 23 Nov 2024 13:53:09 +0100 Subject: [PATCH] corerct ultramegabug that was avoiding copying from tensor to shm --- .../pytorch/javacpp/PytorchJavaCPPInterface.java | 3 +++ .../modelrunner/pytorch/javacpp/shm/ShmBuilder.java | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java index 8bc16e6..87a283a 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java @@ -253,8 +253,10 @@ protected void runFromShmas(List inputs, List outputs) throws IO IValue output = model.forward(inputsVector); TensorVector outputTensorVector = null; if (output.isTensorList()) { + System.out.println("entered 1"); outputTensorVector = output.toTensorVector(); } else { + System.out.println("entered 2"); outputTensorVector = new TensorVector(); outputTensorVector.put(output.toTensor()); } @@ -262,6 +264,7 @@ protected void runFromShmas(List inputs, List outputs) throws IO // Fill the agnostic output tensors list with data from the inference result int c = 0; for (String ee : outputs) { + System.out.println(ee); Map decoded = Types.decode(ee); ShmBuilder.build(outputTensorVector.get(c ++), (String) decoded.get(MEM_NAME_KEY)); } diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java index 90cccdf..6d29a1f 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java @@ -25,6 +25,7 @@ import io.bioimage.modelrunner.utils.CommonUtils; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Arrays; import org.bytedeco.pytorch.Tensor; @@ -106,7 +107,11 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true); - shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); + long flatSize = 1; + for (long l : arrayShape) {flatSize *= l;} + ByteBuffer byteBuffer = ByteBuffer.allocate((int) (flatSize * Float.BYTES)); + tensor.data_ptr_float().get(byteBuffer.asFloatBuffer().array()); + shma.getDataBufferNoHeader().put(byteBuffer); if (PlatformDetection.isWindows()) shma.close(); }