diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/JavaWorker.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/JavaWorker.java index 554f933..6e5d7e5 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/JavaWorker.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/JavaWorker.java @@ -87,8 +87,10 @@ private void executeScript(String script, Map inputs) { this.reportLaunch(); try { if (script.equals("loadModel")) { + System.out.println("STATY IN WORKER LOAD"); pi.loadModel((String) inputs.get("modelFolder"), (String) inputs.get("modelSource")); } else if (script.equals("inference")) { + System.out.println("STATY IN WORKER"); pi.runFromShmas((List) inputs.get("inputs"), (List) inputs.get("outputs")); } else if (script.equals("close")) { pi.closeModel(); diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java index 840bc4f..2bec60e 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java @@ -239,17 +239,21 @@ void run(List> inputTensors, List> outputTensors) throws Run } protected void runFromShmas(List inputs, List outputs) throws IOException { - + System.out.println("REACH0"); IValueVector inputsVector = new IValueVector(); + System.out.println("REACH1"); for (String ee : inputs) { + System.out.println("REACH2"); Map decoded = Types.decode(ee); SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY)); org.bytedeco.pytorch.Tensor inT = TensorBuilder.build(shma); inputsVector.put(new IValue(inT)); if (PlatformDetection.isWindows()) shma.close(); } + System.out.println("REACH3"); // Run model model.eval(); + System.out.println("REACH4"); IValue output = model.forward(inputsVector); TensorVector outputTensorVector = null; if (output.isTensorList()) {