Skip to content

Commit

Permalink
more prints
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 23, 2024
1 parent 76cfc77 commit fe05822
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ private void executeScript(String script, Map<String, Object> inputs) {
this.reportLaunch();
try {
if (script.equals("loadModel")) {
System.out.println("STATY IN WORKER LOAD");
pi.loadModel((String) inputs.get("modelFolder"), (String) inputs.get("modelSource"));
} else if (script.equals("inference")) {
System.out.println("STATY IN WORKER");
pi.runFromShmas((List<String>) inputs.get("inputs"), (List<String>) inputs.get("outputs"));
} else if (script.equals("close")) {
pi.closeModel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,21 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws Run
}

protected void runFromShmas(List<String> inputs, List<String> outputs) throws IOException {

System.out.println("REACH0");
IValueVector inputsVector = new IValueVector();
System.out.println("REACH1");
for (String ee : inputs) {
System.out.println("REACH2");
Map<String, Object> decoded = Types.decode(ee);
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
org.bytedeco.pytorch.Tensor inT = TensorBuilder.build(shma);
inputsVector.put(new IValue(inT));
if (PlatformDetection.isWindows()) shma.close();
}
System.out.println("REACH3");
// Run model
model.eval();
System.out.println("REACH4");
IValue output = model.forward(inputsVector);
TensorVector outputTensorVector = null;
if (output.isTensorList()) {
Expand Down

0 comments on commit fe05822

Please sign in to comment.