Skip to content

Latest commit

 

History

History
127 lines (93 loc) · 6.91 KB

README.md

File metadata and controls

127 lines (93 loc) · 6.91 KB

ECCV 2024: Improving Intervention Efficacy via Concept Realignment in Concept Bottleneck Models

License: MIT Python 3.7+ Paper

CIRM Architecture

This repository contains the official Pytorch implementation of our paper "Improving Intervention Efficacy via Concept Realignment in Concept Bottleneck Models" accepted at ECCV 2024.

This work was done by Nishad Singhi, Jae Myung Kim, Karsten Roth, Zeynep Akata.

Abstract

Concept Bottleneck Models (CBMs) ground image classification on human-understandable concepts to allow for interpretable model decisions. Crucially, the CBM design inherently allows for human interventions, in which expert users are given the ability to modify potentially misaligned concept choices to influence the decision behavior of the model in an interpretable fashion. However, existing approaches often require numerous human interventions per image to achieve strong performances, posing practical challenges in scenarios where obtaining human feedback is expensive. In this paper, we find that this is noticeably driven by an independent treatment of concepts during intervention, wherein a change of one concept does not influence the use of other ones in the model's final decision. To address this issue, we introduce a trainable concept intervention realignment module, which leverages concept relations to realign concept assignments post-intervention. Across standard, real-world benchmarks, we find that concept realignment can significantly improve intervention efficacy; significantly reducing the number of interventions needed to reach a target classification performance or concept prediction accuracy. In addition, it easily integrates into existing concept-based architectures without requiring changes to the models themselves. This reduced cost of human-model collaboration is crucial to enhancing the feasibility of CBMs in resource-constrained environments.

Acknowledgements

This repository builds on the Github repository for Concept Embedding Models. We recommend reading their README.

Installation

You can locally install this package by first cloning this repository:

$ git clone https://github.com/ExplainableML/concept_realignment

Move into the cloned directory via cd concept_realignment
Set up the conda environment:

$ conda env create --name CEM --file=conda_environment.yml

Then install the cem package using:

$ python setup.py install

After running this, you should be able to import the package locally using

import cem

Experiment Reproducibility

Downloading the Datasets

In order to be able to properly run our experiments, you will have to download the pre-processed CUB dataset found here to cem/data/CUB200/ and the CelebA dataset found here to cem/data/celeba. You may opt to download them to different locations but their paths will have to be modified in the respective experiment configs (via the root_dir parameter) or via the DATASET_DIR environment variable.

Running Experiments

The scripts to run experiments are in this directory:

$ cd concept-realignment-experiments

Step 1: Training the Base CBM/ CEM models

In train_base_models_and_save_predictions.py, specify the config file to the model you want to train. The config files are located in experiments/configs. After this, run the script as:

$ python train_base_models_and_save_predictions.py

This script trains the base CBM or CEM model, then runs the model on the train and test datasets, and stores its predictions (and some other stuff) in a dictionary. This dictionary is typically saved in results_dir in the config file.

Step 2: Training the Realignment models

Next, we train the model that updates concept predictions after intervention. To do so, specify the relevant paths in run_optuna.sh. Then, compile it:

$ chmod +x run_optuna.sh

Then, run it as:

$ ./run_optuna.sh <model_name>

where <model_name> could be CUB_CEM, CUB_IntCEM, ... This script will perform hyperparameter optimization using the optuna package, and train the concept realignment model. The trained model is saved in the checkpoint_save_dir specified in concept-realignment-experiments/configs/<dataset>/base.yml.

Plot Concept Loss and Accuracy (Figs. 3, 4)

In concept_loss_vs_num_interventions.py, specify the path to the predictions from the base model in predictions_dict_path, and the path to the concept realignment model in concept_corrector_path. Then, you can run it as:

$ python3 concept_loss_vs_num_interventions.py

Compare performance on IntCEMs (Fig. 5)

For posthoc realignment, the procedure to train the concept realignment model and generate plots is the same as mentioned above. For joint training of the concept realignment model and the IntCEM, you can use the appropriate config from experiments/configs/intcem_configs/<dataset>.yaml. Then, you can generate the plots using intcem_concept_loss_vs_num_interventions.py. vanilla_predictions_dict_path refers to the predictions of the IntCEM model trained without concept realignment, and predictions_dict_path refers to the model trained jointly with concept realignment.

Architecture Ablations (Fig. 6)

First, you need to train the concept realignment models with the different architectures. You can do so by using the appropriate model_name in Step 2. Then, you need to specify paths to the checkpoints of these models in architecture_ablations.py and run it.

This script trains the base CBM or CEM model, then runs the model on the train and test datasets, and stores its predictions (and some other stuff) in a dictionary. This dictionary is typically saved in results_dir in the config file.

Coupling Ablations (Fig. 8)

$ python3 coupling_ablation.py

Citation

If you would like to cite this repository, or the accompanying paper, please use the following citation:

@article{singhi2024improving,
  title={Improving Intervention Efficacy via Concept Realignment in Concept Bottleneck Models},
  author={Singhi, Nishad and Kim, Jae Myung and Roth, Karsten and Akata, Zeynep},
  journal={European Conference on Computer Vision},
  year={2024}
}