Skip to content

Commit

Permalink
Adapted examples to the new ,etrics
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzaurin committed Jul 21, 2020
1 parent 0c6a74a commit d4b92e5
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 224 deletions.
80 changes: 40 additions & 40 deletions examples/01_Preprocessors_and_utils.ipynb

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions examples/02_Model_Components.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@
{
"data": {
"text/plain": [
"tensor([[-0.0000, -1.0061, -0.0000, -0.9828, -0.0000, -0.0000, -0.9944, -1.0133],\n",
" [-0.0000, -0.9996, 0.0000, -1.0374, 0.0000, -0.0000, -1.0313, -0.0000],\n",
" [-0.8576, -1.0017, -0.0000, -0.9881, -0.0000, 0.0000, -0.0000, -0.0000],\n",
" [ 3.9816, 0.0000, 0.0000, 0.0000, 3.7309, 1.1728, 0.0000, -1.1160],\n",
" [-1.1339, -0.9925, -0.0000, -0.0000, -0.0000, 0.0000, -0.9638, 0.0000]],\n",
"tensor([[-0.0000, -0.9949, 3.8273, 0.0000, -1.3889, -2.9641, 0.0000, -0.0000],\n",
" [ 3.9123, -0.0000, -0.0000, 1.9555, -1.3561, 1.7069, -0.0000, 0.9275],\n",
" [-0.0000, -0.0000, 0.0000, -0.0000, 0.0000, -1.6489, -0.0000, -1.4985],\n",
" [-1.2736, 0.0000, -1.2819, 2.1232, 0.0000, 2.2767, -0.0000, 3.5354],\n",
" [-0.1726, -0.0000, -1.3275, -0.0000, -1.3703, 0.0000, -0.0000, -1.4637]],\n",
" grad_fn=<MulBackward0>)"
]
},
Expand Down Expand Up @@ -484,10 +484,10 @@
{
"data": {
"text/plain": [
"tensor([[-1.4630e-04, -6.1540e-04, -2.4541e-04, 2.7543e-01, 1.2993e-01,\n",
" -1.6553e-03, 6.7002e-02, 2.3974e-01],\n",
" [-9.9619e-04, -1.9412e-03, 1.2113e-01, 1.0122e-01, 2.9080e-01,\n",
" -2.0852e-03, -1.8016e-04, 2.7996e-02]], grad_fn=<LeakyReluBackward1>)"
"tensor([[-2.2825e-03, -8.3100e-04, -8.8423e-04, -1.1084e-04, 8.8529e-02,\n",
" -5.1577e-04, 2.8343e-01, -1.7071e-03],\n",
" [-1.8486e-03, -8.5602e-04, -1.8552e-03, 3.6481e-01, 9.0812e-02,\n",
" -9.6603e-04, 3.9017e-01, -2.6355e-03]], grad_fn=<LeakyReluBackward1>)"
]
},
"execution_count": 18,
Expand Down
52 changes: 26 additions & 26 deletions examples/03_Binary_Classification_with_Defaults.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -21,12 +21,12 @@
"\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
"from pytorch_widedeep.metrics import BinaryAccuracy"
"from pytorch_widedeep.metrics import Accuracy, Precision"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -185,7 +185,7 @@
"4 30 United-States <=50K "
]
},
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -197,7 +197,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -356,7 +356,7 @@
"4 30 United-States 0 "
]
},
"execution_count": 7,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -381,7 +381,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -394,7 +394,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -412,7 +412,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -437,7 +437,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -475,7 +475,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -489,7 +489,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -527,7 +527,7 @@
")"
]
},
"execution_count": 15,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -560,16 +560,16 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"model.compile(method='binary', metrics=[BinaryAccuracy])"
"model.compile(method='binary', metrics=[Accuracy, Precision])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -591,16 +591,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:02<00:00, 64.79it/s, loss=0.435, metrics={'acc': 0.7901}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.97it/s, loss=0.358, metrics={'acc': 0.799}]\n",
"epoch 2: 100%|██████████| 153/153 [00:02<00:00, 71.36it/s, loss=0.352, metrics={'acc': 0.8352}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.33it/s, loss=0.349, metrics={'acc': 0.8358}]\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 72.24it/s, loss=0.345, metrics={'acc': 0.8383}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.07it/s, loss=0.345, metrics={'acc': 0.8389}]\n",
"epoch 4: 100%|██████████| 153/153 [00:02<00:00, 70.39it/s, loss=0.341, metrics={'acc': 0.8404}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 123.29it/s, loss=0.343, metrics={'acc': 0.8406}]\n",
"epoch 5: 100%|██████████| 153/153 [00:02<00:00, 71.14it/s, loss=0.339, metrics={'acc': 0.8423}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.12it/s, loss=0.342, metrics={'acc': 0.8426}]\n"
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 102.41it/s, loss=0.585, metrics={'acc': 0.7512, 'prec': 0.1818}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 98.78it/s, loss=0.513, metrics={'acc': 0.754, 'prec': 0.2429}] \n",
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 117.30it/s, loss=0.481, metrics={'acc': 0.782, 'prec': 0.8287}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 106.49it/s, loss=0.454, metrics={'acc': 0.7866, 'prec': 0.8245}]\n",
"epoch 3: 100%|██████████| 153/153 [00:01<00:00, 124.78it/s, loss=0.44, metrics={'acc': 0.8055, 'prec': 0.781}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 115.36it/s, loss=0.425, metrics={'acc': 0.8077, 'prec': 0.7818}]\n",
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 125.01it/s, loss=0.418, metrics={'acc': 0.814, 'prec': 0.7661}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 114.92it/s, loss=0.408, metrics={'acc': 0.8149, 'prec': 0.7671}]\n",
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 116.57it/s, loss=0.404, metrics={'acc': 0.819, 'prec': 0.7527}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 108.89it/s, loss=0.397, metrics={'acc': 0.8203, 'prec': 0.7547}]\n"
]
}
],
Expand Down
Loading

0 comments on commit d4b92e5

Please sign in to comment.