From e8d0cb65fcc26489c1d2e7f9f3d45bea5e169025 Mon Sep 17 00:00:00 2001
From: carlosuc3m <100329787@alumnos.uc3m.es>
Date: Mon, 25 Mar 2024 18:57:36 +0100
Subject: [PATCH] update to new JDLL
---
pom.xml | 2 +-
.../javacpp/PytorchJavaCPPInterface.java | 57 ++++++++++++-------
.../javacpp/shm/NDArrayShmBuilder.java | 7 ++-
3 files changed, 41 insertions(+), 25 deletions(-)
diff --git a/pom.xml b/pom.xml
index 0d44223..12b63ca 100644
--- a/pom.xml
+++ b/pom.xml
@@ -126,7 +126,7 @@
sign,deploy-to-scijava
2.0.1-1.5.9
- 0.5.1
+ 0.5.6-SNAPSHOT
11.8-8.6-1.5.8
2023.1-1.5.9
diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java
index 8bedd98..fffd3cf 100644
--- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java
+++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java
@@ -30,6 +30,7 @@
import java.net.URL;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
+import java.nio.file.FileAlreadyExistsException;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.HashMap;
@@ -59,6 +60,8 @@
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
+import net.imglib2.util.Cast;
+import net.imglib2.util.Util;
/**
* This class implements an interface that allows the main plugin to interact in
@@ -96,8 +99,9 @@ public class PytorchJavaCPPInterface implements DeepLearningEngineInterface
private boolean interprocessing = true;
private Process process;
-
- private List shmaList = new ArrayList();
+
+ private List shmaInputList = new ArrayList();
+ private List shmaOutputList = new ArrayList();
private List shmaNamesList = new ArrayList();
@@ -179,8 +183,10 @@ public static < T extends RealType< T > & NativeType< T > > void main(String[] a
for (int i = 1; i < args.length; i ++) {
HashMap map = gson.fromJson(args[i], mapType);
if ((boolean) map.get(IS_INPUT_KEY)) {
- RandomAccessibleInterval rai = SharedMemoryArray.buildImgLib2FromNumpyLikeSHMA((String) map.get(MEM_NAME_KEY));
- inputsVector.put(new IValue(JavaCPPTensorBuilder.buildFromRai(rai)));
+ SharedMemoryArray shma = SharedMemoryArray.read((String) map.get(MEM_NAME_KEY));
+ RandomAccessibleInterval rai = shma.getSharedRAI();
+ inputsVector.put(new IValue(JavaCPPTensorBuilder.buildFromRai(rai)));
+ if (PlatformDetection.isWindows()) shma.close();
}
}
// Run model
@@ -199,10 +205,11 @@ public static < T extends RealType< T > & NativeType< T > > void main(String[] a
for (int i = 1; i < args.length; i ++) {
HashMap map = gson.fromJson(args[i], mapType);
if (!((boolean) map.get(IS_INPUT_KEY))) {
- NDArrayShmBuilder.buildShma(outputTensorVector.get(c), (String) map.get(MEM_NAME_KEY));
+ SharedMemoryArray shma = NDArrayShmBuilder.buildShma(outputTensorVector.get(c), (String) map.get(MEM_NAME_KEY));
outputTensorVector.get(c).close();
outputTensorVector.get(c).deallocate();
c ++;
+ if (PlatformDetection.isWindows()) shma.close();
}
}
outputTensorVector.close();
@@ -335,7 +342,8 @@ public static void fillOutputTensors(TensorVector tensorVector, List>
* @throws RunModelException if there is any issue running the model
*/
public void runInterprocessing(List> inputTensors, List> outputTensors) throws RunModelException {
- shmaList = new ArrayList();
+ shmaInputList = new ArrayList();
+ shmaOutputList = new ArrayList();
try {
List args = getProcessCommandsWithoutArgs();
List encIns = encodeInputs(inputTensors);
@@ -355,7 +363,14 @@ public void runInterprocessing(List> inputTensors, List> out
process = null;
for (int i = 0; i < outputTensors.size(); i ++) {
String name = (String) decodeString(encOuts.get(i)).get(MEM_NAME_KEY);
- outputTensors.get(i).setData(SharedMemoryArray.buildImgLib2FromNumpyLikeSHMA(name));
+ SharedMemoryArray shm = shmaOutputList.stream()
+ .filter(ss -> ss.getName().equals(name)).findFirst().orElse(null);
+ if (shm == null) {
+ shm = SharedMemoryArray.read(name);
+ shmaOutputList.add(shm);
+ }
+ RandomAccessibleInterval> rai = shm.getSharedRAI();
+ outputTensors.get(i).setData(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai))));
}
closeShmas();
} catch (Exception e) {
@@ -366,13 +381,14 @@ public void runInterprocessing(List> inputTensors, List> out
}
private void closeShmas() {
- shmaList.forEach(shm -> {
+ shmaInputList.forEach(shm -> {
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
});
- // TODO add methos imilar to Python's shared_memory.SharedMemory(name="") in SharedArrays class in JDLL
- this.shmaNamesList.forEach(shm -> {
- try { SharedMemoryArray.buildImgLib2FromNumpyLikeSHMA(shm); } catch (Exception e1) {}
+ shmaInputList = null;
+ shmaOutputList.forEach(shm -> {
+ try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
});
+ shmaOutputList = null;
}
private static List modifyForWinCmd(List ins){
@@ -385,26 +401,25 @@ private static List modifyForWinCmd(List ins){
}
- private List encodeInputs(List> inputTensors) {
- int i = 0;
+ private List encodeInputs(List> inputTensors) throws FileAlreadyExistsException {
List encodedInputTensors = new ArrayList();
Gson gson = new Gson();
for (Tensor> tt : inputTensors) {
- shmaList.add(SharedMemoryArray.buildNumpyLikeSHMA(tt.getData()));
+ SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true);
+ shmaInputList.add(shma);
HashMap map = new HashMap();
map.put(NAME_KEY, tt.getName());
map.put(SHAPE_KEY, tt.getShape());
map.put(DTYPE_KEY, CommonUtils.getDataType(tt.getData()));
map.put(IS_INPUT_KEY, true);
- map.put(MEM_NAME_KEY, shmaList.get(i).getName());
+ map.put(MEM_NAME_KEY, shma.getName());
encodedInputTensors.add(gson.toJson(map));
- i ++;
}
return encodedInputTensors;
}
- private List encodeOutputs(List> outputTensors) {
+ private List encodeOutputs(List> outputTensors) throws FileAlreadyExistsException {
Gson gson = new Gson();
List encodedOutputTensors = new ArrayList();
for (Tensor> tt : outputTensors) {
@@ -414,13 +429,13 @@ private List encodeOutputs(List> outputTensors) {
if (!tt.isEmpty()) {
map.put(SHAPE_KEY, tt.getShape());
map.put(DTYPE_KEY, CommonUtils.getDataType(tt.getData()));
- SharedMemoryArray shma = SharedMemoryArray.buildNumpyLikeSHMA(tt.getData());
- shmaList.add(shma);
+ SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true);
+ shmaOutputList.add(shma);
map.put(MEM_NAME_KEY, shma.getName());
} else if (PlatformDetection.isWindows()){
String memName = SharedMemoryArray.createShmName();
- SharedMemoryArray shma = SharedMemoryArray.buildSHMA(memName, null);
- shmaList.add(shma);
+ SharedMemoryArray shma = SharedMemoryArray.create(0);
+ shmaOutputList.add(shma);
map.put(MEM_NAME_KEY, memName);
} else {
String memName = SharedMemoryArray.createShmName();
diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java
index e02f5e9..a8b44fe 100644
--- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java
+++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java
@@ -33,17 +33,18 @@
* @author Carlos Garcia Lopez de Haro
*/
public class NDArrayShmBuilder {
+
/**
* Build a shared memory segment from a Pytorch tensor
* @param tensor
* the Pytorch tensor created using JavaCPP
* @param memoryName
- * the sahred memory region name
+ * the shared memory region name
* @return the {@link SharedMemoryArray} object created
* @throws IOException if there is any error creating the shared memory segment
*/
- public static SharedMemoryArray buildShma(org.bytedeco.pytorch.Tensor tensor, String memoryName) throws IOException {
- return SharedMemoryArray.buildNumpyLikeSHMA(memoryName, Cast.unchecked(ImgLib2Builder.build(tensor)));
+ public static SharedMemoryArray buildShma(org.bytedeco.pytorch.Tensor tensor, String memoryName) throws IOException {
+ return SharedMemoryArray.createSHMAFromRAI(memoryName, Cast.unchecked(ImgLib2Builder.build(tensor)), false, true);
}
}