Skip to content

Commit

Permalink
example notebook and network with
Browse files Browse the repository at this point in the history
added documentation
  • Loading branch information
j-bl committed Aug 28, 2023
1 parent f203613 commit 4df3c85
Show file tree
Hide file tree
Showing 8 changed files with 483 additions and 6 deletions.
42 changes: 38 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
# Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization

This repository provides code to use the method presented in our GCPR 2023 paper "Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization". If you use this method, please cite:
## Overview
This repository provides code to use the method presented in our GCPR 2023 paper **"Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization"**. If you want to get started, take a look at our [example network](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/regression_network.py) and the corresponding [jupyter notebook](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/feature_steering_example.ipynb).

<div style="text-align:center">
<img src="teaser.png" alt="By measuring the feature usage, we can steer the model towards (not) using features that are specifically (un-)desired." width="55%"/>
</div>

Our method generalizes from debiasing to the **encouragement and discouragement of arbitrary features**. That is, it not only aims at removing the influence of undesired features / biases but also at increasing the influence of features that are known to be well-established from domain knowledge.

If you use our method, please cite:

@inproceedings{Blunk23:FS,
author = {Jan Blunk and Niklas Penzel and Paul Bodesheim and Joachim Denzler},
Expand All @@ -9,8 +18,33 @@ This repository provides code to use the method presented in our GCPR 2023 paper
year = {2023},
}

