-
Notifications
You must be signed in to change notification settings - Fork 492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save TabNet in ONNX format #277
Comments
I'm not familiar with ONNX, but it would be quite easy to save the network as traced script (from pytorch jit: https://pytorch.org/docs/stable/jit.html), which could be used for inference without the need of pytorch-tabnet but also without python itself (can be called in C++ only). Training the model and then saving the Feel free to open a PR giving examples for either ONNX or jit and I'll be happy to review (not sure about adding onnx as a dependency in the repo however), if you have questions I might help for jit but I guess it would look like something like this:
I think that's it! (this will only trace the forward, without explanation, you'll need to create a wrapper with a custom forward function to get both preds and explanations) |
Hi Optimo, thanks for your quick answer. My main use case is to be able to pack the trained model in a REST service for predictions. In Python. As far as I understand (I'm not a great expert of PyTorch) a TabNetClassifier is a torch.nn.Module, so as explained here: https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html we should be able to export through tracing, using torch.onnx.export as soon as I have time, I'll follow, need to see what is the meaning of the params I agree with you that adding onnx to TabNet is not what should be done, I was thinking to add an example of Notebooks and best practices. |
It seems more difficult than I expected. When I call:
I get the following error: RuntimeError: Only tuples, lists, and Variables supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type numpy.ndarray It seems:
What kind of data can I pass as input? Only NumPy array? Strange enough since under the cover it is PyTorch. Should be easy to accept Tensor |
Hi, I will try to see how this can work. |
Only Python. I want to export the model trained using ONNX to be able to (for example) develop a REST service and avoid having to install pytorch-tabnet, using only ONNX runtime |
For that, you can use pytorch save method, but I will come back with some tests using onnx |
@luigisaetta I think trying to ONNXify (whatever this is called) the entire class
|
@Optimox You are right, only the network should be exported, but I have to check also how the input format is (only one input or several for the embeddings and so on) I will see if we can setup some kind of optional deps to have a custom exporter, or at least to have a notebook concerning this https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html# Will try to have a look this week. |
Hi, as you can see from my comment above, I get error when I apply the torch.onnx.export to clf.network. So, as far as I understand I'm trying to export only the network. But, since it wants an input... it doesn't accept NumPy array. |
@luigisaetta sorry I wrote on the docstring example numpy arrray but it should be a torch.Tensor, does it work with tensors? |
No, it doesn't. I'll find my test and post here the error. Basically, I think it calls a method that exists in numpy array but not in Torch tensor. |
@Optimox I have published an article on TowardsDataScience, https://towardsdatascience.com/pytorch-tabnet-integration-with-mlflow-cb14f3920cb0 |
Hello @luigisaetta, Great article, very detailed! I'm happy to see that integration with MLflow is made so easy by the callbacks. @queraq worked on this and I'm sure neither of us had this specific usage in mind. There is one part of the article where I think a bit of clarification would be welcomed: the Encoder-Decoder part. In fact TabNet models are only sort of encoders (plus all the sequential attention part), there is no decoder at all. The only reason there exists a decoder part is to enable self-supervised pre-training, which needs a decoder. Since you do not mention pre-training in the article I think you should not talk about a decoder-encoder model, or maybe you could add a paragraph about Anyway great article, thanks for sharing with us and giving credits to the repo. But... the article does not tell me if you managed to get ONNX format working?! :) Cheers! |
@Optimox regarding onnx, no I didn't make any progress. The point where I become blocked is that TabNet doesn't seem to accept tensors as input. I think in the code it calls some methods existing only for numpy array. Have you any suggestion? |
hmm actually I think I know. If you have a look a this file https://github.com/dreamquark-ai/tabnet/blob/develop/pytorch_tabnet/tab_network.py where everything happens about the network, we are actually using numpy for some stupid reason (laziness and bad habits mainly). I think it would be very easy to replace all the
I don't have much time at the moment but I'll definitely change that. If you want to make those changes and see if it works for ONNX don't hesitate. You can also open a PR and I'll review it carefully. Otherwise I'll do this as soon as I can or maybe @eduardocarvp will have a look before me? I think we might have found your problem :) |
Hi, has anyone made any progress on this yet? would be really appreciated if you could share a bit about how the export would work. Right now I am stuck at exporting and the error tells me this: Could anyone help? |
Hi! I tried this but I get this error: Is there any way to fix this? |
Yes you probably need to 'scriptify' sparsemax (and entmax functions) so that they can be accepted for tracing. I don't know how hard it would be, you can try adding @script on top of the definition of sparsemax and entmax and see if it works. |
Thanks for the reply. Im trying to speed up inference and torchscript is one way I was trying. Is there any other more straightforward method you would suggest to speed it up before I try this for torchscript? |
Oh guys, is it still can not export to ONNX right now? |
Feature request
Is it possible to save TabNet in ONNX format?
What is the expected behavior?
What is motivation or use case for adding/changing the behavior?
ONNX is quickly becoming the de-facto standard to save models, even because these way you avoid to import packages when you want to pack for inference.
How should this be implemented in your opinion?
Are you willing to work on this yourself?
well, for now don't have a precise idea, but willing to give some help if I have some suggestion where to start.
The feature probably could be implemented as a NotBook example, therefore with no needed changes to the core implementation.
The text was updated successfully, but these errors were encountered: