diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java index aff776d..0a1b16b 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java @@ -111,14 +111,15 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true); - /*long flatSize = 1; + long flatSize = 1; for (long l : arrayShape) {flatSize *= l;} - ByteBuffer byteBuffer = ByteBuffer.allocate((int) (flatSize * Float.BYTES)); - tensor.data_ptr_float().get(byteBuffer.asFloatBuffer().array()); + float[] flat = new float[(int) flatSize]; + ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Float.BYTES)); + FloatBuffer floatBuffer = byteBuffer.asFloatBuffer(); + tensor.data_ptr_float().get(flat); + floatBuffer.put(flatSize); + byteBuffer.rewind(); shma.getDataBufferNoHeader().put(byteBuffer); - */ - RandomAccessibleInterval rai = shma.getSharedRAI(); - rai = ImgLib2Builder.build(tensor); if (PlatformDetection.isWindows()) shma.close(); }