diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java index 91eb8d7..dea53ae 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java @@ -25,6 +25,7 @@ import io.bioimage.modelrunner.utils.CommonUtils; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Arrays; import org.bytedeco.pytorch.Tensor; @@ -80,7 +81,8 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); - tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array()); + ByteBuffer buff = shma.getDataBufferNoHeader(); + buff = tensor.asByteBuffer(); if (PlatformDetection.isWindows()) shma.close(); } @@ -91,7 +93,8 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); - tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array()); + ByteBuffer buff = shma.getDataBufferNoHeader(); + buff = tensor.asByteBuffer(); if (PlatformDetection.isWindows()) shma.close(); } @@ -102,7 +105,8 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); - tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array()); + ByteBuffer buff = shma.getDataBufferNoHeader(); + buff = tensor.asByteBuffer(); if (PlatformDetection.isWindows()) shma.close(); } @@ -113,7 +117,8 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); - tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array()); + ByteBuffer buff = shma.getDataBufferNoHeader(); + buff = tensor.asByteBuffer(); if (PlatformDetection.isWindows()) shma.close(); } @@ -124,7 +129,8 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); - tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array()); + ByteBuffer buff = shma.getDataBufferNoHeader(); + buff = tensor.asByteBuffer(); if (PlatformDetection.isWindows()) shma.close(); } } diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/TensorBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/TensorBuilder.java index a013289..0204410 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/TensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/TensorBuilder.java @@ -85,7 +85,10 @@ private static Tensor buildByte(SharedMemoryArray shmArray) if (!shmArray.isNumpyFormat()) throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); ByteBuffer buff = shmArray.getDataBufferNoHeader(); - Tensor ndarray = Tensor.create(buff.array(), ogShape); + byte[] flat = new byte[buff.capacity()]; + buff.get(flat); + buff.rewind(); + Tensor ndarray = Tensor.create(flat, ogShape); return ndarray; } @@ -99,7 +102,10 @@ private static Tensor buildInt(SharedMemoryArray shmaArray) if (!shmaArray.isNumpyFormat()) throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); ByteBuffer buff = shmaArray.getDataBufferNoHeader(); - Tensor ndarray = Tensor.create(buff.asIntBuffer().array(), ogShape); + int[] flat = new int[buff.capacity() / 4]; + buff.asIntBuffer().get(flat); + buff.rewind(); + Tensor ndarray = Tensor.create(flat, ogShape); return ndarray; } @@ -113,7 +119,10 @@ private static org.bytedeco.pytorch.Tensor buildLong(SharedMemoryArray shmArray) if (!shmArray.isNumpyFormat()) throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); ByteBuffer buff = shmArray.getDataBufferNoHeader(); - Tensor ndarray = Tensor.create(buff.asLongBuffer().array(), ogShape); + long[] flat = new long[buff.capacity() / 8]; + buff.asLongBuffer().get(flat); + buff.rewind(); + Tensor ndarray = Tensor.create(flat, ogShape); return ndarray; } @@ -127,7 +136,10 @@ private static org.bytedeco.pytorch.Tensor buildFloat(SharedMemoryArray shmArray if (!shmArray.isNumpyFormat()) throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); ByteBuffer buff = shmArray.getDataBufferNoHeader(); - Tensor ndarray = Tensor.create(buff.asFloatBuffer().array(), ogShape); + float[] flat = new float[buff.capacity() / 4]; + buff.asFloatBuffer().get(flat); + buff.rewind(); + Tensor ndarray = Tensor.create(flat, ogShape); return ndarray; } @@ -141,7 +153,10 @@ private static org.bytedeco.pytorch.Tensor buildDouble(SharedMemoryArray shmArra if (!shmArray.isNumpyFormat()) throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); ByteBuffer buff = shmArray.getDataBufferNoHeader(); - Tensor ndarray = Tensor.create(buff.asDoubleBuffer().array(), ogShape); + double[] flat = new double[buff.capacity() / 8]; + buff.asDoubleBuffer().get(flat); + buff.rewind(); + Tensor ndarray = Tensor.create(flat, ogShape); return ndarray; } }