Skip to content

Commit

Permalink
Merge pull request #25 from mmschlk/development
Browse files Browse the repository at this point in the history
Add initial explainer.
  • Loading branch information
mmschlk authored Dec 7, 2023
2 parents eb64dda + be53703 commit 893c7a3
Show file tree
Hide file tree
Showing 30 changed files with 652 additions and 29 deletions.
26 changes: 23 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
</a>

<!-- Coverage Test -->
<a href='https://coveralls.io/github/mmschlk/shapiq'>
<img src='https://coveralls.io/repos/github/mmschlk/shapiq/badge.svg' alt='Coverage Status' />
<a href='https://coveralls.io/github/mmschlk/shapiq?branch=main'>
<img src='https://coveralls.io/repos/github/mmschlk/shapiq/badge.svg?branch=main' alt='Coverage Status' />
</a>

<!-- Read the Docs -->
<a href='https://shapiq.readthedocs.io/en/latest/?badge=latest'>
<img src='https://readthedocs.org/projects/shapiq/badge/?version=latest' alt='Documentation Status' />
Expand Down Expand Up @@ -83,3 +83,23 @@ The pseudo-code above can produce the following plot (here also an image is adde

## 📖 Documentation
The documentation for ``shapiq`` can be found [here](https://shapiq.readthedocs.io/en/latest/).

## 💬 Citation

If you **ejnoy** `shapiq` consider starring ⭐ the repository. If you **really enjoy** the package or it has been useful to you, and you would like to cite it in a scientific publication, please refer to the [paper](https://openreview.net/forum?id=IEMLNF4gK4) accepted at NeurIPS'23:

```bibtex
@article{shapiq,
author = {Fabian Fumagalli and
Maximilian Muschalik and
Patrick Kolpaczki and
Eyke H{\"{u}}llermeier and
Barbara Hammer},
title = {{SHAP-IQ:} Unified Approximation of any-order Shapley Interactions},
journal = {CoRR},
volume = {abs/2303.01179},
year = {2023},
doi = {10.48550/ARXIV.2303.01179},
eprinttype = {arXiv}
}
```
Binary file modified requirements.txt
Binary file not shown.
4 changes: 2 additions & 2 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

# explainer classes
from .explainer import Explainer
from .explainer import InteractionExplainer

# game classes
from .games import DummyGame
Expand All @@ -40,7 +40,7 @@
"RegressionSII",
"RegressionFSI",
# explainers
"Explainer",
"InteractionExplainer",
# games
"DummyGame",
# plots
Expand Down
3 changes: 2 additions & 1 deletion shapiq/approximator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from ._base import convert_nsii_into_one_dimension, transforms_sii_to_nsii # TODO add to tests
from .permutation.sii import PermutationSamplingSII
from .permutation.sti import PermutationSamplingSTI
from .regression import RegressionSII, RegressionFSI
from .regression import RegressionSII, RegressionFSI, KernelSHAP
from .shapiq import ShapIQ

__all__ = [
"PermutationSamplingSII",
"PermutationSamplingSTI",
"KernelSHAP",
"RegressionFSI",
"RegressionSII",
"ShapIQ",
Expand Down
14 changes: 8 additions & 6 deletions shapiq/approximator/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scipy.special import binom, bernoulli
from utils import get_explicit_subsets, powerset, split_subsets_budget

AVAILABLE_INDICES = {"SII", "nSII", "STI", "FSI"}
AVAILABLE_INDICES = {"SII", "nSII", "STI", "FSI", "SV"}


__all__ = [
Expand Down Expand Up @@ -50,10 +50,9 @@ class InteractionValues:

def __post_init__(self) -> None:
"""Checks if the index is valid."""
if self.index not in ["SII", "nSII", "STI", "FSI"]:
if self.index not in AVAILABLE_INDICES:
raise ValueError(
f"Index {self.index} is not valid. "
f"Available indices are 'SII', 'nSII', 'STI', and 'FSI'."
f"Index {self.index} is not valid. " f"Available indices are {AVAILABLE_INDICES}."
)
if self.interaction_lookup is None:
self.interaction_lookup = _generate_interaction_lookup(
Expand All @@ -67,10 +66,13 @@ def __repr__(self) -> str:
f" index={self.index}, max_order={self.max_order}, min_order={self.min_order}"
f", estimated={self.estimated}, estimation_budget={self.estimation_budget},\n"
) + " values={\n"
for interaction in powerset(set(range(self.n_players)), min_size=1, max_size=2):
for interaction in powerset(
set(range(self.n_players)), min_size=1, max_size=self.max_order
):
representation += f" {interaction}: "
interaction_value = str(round(self[interaction], 4))
interaction_value = interaction_value.replace("-0.0", "0.0").replace("0.0", "0")
interaction_value = interaction_value.replace("-0.0", "0.0").replace(" 0.0", " 0")
interaction_value = interaction_value.replace("0.0 ", "0 ")
representation += f"{interaction_value},\n"
representation = representation[:-2] # remove last "," and add closing bracket
representation += "\n }\n)"
Expand Down
3 changes: 2 additions & 1 deletion shapiq/approximator/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"""
from .sii import RegressionSII
from .fsi import RegressionFSI
from .sv import KernelSHAP

__all__ = ["RegressionSII", "RegressionFSI"]
__all__ = ["RegressionSII", "RegressionFSI", "KernelSHAP"]
4 changes: 2 additions & 2 deletions shapiq/approximator/regression/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from utils import powerset, get_explicit_subsets

AVAILABLE_INDICES_REGRESSION = ["FSI", "SII"]
AVAILABLE_INDICES_REGRESSION = ["FSI", "SII", "SV"]


class Regression(Approximator, ShapleySamplingMixin):
Expand Down Expand Up @@ -120,7 +120,7 @@ def approximate(
# if SII is used regression_subsets needs to be changed
if self.index == "SII":
regression_subsets, num_players = self._get_sii_subset_representation(all_subsets) # A
else:
else: # FSI or SV
regression_subsets, num_players = self._get_fsi_subset_representation(all_subsets) # A

# initialize the regression variables
Expand Down
9 changes: 7 additions & 2 deletions shapiq/approximator/regression/fsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class RegressionFSI(Regression, NShapleyMixin):
"""Estimates the FSI values using the weighted least square approach.
"""Estimates the FSI values [1] using the weighted least square approach.
Args:
n: The number of players.
Expand All @@ -20,11 +20,16 @@ class RegressionFSI(Regression, NShapleyMixin):
min_order: The minimum order of the approximation. For FSI, min_order is equal to 1.
iteration_cost: The cost of a single iteration of the regression FSI.
References:
[1]: Tsai, C.-P., Yeh, C.-K., & Ravikumar, P. (2023). Faith-Shap: The Faithful Shapley
Interaction Index. J. Mach. Learn. Res., 24, 94:1-94:42. Retrieved from
http://jmlr.org/papers/v24/22-0202.html
Example:
>>> from games import DummyGame
>>> from approximator import RegressionFSI
>>> game = DummyGame(n=5, interaction=(1, 2))
>>> approximator = RegressionFsi(n=5, max_order=2)
>>> approximator = RegressionFSI(n=5, max_order=2)
>>> approximator.approximate(budget=100, game=game)
InteractionValues(
index=FSI, order=2, estimated=False, estimation_budget=32,
Expand Down
43 changes: 43 additions & 0 deletions shapiq/approximator/regression/sv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""This module contains the KernelSHAP regression approximator for estimating the SV."""

"""Regression with Faithful Shapley Interaction (FSI) index approximation."""
from typing import Optional

from ._base import Regression


class KernelSHAP(Regression):
"""Estimates the FSI values using the weighted least square approach.
Args:
n: The number of players.
max_order: The interaction order of the approximation.
random_state: The random state of the estimator. Defaults to `None`.
Attributes:
n: The number of players.
N: The set of players (starting from 0 to n - 1).
max_order: The interaction order of the approximation.
min_order: The minimum order of the approximation. For FSI, min_order is equal to 1.
iteration_cost: The cost of a single iteration of the regression FSI.
Example:
>>> from games import DummyGame
>>> from approximator import KernelSHAP
>>> game = DummyGame(n=5, interaction=(1, 2))
>>> approximator = KernelSHAP(n=5)
>>> approximator.approximate(budget=100, game=game)
InteractionValues(
index=SV, order=1, estimated=False, estimation_budget=32,
values={
(0,): 0.2,
(1,): 0.7,
(2,): 0.7,
(3,): 0.2,
(4,): 0.2,
}
)
"""

def __init__(self, n: int, random_state: Optional[int] = None):
super().__init__(n, max_order=1, index="SV", random_state=random_state)
5 changes: 3 additions & 2 deletions shapiq/explainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""This module contains the explainer for the shapiq package."""

from ._base import Explainer

from .interaction import InteractionExplainer

__all__ = [
"Explainer",
"InteractionExplainer",
]
31 changes: 23 additions & 8 deletions shapiq/explainer/_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
"""This module contains the base explainer classes for the shapiq package."""
import warnings
from abc import ABC, abstractmethod
from typing import Any
from typing import Callable

import numpy as np

from approximator._base import InteractionValues
from explainer.imputer.marginal_imputer import MarginalImputer


class Explainer(ABC):
"""The base class for all explainers in the shapiq package."""
"""The base class for all explainers in the shapiq package.
Args:
model: The model to explain as a callable function expecting a data points as input and
returning the model's predictions.
background_data: The background data to use for the explainer.
"""

@abstractmethod
def __init__(self) -> None:
"""Initializes the explainer."""
warnings.warn("Explainer is not implemented yet.")
def __init__(
self, model: Callable[[np.ndarray], np.ndarray], background_data: np.ndarray
) -> None:
self._model = model
self._background_data = background_data
self._n_features = self._background_data.shape[1]
self._imputer = MarginalImputer(self._model, self._background_data)

@abstractmethod
def explain(self) -> Any:
warnings.warn("Explainer is not implemented yet.")
def explain(self, x_explain: np.ndarray) -> InteractionValues:
"""Explains the model's predictions."""
raise NotImplementedError("Method `explain` must be implemented in a subclass.")
5 changes: 5 additions & 0 deletions shapiq/explainer/imputer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""This module contains the imputer for the shapiq package."""

from .marginal_imputer import MarginalImputer

__all__ = ["MarginalImputer"]
47 changes: 47 additions & 0 deletions shapiq/explainer/imputer/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Base class for imputers."""
from abc import abstractmethod
from typing import Callable, Optional

import numpy as np


class Imputer:
"""Base class for imputers.
Args:
model: The model to explain as a callable function expecting a data points as input and
returning the model's predictions.
background_data: The background data to use for the explainer as a two-dimensional array
with shape (n_samples, n_features).
categorical_features: A list of indices of the categorical features in the background data.
random_state: The random state to use for sampling. Defaults to `None`.
"""

@abstractmethod
def __init__(
self,
model: Callable[[np.ndarray], np.ndarray],
background_data: np.ndarray,
categorical_features: list[int] = None,
random_state: Optional[int] = None,
) -> None:
self._model = model
self._background_data = background_data
self._n_features = self._background_data.shape[1]
self._cat_features: list = [] if categorical_features is None else categorical_features
self._random_state = random_state
self._rng = np.random.default_rng(self._random_state)

@abstractmethod
def __call__(self, subsets: np.ndarray[bool]) -> np.ndarray[float]:
"""Imputes the missing values of a data point and calls the model.
Args:
subsets: A boolean array indicating which features are present (`True`) and which are
missing (`False`). The shape of the array must be (n_subsets, n_features).
Returns:
The model's predictions on the imputed data points. The shape of the array is
(n_subsets, n_outputs).
"""
raise NotImplementedError("Method `__call__` must be implemented in a subclass.")
Loading

0 comments on commit 893c7a3

Please sign in to comment.