Skip to content

Commit

Permalink
correct support for shm to tensor creation
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 3, 2024
1 parent 050116e commit 41c0519
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.bioimage.modelrunner.utils.CommonUtils;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;

import org.bytedeco.pytorch.Tensor;
Expand Down Expand Up @@ -80,7 +81,8 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
ByteBuffer buff = shma.getDataBufferNoHeader();
buff = tensor.asByteBuffer();
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -91,7 +93,8 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
ByteBuffer buff = shma.getDataBufferNoHeader();
buff = tensor.asByteBuffer();
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -102,7 +105,8 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
ByteBuffer buff = shma.getDataBufferNoHeader();
buff = tensor.asByteBuffer();
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -113,7 +117,8 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
ByteBuffer buff = shma.getDataBufferNoHeader();
buff = tensor.asByteBuffer();
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -124,7 +129,8 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
ByteBuffer buff = shma.getDataBufferNoHeader();
buff = tensor.asByteBuffer();
if (PlatformDetection.isWindows()) shma.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ private static Tensor buildByte(SharedMemoryArray shmArray)
if (!shmArray.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = shmArray.getDataBufferNoHeader();
Tensor ndarray = Tensor.create(buff.array(), ogShape);
byte[] flat = new byte[buff.capacity()];
buff.get(flat);
buff.rewind();
Tensor ndarray = Tensor.create(flat, ogShape);
return ndarray;
}

Expand All @@ -99,7 +102,10 @@ private static Tensor buildInt(SharedMemoryArray shmaArray)
if (!shmaArray.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = shmaArray.getDataBufferNoHeader();
Tensor ndarray = Tensor.create(buff.asIntBuffer().array(), ogShape);
int[] flat = new int[buff.capacity() / 4];
buff.asIntBuffer().get(flat);
buff.rewind();
Tensor ndarray = Tensor.create(flat, ogShape);
return ndarray;
}

Expand All @@ -113,7 +119,10 @@ private static org.bytedeco.pytorch.Tensor buildLong(SharedMemoryArray shmArray)
if (!shmArray.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = shmArray.getDataBufferNoHeader();
Tensor ndarray = Tensor.create(buff.asLongBuffer().array(), ogShape);
long[] flat = new long[buff.capacity() / 8];
buff.asLongBuffer().get(flat);
buff.rewind();
Tensor ndarray = Tensor.create(flat, ogShape);
return ndarray;
}

Expand All @@ -127,7 +136,10 @@ private static org.bytedeco.pytorch.Tensor buildFloat(SharedMemoryArray shmArray
if (!shmArray.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = shmArray.getDataBufferNoHeader();
Tensor ndarray = Tensor.create(buff.asFloatBuffer().array(), ogShape);
float[] flat = new float[buff.capacity() / 4];
buff.asFloatBuffer().get(flat);
buff.rewind();
Tensor ndarray = Tensor.create(flat, ogShape);
return ndarray;
}

Expand All @@ -141,7 +153,10 @@ private static org.bytedeco.pytorch.Tensor buildDouble(SharedMemoryArray shmArra
if (!shmArray.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = shmArray.getDataBufferNoHeader();
Tensor ndarray = Tensor.create(buff.asDoubleBuffer().array(), ogShape);
double[] flat = new double[buff.capacity() / 8];
buff.asDoubleBuffer().get(flat);
buff.rewind();
Tensor ndarray = Tensor.create(flat, ogShape);
return ndarray;
}
}

0 comments on commit 41c0519

Please sign in to comment.