-
Notifications
You must be signed in to change notification settings - Fork 12
PyTorch models
DeepImageJ can load Pytorch models making use of a third-party library called Deep Java Library (DJL), developed by the Amazon Web Services Team. Find detailed documentation of Pytorch DJL. Internally, DJL is slightly different than the Tensorflow Java library, even though they both call the C++ code of either Pytorch or Tensorflow through a Java jni .jar executable. Tensorflow Java works with tensors and DJL works with NDArrays, which try to mimic the Python Numpy arrays and overcome the difficulty of handling Java arrays.
As DJL makes use of the Python C++ API, the models have to be saved in TorchScript format. The latter does not add complexity to coding in Python as it only implies adding 2 extra lines of code:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18(pretrained=True)
# Switch the model to eval model
model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
# Save the TorchScript model
traced_script_module.save("traced_resnet_model.pt")
DeepImageJ is installed using version 1.7 of DJL which supports Pytorch 1.7. This Pytorch version should be backward compatible with previous versions of DJL. However, the Pytorch version can be changed manually replacing the corresponding executable .jar files in the jars
folder. This process is similar to changing the Tensorflow Java library version manually.
For the compatibility with Windows OS, DJL requires the installation of Visual Studio 2019 redistributable.
Introduction:
User Guide:
Model Developers Guide: