Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use TabNet with standard PyTorch model API for data too large to fit in memory #378

Closed
jlehrer1 opened this issue Apr 3, 2022 · 6 comments
Assignees
Labels
enhancement New feature or request

Comments

@jlehrer1
Copy link

jlehrer1 commented Apr 3, 2022

Feature request

Although the sklearn API is quite nice for ease-of-use, it would also be great to use the TabNet model with the standard PyTorch API.

What is the expected behavior?
Call call net = TabNet(100, 10), and use net(sample) and optimizer.backward() to train the model via SGD.

What is motivation or use case for adding/changing the behavior?
There are many cases where writing a manual train loop is preferred, especially when I want to hot-swap this model into an already existing pipeline, or the dataset is too large to fit in memory and can only be accessed sample-wise. This is my entire reason for using TabNet over XGBoost, where creating a dataset distributed in memory is not trivial in certain cases.

How should this be implemented in your opinion?
I see that pytorch_tabnet.tab_network.TabNet already exists. What I'm unsure about is the output of the forward pass. It seems to contain both the outputs of a forward pass, as well as M_loss defined in the encoder. Should I be using this loss, or a standard CrossEntropy loss for classification?

Are you willing to work on this yourself?
Yes! This should be a simple thing to do, I just need to know if M_loss can be ignored in the output of a forward pass of pytorch_tabnet.tab_network.TabNet.

@jlehrer1 jlehrer1 added the enhancement New feature or request label Apr 3, 2022
@Optimox
Copy link
Collaborator

Optimox commented Apr 3, 2022

Hello @jlehrer1,

The M_loss part is used to add extra sparsity as you can see here :

loss = loss - self.lambda_sparse * M_loss

So you can either ignore M_loss and this will be the same as using lambda_sparse=0 or you can add this to your loss in your custom pipeline to add sparsity constraint (as in the original paper).

@Optimox
Copy link
Collaborator

Optimox commented Apr 4, 2022

related to #143

@jlehrer1
Copy link
Author

@Optimox Gotcha, thank you. Is the correct base model pytorch_tabnet.tab_model.TabNetClassifier or should I be using pytorch_tabnet.tab_network.TabNet as my PyTorch equivalent to the sklearn API wrapper?

@jlehrer1
Copy link
Author

jlehrer1 commented Apr 11, 2022

It would be really nice to use the base model with raw PyTorch, as in having the usual class Model(torch.nn.Module). Should I submit as a feature request? :). I'm happy to work on it if just pointed in the right direction!

@Optimox
Copy link
Collaborator

Optimox commented Apr 12, 2022

I'm not sure I understand what feature you are requesting : every class in this file inherits from torch.nn.Module https://github.com/dreamquark-ai/tabnet/blob/develop/pytorch_tabnet/tab_network.py

Feel free to reuse it and insert it in your own pipeline.

@jlehrer1
Copy link
Author

Perfect, thanks. I wasn't sure if there was any extra logic outside of tab_network.py besides adding the M_loss, but it seems there isn't (besides the explainability).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants