[Pytorch] torch class usage question #1128
-
Hello, A big thank you to the developers for this project. I am trying to use the pytorch project. I got stuck on something that is probably really dumb but after about 2 hours, I thought I should ask. The torch class has methods stack and cat that accept a TensorArrayRef instance. In python, one would generally do something like:
I am trying to replicate the above in Java as below:
I just cannot figure out how to use the TensorArrayRef class. Thanks in advance for your help. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
cat(new TensorArrayRef(new Tensor(2).put(a).position(1).put(b).position(0), 2), 1); But yeah that's not super user-friendly. There is also a constructor taking cat(new TensorArrayRef(new TensorVector(a, b)), 1); Please give it a try with the snapshots: http://bytedeco.org/builds/ |
Beta Was this translation helpful? Give feedback.
c10::ArrayRef<>
is a wrapper around arrays, in this case arrays ofTensor
objects, so something like this should work:But yeah that's not super user-friendly. There is also a constructor taking
std::vector<>
, so I've just mapped that in commit eeee3e7. With that we can do something like this instead, which is already much better:Please give it a try with the snapshots: http://bytedeco.org/builds/
And thanks for reporting this issue!