This repository includes a Python implementation of the hybrid CMI estimator CMIh presented by [Zan et al.](https://doi.org/10.3390/e24091234) The authors' original R implementation can be found [here](https://github.com/leizan/CMIh2022). CMIh was published under the MIT license.

## Installation
First, you have to install ACD as described in the [Repository for "Hierarchical interpretations for neural network predictions" by Singh et al.](https://github.com/csinva/hierarchical-dnn-interpretations)
**Install with pip, Python and PyTorch 2.0+**

git clone https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing.git
cd beyond-debiasing
pip install -r requirements.txt

First, create an environment with pip and Python first (Anaconda environment / Python virtual environment). We recommend to install [PyTorch with CUDA support](https://pytorch.org/get-started/locally/). Then, you can install all subsequent packages via pip as described above.

## Usage in Python
Since our method relies on loss regularization, it is very simple to add to your own networks - you only need to modify your loss function. To help with that, we provide an [exemplary network](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/regression_network.py) and a [jupyter notebook](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/feature_steering_example.ipynb) with example code.

## Repository Organization
* Installation:
* [`requirements.txt`](requirements.txt): List of required packages for installation with pip
* Feature attribution:
* [`contextual_decomposition.py`](contextual_decomposition.py): Wrapper for contextual decomposition
* [`mixed_cmi_estimator.py`](mixed_cmi_estimator.py): Python port of the CMIh estimator of the conditional
* Redundant regression dataset:
* [`algebra.py`](algebra.py): Generation of random orthogonal matrices
* [`make_regression.py`](make_regression.py): An adapted version of scikit-learns make_regression(...), where the coefficients are standard-uniform
* [`regression_dataset.py`](regression_dataset.py): Generation of the redundant regression dataset
* [`dataset_utils.py`](dataset_utils.py): Creation of torch dataset from numpy arrays
* [`tensor_utils.py`](tensor_utils.py): Some helpful functions for dealing with tensors
* Example:
* [`feature_steering_example.ipynb`](feature_steering_example.ipynb): Example for generating the dataset, creating and training the network with detailed comments
* [`regression_network.py`](regression_network.py): Neural network (PyTorch) used in the example notebook

With [`mixed_cmi_estimator.py`](mixed_cmi_estimator.py) this repository includes a Python implementation of the hybrid CMI estimator CMIh presented by [Zan et al.](https://doi.org/10.3390/e24091234) The authors' original R implementation can be found [here](https://github.com/leizan/CMIh2022).

37 changes: 37 additions & 0 deletions dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

def get_dataset_from_arrays(train_features, train_outputs, test_features, test_outputs, validation_features=None, validation_outputs=None, batch_size=1):
"""
Both test and train dataset are numpy arrays. Observations are represented
as rows, features as columns.
train_targets and test_targets are vectors, containing one value per row
(expected results).
"""

train_inputs = torch.tensor(train_features.tolist())
train_targets = torch.FloatTensor(train_outputs)
train_dataset = torch.utils.data.TensorDataset(train_inputs, train_targets)

test_inputs = torch.tensor(test_features.tolist())
test_targets = torch.FloatTensor(test_outputs)
test_dataset = torch.utils.data.TensorDataset(test_inputs, test_targets)

train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=1
)

if not validation_features is None:
validation_inputs = torch.tensor(validation_features.tolist())
validation_targets = torch.FloatTensor(validation_outputs)
validation_dataset = torch.utils.data.TensorDataset(validation_inputs, validation_targets)

validation_loader = torch.utils.data.DataLoader(
validation_dataset, batch_size=batch_size, shuffle=False, num_workers=1
)

return (train_dataset, train_loader, test_dataset, test_loader, validation_dataset, validation_loader)
else:
return (train_dataset, train_loader, test_dataset, test_loader)
222 changes: 222 additions & 0 deletions feature_steering_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"from regression_dataset import make_regression_dataset\n",
"from regression_network import RegressionNetwork\n",
"\n",
"# Deterministic execution.\n",
"CUDA_LAUNCH_BLOCKING = 1\n",
"seed = 42\n",
"torch.manual_seed(seed)\n",
"np.random.seed(seed)\n",
"torch.backends.cudnn.benchmark = False\n",
"torch.backends.cudnn.deterministic = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example for the Application of the Feature Steering Method Presented in “Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization”\n",
"\n",
"This jupyter notebook provides an example for our method for the redundant regression dataset presented in our paper.\n",
"\n",
"You can choose to generate feature attributions with the feature attribution method provided by Reimers et al. based on both **contextual decomposition** and **conditional mutual information**. Additionally, you can choose other hyperparameters such as the weight factor $\\lambda$ and the norm that is applied (L1 / L2 norm).\n",
"\n",
"## Dataset\n",
"We create a small regression dataset with redundant variables as described in our paper. That is, the created dataset has 9 input variables with a redundancy of 3 variables. In total, we generate 2000 samples, of which 1400 are used for training.\n",
"\n",
"*Note:* In the evaluations for our paper we not only generate one, but rather 9 datasets with different seeds."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Configuration of the datasets.\n",
"high_dim_transform = True\n",
"normalize = True # Should only be set if HIGH_DIM_TRANSFORM\n",
"n_informative_low_dim = 6\n",
"n_high_dim = 9\n",
"n_train, n_test, n_validation = 1400, 300, 300\n",
"n_uninformative_low_dim = 0\n",
"dataset_seed = 42\n",
"batch_size = 100\n",
"n_datasets = 9\n",
"\n",
"noise_on_output = 0.0\n",
"noise_on_high_dim_snrdb = None"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Create and load the regression dataset.\n",
"train_dataloader, test_dataloader, validation_dataloader, _ = make_regression_dataset(\n",
" high_dim_transform=high_dim_transform,\n",
" n_features_low_dim=n_informative_low_dim,\n",
" n_uninformative_low_dim=n_uninformative_low_dim,\n",
" n_high_dim=n_high_dim,\n",
" noise_on_high_dim_snrdb=noise_on_high_dim_snrdb,\n",
" noise_on_output=noise_on_output,\n",
" n_train=n_train,\n",
" n_test=n_test,\n",
" n_validation=n_validation,\n",
" normalize=normalize,\n",
" seed=dataset_seed,\n",
" batch_size=batch_size,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Network\n",
"We follow the paper and create a network with a single hidden layer of size 9 and input size 9."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Network architecture.\n",
"input_size = n_high_dim\n",
"hidden_dim_size = n_high_dim\n",
"n_hidden_layers = 1\n",
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"# Create Network.\n",
"mlp = RegressionNetwork(\n",
" input_shape=input_size,\n",
" n_hidden_layers=n_hidden_layers,\n",
" hidden_dim_size=hidden_dim_size,\n",
" device=device,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After creating the network, we can train it with the *feature steering loss*.\n",
"\n",
"Recall from the paper that our method to steer feature usage is implemented via loss regularization. Let $D$ refer to the set of features that should be discouraged and $E$ to the set of features that should be encouraged. With $c_i$ being a measure of the influence of feature $i$ on the model's prediction process, $\\lambda \\in \\mathbb{R}_{\\ge 0}$ as a weight factor and $\\mathcal{L}$ as the standard maximum-likelihood loss for network parameters $\\theta$, our model is trained with the following loss function:\n",
"\n",
"$$ \\mathcal{L}'(\\theta) = \\mathcal{L}(\\theta) + \\lambda \\left( \\sum_{i \\in D} || c_i || - \\sum_{i \\in E} || c_i || \\right) .$$\n",
"For $|| \\cdot ||$, we consider the L1 and L2 norms.\n",
"\n",
"**Parameters:**\n",
"\n",
"Our implementation allows you to choose several *hyperparameters* for the feature steering process. You can adapt the following aspects of the calculation of the loss function:\n",
"\n",
"* The feature attributions $c_i$ are generated based on the feature attribution method proposed by Reimers et al. For this, the attribution modes `cmi` for feature attribution based on the (transformed) conditional mutual information and `contextual_decomposition` for feature attribution performed with contextual decomposition are available.\n",
"* Feature steering can be performed with feature attributions weighted with L1 norm (`loss_l1`) and L2 norm (`loss_l2`). That is, this modifies the norm applied for $|| \\cdot ||$.\n",
"* The indices of the features that shall be encouraged or discouraged (defining $D$ and $E$) are passed as lists.\n",
"* The weight factor $\\lambda$ is specified as `lambda`."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss (per sample) after epoch 1: 4712089.607142857\n",
"Loss (per sample) after epoch 2: 4480544.142857143\n",
"Loss (per sample) after epoch 3: 4258867.017857143\n",
"Loss (per sample) after epoch 4: 4050848.214285714\n",
"Loss (per sample) after epoch 5: 3851129.5714285714\n",
"Loss (per sample) after epoch 6: 3662716.375\n",
"Loss (per sample) after epoch 7: 3484030.3214285714\n",
"Loss (per sample) after epoch 8: 3317511.035714286\n",
"Loss (per sample) after epoch 9: 3158147.5714285714\n",
"Loss (per sample) after epoch 10: 3006144.9821428573\n",
"Loss (per sample) after epoch 11: 2864496.8035714286\n",
"Loss (per sample) after epoch 12: 2727674.410714286\n",
"Loss (per sample) after epoch 13: 2597680.910714286\n",
"Loss (per sample) after epoch 14: 2478867.535714286\n",
"Loss (per sample) after epoch 15: 2361367.4553571427\n",
"Loss (per sample) after epoch 16: 2251085.125\n",
"Loss (per sample) after epoch 17: 2148403.6160714286\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[9], line 13\u001b[0m\n\u001b[1;32m 4\u001b[0m feat_steering_config \u001b[39m=\u001b[39m {\n\u001b[1;32m 5\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mattrib_mode\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mcmi\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 6\u001b[0m \u001b[39m\"\u001b[39m\u001b[39msteering_mode\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mloss_l2\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mlambda\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m100.0\u001b[39m, \u001b[39m# Adapt accordingly for CMI / CD\u001b[39;00m\n\u001b[1;32m 10\u001b[0m }\n\u001b[1;32m 12\u001b[0m \u001b[39m# Train the network.\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m mlp\u001b[39m.\u001b[39;49mtrain(train_dataloader, feat_steering_config, epochs, learning_rate)\n",
"File \u001b[0;32m~/OneDrive/Publikationen/2023 - Feature Steering GCPR/Offizielles Repository/beyond-debiasing/regression_network.py:169\u001b[0m, in \u001b[0;36mRegressionNetwork.train\u001b[0;34m(self, train_dataloader, feat_steering_config, epochs, learning_rate)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[39mif\u001b[39;00m loss\u001b[39m.\u001b[39misnan():\n\u001b[1;32m 167\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mThe loss of your model is nan. Thus, no reasonable gradient can be computed!\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 169\u001b[0m loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m 170\u001b[0m optimizer\u001b[39m.\u001b[39mstep()\n\u001b[1;32m 172\u001b[0m \u001b[39m# Print statistics.\u001b[39;00m\n",
"File \u001b[0;32m~/anaconda3/envs/featuresteering-minimal/lib/python3.11/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[1;32m 478\u001b[0m \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 479\u001b[0m Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[1;32m 480\u001b[0m (\u001b[39mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 485\u001b[0m inputs\u001b[39m=\u001b[39minputs,\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\n\u001b[1;32m 488\u001b[0m \u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs\n\u001b[1;32m 489\u001b[0m )\n",
"File \u001b[0;32m~/anaconda3/envs/featuresteering-minimal/lib/python3.11/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 195\u001b[0m retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[1;32m 197\u001b[0m \u001b[39m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 199\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward( \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m tensors, grad_tensors_, retain_graph, create_graph, inputs,\n\u001b[1;32m 202\u001b[0m allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# Training configuration.\n",
"learning_rate = 0.01\n",
"epochs = 90\n",
"feat_steering_config = {\n",
" \"attrib_mode\": \"cmi\",\n",
" \"steering_mode\": \"loss_l2\",\n",
" \"encourage\": [0, 1, 2],\n",
" \"discourage\": [],\n",
" \"lambda\": 100.0, # Adapt accordingly for CMI / CD\n",
"}\n",
"\n",
"# Train the network.\n",
"mlp.train(train_dataloader, feat_steering_config, epochs, learning_rate)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "featuresteering",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
File renamed without changes.
4 changes: 2 additions & 2 deletions regression_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import numpy as np
from toy_data.algebra import random_orthogonal_matrix
from toy_data.sklearn_adaptions import make_regression
from algebra import random_orthogonal_matrix
from make_regression import make_regression
from dataset_utils import get_dataset_from_arrays

def make_regression_dataset(high_dim_transform=True, n_features_low_dim=4, n_uninformative_low_dim=0, n_high_dim = 128, noise_on_high_dim_snrdb=None,
Expand Down
Loading

0 comments on commit 4df3c85

Please sign in to comment.