Skip to content

Commit

Permalink
Merge pull request #49 from jrzaurin/jrzaurin/perceiver
Browse files Browse the repository at this point in the history
Jrzaurin/perceiver
  • Loading branch information
jrzaurin authored Sep 7, 2021
2 parents 0c79deb + e865fb2 commit 86217bb
Show file tree
Hide file tree
Showing 57 changed files with 5,103 additions and 2,692 deletions.
88 changes: 46 additions & 42 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
[![Python 3.6 3.7 3.8 3.9](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/)
[![Python 3.7 3.8 3.9](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/)

# pytorch-widedeep

Expand All @@ -24,6 +24,13 @@ using wide and deep models.

**Slack**: if you want to contribute or just want to chat with us, join [slack](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)

The content of this document is organized as follows:

1. [introduction](#introduction)
2. [The deeptabular component](#the-deeptabular-component)
3. [installation](#installation)
4. [quick start (tl;dr)](#quick-start)

### Introduction

``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792)
Expand Down Expand Up @@ -82,61 +89,58 @@ into:
<img width="300" src="docs/figures/architecture_2_math.png">
</p>

I recommend using the ``wide`` and ``deeptabular`` models in
``pytorch-widedeep``. However it is very likely that users will want to use
their own models for the ``deeptext`` and ``deepimage`` components. That is
perfectly possible as long as the the custom models have an attribute called
``output_dim`` with the size of the last layer of activations, so that
``WideDeep`` can be constructed. Again, examples on how to use custom
components can be found in the Examples folder. Just in case
``pytorch-widedeep`` includes standard text (stack of LSTMs) and image
(pre-trained ResNets or stack of CNNs) models.

### The ``deeptabular`` component

It is important to emphasize that **each individual component, `wide`,
`deeptabular`, `deeptext` and `deepimage`, can be used independently** and in
isolation. For example, one could use only `wide`, which is in simply a
linear model. In fact, one of the most interesting functionalities
in``pytorch-widedeep`` is the ``deeptabular`` component. Currently,
``pytorch-widedeep`` offers the following different models for that
component:
in``pytorch-widedeep`` would be the use of the ``deeptabular`` component on
its own, i.e. what one might normally refer as Deep Learning for Tabular
Data. Currently, ``pytorch-widedeep`` offers the following different models
for that component:

1. ``TabMlp``: this is almost identical to the [tabular
model](https://docs.fast.ai/tutorial.tabular.html) in the fantastic
[fastai](https://docs.fast.ai/) library, and consists simply in embeddings
representing the categorical features, concatenated with the continuous
features, and passed then through a MLP.

2. ``TabRenset``: This is similar to the previous model but the embeddings are
1. **TabMlp**: a simple MLP that receives embeddings representing the
categorical features, concatenated with the continuous features.
2. **TabResnet**: similar to the previous model but the embeddings are
passed through a series of ResNet blocks built with dense layers.

3. ``Tabnet``: Details on TabNet can be found in:
3. **TabNet**: details on TabNet can be found in
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)

4. ``TabTransformer``: Details on the TabTransformer can be found in:
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
Note that the TabTransformer implementation available at ``pytorch-widedeep``
is an adaptation of the original implementation.
And the ``Tabformer`` family, i.e. Transformers for Tabular data:

5. ``FT-Transformer``: or Feature Tokenizer transformer. This is a relatively small
variation of the ``TabTransformer``. The variation itself was first
introduced in the ``SAINT`` paper, but the name "``FT-Transformer``" was first
used in
4. **TabTransformer**: details on the TabTransformer can be found in
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
5. **SAINT**: Details on SAINT can be found in
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
6. **FT-Transformer**: details on the FT-Transformer can be found in
[Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959).
When using the ``FT-Transformer`` each continuous feature is "embedded"
(i.e. going through a 1-layer MLP with or without activation function) and
then passed through the attention blocks along with the categorical features.
This is available in ``pytorch-widedeep``'s ``TabTransformer`` by setting the
parameter ``embed_continuous = True``.

7. **TabFastFormer**: adaptation of the FastFormer for tabular data. Details
on the Fasformer can be found in
[FastFormers: Highly Efficient Transformer Models for Natural Language Understanding](https://arxiv.org/abs/2010.13382)
8. **TabPerceiver**: adaptation of the Perceiver for tabular data. Details on
the Perceiver can be found in
[Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)

6. ``SAINT``: Details on SAINT can be found in:
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
Note that while there are scientific publications for the TabTransformer,
SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own
adaptation of those algorithms for tabular data.

For details on these models and their options please see the examples in the
Examples folder and the documentation.

Finally, while I recommend using the ``wide`` and ``deeptabular`` models in
``pytorch-widedeep`` it is very likely that users will want to use their own
models for the ``deeptext`` and ``deepimage`` components. That is perfectly
possible as long as the the custom models have an attribute called
``output_dim`` with the size of the last layer of activations, so that
``WideDeep`` can be constructed. Again, examples on how to use custom
components can be found in the Examples folder. Just in case
``pytorch-widedeep`` includes standard text (stack of LSTMs) and image
(pre-trained ResNets or stack of CNNs) models.


### Installation
### Installation

Install using pip:

Expand Down Expand Up @@ -167,8 +171,8 @@ when running on Mac, present in previous versions, persist on this release
and the data-loaders will not run in parallel. In addition, since `python
3.8`, [the `multiprocessing` library start method changed from `'fork'` to`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
This also affects the data-loaders (for any `torch` version) and they will
not run in parallel. Therefore, for Mac users I recommend using `python
3.6` or `3.7` and `torch <= 1.6` (with the corresponding, consistent
not run in parallel. Therefore, for Mac users I recommend using `python 3.7`
and `torch <= 1.6` (with the corresponding, consistent
version of `torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to
force this versioning in the `setup.py` file since I expect that all these
issues are fixed in the future. Therefore, after installing
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0.5
1.0.9
1 change: 1 addition & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ them to address different problems
* `Save and Load Model and Artifacts <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/08_save_and_load_model_and_artifacts.ipynb>`__
* `Using Custom DataLoaders and Torchmetrics <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb>`__
* `The Transformer Family <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/10_The_Transformer_Family.ipynb>`__
* `Extracting Embeddings <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/11_Extracting_Embeddings.ipynb>`__
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Documentation
Dataloaders <dataloaders>
Callbacks <callbacks>
The Trainer <trainer>
Tab2Vec <tab2vec>
Examples <examples>


Expand Down
4 changes: 2 additions & 2 deletions docs/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ on their own and can be imported as:
from pytorch_widedeep.losses import FocalLoss
.. note:: Losses in this module expect the predictions and ground truth to have the
same dimensions for regression and binary classification problems (i.e.
:math:`N_{samples}, 1)`. In the case of multiclass classification problems
same dimensions for regression and binary classification problems
:math:`(N_{samples}, 1)`. In the case of multiclass classification problems
the ground truth is expected to be a 1D tensor with the corresponding
classes. See Examples below

Expand Down
7 changes: 3 additions & 4 deletions docs/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ Metrics
=======

.. note:: Metrics in this module expect the predictions and ground truth to have the
same dimensions for regression and binary classification problems (i.e.
:math:`N_{samples}, 1)`. In the case of multiclass classification problems the
ground truth is expected to be a 1D tensor with the corresponding classes.
See Examples below
same dimensions for regression and binary classification problems: :math:`(N_{samples}, 1)`.
In the case of multiclass classification problems the ground truth is expected to be
a 1D tensor with the corresponding classes. See Examples below

We have added the possibility of using the metrics available at the
`torchmetrics <https://torchmetrics.readthedocs.io/en/latest/>`_ library.
Expand Down
19 changes: 16 additions & 3 deletions docs/model_components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ This module contains the four main components that will comprise a Wide and
Deep model, and the ``WideDeep`` "constructor" class. These four components
are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.

.. note:: ``TabMlp``, ``TabResnet``, ``TabNet``, ``TabTransformer`` and ``SAINT`` can
all be used as the ``deeptabular`` component of the model and simply
represent different alternatives
.. note:: ``TabMlp``, ``TabResnet``, ``TabNet``, ``TabTransformer``, ``SAINT``,
``FTTransformer``, ``TabPerceiver`` and ``TabFastFormer`` can all be used
as the ``deeptabular`` component of the model and simply represent different
alternatives

.. autoclass:: pytorch_widedeep.models.wide.Wide
:exclude-members: forward
Expand All @@ -33,6 +34,18 @@ are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
:exclude-members: forward
:members:

.. autoclass:: pytorch_widedeep.models.transformers.ft_transformer.FTTransformer
:exclude-members: forward
:members:

.. autoclass:: pytorch_widedeep.models.transformers.tab_perceiver.TabPerceiver
:exclude-members: forward
:members:

.. autoclass:: pytorch_widedeep.models.transformers.tab_fastformer.TabFastFormer
:exclude-members: forward
:members:

.. autoclass:: pytorch_widedeep.models.deep_text.DeepText
:exclude-members: forward
:members:
Expand Down
7 changes: 7 additions & 0 deletions docs/tab2vec.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Tab2Vec
=======

.. autoclass:: pytorch_widedeep.tab2vec.Tab2Vec
:members:
:undoc-members:

Loading

0 comments on commit 86217bb

Please sign in to comment.