diff --git a/VERSION b/VERSION index 0495c4a8..f0bb29e7 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.2.3 +1.3.0 diff --git a/examples/notebooks/18_feature_importance_via_attention_weights.ipynb b/examples/notebooks/18_feature_importance_via_attention_weights.ipynb new file mode 100644 index 00000000..6ae3f447 --- /dev/null +++ b/examples/notebooks/18_feature_importance_via_attention_weights.ipynb @@ -0,0 +1,1151 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0e4d6c56", + "metadata": {}, + "source": [ + "### Feature Importance via the attention weights\n", + "\n", + "I will start by saying that I consider this feature of the library purely experimental. First of all I think there are multiple ways one could address finding the features importances for these models. However, and more importantly, one has to bear in mind that even tree-based algorithms on the same dataset produce different feature importances. This is more \"dramatic\" if one uses different techniques, such as shap or feature permutation (see for example [this](https://reneelin2019.medium.com/calculating-feature-importance-with-permutation-to-explain-the-model-income-prediction-example-38a52e67441d) and references therein). All this to say that, sometimes, feature importance is just a measure contained within the experiment run, and for the model used.\n", + "\n", + "With that in mind, each instantiation of a deep tabular model, that has millions of trainable parameters, will potentially produce a different set of feature importances, even if the model has the same architecture. Moreover, this effect will become more apparent if the dataset is relatively easy and there are dependent/related columns so that one could get to the same success metric with different parameters. \n", + "\n", + "In summary, feature importances are implemented in this librray for all attention-based models for tabular data, with the exception of the `TabPerceiver`. However this functionality has to be used and interpreted with care and consider of value within the 'universe' (or context) of the model with which these features were produced.\n", + "\n", + "Nonetheless, let's have a look to how one would access to the feature importances when using this library. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "365df643", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "import numpy as np\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import accuracy_score\n", + "\n", + "\n", + "from pytorch_widedeep import Trainer\n", + "from pytorch_widedeep.models import TabTransformer, ContextAttentionMLP, WideDeep\n", + "from pytorch_widedeep.callbacks import EarlyStopping\n", + "from pytorch_widedeep.metrics import Accuracy\n", + "from pytorch_widedeep.datasets import load_adult\n", + "from pytorch_widedeep.preprocessing import TabPreprocessor" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0f935ef4", + "metadata": {}, + "outputs": [], + "source": [ + "# use_cuda = torch.cuda.is_available()\n", + "df = load_adult(as_frame=True)\n", + "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n", + "df[\"income_label\"] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n", + "df.drop([\"income\", \"fnlwgt\", \"educational_num\"], axis=1, inplace=True)\n", + "target_colname = \"income_label\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "202e7377", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclasseducationmarital_statusoccupationrelationshipracegendercapital_gaincapital_losshours_per_weeknative_countryincome_label
025Private11thNever-marriedMachine-op-inspctOwn-childBlackMale0040United-States0
138PrivateHS-gradMarried-civ-spouseFarming-fishingHusbandWhiteMale0050United-States0
228Local-govAssoc-acdmMarried-civ-spouseProtective-servHusbandWhiteMale0040United-States1
344PrivateSome-collegeMarried-civ-spouseMachine-op-inspctHusbandBlackMale7688040United-States1
418?Some-collegeNever-married?Own-childWhiteFemale0030United-States0
\n", + "
" + ], + "text/plain": [ + " age workclass education marital_status occupation \\\n", + "0 25 Private 11th Never-married Machine-op-inspct \n", + "1 38 Private HS-grad Married-civ-spouse Farming-fishing \n", + "2 28 Local-gov Assoc-acdm Married-civ-spouse Protective-serv \n", + "3 44 Private Some-college Married-civ-spouse Machine-op-inspct \n", + "4 18 ? Some-college Never-married ? \n", + "\n", + " relationship race gender capital_gain capital_loss hours_per_week \\\n", + "0 Own-child Black Male 0 0 40 \n", + "1 Husband White Male 0 0 50 \n", + "2 Husband White Male 0 0 40 \n", + "3 Husband Black Male 7688 0 40 \n", + "4 Own-child White Female 0 0 30 \n", + "\n", + " native_country income_label \n", + "0 United-States 0 \n", + "1 United-States 0 \n", + "2 United-States 1 \n", + "3 United-States 1 \n", + "4 United-States 0 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4264cba5", + "metadata": {}, + "outputs": [], + "source": [ + "cat_embed_cols = []\n", + "for col in df.columns:\n", + " if df[col].dtype == \"O\" or df[col].nunique() < 200 and col != target_colname:\n", + " cat_embed_cols.append(col)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "48ae0977", + "metadata": {}, + "outputs": [], + "source": [ + "# all cols will be categorical\n", + "assert len(cat_embed_cols) == df.shape[1] - 1" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0abe7bb4", + "metadata": {}, + "outputs": [], + "source": [ + "train, test = train_test_split(\n", + " df, test_size=0.1, random_state=1, stratify=df[[target_colname]]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a762aee5", + "metadata": {}, + "outputs": [], + "source": [ + "tab_preprocessor = TabPreprocessor(cat_embed_cols=cat_embed_cols, with_attention=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a233f9a8", + "metadata": {}, + "outputs": [], + "source": [ + "X_tab_train = tab_preprocessor.fit_transform(train)\n", + "X_tab_test = tab_preprocessor.transform(test)\n", + "target = train[target_colname].values" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c17233fe", + "metadata": {}, + "outputs": [], + "source": [ + "tab_transformer = TabTransformer(\n", + " column_idx=tab_preprocessor.column_idx,\n", + " cat_embed_input=tab_preprocessor.cat_embed_input,\n", + " cat_embed_dropout=0.0,\n", + " input_dim=8,\n", + " n_heads=2,\n", + " n_blocks=1,\n", + " attn_dropout=0.1,\n", + " transformer_activation=\"relu\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "5010c08d", + "metadata": {}, + "outputs": [], + "source": [ + "model = WideDeep(deeptabular=tab_transformer)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "145803d9", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "24133088", + "metadata": {}, + "outputs": [], + "source": [ + "lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " optimizer,\n", + " threshold=0.001,\n", + " threshold_mode=\"abs\",\n", + " patience=10,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "06330871", + "metadata": {}, + "outputs": [], + "source": [ + "early_stopping = EarlyStopping(\n", + " min_delta=0.001, patience=30, restore_best_weights=True, verbose=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "405befbe", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model,\n", + " objective=\"binary\",\n", + " optimizers=optimizer,\n", + " lr_schedulers=lr_scheduler,\n", + " reducelronplateau_criterion=\"loss\",\n", + " callbacks=[early_stopping],\n", + " metrics=[Accuracy],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "997a5283", + "metadata": {}, + "source": [ + "The feature importances will be computed after training, using a sample of the training dataset of size `feature_importance_sample_size`" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "459d0f69", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 100%|████| 275/275 [00:03<00:00, 90.98it/s, loss=0.332, metrics={'acc': 0.847}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 148.31it/s, loss=0.292, metrics={'acc': 0.866}]\n", + "epoch 2: 100%|████| 275/275 [00:02<00:00, 96.35it/s, loss=0.289, metrics={'acc': 0.868}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 147.03it/s, loss=0.278, metrics={'acc': 0.8717}]\n", + "epoch 3: 100%|████| 275/275 [00:02<00:00, 94.23it/s, loss=0.28, metrics={'acc': 0.8719}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 139.13it/s, loss=0.275, metrics={'acc': 0.8732}]\n", + "epoch 4: 100%|████| 275/275 [00:03<00:00, 90.15it/s, loss=0.276, metrics={'acc': 0.872}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 133.21it/s, loss=0.275, metrics={'acc': 0.8706}]\n", + "epoch 5: 100%|███| 275/275 [00:03<00:00, 86.75it/s, loss=0.274, metrics={'acc': 0.8736}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 132.85it/s, loss=0.275, metrics={'acc': 0.8717}]\n", + "epoch 6: 100%|███| 275/275 [00:03<00:00, 89.16it/s, loss=0.272, metrics={'acc': 0.8742}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 129.50it/s, loss=0.273, metrics={'acc': 0.8733}]\n", + "epoch 7: 100%|███| 275/275 [00:03<00:00, 88.98it/s, loss=0.271, metrics={'acc': 0.8748}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 129.81it/s, loss=0.273, metrics={'acc': 0.8715}]\n", + "epoch 8: 100%|████| 275/275 [00:03<00:00, 88.63it/s, loss=0.27, metrics={'acc': 0.8748}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 129.81it/s, loss=0.272, metrics={'acc': 0.8739}]\n", + "epoch 9: 100%|███| 275/275 [00:03<00:00, 88.68it/s, loss=0.269, metrics={'acc': 0.8762}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 129.69it/s, loss=0.271, metrics={'acc': 0.8741}]\n", + "epoch 10: 100%|██| 275/275 [00:03<00:00, 89.56it/s, loss=0.267, metrics={'acc': 0.8761}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 132.71it/s, loss=0.271, metrics={'acc': 0.874}]\n", + "epoch 11: 100%|██| 275/275 [00:03<00:00, 81.74it/s, loss=0.267, metrics={'acc': 0.8757}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 120.83it/s, loss=0.272, metrics={'acc': 0.8744}]\n", + "epoch 12: 100%|██| 275/275 [00:03<00:00, 75.27it/s, loss=0.266, metrics={'acc': 0.8768}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 113.89it/s, loss=0.272, metrics={'acc': 0.8741}]\n", + "epoch 13: 100%|██| 275/275 [00:03<00:00, 76.29it/s, loss=0.265, metrics={'acc': 0.8787}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 113.24it/s, loss=0.273, metrics={'acc': 0.8726}]\n", + "epoch 14: 100%|██| 275/275 [00:03<00:00, 76.86it/s, loss=0.265, metrics={'acc': 0.8777}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 114.75it/s, loss=0.273, metrics={'acc': 0.8741}]\n", + "epoch 15: 100%|██| 275/275 [00:03<00:00, 72.97it/s, loss=0.264, metrics={'acc': 0.8779}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 115.13it/s, loss=0.273, metrics={'acc': 0.8735}]\n", + "epoch 16: 100%|██| 275/275 [00:03<00:00, 71.00it/s, loss=0.263, metrics={'acc': 0.8789}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 159.56it/s, loss=0.274, metrics={'acc': 0.8718}]\n", + "epoch 17: 100%|██| 275/275 [00:02<00:00, 98.42it/s, loss=0.263, metrics={'acc': 0.8789}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 172.29it/s, loss=0.274, metrics={'acc': 0.8727}]\n", + "epoch 18: 100%|█| 275/275 [00:02<00:00, 111.46it/s, loss=0.262, metrics={'acc': 0.8786}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 172.87it/s, loss=0.274, metrics={'acc': 0.8737}]\n", + "epoch 19: 100%|█| 275/275 [00:02<00:00, 109.79it/s, loss=0.261, metrics={'acc': 0.8785}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 167.10it/s, loss=0.275, metrics={'acc': 0.8737}]\n", + "epoch 20: 100%|██| 275/275 [00:02<00:00, 99.87it/s, loss=0.261, metrics={'acc': 0.8789}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 156.25it/s, loss=0.275, metrics={'acc': 0.8724}]\n", + "epoch 21: 100%|██| 275/275 [00:03<00:00, 83.22it/s, loss=0.256, metrics={'acc': 0.8813}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 107.03it/s, loss=0.274, metrics={'acc': 0.8727}]\n", + "epoch 22: 100%|██| 275/275 [00:03<00:00, 79.09it/s, loss=0.255, metrics={'acc': 0.8832}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 140.85it/s, loss=0.275, metrics={'acc': 0.8733}]\n", + "epoch 23: 100%|██| 275/275 [00:02<00:00, 99.53it/s, loss=0.254, metrics={'acc': 0.8822}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 128.29it/s, loss=0.276, metrics={'acc': 0.8724}]\n", + "epoch 24: 100%|██| 275/275 [00:03<00:00, 86.39it/s, loss=0.253, metrics={'acc': 0.8846}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 112.40it/s, loss=0.276, metrics={'acc': 0.8724}]\n", + "epoch 25: 100%|███| 275/275 [00:03<00:00, 75.98it/s, loss=0.253, metrics={'acc': 0.884}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 123.45it/s, loss=0.276, metrics={'acc': 0.8715}]\n", + "epoch 26: 100%|██| 275/275 [00:03<00:00, 80.57it/s, loss=0.253, metrics={'acc': 0.8839}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 134.19it/s, loss=0.277, metrics={'acc': 0.8706}]\n", + "epoch 27: 100%|██| 275/275 [00:03<00:00, 83.52it/s, loss=0.253, metrics={'acc': 0.8838}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 120.70it/s, loss=0.277, metrics={'acc': 0.8709}]\n", + "epoch 28: 100%|██| 275/275 [00:03<00:00, 87.30it/s, loss=0.252, metrics={'acc': 0.8837}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 136.44it/s, loss=0.277, metrics={'acc': 0.8706}]\n", + "epoch 29: 100%|██| 275/275 [00:03<00:00, 84.86it/s, loss=0.252, metrics={'acc': 0.8847}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 126.91it/s, loss=0.278, metrics={'acc': 0.8698}]\n", + "epoch 30: 100%|██| 275/275 [00:03<00:00, 83.93it/s, loss=0.252, metrics={'acc': 0.8837}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 130.11it/s, loss=0.277, metrics={'acc': 0.8715}]\n", + "epoch 31: 100%|██| 275/275 [00:03<00:00, 81.44it/s, loss=0.252, metrics={'acc': 0.8824}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 120.13it/s, loss=0.278, metrics={'acc': 0.8711}]\n", + "epoch 32: 100%|███| 275/275 [00:03<00:00, 81.13it/s, loss=0.25, metrics={'acc': 0.8843}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 109.88it/s, loss=0.278, metrics={'acc': 0.8705}]\n", + "epoch 33: 100%|███| 275/275 [00:03<00:00, 69.72it/s, loss=0.25, metrics={'acc': 0.8838}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 118.55it/s, loss=0.278, metrics={'acc': 0.8699}]\n", + "epoch 34: 100%|██| 275/275 [00:03<00:00, 85.04it/s, loss=0.251, metrics={'acc': 0.8843}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 136.37it/s, loss=0.278, metrics={'acc': 0.8698}]\n", + "epoch 35: 100%|███| 275/275 [00:02<00:00, 92.70it/s, loss=0.25, metrics={'acc': 0.8841}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 145.48it/s, loss=0.278, metrics={'acc': 0.8697}]\n", + "epoch 36: 100%|██| 275/275 [00:02<00:00, 95.07it/s, loss=0.251, metrics={'acc': 0.8843}]\n", + "valid: 100%|████████| 69/69 [00:00<00:00, 144.44it/s, loss=0.278, metrics={'acc': 0.87}]\n", + "epoch 37: 100%|███| 275/275 [00:02<00:00, 93.25it/s, loss=0.25, metrics={'acc': 0.8837}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 144.45it/s, loss=0.278, metrics={'acc': 0.8701}]\n", + "epoch 38: 100%|███| 275/275 [00:03<00:00, 91.39it/s, loss=0.25, metrics={'acc': 0.8843}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 139.95it/s, loss=0.278, metrics={'acc': 0.8702}]\n", + "epoch 39: 100%|███| 275/275 [00:02<00:00, 92.64it/s, loss=0.25, metrics={'acc': 0.8847}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 140.00it/s, loss=0.278, metrics={'acc': 0.8703}]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best Epoch: 9. Best val_loss: 0.27098\n", + "Restoring model weights from the end of the best epoch\n" + ] + } + ], + "source": [ + "trainer.fit(\n", + " X_tab=X_tab_train,\n", + " target=target,\n", + " val_split=0.2,\n", + " n_epochs=100,\n", + " batch_size=128,\n", + " validation_freq=1,\n", + " feature_importance_sample_size=1000,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "cd603a44", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'age': 0.098023,\n", + " 'workclass': 0.07621125,\n", + " 'education': 0.07414728,\n", + " 'marital_status': 0.113280274,\n", + " 'occupation': 0.07292068,\n", + " 'relationship': 0.08008792,\n", + " 'race': 0.104180396,\n", + " 'gender': 0.07037963,\n", + " 'capital_gain': 0.06584223,\n", + " 'capital_loss': 0.07647487,\n", + " 'hours_per_week': 0.09369389,\n", + " 'native_country': 0.0747586}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.feature_importance" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "8342dfdc", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "predict: 100%|█████████████████████████████████████████| 39/39 [00:00<00:00, 213.15it/s]\n" + ] + } + ], + "source": [ + "preds = trainer.predict(X_tab=X_tab_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "7cfdabda", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8734902763561925" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "accuracy_score(preds, test.income_label)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "5b49516b", + "metadata": {}, + "outputs": [], + "source": [ + "test.reset_index(drop=True, inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "6ad4b3b6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclasseducationmarital_statusoccupationrelationshipracegendercapital_gaincapital_losshours_per_weeknative_countryincome_label
026PrivateSome-collegeNever-marriedExec-managerialNot-in-familyWhiteMale0060United-States0
\n", + "
" + ], + "text/plain": [ + " age workclass education marital_status occupation relationship \\\n", + "0 26 Private Some-college Never-married Exec-managerial Not-in-family \n", + "\n", + " race gender capital_gain capital_loss hours_per_week native_country \\\n", + "0 White Male 0 0 60 United-States \n", + "\n", + " income_label \n", + "0 0 " + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test[test.income_label == 0].head(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "6419edff", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclasseducationmarital_statusoccupationrelationshipracegendercapital_gaincapital_losshours_per_weeknative_countryincome_label
336Local-govDoctorateMarried-civ-spouseProf-specialtyHusbandWhiteMale0188750United-States1
\n", + "
" + ], + "text/plain": [ + " age workclass education marital_status occupation relationship \\\n", + "3 36 Local-gov Doctorate Married-civ-spouse Prof-specialty Husband \n", + "\n", + " race gender capital_gain capital_loss hours_per_week native_country \\\n", + "3 White Male 0 1887 50 United-States \n", + "\n", + " income_label \n", + "3 1 " + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test[test.income_label == 1].head(1)" + ] + }, + { + "cell_type": "markdown", + "id": "a6860689", + "metadata": {}, + "source": [ + "To get the feature importance of a test dataset, simply use the `explain` method" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "ba6432ba", + "metadata": {}, + "outputs": [], + "source": [ + "feat_imp_per_sample = trainer.explain(X_tab_test, save_step_masks=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "f350e822", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['marital_status',\n", + " 'race',\n", + " 'age',\n", + " 'capital_loss',\n", + " 'occupation',\n", + " 'native_country',\n", + " 'workclass',\n", + " 'education',\n", + " 'gender',\n", + " 'relationship',\n", + " 'hours_per_week',\n", + " 'capital_gain']" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(test.iloc[0].index[np.argsort(-feat_imp_per_sample[0])])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "f05a8302", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['marital_status',\n", + " 'race',\n", + " 'capital_loss',\n", + " 'occupation',\n", + " 'education',\n", + " 'native_country',\n", + " 'hours_per_week',\n", + " 'relationship',\n", + " 'age',\n", + " 'workclass',\n", + " 'gender',\n", + " 'capital_gain']" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(test.iloc[3].index[np.argsort(-feat_imp_per_sample[3])])" + ] + }, + { + "cell_type": "markdown", + "id": "558fc30f", + "metadata": {}, + "source": [ + "We could do the same with the `ContextAttentionMLP`" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "45b8a869", + "metadata": {}, + "outputs": [], + "source": [ + "context_attn_mlp = ContextAttentionMLP(\n", + " column_idx=tab_preprocessor.column_idx,\n", + " cat_embed_input=tab_preprocessor.cat_embed_input,\n", + " cat_embed_dropout=0.0,\n", + " input_dim=16,\n", + " attn_dropout=0.1,\n", + " attn_activation=\"relu\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "07f0c434", + "metadata": {}, + "outputs": [], + "source": [ + "mlp_model = WideDeep(deeptabular=context_attn_mlp)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "d9c1ef75", + "metadata": {}, + "outputs": [], + "source": [ + "mlp_optimizer = torch.optim.Adam(mlp_model.parameters(), lr=0.01, weight_decay=0.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "4c2f21ca", + "metadata": {}, + "outputs": [], + "source": [ + "mlp_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " mlp_optimizer,\n", + " threshold=0.001,\n", + " threshold_mode=\"abs\",\n", + " patience=10,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "504c1755", + "metadata": {}, + "outputs": [], + "source": [ + "mlp_early_stopping = EarlyStopping(\n", + " min_delta=0.001, patience=30, restore_best_weights=True, verbose=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "78c79a3d", + "metadata": {}, + "outputs": [], + "source": [ + "mlp_trainer = Trainer(\n", + " mlp_model,\n", + " objective=\"binary\",\n", + " optimizers=mlp_optimizer,\n", + " lr_schedulers=mlp_lr_scheduler,\n", + " reducelronplateau_criterion=\"loss\",\n", + " callbacks=[mlp_early_stopping],\n", + " metrics=[Accuracy],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "af40082b", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 100%|███| 275/275 [00:03<00:00, 91.33it/s, loss=0.395, metrics={'acc': 0.8139}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 125.66it/s, loss=0.306, metrics={'acc': 0.8577}]\n", + "epoch 2: 100%|███| 275/275 [00:03<00:00, 87.31it/s, loss=0.333, metrics={'acc': 0.8396}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 140.84it/s, loss=0.291, metrics={'acc': 0.8631}]\n", + "epoch 3: 100%|███| 275/275 [00:03<00:00, 84.57it/s, loss=0.323, metrics={'acc': 0.8494}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 141.55it/s, loss=0.293, metrics={'acc': 0.8632}]\n", + "epoch 4: 100%|███| 275/275 [00:03<00:00, 84.62it/s, loss=0.312, metrics={'acc': 0.8518}]\n", + "valid: 100%|████████| 69/69 [00:00<00:00, 156.92it/s, loss=0.3, metrics={'acc': 0.8543}]\n", + "epoch 5: 100%|████| 275/275 [00:03<00:00, 85.74it/s, loss=0.31, metrics={'acc': 0.8552}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 130.66it/s, loss=0.303, metrics={'acc': 0.8545}]\n", + "epoch 6: 100%|███| 275/275 [00:02<00:00, 93.70it/s, loss=0.303, metrics={'acc': 0.8579}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 161.71it/s, loss=0.291, metrics={'acc': 0.8609}]\n", + "epoch 7: 100%|███| 275/275 [00:03<00:00, 81.87it/s, loss=0.302, metrics={'acc': 0.8584}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 126.83it/s, loss=0.294, metrics={'acc': 0.8579}]\n", + "epoch 8: 100%|█████| 275/275 [00:04<00:00, 62.67it/s, loss=0.299, metrics={'acc': 0.86}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 161.52it/s, loss=0.284, metrics={'acc': 0.8653}]\n", + "epoch 9: 100%|███| 275/275 [00:03<00:00, 85.13it/s, loss=0.296, metrics={'acc': 0.8615}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 153.13it/s, loss=0.282, metrics={'acc': 0.8672}]\n", + "epoch 10: 100%|███| 275/275 [00:02<00:00, 92.61it/s, loss=0.296, metrics={'acc': 0.861}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 156.41it/s, loss=0.281, metrics={'acc': 0.8716}]\n", + "epoch 11: 100%|██| 275/275 [00:02<00:00, 99.43it/s, loss=0.295, metrics={'acc': 0.8619}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 155.07it/s, loss=0.283, metrics={'acc': 0.8649}]\n", + "epoch 12: 100%|██| 275/275 [00:03<00:00, 83.91it/s, loss=0.294, metrics={'acc': 0.8603}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 108.25it/s, loss=0.285, metrics={'acc': 0.8652}]\n", + "epoch 13: 100%|██| 275/275 [00:03<00:00, 77.13it/s, loss=0.292, metrics={'acc': 0.8631}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 116.60it/s, loss=0.284, metrics={'acc': 0.8656}]\n", + "epoch 14: 100%|██| 275/275 [00:03<00:00, 81.83it/s, loss=0.294, metrics={'acc': 0.8622}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 126.96it/s, loss=0.283, metrics={'acc': 0.8665}]\n", + "epoch 15: 100%|██| 275/275 [00:03<00:00, 81.45it/s, loss=0.292, metrics={'acc': 0.8621}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 128.30it/s, loss=0.284, metrics={'acc': 0.8664}]\n", + "epoch 16: 100%|██| 275/275 [00:03<00:00, 81.18it/s, loss=0.293, metrics={'acc': 0.8626}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 129.63it/s, loss=0.283, metrics={'acc': 0.8659}]\n", + "epoch 17: 100%|██| 275/275 [00:03<00:00, 76.42it/s, loss=0.292, metrics={'acc': 0.8619}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 112.97it/s, loss=0.285, metrics={'acc': 0.8658}]\n", + "epoch 18: 100%|██| 275/275 [00:03<00:00, 71.98it/s, loss=0.294, metrics={'acc': 0.8609}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 117.61it/s, loss=0.279, metrics={'acc': 0.8691}]\n", + "epoch 19: 100%|███| 275/275 [00:03<00:00, 76.80it/s, loss=0.29, metrics={'acc': 0.8652}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 113.99it/s, loss=0.283, metrics={'acc': 0.8707}]\n", + "epoch 20: 100%|██| 275/275 [00:03<00:00, 69.47it/s, loss=0.293, metrics={'acc': 0.8613}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 115.32it/s, loss=0.283, metrics={'acc': 0.8646}]\n", + "epoch 21: 100%|███| 275/275 [00:03<00:00, 69.08it/s, loss=0.29, metrics={'acc': 0.8636}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 117.65it/s, loss=0.289, metrics={'acc': 0.8621}]\n", + "epoch 22: 100%|██| 275/275 [00:03<00:00, 69.49it/s, loss=0.291, metrics={'acc': 0.8626}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 125.45it/s, loss=0.284, metrics={'acc': 0.8675}]\n", + "epoch 23: 100%|███| 275/275 [00:03<00:00, 81.16it/s, loss=0.29, metrics={'acc': 0.8637}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 124.96it/s, loss=0.284, metrics={'acc': 0.8649}]\n", + "epoch 24: 100%|███| 275/275 [00:03<00:00, 76.37it/s, loss=0.29, metrics={'acc': 0.8646}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 114.71it/s, loss=0.28, metrics={'acc': 0.8678}]\n", + "epoch 25: 100%|██| 275/275 [00:03<00:00, 68.76it/s, loss=0.289, metrics={'acc': 0.8634}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 113.49it/s, loss=0.282, metrics={'acc': 0.8675}]\n", + "epoch 26: 100%|██| 275/275 [00:03<00:00, 68.98it/s, loss=0.291, metrics={'acc': 0.8624}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 121.90it/s, loss=0.282, metrics={'acc': 0.8668}]\n", + "epoch 27: 100%|██| 275/275 [00:03<00:00, 72.16it/s, loss=0.289, metrics={'acc': 0.8648}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 109.51it/s, loss=0.28, metrics={'acc': 0.8674}]\n", + "epoch 28: 100%|██| 275/275 [00:03<00:00, 68.83it/s, loss=0.288, metrics={'acc': 0.8653}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 110.77it/s, loss=0.283, metrics={'acc': 0.8637}]\n", + "epoch 29: 100%|███| 275/275 [00:04<00:00, 68.63it/s, loss=0.29, metrics={'acc': 0.8655}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 113.82it/s, loss=0.283, metrics={'acc': 0.8643}]\n", + "epoch 30: 100%|███| 275/275 [00:04<00:00, 67.75it/s, loss=0.284, metrics={'acc': 0.868}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 106.96it/s, loss=0.276, metrics={'acc': 0.8708}]\n", + "epoch 31: 100%|██| 275/275 [00:04<00:00, 61.95it/s, loss=0.284, metrics={'acc': 0.8664}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 97.29it/s, loss=0.276, metrics={'acc': 0.8699}]\n", + "epoch 32: 100%|██| 275/275 [00:05<00:00, 53.82it/s, loss=0.284, metrics={'acc': 0.8663}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 87.62it/s, loss=0.275, metrics={'acc': 0.8703}]\n", + "epoch 33: 100%|██| 275/275 [00:05<00:00, 52.92it/s, loss=0.282, metrics={'acc': 0.8669}]\n", + "valid: 100%|█████████| 69/69 [00:00<00:00, 87.98it/s, loss=0.276, metrics={'acc': 0.87}]\n", + "epoch 34: 100%|██| 275/275 [00:05<00:00, 54.57it/s, loss=0.282, metrics={'acc': 0.8666}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 95.37it/s, loss=0.275, metrics={'acc': 0.8708}]\n", + "epoch 35: 100%|██| 275/275 [00:04<00:00, 58.35it/s, loss=0.281, metrics={'acc': 0.8701}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 98.99it/s, loss=0.276, metrics={'acc': 0.8693}]\n", + "epoch 36: 100%|██| 275/275 [00:04<00:00, 57.91it/s, loss=0.281, metrics={'acc': 0.8673}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 96.97it/s, loss=0.277, metrics={'acc': 0.8689}]\n", + "epoch 37: 100%|██| 275/275 [00:04<00:00, 58.82it/s, loss=0.279, metrics={'acc': 0.8687}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 98.79it/s, loss=0.277, metrics={'acc': 0.8676}]\n", + "epoch 38: 100%|███| 275/275 [00:04<00:00, 61.86it/s, loss=0.28, metrics={'acc': 0.8707}]\n", + "valid: 100%|████████| 69/69 [00:00<00:00, 102.77it/s, loss=0.276, metrics={'acc': 0.87}]\n", + "epoch 39: 100%|███| 275/275 [00:04<00:00, 61.11it/s, loss=0.281, metrics={'acc': 0.869}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 101.65it/s, loss=0.276, metrics={'acc': 0.8683}]\n", + "epoch 40: 100%|███| 275/275 [00:04<00:00, 59.90it/s, loss=0.28, metrics={'acc': 0.8692}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 102.92it/s, loss=0.278, metrics={'acc': 0.8681}]\n", + "epoch 41: 100%|██| 275/275 [00:04<00:00, 60.41it/s, loss=0.281, metrics={'acc': 0.8674}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 98.36it/s, loss=0.276, metrics={'acc': 0.8703}]\n", + "epoch 42: 100%|██| 275/275 [00:04<00:00, 55.65it/s, loss=0.281, metrics={'acc': 0.8691}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 97.88it/s, loss=0.276, metrics={'acc': 0.8693}]\n", + "epoch 43: 100%|███| 275/275 [00:04<00:00, 59.98it/s, loss=0.28, metrics={'acc': 0.8679}]\n", + "valid: 100%|███████| 69/69 [00:00<00:00, 98.41it/s, loss=0.276, metrics={'acc': 0.8701}]\n", + "epoch 44: 100%|██| 275/275 [00:04<00:00, 59.67it/s, loss=0.279, metrics={'acc': 0.8688}]\n", + "valid: 100%|█████████| 69/69 [00:00<00:00, 98.48it/s, loss=0.276, metrics={'acc': 0.87}]\n", + "epoch 45: 100%|██| 275/275 [00:04<00:00, 61.62it/s, loss=0.279, metrics={'acc': 0.8694}]\n", + "valid: 100%|████████| 69/69 [00:00<00:00, 103.90it/s, loss=0.276, metrics={'acc': 0.87}]\n", + "epoch 46: 100%|███| 275/275 [00:04<00:00, 61.50it/s, loss=0.28, metrics={'acc': 0.8688}]\n", + "valid: 100%|████████| 69/69 [00:00<00:00, 102.02it/s, loss=0.276, metrics={'acc': 0.87}]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 47: 100%|██| 275/275 [00:03<00:00, 70.92it/s, loss=0.279, metrics={'acc': 0.8693}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 121.47it/s, loss=0.276, metrics={'acc': 0.8699}]\n", + "epoch 48: 100%|██| 275/275 [00:03<00:00, 75.09it/s, loss=0.277, metrics={'acc': 0.8701}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 126.88it/s, loss=0.276, metrics={'acc': 0.8706}]\n", + "epoch 49: 100%|██| 275/275 [00:03<00:00, 74.75it/s, loss=0.279, metrics={'acc': 0.8694}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 130.57it/s, loss=0.276, metrics={'acc': 0.8705}]\n", + "epoch 50: 100%|██| 275/275 [00:03<00:00, 81.06it/s, loss=0.278, metrics={'acc': 0.8704}]\n", + "valid: 100%|████████| 69/69 [00:00<00:00, 137.71it/s, loss=0.276, metrics={'acc': 0.87}]\n", + "epoch 51: 100%|██| 275/275 [00:03<00:00, 79.21it/s, loss=0.278, metrics={'acc': 0.8689}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 134.51it/s, loss=0.276, metrics={'acc': 0.8702}]\n", + "epoch 52: 100%|███| 275/275 [00:03<00:00, 81.95it/s, loss=0.279, metrics={'acc': 0.869}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 136.85it/s, loss=0.276, metrics={'acc': 0.8701}]\n", + "epoch 53: 100%|██| 275/275 [00:03<00:00, 82.84it/s, loss=0.279, metrics={'acc': 0.8702}]\n", + "valid: 100%|████████| 69/69 [00:00<00:00, 139.43it/s, loss=0.276, metrics={'acc': 0.87}]\n", + "epoch 54: 100%|███| 275/275 [00:03<00:00, 84.35it/s, loss=0.28, metrics={'acc': 0.8678}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 139.32it/s, loss=0.276, metrics={'acc': 0.8701}]\n", + "epoch 55: 100%|██| 275/275 [00:03<00:00, 82.40it/s, loss=0.278, metrics={'acc': 0.8694}]\n", + "valid: 100%|████████| 69/69 [00:00<00:00, 140.20it/s, loss=0.276, metrics={'acc': 0.87}]\n", + "epoch 56: 100%|███| 275/275 [00:03<00:00, 84.02it/s, loss=0.28, metrics={'acc': 0.8686}]\n", + "valid: 100%|████████| 69/69 [00:00<00:00, 138.68it/s, loss=0.276, metrics={'acc': 0.87}]\n", + "epoch 57: 100%|████| 275/275 [00:03<00:00, 83.66it/s, loss=0.28, metrics={'acc': 0.868}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 139.25it/s, loss=0.276, metrics={'acc': 0.8702}]\n", + "epoch 58: 100%|███| 275/275 [00:03<00:00, 83.12it/s, loss=0.279, metrics={'acc': 0.869}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 134.59it/s, loss=0.276, metrics={'acc': 0.8702}]\n", + "epoch 59: 100%|██| 275/275 [00:03<00:00, 81.14it/s, loss=0.278, metrics={'acc': 0.8682}]\n", + "valid: 100%|██████| 69/69 [00:00<00:00, 135.79it/s, loss=0.276, metrics={'acc': 0.8702}]\n", + "epoch 60: 100%|███| 275/275 [00:03<00:00, 80.78it/s, loss=0.28, metrics={'acc': 0.8684}]\n", + "valid: 100%|████████| 69/69 [00:00<00:00, 133.84it/s, loss=0.276, metrics={'acc': 0.87}]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best Epoch: 30. Best val_loss: 0.27563\n", + "Restoring model weights from the end of the best epoch\n" + ] + } + ], + "source": [ + "mlp_trainer.fit(\n", + " X_tab=X_tab_train,\n", + " target=target,\n", + " val_split=0.2,\n", + " n_epochs=100,\n", + " batch_size=128,\n", + " validation_freq=1,\n", + " feature_importance_sample_size=1000,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "0a4fdc77", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'age': 0.103683405,\n", + " 'workclass': 0.066264994,\n", + " 'education': 0.10014994,\n", + " 'marital_status': 0.1235957,\n", + " 'occupation': 0.12825337,\n", + " 'relationship': 0.15234835,\n", + " 'race': 0.061964743,\n", + " 'gender': 0.05328226,\n", + " 'capital_gain': 0.03052448,\n", + " 'capital_loss': 0.037544865,\n", + " 'hours_per_week': 0.07689079,\n", + " 'native_country': 0.0654971}" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mlp_trainer.feature_importance" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "c10be01f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "predict: 100%|█████████████████████████████████████████| 39/39 [00:00<00:00, 211.89it/s]\n" + ] + } + ], + "source": [ + "mlp_preds = mlp_trainer.predict(X_tab=X_tab_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "0b2a46a9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.873899692937564" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "accuracy_score(mlp_preds, test.income_label)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/scripts/adult_census_feature_importance.py b/examples/scripts/adult_census_feature_importance.py new file mode 100644 index 00000000..64e7447d --- /dev/null +++ b/examples/scripts/adult_census_feature_importance.py @@ -0,0 +1,162 @@ +from sklearn.model_selection import train_test_split + +from pytorch_widedeep import Trainer +from pytorch_widedeep.models import ( + SAINT, + TabNet, + WideDeep, + FTTransformer, + TabFastFormer, + TabTransformer, + SelfAttentionMLP, + ContextAttentionMLP, +) +from pytorch_widedeep.metrics import Accuracy +from pytorch_widedeep.datasets import load_adult +from pytorch_widedeep.preprocessing import TabPreprocessor + +# use_cuda = torch.cuda.is_available() + +df = load_adult(as_frame=True) +df.columns = [c.replace("-", "_") for c in df.columns] +df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int) +df.drop("income", axis=1, inplace=True) +target_colname = "income_label" + +cat_embed_cols = [] +for col in df.columns: + if df[col].dtype == "O" or df[col].nunique() < 200 and col != target_colname: + cat_embed_cols.append(col) + +train, test = train_test_split( + df, test_size=0.1, random_state=1, stratify=df[[target_colname]] +) + +with_cls_token = True +tab_preprocessor = TabPreprocessor( + cat_embed_cols=cat_embed_cols, with_attention=True, with_cls_token=with_cls_token +) + +X_tab_train = tab_preprocessor.fit_transform(train) +X_tab_test = tab_preprocessor.transform(test) +target = train[target_colname].values + + +tab_transformer = TabTransformer( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + input_dim=8, + n_heads=2, + n_blocks=2, +) + +saint = SAINT( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + input_dim=8, + n_heads=2, + n_blocks=2, +) + +tab_fastformer = TabFastFormer( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + input_dim=8, + n_heads=2, + n_blocks=2, +) + +ft_transformer = FTTransformer( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + input_dim=8, + n_heads=2, + n_blocks=2, + kv_compression_factor=1.0, # if this is diff than one, we cannot do this +) + +context_attention_mlp = ContextAttentionMLP( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + input_dim=16, + attn_dropout=0.2, + n_blocks=3, +) + +self_attention_mlp = SelfAttentionMLP( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + n_blocks=3, +) + +for attention_based_model in [ + tab_transformer, + saint, + tab_fastformer, + ft_transformer, + context_attention_mlp, + self_attention_mlp, +]: + model = WideDeep(deeptabular=attention_based_model) # type: ignore[arg-type] + + trainer = Trainer( + model, + objective="binary", + metrics=[Accuracy], + ) + + trainer.fit( + X_tab=X_tab_train, + target=target, + n_epochs=5, + batch_size=128, + val_split=0.2, + feature_importance_sample_size=1000, + ) + + feat_imp_per_sample = trainer.explain(X_tab_test) + + assert ( + len(trainer.feature_importance) == X_tab_train.shape[1] - 1 + if with_cls_token + else X_tab_train.shape[1] + ) and feat_imp_per_sample.shape == test[cat_embed_cols].shape + + +train, test = train_test_split( + df, test_size=0.1, random_state=1, stratify=df[[target_colname]] +) + +tab_preprocessor = TabPreprocessor(cat_embed_cols=cat_embed_cols) + +X_tab_train = tab_preprocessor.fit_transform(train) +X_tab_test = tab_preprocessor.transform(test) +target = train[target_colname].values + +tabnet = TabNet( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, +) + +model = WideDeep(deeptabular=tabnet) + +trainer = Trainer( + model, + objective="binary", + metrics=[Accuracy], +) + +trainer.fit( + X_tab=X_tab_train, + target=target, + n_epochs=5, + batch_size=128, + val_split=0.2, + feature_importance_sample_size=1000, +) +feat_imp_per_sample = trainer.explain(X_tab_test, save_step_masks=False) + +assert ( + len(trainer.feature_importance) == X_tab_train.shape[1] + and feat_imp_per_sample.shape == test[cat_embed_cols].shape +) diff --git a/pytorch_widedeep/callbacks.py b/pytorch_widedeep/callbacks.py index 9fd38e2b..d73957c1 100644 --- a/pytorch_widedeep/callbacks.py +++ b/pytorch_widedeep/callbacks.py @@ -655,6 +655,7 @@ def on_epoch_end( if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 + self.best_epoch = epoch if self.restore_best_weights: self.state_dict = copy.deepcopy(self.model.state_dict()) else: @@ -665,7 +666,9 @@ def on_epoch_end( def on_train_end(self, logs: Optional[Dict] = None): if self.stopped_epoch > 0 and self.verbose > 0: - print("Epoch %05d: early stopping" % (self.stopped_epoch + 1)) + print( + f"Best Epoch: {self.best_epoch + 1}. Best {self.monitor}: {self.best:.5f}" + ) if self.restore_best_weights and self.state_dict is not None: if self.verbose > 0: print("Restoring model weights from the end of the best epoch") diff --git a/pytorch_widedeep/models/tabular/mlp/context_attention_mlp.py b/pytorch_widedeep/models/tabular/mlp/context_attention_mlp.py index f98e2647..d31beded 100644 --- a/pytorch_widedeep/models/tabular/mlp/context_attention_mlp.py +++ b/pytorch_widedeep/models/tabular/mlp/context_attention_mlp.py @@ -188,7 +188,7 @@ def output_dim(self) -> int: ) @property - def attention_weights(self) -> List: + def attention_weights(self) -> List[Tensor]: r"""List with the attention weights per block The shape of the attention weights is $(N, F)$, where $N$ is the batch diff --git a/pytorch_widedeep/models/tabular/mlp/self_attention_mlp.py b/pytorch_widedeep/models/tabular/mlp/self_attention_mlp.py index 8ca6890d..4b17b312 100644 --- a/pytorch_widedeep/models/tabular/mlp/self_attention_mlp.py +++ b/pytorch_widedeep/models/tabular/mlp/self_attention_mlp.py @@ -198,7 +198,7 @@ def output_dim(self) -> int: ) @property - def attention_weights(self) -> List: + def attention_weights(self) -> List[Tensor]: r"""List with the attention weights per block The shape of the attention weights is $(N, H, F, F)$, where $N$ is the diff --git a/pytorch_widedeep/models/tabular/transformers/ft_transformer.py b/pytorch_widedeep/models/tabular/transformers/ft_transformer.py index c7c8149d..7b1589a6 100644 --- a/pytorch_widedeep/models/tabular/transformers/ft_transformer.py +++ b/pytorch_widedeep/models/tabular/transformers/ft_transformer.py @@ -275,7 +275,7 @@ def output_dim(self) -> int: ) @property - def attention_weights(self) -> List: + def attention_weights(self) -> List[Tensor]: r"""List with the attention weights per block The shape of the attention weights is: $(N, H, F, k)$, where $N$ is diff --git a/pytorch_widedeep/models/tabular/transformers/saint.py b/pytorch_widedeep/models/tabular/transformers/saint.py index 27da26d4..cfed7488 100644 --- a/pytorch_widedeep/models/tabular/transformers/saint.py +++ b/pytorch_widedeep/models/tabular/transformers/saint.py @@ -248,7 +248,7 @@ def output_dim(self) -> int: ) @property - def attention_weights(self) -> List: + def attention_weights(self) -> List[Tuple[Tensor, Tensor]]: r"""List with the attention weights. Each element of the list is a tuple where the first and the second elements are the column and row attention weights respectively diff --git a/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py b/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py index 02601fa6..17e9114b 100644 --- a/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py +++ b/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py @@ -281,7 +281,7 @@ def output_dim(self) -> int: ) @property - def attention_weights(self) -> List: + def attention_weights(self) -> List[Tuple[Tensor, Tensor]]: r"""List with the attention weights. Each element of the list is a tuple where the first and second elements are the $\alpha$ and $\beta$ attention weights in the paper. diff --git a/pytorch_widedeep/models/tabular/transformers/tab_transformer.py b/pytorch_widedeep/models/tabular/transformers/tab_transformer.py index 295ec7ba..868e3cbf 100644 --- a/pytorch_widedeep/models/tabular/transformers/tab_transformer.py +++ b/pytorch_widedeep/models/tabular/transformers/tab_transformer.py @@ -284,7 +284,7 @@ def output_dim(self) -> int: ) @property - def attention_weights(self) -> List: + def attention_weights(self) -> List[Tensor]: r"""List with the attention weights per block The shape of the attention weights is $(N, H, F, F)$, where $N$ is the diff --git a/pytorch_widedeep/models/wide_deep.py b/pytorch_widedeep/models/wide_deep.py index 35798e03..43943edd 100644 --- a/pytorch_widedeep/models/wide_deep.py +++ b/pytorch_widedeep/models/wide_deep.py @@ -15,7 +15,7 @@ warnings.filterwarnings("default", category=UserWarning) -WDModel = Union[nn.Module, BaseWDModelComponent] +WDModel = Union[nn.Module, nn.Sequential, BaseWDModelComponent] class WideDeep(nn.Module): diff --git a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py index 231072ae..1da60c99 100644 --- a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py +++ b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py @@ -1,5 +1,6 @@ import os import sys +import warnings from abc import ABC, abstractmethod import numpy as np @@ -174,17 +175,34 @@ def _set_callbacks(self, callbacks: Any): self.callback_container.set_model(self.cd_model) self.callback_container.set_trainer(self) - def _restore_best_weights(self): - already_restored = any( - [ - ( - callback.__class__.__name__ == "EarlyStopping" - and callback.restore_best_weights - ) - for callback in self.callback_container.callbacks - ] - ) + def _restore_best_weights(self): # noqa: C901 + early_stopping_min_delta = None + model_checkpoint_min_delta = None + already_restored = False + + for callback in self.callback_container.callbacks: + if ( + callback.__class__.__name__ == "EarlyStopping" + and callback.restore_best_weights + ): + early_stopping_min_delta = callback.min_delta + already_restored = True + + if callback.__class__.__name__ == "ModelCheckpoint": + model_checkpoint_min_delta = callback.min_delta + + if ( + early_stopping_min_delta is not None + and model_checkpoint_min_delta is not None + ) and (early_stopping_min_delta != model_checkpoint_min_delta): + warnings.warn( + "'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. " + "This implies a different definition of 'improvement' for these two callbacks", + UserWarning, + ) + if already_restored: + # already restored via EarlyStopping pass else: for callback in self.callback_container.callbacks: diff --git a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py index 29515320..aea5b9d6 100644 --- a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py +++ b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py @@ -1,5 +1,6 @@ import os import sys +import warnings from abc import ABC, abstractmethod import numpy as np @@ -121,17 +122,34 @@ def _set_callbacks(self, callbacks: Any): self.callback_container.set_model(self.ed_model) self.callback_container.set_trainer(self) - def _restore_best_weights(self): - already_restored = any( - [ - ( - callback.__class__.__name__ == "EarlyStopping" - and callback.restore_best_weights - ) - for callback in self.callback_container.callbacks - ] - ) + def _restore_best_weights(self): # noqa: C901 + early_stopping_min_delta = None + model_checkpoint_min_delta = None + already_restored = False + + for callback in self.callback_container.callbacks: + if ( + callback.__class__.__name__ == "EarlyStopping" + and callback.restore_best_weights + ): + early_stopping_min_delta = callback.min_delta + already_restored = True + + if callback.__class__.__name__ == "ModelCheckpoint": + model_checkpoint_min_delta = callback.min_delta + + if ( + early_stopping_min_delta is not None + and model_checkpoint_min_delta is not None + ) and (early_stopping_min_delta != model_checkpoint_min_delta): + warnings.warn( + "'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. " + "This implies a different definition of 'improvement' for these two callbacks", + UserWarning, + ) + if already_restored: + # already restored via EarlyStopping pass else: for callback in self.callback_container.callbacks: diff --git a/pytorch_widedeep/training/_base_bayesian_trainer.py b/pytorch_widedeep/training/_base_bayesian_trainer.py index cefc111d..e6da0285 100644 --- a/pytorch_widedeep/training/_base_bayesian_trainer.py +++ b/pytorch_widedeep/training/_base_bayesian_trainer.py @@ -1,5 +1,6 @@ import os import sys +import warnings from abc import ABC, abstractmethod import numpy as np @@ -120,17 +121,34 @@ def save( ): raise NotImplementedError("Trainer.save method not implemented") - def _restore_best_weights(self): - already_restored = any( - [ - ( - callback.__class__.__name__ == "EarlyStopping" - and callback.restore_best_weights - ) - for callback in self.callback_container.callbacks - ] - ) + def _restore_best_weights(self): # noqa: C901 + early_stopping_min_delta = None + model_checkpoint_min_delta = None + already_restored = False + + for callback in self.callback_container.callbacks: + if ( + callback.__class__.__name__ == "EarlyStopping" + and callback.restore_best_weights + ): + early_stopping_min_delta = callback.min_delta + already_restored = True + + if callback.__class__.__name__ == "ModelCheckpoint": + model_checkpoint_min_delta = callback.min_delta + + if ( + early_stopping_min_delta is not None + and model_checkpoint_min_delta is not None + ) and (early_stopping_min_delta != model_checkpoint_min_delta): + warnings.warn( + "'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. " + "This implies a different definition of 'improvement' for these two callbacks", + UserWarning, + ) + if already_restored: + # already restored via EarlyStopping pass else: for callback in self.callback_container.callbacks: diff --git a/pytorch_widedeep/training/_base_trainer.py b/pytorch_widedeep/training/_base_trainer.py index 4f0527c0..3c11f77c 100644 --- a/pytorch_widedeep/training/_base_trainer.py +++ b/pytorch_widedeep/training/_base_trainer.py @@ -30,7 +30,6 @@ ) from pytorch_widedeep.initializers import Initializer, MultipleInitializer from pytorch_widedeep.training._trainer_utils import alias_to_loss -from pytorch_widedeep.models.tabular.tabnet._utils import create_explain_matrix from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer from pytorch_widedeep.training._multiple_transforms import MultipleTransforms from pytorch_widedeep.training._loss_and_obj_aliases import _ObjectiveToMethod @@ -67,7 +66,6 @@ def __init__( self.model = model if self.model.is_tabnet: self.lambda_sparse = kwargs.get("lambda_sparse", 1e-3) - self.reducing_matrix = create_explain_matrix(self.model) self.model.to(self.device) self.model.wd_device = self.device diff --git a/pytorch_widedeep/training/_feature_importance.py b/pytorch_widedeep/training/_feature_importance.py new file mode 100644 index 00000000..4e8711a0 --- /dev/null +++ b/pytorch_widedeep/training/_feature_importance.py @@ -0,0 +1,379 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch +from scipy.sparse import csc_matrix +from torch.utils.data import DataLoader + +from pytorch_widedeep.wdtypes import ( + Any, + Dict, + List, + Tuple, + Union, + Tensor, + Optional, + WideDeep, +) +from pytorch_widedeep.models.tabular import ( + SAINT, + TabPerceiver, + FTTransformer, + TabFastFormer, + TabTransformer, + SelfAttentionMLP, + ContextAttentionMLP, +) +from pytorch_widedeep.utils.general_utils import Alias +from pytorch_widedeep.training._wd_dataset import WideDeepDataset +from pytorch_widedeep.models.tabular.tabnet._utils import create_explain_matrix + +TransformerBasedModels = ( + SAINT, + FTTransformer, + TabFastFormer, + TabTransformer, + SelfAttentionMLP, + ContextAttentionMLP, +) + +__all__ = ["FeatureImportance", "Explainer"] + + +# TODO: review typing for WideDeep (in particular the deeptabular part) The +# issue with the typing of the deeptabular part is the following: the +# deeptabular part is typed as Optional[BaseWDModelComponent]. While that is +# correct, is not fully informative, and the most correct approach would +# perhaps be [BaseWDModelComponent, BaseTabularModelWithAttention, +# BaseTabularModelWithoutAttention]. Perhaps this way as we do +# model.deeptabular._modules["0"] it would be understood that the type of +# this value is such that has a property called `attention_weights` and I +# would not have to use type ignores here and there. For the time being I am +# going to leave it as it is (since it is not wrong), but this needs to be +# revisited + + +class FeatureImportance: + def __init__(self, device: str, n_samples: int = 1000): + self.device = device + self.n_samples = n_samples + + def feature_importance( + self, loader: DataLoader, model: WideDeep + ) -> Dict[str, float]: + if model.is_tabnet: + model_feat_importance: ModelFeatureImportance = TabNetFeatureImportance( + self.device, self.n_samples + ) + elif isinstance(model.deeptabular._modules["0"], TransformerBasedModels): + model_feat_importance = TransformerBasedFeatureImportance( + self.device, self.n_samples + ) + else: + raise ValueError( + "The computation of feature importance is not supported for this particular model" + ) + + return model_feat_importance.feature_importance(loader, model) + + +class Explainer: + def __init__(self, device: str): + self.device = device + + def explain( + self, + model: WideDeep, + X_tab: np.ndarray, + num_workers: int, + batch_size: Optional[int] = None, + save_step_masks: Optional[bool] = None, + ) -> Union[Tuple, np.ndarray]: + if model.is_tabnet: + assert ( + save_step_masks is not None + ), "If the model is TabNet, please set 'save_step_masks' to True/False" + model_explainer: ModelExplainer = TabNetExplainer(self.device) + res = model_explainer.explain( + model, + X_tab, + num_workers, + batch_size, + save_step_masks, + ) + elif isinstance(model.deeptabular._modules["0"], TransformerBasedModels): + model_explainer = TransformerBasedExplainer(self.device) + res = model_explainer.explain( + model, + X_tab, + num_workers, + batch_size, + ) + else: + raise ValueError( + "The computation of feature importance is not supported for this particular model" + ) + + return res + + +class BaseFeatureImportance(ABC): + def __init__(self, device: str, n_samples: int = 1000): + self.device = device + self.n_samples = n_samples + + @abstractmethod + def feature_importance( + self, loader: DataLoader, model: WideDeep + ) -> Dict[str, float]: + raise NotImplementedError( + "Any Feature Importance technique must implement this method" + ) + + def _sample_data(self, loader: DataLoader) -> Tensor: + n_iterations = self.n_samples // loader.batch_size + + batches = [] + for i, (data, _, _) in enumerate(loader): + if i < n_iterations: + batches.append(data["deeptabular"].to(self.device)) + else: + break + + return torch.cat(batches, dim=0) + + +class TabNetFeatureImportance(BaseFeatureImportance): + def __init__(self, device: str, n_samples: int = 1000): + super().__init__( + device=device, + n_samples=n_samples, + ) + + def feature_importance( + self, loader: DataLoader, model: WideDeep + ) -> Dict[str, float]: + model.eval() + + reducing_matrix = create_explain_matrix(model) + model_backbone = list(model.deeptabular.children())[0] + feat_imp = np.zeros((model_backbone.embed_out_dim)) # type: ignore[arg-type] + + X = self._sample_data(loader) + M_explain, _ = model_backbone.forward_masks(X) # type: ignore[operator] + feat_imp += M_explain.sum(dim=0).cpu().detach().numpy() + feat_imp = csc_matrix.dot(feat_imp, reducing_matrix) + feat_imp = feat_imp / np.sum(feat_imp) + + return {k: v for k, v in zip(model_backbone.column_idx.keys(), feat_imp)} # type: ignore + + +class TransformerBasedFeatureImportance(BaseFeatureImportance): + def __init__(self, device: str, n_samples: int = 1000): + super().__init__( + device=device, + n_samples=n_samples, + ) + + def feature_importance( + self, loader: DataLoader, model: WideDeep + ) -> Dict[str, float]: + self._check_inputs(model) + self.model_type = self._model_type(model) + + X = self._sample_data(loader) + + model.eval() + _ = model.deeptabular(X) + + feature_importance, column_idx = self._feature_importance(model) + + agg_feature_importance = feature_importance.mean(0).cpu().detach().numpy() + + return {k: v for k, v in zip(column_idx, agg_feature_importance)} + + def _model_type_attention_weights(self, model: WideDeep) -> Tensor: + if self.model_type == "saint": + attention_weights = torch.stack( + [aw[0] for aw in model.deeptabular[0].attention_weights], # type: ignore[index] + dim=0, + ) + elif self.model_type == "tabfastformer": + alpha_weights, beta_weights = zip(*model.deeptabular[0].attention_weights) # type: ignore[index] + attention_weights = torch.stack(alpha_weights + beta_weights, dim=0) + else: + attention_weights = torch.stack( + model.deeptabular[0].attention_weights, dim=0 # type: ignore[index] + ) + + return attention_weights + + def _model_type_feature_importance( + self, model: WideDeep, attention_weights: Tensor + ) -> Tuple[Tensor, List[str]]: + model_backbone = list(model.deeptabular.children())[0] + with_cls_token = model.deeptabular._modules["0"].with_cls_token + + column_idx = ( + list(model_backbone.column_idx.keys())[1:] # type: ignore + if with_cls_token + else list(model_backbone.column_idx.keys()) # type: ignore + ) + + if self.model_type == "contextattentionmlp": + feat_imp = ( + attention_weights.mean(0) + if not with_cls_token + else attention_weights.mean(0)[:, 1:] + ) + elif self.model_type == "tabfastformer": + feat_imp = ( + attention_weights.mean(0).mean(1) + if not with_cls_token + else attention_weights.mean(0).mean(1)[:, 1:] + ) + else: + feat_imp = ( + attention_weights.mean(0).mean(1).mean(1) + if not with_cls_token + else attention_weights.mean(0).mean(1)[:, 0, 1:] + ) + + return feat_imp, column_idx + + @staticmethod + def _model_type(model: WideDeep) -> str: + if isinstance(model.deeptabular._modules["0"], SAINT): + model_type = "saint" + if isinstance(model.deeptabular._modules["0"], FTTransformer): + model_type = "fttransformer" + if isinstance(model.deeptabular._modules["0"], TabFastFormer): + model_type = "tabfastformer" + if isinstance(model.deeptabular._modules["0"], TabTransformer): + model_type = "tabtransformer" + if isinstance(model.deeptabular._modules["0"], SelfAttentionMLP): + model_type = "selfattentionmlp" + if isinstance(model.deeptabular._modules["0"], ContextAttentionMLP): + model_type = "contextattentionmlp" + + return model_type + + def _feature_importance(self, model: WideDeep) -> Tuple[Tensor, List[str]]: + attention_weights = self._model_type_attention_weights(model) + + feature_importance, column_idx = self._model_type_feature_importance( + model, attention_weights + ) + + return feature_importance, column_idx + + def _check_inputs(self, model: WideDeep): + if isinstance(model.deeptabular._modules["0"], TabPerceiver): + raise ValueError( + "At this stage the feature importance is not supported for the 'TabPerceiver'" + ) + if isinstance(model.deeptabular._modules["0"], FTTransformer) and ( + model.deeptabular._modules["0"].kv_compression_factor != 1 + ): + raise ValueError( + "Feature importance can only be computed if the compression factor " + "'kv_compression_factor' is set to 1" + ) + + +class TabNetExplainer: + def __init__(self, device: str, n_samples: int = 1000): + self.device = device + + @Alias("X_tab", "X") + def explain( + self, + model: WideDeep, + X_tab: np.ndarray, + num_workers: int, + batch_size: Optional[int] = None, + save_step_masks: bool = False, + ) -> Union[Tuple, np.ndarray]: + model.eval() + model_backbone = list(model.deeptabular.children())[0] + reducing_matrix = create_explain_matrix(model) + + loader = DataLoader( + dataset=WideDeepDataset(X_tab=X_tab), + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + ) + + m_explain_l = [] + for batch_nb, data in enumerate(loader): + X = data["deeptabular"].to(self.device) + M_explain, masks = model_backbone.forward_masks(X) # type: ignore[operator] + m_explain_l.append( + csc_matrix.dot(M_explain.cpu().detach().numpy(), reducing_matrix) + ) + if save_step_masks: + for key, value in masks.items(): + masks[key] = csc_matrix.dot( + value.cpu().detach().numpy(), reducing_matrix + ) + if batch_nb == 0: + m_explain_step = masks + else: + for key, value in masks.items(): + m_explain_step[key] = np.vstack([m_explain_step[key], value]) + + m_explain_agg = np.vstack(m_explain_l) + m_explain_agg_norm = m_explain_agg / m_explain_agg.sum(axis=1)[:, np.newaxis] + + res: Union[Tuple, np.ndarray] = ( + (m_explain_agg_norm, m_explain_step) + if save_step_masks + else np.vstack(m_explain_agg_norm) + ) + + return res + + +class TransformerBasedExplainer(TransformerBasedFeatureImportance): + def __init__(self, device: str): + super().__init__( + device=device, + n_samples=1000, # irrelevant + ) + + @Alias("X_tab", "X") + def explain( + self, + model: WideDeep, + X_tab: np.ndarray, + num_workers: int, + batch_size: Optional[int] = None, + ) -> np.ndarray: + model.eval() + + self.model_type = self._model_type(model) + + loader = DataLoader( + dataset=WideDeepDataset(X_tab=X_tab), + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + ) + + batch_feat_imp: Any = [] + for _, data in enumerate(loader): + X = data["deeptabular"].to(self.device) + _ = model.deeptabular(X) + + feat_imp, col_idx = self._feature_importance(model) + + batch_feat_imp.append(feat_imp) + + return torch.cat(batch_feat_imp).cpu().detach().numpy() + + +ModelFeatureImportance = Union[ + TabNetFeatureImportance, TransformerBasedFeatureImportance +] +ModelExplainer = Union[TabNetExplainer, TransformerBasedExplainer] diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index cd005071..5dd45ae3 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -8,7 +8,6 @@ import torch.nn.functional as F from tqdm import trange from torch import nn -from scipy.sparse import csc_matrix from torchmetrics import Metric as TorchMetric from torch.utils.data import DataLoader @@ -39,6 +38,10 @@ wd_train_val_split, print_loss_and_metric, ) +from pytorch_widedeep.training._feature_importance import ( + Explainer, + FeatureImportance, +) class Trainer(BaseTrainer): @@ -265,6 +268,7 @@ def fit( # noqa: C901 validation_freq: int = 1, batch_size: int = 32, custom_dataloader: Optional[DataLoader] = None, + feature_importance_sample_size: Optional[int] = None, finetune: bool = False, with_lds: bool = False, **kwargs, @@ -520,15 +524,18 @@ def fit( # noqa: C901 self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric) if self.early_stop: - self.callback_container.on_train_end(epoch_logs) + # self.callback_container.on_train_end(epoch_logs) break if self.model.with_fds: self._update_fds_stats(train_loader, epoch) self.callback_container.on_train_end(epoch_logs) - if self.model.is_tabnet: - self._compute_feature_importance(train_loader) + + if feature_importance_sample_size is not None: + self.feature_importance = FeatureImportance( + self.device, feature_importance_sample_size + ).feature_importance(train_loader, self.model) self._restore_best_weights() self.model.train() @@ -740,64 +747,13 @@ def predict_proba( # type: ignore[return] if self.method == "multiclass": return np.vstack(preds_l) - def explain( - self, X_tab: np.ndarray, save_step_masks: bool = False - ) -> Union[Tuple, np.ndarray]: - """ - if the `deeptabular` component is a `Tabnet` model, returns the - aggregated feature importance for each instance (or observation) in - the `X_tab` array. If `save_step_masks` is set to `True`, the - masks per step will also be returned. - - Parameters - ---------- - X_tab: np.ndarray - Input array corresponding **only** to the deeptabular component - save_step_masks: bool - Boolean indicating if the masks per step will be returned - - Returns - ------- - Union[Tuple, np.ndarray] - Array or Tuple of two arrays with the corresponding aggregated - feature importance and the masks per step if `save_step_masks` - is set to `True` - """ - loader = DataLoader( - dataset=WideDeepDataset(X_tab=X_tab), - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) + def explain(self, X_tab: np.ndarray, save_step_masks: Optional[bool] = None): + # TO DO: Add docs to this, to the feat imp parameter and the all + # related classes + explainer = Explainer(self.device) - self.model.eval() - tabnet_backbone = list(self.model.deeptabular.children())[0] - - m_explain_l = [] - for batch_nb, data in enumerate(loader): - X = data["deeptabular"].to(self.device) - M_explain, masks = tabnet_backbone.forward_masks(X) # type: ignore[operator] - m_explain_l.append( - csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix) - ) - if save_step_masks: - for key, value in masks.items(): - masks[key] = csc_matrix.dot( - value.cpu().detach().numpy(), self.reducing_matrix - ) - if batch_nb == 0: - m_explain_step = masks - else: - for key, value in masks.items(): - m_explain_step[key] = np.vstack([m_explain_step[key], value]) - - m_explain_agg = np.vstack(m_explain_l) - m_explain_agg_norm = m_explain_agg / m_explain_agg.sum(axis=1)[:, np.newaxis] - - res: Union[Tuple, np.ndarray] = ( - (m_explain_agg_norm, m_explain_step) - if save_step_masks - else np.vstack(m_explain_agg_norm) + res = explainer.explain( + self.model, X_tab, self.num_workers, self.batch_size, save_step_masks ) return res @@ -1069,28 +1025,6 @@ def _update_fds_stats(self, train_loader: DataLoader, epoch: int): self.model.fds_layer.update_last_epoch_stats(epoch) self.model.fds_layer.update_running_stats(features, y_pred, epoch) - def _compute_feature_importance(self, loader: DataLoader): - self.model.eval() - tabnet_backbone = list(self.model.deeptabular.children())[0] - feat_imp = np.zeros((tabnet_backbone.embed_out_dim)) # type: ignore[arg-type] - for data, target, _ in loader: - X = data["deeptabular"].to(self.device) - y = ( - target.view(-1, 1).float() - if self.method not in ["multiclass", "qregression"] - else target - ) - y = y.to(self.device) - M_explain, masks = tabnet_backbone.forward_masks(X) # type: ignore[operator] - feat_imp += M_explain.sum(dim=0).cpu().detach().numpy() - - feat_imp = csc_matrix.dot(feat_imp, self.reducing_matrix) - feat_imp = feat_imp / np.sum(feat_imp) - - self.feature_importance = { - k: v for k, v in zip(tabnet_backbone.column_idx.keys(), feat_imp) # type: ignore[operator, union-attr] - } - def _predict( # noqa: C901 self, X_wide: Optional[np.ndarray] = None, diff --git a/pytorch_widedeep/version.py b/pytorch_widedeep/version.py index 10aa336c..67bc602a 100644 --- a/pytorch_widedeep/version.py +++ b/pytorch_widedeep/version.py @@ -1 +1 @@ -__version__ = "1.2.3" +__version__ = "1.3.0" diff --git a/setup.py b/setup.py index 8213c3a1..504f09b3 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ def requirements(fname): "1.0": "Development Status :: 5 - Production/Stable", # v1.0 - most functionality + doc + test # noqa "1.1": "Development Status :: 5 - Production/Stable", # v1.1 - new functionality "1.2": "Development Status :: 5 - Production/Stable", # v1.2 - new functionality + "1.3": "Development Status :: 5 - Production/Stable", # v1.3 - new functionality "2.0": "Development Status :: 6 - Mature", # v2.0 - new functionality? } diff --git a/tests/test_feature_importance/test_transformers_feat_imp.py b/tests/test_feature_importance/test_transformers_feat_imp.py new file mode 100644 index 00000000..f176400a --- /dev/null +++ b/tests/test_feature_importance/test_transformers_feat_imp.py @@ -0,0 +1,187 @@ +import numpy as np +import pandas as pd +import pytest + +from pytorch_widedeep import Trainer +from pytorch_widedeep.models import ( + SAINT, + TabNet, + WideDeep, + FTTransformer, + TabFastFormer, + TabTransformer, + SelfAttentionMLP, + ContextAttentionMLP, +) +from pytorch_widedeep.preprocessing import TabPreprocessor + +np.random.seed(42) + +# Define the column names +cat_cols = ["cat1", "cat2", "cat3", "cat4"] +cont_cols = ["cont1", "cont2", "cont3", "cont4"] +columns = cat_cols + cont_cols + +# Generate random categorical data +categorical_data = np.random.choice(["A", "B", "C"], size=(32, 4)) + +# Generate random numerical data +numerical_data = np.random.randn(32, 4) + +# Create the DataFrame +data = np.concatenate((categorical_data, numerical_data), axis=1) +df = pd.DataFrame(data, columns=columns) +target = np.random.choice(2, 32) + +df_tr = df[:16].copy() +df_te = df[16:].copy().reset_index(drop=True) + +y_tr = target[:16] +y_te = target[16:] + +# ############################TESTS BEGIN ####################################### + + +def _build_model_for_feat_imp_test(model_name, params): + if model_name == "tabtransformer": + return TabTransformer( + input_dim=6, n_blocks=2, n_heads=2, embed_continuous=True, **params + ) + if model_name == "saint": + return SAINT(input_dim=6, n_blocks=2, n_heads=2, **params) + if model_name == "fttransformer": + return FTTransformer( + input_dim=6, n_blocks=2, n_heads=2, kv_compression_factor=1.0, **params + ) + if model_name == "tabfastformer": + return TabFastFormer(input_dim=6, n_blocks=2, n_heads=2, **params) + if model_name == "self_attn_mlp": + return SelfAttentionMLP(input_dim=6, n_blocks=2, n_heads=2, **params) + if model_name == "cxt_attn_mlp": + return ContextAttentionMLP(input_dim=6, n_blocks=2, **params) + + +@pytest.mark.parametrize("with_cls_token", [True, False]) +@pytest.mark.parametrize( + "model_name", + [ + "tabtransformer", + "saint", + "fttransformer", + "tabfastformer", + "self_attn_mlp", + "cxt_attn_mlp", + ], +) +def test_feature_importances(with_cls_token, model_name): + tab_preprocessor = TabPreprocessor( + cat_embed_cols=cat_cols, + continuous_cols=cont_cols, + with_attention=True, + with_cls_token=with_cls_token, + ) + X_tr = tab_preprocessor.fit_transform(df_tr).astype(float) + X_te = tab_preprocessor.transform(df_te).astype(float) + + params = { + "column_idx": tab_preprocessor.column_idx, + "cat_embed_input": tab_preprocessor.cat_embed_input, + "continuous_cols": tab_preprocessor.continuous_cols, + } + + tab_model = _build_model_for_feat_imp_test(model_name, params) + + model = WideDeep(deeptabular=tab_model) + + trainer = Trainer( + model, + objective="binary", + ) + + trainer.fit( + X_tab=X_tr, + target=target, + n_epochs=1, + batch_size=16, + feature_importance_sample_size=32, + ) + + feat_imps = trainer.feature_importance + feat_imp_per_sample = trainer.explain(X_te) + + assert len(feat_imps) == df_tr.shape[1] and feat_imp_per_sample.shape == df_te.shape + + +def test_fttransformer_valueerror(): + tab_preprocessor = TabPreprocessor( + cat_embed_cols=cat_cols, + continuous_cols=cont_cols, + with_attention=True, + ) + X_tr = tab_preprocessor.fit_transform(df_tr).astype(float) + + params = { + "column_idx": tab_preprocessor.column_idx, + "cat_embed_input": tab_preprocessor.cat_embed_input, + "continuous_cols": tab_preprocessor.continuous_cols, + } + + model = FTTransformer( + input_dim=6, n_blocks=2, n_heads=2, kv_compression_factor=0.5, **params + ) + model = WideDeep(deeptabular=model) + + trainer = Trainer( + model, + objective="binary", + ) + + with pytest.raises(ValueError) as ve: + trainer.fit( + X_tab=X_tr, + target=target, + n_epochs=1, + batch_size=16, + feature_importance_sample_size=32, + ) + + assert ( + ve.value.args[0] + == "Feature importance can only be computed if the compression factor 'kv_compression_factor' is set to 1" + ) + + +def test_feature_importances_tabnet(): + tab_preprocessor = TabPreprocessor( + cat_embed_cols=cat_cols, + continuous_cols=cont_cols, + ) + X_tr = tab_preprocessor.fit_transform(df_tr).astype(float) + X_te = tab_preprocessor.transform(df_te).astype(float) + + tabnet = TabNet( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + continuous_cols=tab_preprocessor.continuous_cols, + embed_continuous=True, + ) + + model = WideDeep(deeptabular=tabnet) + + trainer = Trainer( + model, + objective="binary", + ) + + trainer.fit( + X_tab=X_tr, + target=target, + n_epochs=1, + batch_size=16, + feature_importance_sample_size=32, + ) + + feat_imps = trainer.feature_importance + feat_imp_per_sample = trainer.explain(X_te, save_step_masks=False) + + assert len(feat_imps) == df_tr.shape[1] and feat_imp_per_sample.shape == df_te.shape diff --git a/tests/test_model_functioning/test_miscellaneous.py b/tests/test_model_functioning/test_miscellaneous.py index e93c313b..3b356d9d 100644 --- a/tests/test_model_functioning/test_miscellaneous.py +++ b/tests/test_model_functioning/test_miscellaneous.py @@ -287,32 +287,32 @@ def test_save_and_load_dict(): assert same_weights and history_saved -############################################################################### -# test explain matrices and feature importance for TabNet -############################################################################### +# ############################################################################### +# # test explain matrices and feature importance for TabNet (Moved to 'test_feature_importance') +# ############################################################################### -def test_explain_mtx_and_feat_imp(): - model = WideDeep(deeptabular=tabnet) - trainer = Trainer(model, objective="binary", verbose=0) - trainer.fit( - X_tab=X_tab, - target=target, - batch_size=16, - ) +# def test_explain_mtx_and_feat_imp(): +# model = WideDeep(deeptabular=tabnet) +# trainer = Trainer(model, objective="binary", verbose=0) +# trainer.fit( +# X_tab=X_tab, +# target=target, +# batch_size=16, +# ) - checks = [] - checks.append(len(trainer.feature_importance) == len(tabnet.column_idx)) +# checks = [] +# checks.append(len(trainer.feature_importance) == len(tabnet.column_idx)) - expl_mtx, step_masks = trainer.explain(X_tab[:6], save_step_masks=True) - checks.append(expl_mtx.shape[0] == 6) - checks.append(expl_mtx.shape[1] == 10) +# expl_mtx, step_masks = trainer.explain(X_tab[:6], save_step_masks=True) +# checks.append(expl_mtx.shape[0] == 6) +# checks.append(expl_mtx.shape[1] == 10) - for i in range(tabnet.n_steps): - checks.append(step_masks[i].shape[0] == 6) - checks.append(step_masks[i].shape[1] == 10) +# for i in range(tabnet.n_steps): +# checks.append(step_masks[i].shape[0] == 6) +# checks.append(step_masks[i].shape[1] == 10) - assert all(checks) +# assert all(checks) ###############################################################################