From 65b4ed6fc7300cc6bdc65e1433fdc32d4792cc37 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Wed, 2 Oct 2024 13:41:59 +0200 Subject: [PATCH] increase robustness sending to other process --- .../io/bioimage/modelrunner/pytorch/javacpp/JavaWorker.java | 2 +- .../pytorch/javacpp/PytorchJavaCPPInterface.java | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/JavaWorker.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/JavaWorker.java index adb98cc..e353b2f 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/JavaWorker.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/JavaWorker.java @@ -87,7 +87,7 @@ private void executeScript(String script, Map inputs) { this.reportLaunch(); try { if (script.equals("loadModel")) { - pi.loadModel((String) inputs.get("modelFolder"), null); + pi.loadModel((String) inputs.get("modelFolder"), (String) inputs.get("modelSource")); } else if (script.equals("inference")) { pi.runFromShmas((List) inputs.get("inputs"), (List) inputs.get("outputs")); } else if (script.equals("close")) { diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java index 43dcf40..3597fdb 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java @@ -94,6 +94,8 @@ public class PytorchJavaCPPInterface implements DeepLearningEngineInterface */ private JitModule model; + private String modelFolder; + private String modelSource; private boolean interprocessing = true; @@ -157,6 +159,7 @@ private Service getRunner() throws IOException, URISyntaxException { */ @Override public void loadModel(String modelFolder, String modelSource) throws LoadModelException { + this.modelFolder = modelFolder; this.modelSource = modelSource; if (interprocessing) { try { @@ -177,7 +180,8 @@ public void loadModel(String modelFolder, String modelSource) throws LoadModelEx private void launchModelLoadOnProcess() throws IOException, InterruptedException { HashMap args = new HashMap(); - args.put("modelFolder", this.modelSource); + args.put("modelFolder", this.modelFolder); + args.put("modelSource", this.modelSource); Task task = runner.task("loadModel", args); task.waitFor(); if (task.status == TaskStatus.CANCELED)