Skip to content

Commit

Permalink
update to new JDLL
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Mar 25, 2024
1 parent 9949c98 commit e8d0cb6
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
<releaseProfiles>sign,deploy-to-scijava</releaseProfiles>

<pytorch-javacpp.version>2.0.1-1.5.9</pytorch-javacpp.version>
<dl-modelrunner.version>0.5.1</dl-modelrunner.version>
<dl-modelrunner.version>0.5.6-SNAPSHOT</dl-modelrunner.version>
<cuda-javacpp.version>11.8-8.6-1.5.8</cuda-javacpp.version>
<mkl-javacpp.version>2023.1-1.5.9</mkl-javacpp.version>
</properties>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.net.URL;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileAlreadyExistsException;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -59,6 +60,8 @@
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;

/**
* This class implements an interface that allows the main plugin to interact in
Expand Down Expand Up @@ -96,8 +99,9 @@ public class PytorchJavaCPPInterface implements DeepLearningEngineInterface
private boolean interprocessing = true;

private Process process;

private List<SharedMemoryArray> shmaList = new ArrayList<SharedMemoryArray>();

private List<SharedMemoryArray> shmaInputList = new ArrayList<SharedMemoryArray>();
private List<SharedMemoryArray> shmaOutputList = new ArrayList<SharedMemoryArray>();

private List<String> shmaNamesList = new ArrayList<String>();

Expand Down Expand Up @@ -179,8 +183,10 @@ public static < T extends RealType< T > & NativeType< T > > void main(String[] a
for (int i = 1; i < args.length; i ++) {
HashMap<String, Object> map = gson.fromJson(args[i], mapType);
if ((boolean) map.get(IS_INPUT_KEY)) {
RandomAccessibleInterval<T> rai = SharedMemoryArray.buildImgLib2FromNumpyLikeSHMA((String) map.get(MEM_NAME_KEY));
inputsVector.put(new IValue(JavaCPPTensorBuilder.buildFromRai(rai)));
SharedMemoryArray shma = SharedMemoryArray.read((String) map.get(MEM_NAME_KEY));
RandomAccessibleInterval<T> rai = shma.getSharedRAI();
inputsVector.put(new IValue(JavaCPPTensorBuilder.buildFromRai(rai)));
if (PlatformDetection.isWindows()) shma.close();
}
}
// Run model
Expand All @@ -199,10 +205,11 @@ public static < T extends RealType< T > & NativeType< T > > void main(String[] a
for (int i = 1; i < args.length; i ++) {
HashMap<String, Object> map = gson.fromJson(args[i], mapType);
if (!((boolean) map.get(IS_INPUT_KEY))) {
NDArrayShmBuilder.buildShma(outputTensorVector.get(c), (String) map.get(MEM_NAME_KEY));
SharedMemoryArray shma = NDArrayShmBuilder.buildShma(outputTensorVector.get(c), (String) map.get(MEM_NAME_KEY));
outputTensorVector.get(c).close();
outputTensorVector.get(c).deallocate();
c ++;
if (PlatformDetection.isWindows()) shma.close();
}
}
outputTensorVector.close();
Expand Down Expand Up @@ -335,7 +342,8 @@ public static void fillOutputTensors(TensorVector tensorVector, List<Tensor<?>>
* @throws RunModelException if there is any issue running the model
*/
public void runInterprocessing(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) throws RunModelException {
shmaList = new ArrayList<SharedMemoryArray>();
shmaInputList = new ArrayList<SharedMemoryArray>();
shmaOutputList = new ArrayList<SharedMemoryArray>();
try {
List<String> args = getProcessCommandsWithoutArgs();
List<String> encIns = encodeInputs(inputTensors);
Expand All @@ -355,7 +363,14 @@ public void runInterprocessing(List<Tensor<?>> inputTensors, List<Tensor<?>> out
process = null;
for (int i = 0; i < outputTensors.size(); i ++) {
String name = (String) decodeString(encOuts.get(i)).get(MEM_NAME_KEY);
outputTensors.get(i).setData(SharedMemoryArray.buildImgLib2FromNumpyLikeSHMA(name));
SharedMemoryArray shm = shmaOutputList.stream()
.filter(ss -> ss.getName().equals(name)).findFirst().orElse(null);
if (shm == null) {
shm = SharedMemoryArray.read(name);
shmaOutputList.add(shm);
}
RandomAccessibleInterval<?> rai = shm.getSharedRAI();
outputTensors.get(i).setData(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai))));
}
closeShmas();
} catch (Exception e) {
Expand All @@ -366,13 +381,14 @@ public void runInterprocessing(List<Tensor<?>> inputTensors, List<Tensor<?>> out
}

private void closeShmas() {
shmaList.forEach(shm -> {
shmaInputList.forEach(shm -> {
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
});
// TODO add methos imilar to Python's shared_memory.SharedMemory(name="") in SharedArrays class in JDLL
this.shmaNamesList.forEach(shm -> {
try { SharedMemoryArray.buildImgLib2FromNumpyLikeSHMA(shm); } catch (Exception e1) {}
shmaInputList = null;
shmaOutputList.forEach(shm -> {
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
});
shmaOutputList = null;
}

private static List<String> modifyForWinCmd(List<String> ins){
Expand All @@ -385,26 +401,25 @@ private static List<String> modifyForWinCmd(List<String> ins){
}


private List<String> encodeInputs(List<Tensor<?>> inputTensors) {
int i = 0;
private List<String> encodeInputs(List<Tensor<?>> inputTensors) throws FileAlreadyExistsException {
List<String> encodedInputTensors = new ArrayList<String>();
Gson gson = new Gson();
for (Tensor<?> tt : inputTensors) {
shmaList.add(SharedMemoryArray.buildNumpyLikeSHMA(tt.getData()));
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true);
shmaInputList.add(shma);
HashMap<String, Object> map = new HashMap<String, Object>();
map.put(NAME_KEY, tt.getName());
map.put(SHAPE_KEY, tt.getShape());
map.put(DTYPE_KEY, CommonUtils.getDataType(tt.getData()));
map.put(IS_INPUT_KEY, true);
map.put(MEM_NAME_KEY, shmaList.get(i).getName());
map.put(MEM_NAME_KEY, shma.getName());
encodedInputTensors.add(gson.toJson(map));
i ++;
}
return encodedInputTensors;
}


private List<String> encodeOutputs(List<Tensor<?>> outputTensors) {
private List<String> encodeOutputs(List<Tensor<?>> outputTensors) throws FileAlreadyExistsException {
Gson gson = new Gson();
List<String> encodedOutputTensors = new ArrayList<String>();
for (Tensor<?> tt : outputTensors) {
Expand All @@ -414,13 +429,13 @@ private List<String> encodeOutputs(List<Tensor<?>> outputTensors) {
if (!tt.isEmpty()) {
map.put(SHAPE_KEY, tt.getShape());
map.put(DTYPE_KEY, CommonUtils.getDataType(tt.getData()));
SharedMemoryArray shma = SharedMemoryArray.buildNumpyLikeSHMA(tt.getData());
shmaList.add(shma);
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true);
shmaOutputList.add(shma);
map.put(MEM_NAME_KEY, shma.getName());
} else if (PlatformDetection.isWindows()){
String memName = SharedMemoryArray.createShmName();
SharedMemoryArray shma = SharedMemoryArray.buildSHMA(memName, null);
shmaList.add(shma);
SharedMemoryArray shma = SharedMemoryArray.create(0);
shmaOutputList.add(shma);
map.put(MEM_NAME_KEY, memName);
} else {
String memName = SharedMemoryArray.createShmName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@
* @author Carlos Garcia Lopez de Haro
*/
public class NDArrayShmBuilder {


/**
* Build a shared memory segment from a Pytorch tensor
* @param tensor
* the Pytorch tensor created using JavaCPP
* @param memoryName
* the sahred memory region name
* the shared memory region name
* @return the {@link SharedMemoryArray} object created
* @throws IOException if there is any error creating the shared memory segment
*/
public static SharedMemoryArray buildShma(org.bytedeco.pytorch.Tensor tensor, String memoryName) throws IOException {
return SharedMemoryArray.buildNumpyLikeSHMA(memoryName, Cast.unchecked(ImgLib2Builder.build(tensor)));
public static SharedMemoryArray buildShma(org.bytedeco.pytorch.Tensor tensor, String memoryName) throws IOException {
return SharedMemoryArray.createSHMAFromRAI(memoryName, Cast.unchecked(ImgLib2Builder.build(tensor)), false, true);
}
}

0 comments on commit e8d0cb6

Please sign in to comment.