diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java index 8bc16e6..840bc4f 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java @@ -253,8 +253,10 @@ protected void runFromShmas(List inputs, List outputs) throws IO IValue output = model.forward(inputsVector); TensorVector outputTensorVector = null; if (output.isTensorList()) { + System.out.println("SSECRET_KEY : 1 "); outputTensorVector = output.toTensorVector(); } else { + System.out.println("SSECRET_KEY : 2 "); outputTensorVector = new TensorVector(); outputTensorVector.put(output.toTensor()); } @@ -263,6 +265,7 @@ protected void runFromShmas(List inputs, List outputs) throws IO int c = 0; for (String ee : outputs) { Map decoded = Types.decode(ee); + System.out.println("ENTERED: " + ee); ShmBuilder.build(outputTensorVector.get(c ++), (String) decoded.get(MEM_NAME_KEY)); } outputTensorVector.close(); diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java index 6637008..c5b2486 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java @@ -68,14 +68,19 @@ public static void build(Tensor tensor, String memoryName) throws IllegalArgumen { if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Byte) || tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Char)) { + System.out.println("SSECRET_KEY : BYTE "); buildFromTensorByte(tensor, memoryName); } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Int)) { + System.out.println("SSECRET_KEY : INT "); buildFromTensorInt(tensor, memoryName); } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Float)) { + System.out.println("SSECRET_KEY : FLOAT "); buildFromTensorFloat(tensor, memoryName); } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Double)) { + System.out.println("SSECRET_KEY : SOUBKE "); buildFromTensorDouble(tensor, memoryName); } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Long)) { + System.out.println("SSECRET_KEY : LONG "); buildFromTensorLong(tensor, memoryName); } else { throw new IllegalArgumentException("Unsupported tensor type: " + tensor.scalar_type());