This shows how to use a pytorch model in C++ in an Android app.
Running this will require compiling pytorch for android and setting TORCHPATH
in app/src/main/cpp/CMakeLists.txt
.
Also requires glog
: sudo apt install libgoogle-glog-dev
The model was created like this:
import torch as to
import torch.nn as nn
class Mod(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3, 2)
def forward(self, x):
return self.fc(x)
def main():
model = Mod()
x = to.rand(1, 3)
traced_module = to.jit.trace(model, x)
traced_module.save('traced_model.pt')
main()