diff --git a/README.md b/README.md index 8ec73e9..ff30c75 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,15 @@ # Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization -This repository provides code to use the method presented in our GCPR 2023 paper "Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization". If you use this method, please cite: +## Overview +This repository provides code to use the method presented in our GCPR 2023 paper **"Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization"**. If you want to get started, take a look at our [example network](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/regression_network.py) and the corresponding [jupyter notebook](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/feature_steering_example.ipynb). + +
+By measuring the feature usage, we can steer the model towards (not) using features that are specifically (un-)desired. +
+ +Our method generalizes from debiasing to the **encouragement and discouragement of arbitrary features**. That is, it not only aims at removing the influence of undesired features / biases but also at increasing the influence of features that are known to be well-established from domain knowledge. + +If you use our method, please cite: @inproceedings{Blunk23:FS, author = {Jan Blunk and Niklas Penzel and Paul Bodesheim and Joachim Denzler}, @@ -9,8 +18,33 @@ This repository provides code to use the method presented in our GCPR 2023 paper year = {2023}, } -This repository includes a Python implementation of the hybrid CMI estimator CMIh presented by [Zan et al.](https://doi.org/10.3390/e24091234) The authors' original R implementation can be found [here](https://github.com/leizan/CMIh2022). CMIh was published under the MIT license. - ## Installation -First, you have to install ACD as described in the [Repository for "Hierarchical interpretations for neural network predictions" by Singh et al.](https://github.com/csinva/hierarchical-dnn-interpretations) +**Install with pip, Python and PyTorch 2.0+** + + git clone https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing.git + cd beyond-debiasing + pip install -r requirements.txt + +First, create an environment with pip and Python first (Anaconda environment / Python virtual environment). We recommend to install [PyTorch with CUDA support](https://pytorch.org/get-started/locally/). Then, you can install all subsequent packages via pip as described above. + +## Usage in Python +Since our method relies on loss regularization, it is very simple to add to your own networks - you only need to modify your loss function. To help with that, we provide an [exemplary network](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/regression_network.py) and a [jupyter notebook](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/feature_steering_example.ipynb) with example code. + +## Repository Organization +* Installation: + * [`requirements.txt`](requirements.txt): List of required packages for installation with pip +* Feature attribution: + * [`contextual_decomposition.py`](contextual_decomposition.py): Wrapper for contextual decomposition + * [`mixed_cmi_estimator.py`](mixed_cmi_estimator.py): Python port of the CMIh estimator of the conditional +* Redundant regression dataset: + * [`algebra.py`](algebra.py): Generation of random orthogonal matrices + * [`make_regression.py`](make_regression.py): An adapted version of scikit-learns make_regression(...), where the coefficients are standard-uniform + * [`regression_dataset.py`](regression_dataset.py): Generation of the redundant regression dataset + * [`dataset_utils.py`](dataset_utils.py): Creation of torch dataset from numpy arrays + * [`tensor_utils.py`](tensor_utils.py): Some helpful functions for dealing with tensors +* Example: + * [`feature_steering_example.ipynb`](feature_steering_example.ipynb): Example for generating the dataset, creating and training the network with detailed comments + * [`regression_network.py`](regression_network.py): Neural network (PyTorch) used in the example notebook + +With [`mixed_cmi_estimator.py`](mixed_cmi_estimator.py) this repository includes a Python implementation of the hybrid CMI estimator CMIh presented by [Zan et al.](https://doi.org/10.3390/e24091234) The authors' original R implementation can be found [here](https://github.com/leizan/CMIh2022). diff --git a/dataset_utils.py b/dataset_utils.py new file mode 100644 index 0000000..ee6b039 --- /dev/null +++ b/dataset_utils.py @@ -0,0 +1,37 @@ +import torch + +def get_dataset_from_arrays(train_features, train_outputs, test_features, test_outputs, validation_features=None, validation_outputs=None, batch_size=1): + """ + Both test and train dataset are numpy arrays. Observations are represented + as rows, features as columns. + train_targets and test_targets are vectors, containing one value per row + (expected results). + """ + + train_inputs = torch.tensor(train_features.tolist()) + train_targets = torch.FloatTensor(train_outputs) + train_dataset = torch.utils.data.TensorDataset(train_inputs, train_targets) + + test_inputs = torch.tensor(test_features.tolist()) + test_targets = torch.FloatTensor(test_outputs) + test_dataset = torch.utils.data.TensorDataset(test_inputs, test_targets) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True + ) + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=1 + ) + + if not validation_features is None: + validation_inputs = torch.tensor(validation_features.tolist()) + validation_targets = torch.FloatTensor(validation_outputs) + validation_dataset = torch.utils.data.TensorDataset(validation_inputs, validation_targets) + + validation_loader = torch.utils.data.DataLoader( + validation_dataset, batch_size=batch_size, shuffle=False, num_workers=1 + ) + + return (train_dataset, train_loader, test_dataset, test_loader, validation_dataset, validation_loader) + else: + return (train_dataset, train_loader, test_dataset, test_loader) diff --git a/feature_steering_example.ipynb b/feature_steering_example.ipynb new file mode 100644 index 0000000..30a4017 --- /dev/null +++ b/feature_steering_example.ipynb @@ -0,0 +1,222 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "from regression_dataset import make_regression_dataset\n", + "from regression_network import RegressionNetwork\n", + "\n", + "# Deterministic execution.\n", + "CUDA_LAUNCH_BLOCKING = 1\n", + "seed = 42\n", + "torch.manual_seed(seed)\n", + "np.random.seed(seed)\n", + "torch.backends.cudnn.benchmark = False\n", + "torch.backends.cudnn.deterministic = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example for the Application of the Feature Steering Method Presented in “Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization”\n", + "\n", + "This jupyter notebook provides an example for our method for the redundant regression dataset presented in our paper.\n", + "\n", + "You can choose to generate feature attributions with the feature attribution method provided by Reimers et al. based on both **contextual decomposition** and **conditional mutual information**. Additionally, you can choose other hyperparameters such as the weight factor $\\lambda$ and the norm that is applied (L1 / L2 norm).\n", + "\n", + "## Dataset\n", + "We create a small regression dataset with redundant variables as described in our paper. That is, the created dataset has 9 input variables with a redundancy of 3 variables. In total, we generate 2000 samples, of which 1400 are used for training.\n", + "\n", + "*Note:* In the evaluations for our paper we not only generate one, but rather 9 datasets with different seeds." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Configuration of the datasets.\n", + "high_dim_transform = True\n", + "normalize = True # Should only be set if HIGH_DIM_TRANSFORM\n", + "n_informative_low_dim = 6\n", + "n_high_dim = 9\n", + "n_train, n_test, n_validation = 1400, 300, 300\n", + "n_uninformative_low_dim = 0\n", + "dataset_seed = 42\n", + "batch_size = 100\n", + "n_datasets = 9\n", + "\n", + "noise_on_output = 0.0\n", + "noise_on_high_dim_snrdb = None" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Create and load the regression dataset.\n", + "train_dataloader, test_dataloader, validation_dataloader, _ = make_regression_dataset(\n", + " high_dim_transform=high_dim_transform,\n", + " n_features_low_dim=n_informative_low_dim,\n", + " n_uninformative_low_dim=n_uninformative_low_dim,\n", + " n_high_dim=n_high_dim,\n", + " noise_on_high_dim_snrdb=noise_on_high_dim_snrdb,\n", + " noise_on_output=noise_on_output,\n", + " n_train=n_train,\n", + " n_test=n_test,\n", + " n_validation=n_validation,\n", + " normalize=normalize,\n", + " seed=dataset_seed,\n", + " batch_size=batch_size,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Network\n", + "We follow the paper and create a network with a single hidden layer of size 9 and input size 9." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Network architecture.\n", + "input_size = n_high_dim\n", + "hidden_dim_size = n_high_dim\n", + "n_hidden_layers = 1\n", + "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "# Create Network.\n", + "mlp = RegressionNetwork(\n", + " input_shape=input_size,\n", + " n_hidden_layers=n_hidden_layers,\n", + " hidden_dim_size=hidden_dim_size,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After creating the network, we can train it with the *feature steering loss*.\n", + "\n", + "Recall from the paper that our method to steer feature usage is implemented via loss regularization. Let $D$ refer to the set of features that should be discouraged and $E$ to the set of features that should be encouraged. With $c_i$ being a measure of the influence of feature $i$ on the model's prediction process, $\\lambda \\in \\mathbb{R}_{\\ge 0}$ as a weight factor and $\\mathcal{L}$ as the standard maximum-likelihood loss for network parameters $\\theta$, our model is trained with the following loss function:\n", + "\n", + "$$ \\mathcal{L}'(\\theta) = \\mathcal{L}(\\theta) + \\lambda \\left( \\sum_{i \\in D} || c_i || - \\sum_{i \\in E} || c_i || \\right) .$$\n", + "For $|| \\cdot ||$, we consider the L1 and L2 norms.\n", + "\n", + "**Parameters:**\n", + "\n", + "Our implementation allows you to choose several *hyperparameters* for the feature steering process. You can adapt the following aspects of the calculation of the loss function:\n", + "\n", + "* The feature attributions $c_i$ are generated based on the feature attribution method proposed by Reimers et al. For this, the attribution modes `cmi` for feature attribution based on the (transformed) conditional mutual information and `contextual_decomposition` for feature attribution performed with contextual decomposition are available.\n", + "* Feature steering can be performed with feature attributions weighted with L1 norm (`loss_l1`) and L2 norm (`loss_l2`). That is, this modifies the norm applied for $|| \\cdot ||$.\n", + "* The indices of the features that shall be encouraged or discouraged (defining $D$ and $E$) are passed as lists.\n", + "* The weight factor $\\lambda$ is specified as `lambda`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss (per sample) after epoch 1: 4712089.607142857\n", + "Loss (per sample) after epoch 2: 4480544.142857143\n", + "Loss (per sample) after epoch 3: 4258867.017857143\n", + "Loss (per sample) after epoch 4: 4050848.214285714\n", + "Loss (per sample) after epoch 5: 3851129.5714285714\n", + "Loss (per sample) after epoch 6: 3662716.375\n", + "Loss (per sample) after epoch 7: 3484030.3214285714\n", + "Loss (per sample) after epoch 8: 3317511.035714286\n", + "Loss (per sample) after epoch 9: 3158147.5714285714\n", + "Loss (per sample) after epoch 10: 3006144.9821428573\n", + "Loss (per sample) after epoch 11: 2864496.8035714286\n", + "Loss (per sample) after epoch 12: 2727674.410714286\n", + "Loss (per sample) after epoch 13: 2597680.910714286\n", + "Loss (per sample) after epoch 14: 2478867.535714286\n", + "Loss (per sample) after epoch 15: 2361367.4553571427\n", + "Loss (per sample) after epoch 16: 2251085.125\n", + "Loss (per sample) after epoch 17: 2148403.6160714286\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 13\u001b[0m\n\u001b[1;32m 4\u001b[0m feat_steering_config \u001b[39m=\u001b[39m {\n\u001b[1;32m 5\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mattrib_mode\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mcmi\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 6\u001b[0m \u001b[39m\"\u001b[39m\u001b[39msteering_mode\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mloss_l2\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mlambda\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m100.0\u001b[39m, \u001b[39m# Adapt accordingly for CMI / CD\u001b[39;00m\n\u001b[1;32m 10\u001b[0m }\n\u001b[1;32m 12\u001b[0m \u001b[39m# Train the network.\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m mlp\u001b[39m.\u001b[39;49mtrain(train_dataloader, feat_steering_config, epochs, learning_rate)\n", + "File \u001b[0;32m~/OneDrive/Publikationen/2023 - Feature Steering GCPR/Offizielles Repository/beyond-debiasing/regression_network.py:169\u001b[0m, in \u001b[0;36mRegressionNetwork.train\u001b[0;34m(self, train_dataloader, feat_steering_config, epochs, learning_rate)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[39mif\u001b[39;00m loss\u001b[39m.\u001b[39misnan():\n\u001b[1;32m 167\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mThe loss of your model is nan. Thus, no reasonable gradient can be computed!\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 169\u001b[0m loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m 170\u001b[0m optimizer\u001b[39m.\u001b[39mstep()\n\u001b[1;32m 172\u001b[0m \u001b[39m# Print statistics.\u001b[39;00m\n", + "File \u001b[0;32m~/anaconda3/envs/featuresteering-minimal/lib/python3.11/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[1;32m 478\u001b[0m \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 479\u001b[0m Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[1;32m 480\u001b[0m (\u001b[39mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 485\u001b[0m inputs\u001b[39m=\u001b[39minputs,\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\n\u001b[1;32m 488\u001b[0m \u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs\n\u001b[1;32m 489\u001b[0m )\n", + "File \u001b[0;32m~/anaconda3/envs/featuresteering-minimal/lib/python3.11/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 195\u001b[0m retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[1;32m 197\u001b[0m \u001b[39m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 199\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward( \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m tensors, grad_tensors_, retain_graph, create_graph, inputs,\n\u001b[1;32m 202\u001b[0m allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# Training configuration.\n", + "learning_rate = 0.01\n", + "epochs = 90\n", + "feat_steering_config = {\n", + " \"attrib_mode\": \"cmi\",\n", + " \"steering_mode\": \"loss_l2\",\n", + " \"encourage\": [0, 1, 2],\n", + " \"discourage\": [],\n", + " \"lambda\": 100.0, # Adapt accordingly for CMI / CD\n", + "}\n", + "\n", + "# Train the network.\n", + "mlp.train(train_dataloader, feat_steering_config, epochs, learning_rate)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "featuresteering", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mixed_cmiI_estimator.py b/mixed_cmi_estimator.py similarity index 100% rename from mixed_cmiI_estimator.py rename to mixed_cmi_estimator.py diff --git a/regression_dataset.py b/regression_dataset.py index 766756c..68d7cd9 100644 --- a/regression_dataset.py +++ b/regression_dataset.py @@ -1,7 +1,7 @@ import os import numpy as np -from toy_data.algebra import random_orthogonal_matrix -from toy_data.sklearn_adaptions import make_regression +from algebra import random_orthogonal_matrix +from make_regression import make_regression from dataset_utils import get_dataset_from_arrays def make_regression_dataset(high_dim_transform=True, n_features_low_dim=4, n_uninformative_low_dim=0, n_high_dim = 128, noise_on_high_dim_snrdb=None, diff --git a/regression_network.py b/regression_network.py new file mode 100644 index 0000000..d19acb5 --- /dev/null +++ b/regression_network.py @@ -0,0 +1,176 @@ +import torch +from torch import nn +from contextual_decomposition import get_cd_1d_by_modules +from mixed_cmi_estimator import mixed_cmi_model + +class RegressionNetwork(nn.Module): + def __init__(self, input_shape, n_hidden_layers=2, hidden_dim_size=32, device='cpu'): + super().__init__() + self.device = device + + # The network always has at least one hidden layer (input_shape -> 32). + # Make sure that n_hidden_layers is valid. + if n_hidden_layers < 1: + raise ValueError("The network cannot have less than 1 hidden layer.") + + # Generate and initialize hidden layers. + # Note: we only need to generate n_hidden_layers-1 hidden layers! + lin_layers = [nn.Linear(in_features=hidden_dim_size, out_features=hidden_dim_size)] * (n_hidden_layers - 1) + for lin_layer in lin_layers: + nn.init.xavier_uniform_(lin_layer.weight) + nn.init.zeros_(lin_layer.bias) + relus = [nn.ReLU()] * len(lin_layers) + + # Generate and intialize first and last layer. + input_layer = nn.Linear(input_shape, hidden_dim_size) + output_layer = nn.Linear(hidden_dim_size, 1) + for lin_layer in [input_layer, output_layer]: + nn.init.xavier_uniform_(lin_layer.weight) + nn.init.zeros_(lin_layer.bias) + + # Combine layers to a model. + modules = [ nn.Flatten(), + input_layer, + nn.ReLU(), + *[z for tuple in zip(lin_layers, relus) for z in tuple], + output_layer, + ] + self.layers = nn.Sequential(*modules) + self.to(device) + + def forward(self, x): + return self.layers(x) + + def feat_steering_loss(self, inputs, targets, outputs, feat_steering_config=None): + # Get configuration for feature steering. + # Do not perform feature steering if it is not desired. + if feat_steering_config["steering_mode"] == "none": + return torch.tensor(0.0) + elif not feat_steering_config["steering_mode"] in ["loss_l1", "loss_l2"]: + raise ValueError("The feature steering mode is invalid.") + if not feat_steering_config["attrib_mode"] in ["contextual_decomposition", "cmi"]: + raise ValueError("The feature attribution mode is invalid.") + feat_to_encourage, feat_to_discourage = feat_steering_config["encourage"], feat_steering_config["discourage"] + + # Feature attribution. + if feat_steering_config["attrib_mode"] == "contextual_decomposition": + scores_feat_to_encourage, _ = get_cd_1d_by_modules(self, self.layers, inputs, feat_to_encourage, device=self.device) + scores_feat_to_discourage, _ = get_cd_1d_by_modules(self, self.layers, inputs, feat_to_discourage, device=self.device) + + elif feat_steering_config["attrib_mode"] == "cmi": + # Estimate CMI. + if len(feat_to_encourage) > 0: + scores_feat_to_encourage = torch.stack([mixed_cmi_model(inputs[:,feat], outputs, targets, feature_is_categorical=False, target_is_categorical=False) for feat in feat_to_encourage], 0) + # scores_feat_to_encourage = torch.stack([get_continuous_cmi(inputs[:,feat], outputs, z=targets, knn=0.2, seed=42) for feat in feat_to_encourage], 0) + else: + scores_feat_to_encourage = torch.tensor([]).float() + if len(feat_to_discourage) > 0: + scores_feat_to_discourage = torch.stack([mixed_cmi_model(inputs[:,feat], outputs, targets, feature_is_categorical=False, target_is_categorical=False) for feat in feat_to_discourage], 0) + # scores_feat_to_discourage = torch.stack([get_continuous_cmi(inputs[:,feat], outputs, z=targets, knn=0.2, seed=42) for feat in feat_to_discourage], 0) + else: + scores_feat_to_discourage = torch.tensor([]).float() + + # Transform to [0,1]. + # NOTE: Even though in theory CMI >= 0, in practice our estimates + # can be smaler than zero. + # Make sure that sqrt does not receive values < 0. In analogy to the + # Straight-Through Estimators (STEs) we apply our transformation only + # to inputs >= 0 and use the identity transformation for inputs < 0. + scores_feat_to_encourage[scores_feat_to_encourage > 0] = torch.sqrt(1 - torch.exp(-2*scores_feat_to_encourage[scores_feat_to_encourage > 0])) + scores_feat_to_discourage[scores_feat_to_discourage > 0] = torch.sqrt(1 - torch.exp(-2*scores_feat_to_discourage[scores_feat_to_discourage > 0])) + + else: + raise NotImplementedError("The selected feature attribution mode is not yet implemented!") + + # Small corrections: + # If there are no features to en- or discourage, we can explicitly set their contribution to 0. + if len(feat_to_encourage) == 0: + scores_feat_to_encourage = torch.tensor(0) + if len(feat_to_discourage) == 0: + scores_feat_to_discourage = torch.tensor(0) + + # Feature steering. + if feat_steering_config["attrib_mode"] == "cmi": + # With the CMI estimates we can have negative values even though in theory CMI >= 0. + # L1 / L2 would emphasize them, but we want values < 0 to result in a smaller loss. + # L1-Loss: + # We know that our values should be almost > 0. Therefore, we apply the absolute value + # only to values >= 0 and the identity transformation to all others (analogous to + # Straight-Through Estimators, keeps gradients). + # In practice, this results in ignoring the absolute value. + # + # L2-Loss: + # Here, we also only square for values >= 0 and the identity transformation to all + # others. + if feat_steering_config["steering_mode"] == "loss_l2": + scores_feat_to_encourage[scores_feat_to_encourage >= 0] = torch.square(scores_feat_to_encourage[scores_feat_to_encourage >= 0]) + scores_feat_to_discourage[scores_feat_to_discourage >= 0] = torch.square(scores_feat_to_discourage[scores_feat_to_discourage >= 0]) + return feat_steering_config["lambda"] * (torch.sum(scores_feat_to_discourage) - torch.sum(scores_feat_to_encourage)) / inputs.size()[0] # Average over Batch + + if feat_steering_config["lambda"] == 0: + return torch.tensor(0.0) + elif feat_steering_config["steering_mode"] == "loss_l1": + feat_steering_loss = feat_steering_config["lambda"] * (torch.sum(torch.abs(scores_feat_to_discourage)) - torch.sum(torch.abs(scores_feat_to_encourage))) + elif feat_steering_config["steering_mode"] == "loss_l2": + feat_steering_loss = feat_steering_config["lambda"] * (torch.sum(torch.square(scores_feat_to_discourage)) - torch.sum(torch.square(scores_feat_to_encourage))) + else: + raise NotImplementedError("The selected feature steering mode is not yet implemented!") + + return feat_steering_loss / inputs.size()[0] # Average over Batch + + def loss(self, inputs, targets, outputs, feat_steering_config=None): + # For MSE make sure that outputs is a 1D tensor. That is, we need to + # prevent tensors of shape torch.Size([batch_size, 1]). + if len(outputs.size()) > 1: + outputs = outputs.squeeze(axis=1) + + # Compute default loss. + loss_func = nn.MSELoss() + loss = loss_func(outputs, targets) + + # No feature steering if in evaluation mode or explicitly specified. + if not self.training or feat_steering_config["steering_mode"] == "none": + return loss + else: + feat_steering_loss = self.feat_steering_loss(inputs, targets, outputs, feat_steering_config=feat_steering_config) + if feat_steering_loss.isnan(): + raise ValueError("The feature steering loss of your model is nan. Thus, no reasonable gradient can be computed! \ + The feature steering config was: " + str(feat_steering_config) + ".") + return loss + feat_steering_loss + + def train(self, train_dataloader, feat_steering_config, epochs=90, learning_rate=0.01): + optimizer = torch.optim.AdamW(self.layers.parameters(), lr=learning_rate) + + for epoch in range(epochs): + epoch_loss = 0.0 + + for inputs, targets in train_dataloader: + # Pass data to GPU / CPU if necessary. + inputs, targets = inputs.to(self.device), targets.to(self.device) + + # Zero the gradients. + optimizer.zero_grad() + + # Perform forward pass. + outputs = self(inputs) + if outputs.isnan().any(): + raise ValueError("The output of the model contains nan. Thus, no \ + reasonable loss can be computed!") + + # Calculate loss. + loss = self.loss(inputs, targets, outputs, feat_steering_config=feat_steering_config) + + # Perform backward pass and modify weights accordingly. + if loss == torch.inf: + raise ValueError("The loss of your model is inf. Thus, no reasonable gradient can be computed!") + if loss.isnan(): + raise ValueError("The loss of your model is nan. Thus, no reasonable gradient can be computed!") + + loss.backward() + optimizer.step() + + # Print statistics. + epoch_loss += loss.item() + print("Loss (per sample) after epoch " + str(epoch+1) + ": " + str(epoch_loss / len(train_dataloader))) + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2555d89 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +acd +jupyterlab +numpy +scikit-learn +torch +torchaudio +torchvision +torchmetrics \ No newline at end of file diff --git a/teaser.png b/teaser.png new file mode 100644 index 0000000..dbe147c Binary files /dev/null and b/teaser.png differ