From e8d0cb65fcc26489c1d2e7f9f3d45bea5e169025 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Mon, 25 Mar 2024 18:57:36 +0100 Subject: [PATCH] update to new JDLL --- pom.xml | 2 +- .../javacpp/PytorchJavaCPPInterface.java | 57 ++++++++++++------- .../javacpp/shm/NDArrayShmBuilder.java | 7 ++- 3 files changed, 41 insertions(+), 25 deletions(-) diff --git a/pom.xml b/pom.xml index 0d44223..12b63ca 100644 --- a/pom.xml +++ b/pom.xml @@ -126,7 +126,7 @@ sign,deploy-to-scijava 2.0.1-1.5.9 - 0.5.1 + 0.5.6-SNAPSHOT 11.8-8.6-1.5.8 2023.1-1.5.9 diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java index 8bedd98..fffd3cf 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java @@ -30,6 +30,7 @@ import java.net.URL; import java.net.URLDecoder; import java.nio.charset.StandardCharsets; +import java.nio.file.FileAlreadyExistsException; import java.security.ProtectionDomain; import java.util.ArrayList; import java.util.HashMap; @@ -59,6 +60,8 @@ import net.imglib2.type.NativeType; import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Cast; +import net.imglib2.util.Util; /** * This class implements an interface that allows the main plugin to interact in @@ -96,8 +99,9 @@ public class PytorchJavaCPPInterface implements DeepLearningEngineInterface private boolean interprocessing = true; private Process process; - - private List shmaList = new ArrayList(); + + private List shmaInputList = new ArrayList(); + private List shmaOutputList = new ArrayList(); private List shmaNamesList = new ArrayList(); @@ -179,8 +183,10 @@ public static < T extends RealType< T > & NativeType< T > > void main(String[] a for (int i = 1; i < args.length; i ++) { HashMap map = gson.fromJson(args[i], mapType); if ((boolean) map.get(IS_INPUT_KEY)) { - RandomAccessibleInterval rai = SharedMemoryArray.buildImgLib2FromNumpyLikeSHMA((String) map.get(MEM_NAME_KEY)); - inputsVector.put(new IValue(JavaCPPTensorBuilder.buildFromRai(rai))); + SharedMemoryArray shma = SharedMemoryArray.read((String) map.get(MEM_NAME_KEY)); + RandomAccessibleInterval rai = shma.getSharedRAI(); + inputsVector.put(new IValue(JavaCPPTensorBuilder.buildFromRai(rai))); + if (PlatformDetection.isWindows()) shma.close(); } } // Run model @@ -199,10 +205,11 @@ public static < T extends RealType< T > & NativeType< T > > void main(String[] a for (int i = 1; i < args.length; i ++) { HashMap map = gson.fromJson(args[i], mapType); if (!((boolean) map.get(IS_INPUT_KEY))) { - NDArrayShmBuilder.buildShma(outputTensorVector.get(c), (String) map.get(MEM_NAME_KEY)); + SharedMemoryArray shma = NDArrayShmBuilder.buildShma(outputTensorVector.get(c), (String) map.get(MEM_NAME_KEY)); outputTensorVector.get(c).close(); outputTensorVector.get(c).deallocate(); c ++; + if (PlatformDetection.isWindows()) shma.close(); } } outputTensorVector.close(); @@ -335,7 +342,8 @@ public static void fillOutputTensors(TensorVector tensorVector, List> * @throws RunModelException if there is any issue running the model */ public void runInterprocessing(List> inputTensors, List> outputTensors) throws RunModelException { - shmaList = new ArrayList(); + shmaInputList = new ArrayList(); + shmaOutputList = new ArrayList(); try { List args = getProcessCommandsWithoutArgs(); List encIns = encodeInputs(inputTensors); @@ -355,7 +363,14 @@ public void runInterprocessing(List> inputTensors, List> out process = null; for (int i = 0; i < outputTensors.size(); i ++) { String name = (String) decodeString(encOuts.get(i)).get(MEM_NAME_KEY); - outputTensors.get(i).setData(SharedMemoryArray.buildImgLib2FromNumpyLikeSHMA(name)); + SharedMemoryArray shm = shmaOutputList.stream() + .filter(ss -> ss.getName().equals(name)).findFirst().orElse(null); + if (shm == null) { + shm = SharedMemoryArray.read(name); + shmaOutputList.add(shm); + } + RandomAccessibleInterval rai = shm.getSharedRAI(); + outputTensors.get(i).setData(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai)))); } closeShmas(); } catch (Exception e) { @@ -366,13 +381,14 @@ public void runInterprocessing(List> inputTensors, List> out } private void closeShmas() { - shmaList.forEach(shm -> { + shmaInputList.forEach(shm -> { try { shm.close(); } catch (IOException e1) { e1.printStackTrace();} }); - // TODO add methos imilar to Python's shared_memory.SharedMemory(name="") in SharedArrays class in JDLL - this.shmaNamesList.forEach(shm -> { - try { SharedMemoryArray.buildImgLib2FromNumpyLikeSHMA(shm); } catch (Exception e1) {} + shmaInputList = null; + shmaOutputList.forEach(shm -> { + try { shm.close(); } catch (IOException e1) { e1.printStackTrace();} }); + shmaOutputList = null; } private static List modifyForWinCmd(List ins){ @@ -385,26 +401,25 @@ private static List modifyForWinCmd(List ins){ } - private List encodeInputs(List> inputTensors) { - int i = 0; + private List encodeInputs(List> inputTensors) throws FileAlreadyExistsException { List encodedInputTensors = new ArrayList(); Gson gson = new Gson(); for (Tensor tt : inputTensors) { - shmaList.add(SharedMemoryArray.buildNumpyLikeSHMA(tt.getData())); + SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true); + shmaInputList.add(shma); HashMap map = new HashMap(); map.put(NAME_KEY, tt.getName()); map.put(SHAPE_KEY, tt.getShape()); map.put(DTYPE_KEY, CommonUtils.getDataType(tt.getData())); map.put(IS_INPUT_KEY, true); - map.put(MEM_NAME_KEY, shmaList.get(i).getName()); + map.put(MEM_NAME_KEY, shma.getName()); encodedInputTensors.add(gson.toJson(map)); - i ++; } return encodedInputTensors; } - private List encodeOutputs(List> outputTensors) { + private List encodeOutputs(List> outputTensors) throws FileAlreadyExistsException { Gson gson = new Gson(); List encodedOutputTensors = new ArrayList(); for (Tensor tt : outputTensors) { @@ -414,13 +429,13 @@ private List encodeOutputs(List> outputTensors) { if (!tt.isEmpty()) { map.put(SHAPE_KEY, tt.getShape()); map.put(DTYPE_KEY, CommonUtils.getDataType(tt.getData())); - SharedMemoryArray shma = SharedMemoryArray.buildNumpyLikeSHMA(tt.getData()); - shmaList.add(shma); + SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true); + shmaOutputList.add(shma); map.put(MEM_NAME_KEY, shma.getName()); } else if (PlatformDetection.isWindows()){ String memName = SharedMemoryArray.createShmName(); - SharedMemoryArray shma = SharedMemoryArray.buildSHMA(memName, null); - shmaList.add(shma); + SharedMemoryArray shma = SharedMemoryArray.create(0); + shmaOutputList.add(shma); map.put(MEM_NAME_KEY, memName); } else { String memName = SharedMemoryArray.createShmName(); diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java index e02f5e9..a8b44fe 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java @@ -33,17 +33,18 @@ * @author Carlos Garcia Lopez de Haro */ public class NDArrayShmBuilder { + /** * Build a shared memory segment from a Pytorch tensor * @param tensor * the Pytorch tensor created using JavaCPP * @param memoryName - * the sahred memory region name + * the shared memory region name * @return the {@link SharedMemoryArray} object created * @throws IOException if there is any error creating the shared memory segment */ - public static SharedMemoryArray buildShma(org.bytedeco.pytorch.Tensor tensor, String memoryName) throws IOException { - return SharedMemoryArray.buildNumpyLikeSHMA(memoryName, Cast.unchecked(ImgLib2Builder.build(tensor))); + public static SharedMemoryArray buildShma(org.bytedeco.pytorch.Tensor tensor, String memoryName) throws IOException { + return SharedMemoryArray.createSHMAFromRAI(memoryName, Cast.unchecked(ImgLib2Builder.build(tensor)), false, true); } }