diff --git a/VERSION b/VERSION
index 0495c4a8..f0bb29e7 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-1.2.3
+1.3.0
diff --git a/examples/notebooks/18_feature_importance_via_attention_weights.ipynb b/examples/notebooks/18_feature_importance_via_attention_weights.ipynb
new file mode 100644
index 00000000..6ae3f447
--- /dev/null
+++ b/examples/notebooks/18_feature_importance_via_attention_weights.ipynb
@@ -0,0 +1,1151 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "0e4d6c56",
+ "metadata": {},
+ "source": [
+ "### Feature Importance via the attention weights\n",
+ "\n",
+ "I will start by saying that I consider this feature of the library purely experimental. First of all I think there are multiple ways one could address finding the features importances for these models. However, and more importantly, one has to bear in mind that even tree-based algorithms on the same dataset produce different feature importances. This is more \"dramatic\" if one uses different techniques, such as shap or feature permutation (see for example [this](https://reneelin2019.medium.com/calculating-feature-importance-with-permutation-to-explain-the-model-income-prediction-example-38a52e67441d) and references therein). All this to say that, sometimes, feature importance is just a measure contained within the experiment run, and for the model used.\n",
+ "\n",
+ "With that in mind, each instantiation of a deep tabular model, that has millions of trainable parameters, will potentially produce a different set of feature importances, even if the model has the same architecture. Moreover, this effect will become more apparent if the dataset is relatively easy and there are dependent/related columns so that one could get to the same success metric with different parameters. \n",
+ "\n",
+ "In summary, feature importances are implemented in this librray for all attention-based models for tabular data, with the exception of the `TabPerceiver`. However this functionality has to be used and interpreted with care and consider of value within the 'universe' (or context) of the model with which these features were produced.\n",
+ "\n",
+ "Nonetheless, let's have a look to how one would access to the feature importances when using this library. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "365df643",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "import numpy as np\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.metrics import accuracy_score\n",
+ "\n",
+ "\n",
+ "from pytorch_widedeep import Trainer\n",
+ "from pytorch_widedeep.models import TabTransformer, ContextAttentionMLP, WideDeep\n",
+ "from pytorch_widedeep.callbacks import EarlyStopping\n",
+ "from pytorch_widedeep.metrics import Accuracy\n",
+ "from pytorch_widedeep.datasets import load_adult\n",
+ "from pytorch_widedeep.preprocessing import TabPreprocessor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "0f935ef4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# use_cuda = torch.cuda.is_available()\n",
+ "df = load_adult(as_frame=True)\n",
+ "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n",
+ "df[\"income_label\"] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n",
+ "df.drop([\"income\", \"fnlwgt\", \"educational_num\"], axis=1, inplace=True)\n",
+ "target_colname = \"income_label\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "202e7377",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " age | \n",
+ " workclass | \n",
+ " education | \n",
+ " marital_status | \n",
+ " occupation | \n",
+ " relationship | \n",
+ " race | \n",
+ " gender | \n",
+ " capital_gain | \n",
+ " capital_loss | \n",
+ " hours_per_week | \n",
+ " native_country | \n",
+ " income_label | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 25 | \n",
+ " Private | \n",
+ " 11th | \n",
+ " Never-married | \n",
+ " Machine-op-inspct | \n",
+ " Own-child | \n",
+ " Black | \n",
+ " Male | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 40 | \n",
+ " United-States | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 38 | \n",
+ " Private | \n",
+ " HS-grad | \n",
+ " Married-civ-spouse | \n",
+ " Farming-fishing | \n",
+ " Husband | \n",
+ " White | \n",
+ " Male | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 50 | \n",
+ " United-States | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 28 | \n",
+ " Local-gov | \n",
+ " Assoc-acdm | \n",
+ " Married-civ-spouse | \n",
+ " Protective-serv | \n",
+ " Husband | \n",
+ " White | \n",
+ " Male | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 40 | \n",
+ " United-States | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 44 | \n",
+ " Private | \n",
+ " Some-college | \n",
+ " Married-civ-spouse | \n",
+ " Machine-op-inspct | \n",
+ " Husband | \n",
+ " Black | \n",
+ " Male | \n",
+ " 7688 | \n",
+ " 0 | \n",
+ " 40 | \n",
+ " United-States | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 18 | \n",
+ " ? | \n",
+ " Some-college | \n",
+ " Never-married | \n",
+ " ? | \n",
+ " Own-child | \n",
+ " White | \n",
+ " Female | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 30 | \n",
+ " United-States | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age workclass education marital_status occupation \\\n",
+ "0 25 Private 11th Never-married Machine-op-inspct \n",
+ "1 38 Private HS-grad Married-civ-spouse Farming-fishing \n",
+ "2 28 Local-gov Assoc-acdm Married-civ-spouse Protective-serv \n",
+ "3 44 Private Some-college Married-civ-spouse Machine-op-inspct \n",
+ "4 18 ? Some-college Never-married ? \n",
+ "\n",
+ " relationship race gender capital_gain capital_loss hours_per_week \\\n",
+ "0 Own-child Black Male 0 0 40 \n",
+ "1 Husband White Male 0 0 50 \n",
+ "2 Husband White Male 0 0 40 \n",
+ "3 Husband Black Male 7688 0 40 \n",
+ "4 Own-child White Female 0 0 30 \n",
+ "\n",
+ " native_country income_label \n",
+ "0 United-States 0 \n",
+ "1 United-States 0 \n",
+ "2 United-States 1 \n",
+ "3 United-States 1 \n",
+ "4 United-States 0 "
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "4264cba5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cat_embed_cols = []\n",
+ "for col in df.columns:\n",
+ " if df[col].dtype == \"O\" or df[col].nunique() < 200 and col != target_colname:\n",
+ " cat_embed_cols.append(col)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "48ae0977",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# all cols will be categorical\n",
+ "assert len(cat_embed_cols) == df.shape[1] - 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "0abe7bb4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train, test = train_test_split(\n",
+ " df, test_size=0.1, random_state=1, stratify=df[[target_colname]]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "a762aee5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tab_preprocessor = TabPreprocessor(cat_embed_cols=cat_embed_cols, with_attention=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "a233f9a8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X_tab_train = tab_preprocessor.fit_transform(train)\n",
+ "X_tab_test = tab_preprocessor.transform(test)\n",
+ "target = train[target_colname].values"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "c17233fe",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tab_transformer = TabTransformer(\n",
+ " column_idx=tab_preprocessor.column_idx,\n",
+ " cat_embed_input=tab_preprocessor.cat_embed_input,\n",
+ " cat_embed_dropout=0.0,\n",
+ " input_dim=8,\n",
+ " n_heads=2,\n",
+ " n_blocks=1,\n",
+ " attn_dropout=0.1,\n",
+ " transformer_activation=\"relu\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "5010c08d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = WideDeep(deeptabular=tab_transformer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "145803d9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "24133088",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
+ " optimizer,\n",
+ " threshold=0.001,\n",
+ " threshold_mode=\"abs\",\n",
+ " patience=10,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "06330871",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "early_stopping = EarlyStopping(\n",
+ " min_delta=0.001, patience=30, restore_best_weights=True, verbose=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "405befbe",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer = Trainer(\n",
+ " model,\n",
+ " objective=\"binary\",\n",
+ " optimizers=optimizer,\n",
+ " lr_schedulers=lr_scheduler,\n",
+ " reducelronplateau_criterion=\"loss\",\n",
+ " callbacks=[early_stopping],\n",
+ " metrics=[Accuracy],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "997a5283",
+ "metadata": {},
+ "source": [
+ "The feature importances will be computed after training, using a sample of the training dataset of size `feature_importance_sample_size`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "459d0f69",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "epoch 1: 100%|████| 275/275 [00:03<00:00, 90.98it/s, loss=0.332, metrics={'acc': 0.847}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 148.31it/s, loss=0.292, metrics={'acc': 0.866}]\n",
+ "epoch 2: 100%|████| 275/275 [00:02<00:00, 96.35it/s, loss=0.289, metrics={'acc': 0.868}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 147.03it/s, loss=0.278, metrics={'acc': 0.8717}]\n",
+ "epoch 3: 100%|████| 275/275 [00:02<00:00, 94.23it/s, loss=0.28, metrics={'acc': 0.8719}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 139.13it/s, loss=0.275, metrics={'acc': 0.8732}]\n",
+ "epoch 4: 100%|████| 275/275 [00:03<00:00, 90.15it/s, loss=0.276, metrics={'acc': 0.872}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 133.21it/s, loss=0.275, metrics={'acc': 0.8706}]\n",
+ "epoch 5: 100%|███| 275/275 [00:03<00:00, 86.75it/s, loss=0.274, metrics={'acc': 0.8736}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 132.85it/s, loss=0.275, metrics={'acc': 0.8717}]\n",
+ "epoch 6: 100%|███| 275/275 [00:03<00:00, 89.16it/s, loss=0.272, metrics={'acc': 0.8742}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 129.50it/s, loss=0.273, metrics={'acc': 0.8733}]\n",
+ "epoch 7: 100%|███| 275/275 [00:03<00:00, 88.98it/s, loss=0.271, metrics={'acc': 0.8748}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 129.81it/s, loss=0.273, metrics={'acc': 0.8715}]\n",
+ "epoch 8: 100%|████| 275/275 [00:03<00:00, 88.63it/s, loss=0.27, metrics={'acc': 0.8748}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 129.81it/s, loss=0.272, metrics={'acc': 0.8739}]\n",
+ "epoch 9: 100%|███| 275/275 [00:03<00:00, 88.68it/s, loss=0.269, metrics={'acc': 0.8762}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 129.69it/s, loss=0.271, metrics={'acc': 0.8741}]\n",
+ "epoch 10: 100%|██| 275/275 [00:03<00:00, 89.56it/s, loss=0.267, metrics={'acc': 0.8761}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 132.71it/s, loss=0.271, metrics={'acc': 0.874}]\n",
+ "epoch 11: 100%|██| 275/275 [00:03<00:00, 81.74it/s, loss=0.267, metrics={'acc': 0.8757}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 120.83it/s, loss=0.272, metrics={'acc': 0.8744}]\n",
+ "epoch 12: 100%|██| 275/275 [00:03<00:00, 75.27it/s, loss=0.266, metrics={'acc': 0.8768}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 113.89it/s, loss=0.272, metrics={'acc': 0.8741}]\n",
+ "epoch 13: 100%|██| 275/275 [00:03<00:00, 76.29it/s, loss=0.265, metrics={'acc': 0.8787}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 113.24it/s, loss=0.273, metrics={'acc': 0.8726}]\n",
+ "epoch 14: 100%|██| 275/275 [00:03<00:00, 76.86it/s, loss=0.265, metrics={'acc': 0.8777}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 114.75it/s, loss=0.273, metrics={'acc': 0.8741}]\n",
+ "epoch 15: 100%|██| 275/275 [00:03<00:00, 72.97it/s, loss=0.264, metrics={'acc': 0.8779}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 115.13it/s, loss=0.273, metrics={'acc': 0.8735}]\n",
+ "epoch 16: 100%|██| 275/275 [00:03<00:00, 71.00it/s, loss=0.263, metrics={'acc': 0.8789}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 159.56it/s, loss=0.274, metrics={'acc': 0.8718}]\n",
+ "epoch 17: 100%|██| 275/275 [00:02<00:00, 98.42it/s, loss=0.263, metrics={'acc': 0.8789}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 172.29it/s, loss=0.274, metrics={'acc': 0.8727}]\n",
+ "epoch 18: 100%|█| 275/275 [00:02<00:00, 111.46it/s, loss=0.262, metrics={'acc': 0.8786}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 172.87it/s, loss=0.274, metrics={'acc': 0.8737}]\n",
+ "epoch 19: 100%|█| 275/275 [00:02<00:00, 109.79it/s, loss=0.261, metrics={'acc': 0.8785}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 167.10it/s, loss=0.275, metrics={'acc': 0.8737}]\n",
+ "epoch 20: 100%|██| 275/275 [00:02<00:00, 99.87it/s, loss=0.261, metrics={'acc': 0.8789}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 156.25it/s, loss=0.275, metrics={'acc': 0.8724}]\n",
+ "epoch 21: 100%|██| 275/275 [00:03<00:00, 83.22it/s, loss=0.256, metrics={'acc': 0.8813}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 107.03it/s, loss=0.274, metrics={'acc': 0.8727}]\n",
+ "epoch 22: 100%|██| 275/275 [00:03<00:00, 79.09it/s, loss=0.255, metrics={'acc': 0.8832}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 140.85it/s, loss=0.275, metrics={'acc': 0.8733}]\n",
+ "epoch 23: 100%|██| 275/275 [00:02<00:00, 99.53it/s, loss=0.254, metrics={'acc': 0.8822}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 128.29it/s, loss=0.276, metrics={'acc': 0.8724}]\n",
+ "epoch 24: 100%|██| 275/275 [00:03<00:00, 86.39it/s, loss=0.253, metrics={'acc': 0.8846}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 112.40it/s, loss=0.276, metrics={'acc': 0.8724}]\n",
+ "epoch 25: 100%|███| 275/275 [00:03<00:00, 75.98it/s, loss=0.253, metrics={'acc': 0.884}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 123.45it/s, loss=0.276, metrics={'acc': 0.8715}]\n",
+ "epoch 26: 100%|██| 275/275 [00:03<00:00, 80.57it/s, loss=0.253, metrics={'acc': 0.8839}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 134.19it/s, loss=0.277, metrics={'acc': 0.8706}]\n",
+ "epoch 27: 100%|██| 275/275 [00:03<00:00, 83.52it/s, loss=0.253, metrics={'acc': 0.8838}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 120.70it/s, loss=0.277, metrics={'acc': 0.8709}]\n",
+ "epoch 28: 100%|██| 275/275 [00:03<00:00, 87.30it/s, loss=0.252, metrics={'acc': 0.8837}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 136.44it/s, loss=0.277, metrics={'acc': 0.8706}]\n",
+ "epoch 29: 100%|██| 275/275 [00:03<00:00, 84.86it/s, loss=0.252, metrics={'acc': 0.8847}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 126.91it/s, loss=0.278, metrics={'acc': 0.8698}]\n",
+ "epoch 30: 100%|██| 275/275 [00:03<00:00, 83.93it/s, loss=0.252, metrics={'acc': 0.8837}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 130.11it/s, loss=0.277, metrics={'acc': 0.8715}]\n",
+ "epoch 31: 100%|██| 275/275 [00:03<00:00, 81.44it/s, loss=0.252, metrics={'acc': 0.8824}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 120.13it/s, loss=0.278, metrics={'acc': 0.8711}]\n",
+ "epoch 32: 100%|███| 275/275 [00:03<00:00, 81.13it/s, loss=0.25, metrics={'acc': 0.8843}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 109.88it/s, loss=0.278, metrics={'acc': 0.8705}]\n",
+ "epoch 33: 100%|███| 275/275 [00:03<00:00, 69.72it/s, loss=0.25, metrics={'acc': 0.8838}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 118.55it/s, loss=0.278, metrics={'acc': 0.8699}]\n",
+ "epoch 34: 100%|██| 275/275 [00:03<00:00, 85.04it/s, loss=0.251, metrics={'acc': 0.8843}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 136.37it/s, loss=0.278, metrics={'acc': 0.8698}]\n",
+ "epoch 35: 100%|███| 275/275 [00:02<00:00, 92.70it/s, loss=0.25, metrics={'acc': 0.8841}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 145.48it/s, loss=0.278, metrics={'acc': 0.8697}]\n",
+ "epoch 36: 100%|██| 275/275 [00:02<00:00, 95.07it/s, loss=0.251, metrics={'acc': 0.8843}]\n",
+ "valid: 100%|████████| 69/69 [00:00<00:00, 144.44it/s, loss=0.278, metrics={'acc': 0.87}]\n",
+ "epoch 37: 100%|███| 275/275 [00:02<00:00, 93.25it/s, loss=0.25, metrics={'acc': 0.8837}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 144.45it/s, loss=0.278, metrics={'acc': 0.8701}]\n",
+ "epoch 38: 100%|███| 275/275 [00:03<00:00, 91.39it/s, loss=0.25, metrics={'acc': 0.8843}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 139.95it/s, loss=0.278, metrics={'acc': 0.8702}]\n",
+ "epoch 39: 100%|███| 275/275 [00:02<00:00, 92.64it/s, loss=0.25, metrics={'acc': 0.8847}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 140.00it/s, loss=0.278, metrics={'acc': 0.8703}]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Best Epoch: 9. Best val_loss: 0.27098\n",
+ "Restoring model weights from the end of the best epoch\n"
+ ]
+ }
+ ],
+ "source": [
+ "trainer.fit(\n",
+ " X_tab=X_tab_train,\n",
+ " target=target,\n",
+ " val_split=0.2,\n",
+ " n_epochs=100,\n",
+ " batch_size=128,\n",
+ " validation_freq=1,\n",
+ " feature_importance_sample_size=1000,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "cd603a44",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'age': 0.098023,\n",
+ " 'workclass': 0.07621125,\n",
+ " 'education': 0.07414728,\n",
+ " 'marital_status': 0.113280274,\n",
+ " 'occupation': 0.07292068,\n",
+ " 'relationship': 0.08008792,\n",
+ " 'race': 0.104180396,\n",
+ " 'gender': 0.07037963,\n",
+ " 'capital_gain': 0.06584223,\n",
+ " 'capital_loss': 0.07647487,\n",
+ " 'hours_per_week': 0.09369389,\n",
+ " 'native_country': 0.0747586}"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "trainer.feature_importance"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "8342dfdc",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "predict: 100%|█████████████████████████████████████████| 39/39 [00:00<00:00, 213.15it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "preds = trainer.predict(X_tab=X_tab_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "7cfdabda",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.8734902763561925"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "accuracy_score(preds, test.income_label)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "5b49516b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test.reset_index(drop=True, inplace=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "6ad4b3b6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " age | \n",
+ " workclass | \n",
+ " education | \n",
+ " marital_status | \n",
+ " occupation | \n",
+ " relationship | \n",
+ " race | \n",
+ " gender | \n",
+ " capital_gain | \n",
+ " capital_loss | \n",
+ " hours_per_week | \n",
+ " native_country | \n",
+ " income_label | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 26 | \n",
+ " Private | \n",
+ " Some-college | \n",
+ " Never-married | \n",
+ " Exec-managerial | \n",
+ " Not-in-family | \n",
+ " White | \n",
+ " Male | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 60 | \n",
+ " United-States | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age workclass education marital_status occupation relationship \\\n",
+ "0 26 Private Some-college Never-married Exec-managerial Not-in-family \n",
+ "\n",
+ " race gender capital_gain capital_loss hours_per_week native_country \\\n",
+ "0 White Male 0 0 60 United-States \n",
+ "\n",
+ " income_label \n",
+ "0 0 "
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test[test.income_label == 0].head(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "6419edff",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " age | \n",
+ " workclass | \n",
+ " education | \n",
+ " marital_status | \n",
+ " occupation | \n",
+ " relationship | \n",
+ " race | \n",
+ " gender | \n",
+ " capital_gain | \n",
+ " capital_loss | \n",
+ " hours_per_week | \n",
+ " native_country | \n",
+ " income_label | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 3 | \n",
+ " 36 | \n",
+ " Local-gov | \n",
+ " Doctorate | \n",
+ " Married-civ-spouse | \n",
+ " Prof-specialty | \n",
+ " Husband | \n",
+ " White | \n",
+ " Male | \n",
+ " 0 | \n",
+ " 1887 | \n",
+ " 50 | \n",
+ " United-States | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age workclass education marital_status occupation relationship \\\n",
+ "3 36 Local-gov Doctorate Married-civ-spouse Prof-specialty Husband \n",
+ "\n",
+ " race gender capital_gain capital_loss hours_per_week native_country \\\n",
+ "3 White Male 0 1887 50 United-States \n",
+ "\n",
+ " income_label \n",
+ "3 1 "
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test[test.income_label == 1].head(1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a6860689",
+ "metadata": {},
+ "source": [
+ "To get the feature importance of a test dataset, simply use the `explain` method"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "ba6432ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "feat_imp_per_sample = trainer.explain(X_tab_test, save_step_masks=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "f350e822",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['marital_status',\n",
+ " 'race',\n",
+ " 'age',\n",
+ " 'capital_loss',\n",
+ " 'occupation',\n",
+ " 'native_country',\n",
+ " 'workclass',\n",
+ " 'education',\n",
+ " 'gender',\n",
+ " 'relationship',\n",
+ " 'hours_per_week',\n",
+ " 'capital_gain']"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "list(test.iloc[0].index[np.argsort(-feat_imp_per_sample[0])])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "f05a8302",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['marital_status',\n",
+ " 'race',\n",
+ " 'capital_loss',\n",
+ " 'occupation',\n",
+ " 'education',\n",
+ " 'native_country',\n",
+ " 'hours_per_week',\n",
+ " 'relationship',\n",
+ " 'age',\n",
+ " 'workclass',\n",
+ " 'gender',\n",
+ " 'capital_gain']"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "list(test.iloc[3].index[np.argsort(-feat_imp_per_sample[3])])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "558fc30f",
+ "metadata": {},
+ "source": [
+ "We could do the same with the `ContextAttentionMLP`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "45b8a869",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "context_attn_mlp = ContextAttentionMLP(\n",
+ " column_idx=tab_preprocessor.column_idx,\n",
+ " cat_embed_input=tab_preprocessor.cat_embed_input,\n",
+ " cat_embed_dropout=0.0,\n",
+ " input_dim=16,\n",
+ " attn_dropout=0.1,\n",
+ " attn_activation=\"relu\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "07f0c434",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mlp_model = WideDeep(deeptabular=context_attn_mlp)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "d9c1ef75",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mlp_optimizer = torch.optim.Adam(mlp_model.parameters(), lr=0.01, weight_decay=0.0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "4c2f21ca",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mlp_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
+ " mlp_optimizer,\n",
+ " threshold=0.001,\n",
+ " threshold_mode=\"abs\",\n",
+ " patience=10,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "504c1755",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mlp_early_stopping = EarlyStopping(\n",
+ " min_delta=0.001, patience=30, restore_best_weights=True, verbose=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "78c79a3d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mlp_trainer = Trainer(\n",
+ " mlp_model,\n",
+ " objective=\"binary\",\n",
+ " optimizers=mlp_optimizer,\n",
+ " lr_schedulers=mlp_lr_scheduler,\n",
+ " reducelronplateau_criterion=\"loss\",\n",
+ " callbacks=[mlp_early_stopping],\n",
+ " metrics=[Accuracy],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "af40082b",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "epoch 1: 100%|███| 275/275 [00:03<00:00, 91.33it/s, loss=0.395, metrics={'acc': 0.8139}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 125.66it/s, loss=0.306, metrics={'acc': 0.8577}]\n",
+ "epoch 2: 100%|███| 275/275 [00:03<00:00, 87.31it/s, loss=0.333, metrics={'acc': 0.8396}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 140.84it/s, loss=0.291, metrics={'acc': 0.8631}]\n",
+ "epoch 3: 100%|███| 275/275 [00:03<00:00, 84.57it/s, loss=0.323, metrics={'acc': 0.8494}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 141.55it/s, loss=0.293, metrics={'acc': 0.8632}]\n",
+ "epoch 4: 100%|███| 275/275 [00:03<00:00, 84.62it/s, loss=0.312, metrics={'acc': 0.8518}]\n",
+ "valid: 100%|████████| 69/69 [00:00<00:00, 156.92it/s, loss=0.3, metrics={'acc': 0.8543}]\n",
+ "epoch 5: 100%|████| 275/275 [00:03<00:00, 85.74it/s, loss=0.31, metrics={'acc': 0.8552}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 130.66it/s, loss=0.303, metrics={'acc': 0.8545}]\n",
+ "epoch 6: 100%|███| 275/275 [00:02<00:00, 93.70it/s, loss=0.303, metrics={'acc': 0.8579}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 161.71it/s, loss=0.291, metrics={'acc': 0.8609}]\n",
+ "epoch 7: 100%|███| 275/275 [00:03<00:00, 81.87it/s, loss=0.302, metrics={'acc': 0.8584}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 126.83it/s, loss=0.294, metrics={'acc': 0.8579}]\n",
+ "epoch 8: 100%|█████| 275/275 [00:04<00:00, 62.67it/s, loss=0.299, metrics={'acc': 0.86}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 161.52it/s, loss=0.284, metrics={'acc': 0.8653}]\n",
+ "epoch 9: 100%|███| 275/275 [00:03<00:00, 85.13it/s, loss=0.296, metrics={'acc': 0.8615}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 153.13it/s, loss=0.282, metrics={'acc': 0.8672}]\n",
+ "epoch 10: 100%|███| 275/275 [00:02<00:00, 92.61it/s, loss=0.296, metrics={'acc': 0.861}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 156.41it/s, loss=0.281, metrics={'acc': 0.8716}]\n",
+ "epoch 11: 100%|██| 275/275 [00:02<00:00, 99.43it/s, loss=0.295, metrics={'acc': 0.8619}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 155.07it/s, loss=0.283, metrics={'acc': 0.8649}]\n",
+ "epoch 12: 100%|██| 275/275 [00:03<00:00, 83.91it/s, loss=0.294, metrics={'acc': 0.8603}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 108.25it/s, loss=0.285, metrics={'acc': 0.8652}]\n",
+ "epoch 13: 100%|██| 275/275 [00:03<00:00, 77.13it/s, loss=0.292, metrics={'acc': 0.8631}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 116.60it/s, loss=0.284, metrics={'acc': 0.8656}]\n",
+ "epoch 14: 100%|██| 275/275 [00:03<00:00, 81.83it/s, loss=0.294, metrics={'acc': 0.8622}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 126.96it/s, loss=0.283, metrics={'acc': 0.8665}]\n",
+ "epoch 15: 100%|██| 275/275 [00:03<00:00, 81.45it/s, loss=0.292, metrics={'acc': 0.8621}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 128.30it/s, loss=0.284, metrics={'acc': 0.8664}]\n",
+ "epoch 16: 100%|██| 275/275 [00:03<00:00, 81.18it/s, loss=0.293, metrics={'acc': 0.8626}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 129.63it/s, loss=0.283, metrics={'acc': 0.8659}]\n",
+ "epoch 17: 100%|██| 275/275 [00:03<00:00, 76.42it/s, loss=0.292, metrics={'acc': 0.8619}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 112.97it/s, loss=0.285, metrics={'acc': 0.8658}]\n",
+ "epoch 18: 100%|██| 275/275 [00:03<00:00, 71.98it/s, loss=0.294, metrics={'acc': 0.8609}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 117.61it/s, loss=0.279, metrics={'acc': 0.8691}]\n",
+ "epoch 19: 100%|███| 275/275 [00:03<00:00, 76.80it/s, loss=0.29, metrics={'acc': 0.8652}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 113.99it/s, loss=0.283, metrics={'acc': 0.8707}]\n",
+ "epoch 20: 100%|██| 275/275 [00:03<00:00, 69.47it/s, loss=0.293, metrics={'acc': 0.8613}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 115.32it/s, loss=0.283, metrics={'acc': 0.8646}]\n",
+ "epoch 21: 100%|███| 275/275 [00:03<00:00, 69.08it/s, loss=0.29, metrics={'acc': 0.8636}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 117.65it/s, loss=0.289, metrics={'acc': 0.8621}]\n",
+ "epoch 22: 100%|██| 275/275 [00:03<00:00, 69.49it/s, loss=0.291, metrics={'acc': 0.8626}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 125.45it/s, loss=0.284, metrics={'acc': 0.8675}]\n",
+ "epoch 23: 100%|███| 275/275 [00:03<00:00, 81.16it/s, loss=0.29, metrics={'acc': 0.8637}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 124.96it/s, loss=0.284, metrics={'acc': 0.8649}]\n",
+ "epoch 24: 100%|███| 275/275 [00:03<00:00, 76.37it/s, loss=0.29, metrics={'acc': 0.8646}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 114.71it/s, loss=0.28, metrics={'acc': 0.8678}]\n",
+ "epoch 25: 100%|██| 275/275 [00:03<00:00, 68.76it/s, loss=0.289, metrics={'acc': 0.8634}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 113.49it/s, loss=0.282, metrics={'acc': 0.8675}]\n",
+ "epoch 26: 100%|██| 275/275 [00:03<00:00, 68.98it/s, loss=0.291, metrics={'acc': 0.8624}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 121.90it/s, loss=0.282, metrics={'acc': 0.8668}]\n",
+ "epoch 27: 100%|██| 275/275 [00:03<00:00, 72.16it/s, loss=0.289, metrics={'acc': 0.8648}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 109.51it/s, loss=0.28, metrics={'acc': 0.8674}]\n",
+ "epoch 28: 100%|██| 275/275 [00:03<00:00, 68.83it/s, loss=0.288, metrics={'acc': 0.8653}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 110.77it/s, loss=0.283, metrics={'acc': 0.8637}]\n",
+ "epoch 29: 100%|███| 275/275 [00:04<00:00, 68.63it/s, loss=0.29, metrics={'acc': 0.8655}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 113.82it/s, loss=0.283, metrics={'acc': 0.8643}]\n",
+ "epoch 30: 100%|███| 275/275 [00:04<00:00, 67.75it/s, loss=0.284, metrics={'acc': 0.868}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 106.96it/s, loss=0.276, metrics={'acc': 0.8708}]\n",
+ "epoch 31: 100%|██| 275/275 [00:04<00:00, 61.95it/s, loss=0.284, metrics={'acc': 0.8664}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 97.29it/s, loss=0.276, metrics={'acc': 0.8699}]\n",
+ "epoch 32: 100%|██| 275/275 [00:05<00:00, 53.82it/s, loss=0.284, metrics={'acc': 0.8663}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 87.62it/s, loss=0.275, metrics={'acc': 0.8703}]\n",
+ "epoch 33: 100%|██| 275/275 [00:05<00:00, 52.92it/s, loss=0.282, metrics={'acc': 0.8669}]\n",
+ "valid: 100%|█████████| 69/69 [00:00<00:00, 87.98it/s, loss=0.276, metrics={'acc': 0.87}]\n",
+ "epoch 34: 100%|██| 275/275 [00:05<00:00, 54.57it/s, loss=0.282, metrics={'acc': 0.8666}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 95.37it/s, loss=0.275, metrics={'acc': 0.8708}]\n",
+ "epoch 35: 100%|██| 275/275 [00:04<00:00, 58.35it/s, loss=0.281, metrics={'acc': 0.8701}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 98.99it/s, loss=0.276, metrics={'acc': 0.8693}]\n",
+ "epoch 36: 100%|██| 275/275 [00:04<00:00, 57.91it/s, loss=0.281, metrics={'acc': 0.8673}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 96.97it/s, loss=0.277, metrics={'acc': 0.8689}]\n",
+ "epoch 37: 100%|██| 275/275 [00:04<00:00, 58.82it/s, loss=0.279, metrics={'acc': 0.8687}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 98.79it/s, loss=0.277, metrics={'acc': 0.8676}]\n",
+ "epoch 38: 100%|███| 275/275 [00:04<00:00, 61.86it/s, loss=0.28, metrics={'acc': 0.8707}]\n",
+ "valid: 100%|████████| 69/69 [00:00<00:00, 102.77it/s, loss=0.276, metrics={'acc': 0.87}]\n",
+ "epoch 39: 100%|███| 275/275 [00:04<00:00, 61.11it/s, loss=0.281, metrics={'acc': 0.869}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 101.65it/s, loss=0.276, metrics={'acc': 0.8683}]\n",
+ "epoch 40: 100%|███| 275/275 [00:04<00:00, 59.90it/s, loss=0.28, metrics={'acc': 0.8692}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 102.92it/s, loss=0.278, metrics={'acc': 0.8681}]\n",
+ "epoch 41: 100%|██| 275/275 [00:04<00:00, 60.41it/s, loss=0.281, metrics={'acc': 0.8674}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 98.36it/s, loss=0.276, metrics={'acc': 0.8703}]\n",
+ "epoch 42: 100%|██| 275/275 [00:04<00:00, 55.65it/s, loss=0.281, metrics={'acc': 0.8691}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 97.88it/s, loss=0.276, metrics={'acc': 0.8693}]\n",
+ "epoch 43: 100%|███| 275/275 [00:04<00:00, 59.98it/s, loss=0.28, metrics={'acc': 0.8679}]\n",
+ "valid: 100%|███████| 69/69 [00:00<00:00, 98.41it/s, loss=0.276, metrics={'acc': 0.8701}]\n",
+ "epoch 44: 100%|██| 275/275 [00:04<00:00, 59.67it/s, loss=0.279, metrics={'acc': 0.8688}]\n",
+ "valid: 100%|█████████| 69/69 [00:00<00:00, 98.48it/s, loss=0.276, metrics={'acc': 0.87}]\n",
+ "epoch 45: 100%|██| 275/275 [00:04<00:00, 61.62it/s, loss=0.279, metrics={'acc': 0.8694}]\n",
+ "valid: 100%|████████| 69/69 [00:00<00:00, 103.90it/s, loss=0.276, metrics={'acc': 0.87}]\n",
+ "epoch 46: 100%|███| 275/275 [00:04<00:00, 61.50it/s, loss=0.28, metrics={'acc': 0.8688}]\n",
+ "valid: 100%|████████| 69/69 [00:00<00:00, 102.02it/s, loss=0.276, metrics={'acc': 0.87}]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "epoch 47: 100%|██| 275/275 [00:03<00:00, 70.92it/s, loss=0.279, metrics={'acc': 0.8693}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 121.47it/s, loss=0.276, metrics={'acc': 0.8699}]\n",
+ "epoch 48: 100%|██| 275/275 [00:03<00:00, 75.09it/s, loss=0.277, metrics={'acc': 0.8701}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 126.88it/s, loss=0.276, metrics={'acc': 0.8706}]\n",
+ "epoch 49: 100%|██| 275/275 [00:03<00:00, 74.75it/s, loss=0.279, metrics={'acc': 0.8694}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 130.57it/s, loss=0.276, metrics={'acc': 0.8705}]\n",
+ "epoch 50: 100%|██| 275/275 [00:03<00:00, 81.06it/s, loss=0.278, metrics={'acc': 0.8704}]\n",
+ "valid: 100%|████████| 69/69 [00:00<00:00, 137.71it/s, loss=0.276, metrics={'acc': 0.87}]\n",
+ "epoch 51: 100%|██| 275/275 [00:03<00:00, 79.21it/s, loss=0.278, metrics={'acc': 0.8689}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 134.51it/s, loss=0.276, metrics={'acc': 0.8702}]\n",
+ "epoch 52: 100%|███| 275/275 [00:03<00:00, 81.95it/s, loss=0.279, metrics={'acc': 0.869}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 136.85it/s, loss=0.276, metrics={'acc': 0.8701}]\n",
+ "epoch 53: 100%|██| 275/275 [00:03<00:00, 82.84it/s, loss=0.279, metrics={'acc': 0.8702}]\n",
+ "valid: 100%|████████| 69/69 [00:00<00:00, 139.43it/s, loss=0.276, metrics={'acc': 0.87}]\n",
+ "epoch 54: 100%|███| 275/275 [00:03<00:00, 84.35it/s, loss=0.28, metrics={'acc': 0.8678}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 139.32it/s, loss=0.276, metrics={'acc': 0.8701}]\n",
+ "epoch 55: 100%|██| 275/275 [00:03<00:00, 82.40it/s, loss=0.278, metrics={'acc': 0.8694}]\n",
+ "valid: 100%|████████| 69/69 [00:00<00:00, 140.20it/s, loss=0.276, metrics={'acc': 0.87}]\n",
+ "epoch 56: 100%|███| 275/275 [00:03<00:00, 84.02it/s, loss=0.28, metrics={'acc': 0.8686}]\n",
+ "valid: 100%|████████| 69/69 [00:00<00:00, 138.68it/s, loss=0.276, metrics={'acc': 0.87}]\n",
+ "epoch 57: 100%|████| 275/275 [00:03<00:00, 83.66it/s, loss=0.28, metrics={'acc': 0.868}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 139.25it/s, loss=0.276, metrics={'acc': 0.8702}]\n",
+ "epoch 58: 100%|███| 275/275 [00:03<00:00, 83.12it/s, loss=0.279, metrics={'acc': 0.869}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 134.59it/s, loss=0.276, metrics={'acc': 0.8702}]\n",
+ "epoch 59: 100%|██| 275/275 [00:03<00:00, 81.14it/s, loss=0.278, metrics={'acc': 0.8682}]\n",
+ "valid: 100%|██████| 69/69 [00:00<00:00, 135.79it/s, loss=0.276, metrics={'acc': 0.8702}]\n",
+ "epoch 60: 100%|███| 275/275 [00:03<00:00, 80.78it/s, loss=0.28, metrics={'acc': 0.8684}]\n",
+ "valid: 100%|████████| 69/69 [00:00<00:00, 133.84it/s, loss=0.276, metrics={'acc': 0.87}]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Best Epoch: 30. Best val_loss: 0.27563\n",
+ "Restoring model weights from the end of the best epoch\n"
+ ]
+ }
+ ],
+ "source": [
+ "mlp_trainer.fit(\n",
+ " X_tab=X_tab_train,\n",
+ " target=target,\n",
+ " val_split=0.2,\n",
+ " n_epochs=100,\n",
+ " batch_size=128,\n",
+ " validation_freq=1,\n",
+ " feature_importance_sample_size=1000,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "0a4fdc77",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'age': 0.103683405,\n",
+ " 'workclass': 0.066264994,\n",
+ " 'education': 0.10014994,\n",
+ " 'marital_status': 0.1235957,\n",
+ " 'occupation': 0.12825337,\n",
+ " 'relationship': 0.15234835,\n",
+ " 'race': 0.061964743,\n",
+ " 'gender': 0.05328226,\n",
+ " 'capital_gain': 0.03052448,\n",
+ " 'capital_loss': 0.037544865,\n",
+ " 'hours_per_week': 0.07689079,\n",
+ " 'native_country': 0.0654971}"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mlp_trainer.feature_importance"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "c10be01f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "predict: 100%|█████████████████████████████████████████| 39/39 [00:00<00:00, 211.89it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "mlp_preds = mlp_trainer.predict(X_tab=X_tab_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "0b2a46a9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.873899692937564"
+ ]
+ },
+ "execution_count": 35,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "accuracy_score(mlp_preds, test.income_label)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.15"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/scripts/adult_census_feature_importance.py b/examples/scripts/adult_census_feature_importance.py
new file mode 100644
index 00000000..64e7447d
--- /dev/null
+++ b/examples/scripts/adult_census_feature_importance.py
@@ -0,0 +1,162 @@
+from sklearn.model_selection import train_test_split
+
+from pytorch_widedeep import Trainer
+from pytorch_widedeep.models import (
+ SAINT,
+ TabNet,
+ WideDeep,
+ FTTransformer,
+ TabFastFormer,
+ TabTransformer,
+ SelfAttentionMLP,
+ ContextAttentionMLP,
+)
+from pytorch_widedeep.metrics import Accuracy
+from pytorch_widedeep.datasets import load_adult
+from pytorch_widedeep.preprocessing import TabPreprocessor
+
+# use_cuda = torch.cuda.is_available()
+
+df = load_adult(as_frame=True)
+df.columns = [c.replace("-", "_") for c in df.columns]
+df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
+df.drop("income", axis=1, inplace=True)
+target_colname = "income_label"
+
+cat_embed_cols = []
+for col in df.columns:
+ if df[col].dtype == "O" or df[col].nunique() < 200 and col != target_colname:
+ cat_embed_cols.append(col)
+
+train, test = train_test_split(
+ df, test_size=0.1, random_state=1, stratify=df[[target_colname]]
+)
+
+with_cls_token = True
+tab_preprocessor = TabPreprocessor(
+ cat_embed_cols=cat_embed_cols, with_attention=True, with_cls_token=with_cls_token
+)
+
+X_tab_train = tab_preprocessor.fit_transform(train)
+X_tab_test = tab_preprocessor.transform(test)
+target = train[target_colname].values
+
+
+tab_transformer = TabTransformer(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input,
+ input_dim=8,
+ n_heads=2,
+ n_blocks=2,
+)
+
+saint = SAINT(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input,
+ input_dim=8,
+ n_heads=2,
+ n_blocks=2,
+)
+
+tab_fastformer = TabFastFormer(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input,
+ input_dim=8,
+ n_heads=2,
+ n_blocks=2,
+)
+
+ft_transformer = FTTransformer(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input,
+ input_dim=8,
+ n_heads=2,
+ n_blocks=2,
+ kv_compression_factor=1.0, # if this is diff than one, we cannot do this
+)
+
+context_attention_mlp = ContextAttentionMLP(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input,
+ input_dim=16,
+ attn_dropout=0.2,
+ n_blocks=3,
+)
+
+self_attention_mlp = SelfAttentionMLP(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input,
+ n_blocks=3,
+)
+
+for attention_based_model in [
+ tab_transformer,
+ saint,
+ tab_fastformer,
+ ft_transformer,
+ context_attention_mlp,
+ self_attention_mlp,
+]:
+ model = WideDeep(deeptabular=attention_based_model) # type: ignore[arg-type]
+
+ trainer = Trainer(
+ model,
+ objective="binary",
+ metrics=[Accuracy],
+ )
+
+ trainer.fit(
+ X_tab=X_tab_train,
+ target=target,
+ n_epochs=5,
+ batch_size=128,
+ val_split=0.2,
+ feature_importance_sample_size=1000,
+ )
+
+ feat_imp_per_sample = trainer.explain(X_tab_test)
+
+ assert (
+ len(trainer.feature_importance) == X_tab_train.shape[1] - 1
+ if with_cls_token
+ else X_tab_train.shape[1]
+ ) and feat_imp_per_sample.shape == test[cat_embed_cols].shape
+
+
+train, test = train_test_split(
+ df, test_size=0.1, random_state=1, stratify=df[[target_colname]]
+)
+
+tab_preprocessor = TabPreprocessor(cat_embed_cols=cat_embed_cols)
+
+X_tab_train = tab_preprocessor.fit_transform(train)
+X_tab_test = tab_preprocessor.transform(test)
+target = train[target_colname].values
+
+tabnet = TabNet(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input,
+)
+
+model = WideDeep(deeptabular=tabnet)
+
+trainer = Trainer(
+ model,
+ objective="binary",
+ metrics=[Accuracy],
+)
+
+trainer.fit(
+ X_tab=X_tab_train,
+ target=target,
+ n_epochs=5,
+ batch_size=128,
+ val_split=0.2,
+ feature_importance_sample_size=1000,
+)
+feat_imp_per_sample = trainer.explain(X_tab_test, save_step_masks=False)
+
+assert (
+ len(trainer.feature_importance) == X_tab_train.shape[1]
+ and feat_imp_per_sample.shape == test[cat_embed_cols].shape
+)
diff --git a/pytorch_widedeep/callbacks.py b/pytorch_widedeep/callbacks.py
index 9fd38e2b..d73957c1 100644
--- a/pytorch_widedeep/callbacks.py
+++ b/pytorch_widedeep/callbacks.py
@@ -655,6 +655,7 @@ def on_epoch_end(
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
+ self.best_epoch = epoch
if self.restore_best_weights:
self.state_dict = copy.deepcopy(self.model.state_dict())
else:
@@ -665,7 +666,9 @@ def on_epoch_end(
def on_train_end(self, logs: Optional[Dict] = None):
if self.stopped_epoch > 0 and self.verbose > 0:
- print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
+ print(
+ f"Best Epoch: {self.best_epoch + 1}. Best {self.monitor}: {self.best:.5f}"
+ )
if self.restore_best_weights and self.state_dict is not None:
if self.verbose > 0:
print("Restoring model weights from the end of the best epoch")
diff --git a/pytorch_widedeep/models/tabular/mlp/context_attention_mlp.py b/pytorch_widedeep/models/tabular/mlp/context_attention_mlp.py
index f98e2647..d31beded 100644
--- a/pytorch_widedeep/models/tabular/mlp/context_attention_mlp.py
+++ b/pytorch_widedeep/models/tabular/mlp/context_attention_mlp.py
@@ -188,7 +188,7 @@ def output_dim(self) -> int:
)
@property
- def attention_weights(self) -> List:
+ def attention_weights(self) -> List[Tensor]:
r"""List with the attention weights per block
The shape of the attention weights is $(N, F)$, where $N$ is the batch
diff --git a/pytorch_widedeep/models/tabular/mlp/self_attention_mlp.py b/pytorch_widedeep/models/tabular/mlp/self_attention_mlp.py
index 8ca6890d..4b17b312 100644
--- a/pytorch_widedeep/models/tabular/mlp/self_attention_mlp.py
+++ b/pytorch_widedeep/models/tabular/mlp/self_attention_mlp.py
@@ -198,7 +198,7 @@ def output_dim(self) -> int:
)
@property
- def attention_weights(self) -> List:
+ def attention_weights(self) -> List[Tensor]:
r"""List with the attention weights per block
The shape of the attention weights is $(N, H, F, F)$, where $N$ is the
diff --git a/pytorch_widedeep/models/tabular/transformers/ft_transformer.py b/pytorch_widedeep/models/tabular/transformers/ft_transformer.py
index c7c8149d..7b1589a6 100644
--- a/pytorch_widedeep/models/tabular/transformers/ft_transformer.py
+++ b/pytorch_widedeep/models/tabular/transformers/ft_transformer.py
@@ -275,7 +275,7 @@ def output_dim(self) -> int:
)
@property
- def attention_weights(self) -> List:
+ def attention_weights(self) -> List[Tensor]:
r"""List with the attention weights per block
The shape of the attention weights is: $(N, H, F, k)$, where $N$ is
diff --git a/pytorch_widedeep/models/tabular/transformers/saint.py b/pytorch_widedeep/models/tabular/transformers/saint.py
index 27da26d4..cfed7488 100644
--- a/pytorch_widedeep/models/tabular/transformers/saint.py
+++ b/pytorch_widedeep/models/tabular/transformers/saint.py
@@ -248,7 +248,7 @@ def output_dim(self) -> int:
)
@property
- def attention_weights(self) -> List:
+ def attention_weights(self) -> List[Tuple[Tensor, Tensor]]:
r"""List with the attention weights. Each element of the list is a tuple
where the first and the second elements are the column and row
attention weights respectively
diff --git a/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py b/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py
index 02601fa6..17e9114b 100644
--- a/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py
+++ b/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py
@@ -281,7 +281,7 @@ def output_dim(self) -> int:
)
@property
- def attention_weights(self) -> List:
+ def attention_weights(self) -> List[Tuple[Tensor, Tensor]]:
r"""List with the attention weights. Each element of the list is a
tuple where the first and second elements are the $\alpha$
and $\beta$ attention weights in the paper.
diff --git a/pytorch_widedeep/models/tabular/transformers/tab_transformer.py b/pytorch_widedeep/models/tabular/transformers/tab_transformer.py
index 295ec7ba..868e3cbf 100644
--- a/pytorch_widedeep/models/tabular/transformers/tab_transformer.py
+++ b/pytorch_widedeep/models/tabular/transformers/tab_transformer.py
@@ -284,7 +284,7 @@ def output_dim(self) -> int:
)
@property
- def attention_weights(self) -> List:
+ def attention_weights(self) -> List[Tensor]:
r"""List with the attention weights per block
The shape of the attention weights is $(N, H, F, F)$, where $N$ is the
diff --git a/pytorch_widedeep/models/wide_deep.py b/pytorch_widedeep/models/wide_deep.py
index 35798e03..43943edd 100644
--- a/pytorch_widedeep/models/wide_deep.py
+++ b/pytorch_widedeep/models/wide_deep.py
@@ -15,7 +15,7 @@
warnings.filterwarnings("default", category=UserWarning)
-WDModel = Union[nn.Module, BaseWDModelComponent]
+WDModel = Union[nn.Module, nn.Sequential, BaseWDModelComponent]
class WideDeep(nn.Module):
diff --git a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py
index 231072ae..1da60c99 100644
--- a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py
+++ b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py
@@ -1,5 +1,6 @@
import os
import sys
+import warnings
from abc import ABC, abstractmethod
import numpy as np
@@ -174,17 +175,34 @@ def _set_callbacks(self, callbacks: Any):
self.callback_container.set_model(self.cd_model)
self.callback_container.set_trainer(self)
- def _restore_best_weights(self):
- already_restored = any(
- [
- (
- callback.__class__.__name__ == "EarlyStopping"
- and callback.restore_best_weights
- )
- for callback in self.callback_container.callbacks
- ]
- )
+ def _restore_best_weights(self): # noqa: C901
+ early_stopping_min_delta = None
+ model_checkpoint_min_delta = None
+ already_restored = False
+
+ for callback in self.callback_container.callbacks:
+ if (
+ callback.__class__.__name__ == "EarlyStopping"
+ and callback.restore_best_weights
+ ):
+ early_stopping_min_delta = callback.min_delta
+ already_restored = True
+
+ if callback.__class__.__name__ == "ModelCheckpoint":
+ model_checkpoint_min_delta = callback.min_delta
+
+ if (
+ early_stopping_min_delta is not None
+ and model_checkpoint_min_delta is not None
+ ) and (early_stopping_min_delta != model_checkpoint_min_delta):
+ warnings.warn(
+ "'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. "
+ "This implies a different definition of 'improvement' for these two callbacks",
+ UserWarning,
+ )
+
if already_restored:
+ # already restored via EarlyStopping
pass
else:
for callback in self.callback_container.callbacks:
diff --git a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py
index 29515320..aea5b9d6 100644
--- a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py
+++ b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py
@@ -1,5 +1,6 @@
import os
import sys
+import warnings
from abc import ABC, abstractmethod
import numpy as np
@@ -121,17 +122,34 @@ def _set_callbacks(self, callbacks: Any):
self.callback_container.set_model(self.ed_model)
self.callback_container.set_trainer(self)
- def _restore_best_weights(self):
- already_restored = any(
- [
- (
- callback.__class__.__name__ == "EarlyStopping"
- and callback.restore_best_weights
- )
- for callback in self.callback_container.callbacks
- ]
- )
+ def _restore_best_weights(self): # noqa: C901
+ early_stopping_min_delta = None
+ model_checkpoint_min_delta = None
+ already_restored = False
+
+ for callback in self.callback_container.callbacks:
+ if (
+ callback.__class__.__name__ == "EarlyStopping"
+ and callback.restore_best_weights
+ ):
+ early_stopping_min_delta = callback.min_delta
+ already_restored = True
+
+ if callback.__class__.__name__ == "ModelCheckpoint":
+ model_checkpoint_min_delta = callback.min_delta
+
+ if (
+ early_stopping_min_delta is not None
+ and model_checkpoint_min_delta is not None
+ ) and (early_stopping_min_delta != model_checkpoint_min_delta):
+ warnings.warn(
+ "'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. "
+ "This implies a different definition of 'improvement' for these two callbacks",
+ UserWarning,
+ )
+
if already_restored:
+ # already restored via EarlyStopping
pass
else:
for callback in self.callback_container.callbacks:
diff --git a/pytorch_widedeep/training/_base_bayesian_trainer.py b/pytorch_widedeep/training/_base_bayesian_trainer.py
index cefc111d..e6da0285 100644
--- a/pytorch_widedeep/training/_base_bayesian_trainer.py
+++ b/pytorch_widedeep/training/_base_bayesian_trainer.py
@@ -1,5 +1,6 @@
import os
import sys
+import warnings
from abc import ABC, abstractmethod
import numpy as np
@@ -120,17 +121,34 @@ def save(
):
raise NotImplementedError("Trainer.save method not implemented")
- def _restore_best_weights(self):
- already_restored = any(
- [
- (
- callback.__class__.__name__ == "EarlyStopping"
- and callback.restore_best_weights
- )
- for callback in self.callback_container.callbacks
- ]
- )
+ def _restore_best_weights(self): # noqa: C901
+ early_stopping_min_delta = None
+ model_checkpoint_min_delta = None
+ already_restored = False
+
+ for callback in self.callback_container.callbacks:
+ if (
+ callback.__class__.__name__ == "EarlyStopping"
+ and callback.restore_best_weights
+ ):
+ early_stopping_min_delta = callback.min_delta
+ already_restored = True
+
+ if callback.__class__.__name__ == "ModelCheckpoint":
+ model_checkpoint_min_delta = callback.min_delta
+
+ if (
+ early_stopping_min_delta is not None
+ and model_checkpoint_min_delta is not None
+ ) and (early_stopping_min_delta != model_checkpoint_min_delta):
+ warnings.warn(
+ "'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. "
+ "This implies a different definition of 'improvement' for these two callbacks",
+ UserWarning,
+ )
+
if already_restored:
+ # already restored via EarlyStopping
pass
else:
for callback in self.callback_container.callbacks:
diff --git a/pytorch_widedeep/training/_base_trainer.py b/pytorch_widedeep/training/_base_trainer.py
index 4f0527c0..3c11f77c 100644
--- a/pytorch_widedeep/training/_base_trainer.py
+++ b/pytorch_widedeep/training/_base_trainer.py
@@ -30,7 +30,6 @@
)
from pytorch_widedeep.initializers import Initializer, MultipleInitializer
from pytorch_widedeep.training._trainer_utils import alias_to_loss
-from pytorch_widedeep.models.tabular.tabnet._utils import create_explain_matrix
from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer
from pytorch_widedeep.training._multiple_transforms import MultipleTransforms
from pytorch_widedeep.training._loss_and_obj_aliases import _ObjectiveToMethod
@@ -67,7 +66,6 @@ def __init__(
self.model = model
if self.model.is_tabnet:
self.lambda_sparse = kwargs.get("lambda_sparse", 1e-3)
- self.reducing_matrix = create_explain_matrix(self.model)
self.model.to(self.device)
self.model.wd_device = self.device
diff --git a/pytorch_widedeep/training/_feature_importance.py b/pytorch_widedeep/training/_feature_importance.py
new file mode 100644
index 00000000..4e8711a0
--- /dev/null
+++ b/pytorch_widedeep/training/_feature_importance.py
@@ -0,0 +1,379 @@
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch
+from scipy.sparse import csc_matrix
+from torch.utils.data import DataLoader
+
+from pytorch_widedeep.wdtypes import (
+ Any,
+ Dict,
+ List,
+ Tuple,
+ Union,
+ Tensor,
+ Optional,
+ WideDeep,
+)
+from pytorch_widedeep.models.tabular import (
+ SAINT,
+ TabPerceiver,
+ FTTransformer,
+ TabFastFormer,
+ TabTransformer,
+ SelfAttentionMLP,
+ ContextAttentionMLP,
+)
+from pytorch_widedeep.utils.general_utils import Alias
+from pytorch_widedeep.training._wd_dataset import WideDeepDataset
+from pytorch_widedeep.models.tabular.tabnet._utils import create_explain_matrix
+
+TransformerBasedModels = (
+ SAINT,
+ FTTransformer,
+ TabFastFormer,
+ TabTransformer,
+ SelfAttentionMLP,
+ ContextAttentionMLP,
+)
+
+__all__ = ["FeatureImportance", "Explainer"]
+
+
+# TODO: review typing for WideDeep (in particular the deeptabular part) The
+# issue with the typing of the deeptabular part is the following: the
+# deeptabular part is typed as Optional[BaseWDModelComponent]. While that is
+# correct, is not fully informative, and the most correct approach would
+# perhaps be [BaseWDModelComponent, BaseTabularModelWithAttention,
+# BaseTabularModelWithoutAttention]. Perhaps this way as we do
+# model.deeptabular._modules["0"] it would be understood that the type of
+# this value is such that has a property called `attention_weights` and I
+# would not have to use type ignores here and there. For the time being I am
+# going to leave it as it is (since it is not wrong), but this needs to be
+# revisited
+
+
+class FeatureImportance:
+ def __init__(self, device: str, n_samples: int = 1000):
+ self.device = device
+ self.n_samples = n_samples
+
+ def feature_importance(
+ self, loader: DataLoader, model: WideDeep
+ ) -> Dict[str, float]:
+ if model.is_tabnet:
+ model_feat_importance: ModelFeatureImportance = TabNetFeatureImportance(
+ self.device, self.n_samples
+ )
+ elif isinstance(model.deeptabular._modules["0"], TransformerBasedModels):
+ model_feat_importance = TransformerBasedFeatureImportance(
+ self.device, self.n_samples
+ )
+ else:
+ raise ValueError(
+ "The computation of feature importance is not supported for this particular model"
+ )
+
+ return model_feat_importance.feature_importance(loader, model)
+
+
+class Explainer:
+ def __init__(self, device: str):
+ self.device = device
+
+ def explain(
+ self,
+ model: WideDeep,
+ X_tab: np.ndarray,
+ num_workers: int,
+ batch_size: Optional[int] = None,
+ save_step_masks: Optional[bool] = None,
+ ) -> Union[Tuple, np.ndarray]:
+ if model.is_tabnet:
+ assert (
+ save_step_masks is not None
+ ), "If the model is TabNet, please set 'save_step_masks' to True/False"
+ model_explainer: ModelExplainer = TabNetExplainer(self.device)
+ res = model_explainer.explain(
+ model,
+ X_tab,
+ num_workers,
+ batch_size,
+ save_step_masks,
+ )
+ elif isinstance(model.deeptabular._modules["0"], TransformerBasedModels):
+ model_explainer = TransformerBasedExplainer(self.device)
+ res = model_explainer.explain(
+ model,
+ X_tab,
+ num_workers,
+ batch_size,
+ )
+ else:
+ raise ValueError(
+ "The computation of feature importance is not supported for this particular model"
+ )
+
+ return res
+
+
+class BaseFeatureImportance(ABC):
+ def __init__(self, device: str, n_samples: int = 1000):
+ self.device = device
+ self.n_samples = n_samples
+
+ @abstractmethod
+ def feature_importance(
+ self, loader: DataLoader, model: WideDeep
+ ) -> Dict[str, float]:
+ raise NotImplementedError(
+ "Any Feature Importance technique must implement this method"
+ )
+
+ def _sample_data(self, loader: DataLoader) -> Tensor:
+ n_iterations = self.n_samples // loader.batch_size
+
+ batches = []
+ for i, (data, _, _) in enumerate(loader):
+ if i < n_iterations:
+ batches.append(data["deeptabular"].to(self.device))
+ else:
+ break
+
+ return torch.cat(batches, dim=0)
+
+
+class TabNetFeatureImportance(BaseFeatureImportance):
+ def __init__(self, device: str, n_samples: int = 1000):
+ super().__init__(
+ device=device,
+ n_samples=n_samples,
+ )
+
+ def feature_importance(
+ self, loader: DataLoader, model: WideDeep
+ ) -> Dict[str, float]:
+ model.eval()
+
+ reducing_matrix = create_explain_matrix(model)
+ model_backbone = list(model.deeptabular.children())[0]
+ feat_imp = np.zeros((model_backbone.embed_out_dim)) # type: ignore[arg-type]
+
+ X = self._sample_data(loader)
+ M_explain, _ = model_backbone.forward_masks(X) # type: ignore[operator]
+ feat_imp += M_explain.sum(dim=0).cpu().detach().numpy()
+ feat_imp = csc_matrix.dot(feat_imp, reducing_matrix)
+ feat_imp = feat_imp / np.sum(feat_imp)
+
+ return {k: v for k, v in zip(model_backbone.column_idx.keys(), feat_imp)} # type: ignore
+
+
+class TransformerBasedFeatureImportance(BaseFeatureImportance):
+ def __init__(self, device: str, n_samples: int = 1000):
+ super().__init__(
+ device=device,
+ n_samples=n_samples,
+ )
+
+ def feature_importance(
+ self, loader: DataLoader, model: WideDeep
+ ) -> Dict[str, float]:
+ self._check_inputs(model)
+ self.model_type = self._model_type(model)
+
+ X = self._sample_data(loader)
+
+ model.eval()
+ _ = model.deeptabular(X)
+
+ feature_importance, column_idx = self._feature_importance(model)
+
+ agg_feature_importance = feature_importance.mean(0).cpu().detach().numpy()
+
+ return {k: v for k, v in zip(column_idx, agg_feature_importance)}
+
+ def _model_type_attention_weights(self, model: WideDeep) -> Tensor:
+ if self.model_type == "saint":
+ attention_weights = torch.stack(
+ [aw[0] for aw in model.deeptabular[0].attention_weights], # type: ignore[index]
+ dim=0,
+ )
+ elif self.model_type == "tabfastformer":
+ alpha_weights, beta_weights = zip(*model.deeptabular[0].attention_weights) # type: ignore[index]
+ attention_weights = torch.stack(alpha_weights + beta_weights, dim=0)
+ else:
+ attention_weights = torch.stack(
+ model.deeptabular[0].attention_weights, dim=0 # type: ignore[index]
+ )
+
+ return attention_weights
+
+ def _model_type_feature_importance(
+ self, model: WideDeep, attention_weights: Tensor
+ ) -> Tuple[Tensor, List[str]]:
+ model_backbone = list(model.deeptabular.children())[0]
+ with_cls_token = model.deeptabular._modules["0"].with_cls_token
+
+ column_idx = (
+ list(model_backbone.column_idx.keys())[1:] # type: ignore
+ if with_cls_token
+ else list(model_backbone.column_idx.keys()) # type: ignore
+ )
+
+ if self.model_type == "contextattentionmlp":
+ feat_imp = (
+ attention_weights.mean(0)
+ if not with_cls_token
+ else attention_weights.mean(0)[:, 1:]
+ )
+ elif self.model_type == "tabfastformer":
+ feat_imp = (
+ attention_weights.mean(0).mean(1)
+ if not with_cls_token
+ else attention_weights.mean(0).mean(1)[:, 1:]
+ )
+ else:
+ feat_imp = (
+ attention_weights.mean(0).mean(1).mean(1)
+ if not with_cls_token
+ else attention_weights.mean(0).mean(1)[:, 0, 1:]
+ )
+
+ return feat_imp, column_idx
+
+ @staticmethod
+ def _model_type(model: WideDeep) -> str:
+ if isinstance(model.deeptabular._modules["0"], SAINT):
+ model_type = "saint"
+ if isinstance(model.deeptabular._modules["0"], FTTransformer):
+ model_type = "fttransformer"
+ if isinstance(model.deeptabular._modules["0"], TabFastFormer):
+ model_type = "tabfastformer"
+ if isinstance(model.deeptabular._modules["0"], TabTransformer):
+ model_type = "tabtransformer"
+ if isinstance(model.deeptabular._modules["0"], SelfAttentionMLP):
+ model_type = "selfattentionmlp"
+ if isinstance(model.deeptabular._modules["0"], ContextAttentionMLP):
+ model_type = "contextattentionmlp"
+
+ return model_type
+
+ def _feature_importance(self, model: WideDeep) -> Tuple[Tensor, List[str]]:
+ attention_weights = self._model_type_attention_weights(model)
+
+ feature_importance, column_idx = self._model_type_feature_importance(
+ model, attention_weights
+ )
+
+ return feature_importance, column_idx
+
+ def _check_inputs(self, model: WideDeep):
+ if isinstance(model.deeptabular._modules["0"], TabPerceiver):
+ raise ValueError(
+ "At this stage the feature importance is not supported for the 'TabPerceiver'"
+ )
+ if isinstance(model.deeptabular._modules["0"], FTTransformer) and (
+ model.deeptabular._modules["0"].kv_compression_factor != 1
+ ):
+ raise ValueError(
+ "Feature importance can only be computed if the compression factor "
+ "'kv_compression_factor' is set to 1"
+ )
+
+
+class TabNetExplainer:
+ def __init__(self, device: str, n_samples: int = 1000):
+ self.device = device
+
+ @Alias("X_tab", "X")
+ def explain(
+ self,
+ model: WideDeep,
+ X_tab: np.ndarray,
+ num_workers: int,
+ batch_size: Optional[int] = None,
+ save_step_masks: bool = False,
+ ) -> Union[Tuple, np.ndarray]:
+ model.eval()
+ model_backbone = list(model.deeptabular.children())[0]
+ reducing_matrix = create_explain_matrix(model)
+
+ loader = DataLoader(
+ dataset=WideDeepDataset(X_tab=X_tab),
+ batch_size=batch_size,
+ num_workers=num_workers,
+ shuffle=False,
+ )
+
+ m_explain_l = []
+ for batch_nb, data in enumerate(loader):
+ X = data["deeptabular"].to(self.device)
+ M_explain, masks = model_backbone.forward_masks(X) # type: ignore[operator]
+ m_explain_l.append(
+ csc_matrix.dot(M_explain.cpu().detach().numpy(), reducing_matrix)
+ )
+ if save_step_masks:
+ for key, value in masks.items():
+ masks[key] = csc_matrix.dot(
+ value.cpu().detach().numpy(), reducing_matrix
+ )
+ if batch_nb == 0:
+ m_explain_step = masks
+ else:
+ for key, value in masks.items():
+ m_explain_step[key] = np.vstack([m_explain_step[key], value])
+
+ m_explain_agg = np.vstack(m_explain_l)
+ m_explain_agg_norm = m_explain_agg / m_explain_agg.sum(axis=1)[:, np.newaxis]
+
+ res: Union[Tuple, np.ndarray] = (
+ (m_explain_agg_norm, m_explain_step)
+ if save_step_masks
+ else np.vstack(m_explain_agg_norm)
+ )
+
+ return res
+
+
+class TransformerBasedExplainer(TransformerBasedFeatureImportance):
+ def __init__(self, device: str):
+ super().__init__(
+ device=device,
+ n_samples=1000, # irrelevant
+ )
+
+ @Alias("X_tab", "X")
+ def explain(
+ self,
+ model: WideDeep,
+ X_tab: np.ndarray,
+ num_workers: int,
+ batch_size: Optional[int] = None,
+ ) -> np.ndarray:
+ model.eval()
+
+ self.model_type = self._model_type(model)
+
+ loader = DataLoader(
+ dataset=WideDeepDataset(X_tab=X_tab),
+ batch_size=batch_size,
+ num_workers=num_workers,
+ shuffle=False,
+ )
+
+ batch_feat_imp: Any = []
+ for _, data in enumerate(loader):
+ X = data["deeptabular"].to(self.device)
+ _ = model.deeptabular(X)
+
+ feat_imp, col_idx = self._feature_importance(model)
+
+ batch_feat_imp.append(feat_imp)
+
+ return torch.cat(batch_feat_imp).cpu().detach().numpy()
+
+
+ModelFeatureImportance = Union[
+ TabNetFeatureImportance, TransformerBasedFeatureImportance
+]
+ModelExplainer = Union[TabNetExplainer, TransformerBasedExplainer]
diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py
index cd005071..5dd45ae3 100644
--- a/pytorch_widedeep/training/trainer.py
+++ b/pytorch_widedeep/training/trainer.py
@@ -8,7 +8,6 @@
import torch.nn.functional as F
from tqdm import trange
from torch import nn
-from scipy.sparse import csc_matrix
from torchmetrics import Metric as TorchMetric
from torch.utils.data import DataLoader
@@ -39,6 +38,10 @@
wd_train_val_split,
print_loss_and_metric,
)
+from pytorch_widedeep.training._feature_importance import (
+ Explainer,
+ FeatureImportance,
+)
class Trainer(BaseTrainer):
@@ -265,6 +268,7 @@ def fit( # noqa: C901
validation_freq: int = 1,
batch_size: int = 32,
custom_dataloader: Optional[DataLoader] = None,
+ feature_importance_sample_size: Optional[int] = None,
finetune: bool = False,
with_lds: bool = False,
**kwargs,
@@ -520,15 +524,18 @@ def fit( # noqa: C901
self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric)
if self.early_stop:
- self.callback_container.on_train_end(epoch_logs)
+ # self.callback_container.on_train_end(epoch_logs)
break
if self.model.with_fds:
self._update_fds_stats(train_loader, epoch)
self.callback_container.on_train_end(epoch_logs)
- if self.model.is_tabnet:
- self._compute_feature_importance(train_loader)
+
+ if feature_importance_sample_size is not None:
+ self.feature_importance = FeatureImportance(
+ self.device, feature_importance_sample_size
+ ).feature_importance(train_loader, self.model)
self._restore_best_weights()
self.model.train()
@@ -740,64 +747,13 @@ def predict_proba( # type: ignore[return]
if self.method == "multiclass":
return np.vstack(preds_l)
- def explain(
- self, X_tab: np.ndarray, save_step_masks: bool = False
- ) -> Union[Tuple, np.ndarray]:
- """
- if the `deeptabular` component is a `Tabnet` model, returns the
- aggregated feature importance for each instance (or observation) in
- the `X_tab` array. If `save_step_masks` is set to `True`, the
- masks per step will also be returned.
-
- Parameters
- ----------
- X_tab: np.ndarray
- Input array corresponding **only** to the deeptabular component
- save_step_masks: bool
- Boolean indicating if the masks per step will be returned
-
- Returns
- -------
- Union[Tuple, np.ndarray]
- Array or Tuple of two arrays with the corresponding aggregated
- feature importance and the masks per step if `save_step_masks`
- is set to `True`
- """
- loader = DataLoader(
- dataset=WideDeepDataset(X_tab=X_tab),
- batch_size=self.batch_size,
- num_workers=self.num_workers,
- shuffle=False,
- )
+ def explain(self, X_tab: np.ndarray, save_step_masks: Optional[bool] = None):
+ # TO DO: Add docs to this, to the feat imp parameter and the all
+ # related classes
+ explainer = Explainer(self.device)
- self.model.eval()
- tabnet_backbone = list(self.model.deeptabular.children())[0]
-
- m_explain_l = []
- for batch_nb, data in enumerate(loader):
- X = data["deeptabular"].to(self.device)
- M_explain, masks = tabnet_backbone.forward_masks(X) # type: ignore[operator]
- m_explain_l.append(
- csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix)
- )
- if save_step_masks:
- for key, value in masks.items():
- masks[key] = csc_matrix.dot(
- value.cpu().detach().numpy(), self.reducing_matrix
- )
- if batch_nb == 0:
- m_explain_step = masks
- else:
- for key, value in masks.items():
- m_explain_step[key] = np.vstack([m_explain_step[key], value])
-
- m_explain_agg = np.vstack(m_explain_l)
- m_explain_agg_norm = m_explain_agg / m_explain_agg.sum(axis=1)[:, np.newaxis]
-
- res: Union[Tuple, np.ndarray] = (
- (m_explain_agg_norm, m_explain_step)
- if save_step_masks
- else np.vstack(m_explain_agg_norm)
+ res = explainer.explain(
+ self.model, X_tab, self.num_workers, self.batch_size, save_step_masks
)
return res
@@ -1069,28 +1025,6 @@ def _update_fds_stats(self, train_loader: DataLoader, epoch: int):
self.model.fds_layer.update_last_epoch_stats(epoch)
self.model.fds_layer.update_running_stats(features, y_pred, epoch)
- def _compute_feature_importance(self, loader: DataLoader):
- self.model.eval()
- tabnet_backbone = list(self.model.deeptabular.children())[0]
- feat_imp = np.zeros((tabnet_backbone.embed_out_dim)) # type: ignore[arg-type]
- for data, target, _ in loader:
- X = data["deeptabular"].to(self.device)
- y = (
- target.view(-1, 1).float()
- if self.method not in ["multiclass", "qregression"]
- else target
- )
- y = y.to(self.device)
- M_explain, masks = tabnet_backbone.forward_masks(X) # type: ignore[operator]
- feat_imp += M_explain.sum(dim=0).cpu().detach().numpy()
-
- feat_imp = csc_matrix.dot(feat_imp, self.reducing_matrix)
- feat_imp = feat_imp / np.sum(feat_imp)
-
- self.feature_importance = {
- k: v for k, v in zip(tabnet_backbone.column_idx.keys(), feat_imp) # type: ignore[operator, union-attr]
- }
-
def _predict( # noqa: C901
self,
X_wide: Optional[np.ndarray] = None,
diff --git a/pytorch_widedeep/version.py b/pytorch_widedeep/version.py
index 10aa336c..67bc602a 100644
--- a/pytorch_widedeep/version.py
+++ b/pytorch_widedeep/version.py
@@ -1 +1 @@
-__version__ = "1.2.3"
+__version__ = "1.3.0"
diff --git a/setup.py b/setup.py
index 8213c3a1..504f09b3 100644
--- a/setup.py
+++ b/setup.py
@@ -21,6 +21,7 @@ def requirements(fname):
"1.0": "Development Status :: 5 - Production/Stable", # v1.0 - most functionality + doc + test # noqa
"1.1": "Development Status :: 5 - Production/Stable", # v1.1 - new functionality
"1.2": "Development Status :: 5 - Production/Stable", # v1.2 - new functionality
+ "1.3": "Development Status :: 5 - Production/Stable", # v1.3 - new functionality
"2.0": "Development Status :: 6 - Mature", # v2.0 - new functionality?
}
diff --git a/tests/test_feature_importance/test_transformers_feat_imp.py b/tests/test_feature_importance/test_transformers_feat_imp.py
new file mode 100644
index 00000000..f176400a
--- /dev/null
+++ b/tests/test_feature_importance/test_transformers_feat_imp.py
@@ -0,0 +1,187 @@
+import numpy as np
+import pandas as pd
+import pytest
+
+from pytorch_widedeep import Trainer
+from pytorch_widedeep.models import (
+ SAINT,
+ TabNet,
+ WideDeep,
+ FTTransformer,
+ TabFastFormer,
+ TabTransformer,
+ SelfAttentionMLP,
+ ContextAttentionMLP,
+)
+from pytorch_widedeep.preprocessing import TabPreprocessor
+
+np.random.seed(42)
+
+# Define the column names
+cat_cols = ["cat1", "cat2", "cat3", "cat4"]
+cont_cols = ["cont1", "cont2", "cont3", "cont4"]
+columns = cat_cols + cont_cols
+
+# Generate random categorical data
+categorical_data = np.random.choice(["A", "B", "C"], size=(32, 4))
+
+# Generate random numerical data
+numerical_data = np.random.randn(32, 4)
+
+# Create the DataFrame
+data = np.concatenate((categorical_data, numerical_data), axis=1)
+df = pd.DataFrame(data, columns=columns)
+target = np.random.choice(2, 32)
+
+df_tr = df[:16].copy()
+df_te = df[16:].copy().reset_index(drop=True)
+
+y_tr = target[:16]
+y_te = target[16:]
+
+# ############################TESTS BEGIN #######################################
+
+
+def _build_model_for_feat_imp_test(model_name, params):
+ if model_name == "tabtransformer":
+ return TabTransformer(
+ input_dim=6, n_blocks=2, n_heads=2, embed_continuous=True, **params
+ )
+ if model_name == "saint":
+ return SAINT(input_dim=6, n_blocks=2, n_heads=2, **params)
+ if model_name == "fttransformer":
+ return FTTransformer(
+ input_dim=6, n_blocks=2, n_heads=2, kv_compression_factor=1.0, **params
+ )
+ if model_name == "tabfastformer":
+ return TabFastFormer(input_dim=6, n_blocks=2, n_heads=2, **params)
+ if model_name == "self_attn_mlp":
+ return SelfAttentionMLP(input_dim=6, n_blocks=2, n_heads=2, **params)
+ if model_name == "cxt_attn_mlp":
+ return ContextAttentionMLP(input_dim=6, n_blocks=2, **params)
+
+
+@pytest.mark.parametrize("with_cls_token", [True, False])
+@pytest.mark.parametrize(
+ "model_name",
+ [
+ "tabtransformer",
+ "saint",
+ "fttransformer",
+ "tabfastformer",
+ "self_attn_mlp",
+ "cxt_attn_mlp",
+ ],
+)
+def test_feature_importances(with_cls_token, model_name):
+ tab_preprocessor = TabPreprocessor(
+ cat_embed_cols=cat_cols,
+ continuous_cols=cont_cols,
+ with_attention=True,
+ with_cls_token=with_cls_token,
+ )
+ X_tr = tab_preprocessor.fit_transform(df_tr).astype(float)
+ X_te = tab_preprocessor.transform(df_te).astype(float)
+
+ params = {
+ "column_idx": tab_preprocessor.column_idx,
+ "cat_embed_input": tab_preprocessor.cat_embed_input,
+ "continuous_cols": tab_preprocessor.continuous_cols,
+ }
+
+ tab_model = _build_model_for_feat_imp_test(model_name, params)
+
+ model = WideDeep(deeptabular=tab_model)
+
+ trainer = Trainer(
+ model,
+ objective="binary",
+ )
+
+ trainer.fit(
+ X_tab=X_tr,
+ target=target,
+ n_epochs=1,
+ batch_size=16,
+ feature_importance_sample_size=32,
+ )
+
+ feat_imps = trainer.feature_importance
+ feat_imp_per_sample = trainer.explain(X_te)
+
+ assert len(feat_imps) == df_tr.shape[1] and feat_imp_per_sample.shape == df_te.shape
+
+
+def test_fttransformer_valueerror():
+ tab_preprocessor = TabPreprocessor(
+ cat_embed_cols=cat_cols,
+ continuous_cols=cont_cols,
+ with_attention=True,
+ )
+ X_tr = tab_preprocessor.fit_transform(df_tr).astype(float)
+
+ params = {
+ "column_idx": tab_preprocessor.column_idx,
+ "cat_embed_input": tab_preprocessor.cat_embed_input,
+ "continuous_cols": tab_preprocessor.continuous_cols,
+ }
+
+ model = FTTransformer(
+ input_dim=6, n_blocks=2, n_heads=2, kv_compression_factor=0.5, **params
+ )
+ model = WideDeep(deeptabular=model)
+
+ trainer = Trainer(
+ model,
+ objective="binary",
+ )
+
+ with pytest.raises(ValueError) as ve:
+ trainer.fit(
+ X_tab=X_tr,
+ target=target,
+ n_epochs=1,
+ batch_size=16,
+ feature_importance_sample_size=32,
+ )
+
+ assert (
+ ve.value.args[0]
+ == "Feature importance can only be computed if the compression factor 'kv_compression_factor' is set to 1"
+ )
+
+
+def test_feature_importances_tabnet():
+ tab_preprocessor = TabPreprocessor(
+ cat_embed_cols=cat_cols,
+ continuous_cols=cont_cols,
+ )
+ X_tr = tab_preprocessor.fit_transform(df_tr).astype(float)
+ X_te = tab_preprocessor.transform(df_te).astype(float)
+
+ tabnet = TabNet(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input,
+ continuous_cols=tab_preprocessor.continuous_cols,
+ embed_continuous=True,
+ )
+
+ model = WideDeep(deeptabular=tabnet)
+
+ trainer = Trainer(
+ model,
+ objective="binary",
+ )
+
+ trainer.fit(
+ X_tab=X_tr,
+ target=target,
+ n_epochs=1,
+ batch_size=16,
+ feature_importance_sample_size=32,
+ )
+
+ feat_imps = trainer.feature_importance
+ feat_imp_per_sample = trainer.explain(X_te, save_step_masks=False)
+
+ assert len(feat_imps) == df_tr.shape[1] and feat_imp_per_sample.shape == df_te.shape
diff --git a/tests/test_model_functioning/test_miscellaneous.py b/tests/test_model_functioning/test_miscellaneous.py
index e93c313b..3b356d9d 100644
--- a/tests/test_model_functioning/test_miscellaneous.py
+++ b/tests/test_model_functioning/test_miscellaneous.py
@@ -287,32 +287,32 @@ def test_save_and_load_dict():
assert same_weights and history_saved
-###############################################################################
-# test explain matrices and feature importance for TabNet
-###############################################################################
+# ###############################################################################
+# # test explain matrices and feature importance for TabNet (Moved to 'test_feature_importance')
+# ###############################################################################
-def test_explain_mtx_and_feat_imp():
- model = WideDeep(deeptabular=tabnet)
- trainer = Trainer(model, objective="binary", verbose=0)
- trainer.fit(
- X_tab=X_tab,
- target=target,
- batch_size=16,
- )
+# def test_explain_mtx_and_feat_imp():
+# model = WideDeep(deeptabular=tabnet)
+# trainer = Trainer(model, objective="binary", verbose=0)
+# trainer.fit(
+# X_tab=X_tab,
+# target=target,
+# batch_size=16,
+# )
- checks = []
- checks.append(len(trainer.feature_importance) == len(tabnet.column_idx))
+# checks = []
+# checks.append(len(trainer.feature_importance) == len(tabnet.column_idx))
- expl_mtx, step_masks = trainer.explain(X_tab[:6], save_step_masks=True)
- checks.append(expl_mtx.shape[0] == 6)
- checks.append(expl_mtx.shape[1] == 10)
+# expl_mtx, step_masks = trainer.explain(X_tab[:6], save_step_masks=True)
+# checks.append(expl_mtx.shape[0] == 6)
+# checks.append(expl_mtx.shape[1] == 10)
- for i in range(tabnet.n_steps):
- checks.append(step_masks[i].shape[0] == 6)
- checks.append(step_masks[i].shape[1] == 10)
+# for i in range(tabnet.n_steps):
+# checks.append(step_masks[i].shape[0] == 6)
+# checks.append(step_masks[i].shape[1] == 10)
- assert all(checks)
+# assert all(checks)
###############################################################################