Skip to content

Commit

Permalink
Merge pull request #39 from jrzaurin/tabnet
Browse files Browse the repository at this point in the history
Tabnet
  • Loading branch information
jrzaurin authored Jun 25, 2021
2 parents b487b06 + bc873a0 commit a71699d
Show file tree
Hide file tree
Showing 71 changed files with 6,338 additions and 2,864 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Untitled*.ipynb
# data related dirs
data/
model_weights/
tmp_dir/
weights/

# Unit Tests/Coverage
Expand Down
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
dist: xenial
language: python
python:
- "3.6"
- "3.7"
- "3.7.9"
- "3.8"
- "3.9"
matrix:
fast_finish: true
include:
Expand Down
68 changes: 45 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
[![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
[![Python 3.6 3.7 3.8 3.9](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/)

# pytorch-widedeep

Expand All @@ -18,12 +18,13 @@ using wide and deep models.

**Documentation:** [https://pytorch-widedeep.readthedocs.io](https://pytorch-widedeep.readthedocs.io/en/latest/index.html)

**Companion posts:** [infinitoml](https://jrzaurin.github.io/infinitoml/)
**Companion posts and tutorials:** [infinitoml](https://jrzaurin.github.io/infinitoml/)

**Experiments and comparisson with `LightGBM`**: [TabularDL vs LightGBM](https://github.com/jrzaurin/tabulardl-benchmark)

### Introduction

`pytorch-widedeep` is based on Google's Wide and Deep Algorithm, [Wide & Deep
Learning for Recommender Systems](https://arxiv.org/abs/1606.07792).
``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792)

In general terms, `pytorch-widedeep` is a package to use deep learning with
tabular data. In particular, is intended to facilitate the combination of text
Expand Down Expand Up @@ -84,7 +85,7 @@ It is important to emphasize that **each individual component, `wide`,
isolation. For example, one could use only `wide`, which is in simply a linear
model. In fact, one of the most interesting functionalities
in``pytorch-widedeep`` is the ``deeptabular`` component. Currently,
``pytorch-widedeep`` offers 3 models for that component:
``pytorch-widedeep`` offers 4 models for that component:

1. ``TabMlp``: this is almost identical to the [tabular
model](https://docs.fast.ai/tutorial.tabular.html) in the fantastic
Expand All @@ -95,12 +96,15 @@ features, and passed then through a MLP.
2. ``TabRenset``: This is similar to the previous model but the embeddings are
passed through a series of ResNet blocks built with dense layers.

3. ``TabTransformer``: Details on the TabTransformer can be found in:
3. ``Tabnet``: Details on TabNet can be found in:
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)

4. ``TabTransformer``: Details on the TabTransformer can be found in:
[TabTransformer: Tabular Data Modeling Using Contextual
Embeddings](https://arxiv.org/pdf/2012.06678.pdf)


For details on these 3 models and their options please see the examples in the
For details on these 4 models and their options please see the examples in the
Examples folder and the documentation.

Finally, while I recommend using the ``wide`` and ``deeptabular`` models in
Expand Down Expand Up @@ -139,20 +143,20 @@ cd pytorch-widedeep
pip install -e .
```

**Important note for Mac users**: at the time of writing (Feb-2020) the latest
`torch` release is `1.7.1`. This release has some
**Important note for Mac users**: at the time of writing (June-2021) the
latest `torch` release is `1.9`. Some past
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
when running on Mac and the data-loaders will not run in parallel. In
addition, since `python 3.8`, [the `multiprocessing` library start method
changed from `'fork'` to
when running on Mac, present in previous versions, persist on this release and
the data-loaders will not run in parallel. In addition, since `python 3.8`,
[the `multiprocessing` library start method changed from `'fork'` to
`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
This also affects the data-loaders (for any `torch` version) and they will not
run in parallel. Therefore, for Mac users I recommend using `python 3.6` or
`3.7` and `torch <= 1.6` (with the corresponding, consistent version of
This also affects the data-loaders (for any `torch` version) and they will
not run in parallel. Therefore, for Mac users I recommend using `python 3.6`
or `3.7` and `torch <= 1.6` (with the corresponding, consistent version of
`torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this
versioning in the `setup.py` file since I expect that all these issues are
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip or
directly from github, downgrade `torch` and `torchvision` manually:
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip
or directly from github, downgrade `torch` and `torchvision` manually:

```bash
pip install pytorch-widedeep
Expand All @@ -167,16 +171,13 @@ Binary classification with the [adult
dataset]([adult](https://www.kaggle.com/wenruliu/adult-income-dataset))
using `Wide` and `DeepDense` and defaults settings.


```python
```

Building a wide (linear) and deep model with ``pytorch-widedeep``:

```python

import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split

from pytorch_widedeep import Trainer
Expand Down Expand Up @@ -248,8 +249,29 @@ X_wide_te = wide_preprocessor.transform(df_test)
X_tab_te = tab_preprocessor.transform(df_test)
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)

# save and load
trainer.save_model("model_weights/model.t")
# Save and load

# Option 1: this will also save training history and lr history if the
# LRHistory callback is used
trainer.save(path="model_weights", save_state_dict=True)

# Option 2: save as any other torch model
torch.save(model.state_dict(), "model_weights/wd_model.pt")

# From here in advance, Option 1 or 2 are the same. I assume the user has
# prepared the data and defined the new model components:
# 1. Build the model
model_new = WideDeep(wide=wide, deeptabular=deeptabular)
model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))

# 2. Instantiate the trainer
trainer_new = Trainer(
model_new,
objective="binary",
)

# 3. Either start the fit or directly predict
preds = trainer_new.predict(X_wide=X_wide, X_tab=X_tab)
```

Of course, one can do **much more**. See the Examples folder, the
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.4.8
1.0.0
3 changes: 3 additions & 0 deletions docs/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ Here are the 4 callbacks available in ``pytorch-widedepp``: ``History``,
.. autoclass:: pytorch_widedeep.callbacks.History
:members:

.. autoclass:: pytorch_widedeep.callbacks.LRShedulerCallback
:members:

.. autoclass:: pytorch_widedeep.callbacks.LRHistory
:members:

Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
# Remove the prompt when copying examples
copybutton_prompt_text = ">>> "

autoclass_content = "init" # 'both'
# autoclass_content = "init" # 'both'
autodoc_member_order = "bysource"
# autodoc_default_flags = ["show-inheritance"]

Expand Down
1 change: 1 addition & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ them to address different problems
* `Regression with Images and Text <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/05_Regression_with_Images_and_Text.ipynb>`__
* `FineTune routines <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/06_FineTune_and_WarmUp_Model_Components.ipynb>`__
* `Custom Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/07_Custom_Components.ipynb>`__
* `Save and Load Model and Artifacts <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/08_save_and_load_model_and_artifacts.ipynb>`__
12 changes: 7 additions & 5 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ deeptabular, deeptext and deepimage, can be used independently** and in
isolation. For example, one could use only ``wide``, which is in simply a
linear model. In fact, one of the most interesting offerings of
``pytorch-widedeep`` is the ``deeptabular`` component. Currently,
``pytorch-widedeep`` offers 3 models for that component:
``pytorch-widedeep`` offers 4 models for that component:

1. ``TabMlp``: this is almost identical to the `tabular
model <https://docs.fast.ai/tutorial.tabular.html>`_ in the fantastic
Expand All @@ -101,12 +101,14 @@ features, and passed then through a MLP.
2. ``TabRenset``: This is similar to the previous model but the embeddings are
passed through a series of ResNet blocks built with dense layers.

3. ``TabTransformer``: Details on the TabTransformer can be found in:
`TabTransformer: Tabular Data Modeling Using Contextual
Embeddings <https://arxiv.org/pdf/2012.06678.pdf>`_.
3. ``Tabnet``: Details on TabNet can be found in: `TabNet: Attentive
Interpretable Tabular Learning <https://arxiv.org/abs/1908.07442>`_.

4. ``TabTransformer``: Details on the TabTransformer can be found in:
`TabTransformer: Tabular Data Modeling Using Contextual Embeddings
<https://arxiv.org/pdf/2012.06678.pdf>`_.

For details on these 3 models and their options please see the examples in the
For details on these 4 models and their options please see the examples in the
Examples folder and the documentation.

Finally, while I recommend using the ``wide`` and ``deeptabular`` models in
Expand Down
7 changes: 4 additions & 3 deletions docs/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ Losses
======

``pytorch-widedeep`` accepts a number of losses and objectives that can be
passed to the ``Trainer`` class via the ``str`` parameter ``objective`` (see
``pytorch-widedeep.training.Trainer``). For most cases the loss function that
``pytorch-widedeep`` will use internally is already implemented in Pytorch.
passed to the ``Trainer`` class via the parameter ``objective``
(see ``pytorch-widedeep.training.Trainer``). For most cases the loss function
that ``pytorch-widedeep`` will use internally is already implemented in
Pytorch.

In addition, ``pytorch-widedeep`` implements four "custom" loss functions.
These are described below for completion since, as I mentioned before, they
Expand Down
10 changes: 7 additions & 3 deletions docs/model_components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ This module contains the four main components that will comprise a Wide and
Deep model, and the ``WideDeep`` "constructor" class. These four components
are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.

.. note:: ``TabMlp``, ``TabResnet`` and ``TabTransformer`` can all be used
as the ``deeptabular`` component of the model and simply represent
different alternatives
.. note:: ``TabMlp``, ``TabResnet``, ``TabNet`` and ``TabTransformer`` can all
be used as the ``deeptabular`` component of the model and simply
represent different alternatives

.. autoclass:: pytorch_widedeep.models.wide.Wide
:exclude-members: forward
Expand All @@ -21,6 +21,10 @@ are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
:exclude-members: forward
:members:

.. autoclass:: pytorch_widedeep.models.tabnet.tab_net.TabNet
:exclude-members: forward
:members:

.. autoclass:: pytorch_widedeep.models.tab_transformer.TabTransformer
:exclude-members: forward
:members:
Expand Down
37 changes: 35 additions & 2 deletions docs/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,41 @@ Fit and predict
X_tab_te = tab_preprocessor.transform(df_test)
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)
# save and load
trainer.save_model("model_weights/model.t")
Save and load
-------------------------------

.. code-block:: python
# Option 1: this will also save training history and lr history if the
# LRHistory callback is used
# Day 0, you have trained your model, save it using the trainer.save
# method
trainer.save(path="model_weights", save_state_dict=True)
# Option 2: save as any other torch model
# Day 0, you have trained your model, save as any other torch model
torch.save(model.state_dict(), "model_weights/wd_model.pt")
# From here in advance, Option 1 or 2 are the same
# Few days have passed...I assume the user has prepared the data and
# defined the model components:
# 1. Build the model
model_new = WideDeep(wide=wide, deeptabular=deeptabular)
model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))
# 2. Instantiate the trainer
trainer_new = Trainer(
model_new,
objective="binary",
)
# 3. Either fit or directly predict
preds = trainer_new.predict(X_wide=X_wide, X_tab=X_tab)
Of course, one can do **much more**. See the Examples folder in the repo, this
documentation or the companion posts for a better understanding of the content
Expand Down
Loading

0 comments on commit a71699d

Please sign in to comment.