diff --git a/README.md b/README.md
index 8ec73e9..ff30c75 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,15 @@
# Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization
-This repository provides code to use the method presented in our GCPR 2023 paper "Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization". If you use this method, please cite:
+## Overview
+This repository provides code to use the method presented in our GCPR 2023 paper **"Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization"**. If you want to get started, take a look at our [example network](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/regression_network.py) and the corresponding [jupyter notebook](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/feature_steering_example.ipynb).
+
+
+
+
+
+Our method generalizes from debiasing to the **encouragement and discouragement of arbitrary features**. That is, it not only aims at removing the influence of undesired features / biases but also at increasing the influence of features that are known to be well-established from domain knowledge.
+
+If you use our method, please cite:
@inproceedings{Blunk23:FS,
author = {Jan Blunk and Niklas Penzel and Paul Bodesheim and Joachim Denzler},
@@ -9,8 +18,33 @@ This repository provides code to use the method presented in our GCPR 2023 paper
year = {2023},
}
-This repository includes a Python implementation of the hybrid CMI estimator CMIh presented by [Zan et al.](https://doi.org/10.3390/e24091234) The authors' original R implementation can be found [here](https://github.com/leizan/CMIh2022). CMIh was published under the MIT license.
-
## Installation
-First, you have to install ACD as described in the [Repository for "Hierarchical interpretations for neural network predictions" by Singh et al.](https://github.com/csinva/hierarchical-dnn-interpretations)
+**Install with pip, Python and PyTorch 2.0+**
+
+ git clone https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing.git
+ cd beyond-debiasing
+ pip install -r requirements.txt
+
+First, create an environment with pip and Python first (Anaconda environment / Python virtual environment). We recommend to install [PyTorch with CUDA support](https://pytorch.org/get-started/locally/). Then, you can install all subsequent packages via pip as described above.
+
+## Usage in Python
+Since our method relies on loss regularization, it is very simple to add to your own networks - you only need to modify your loss function. To help with that, we provide an [exemplary network](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/regression_network.py) and a [jupyter notebook](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/feature_steering_example.ipynb) with example code.
+
+## Repository Organization
+* Installation:
+ * [`requirements.txt`](requirements.txt): List of required packages for installation with pip
+* Feature attribution:
+ * [`contextual_decomposition.py`](contextual_decomposition.py): Wrapper for contextual decomposition
+ * [`mixed_cmi_estimator.py`](mixed_cmi_estimator.py): Python port of the CMIh estimator of the conditional
+* Redundant regression dataset:
+ * [`algebra.py`](algebra.py): Generation of random orthogonal matrices
+ * [`make_regression.py`](make_regression.py): An adapted version of scikit-learns make_regression(...), where the coefficients are standard-uniform
+ * [`regression_dataset.py`](regression_dataset.py): Generation of the redundant regression dataset
+ * [`dataset_utils.py`](dataset_utils.py): Creation of torch dataset from numpy arrays
+ * [`tensor_utils.py`](tensor_utils.py): Some helpful functions for dealing with tensors
+* Example:
+ * [`feature_steering_example.ipynb`](feature_steering_example.ipynb): Example for generating the dataset, creating and training the network with detailed comments
+ * [`regression_network.py`](regression_network.py): Neural network (PyTorch) used in the example notebook
+
+With [`mixed_cmi_estimator.py`](mixed_cmi_estimator.py) this repository includes a Python implementation of the hybrid CMI estimator CMIh presented by [Zan et al.](https://doi.org/10.3390/e24091234) The authors' original R implementation can be found [here](https://github.com/leizan/CMIh2022).
diff --git a/dataset_utils.py b/dataset_utils.py
new file mode 100644
index 0000000..ee6b039
--- /dev/null
+++ b/dataset_utils.py
@@ -0,0 +1,37 @@
+import torch
+
+def get_dataset_from_arrays(train_features, train_outputs, test_features, test_outputs, validation_features=None, validation_outputs=None, batch_size=1):
+ """
+ Both test and train dataset are numpy arrays. Observations are represented
+ as rows, features as columns.
+ train_targets and test_targets are vectors, containing one value per row
+ (expected results).
+ """
+
+ train_inputs = torch.tensor(train_features.tolist())
+ train_targets = torch.FloatTensor(train_outputs)
+ train_dataset = torch.utils.data.TensorDataset(train_inputs, train_targets)
+
+ test_inputs = torch.tensor(test_features.tolist())
+ test_targets = torch.FloatTensor(test_outputs)
+ test_dataset = torch.utils.data.TensorDataset(test_inputs, test_targets)
+
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
+ )
+ test_loader = torch.utils.data.DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=1
+ )
+
+ if not validation_features is None:
+ validation_inputs = torch.tensor(validation_features.tolist())
+ validation_targets = torch.FloatTensor(validation_outputs)
+ validation_dataset = torch.utils.data.TensorDataset(validation_inputs, validation_targets)
+
+ validation_loader = torch.utils.data.DataLoader(
+ validation_dataset, batch_size=batch_size, shuffle=False, num_workers=1
+ )
+
+ return (train_dataset, train_loader, test_dataset, test_loader, validation_dataset, validation_loader)
+ else:
+ return (train_dataset, train_loader, test_dataset, test_loader)
diff --git a/feature_steering_example.ipynb b/feature_steering_example.ipynb
new file mode 100644
index 0000000..30a4017
--- /dev/null
+++ b/feature_steering_example.ipynb
@@ -0,0 +1,222 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import numpy as np\n",
+ "from regression_dataset import make_regression_dataset\n",
+ "from regression_network import RegressionNetwork\n",
+ "\n",
+ "# Deterministic execution.\n",
+ "CUDA_LAUNCH_BLOCKING = 1\n",
+ "seed = 42\n",
+ "torch.manual_seed(seed)\n",
+ "np.random.seed(seed)\n",
+ "torch.backends.cudnn.benchmark = False\n",
+ "torch.backends.cudnn.deterministic = True"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Example for the Application of the Feature Steering Method Presented in “Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization”\n",
+ "\n",
+ "This jupyter notebook provides an example for our method for the redundant regression dataset presented in our paper.\n",
+ "\n",
+ "You can choose to generate feature attributions with the feature attribution method provided by Reimers et al. based on both **contextual decomposition** and **conditional mutual information**. Additionally, you can choose other hyperparameters such as the weight factor $\\lambda$ and the norm that is applied (L1 / L2 norm).\n",
+ "\n",
+ "## Dataset\n",
+ "We create a small regression dataset with redundant variables as described in our paper. That is, the created dataset has 9 input variables with a redundancy of 3 variables. In total, we generate 2000 samples, of which 1400 are used for training.\n",
+ "\n",
+ "*Note:* In the evaluations for our paper we not only generate one, but rather 9 datasets with different seeds."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Configuration of the datasets.\n",
+ "high_dim_transform = True\n",
+ "normalize = True # Should only be set if HIGH_DIM_TRANSFORM\n",
+ "n_informative_low_dim = 6\n",
+ "n_high_dim = 9\n",
+ "n_train, n_test, n_validation = 1400, 300, 300\n",
+ "n_uninformative_low_dim = 0\n",
+ "dataset_seed = 42\n",
+ "batch_size = 100\n",
+ "n_datasets = 9\n",
+ "\n",
+ "noise_on_output = 0.0\n",
+ "noise_on_high_dim_snrdb = None"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create and load the regression dataset.\n",
+ "train_dataloader, test_dataloader, validation_dataloader, _ = make_regression_dataset(\n",
+ " high_dim_transform=high_dim_transform,\n",
+ " n_features_low_dim=n_informative_low_dim,\n",
+ " n_uninformative_low_dim=n_uninformative_low_dim,\n",
+ " n_high_dim=n_high_dim,\n",
+ " noise_on_high_dim_snrdb=noise_on_high_dim_snrdb,\n",
+ " noise_on_output=noise_on_output,\n",
+ " n_train=n_train,\n",
+ " n_test=n_test,\n",
+ " n_validation=n_validation,\n",
+ " normalize=normalize,\n",
+ " seed=dataset_seed,\n",
+ " batch_size=batch_size,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Network\n",
+ "We follow the paper and create a network with a single hidden layer of size 9 and input size 9."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Network architecture.\n",
+ "input_size = n_high_dim\n",
+ "hidden_dim_size = n_high_dim\n",
+ "n_hidden_layers = 1\n",
+ "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
+ "\n",
+ "# Create Network.\n",
+ "mlp = RegressionNetwork(\n",
+ " input_shape=input_size,\n",
+ " n_hidden_layers=n_hidden_layers,\n",
+ " hidden_dim_size=hidden_dim_size,\n",
+ " device=device,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "After creating the network, we can train it with the *feature steering loss*.\n",
+ "\n",
+ "Recall from the paper that our method to steer feature usage is implemented via loss regularization. Let $D$ refer to the set of features that should be discouraged and $E$ to the set of features that should be encouraged. With $c_i$ being a measure of the influence of feature $i$ on the model's prediction process, $\\lambda \\in \\mathbb{R}_{\\ge 0}$ as a weight factor and $\\mathcal{L}$ as the standard maximum-likelihood loss for network parameters $\\theta$, our model is trained with the following loss function:\n",
+ "\n",
+ "$$ \\mathcal{L}'(\\theta) = \\mathcal{L}(\\theta) + \\lambda \\left( \\sum_{i \\in D} || c_i || - \\sum_{i \\in E} || c_i || \\right) .$$\n",
+ "For $|| \\cdot ||$, we consider the L1 and L2 norms.\n",
+ "\n",
+ "**Parameters:**\n",
+ "\n",
+ "Our implementation allows you to choose several *hyperparameters* for the feature steering process. You can adapt the following aspects of the calculation of the loss function:\n",
+ "\n",
+ "* The feature attributions $c_i$ are generated based on the feature attribution method proposed by Reimers et al. For this, the attribution modes `cmi` for feature attribution based on the (transformed) conditional mutual information and `contextual_decomposition` for feature attribution performed with contextual decomposition are available.\n",
+ "* Feature steering can be performed with feature attributions weighted with L1 norm (`loss_l1`) and L2 norm (`loss_l2`). That is, this modifies the norm applied for $|| \\cdot ||$.\n",
+ "* The indices of the features that shall be encouraged or discouraged (defining $D$ and $E$) are passed as lists.\n",
+ "* The weight factor $\\lambda$ is specified as `lambda`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loss (per sample) after epoch 1: 4712089.607142857\n",
+ "Loss (per sample) after epoch 2: 4480544.142857143\n",
+ "Loss (per sample) after epoch 3: 4258867.017857143\n",
+ "Loss (per sample) after epoch 4: 4050848.214285714\n",
+ "Loss (per sample) after epoch 5: 3851129.5714285714\n",
+ "Loss (per sample) after epoch 6: 3662716.375\n",
+ "Loss (per sample) after epoch 7: 3484030.3214285714\n",
+ "Loss (per sample) after epoch 8: 3317511.035714286\n",
+ "Loss (per sample) after epoch 9: 3158147.5714285714\n",
+ "Loss (per sample) after epoch 10: 3006144.9821428573\n",
+ "Loss (per sample) after epoch 11: 2864496.8035714286\n",
+ "Loss (per sample) after epoch 12: 2727674.410714286\n",
+ "Loss (per sample) after epoch 13: 2597680.910714286\n",
+ "Loss (per sample) after epoch 14: 2478867.535714286\n",
+ "Loss (per sample) after epoch 15: 2361367.4553571427\n",
+ "Loss (per sample) after epoch 16: 2251085.125\n",
+ "Loss (per sample) after epoch 17: 2148403.6160714286\n"
+ ]
+ },
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[9], line 13\u001b[0m\n\u001b[1;32m 4\u001b[0m feat_steering_config \u001b[39m=\u001b[39m {\n\u001b[1;32m 5\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mattrib_mode\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mcmi\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 6\u001b[0m \u001b[39m\"\u001b[39m\u001b[39msteering_mode\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mloss_l2\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mlambda\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m100.0\u001b[39m, \u001b[39m# Adapt accordingly for CMI / CD\u001b[39;00m\n\u001b[1;32m 10\u001b[0m }\n\u001b[1;32m 12\u001b[0m \u001b[39m# Train the network.\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m mlp\u001b[39m.\u001b[39;49mtrain(train_dataloader, feat_steering_config, epochs, learning_rate)\n",
+ "File \u001b[0;32m~/OneDrive/Publikationen/2023 - Feature Steering GCPR/Offizielles Repository/beyond-debiasing/regression_network.py:169\u001b[0m, in \u001b[0;36mRegressionNetwork.train\u001b[0;34m(self, train_dataloader, feat_steering_config, epochs, learning_rate)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[39mif\u001b[39;00m loss\u001b[39m.\u001b[39misnan():\n\u001b[1;32m 167\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mThe loss of your model is nan. Thus, no reasonable gradient can be computed!\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 169\u001b[0m loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m 170\u001b[0m optimizer\u001b[39m.\u001b[39mstep()\n\u001b[1;32m 172\u001b[0m \u001b[39m# Print statistics.\u001b[39;00m\n",
+ "File \u001b[0;32m~/anaconda3/envs/featuresteering-minimal/lib/python3.11/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[1;32m 478\u001b[0m \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 479\u001b[0m Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[1;32m 480\u001b[0m (\u001b[39mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 485\u001b[0m inputs\u001b[39m=\u001b[39minputs,\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\n\u001b[1;32m 488\u001b[0m \u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs\n\u001b[1;32m 489\u001b[0m )\n",
+ "File \u001b[0;32m~/anaconda3/envs/featuresteering-minimal/lib/python3.11/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 195\u001b[0m retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[1;32m 197\u001b[0m \u001b[39m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 199\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward( \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m tensors, grad_tensors_, retain_graph, create_graph, inputs,\n\u001b[1;32m 202\u001b[0m allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "# Training configuration.\n",
+ "learning_rate = 0.01\n",
+ "epochs = 90\n",
+ "feat_steering_config = {\n",
+ " \"attrib_mode\": \"cmi\",\n",
+ " \"steering_mode\": \"loss_l2\",\n",
+ " \"encourage\": [0, 1, 2],\n",
+ " \"discourage\": [],\n",
+ " \"lambda\": 100.0, # Adapt accordingly for CMI / CD\n",
+ "}\n",
+ "\n",
+ "# Train the network.\n",
+ "mlp.train(train_dataloader, feat_steering_config, epochs, learning_rate)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "featuresteering",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.5"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/mixed_cmiI_estimator.py b/mixed_cmi_estimator.py
similarity index 100%
rename from mixed_cmiI_estimator.py
rename to mixed_cmi_estimator.py
diff --git a/regression_dataset.py b/regression_dataset.py
index 766756c..68d7cd9 100644
--- a/regression_dataset.py
+++ b/regression_dataset.py
@@ -1,7 +1,7 @@
import os
import numpy as np
-from toy_data.algebra import random_orthogonal_matrix
-from toy_data.sklearn_adaptions import make_regression
+from algebra import random_orthogonal_matrix
+from make_regression import make_regression
from dataset_utils import get_dataset_from_arrays
def make_regression_dataset(high_dim_transform=True, n_features_low_dim=4, n_uninformative_low_dim=0, n_high_dim = 128, noise_on_high_dim_snrdb=None,
diff --git a/regression_network.py b/regression_network.py
new file mode 100644
index 0000000..d19acb5
--- /dev/null
+++ b/regression_network.py
@@ -0,0 +1,176 @@
+import torch
+from torch import nn
+from contextual_decomposition import get_cd_1d_by_modules
+from mixed_cmi_estimator import mixed_cmi_model
+
+class RegressionNetwork(nn.Module):
+ def __init__(self, input_shape, n_hidden_layers=2, hidden_dim_size=32, device='cpu'):
+ super().__init__()
+ self.device = device
+
+ # The network always has at least one hidden layer (input_shape -> 32).
+ # Make sure that n_hidden_layers is valid.
+ if n_hidden_layers < 1:
+ raise ValueError("The network cannot have less than 1 hidden layer.")
+
+ # Generate and initialize hidden layers.
+ # Note: we only need to generate n_hidden_layers-1 hidden layers!
+ lin_layers = [nn.Linear(in_features=hidden_dim_size, out_features=hidden_dim_size)] * (n_hidden_layers - 1)
+ for lin_layer in lin_layers:
+ nn.init.xavier_uniform_(lin_layer.weight)
+ nn.init.zeros_(lin_layer.bias)
+ relus = [nn.ReLU()] * len(lin_layers)
+
+ # Generate and intialize first and last layer.
+ input_layer = nn.Linear(input_shape, hidden_dim_size)
+ output_layer = nn.Linear(hidden_dim_size, 1)
+ for lin_layer in [input_layer, output_layer]:
+ nn.init.xavier_uniform_(lin_layer.weight)
+ nn.init.zeros_(lin_layer.bias)
+
+ # Combine layers to a model.
+ modules = [ nn.Flatten(),
+ input_layer,
+ nn.ReLU(),
+ *[z for tuple in zip(lin_layers, relus) for z in tuple],
+ output_layer,
+ ]
+ self.layers = nn.Sequential(*modules)
+ self.to(device)
+
+ def forward(self, x):
+ return self.layers(x)
+
+ def feat_steering_loss(self, inputs, targets, outputs, feat_steering_config=None):
+ # Get configuration for feature steering.
+ # Do not perform feature steering if it is not desired.
+ if feat_steering_config["steering_mode"] == "none":
+ return torch.tensor(0.0)
+ elif not feat_steering_config["steering_mode"] in ["loss_l1", "loss_l2"]:
+ raise ValueError("The feature steering mode is invalid.")
+ if not feat_steering_config["attrib_mode"] in ["contextual_decomposition", "cmi"]:
+ raise ValueError("The feature attribution mode is invalid.")
+ feat_to_encourage, feat_to_discourage = feat_steering_config["encourage"], feat_steering_config["discourage"]
+
+ # Feature attribution.
+ if feat_steering_config["attrib_mode"] == "contextual_decomposition":
+ scores_feat_to_encourage, _ = get_cd_1d_by_modules(self, self.layers, inputs, feat_to_encourage, device=self.device)
+ scores_feat_to_discourage, _ = get_cd_1d_by_modules(self, self.layers, inputs, feat_to_discourage, device=self.device)
+
+ elif feat_steering_config["attrib_mode"] == "cmi":
+ # Estimate CMI.
+ if len(feat_to_encourage) > 0:
+ scores_feat_to_encourage = torch.stack([mixed_cmi_model(inputs[:,feat], outputs, targets, feature_is_categorical=False, target_is_categorical=False) for feat in feat_to_encourage], 0)
+ # scores_feat_to_encourage = torch.stack([get_continuous_cmi(inputs[:,feat], outputs, z=targets, knn=0.2, seed=42) for feat in feat_to_encourage], 0)
+ else:
+ scores_feat_to_encourage = torch.tensor([]).float()
+ if len(feat_to_discourage) > 0:
+ scores_feat_to_discourage = torch.stack([mixed_cmi_model(inputs[:,feat], outputs, targets, feature_is_categorical=False, target_is_categorical=False) for feat in feat_to_discourage], 0)
+ # scores_feat_to_discourage = torch.stack([get_continuous_cmi(inputs[:,feat], outputs, z=targets, knn=0.2, seed=42) for feat in feat_to_discourage], 0)
+ else:
+ scores_feat_to_discourage = torch.tensor([]).float()
+
+ # Transform to [0,1].
+ # NOTE: Even though in theory CMI >= 0, in practice our estimates
+ # can be smaler than zero.
+ # Make sure that sqrt does not receive values < 0. In analogy to the
+ # Straight-Through Estimators (STEs) we apply our transformation only
+ # to inputs >= 0 and use the identity transformation for inputs < 0.
+ scores_feat_to_encourage[scores_feat_to_encourage > 0] = torch.sqrt(1 - torch.exp(-2*scores_feat_to_encourage[scores_feat_to_encourage > 0]))
+ scores_feat_to_discourage[scores_feat_to_discourage > 0] = torch.sqrt(1 - torch.exp(-2*scores_feat_to_discourage[scores_feat_to_discourage > 0]))
+
+ else:
+ raise NotImplementedError("The selected feature attribution mode is not yet implemented!")
+
+ # Small corrections:
+ # If there are no features to en- or discourage, we can explicitly set their contribution to 0.
+ if len(feat_to_encourage) == 0:
+ scores_feat_to_encourage = torch.tensor(0)
+ if len(feat_to_discourage) == 0:
+ scores_feat_to_discourage = torch.tensor(0)
+
+ # Feature steering.
+ if feat_steering_config["attrib_mode"] == "cmi":
+ # With the CMI estimates we can have negative values even though in theory CMI >= 0.
+ # L1 / L2 would emphasize them, but we want values < 0 to result in a smaller loss.
+ # L1-Loss:
+ # We know that our values should be almost > 0. Therefore, we apply the absolute value
+ # only to values >= 0 and the identity transformation to all others (analogous to
+ # Straight-Through Estimators, keeps gradients).
+ # In practice, this results in ignoring the absolute value.
+ #
+ # L2-Loss:
+ # Here, we also only square for values >= 0 and the identity transformation to all
+ # others.
+ if feat_steering_config["steering_mode"] == "loss_l2":
+ scores_feat_to_encourage[scores_feat_to_encourage >= 0] = torch.square(scores_feat_to_encourage[scores_feat_to_encourage >= 0])
+ scores_feat_to_discourage[scores_feat_to_discourage >= 0] = torch.square(scores_feat_to_discourage[scores_feat_to_discourage >= 0])
+ return feat_steering_config["lambda"] * (torch.sum(scores_feat_to_discourage) - torch.sum(scores_feat_to_encourage)) / inputs.size()[0] # Average over Batch
+
+ if feat_steering_config["lambda"] == 0:
+ return torch.tensor(0.0)
+ elif feat_steering_config["steering_mode"] == "loss_l1":
+ feat_steering_loss = feat_steering_config["lambda"] * (torch.sum(torch.abs(scores_feat_to_discourage)) - torch.sum(torch.abs(scores_feat_to_encourage)))
+ elif feat_steering_config["steering_mode"] == "loss_l2":
+ feat_steering_loss = feat_steering_config["lambda"] * (torch.sum(torch.square(scores_feat_to_discourage)) - torch.sum(torch.square(scores_feat_to_encourage)))
+ else:
+ raise NotImplementedError("The selected feature steering mode is not yet implemented!")
+
+ return feat_steering_loss / inputs.size()[0] # Average over Batch
+
+ def loss(self, inputs, targets, outputs, feat_steering_config=None):
+ # For MSE make sure that outputs is a 1D tensor. That is, we need to
+ # prevent tensors of shape torch.Size([batch_size, 1]).
+ if len(outputs.size()) > 1:
+ outputs = outputs.squeeze(axis=1)
+
+ # Compute default loss.
+ loss_func = nn.MSELoss()
+ loss = loss_func(outputs, targets)
+
+ # No feature steering if in evaluation mode or explicitly specified.
+ if not self.training or feat_steering_config["steering_mode"] == "none":
+ return loss
+ else:
+ feat_steering_loss = self.feat_steering_loss(inputs, targets, outputs, feat_steering_config=feat_steering_config)
+ if feat_steering_loss.isnan():
+ raise ValueError("The feature steering loss of your model is nan. Thus, no reasonable gradient can be computed! \
+ The feature steering config was: " + str(feat_steering_config) + ".")
+ return loss + feat_steering_loss
+
+ def train(self, train_dataloader, feat_steering_config, epochs=90, learning_rate=0.01):
+ optimizer = torch.optim.AdamW(self.layers.parameters(), lr=learning_rate)
+
+ for epoch in range(epochs):
+ epoch_loss = 0.0
+
+ for inputs, targets in train_dataloader:
+ # Pass data to GPU / CPU if necessary.
+ inputs, targets = inputs.to(self.device), targets.to(self.device)
+
+ # Zero the gradients.
+ optimizer.zero_grad()
+
+ # Perform forward pass.
+ outputs = self(inputs)
+ if outputs.isnan().any():
+ raise ValueError("The output of the model contains nan. Thus, no \
+ reasonable loss can be computed!")
+
+ # Calculate loss.
+ loss = self.loss(inputs, targets, outputs, feat_steering_config=feat_steering_config)
+
+ # Perform backward pass and modify weights accordingly.
+ if loss == torch.inf:
+ raise ValueError("The loss of your model is inf. Thus, no reasonable gradient can be computed!")
+ if loss.isnan():
+ raise ValueError("The loss of your model is nan. Thus, no reasonable gradient can be computed!")
+
+ loss.backward()
+ optimizer.step()
+
+ # Print statistics.
+ epoch_loss += loss.item()
+ print("Loss (per sample) after epoch " + str(epoch+1) + ": " + str(epoch_loss / len(train_dataloader)))
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..2555d89
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,8 @@
+acd
+jupyterlab
+numpy
+scikit-learn
+torch
+torchaudio
+torchvision
+torchmetrics
\ No newline at end of file
diff --git a/teaser.png b/teaser.png
new file mode 100644
index 0000000..dbe147c
Binary files /dev/null and b/teaser.png differ