Skip to content

Commit

Permalink
keep updating for interprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Dec 14, 2023
1 parent 6da9983 commit b0ea6c6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ private PytorchJavaCPPInterface(boolean doInterprocessing)
public static < T extends RealType< T > & NativeType< T > > void main(String[] args) throws LoadModelException, RunModelException {
if (args.length == 0) {

String modelFolder = "C:\\Users\\angel\\OneDrive\\Documentos\\pasteur\\git\\deep-icy\\models\\Neuron Segmentation in EM (Membrane Prediction)_30102023_192607";
String modelSourc = modelFolder + "\\weights-torchscript.pt";
String modelFolder = "/home/carlos/git/deep-icy/models/Neuron Segmentation in EM (Membrane Prediction)_30102023_192607";
String modelSourc = modelFolder + "/weights-torchscript.pt";
PytorchJavaCPPInterface pi = new PytorchJavaCPPInterface();
pi.loadModel(modelFolder, modelSourc);
RandomAccessibleInterval<FloatType> rai = ArrayImgs.floats(new long[] {1, 1, 16, 144, 144});
Expand Down Expand Up @@ -176,7 +176,7 @@ public static < T extends RealType< T > & NativeType< T > > void main(String[] a
HashMap<String, Object> map = gson.fromJson(args[i], mapType);
if ((boolean) map.get(IS_INPUT_KEY)) {
RandomAccessibleInterval<T> rai = SharedMemoryArray.buildImgLib2FromNumpyLikeSHMA((String) map.get(MEM_NAME_KEY));
inputsVector.put(new IValue(JavaCPPTensorBuilder.build(rai)));
inputsVector.put(new IValue(JavaCPPTensorBuilder.buildFromRai(rai)));
}
}
// Run model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
Expand Down Expand Up @@ -59,7 +58,7 @@ public class JavaCPPTensorBuilder {
*/
public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch.Tensor build(Tensor<T> tensor) throws IllegalArgumentException
{
return build(tensor.getData());
return buildFromRai(tensor.getData());
}

/**
Expand All @@ -72,7 +71,7 @@ public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch
* @return The {@link org.bytedeco.pytorch.Tensor} built from the {@link RandomAccessibleInterval}.
* @throws IllegalArgumentException if the {@link RandomAccessibleInterval} is not supported
*/
public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch.Tensor build(RandomAccessibleInterval<T> tensor) throws IllegalArgumentException
public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch.Tensor buildFromRai(RandomAccessibleInterval<T> tensor) throws IllegalArgumentException
{
if (Util.getTypeFromInterval(tensor) instanceof ByteType) {
return buildFromTensorByte(Cast.unchecked(tensor));
Expand Down

0 comments on commit b0ea6c6

Please sign in to comment.