Skip to content

Commit

Permalink
Merge pull request #17 from jrzaurin/precision_recall
Browse files Browse the repository at this point in the history
Precision recall
  • Loading branch information
jrzaurin committed Jul 21, 2020
2 parents 31c2d8e + 393ea43 commit 65465a4
Show file tree
Hide file tree
Showing 21 changed files with 573 additions and 308 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ final output neuron or neurons, depending on whether we are performing a
binary classification or regression, or a multi-class classification. The
components within the faded-pink rectangles are concatenated.

In math terms, and following the notation in the [paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated as:
In math terms, and following the notation in the
[paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated
as:

<p align="center">
<img width="500" src="docs/figures/architecture_1_math.png">
Expand Down Expand Up @@ -130,7 +132,7 @@ from sklearn.model_selection import train_test_split

from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import BinaryAccuracy
from pytorch_widedeep.metrics import Accuracy

# these next 4 lines are not directly related to pytorch-widedeep. I assume
# you have downloaded the dataset and place it in a dir called data/adult/
Expand Down Expand Up @@ -178,7 +180,7 @@ deepdense = DeepDense(

# build, compile and fit
model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method="binary", metrics=[BinaryAccuracy])
model.compile(method="binary", metrics=[Accuracy])
model.fit(
X_wide=X_wide,
X_deep=X_deep,
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.4.1
0.4.2
2 changes: 1 addition & 1 deletion code_style.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# sort imports
isort --recursive . pytorch_widedeep tests examples setup.py
isort . pytorch_widedeep tests examples setup.py
# Black code style
black . pytorch_widedeep tests examples setup.py
# flake8 standards
Expand Down
13 changes: 13 additions & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
pytorch-widedeep Examples
*****************************

This section provides links to example notebooks that may be helpful to better
understand the functionalities withing ``pytorch-widedeep`` and how to use
them to address different problems

* `Preprocessors and Utils <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/01_Preprocessors_and_utils.ipynb>`__
* `Model Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/02_Model_Components.ipynb>`__
* `Binary Classification with default parameters <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/03_Binary_Classification_with_Defaults.ipynb>`__
* `Binary Classification with varying parameters <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/04_Binary_Classification_Varying_Parameters.ipynb>`__
* `Regression with Images and Text <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/05_Regression_with_Images_and_Text.ipynb>`__
* `Warm up routines <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/06_WarmUp_Model_Components.ipynb>`__
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Documentation
Preprocessing <preprocessing>
Model Components <model_components>
Wide and Deep Models <wide_deep/index>
Examples <examples>


Introduction
Expand Down
4 changes: 2 additions & 2 deletions docs/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Prepare the wide and deep columns
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import BinaryAccuracy
from pytorch_widedeep.metrics import Accuracy
# prepare wide, crossed, embedding and continuous columns
wide_cols = [
Expand Down Expand Up @@ -83,7 +83,7 @@ Build, compile, fit and predict
# build, compile and fit
model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method="binary", metrics=[BinaryAccuracy])
model.compile(method="binary", metrics=[Accuracy])
model.fit(
X_wide=X_wide,
X_deep=X_deep,
Expand Down
19 changes: 17 additions & 2 deletions docs/wide_deep/metrics.rst
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
Metrics
=======

.. autoclass:: pytorch_widedeep.metrics.BinaryAccuracy
.. autoclass:: pytorch_widedeep.metrics.Accuracy
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: pytorch_widedeep.metrics.CategoricalAccuracy
.. autoclass:: pytorch_widedeep.metrics.Precision
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: pytorch_widedeep.metrics.Recall
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: pytorch_widedeep.metrics.FBetaScore
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: pytorch_widedeep.metrics.F1Score
:members:
:undoc-members:
:show-inheritance:
80 changes: 40 additions & 40 deletions examples/01_Preprocessors_and_utils.ipynb

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions examples/02_Model_Components.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@
{
"data": {
"text/plain": [
"tensor([[-0.0000, -1.0061, -0.0000, -0.9828, -0.0000, -0.0000, -0.9944, -1.0133],\n",
" [-0.0000, -0.9996, 0.0000, -1.0374, 0.0000, -0.0000, -1.0313, -0.0000],\n",
" [-0.8576, -1.0017, -0.0000, -0.9881, -0.0000, 0.0000, -0.0000, -0.0000],\n",
" [ 3.9816, 0.0000, 0.0000, 0.0000, 3.7309, 1.1728, 0.0000, -1.1160],\n",
" [-1.1339, -0.9925, -0.0000, -0.0000, -0.0000, 0.0000, -0.9638, 0.0000]],\n",
"tensor([[-0.0000, -0.9949, 3.8273, 0.0000, -1.3889, -2.9641, 0.0000, -0.0000],\n",
" [ 3.9123, -0.0000, -0.0000, 1.9555, -1.3561, 1.7069, -0.0000, 0.9275],\n",
" [-0.0000, -0.0000, 0.0000, -0.0000, 0.0000, -1.6489, -0.0000, -1.4985],\n",
" [-1.2736, 0.0000, -1.2819, 2.1232, 0.0000, 2.2767, -0.0000, 3.5354],\n",
" [-0.1726, -0.0000, -1.3275, -0.0000, -1.3703, 0.0000, -0.0000, -1.4637]],\n",
" grad_fn=<MulBackward0>)"
]
},
Expand Down Expand Up @@ -484,10 +484,10 @@
{
"data": {
"text/plain": [
"tensor([[-1.4630e-04, -6.1540e-04, -2.4541e-04, 2.7543e-01, 1.2993e-01,\n",
" -1.6553e-03, 6.7002e-02, 2.3974e-01],\n",
" [-9.9619e-04, -1.9412e-03, 1.2113e-01, 1.0122e-01, 2.9080e-01,\n",
" -2.0852e-03, -1.8016e-04, 2.7996e-02]], grad_fn=<LeakyReluBackward1>)"
"tensor([[-2.2825e-03, -8.3100e-04, -8.8423e-04, -1.1084e-04, 8.8529e-02,\n",
" -5.1577e-04, 2.8343e-01, -1.7071e-03],\n",
" [-1.8486e-03, -8.5602e-04, -1.8552e-03, 3.6481e-01, 9.0812e-02,\n",
" -9.6603e-04, 3.9017e-01, -2.6355e-03]], grad_fn=<LeakyReluBackward1>)"
]
},
"execution_count": 18,
Expand Down
52 changes: 26 additions & 26 deletions examples/03_Binary_Classification_with_Defaults.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -21,12 +21,12 @@
"\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
"from pytorch_widedeep.metrics import BinaryAccuracy"
"from pytorch_widedeep.metrics import Accuracy, Precision"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -185,7 +185,7 @@
"4 30 United-States <=50K "
]
},
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -197,7 +197,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -356,7 +356,7 @@
"4 30 United-States 0 "
]
},
"execution_count": 7,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -381,7 +381,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -394,7 +394,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -412,7 +412,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -437,7 +437,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -475,7 +475,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -489,7 +489,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -527,7 +527,7 @@
")"
]
},
"execution_count": 15,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -560,16 +560,16 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"model.compile(method='binary', metrics=[BinaryAccuracy])"
"model.compile(method='binary', metrics=[Accuracy, Precision])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -591,16 +591,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:02<00:00, 64.79it/s, loss=0.435, metrics={'acc': 0.7901}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.97it/s, loss=0.358, metrics={'acc': 0.799}]\n",
"epoch 2: 100%|██████████| 153/153 [00:02<00:00, 71.36it/s, loss=0.352, metrics={'acc': 0.8352}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.33it/s, loss=0.349, metrics={'acc': 0.8358}]\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 72.24it/s, loss=0.345, metrics={'acc': 0.8383}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.07it/s, loss=0.345, metrics={'acc': 0.8389}]\n",
"epoch 4: 100%|██████████| 153/153 [00:02<00:00, 70.39it/s, loss=0.341, metrics={'acc': 0.8404}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 123.29it/s, loss=0.343, metrics={'acc': 0.8406}]\n",
"epoch 5: 100%|██████████| 153/153 [00:02<00:00, 71.14it/s, loss=0.339, metrics={'acc': 0.8423}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.12it/s, loss=0.342, metrics={'acc': 0.8426}]\n"
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 102.41it/s, loss=0.585, metrics={'acc': 0.7512, 'prec': 0.1818}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 98.78it/s, loss=0.513, metrics={'acc': 0.754, 'prec': 0.2429}] \n",
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 117.30it/s, loss=0.481, metrics={'acc': 0.782, 'prec': 0.8287}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 106.49it/s, loss=0.454, metrics={'acc': 0.7866, 'prec': 0.8245}]\n",
"epoch 3: 100%|██████████| 153/153 [00:01<00:00, 124.78it/s, loss=0.44, metrics={'acc': 0.8055, 'prec': 0.781}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 115.36it/s, loss=0.425, metrics={'acc': 0.8077, 'prec': 0.7818}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 125.01it/s, loss=0.418, metrics={'acc': 0.814, 'prec': 0.7661}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 114.92it/s, loss=0.408, metrics={'acc': 0.8149, 'prec': 0.7671}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 116.57it/s, loss=0.404, metrics={'acc': 0.819, 'prec': 0.7527}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 108.89it/s, loss=0.397, metrics={'acc': 0.8203, 'prec': 0.7547}]\n"
]
}
],
Expand Down
Loading

0 comments on commit 65465a4

Please sign in to comment.