diff --git a/README.md b/README.md
index 4c6f1d7e..adf85e35 100644
--- a/README.md
+++ b/README.md
@@ -5,6 +5,12 @@
[![Build Status](https://travis-ci.org/jrzaurin/pytorch-widedeep.svg?branch=master)](https://travis-ci.org/jrzaurin/pytorch-widedeep)
[![Documentation Status](https://readthedocs.org/projects/pytorch-widedeep/badge/?version=latest)](https://pytorch-widedeep.readthedocs.io/en/latest/?badge=latest)
+[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
+
+Platform | Version Support
+---------|:---------------
+OSX | [![Python 3.6 3.7](https://img.shields.io/badge/python-3.6%20%203.7-blue.svg)](https://www.python.org/)
+Linux | [![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%203.7%203.8-blue.svg)](https://www.python.org/)
# pytorch-widedeep
@@ -34,11 +40,11 @@ few lines of code.
-Architecture 1 combines the `Wide`, one-hot encoded features with the outputs
-from the `DeepDense`, `DeepText` and `DeepImage` components connected to a
-final output neuron or neurons, depending on whether we are performing a
-binary classification or regression, or a multi-class classification. The
-components within the faded-pink rectangles are concatenated.
+Architecture 1 combines the `Wide`, Linear model with the outputs from the
+`DeepDense`, `DeepText` and `DeepImage` components connected to a final output
+neuron or neurons, depending on whether we are performing a binary
+classification or regression, or a multi-class classification. The components
+within the faded-pink rectangles are concatenated.
In math terms, and following the notation in the
[paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated
@@ -65,10 +71,10 @@ otherwise".*
-Architecture 2 combines the `Wide` one-hot encoded features with the Deep
-components of the model connected to the output neuron(s), after the different
-Deep components have been themselves combined through a FC-Head (that I refer
-as `deephead`).
+Architecture 2 combines the `Wide`, Linear model with the Deep components of
+the model connected to the output neuron(s), after the different Deep
+components have been themselves combined through a FC-Head (that I refer as
+`deephead`).
In math terms, and following the notation in the
[paper](https://arxiv.org/abs/1606.07792), Architecture 2 can be formulated
@@ -84,7 +90,8 @@ and `DeepImage` are optional. `pytorch-widedeep` includes standard text (stack
of LSTMs) and image (pre-trained ResNets or stack of CNNs) models. However,
the user can use any custom model as long as it has an attribute called
`output_dim` with the size of the last layer of activations, so that
-`WideDeep` can be constructed. See the examples folder for more information.
+`WideDeep` can be constructed. See the examples folder or the docs for more
+information.
### Installation
@@ -112,14 +119,6 @@ cd pytorch-widedeep
pip install -e .
```
-### Examples
-
-There are a number of notebooks in the `examples` folder plus some additional
-files. These notebooks cover most of the utilities of this package and can
-also act as documentation. In the case that github does not render the
-notebooks, or it renders them missing some parts, they are saved as markdown
-files in the `docs` folder.
-
### Quick start
Binary classification with the [adult
@@ -128,6 +127,7 @@ using `Wide` and `DeepDense` and defaults settings.
```python
import pandas as pd
+import numpy as np
from sklearn.model_selection import train_test_split
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
@@ -166,7 +166,7 @@ target = df_train[target_col].values
# wide
preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
X_wide = preprocess_wide.fit_transform(df_train)
-wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)
+wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
# deepdense
preprocess_deep = DensePreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
diff --git a/VERSION b/VERSION
index f7abe273..c8a5397f 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.4.2
\ No newline at end of file
+0.4.5
\ No newline at end of file
diff --git a/docs/wide_deep/callbacks.rst b/docs/callbacks.rst
similarity index 100%
rename from docs/wide_deep/callbacks.rst
rename to docs/callbacks.rst
diff --git a/docs/figures/architecture_1.png b/docs/figures/architecture_1.png
index a9fb25df..5ffa7c98 100644
Binary files a/docs/figures/architecture_1.png and b/docs/figures/architecture_1.png differ
diff --git a/docs/figures/architecture_2.png b/docs/figures/architecture_2.png
index 074af49d..7c0068bc 100644
Binary files a/docs/figures/architecture_2.png and b/docs/figures/architecture_2.png differ
diff --git a/docs/index.rst b/docs/index.rst
index 8dbe21e6..02f4291a 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -18,7 +18,10 @@ Documentation
Utilities
Preprocessing
Model Components
- Wide and Deep Models
+ Metrics
+ Callbacks
+ Focal Loss
+ Wide and Deep Models
Examples
@@ -45,12 +48,11 @@ Architectures
:width: 600px
:align: center
-Architecture 1 combines the ``Wide``, one-hot encoded features with the
-outputs from the ``DeepDense``, ``DeepText`` and ``DeepImage`` components
-connected to a final output neuron or neurons, depending on whether we are
-performing a binary classification or regression, or a multi-class
-classification. The components within the faded-pink rectangles are
-concatenated.
+Architecture 1 combines the `Wide`, Linear model with the outputs from the
+`DeepDense`, `DeepText` and `DeepImage` components connected to a final output
+neuron or neurons, depending on whether we are performing a binary
+classification or regression, or a multi-class classification. The components
+within the faded-pink rectangles are concatenated.
In math terms, and following the notation in the `paper
`_, Architecture 1 can be formulated as:
@@ -76,10 +78,10 @@ is the activation function.
:width: 600px
:align: center
-Architecture 2 combines the ``Wide`` one-hot encoded features with the Deep
-components of the model connected to the output neuron(s), after the different
-Deep components have been themselves combined through a FC-Head (referred as
-as ``deephead``).
+Architecture 2 combines the `Wide`, Linear model with the Deep components of
+the model connected to the output neuron(s), after the different Deep
+components have been themselves combined through a FC-Head (that I refer as
+`deephead`).
In math terms, and following the notation in the `paper
`_, Architecture 2 can be formulated as:
diff --git a/docs/wide_deep/losses.rst b/docs/losses.rst
similarity index 100%
rename from docs/wide_deep/losses.rst
rename to docs/losses.rst
diff --git a/docs/wide_deep/metrics.rst b/docs/metrics.rst
similarity index 100%
rename from docs/wide_deep/metrics.rst
rename to docs/metrics.rst
diff --git a/docs/model_components.rst b/docs/model_components.rst
index b308d21c..cca672c6 100644
--- a/docs/model_components.rst
+++ b/docs/model_components.rst
@@ -1,10 +1,9 @@
The ``models`` module
-=====================
+======================
This module contains the four main Wide and Deep model component. These are:
``Wide``, ``DeepDense``, ``DeepText`` and ``DeepImage``.
-
.. autoclass:: pytorch_widedeep.models.wide.Wide
:members:
:undoc-members:
diff --git a/docs/quick_start.rst b/docs/quick_start.rst
index 2fc305a5..b37fff65 100644
--- a/docs/quick_start.rst
+++ b/docs/quick_start.rst
@@ -15,6 +15,7 @@ The following code snippet is not directly related to ``pytorch-widedeep``.
.. code-block:: python
import pandas as pd
+ import numpy as np
from sklearn.model_selection import train_test_split
df = pd.read_csv("data/adult/adult.csv.zip")
@@ -23,6 +24,7 @@ The following code snippet is not directly related to ``pytorch-widedeep``.
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.income_label)
+
Prepare the wide and deep columns
---------------------------------
@@ -63,7 +65,7 @@ Preprocessing and model components definition
# wide
preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
X_wide = preprocess_wide.fit_transform(df_train)
- wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)
+ wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
# deepdense
preprocess_deep = DensePreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
diff --git a/docs/wide_deep/wide_deep.rst b/docs/wide_deep.rst
similarity index 72%
rename from docs/wide_deep/wide_deep.rst
rename to docs/wide_deep.rst
index 82f909e7..0675afef 100644
--- a/docs/wide_deep/wide_deep.rst
+++ b/docs/wide_deep.rst
@@ -1,6 +1,9 @@
Building Wide and Deep Models
=============================
+Here is the documentation to build the two architectures, and the different
+options available in ``pytorch-widedeep`` as one builds the model.
+
:class:`pytorch_widedeep.models.wide_deep.WideDeep` is the main class. It will
collect all model components and build one of the two possible architectures
with a series of optional parameters.
diff --git a/docs/wide_deep/index.rst b/docs/wide_deep/index.rst
deleted file mode 100644
index c95cbf17..00000000
--- a/docs/wide_deep/index.rst
+++ /dev/null
@@ -1,15 +0,0 @@
-Wide and Deep Models
-=====================
-
-Here is the documentation to build the two architectures, and the different
-options available in ``pytorch-widedeep`` as one builds the model.
-
-Objects
--------
-
-.. toctree::
-
- metrics
- callbacks
- losses
- wide_deep
\ No newline at end of file
diff --git a/examples/01_Preprocessors_and_utils.ipynb b/examples/01_Preprocessors_and_utils.ipynb
index b0fbdb48..8457bb6c 100644
--- a/examples/01_Preprocessors_and_utils.ipynb
+++ b/examples/01_Preprocessors_and_utils.ipynb
@@ -50,7 +50,9 @@
"source": [
"## 1. WidePreprocessor\n",
"\n",
- "This class simply takes a dataset and one-hot encodes it, with a few additional rings and bells. "
+ "The Wide component of the model is a linear model that in principle, could be implemented as a linear layer receiving the result of on one-hot encoding categorical columns. However, this is not memory efficient. Therefore, we implement a liner layer as an Embedding layer plus a bias. I will explain in a bit more detail later. \n",
+ "\n",
+ "With that in mind, `WidePreprocessor` simply encodes the categories numerically so that they are the indexes of the lookup table that is an Embedding layer."
]
},
{
@@ -284,13 +286,13 @@
{
"data": {
"text/plain": [
- "array([[0., 1., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
+ "array([[ 1, 17, 23, ..., 89, 91, 316],\n",
+ " [ 2, 18, 23, ..., 89, 92, 317],\n",
+ " [ 3, 18, 24, ..., 89, 93, 318],\n",
" ...,\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.]])"
+ " [ 2, 20, 23, ..., 90, 103, 323],\n",
+ " [ 2, 17, 23, ..., 89, 103, 323],\n",
+ " [ 2, 21, 29, ..., 90, 115, 324]])"
]
},
"execution_count": 6,
@@ -306,45 +308,103 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "or sparse"
+ "Let's take from example the first entry"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([ 1, 17, 23, 32, 47, 89, 91, 316])"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "wide_preprocessor_sparse = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols, sparse=True)\n",
- "X_wide_sparse = wide_preprocessor_sparse.fit_transform(df)"
+ "X_wide[0]"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " education | \n",
+ " relationship | \n",
+ " workclass | \n",
+ " occupation | \n",
+ " native-country | \n",
+ " gender | \n",
+ " education_occupation | \n",
+ " native-country_occupation | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 11th | \n",
+ " Own-child | \n",
+ " Private | \n",
+ " Machine-op-inspct | \n",
+ " United-States | \n",
+ " Male | \n",
+ " 11th-Machine-op-inspct | \n",
+ " United-States-Machine-op-inspct | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
"text/plain": [
- "<48842x796 sparse matrix of type ''\n",
- "\twith 390736 stored elements in Compressed Sparse Row format>"
+ " education relationship workclass occupation native-country gender \\\n",
+ "0 11th Own-child Private Machine-op-inspct United-States Male \n",
+ "\n",
+ " education_occupation native-country_occupation \n",
+ "0 11th-Machine-op-inspct United-States-Machine-op-inspct "
]
},
- "execution_count": 8,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "X_wide_sparse"
+ "wide_preprocessor.inverse_transform(X_wide[:1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Note that while this will save memory on disk, due to the batch generation process for `WideDeep` the running time will be notably slow. See [here](https://github.com/jrzaurin/pytorch-widedeep/blob/bfbe6e5d2309857db0dcc5cf3282dfa60504aa52/pytorch_widedeep/models/_wd_dataset.py#L47) for more details."
+ "As we can see, `wide_preprocessor` numerically encodes the `wide_cols` and the `crossed_cols`, which can be recovered using the method `inverse_transform`."
]
},
{
diff --git a/examples/02_Model_Components.ipynb b/examples/02_Model_Components.ipynb
index 8e4fe64a..d5ea250d 100644
--- a/examples/02_Model_Components.ipynb
+++ b/examples/02_Model_Components.ipynb
@@ -23,7 +23,11 @@
"source": [
"### 1. Wide\n",
"\n",
- "The wide component is simply a Linear layer \"plugged\" into the output neuron(s)"
+ "The wide component is a Linear layer \"plugged\" into the output neuron(s)\n",
+ "\n",
+ "The only particularity of our implementation is that we have implemented the linear layer via an Embedding layer plus a bias. While the implementations are equivalent, the latter is faster and far more memory efficient, since we do not need to one hot encode the categorical features. \n",
+ "\n",
+ "Let's assume we the following dataset:"
]
},
{
@@ -31,13 +35,199 @@
"execution_count": 1,
"metadata": {},
"outputs": [],
+ "source": [
+ "import torch\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "\n",
+ "from torch import nn"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " color | \n",
+ " size | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " r | \n",
+ " s | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " b | \n",
+ " n | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " g | \n",
+ " l | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " color size\n",
+ "0 r s\n",
+ "1 b n\n",
+ "2 g l"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df = pd.DataFrame({'color': ['r', 'b', 'g'], 'size': ['s', 'n', 'l']})\n",
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "one hot encoded, the first observation would be"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "obs_0_oh = (np.array([1., 0., 0., 1., 0., 0.])).astype('float32')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "if we simply numerically encode (label encode or `le`) the values, starting from 1 (we will save 0 for padding, i.e. unseen values)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "obs_0_le = (np.array([0, 3])).astype('int64')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "now, let's see if the two implementations are equivalent"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# we have 6 different values. Let's assume we are performing a regression, so pred_dim = 1\n",
+ "lin = nn.Linear(6, 1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "emb = nn.Embedding(6, 1) \n",
+ "emb.weight = nn.Parameter(lin.weight.reshape_as(emb.weight))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([-0.9452], grad_fn=)"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "lin(torch.tensor(obs_0_oh))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([-0.9452], grad_fn=)"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "emb(torch.tensor(obs_0_le)).sum() + lin.bias"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "And this is precisely how the linear component `Wide` is implemented"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
"source": [
"from pytorch_widedeep.models import Wide"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
@@ -46,27 +236,34 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Wide(\n",
- " (wide_linear): Linear(in_features=100, out_features=1, bias=True)\n",
+ " (wide_linear): Embedding(11, 1, padding_idx=0)\n",
")"
]
},
- "execution_count": 2,
+ "execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "wide = Wide(100, 1)\n",
+ "wide = Wide(wide_dim=10, pred_dim=1)\n",
"wide"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note that even though the input dim is 10, the Embedding layer has 11 weights. This is because we save 0 for padding, which is used for unseen values during the encoding process"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -78,12 +275,10 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
- "import torch\n",
- "\n",
"from pytorch_widedeep.models import DeepDense"
]
},
diff --git a/examples/03_Binary_Classification_with_Defaults.ipynb b/examples/03_Binary_Classification_with_Defaults.ipynb
index 1d97d8fa..c645333c 100644
--- a/examples/03_Binary_Classification_with_Defaults.ipynb
+++ b/examples/03_Binary_Classification_with_Defaults.ipynb
@@ -419,14 +419,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "[[0. 1. 0. ... 0. 0. 0.]\n",
- " [0. 0. 0. ... 0. 0. 0.]\n",
- " [0. 0. 0. ... 0. 0. 0.]\n",
+ "[[ 1 17 23 ... 89 91 316]\n",
+ " [ 2 18 23 ... 89 92 317]\n",
+ " [ 3 18 24 ... 89 93 318]\n",
" ...\n",
- " [0. 0. 0. ... 0. 0. 0.]\n",
- " [0. 0. 0. ... 0. 0. 0.]\n",
- " [0. 0. 0. ... 0. 0. 0.]]\n",
- "(48842, 796)\n"
+ " [ 2 20 23 ... 90 103 323]\n",
+ " [ 2 17 23 ... 89 103 323]\n",
+ " [ 2 21 29 ... 90 115 324]]\n",
+ "(48842, 8)\n"
]
}
],
@@ -479,7 +479,7 @@
"metadata": {},
"outputs": [],
"source": [
- "wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
+ "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
"deepdense = DeepDense(hidden_layers=[64,32], \n",
" deep_column_idx=preprocess_deep.deep_column_idx,\n",
" embed_input=preprocess_deep.embeddings_input,\n",
@@ -497,7 +497,7 @@
"text/plain": [
"WideDeep(\n",
" (wide): Wide(\n",
- " (wide_linear): Linear(in_features=796, out_features=1, bias=True)\n",
+ " (wide_linear): Embedding(797, 1, padding_idx=0)\n",
" )\n",
" (deepdense): Sequential(\n",
" (0): DeepDense(\n",
@@ -577,7 +577,7 @@
"output_type": "stream",
"text": [
"\r",
- " 0%| | 0/153 [00:00, ?it/s]"
+ " 0%| | 0/611 [00:00, ?it/s]"
]
},
{
@@ -591,21 +591,21 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 102.41it/s, loss=0.585, metrics={'acc': 0.7512, 'prec': 0.1818}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 98.78it/s, loss=0.513, metrics={'acc': 0.754, 'prec': 0.2429}] \n",
- "epoch 2: 100%|██████████| 153/153 [00:01<00:00, 117.30it/s, loss=0.481, metrics={'acc': 0.782, 'prec': 0.8287}] \n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 106.49it/s, loss=0.454, metrics={'acc': 0.7866, 'prec': 0.8245}]\n",
- "epoch 3: 100%|██████████| 153/153 [00:01<00:00, 124.78it/s, loss=0.44, metrics={'acc': 0.8055, 'prec': 0.781}] \n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 115.36it/s, loss=0.425, metrics={'acc': 0.8077, 'prec': 0.7818}]\n",
- "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 125.01it/s, loss=0.418, metrics={'acc': 0.814, 'prec': 0.7661}] \n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 114.92it/s, loss=0.408, metrics={'acc': 0.8149, 'prec': 0.7671}]\n",
- "epoch 5: 100%|██████████| 153/153 [00:01<00:00, 116.57it/s, loss=0.404, metrics={'acc': 0.819, 'prec': 0.7527}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 108.89it/s, loss=0.397, metrics={'acc': 0.8203, 'prec': 0.7547}]\n"
+ "epoch 1: 100%|██████████| 611/611 [00:04<00:00, 128.34it/s, loss=0.655, metrics={'acc': 0.6487, 'prec': 0.2352}]\n",
+ "valid: 100%|██████████| 153/153 [00:00<00:00, 173.04it/s, loss=0.517, metrics={'acc': 0.6659, 'prec': 0.2524}]\n",
+ "epoch 2: 100%|██████████| 611/611 [00:04<00:00, 133.99it/s, loss=0.466, metrics={'acc': 0.7725, 'prec': 0.5456}]\n",
+ "valid: 100%|██████████| 153/153 [00:00<00:00, 171.03it/s, loss=0.433, metrics={'acc': 0.7765, 'prec': 0.5598}]\n",
+ "epoch 3: 100%|██████████| 611/611 [00:04<00:00, 132.06it/s, loss=0.413, metrics={'acc': 0.803, 'prec': 0.6451}] \n",
+ "valid: 100%|██████████| 153/153 [00:00<00:00, 172.22it/s, loss=0.4, metrics={'acc': 0.8045, 'prec': 0.648}] \n",
+ "epoch 4: 100%|██████████| 611/611 [00:04<00:00, 131.57it/s, loss=0.39, metrics={'acc': 0.8181, 'prec': 0.6836}] \n",
+ "valid: 100%|██████████| 153/153 [00:00<00:00, 169.62it/s, loss=0.384, metrics={'acc': 0.8195, 'prec': 0.6841}]\n",
+ "epoch 5: 100%|██████████| 611/611 [00:04<00:00, 130.85it/s, loss=0.378, metrics={'acc': 0.8247, 'prec': 0.6941}]\n",
+ "valid: 100%|██████████| 153/153 [00:00<00:00, 171.85it/s, loss=0.376, metrics={'acc': 0.8254, 'prec': 0.6946}]\n"
]
}
],
"source": [
- "model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=256, val_split=0.2)"
+ "model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=64, val_split=0.2)"
]
},
{
diff --git a/examples/04_Binary_Classification_Varying_Parameters.ipynb b/examples/04_Binary_Classification_Varying_Parameters.ipynb
index 3333cb9d..e4aa227e 100644
--- a/examples/04_Binary_Classification_Varying_Parameters.ipynb
+++ b/examples/04_Binary_Classification_Varying_Parameters.ipynb
@@ -419,14 +419,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "[[0. 1. 0. ... 0. 0. 0.]\n",
- " [0. 0. 0. ... 0. 0. 0.]\n",
- " [0. 0. 0. ... 0. 0. 0.]\n",
+ "[[ 1 17 23 ... 89 91 316]\n",
+ " [ 2 18 23 ... 89 92 317]\n",
+ " [ 3 18 24 ... 89 93 318]\n",
" ...\n",
- " [0. 0. 0. ... 0. 0. 0.]\n",
- " [0. 0. 0. ... 0. 0. 0.]\n",
- " [0. 0. 0. ... 0. 0. 0.]]\n",
- "(48842, 796)\n"
+ " [ 2 20 23 ... 90 103 323]\n",
+ " [ 2 17 23 ... 89 103 323]\n",
+ " [ 2 21 29 ... 90 115 324]]\n",
+ "(48842, 8)\n"
]
}
],
@@ -488,7 +488,7 @@
"metadata": {},
"outputs": [],
"source": [
- "wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
+ "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
"# We can add dropout and batchnorm to the dense layers\n",
"deepdense = DeepDense(hidden_layers=[64,32], dropout=[0.5, 0.5], batchnorm=True,\n",
" deep_column_idx=preprocess_deep.deep_column_idx,\n",
@@ -507,7 +507,7 @@
"text/plain": [
"WideDeep(\n",
" (wide): Wide(\n",
- " (wide_linear): Linear(in_features=796, out_features=1, bias=True)\n",
+ " (wide_linear): Embedding(797, 1, padding_idx=0)\n",
" )\n",
" (deepdense): Sequential(\n",
" (0): DeepDense(\n",
@@ -575,13 +575,13 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# Optimizers\n",
- "wide_opt = torch.optim.Adam(model.wide.parameters())\n",
- "deep_opt = RAdam(model.deepdense.parameters())\n",
+ "wide_opt = torch.optim.Adam(model.wide.parameters(), lr=0.03)\n",
+ "deep_opt = RAdam(model.deepdense.parameters(), lr=0.01)\n",
"# LR Schedulers\n",
"wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=3)\n",
"deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=5)"
@@ -596,7 +596,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@@ -611,7 +611,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@@ -623,7 +623,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
@@ -645,26 +645,26 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 96.80it/s, loss=0.582, metrics={'acc': 0.7447, 'rec': 0.0374}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 117.19it/s, loss=0.512, metrics={'acc': 0.7488, 'rec': 0.0347}]\n",
- "epoch 2: 100%|██████████| 153/153 [00:01<00:00, 112.47it/s, loss=0.481, metrics={'acc': 0.7819, 'rec': 0.1127}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 119.48it/s, loss=0.454, metrics={'acc': 0.7866, 'rec': 0.139}]\n",
- "epoch 3: 100%|██████████| 153/153 [00:01<00:00, 99.29it/s, loss=0.44, metrics={'acc': 0.8091, 'rec': 0.2838}] \n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 88.13it/s, loss=0.425, metrics={'acc': 0.8108, 'rec': 0.2925}]\n",
- "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 103.34it/s, loss=0.426, metrics={'acc': 0.8131, 'rec': 0.3124}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 101.78it/s, loss=0.423, metrics={'acc': 0.814, 'rec': 0.3156}]\n",
- "epoch 5: 100%|██████████| 153/153 [00:01<00:00, 100.77it/s, loss=0.423, metrics={'acc': 0.8132, 'rec': 0.3134}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 103.36it/s, loss=0.421, metrics={'acc': 0.814, 'rec': 0.3165}]\n",
- "epoch 6: 100%|██████████| 153/153 [00:01<00:00, 100.09it/s, loss=0.421, metrics={'acc': 0.8134, 'rec': 0.3147}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 111.64it/s, loss=0.418, metrics={'acc': 0.8141, 'rec': 0.3178}]\n",
- "epoch 7: 100%|██████████| 153/153 [00:01<00:00, 103.15it/s, loss=0.42, metrics={'acc': 0.8133, 'rec': 0.3148}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 100.57it/s, loss=0.418, metrics={'acc': 0.8141, 'rec': 0.3179}]\n",
- "epoch 8: 100%|██████████| 153/153 [00:01<00:00, 98.05it/s, loss=0.42, metrics={'acc': 0.8133, 'rec': 0.3148}] \n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 105.68it/s, loss=0.418, metrics={'acc': 0.8141, 'rec': 0.3179}]\n",
- "epoch 9: 100%|██████████| 153/153 [00:01<00:00, 101.05it/s, loss=0.419, metrics={'acc': 0.8133, 'rec': 0.3149}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 99.49it/s, loss=0.418, metrics={'acc': 0.8141, 'rec': 0.3181}]\n",
- "epoch 10: 100%|██████████| 153/153 [00:01<00:00, 97.72it/s, loss=0.419, metrics={'acc': 0.8133, 'rec': 0.3149}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 102.56it/s, loss=0.418, metrics={'acc': 0.8141, 'rec': 0.3181}]\n"
+ "epoch 1: 100%|██████████| 153/153 [00:02<00:00, 72.33it/s, loss=0.503, metrics={'acc': 0.7885, 'rec': 0.4864}]\n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 127.72it/s, loss=0.386, metrics={'acc': 0.7962, 'rec': 0.4998}]\n",
+ "epoch 2: 100%|██████████| 153/153 [00:02<00:00, 71.76it/s, loss=0.374, metrics={'acc': 0.8268, 'rec': 0.5242}]\n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 126.72it/s, loss=0.372, metrics={'acc': 0.8277, 'rec': 0.5281}]\n",
+ "epoch 3: 100%|██████████| 153/153 [00:02<00:00, 73.21it/s, loss=0.367, metrics={'acc': 0.8298, 'rec': 0.5242}]\n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 126.68it/s, loss=0.37, metrics={'acc': 0.8303, 'rec': 0.5279}]\n",
+ "epoch 4: 100%|██████████| 153/153 [00:02<00:00, 71.37it/s, loss=0.36, metrics={'acc': 0.8319, 'rec': 0.5372}] \n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 128.64it/s, loss=0.369, metrics={'acc': 0.8324, 'rec': 0.5412}]\n",
+ "epoch 5: 100%|██████████| 153/153 [00:02<00:00, 71.53it/s, loss=0.359, metrics={'acc': 0.8322, 'rec': 0.5378}]\n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 119.31it/s, loss=0.369, metrics={'acc': 0.8325, 'rec': 0.5412}]\n",
+ "epoch 6: 100%|██████████| 153/153 [00:02<00:00, 71.37it/s, loss=0.359, metrics={'acc': 0.8322, 'rec': 0.5361}]\n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 125.99it/s, loss=0.369, metrics={'acc': 0.8326, 'rec': 0.5398}]\n",
+ "epoch 7: 100%|██████████| 153/153 [00:02<00:00, 70.20it/s, loss=0.358, metrics={'acc': 0.8329, 'rec': 0.5396}]\n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 124.88it/s, loss=0.369, metrics={'acc': 0.8331, 'rec': 0.5416}]\n",
+ "epoch 8: 100%|██████████| 153/153 [00:02<00:00, 70.75it/s, loss=0.358, metrics={'acc': 0.833, 'rec': 0.5374}] \n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 125.81it/s, loss=0.369, metrics={'acc': 0.8331, 'rec': 0.5397}]\n",
+ "epoch 9: 100%|██████████| 153/153 [00:02<00:00, 70.40it/s, loss=0.358, metrics={'acc': 0.833, 'rec': 0.5368}] \n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 125.07it/s, loss=0.369, metrics={'acc': 0.8331, 'rec': 0.5391}]\n",
+ "epoch 10: 100%|██████████| 153/153 [00:02<00:00, 70.20it/s, loss=0.358, metrics={'acc': 0.8329, 'rec': 0.537}] \n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 124.43it/s, loss=0.369, metrics={'acc': 0.8331, 'rec': 0.5392}]\n"
]
}
],
@@ -674,7 +674,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -799,7 +799,7 @@
" 'zero_grad']"
]
},
- "execution_count": 15,
+ "execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -817,7 +817,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 20,
"metadata": {},
"outputs": [
{
@@ -826,7 +826,7 @@
"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"
]
},
- "execution_count": 16,
+ "execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
@@ -837,14 +837,14 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "{'train_loss': [0.582023054166557, 0.48075080015300925, 0.44022563099861145, 0.42563695144030006, 0.42342905612552867, 0.42120904256315794, 0.41995110737732033, 0.419722778734818, 0.4194869099099652, 0.41935000135228523], 'train_acc': [0.7446574360811814, 0.7818954265093543, 0.8090753205538351, 0.8130934404831981, 0.8132214060860441, 0.8133749648094593, 0.8132981854477517, 0.8132981854477517, 0.8132981854477517, 0.8132725923271824], 'train_rec': [0.037437159568071365, 0.11273933202028275, 0.28377366065979004, 0.31243982911109924, 0.3134025037288666, 0.3146860599517822, 0.31479302048683167, 0.31479302048683167, 0.3148999810218811, 0.3148999810218811], 'val_loss': [0.5119115939507117, 0.4539328179298303, 0.42495738925077975, 0.4227801507864243, 0.42057751004512495, 0.41838682691256207, 0.4181652260132325, 0.41793563885566515, 0.4176993484680469, 0.4176750809718401], 'val_acc': [0.7488432087138119, 0.7865771262438066, 0.8107571352524466, 0.8139715818353057, 0.8139920560173621, 0.8141353752917571, 0.8140739527455878, 0.8140739527455878, 0.8140739527455878, 0.8140534785635314], 'val_rec': [0.034739453345537186, 0.13904337584972382, 0.29246172308921814, 0.3155643045902252, 0.3165055215358734, 0.3177889883518219, 0.317874550819397, 0.317874550819397, 0.3181312680244446, 0.3181312680244446]}\n"
+ "{'train_loss': [0.5026861273385341, 0.37383826573689777, 0.36658557158669614, 0.3601557047538508, 0.3594148938172783, 0.35907501001763187, 0.358282413942362, 0.35823015644659406, 0.35819698957835927, 0.3581014702133104], 'train_acc': [0.788549637857344, 0.8267857599877153, 0.8297545619737414, 0.8318787909809843, 0.8321859084278146, 0.832237094668953, 0.832902515803752, 0.8329792951654595, 0.8329537020448904, 0.8329281089243211], 'train_rec': [0.48636218905448914, 0.5242272019386292, 0.5242272019386292, 0.5371697545051575, 0.5378115177154541, 0.5361000895500183, 0.5396299362182617, 0.5373836755752563, 0.5368488430976868, 0.5369558334350586], 'val_loss': [0.38589231249613637, 0.371902360365941, 0.36999432627971357, 0.36935041348139447, 0.3691598016482133, 0.36905216712218064, 0.36900061674607104, 0.36898223635477895, 0.36896658937136334, 0.36896434120642835], 'val_acc': [0.79624094017444, 0.8277302321772245, 0.8302895049342779, 0.832357397321977, 0.8325211907784285, 0.8325826133245977, 0.833073993693952, 0.8331354162401212, 0.8330944678760084, 0.833073993693952], 'val_rec': [0.4997860789299011, 0.5281081795692444, 0.5279369950294495, 0.5411996245384216, 0.5411996245384216, 0.5398305654525757, 0.5416274666786194, 0.5396594405174255, 0.5391460657119751, 0.5392315983772278]}\n"
]
}
],
@@ -854,14 +854,14 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "{'lr_wide_0': [0.001, 0.001, 0.001, 0.0001, 0.0001, 0.0001, 1e-05, 1e-05, 1e-05, 1.0000000000000002e-06], 'lr_deepdense_0': [0.001, 0.001, 0.001, 0.001, 0.001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001]}\n"
+ "{'lr_wide_0': [0.03, 0.03, 0.03, 0.003, 0.003, 0.003, 0.00030000000000000003, 0.00030000000000000003, 0.00030000000000000003, 3.0000000000000004e-05], 'lr_deepdense_0': [0.01, 0.01, 0.01, 0.01, 0.01, 0.001, 0.001, 0.001, 0.001, 0.001]}\n"
]
}
],
@@ -880,84 +880,83 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "{'11th': array([-0.02114219, -0.3634936 , 0.03710679, -0.07243915, -0.28715202,\n",
- " -0.29929525, 0.11913099, -0.01372065, -0.06960961, 0.11129184,\n",
- " -0.11541647, 0.02515038, -0.32817808, 0.19789433, -0.6190677 ,\n",
- " 0.13042031], dtype=float32),\n",
- " 'HS-grad': array([ 2.00377326e-04, -2.61859149e-01, -2.51907468e-01, -1.70783494e-02,\n",
- " -1.04680985e-01, -1.51709780e-01, 1.73194274e-01, -1.53597221e-01,\n",
- " -2.76275307e-01, -3.37639779e-01, 5.94966952e-03, 2.58735180e-01,\n",
- " -1.08496705e-02, -8.25304538e-02, -2.43277356e-01, 4.01295513e-01],\n",
- " dtype=float32),\n",
- " 'Assoc-acdm': array([ 0.0126759 , 0.15168795, -0.03856753, -0.27679357, -0.47500238,\n",
- " 0.45382416, -0.545228 , 0.1339748 , 0.02405205, 0.02809528,\n",
- " -0.41063702, -0.06350306, -0.08130409, -0.13869216, 0.4932242 ,\n",
- " 0.17304394], dtype=float32),\n",
- " 'Some-college': array([-0.1610479 , -0.25853214, -0.236602 , 0.13044621, 0.03830301,\n",
- " -0.1743144 , -0.28899103, -0.11883932, 0.08455969, -0.08742228,\n",
- " -0.46097067, -0.11231954, 0.37493324, -0.2029054 , -0.07289007,\n",
- " -0.03197158], dtype=float32),\n",
- " '10th': array([-0.07396724, 0.06467492, -0.08107238, -0.03854853, -0.06056274,\n",
- " 0.17571206, 0.1095883 , 0.12067619, -0.40733424, 0.32879853,\n",
- " -0.17957865, 0.6560938 , -0.10061017, 0.23316202, 0.3059522 ,\n",
- " -0.14240988], dtype=float32),\n",
- " 'Prof-school': array([-0.04007137, -0.2707076 , 0.23113215, -0.41783914, 0.25955105,\n",
- " -0.30054352, 0.32043606, -0.20860812, -0.15136348, -0.36359408,\n",
- " 0.49961898, -0.13973509, 0.51584864, 0.47093126, -0.0276325 ,\n",
- " -0.20662539], dtype=float32),\n",
- " '7th-8th': array([-0.07773081, 0.09345848, 0.072533 , -0.24359678, 0.14904591,\n",
- " 0.18480958, 0.01799594, 0.12402041, -0.35343906, 0.23270686,\n",
- " 0.10102016, -0.0258682 , 0.40796915, -0.05507657, 0.4019308 ,\n",
- " 0.05231443], dtype=float32),\n",
- " 'Bachelors': array([-0.12334875, 0.2271091 , -0.22389385, 0.5577601 , -0.05163969,\n",
- " -0.37246484, -0.02689779, 0.18202123, 0.59914356, -0.07744938,\n",
- " 0.5633556 , -0.18728566, -0.43923494, 0.2014725 , 0.00761633,\n",
- " -0.0447193 ], dtype=float32),\n",
- " 'Masters': array([-0.16333768, -0.16029981, -0.01482454, -0.04896322, -0.0047817 ,\n",
- " 0.09887701, -0.15091099, 0.22599514, -0.17000915, 0.16678709,\n",
- " -0.3679181 , -0.18114986, -0.16266271, -0.27970657, -0.1254899 ,\n",
- " 0.31768733], dtype=float32),\n",
- " 'Doctorate': array([ 0.01903534, -0.02743328, 0.16066255, -0.11599138, -0.00787276,\n",
- " 0.145728 , 0.24741152, -0.09514342, -0.23147094, -0.1098811 ,\n",
- " -0.12666361, 0.19410084, -0.05531591, -0.37460938, 0.42867297,\n",
- " 0.01255902], dtype=float32),\n",
- " '5th-6th': array([-0.1874508 , -0.01520642, -0.23055367, -0.10444976, 0.1880218 ,\n",
- " 0.06044631, -0.17084908, 0.28993553, 0.19094709, 0.01088051,\n",
- " -0.05885294, 0.26692954, -0.10718243, -0.07673435, -0.00814716,\n",
- " 0.46550933], dtype=float32),\n",
- " 'Assoc-voc': array([-0.1354011 , -0.44471708, 0.18469264, -0.02088883, 0.0346331 ,\n",
- " 0.07825129, 0.22990814, -0.38387823, 0.01530089, -0.30289283,\n",
- " -0.18230931, 0.19571105, -0.03887892, 0.01946613, 0.16479516,\n",
- " 0.41735104], dtype=float32),\n",
- " '9th': array([-0.29879665, 0.04421502, -0.11862607, 0.1772717 , -0.12706555,\n",
- " 0.04192697, -0.28609815, 0.3248482 , -0.04987352, -0.39138898,\n",
- " -0.057826 , 0.4970304 , -0.1326947 , 0.22000486, -0.00846681,\n",
- " 0.2706219 ], dtype=float32),\n",
- " '12th': array([-0.14942442, -0.22520241, 0.04879642, -0.16480213, -0.00521241,\n",
- " -0.07897403, 0.07396449, -0.29127416, -0.26175758, -0.5076894 ,\n",
- " -0.06036085, 0.3846129 , -0.6074103 , 0.27427655, -0.15219459,\n",
- " 0.24506666], dtype=float32),\n",
- " '1st-4th': array([ 0.02215668, 0.21796826, -0.35868096, 0.03803689, 0.02591529,\n",
- " 0.3914331 , -0.58327377, 0.3261264 , 0.36127493, -0.25838605,\n",
- " -0.05334533, -0.04685102, 0.17751735, 0.08530575, 0.13134745,\n",
- " 0.44403064], dtype=float32),\n",
- " 'Preschool': array([-0.33479828, 0.19172014, 0.26898265, 0.04768471, 0.01425556,\n",
- " 0.02984914, -0.02165659, 0.09084602, -0.26122406, 0.06567731,\n",
- " 0.0431284 , 0.3698193 , 0.6405797 , -0.00345286, 0.10917825,\n",
- " -0.07227341], dtype=float32),\n",
- " 'unseen': array([ 0.34454626, 0.08338903, 0.00250609, -0.27078775, 0.12649588,\n",
- " 0.35320354, -0.02497412, 0.2975028 , 0.21158105, 0.04682659,\n",
- " 0.03411686, -0.02839612, 0.16605824, 0.15381509, -0.00892953,\n",
- " -0.820573 ], dtype=float32)}"
+ "{'11th': array([ 0.33238176, 0.02123132, 0.42671534, -0.16836806, 0.04070434,\n",
+ " 0.21476945, -0.05866506, 0.09599391, 0.21264766, -0.08261641,\n",
+ " -0.4364204 , 0.5176953 , -0.17785792, 0.1990719 , 0.05055304,\n",
+ " -0.05390744], dtype=float32),\n",
+ " 'HS-grad': array([ 0.1851779 , -0.0601109 , -0.04134565, -0.17099169, 0.01647249,\n",
+ " 0.1691518 , -0.03775224, -0.01711482, -0.13714994, -0.02202759,\n",
+ " -0.2350222 , 0.20368417, 0.06420711, 0.08465873, 0.11443923,\n",
+ " -0.28585908], dtype=float32),\n",
+ " 'Assoc-acdm': array([-0.2891686 , -0.25329128, -0.03977084, 0.34204823, 0.4393897 ,\n",
+ " 0.24583909, -0.08771466, 0.3398704 , 0.06197336, -0.09200054,\n",
+ " 0.13266966, -0.27940965, -0.10639463, 0.16516595, 0.20191231,\n",
+ " -0.11804624], dtype=float32),\n",
+ " 'Some-college': array([ 0.17284533, -0.34509236, -0.22175975, -0.11192639, 0.14154772,\n",
+ " 0.04188053, 0.14860624, 0.28312132, 0.06071718, -0.10315312,\n",
+ " -0.05902205, -0.03197744, 0.20363455, 0.04027565, 0.43063605,\n",
+ " 0.21163562], dtype=float32),\n",
+ " '10th': array([ 0.13888928, 0.28386956, 0.18166119, 0.02652328, 0.11637231,\n",
+ " 0.24056876, -0.06386037, 0.05930374, 0.04393852, 0.17677549,\n",
+ " 0.27980283, -0.01221516, 0.12281907, 0.04273703, 0.22282158,\n",
+ " -0.25718638], dtype=float32),\n",
+ " 'Prof-school': array([ 0.26996085, 0.06557842, 0.0957497 , 0.06524102, 0.05351401,\n",
+ " 0.34774455, -0.39007127, -0.35276353, -0.19460988, 0.06306136,\n",
+ " -0.03555794, 0.02946662, 0.47177076, 0.21887466, 0.34440616,\n",
+ " 0.17761633], dtype=float32),\n",
+ " '7th-8th': array([-0.14013144, -0.20337081, 0.6704599 , -0.10210201, 0.1633953 ,\n",
+ " -0.03677108, -0.04664218, -0.13967332, -0.02610652, -0.15920916,\n",
+ " -0.18137608, -0.01846946, 0.35807863, 0.0148629 , 0.2857368 ,\n",
+ " 0.28930005], dtype=float32),\n",
+ " 'Bachelors': array([-0.38666266, 0.17745058, -0.6287257 , 0.22080924, 0.25037012,\n",
+ " -0.10224682, 0.5612052 , -0.24709803, 0.03214271, -0.22835065,\n",
+ " -0.14132145, 0.3010941 , -0.23835489, 0.08622 , -0.04518703,\n",
+ " 0.31074366], dtype=float32),\n",
+ " 'Masters': array([-0.41403466, -0.33947882, 0.14072244, -0.22146806, -0.18230349,\n",
+ " -0.1195543 , -0.84759206, 0.25256675, 0.14532281, -0.01060636,\n",
+ " -0.03578382, -0.07117725, 0.10634375, -0.11669173, 0.17765476,\n",
+ " -0.03559739], dtype=float32),\n",
+ " 'Doctorate': array([ 0.00375404, -0.02784416, -0.28326795, 0.22763273, 0.03977633,\n",
+ " 0.2893272 , 0.25680798, 0.36434892, -0.65951985, -0.23679003,\n",
+ " -0.11408209, -0.23283346, -0.27024168, 0.0655888 , -0.28381783,\n",
+ " -0.01525949], dtype=float32),\n",
+ " '5th-6th': array([ 0.00683184, 0.23564084, -0.132059 , -0.3406017 , -0.06710123,\n",
+ " -0.09649926, 0.50411046, -0.12363172, -0.0353502 , -0.53238744,\n",
+ " -0.05181202, -0.05146485, -0.23931046, -0.26453286, 0.08420272,\n",
+ " 0.0235041 ], dtype=float32),\n",
+ " 'Assoc-voc': array([ 0.01930698, -0.2455314 , 0.2246628 , 0.16216752, -0.4528598 ,\n",
+ " -0.6121017 , 0.15893641, 0.01993939, -0.3148845 , 0.03837916,\n",
+ " 0.0767131 , -0.36453167, 0.19929656, 0.28016493, 0.29385152,\n",
+ " -0.47822088], dtype=float32),\n",
+ " '9th': array([-0.03110321, 0.69687057, -0.33127317, 0.06741869, 0.08373164,\n",
+ " 0.25090563, 0.07099659, 0.21758935, -0.07414749, -0.19316533,\n",
+ " 0.21613942, 0.28149685, -0.41364396, -0.0439614 , -0.02726781,\n",
+ " -0.04664526], dtype=float32),\n",
+ " '12th': array([ 0.46782094, 0.1987633 , 0.11554655, -0.23237073, -0.35828865,\n",
+ " -0.08366812, 0.0086338 , 0.46672872, -0.24939838, 0.22630745,\n",
+ " -0.16754937, -0.4713689 , -0.08152255, 0.02004629, 0.1118032 ,\n",
+ " 0.20979449], dtype=float32),\n",
+ " '1st-4th': array([-0.16926417, 0.11347993, 0.02692448, -0.10284851, 0.25171363,\n",
+ " -0.04539176, -0.24491136, 0.3281045 , -0.08861455, 0.18578447,\n",
+ " 0.23892452, -0.00729677, 0.16713212, 0.2949316 , -0.00725389,\n",
+ " -0.20607162], dtype=float32),\n",
+ " 'Preschool': array([-0.30532706, 0.25465214, -0.5603218 , -0.16249408, -0.32321507,\n",
+ " 0.11698078, 0.01557691, -0.3124683 , -0.25044286, 0.08334377,\n",
+ " 0.2094927 , 0.03301949, -0.01236501, -0.24443303, -0.03395106,\n",
+ " -0.01797807], dtype=float32),\n",
+ " 'unseen': array([-0.17771505, 0.3246768 , -0.29062387, 0.12164559, 0.34164497,\n",
+ " -0.5451506 , 0.22189835, 0.21224639, 0.4933099 , -0.03533744,\n",
+ " -0.12335563, 0.12472781, 0.1412489 , 0.17336178, 0.4160364 ,\n",
+ " -0.32417113], dtype=float32)}"
]
},
- "execution_count": 19,
+ "execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
diff --git a/examples/05_Regression_with_Images_and_Text.ipynb b/examples/05_Regression_with_Images_and_Text.ipynb
index b5831e53..1b1319be 100644
--- a/examples/05_Regression_with_Images_and_Text.ipynb
+++ b/examples/05_Regression_with_Images_and_Text.ipynb
@@ -1058,7 +1058,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- " 4%|▍ | 42/1001 [00:00<00:02, 411.81it/s]"
+ " 4%|▍ | 43/1001 [00:00<00:02, 424.34it/s]"
]
},
{
@@ -1072,7 +1072,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|██████████| 1001/1001 [00:02<00:00, 402.73it/s]\n"
+ "100%|██████████| 1001/1001 [00:02<00:00, 400.65it/s]\n"
]
},
{
@@ -1097,12 +1097,12 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Linear model\n",
- "wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
+ "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
"# DeepDense: 2 Dense layers\n",
"deepdense = DeepDense(hidden_layers=[128,64], dropout=[0.5, 0.5], \n",
" deep_column_idx=deep_preprocessor.deep_column_idx,\n",
@@ -1125,7 +1125,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -1141,7 +1141,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -1150,7 +1150,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -1172,8 +1172,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 25/25 [01:13<00:00, 2.93s/it, loss=135]\n",
- "valid: 100%|██████████| 7/7 [00:14<00:00, 2.10s/it, loss=124] \n"
+ "epoch 1: 100%|██████████| 25/25 [01:08<00:00, 2.74s/it, loss=1.73e+4]\n",
+ "valid: 100%|██████████| 7/7 [00:14<00:00, 2.01s/it, loss=1.45e+4]\n"
]
}
],
@@ -1193,11 +1193,11 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
- "wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
+ "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
"deepdense = DeepDense(hidden_layers=[128,64], dropout=[0.5, 0.5], \n",
" deep_column_idx=deep_preprocessor.deep_column_idx,\n",
" embed_input=deep_preprocessor.embeddings_input,\n",
@@ -1217,7 +1217,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@@ -1233,7 +1233,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
@@ -1241,7 +1241,7 @@
"text/plain": [
"WideDeep(\n",
" (wide): Wide(\n",
- " (wide_linear): Linear(in_features=356, out_features=1, bias=True)\n",
+ " (wide_linear): Embedding(357, 1, padding_idx=0)\n",
" )\n",
" (deepdense): DeepDense(\n",
" (embed_layers): ModuleDict(\n",
@@ -1386,7 +1386,7 @@
")"
]
},
- "execution_count": 15,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -1406,7 +1406,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@@ -1420,11 +1420,11 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
- "wide_opt = torch.optim.Adam(model.wide.parameters())\n",
+ "wide_opt = torch.optim.Adam(model.wide.parameters(), lr=0.03)\n",
"deep_opt = torch.optim.Adam(deep_params)\n",
"text_opt = RAdam(model.deeptext.parameters())\n",
"img_opt = RAdam(model.deepimage.parameters())\n",
@@ -1433,7 +1433,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@@ -1446,7 +1446,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@@ -1472,7 +1472,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
@@ -1482,7 +1482,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
@@ -1490,7 +1490,7 @@
"text/plain": [
"WideDeep(\n",
" (wide): Wide(\n",
- " (wide_linear): Linear(in_features=356, out_features=1, bias=True)\n",
+ " (wide_linear): Embedding(357, 1, padding_idx=0)\n",
" )\n",
" (deepdense): DeepDense(\n",
" (embed_layers): ModuleDict(\n",
@@ -1635,7 +1635,7 @@
")"
]
},
- "execution_count": 21,
+ "execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@@ -1646,7 +1646,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 23,
"metadata": {},
"outputs": [
{
@@ -1668,8 +1668,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 25/25 [02:02<00:00, 4.88s/it, loss=128]\n",
- "valid: 100%|██████████| 7/7 [00:14<00:00, 2.09s/it, loss=94.5]\n"
+ "epoch 1: 100%|██████████| 25/25 [02:04<00:00, 4.98s/it, loss=1.24e+4]\n",
+ "valid: 100%|██████████| 7/7 [00:16<00:00, 2.33s/it, loss=9.26e+3]\n"
]
}
],
@@ -1687,13 +1687,13 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "{'lr_wide_0': [0.001, 0.001],\n",
+ "{'lr_wide_0': [0.03, 0.03],\n",
" 'lr_deepdense_0': [0.0001, 0.0001],\n",
" 'lr_deepdense_1': [0.0001, 0.0001],\n",
" 'lr_deepdense_2': [0.0001, 0.0001],\n",
@@ -1712,7 +1712,7 @@
" 'lr_deephead_0': [0.001, 0.001]}"
]
},
- "execution_count": 23,
+ "execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
diff --git a/examples/06_WarmUp_Model_Components.ipynb b/examples/06_WarmUp_Model_Components.ipynb
index 70176697..3ee10f57 100644
--- a/examples/06_WarmUp_Model_Components.ipynb
+++ b/examples/06_WarmUp_Model_Components.ipynb
@@ -259,7 +259,7 @@
"metadata": {},
"outputs": [],
"source": [
- "wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
+ "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
"deepdense = DeepDense(hidden_layers=[64,32], \n",
" deep_column_idx=preprocess_deep.deep_column_idx,\n",
" embed_input=preprocess_deep.embeddings_input,\n",
@@ -307,11 +307,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 127.54it/s, loss=0.476, metrics={'acc': 0.7808972948071559}]\n",
- "epoch 2: 100%|██████████| 153/153 [00:01<00:00, 126.88it/s, loss=0.373, metrics={'acc': 0.8048268625393494}]\n",
- "epoch 3: 100%|██████████| 153/153 [00:01<00:00, 141.92it/s, loss=0.365, metrics={'acc': 0.8136820822562895}]\n",
- "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 151.56it/s, loss=0.362, metrics={'acc': 0.8182312594374632}]\n",
- "epoch 5: 100%|██████████| 153/153 [00:00<00:00, 158.22it/s, loss=0.36, metrics={'acc': 0.8210477823561027}]\n",
+ "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 92.77it/s, loss=0.598, metrics={'acc': 0.697438128631024}] \n",
+ "epoch 2: 100%|██████████| 153/153 [00:01<00:00, 106.70it/s, loss=0.394, metrics={'acc': 0.758272976223991}] \n",
+ "epoch 3: 100%|██████████| 153/153 [00:01<00:00, 101.83it/s, loss=0.371, metrics={'acc': 0.7821172335542873}]\n",
+ "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 107.60it/s, loss=0.365, metrics={'acc': 0.7943976659074041}]\n",
+ "epoch 5: 100%|██████████| 153/153 [00:01<00:00, 105.79it/s, loss=0.363, metrics={'acc': 0.8018273488086403}]\n",
" 0%| | 0/153 [00:00, ?it/s]"
]
},
@@ -326,11 +326,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 78.65it/s, loss=0.397, metrics={'acc': 0.8198073691125158}]\n",
- "epoch 2: 100%|██████████| 153/153 [00:02<00:00, 75.69it/s, loss=0.348, metrics={'acc': 0.8221936229255862}]\n",
- "epoch 3: 100%|██████████| 153/153 [00:02<00:00, 74.79it/s, loss=0.343, metrics={'acc': 0.8243576126737133}]\n",
- "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 76.79it/s, loss=0.338, metrics={'acc': 0.8264502057402526}]\n",
- "epoch 5: 100%|██████████| 153/153 [00:01<00:00, 79.57it/s, loss=0.334, metrics={'acc': 0.8283059913495252}]\n",
+ "epoch 1: 100%|██████████| 153/153 [00:02<00:00, 73.75it/s, loss=0.398, metrics={'acc': 0.802813537054573}] \n",
+ "epoch 2: 100%|██████████| 153/153 [00:02<00:00, 73.16it/s, loss=0.349, metrics={'acc': 0.8077444782842372}]\n",
+ "epoch 3: 100%|██████████| 153/153 [00:02<00:00, 72.68it/s, loss=0.343, metrics={'acc': 0.811737005093031}] \n",
+ "epoch 4: 100%|██████████| 153/153 [00:02<00:00, 75.04it/s, loss=0.338, metrics={'acc': 0.8150868602075317}]\n",
+ "epoch 5: 100%|██████████| 153/153 [00:02<00:00, 74.24it/s, loss=0.335, metrics={'acc': 0.8180226755048243}]\n",
" 0%| | 0/153 [00:00, ?it/s]"
]
},
@@ -345,16 +345,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 114.10it/s, loss=0.36, metrics={'acc': 0.8323}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 123.16it/s, loss=0.364, metrics={'acc': 0.8325}]\n",
- "epoch 2: 100%|██████████| 153/153 [00:01<00:00, 113.50it/s, loss=0.359, metrics={'acc': 0.8325}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 122.56it/s, loss=0.364, metrics={'acc': 0.8327}]\n",
- "epoch 3: 100%|██████████| 153/153 [00:01<00:00, 110.90it/s, loss=0.359, metrics={'acc': 0.8325}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 119.56it/s, loss=0.363, metrics={'acc': 0.8327}]\n",
- "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 112.92it/s, loss=0.359, metrics={'acc': 0.8326}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 121.00it/s, loss=0.363, metrics={'acc': 0.8329}]\n",
- "epoch 5: 100%|██████████| 153/153 [00:01<00:00, 114.15it/s, loss=0.358, metrics={'acc': 0.8327}]\n",
- "valid: 100%|██████████| 39/39 [00:00<00:00, 108.91it/s, loss=0.363, metrics={'acc': 0.8329}]\n"
+ "epoch 1: 100%|██████████| 153/153 [00:02<00:00, 74.96it/s, loss=0.361, metrics={'acc': 0.8315}]\n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 136.56it/s, loss=0.365, metrics={'acc': 0.8318}]\n",
+ "epoch 2: 100%|██████████| 153/153 [00:02<00:00, 75.09it/s, loss=0.361, metrics={'acc': 0.8317}]\n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 129.90it/s, loss=0.365, metrics={'acc': 0.8321}]\n",
+ "epoch 3: 100%|██████████| 153/153 [00:02<00:00, 73.24it/s, loss=0.36, metrics={'acc': 0.8317}] \n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 131.37it/s, loss=0.365, metrics={'acc': 0.8321}]\n",
+ "epoch 4: 100%|██████████| 153/153 [00:02<00:00, 72.38it/s, loss=0.36, metrics={'acc': 0.832}] \n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 130.72it/s, loss=0.365, metrics={'acc': 0.8324}]\n",
+ "epoch 5: 100%|██████████| 153/153 [00:02<00:00, 72.24it/s, loss=0.359, metrics={'acc': 0.8322}]\n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 130.20it/s, loss=0.364, metrics={'acc': 0.8326}]\n"
]
}
],
@@ -450,7 +450,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- " 8%|▊ | 84/1001 [00:00<00:02, 416.73it/s]"
+ " 8%|▊ | 84/1001 [00:00<00:02, 418.96it/s]"
]
},
{
@@ -464,7 +464,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|██████████| 1001/1001 [00:02<00:00, 400.82it/s]\n"
+ "100%|██████████| 1001/1001 [00:02<00:00, 409.78it/s]\n"
]
},
{
@@ -497,7 +497,7 @@
"metadata": {},
"outputs": [],
"source": [
- "wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
+ "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
"deepdense = DeepDense( hidden_layers=[64,32], dropout=[0.2,0.2],\n",
" deep_column_idx=prepare_deep.deep_column_idx,\n",
" embed_input=prepare_deep.embeddings_input,\n",
@@ -519,7 +519,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -565,7 +565,7 @@
"text/plain": [
"WideDeep(\n",
" (wide): Wide(\n",
- " (wide_linear): Linear(in_features=356, out_features=1, bias=True)\n",
+ " (wide_linear): Embedding(357, 1, padding_idx=0)\n",
" )\n",
" (deepdense): Sequential(\n",
" (0): DeepDense(\n",
@@ -848,7 +848,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 25/25 [00:00<00:00, 58.03it/s, loss=127]\n",
+ "epoch 1: 100%|██████████| 25/25 [00:00<00:00, 34.36it/s, loss=1.64e+4]\n",
" 0%| | 0/25 [00:00, ?it/s]"
]
},
@@ -863,7 +863,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 25/25 [00:00<00:00, 47.80it/s, loss=116]\n",
+ "epoch 1: 100%|██████████| 25/25 [00:00<00:00, 46.93it/s, loss=1.37e+4]\n",
" 0%| | 0/25 [00:00, ?it/s]"
]
},
@@ -878,7 +878,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 25/25 [00:04<00:00, 5.94it/s, loss=132]\n",
+ "epoch 1: 100%|██████████| 25/25 [00:04<00:00, 5.53it/s, loss=1.74e+4]\n",
" 0%| | 0/25 [00:00, ?it/s]"
]
},
@@ -893,7 +893,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 25/25 [01:12<00:00, 2.92s/it, loss=119]\n",
+ "epoch 1: 100%|██████████| 25/25 [01:05<00:00, 2.63s/it, loss=1.41e+4]\n",
" 0%| | 0/25 [00:00, ?it/s]"
]
},
@@ -908,7 +908,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 25/25 [01:48<00:00, 4.34s/it, loss=108]\n",
+ "epoch 1: 100%|██████████| 25/25 [01:29<00:00, 3.57s/it, loss=1.17e+4]\n",
" 0%| | 0/25 [00:00, ?it/s]"
]
},
@@ -923,7 +923,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 25/25 [02:05<00:00, 5.01s/it, loss=106] \n",
+ "epoch 1: 100%|██████████| 25/25 [01:51<00:00, 4.46s/it, loss=1.11e+4]\n",
" 0%| | 0/25 [00:00, ?it/s]"
]
},
@@ -938,7 +938,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 25/25 [02:57<00:00, 7.11s/it, loss=105] \n",
+ "epoch 1: 100%|██████████| 25/25 [02:17<00:00, 5.48s/it, loss=1.11e+4]\n",
" 0%| | 0/25 [00:00, ?it/s]"
]
},
@@ -953,7 +953,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 25/25 [03:40<00:00, 8.83s/it, loss=104] \n",
+ "epoch 1: 100%|██████████| 25/25 [02:50<00:00, 6.83s/it, loss=1.08e+4]\n",
" 0%| | 0/25 [00:00, ?it/s]"
]
},
@@ -968,8 +968,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 25/25 [01:20<00:00, 3.23s/it, loss=120]\n",
- "valid: 100%|██████████| 7/7 [00:14<00:00, 2.06s/it, loss=109] \n"
+ "epoch 1: 100%|██████████| 25/25 [01:10<00:00, 2.83s/it, loss=1.45e+4]\n",
+ "valid: 100%|██████████| 7/7 [00:14<00:00, 2.00s/it, loss=1.19e+4]\n"
]
}
],
diff --git a/examples/adult_script.py b/examples/adult_script.py
index 840acae4..cdcb7750 100644
--- a/examples/adult_script.py
+++ b/examples/adult_script.py
@@ -53,7 +53,7 @@
)
X_deep = prepare_deep.fit_transform(df)
- wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)
+ wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
deepdense = DeepDense(
hidden_layers=[64, 32],
dropout=[0.2, 0.2],
@@ -63,7 +63,7 @@
)
model = WideDeep(wide=wide, deepdense=deepdense)
- wide_opt = torch.optim.Adam(model.wide.parameters())
+ wide_opt = torch.optim.Adam(model.wide.parameters(), lr=0.01)
deep_opt = RAdam(model.deepdense.parameters())
wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=3)
deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=5)
@@ -92,6 +92,6 @@
X_deep=X_deep,
target=target,
n_epochs=10,
- batch_size=256,
+ batch_size=64,
val_split=0.2,
)
diff --git a/examples/airbnb_script.py b/examples/airbnb_script.py
index 36af56d7..f62b960c 100644
--- a/examples/airbnb_script.py
+++ b/examples/airbnb_script.py
@@ -1,3 +1,4 @@
+import numpy as np
import torch
import pandas as pd
from torchvision.transforms import ToTensor, Normalize
@@ -64,7 +65,7 @@
image_processor = ImagePreprocessor(img_col=img_col, img_path=img_path)
X_images = image_processor.fit_transform(df)
- wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)
+ wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
deepdense = DeepDense(
hidden_layers=[64, 32],
dropout=[0.2, 0.2],
@@ -85,7 +86,7 @@
wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage
)
- wide_opt = torch.optim.Adam(model.wide.parameters())
+ wide_opt = torch.optim.Adam(model.wide.parameters(), lr=0.01)
deep_opt = torch.optim.Adam(model.deepdense.parameters())
text_opt = RAdam(model.deeptext.parameters())
img_opt = RAdam(model.deepimage.parameters())
diff --git a/examples/airbnb_script_multiclass.py b/examples/airbnb_script_multiclass.py
index 838d27ce..a0cda5b0 100644
--- a/examples/airbnb_script_multiclass.py
+++ b/examples/airbnb_script_multiclass.py
@@ -39,7 +39,8 @@
embed_cols=cat_embed_cols, continuous_cols=continuous_cols
)
X_deep = prepare_deep.fit_transform(df)
- wide = Wide(wide_dim=X_wide.shape[1], pred_dim=3)
+
+ wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=3)
deepdense = DeepDense(
hidden_layers=[64, 32],
dropout=[0.2, 0.2],
@@ -48,7 +49,10 @@
continuous_cols=continuous_cols,
)
model = WideDeep(wide=wide, deepdense=deepdense, pred_dim=3)
- model.compile(method="multiclass", metrics=[Accuracy, F1Score])
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.03)
+ model.compile(
+ method="multiclass", metrics=[Accuracy, F1Score], optimizers=optimizer
+ )
model.fit(
X_wide=X_wide,
diff --git a/pypi_README.md b/pypi_README.md
index a93b7b84..8b31d7cb 100644
--- a/pypi_README.md
+++ b/pypi_README.md
@@ -1,6 +1,11 @@
[![Build Status](https://travis-ci.org/jrzaurin/pytorch-widedeep.svg?branch=master)](https://travis-ci.org/jrzaurin/pytorch-widedeep)
[![Documentation Status](https://readthedocs.org/projects/pytorch-widedeep/badge/?version=latest)](https://pytorch-widedeep.readthedocs.io/en/latest/?badge=latest)
+[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
+Platform | Version Support
+---------|:---------------
+OSX | [![Python 3.6 3.7](https://img.shields.io/badge/python-3.6%20%203.7-blue.svg)](https://www.python.org/)
+Linux | [![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%203.7%203.8-blue.svg)](https://www.python.org/)
# pytorch-widedeep
@@ -23,6 +28,7 @@ that in mind there are two architectures that can be implemented with just a
few lines of code. For details on these architectures please visit the
[repo](https://github.com/jrzaurin/pytorch-widedeep).
+
### Installation
Install using pip:
@@ -48,14 +54,6 @@ cd pytorch-widedeep
pip install -e .
```
-### Examples
-
-There are a number of notebooks in the `examples` folder plus some additional
-files. These notebooks cover most of the utilities of this package and can
-also act as documentation. In the case that github does not render the
-notebooks, or it renders them missing some parts, they are saved as markdown
-files in the `docs` folder.
-
### Quick start
Binary classification with the [adult
@@ -64,6 +62,7 @@ using `Wide` and `DeepDense` and defaults settings.
```python
import pandas as pd
+import numpy as np
from sklearn.model_selection import train_test_split
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
@@ -102,7 +101,7 @@ target = df_train[target_col].values
# wide
preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
X_wide = preprocess_wide.fit_transform(df_train)
-wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)
+wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
# deepdense
preprocess_deep = DensePreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
@@ -164,4 +163,4 @@ their `Tokenizer` is the best in class.
The `ImageProcessor` class in this library uses code from the fantastic [Deep
Learning for Computer
Vision](https://www.pyimagesearch.com/deep-learning-computer-vision-python-book/)
-(DL4CV) book by Adrian Rosebrock.
\ No newline at end of file
+(DL4CV) book by Adrian Rosebrock.
diff --git a/pytorch_widedeep/callbacks.py b/pytorch_widedeep/callbacks.py
index 937c12dd..19b9cd86 100644
--- a/pytorch_widedeep/callbacks.py
+++ b/pytorch_widedeep/callbacks.py
@@ -115,7 +115,7 @@ class History(Callback):
r"""Callback that records events into a :obj:`History` object.
This callback runs by default within :obj:`WideDeep`. See
- :class:`pytorch_widedeep.models.wide_deep.WideDeep`. Documentation ss
+ :class:`pytorch_widedeep.models.wide_deep.WideDeep`. Documentation is
included here for completion.
"""
diff --git a/pytorch_widedeep/models/_warmup.py b/pytorch_widedeep/models/_warmup.py
index 5e858222..5b53fbb6 100644
--- a/pytorch_widedeep/models/_warmup.py
+++ b/pytorch_widedeep/models/_warmup.py
@@ -263,7 +263,7 @@ def _warm(
acc = self.metric(F.softmax(y_pred, dim=1), y)
t.set_postfix(metrics=acc, loss=avg_loss)
else:
- t.set_postfix(loss=np.sqrt(avg_loss))
+ t.set_postfix(loss=avg_loss)
def _steps_up_down(self, steps: int, n_epochs: int = 1) -> Tuple[int, int]:
r"""
diff --git a/pytorch_widedeep/models/_wd_dataset.py b/pytorch_widedeep/models/_wd_dataset.py
index fafde273..aa5dc6e4 100644
--- a/pytorch_widedeep/models/_wd_dataset.py
+++ b/pytorch_widedeep/models/_wd_dataset.py
@@ -11,12 +11,8 @@ class WideDeepDataset(Dataset):
Parameters
----------
- X_wide: np.ndarray, scipy csr sparse matrix.
- wide input.Note that if a sparse matrix is passed to the
- WideDeepDataset class, the loading process will be notably slow since
- the transformation to a dense matrix is done on an index basis 'on the
- fly'. At the moment this is the best option given the current support
- offered for sparse tensors for pytorch.
+ X_wide: np.ndarray
+ wide input
X_deep: np.ndarray
deepdense input
X_text: np.ndarray
@@ -24,13 +20,14 @@ class WideDeepDataset(Dataset):
X_img: np.ndarray
deepimage input
target: np.ndarray
- transforms: MultipleTransforms() object (which is in itself a torchvision
- Compose). See in models/_multiple_transforms.py
+ target array
+ transforms: :obj:`MultipleTransforms`
+ torchvision Compose object. See models/_multiple_transforms.py
"""
def __init__(
self,
- X_wide: Union[np.ndarray, sparse_matrix],
+ X_wide: np.ndarray,
X_deep: np.ndarray,
target: Optional[np.ndarray] = None,
X_text: Optional[np.ndarray] = None,
@@ -53,10 +50,7 @@ def __init__(
def __getitem__(self, idx: int):
# X_wide and X_deep are assumed to be *always* present
- if isinstance(self.X_wide, sparse_matrix):
- X = Bunch(wide=np.array(self.X_wide[idx].todense()).squeeze())
- else:
- X = Bunch(wide=self.X_wide[idx])
+ X = Bunch(wide=self.X_wide[idx])
X.deepdense = self.X_deep[idx]
if self.X_text is not None:
X.deeptext = self.X_text[idx]
diff --git a/pytorch_widedeep/models/wide.py b/pytorch_widedeep/models/wide.py
index 10cc7906..eaf4c0f3 100644
--- a/pytorch_widedeep/models/wide.py
+++ b/pytorch_widedeep/models/wide.py
@@ -1,16 +1,24 @@
+import math
+
+import torch
from torch import nn
from ..wdtypes import *
class Wide(nn.Module):
- r"""Simple linear layer that will receive the one-hot encoded `'wide'`
- input and connect it to the output neuron(s).
+ r"""Wide component
+
+ Linear model implemented via an Embedding layer connected to the output
+ neuron(s).
Parameters
-----------
wide_dim: int
- size of the input tensor
+ size of the Embedding layer. `wide_dim` is the summation of all the
+ individual values for all the features that go through the wide
+ component. For example, if the wide component receives 2 features with
+ 5 individual values each, `wide_dim = 10`
pred_dim: int
size of the ouput tensor containing the predictions
@@ -23,21 +31,34 @@ class Wide(nn.Module):
--------
>>> import torch
>>> from pytorch_widedeep.models import Wide
- >>> X = torch.empty(4, 4).random_(2)
- >>> wide = Wide(wide_dim=X.size(0), pred_dim=1)
+ >>> X = torch.empty(4, 4).random_(6)
+ >>> wide = Wide(wide_dim=X.unique().size(0), pred_dim=1)
>>> wide(X)
- tensor([[-0.8841],
- [-0.8633],
- [-1.2713],
- [-0.4762]], grad_fn=)
+ tensor([[-0.1138],
+ [ 0.4603],
+ [ 1.0762],
+ [ 0.8160]], grad_fn=)
"""
def __init__(self, wide_dim: int, pred_dim: int = 1):
super(Wide, self).__init__()
- self.wide_linear = nn.Linear(wide_dim, pred_dim)
+ self.wide_linear = nn.Embedding(wide_dim + 1, pred_dim, padding_idx=0)
+ # (Sum(Embedding) + bias) is equivalent to (OneHotVector + Linear)
+ self.bias = nn.Parameter(torch.zeros(pred_dim))
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ r"""initialize Embedding and bias like nn.Linear. See `original
+ implementation
+ `_.
+ """
+ nn.init.kaiming_uniform_(self.wide_linear.weight, a=math.sqrt(5))
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.wide_linear.weight)
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias, -bound, bound)
def forward(self, X: Tensor) -> Tensor: # type: ignore
- r"""Forward pass. Simply connecting the one-hot encoded input with the
- ouput neuron(s) """
- out = self.wide_linear(X.float())
+ r"""Forward pass. Simply connecting the Embedding layer with the ouput
+ neuron(s)"""
+ out = self.wide_linear(X.long()).sum(dim=1) + self.bias
return out
diff --git a/pytorch_widedeep/models/wide_deep.py b/pytorch_widedeep/models/wide_deep.py
index 0430af08..06b4e396 100644
--- a/pytorch_widedeep/models/wide_deep.py
+++ b/pytorch_widedeep/models/wide_deep.py
@@ -232,7 +232,15 @@ def compile(
Parameters
----------
method: str
- One of `regression`, `binary` or `multiclass`
+ One of `regression`, `binary` or `multiclass`. The default when
+ performing a `regression`, a `binary` classification or a
+ `multiclass` classification is the `mean squared error
+ `_
+ (MSE), `Binary Cross Entropy
+ `_
+ (BCE) and `Cross Entropy
+ `_
+ (CE) respectively.
optimizers: Union[Optimizer, Dict[str, Optimizer]], Optional, Default=AdamW
- An instance of ``pytorch``'s ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or
- a dictionary where there keys are the model components (i.e.
@@ -594,7 +602,7 @@ def fit(
loss=train_loss,
)
else:
- t.set_postfix(loss=np.sqrt(train_loss))
+ t.set_postfix(loss=train_loss)
if self.lr_scheduler:
self._lr_scheduler_step(step_location="on_batch_end")
self.callback_container.on_batch_end(batch=batch_idx)
@@ -626,7 +634,7 @@ def fit(
loss=val_loss,
)
else:
- v.set_postfix(loss=np.sqrt(val_loss))
+ v.set_postfix(loss=val_loss)
epoch_logs["val_loss"] = val_loss
if score is not None:
for k, v in score.items():
diff --git a/pytorch_widedeep/preprocessing/_preprocessors.py b/pytorch_widedeep/preprocessing/_preprocessors.py
index 0c04f79a..564c0d67 100644
--- a/pytorch_widedeep/preprocessing/_preprocessors.py
+++ b/pytorch_widedeep/preprocessing/_preprocessors.py
@@ -42,24 +42,28 @@ def fit_transform(self, df: pd.DataFrame):
class WidePreprocessor(BasePreprocessor):
r"""Preprocessor to prepare the wide input dataset
+ This Preprocessor prepares the data for the wide, linear component. This
+ linear model is implemented via an Embedding layer that is connected to
+ the output neuron. ``WidePreprocessor`` simply numerically encodes all the
+ unique values of all categorical columns ``wide_cols + crossed_cols``. See
+ the Example below.
+
Parameters
----------
wide_cols: List[str]
- List with the name of the columns that will be one-hot encoded and
- passed through the Wide model
+ List with the name of the columns that will label encoded and passed
+ through the Wide model
crossed_cols: List[Tuple[str, str]]
List of Tuples with the name of the columns that will be `'crossed'`
- and then one-hot encoded. e.g. [('education', 'occupation'), ...]
- already_dummies: List[str]
- List of columns that are already dummies/one-hot encoded, and
- therefore do not need to be processed
+ and then label encoded. e.g. [('education', 'occupation'), ...]
Attributes
----------
- one_hot_enc: :obj:`OneHotEncoder`
- an instance of :class:`sklearn.preprocessing.OneHotEncoder`
wide_crossed_cols: :obj:`List`
- List with the names of all columns that will be one-hot encoded
+ List with the names of all columns that will be label encoded
+ feature_dict: :obj:`Dict`
+ Dictionary where the keys are the result of pasting `colname + '_' +
+ column value` and the values are the corresponding mapped integer.
Examples
--------
@@ -69,67 +73,93 @@ class WidePreprocessor(BasePreprocessor):
>>> wide_cols = ['color']
>>> crossed_cols = [('color', 'size')]
>>> wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
- >>> wide_preprocessor.fit_transform(df)
- array([[0., 0., 1., 0., 0., 1.],
- [1., 0., 0., 1., 0., 0.],
- [0., 1., 0., 0., 1., 0.]])
+ >>> X_wide = wide_preprocessor.fit_transform(df)
+ >>> X_wide
+ array([[1, 4],
+ [2, 5],
+ [3, 6]])
+ >>> wide_preprocessor.feature_dict
+ {'color_r': 1,
+ 'color_b': 2,
+ 'color_g': 3,
+ 'color_size_r-s': 4,
+ 'color_size_b-n': 5,
+ 'color_size_g-l': 6}
+ >>> wide_preprocessor.inverse_transform(X_wide)
+ color color_size
+ 0 r r-s
+ 1 b b-n
+ 2 g g-l
"""
def __init__(
- self,
- wide_cols: List[str],
- crossed_cols=None,
- already_dummies: Optional[List[str]] = None,
- sparse=False,
- handle_unknown="ignore",
+ self, wide_cols: List[str], crossed_cols=None,
):
super(WidePreprocessor, self).__init__()
self.wide_cols = wide_cols
self.crossed_cols = crossed_cols
- self.already_dummies = already_dummies
- self.one_hot_enc = OneHotEncoder(sparse=sparse, handle_unknown=handle_unknown)
def fit(self, df: pd.DataFrame) -> BasePreprocessor:
"""Fits the Preprocessor and creates required attributes
"""
df_wide = self._prepare_wide(df)
self.wide_crossed_cols = df_wide.columns.tolist()
- if self.already_dummies:
- dummy_cols = [
- c for c in self.wide_crossed_cols if c not in self.already_dummies
- ]
- self.one_hot_enc.fit(df_wide[dummy_cols])
- else:
- self.one_hot_enc.fit(df_wide[self.wide_crossed_cols])
+ vocab = self._make_global_feature_list(df_wide[self.wide_crossed_cols])
+ # leave 0 as padding index
+ self.feature_dict = {v: i + 1 for i, v in enumerate(vocab)}
return self
- def transform(self, df: pd.DataFrame) -> Union[sparse_matrix, np.ndarray]:
- """Returns the processed dataframe as a one hot encoded dense or
- sparse matrix
+ def transform(self, df: pd.DataFrame) -> np.array:
+ r"""Returns the processed dataframe
"""
try:
- self.one_hot_enc.categories_
+ self.feature_dict
except:
raise NotFittedError(
"This WidePreprocessor instance is not fitted yet. "
"Call 'fit' with appropriate arguments before using this estimator."
)
df_wide = self._prepare_wide(df)
- if self.already_dummies:
- X_oh_1 = df_wide[self.already_dummies].values
- dummy_cols = [
- c for c in self.wide_crossed_cols if c not in self.already_dummies
- ]
- X_oh_2 = self.one_hot_enc.transform(df_wide[dummy_cols])
- return np.hstack((X_oh_1, X_oh_2))
- else:
- return self.one_hot_enc.transform(df_wide[self.wide_crossed_cols])
+ encoded = np.zeros([len(df_wide), len(self.wide_crossed_cols)], dtype=np.long)
+ for col_i, col in enumerate(self.wide_crossed_cols):
+ encoded[:, col_i] = df_wide[col].apply(
+ lambda x: self.feature_dict[col + "_" + str(x)]
+ if col + "_" + str(x) in self.feature_dict
+ else 0
+ )
+ return encoded.astype("int64")
+
+ def inverse_transform(self, encoded: np.ndarray) -> pd.DataFrame:
+ r"""Takes as input the output from the ``transform`` method and it will
+ return the original values.
- def fit_transform(self, df: pd.DataFrame) -> Union[sparse_matrix, np.ndarray]:
+ Parameters
+ ----------
+ encoded: np.ndarray
+ array with the output of the ``transform`` method
+ """
+ decoded = pd.DataFrame(encoded, columns=self.wide_crossed_cols)
+ inverse_dict = {k: v for v, k in self.feature_dict.items()}
+ decoded = decoded.applymap(lambda x: inverse_dict[x])
+ for col in decoded.columns:
+ rm_str = "".join([col, "_"])
+ decoded[col] = decoded[col].apply(lambda x: x.replace(rm_str, ""))
+ return decoded
+
+ def fit_transform(self, df: pd.DataFrame) -> np.ndarray:
"""Combines ``fit`` and ``transform``
"""
return self.fit(df).transform(df)
+ def _make_global_feature_list(self, df: pd.DataFrame) -> List:
+ vocab = []
+ for column in df.columns:
+ vocab += self._make_column_feature_list(df[column])
+ return vocab
+
+ def _make_column_feature_list(self, s: pd.Series) -> List:
+ return [s.name + "_" + str(x) for x in s.unique()]
+
def _cross_cols(self, df: pd.DataFrame):
df_cc = df.copy()
crossed_colnames = []
diff --git a/pytorch_widedeep/version.py b/pytorch_widedeep/version.py
index df124332..98a433b3 100644
--- a/pytorch_widedeep/version.py
+++ b/pytorch_widedeep/version.py
@@ -1 +1 @@
-__version__ = "0.4.2"
+__version__ = "0.4.5"
diff --git a/pytorch_widedeep/wdtypes.py b/pytorch_widedeep/wdtypes.py
index 232e7d83..ed46ddc3 100644
--- a/pytorch_widedeep/wdtypes.py
+++ b/pytorch_widedeep/wdtypes.py
@@ -18,7 +18,6 @@
from torch import Tensor
from torch.nn import Module
-from scipy.sparse.csr import csr_matrix as sparse_matrix
from torch.optim.optimizer import Optimizer
from torchvision.transforms import (
Pad,
diff --git a/tests/test_data_utils/test_du_wide.py b/tests/test_data_utils/test_du_wide.py
index ec250eb3..ed02d8ae 100644
--- a/tests/test_data_utils/test_du_wide.py
+++ b/tests/test_data_utils/test_du_wide.py
@@ -39,7 +39,7 @@ def create_test_dataset(input_type, with_crossed=True):
)
def test_preprocessor1(input_df, expected_shape):
wide_mtx = preprocessor1.fit_transform(input_df)
- assert wide_mtx.shape[1] == expected_shape
+ assert np.unique(wide_mtx).shape[0] == expected_shape
###############################################################################
@@ -63,4 +63,4 @@ def test_preprocessor1(input_df, expected_shape):
)
def test_prepare_wide_wo_crossed(input_df, expected_shape):
wide_mtx = preprocessor2.fit_transform(input_df)
- assert wide_mtx.shape[1] == expected_shape
+ assert np.unique(wide_mtx).shape[0] == expected_shape
diff --git a/tests/test_model_functioning/test_callbacks.py b/tests/test_model_functioning/test_callbacks.py
index 9c866815..6dcb6ff7 100644
--- a/tests/test_model_functioning/test_callbacks.py
+++ b/tests/test_model_functioning/test_callbacks.py
@@ -16,7 +16,7 @@
)
# Wide array
-X_wide = np.random.choice(2, (100, 100), p=[0.8, 0.2])
+X_wide = np.random.choice(50, (100, 10))
# Deep Array
colnames = list(string.ascii_lowercase)[:10]
@@ -38,7 +38,7 @@
###############################################################################
# Test that history saves the information adequately
###############################################################################
-wide = Wide(100, 1)
+wide = Wide(np.unique(X_wide).shape[0], 1)
deepdense = DeepDense(
hidden_layers=[32, 16],
dropout=[0.5, 0.5],
@@ -92,7 +92,7 @@ def test_history_callback(optimizers, schedulers, len_loss_output, len_lr_output
# Test that EarlyStopping stops as expected
###############################################################################
def test_early_stop():
- wide = Wide(100, 1)
+ wide = Wide(np.unique(X_wide).shape[0], 1)
deepdense = DeepDense(
hidden_layers=[32, 16],
dropout=[0.5, 0.5],
@@ -105,7 +105,7 @@ def test_early_stop():
method="binary",
callbacks=[
EarlyStopping(
- min_delta=0.1, patience=3, restore_best_weights=True, verbose=1
+ min_delta=5.0, patience=3, restore_best_weights=True, verbose=1
)
],
verbose=1,
@@ -122,7 +122,7 @@ def test_early_stop():
"save_best_only, max_save, n_files", [(True, 2, 2), (False, 2, 2), (False, 0, 5)]
)
def test_model_checkpoint(save_best_only, max_save, n_files):
- wide = Wide(100, 1)
+ wide = Wide(np.unique(X_wide).shape[0], 1)
deepdense = DeepDense(
hidden_layers=[32, 16],
dropout=[0.5, 0.5],
diff --git a/tests/test_model_functioning/test_data_inputs.py b/tests/test_model_functioning/test_data_inputs.py
index 7189c2a0..1819fb8b 100644
--- a/tests/test_model_functioning/test_data_inputs.py
+++ b/tests/test_model_functioning/test_data_inputs.py
@@ -14,7 +14,7 @@
)
# Wide array
-X_wide = np.random.choice(2, (100, 100), p=[0.8, 0.2])
+X_wide = np.random.choice(50, (100, 100))
# Deep Array
colnames = list(string.ascii_lowercase)[:10]
@@ -50,7 +50,7 @@
) = train_test_split(X_wide, X_deep, X_text, X_img, target)
# build model components
-wide = Wide(100, 1)
+wide = Wide(np.unique(X_wide).shape[0], 1)
deepdense = DeepDense(
hidden_layers=[32, 16],
dropout=[0.5, 0.5],
diff --git a/tests/test_model_functioning/test_fit_methods.py b/tests/test_model_functioning/test_fit_methods.py
index d41042c1..e3ec4134 100644
--- a/tests/test_model_functioning/test_fit_methods.py
+++ b/tests/test_model_functioning/test_fit_methods.py
@@ -6,7 +6,7 @@
from pytorch_widedeep.models import Wide, WideDeep, DeepDense
# Wide array
-X_wide = np.random.choice(2, (100, 100), p=[0.8, 0.2])
+X_wide = np.random.choice(50, (100, 100))
# Deep Array
colnames = list(string.ascii_lowercase)[:10]
@@ -51,7 +51,7 @@ def test_fit_methods(
pred_dim,
probs_dim,
):
- wide = Wide(100, pred_dim)
+ wide = Wide(np.unique(X_wide).shape[0], pred_dim)
deepdense = DeepDense(
hidden_layers=[32, 16],
dropout=[0.5, 0.5],
diff --git a/tests/test_model_functioning/test_focal_loss.py b/tests/test_model_functioning/test_focal_loss.py
index 2009e67a..82bb33d5 100644
--- a/tests/test_model_functioning/test_focal_loss.py
+++ b/tests/test_model_functioning/test_focal_loss.py
@@ -6,7 +6,7 @@
from pytorch_widedeep.models import Wide, WideDeep, DeepDense
# Wide array
-X_wide = np.random.choice(2, (100, 100), p=[0.8, 0.2])
+X_wide = np.random.choice(50, (100, 10))
# Deep Array
colnames = list(string.ascii_lowercase)[:10]
@@ -32,7 +32,7 @@
],
)
def test_focal_loss(X_wide, X_deep, target, method, pred_dim, probs_dim):
- wide = Wide(100, pred_dim)
+ wide = Wide(np.unique(X_wide).shape[0], pred_dim)
deepdense = DeepDense(
hidden_layers=[32, 16],
dropout=[0.5, 0.5],
diff --git a/tests/test_model_functioning/test_initializers.py b/tests/test_model_functioning/test_initializers.py
index e6129544..d97a6d79 100644
--- a/tests/test_model_functioning/test_initializers.py
+++ b/tests/test_model_functioning/test_initializers.py
@@ -19,7 +19,7 @@
)
# Wide array
-X_wide = np.random.choice(2, (100, 100), p=[0.8, 0.2])
+X_wide = np.random.choice(50, (100, 100))
# Deep Array
colnames = list(string.ascii_lowercase)[:10]
@@ -58,7 +58,7 @@
def test_initializers_1():
- wide = Wide(100, 1)
+ wide = Wide(np.unique(X_wide).shape[0], 1)
deepdense = DeepDense(
hidden_layers=[32, 16],
dropout=[0.5, 0.5],
diff --git a/tests/test_warm_up/test_warm_up_routines.py b/tests/test_warm_up/test_warm_up_routines.py
index 8fc2164e..4cd2dbdf 100644
--- a/tests/test_warm_up/test_warm_up_routines.py
+++ b/tests/test_warm_up/test_warm_up_routines.py
@@ -87,7 +87,7 @@ def loss_fn(y_pred, y_true):
target = torch.empty(100, 1).random_(0, 2)
# wide
-X_wide = torch.empty(100, 10).random_(0, 2)
+X_wide = torch.empty(100, 4).random_(1, 20)
# deep
colnames = list(string.ascii_lowercase)[:10]
@@ -107,7 +107,7 @@ def loss_fn(y_pred, y_true):
# Define the model components
# wide
-wide = Wide(10, 1)
+wide = Wide(X_wide.unique().size(0), 1)
if use_cuda:
wide.cuda()