Skip to content

Commit

Permalink
improve interprocessing communication
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 2, 2024
1 parent a4eeaf9 commit 90e4910
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.nio.ByteBuffer;
import java.util.Arrays;

import org.tensorflow.Tensor;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
Expand Down Expand Up @@ -101,13 +100,8 @@ private static void buildFromTensorUByte(TUint8 tensor, String memoryName) throw
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 1;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.asRawTensor().data().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -119,13 +113,8 @@ private static void buildFromTensorInt(TInt32 tensor, String memoryName) throws
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 4;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.asRawTensor().data().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -137,13 +126,8 @@ private static void buildFromTensorFloat(TFloat32 tensor, String memoryName) thr
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 4;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.asRawTensor().data().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -155,13 +139,8 @@ private static void buildFromTensorDouble(TFloat64 tensor, String memoryName) th
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 8;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.asRawTensor().data().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -174,13 +153,8 @@ private static void buildFromTensorLong(TInt64 tensor, String memoryName) throws


SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 8;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.asRawTensor().data().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
if (PlatformDetection.isWindows()) shma.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ private static TInt32 buildInt(SharedMemoryArray tensor)
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
IntBuffer intBuff = buff.asIntBuffer();
int[] intArray = new int[intBuff.capacity()];
intBuff.get(intArray);
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intArray, false);
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intBuff.array(), false);
TInt32 ndarray = TInt32.tensorOf(Shape.of(ogShape),
dataBuffer);
return ndarray;
Expand All @@ -137,9 +135,7 @@ private static TInt64 buildLong(SharedMemoryArray tensor)
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
LongBuffer longBuff = buff.asLongBuffer();
long[] longArray = new long[longBuff.capacity()];
longBuff.get(longArray);
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longArray, false);
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longBuff.array(), false);
TInt64 ndarray = TInt64.tensorOf(Shape.of(ogShape),
dataBuffer);
return ndarray;
Expand All @@ -156,9 +152,7 @@ private static TFloat32 buildFloat(SharedMemoryArray tensor)
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
FloatBuffer floatBuff = buff.asFloatBuffer();
float[] floatArray = new float[floatBuff.capacity()];
floatBuff.get(floatArray);
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatArray, false);
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatBuff.array(), false);
TFloat32 ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
}
Expand All @@ -173,10 +167,8 @@ private static TFloat64 buildDouble(SharedMemoryArray tensor)
if (!tensor.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
DoubleBuffer doubleBuff = buff.asDoubleBuffer();
double[] doubleArray = new double[doubleBuff.capacity()];
doubleBuff.get(doubleArray);
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(doubleArray, false);
DoubleBuffer floatBuff = buff.asDoubleBuffer();
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(floatBuff.array(), false);
TFloat64 ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
}
Expand Down

0 comments on commit 90e4910

Please sign in to comment.