Skip to content

Commit

Permalink
add jna and json paths to process classpath
Browse files Browse the repository at this point in the history
  • Loading branch information
Cgarcia authored and Cgarcia committed Dec 15, 2023
1 parent 60daf6d commit bcbfb34
Showing 1 changed file with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ private PytorchJavaCPPInterface(boolean doInterprocessing)
public static < T extends RealType< T > & NativeType< T > > void main(String[] args) throws LoadModelException, RunModelException {
if (args.length == 0) {

String modelFolder = "/home/carlos/git/deep-icy/models/Neuron Segmentation in EM (Membrane Prediction)_07122023_193930";
String modelFolder = "/Users/Cgarcia/git/deep-icy/models/nse";
String modelSourc = modelFolder + "/weights-torchscript.pt";
PytorchJavaCPPInterface pi = new PytorchJavaCPPInterface();
pi.loadModel(modelFolder, modelSourc);
RandomAccessibleInterval<FloatType> rai = ArrayImgs.floats(new long[] {1, 1, 16, 144, 144});
Tensor<?> inp = Tensor.build("aa", "bczyx", rai);
Tensor<?> out = Tensor.buildEmptyTensor("oo", "bczyx");
Tensor<?> out = Tensor.build("oo", "bczyx", rai);
List<Tensor<?>> ins = new ArrayList<Tensor<?>>();
List<Tensor<?>> ous = new ArrayList<Tensor<?>>();
ins.add(inp);
Expand Down Expand Up @@ -449,10 +449,16 @@ private List<String> getProcessCommandsWithoutArgs() throws IOException, URISynt

String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class);
String imglib2Path = getPathFromClass(NativeType.class);
String gsonPath = getPathFromClass(Gson.class);
String jnaPath = getPathFromClass(com.sun.jna.Library.class);
String jnaPlatformPath = getPathFromClass(com.sun.jna.platform.FileUtils.class);
if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class")
&& !modelrunnerPath.contains(File.pathSeparator)))
modelrunnerPath = System.getProperty("java.class.path");
String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator;
classpath = classpath + gsonPath + File.pathSeparator;
classpath = classpath + jnaPath + File.pathSeparator;
classpath = classpath + jnaPlatformPath + File.pathSeparator;
ProtectionDomain protectionDomain = PytorchJavaCPPInterface.class.getProtectionDomain();
String codeSource = protectionDomain.getCodeSource().getLocation().getPath();
String f_name = URLDecoder.decode(codeSource, StandardCharsets.UTF_8.toString());
Expand Down

0 comments on commit bcbfb34

Please sign in to comment.