Skip to content

Commit

Permalink
remove all prints used to debug
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 24, 2024
1 parent 2a12088 commit 3a5e528
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ public static void main(String[] args) {
try {
pi = new PytorchJavaCPPInterface(false);
} catch (IOException | URISyntaxException e) {
e.printStackTrace();
return;
}

Expand Down Expand Up @@ -88,16 +87,11 @@ private void executeScript(String script, Map<String, Object> inputs) {
this.reportLaunch();
try {
if (script.equals("loadModel")) {
update("STATY IN WORKER LOAD LOAD", null, null);
pi.loadModel((String) inputs.get("modelFolder"), (String) inputs.get("modelSource"));
} else if (script.equals("inference")) {
update("STATY IN WORKER ------------RUN", null, null);
pi.runFromShmas((List<String>) inputs.get("inputs"), (List<String>) inputs.get("outputs"));
} else if (script.equals("close")) {
pi.closeModel();
} else {
update("LOL WTF", null, null);
update("LOL WTF -------------- " + script, null, null);
}
} catch(Exception | Error ex) {
this.fail(Types.stackTrace(ex));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,28 +239,21 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws Run
}

protected void runFromShmas(List<String> inputs, List<String> outputs) throws IOException {
System.out.println("REACH0");
IValueVector inputsVector = new IValueVector();
System.out.println("REACH1");
for (String ee : inputs) {
System.out.println("REACH2");
Map<String, Object> decoded = Types.decode(ee);
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
org.bytedeco.pytorch.Tensor inT = TensorBuilder.build(shma);
inputsVector.put(new IValue(inT));
if (PlatformDetection.isWindows()) shma.close();
}
System.out.println("REACH3");
// Run model
model.eval();
System.out.println("REACH4");
IValue output = model.forward(inputsVector);
TensorVector outputTensorVector = null;
if (output.isTensorList()) {
System.out.println("SSECRET_KEY : 1 ");
outputTensorVector = output.toTensorVector();
} else {
System.out.println("SSECRET_KEY : 2 ");
outputTensorVector = new TensorVector();
outputTensorVector.put(output.toTensor());
}
Expand All @@ -269,7 +262,6 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
int c = 0;
for (String ee : outputs) {
Map<String, Object> decoded = Types.decode(ee);
System.out.println("ENTERED: " + ee);
ShmBuilder.build(outputTensorVector.get(c ++), (String) decoded.get(MEM_NAME_KEY));
}
outputTensorVector.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,14 @@ public static void build(Tensor tensor, String memoryName) throws IllegalArgumen
{
if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Byte)
|| tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Char)) {
System.out.println("SSECRET_KEY : BYTE ");
buildFromTensorByte(tensor, memoryName);
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Int)) {
System.out.println("SSECRET_KEY : INT ");
buildFromTensorInt(tensor, memoryName);
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Float)) {
System.out.println("SSECRET_KEY : FLOAT ");
buildFromTensorFloat(tensor, memoryName);
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Double)) {
System.out.println("SSECRET_KEY : SOUBKE ");
buildFromTensorDouble(tensor, memoryName);
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Long)) {
System.out.println("SSECRET_KEY : LONG ");
buildFromTensorLong(tensor, memoryName);
} else {
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.scalar_type());
Expand All @@ -98,10 +93,9 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
byte[] flat = new byte[(int) flatSize];
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize));
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize)).order(ByteOrder.LITTLE_ENDIAN);
tensor.data_ptr_byte().get(flat);
byteBuffer.put(flat);
byteBuffer.rewind();
shma.getDataBufferNoHeader().put(byteBuffer);
if (PlatformDetection.isWindows()) shma.close();
}
Expand All @@ -116,11 +110,10 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
int[] flat = new int[(int) flatSize];
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Integer.BYTES));
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Integer.BYTES)).order(ByteOrder.LITTLE_ENDIAN);
IntBuffer floatBuffer = byteBuffer.asIntBuffer();
tensor.data_ptr_int().get(flat);
floatBuffer.put(flat);
byteBuffer.rewind();
shma.getDataBufferNoHeader().put(byteBuffer);
if (PlatformDetection.isWindows()) shma.close();
}
Expand All @@ -140,10 +133,6 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
tensor.data_ptr_float().get(flat);
floatBuffer.put(flat);
shma.getDataBufferNoHeader().put(byteBuffer);
System.out.println("equals " + (shma.getDataBufferNoHeader().get(100) == byteBuffer.get(100)));
System.out.println("equals " + (shma.getDataBufferNoHeader().get(500) == byteBuffer.get(500)));
System.out.println("equals " + (shma.getDataBufferNoHeader().get(300) == byteBuffer.get(300)));
System.out.println("equals " + (shma.getDataBufferNoHeader().get(1000) == byteBuffer.get(1000)));
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -157,11 +146,10 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
double[] flat = new double[(int) flatSize];
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Double.BYTES));
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Double.BYTES)).order(ByteOrder.LITTLE_ENDIAN);
DoubleBuffer floatBuffer = byteBuffer.asDoubleBuffer();
tensor.data_ptr_double().get(flat);
floatBuffer.put(flat);
byteBuffer.rewind();
shma.getDataBufferNoHeader().put(byteBuffer);
if (PlatformDetection.isWindows()) shma.close();
}
Expand All @@ -176,11 +164,10 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
long[] flat = new long[(int) flatSize];
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Long.BYTES));
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Long.BYTES)).order(ByteOrder.LITTLE_ENDIAN);
LongBuffer floatBuffer = byteBuffer.asLongBuffer();
tensor.data_ptr_long().get(flat);
floatBuffer.put(flat);
byteBuffer.rewind();
shma.getDataBufferNoHeader().put(byteBuffer);
if (PlatformDetection.isWindows()) shma.close();
}
Expand Down

0 comments on commit 3a5e528

Please sign in to comment.