Skip to content

Commit

Permalink
improve interprocessing communication
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 2, 2024
1 parent 65b4ed6 commit b4f0158
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@
import io.bioimage.modelrunner.utils.CommonUtils;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;

import org.bytedeco.pytorch.Tensor;

import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.UnsignedByteType;

/**
Expand Down Expand Up @@ -81,12 +79,8 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
if (CommonUtils.int32Overflows(arrayShape, 1))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
byte[] flatArr = new byte[(int) flatSize];
tensor.data_ptr_byte().get(flatArr);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
shma.setBuffer(ByteBuffer.wrap(flatArr));
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -96,14 +90,8 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
int[] flatArr = new int[(int) flatSize];
tensor.data_ptr_int().get(flatArr);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
ByteBuffer byteBuffer = ByteBuffer.allocate(flatArr.length * Integer.BYTES);
byteBuffer.asIntBuffer().put(flatArr);
shma.setBuffer(byteBuffer);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -113,15 +101,8 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
float[] flatArr = new float[(int) flatSize];
// TODO check what we get with tensor.data_ptr_byte().get(flatArr);
tensor.data_ptr_float().get(flatArr);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
ByteBuffer byteBuffer = ByteBuffer.allocate(flatArr.length * Float.BYTES);
byteBuffer.asFloatBuffer().put(flatArr);
shma.setBuffer(byteBuffer);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -131,14 +112,8 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
double[] flatArr = new double[(int) flatSize];
tensor.data_ptr_double().get(flatArr);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
ByteBuffer byteBuffer = ByteBuffer.allocate(flatArr.length * Double.BYTES);
byteBuffer.asDoubleBuffer().put(flatArr);
shma.setBuffer(byteBuffer);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -148,14 +123,8 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
long[] flatArr = new long[(int) flatSize];
tensor.data_ptr_long().get(flatArr);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
ByteBuffer byteBuffer = ByteBuffer.allocate(flatArr.length * Long.BYTES);
byteBuffer.asLongBuffer().put(flatArr);
shma.setBuffer(byteBuffer);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
if (PlatformDetection.isWindows()) shma.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
import net.imglib2.util.Cast;

import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;

import org.bytedeco.pytorch.Tensor;
Expand Down Expand Up @@ -103,10 +99,7 @@ private static Tensor buildInt(SharedMemoryArray shmaArray)
if (!shmaArray.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = shmaArray.getDataBufferNoHeader();
IntBuffer intBuff = buff.asIntBuffer();
int[] intArray = new int[intBuff.capacity()];
intBuff.get(intArray);
Tensor ndarray = Tensor.create(intBuff.array(), ogShape);
Tensor ndarray = Tensor.create(buff.asIntBuffer().array(), ogShape);
return ndarray;
}

Expand All @@ -120,10 +113,7 @@ private static org.bytedeco.pytorch.Tensor buildLong(SharedMemoryArray shmArray)
if (!shmArray.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = shmArray.getDataBufferNoHeader();
LongBuffer longBuff = buff.asLongBuffer();
long[] longArray = new long[longBuff.capacity()];
longBuff.get(longArray);
Tensor ndarray = Tensor.create(longBuff.array(), ogShape);
Tensor ndarray = Tensor.create(buff.asLongBuffer().array(), ogShape);
return ndarray;
}

Expand All @@ -137,10 +127,7 @@ private static org.bytedeco.pytorch.Tensor buildFloat(SharedMemoryArray shmArray
if (!shmArray.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = shmArray.getDataBufferNoHeader();
FloatBuffer floatBuff = buff.asFloatBuffer();
float[] floatArray = new float[floatBuff.capacity()];
floatBuff.get(floatArray);
Tensor ndarray = Tensor.create(floatBuff.array(), ogShape);
Tensor ndarray = Tensor.create(buff.asFloatBuffer().array(), ogShape);
return ndarray;
}

Expand All @@ -154,10 +141,7 @@ private static org.bytedeco.pytorch.Tensor buildDouble(SharedMemoryArray shmArra
if (!shmArray.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = shmArray.getDataBufferNoHeader();
DoubleBuffer doubleBuff = buff.asDoubleBuffer();
double[] doubleArray = new double[doubleBuff.capacity()];
doubleBuff.get(doubleArray);
Tensor ndarray = Tensor.create(doubleBuff.array(), ogShape);
Tensor ndarray = Tensor.create(buff.asDoubleBuffer().array(), ogShape);
return ndarray;
}
}

0 comments on commit b4f0158

Please sign in to comment.