diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java index d22d1f4..90cccdf 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java @@ -29,7 +29,11 @@ import org.bytedeco.pytorch.Tensor; -import net.imglib2.type.numeric.integer.UnsignedByteType; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.type.numeric.integer.ByteType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.type.numeric.real.FloatType; /** * A utility class that converts {@link Tensor}s into {@link SharedMemoryArray}s for @@ -79,7 +83,7 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws if (CommonUtils.int32Overflows(arrayShape, 1)) throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1); - SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new ByteType(), false, true); shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); if (PlatformDetection.isWindows()) shma.close(); } @@ -90,7 +94,7 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws if (CommonUtils.int32Overflows(arrayShape, 4)) throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4); - SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true); shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); if (PlatformDetection.isWindows()) shma.close(); } @@ -101,7 +105,7 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw if (CommonUtils.int32Overflows(arrayShape, 4)) throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4); - SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true); shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); if (PlatformDetection.isWindows()) shma.close(); } @@ -112,7 +116,7 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro if (CommonUtils.int32Overflows(arrayShape, 8)) throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8); - SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true); shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); if (PlatformDetection.isWindows()) shma.close(); } @@ -123,7 +127,7 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws if (CommonUtils.int32Overflows(arrayShape, 8)) throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8); - SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true); shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); if (PlatformDetection.isWindows()) shma.close(); }