Skip to content

Commit

Permalink
adds InteractionExplainer and closes #24
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Dec 4, 2023
1 parent 4d64a17 commit f92ea50
Show file tree
Hide file tree
Showing 9 changed files with 492 additions and 13 deletions.
4 changes: 2 additions & 2 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

# explainer classes
from .explainer import Explainer
from .explainer import InteractionExplainer

# game classes
from .games import DummyGame
Expand All @@ -40,7 +40,7 @@
"RegressionSII",
"RegressionFSI",
# explainers
"Explainer",
"InteractionExplainer",
# games
"DummyGame",
# plots
Expand Down
3 changes: 2 additions & 1 deletion shapiq/approximator/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def __repr__(self) -> str:
):
representation += f" {interaction}: "
interaction_value = str(round(self[interaction], 4))
interaction_value = interaction_value.replace("-0.0", "0.0").replace("0.0", "0")
interaction_value = interaction_value.replace("-0.0", "0.0").replace(" 0.0", " 0")
interaction_value = interaction_value.replace("0.0 ", "0 ")
representation += f"{interaction_value},\n"
representation = representation[:-2] # remove last "," and add closing bracket
representation += "\n }\n)"
Expand Down
5 changes: 3 additions & 2 deletions shapiq/explainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""This module contains the explainer for the shapiq package."""

from ._base import Explainer

from .interaction import InteractionExplainer

__all__ = [
"Explainer",
"InteractionExplainer",
]
31 changes: 23 additions & 8 deletions shapiq/explainer/_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
"""This module contains the base explainer classes for the shapiq package."""
import warnings
from abc import ABC, abstractmethod
from typing import Any
from typing import Callable

import numpy as np

from approximator._base import InteractionValues
from explainer.imputer.marginal_imputer import MarginalImputer


class Explainer(ABC):
"""The base class for all explainers in the shapiq package."""
"""The base class for all explainers in the shapiq package.
Args:
model: The model to explain as a callable function expecting a data points as input and
returning the model's predictions.
background_data: The background data to use for the explainer.
"""

@abstractmethod
def __init__(self) -> None:
"""Initializes the explainer."""
warnings.warn("Explainer is not implemented yet.")
def __init__(
self, model: Callable[[np.ndarray], np.ndarray], background_data: np.ndarray
) -> None:
self._model = model
self._background_data = background_data
self._n_features = self._background_data.shape[1]
self._imputer = MarginalImputer(self._model, self._background_data)

@abstractmethod
def explain(self) -> Any:
warnings.warn("Explainer is not implemented yet.")
def explain(self, x_explain: np.ndarray) -> InteractionValues:
"""Explains the model's predictions."""
raise NotImplementedError("Method `explain` must be implemented in a subclass.")
5 changes: 5 additions & 0 deletions shapiq/explainer/imputer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""This module contains the imputer for the shapiq package."""

from .marginal_imputer import MarginalImputer

__all__ = ["MarginalImputer"]
47 changes: 47 additions & 0 deletions shapiq/explainer/imputer/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Base class for imputers."""
from abc import abstractmethod
from typing import Callable, Optional

import numpy as np


class Imputer:
"""Base class for imputers.
Args:
model: The model to explain as a callable function expecting a data points as input and
returning the model's predictions.
background_data: The background data to use for the explainer as a two-dimensional array
with shape (n_samples, n_features).
categorical_features: A list of indices of the categorical features in the background data.
random_state: The random state to use for sampling. Defaults to `None`.
"""

@abstractmethod
def __init__(
self,
model: Callable[[np.ndarray], np.ndarray],
background_data: np.ndarray,
categorical_features: list[int] = None,
random_state: Optional[int] = None,
) -> None:
self._model = model
self._background_data = background_data
self._n_features = self._background_data.shape[1]
self._cat_features: list = [] if categorical_features is None else categorical_features
self._random_state = random_state
self._rng = np.random.default_rng(self._random_state)

@abstractmethod
def __call__(self, subsets: np.ndarray[bool]) -> np.ndarray[float]:
"""Imputes the missing values of a data point and calls the model.
Args:
subsets: A boolean array indicating which features are present (`True`) and which are
missing (`False`). The shape of the array must be (n_subsets, n_features).
Returns:
The model's predictions on the imputed data points. The shape of the array is
(n_subsets, n_outputs).
"""
raise NotImplementedError("Method `__call__` must be implemented in a subclass.")
159 changes: 159 additions & 0 deletions shapiq/explainer/imputer/marginal_imputer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""This module contains the marginal imputer for the shapiq package."""
from typing import Callable, Optional

import numpy as np

from explainer.imputer._base import Imputer


class MarginalImputer(Imputer):
"""The marginal imputer for the shapiq package.
The marginal imputer is used to impute the missing values of a data point by using the
marginal distribution of the background data.
Args:
model: The model to explain as a callable function expecting a data points as input and
returning the model's predictions.
background_data: The background data to use for the explainer as a two-dimensional array
with shape (n_samples, n_features).
sample_replacements: Whether to sample replacements from the background data or to use the
mean (for numerical features) or the median (for categorical features) of the background
data. Defaults to `False`.
sample_size: The number of samples to draw from the background data. Only used if
`sample_replacements` is `True`. Increasing this value will linearly increase the
runtime of the explainer. Defaults to `1`.
categorical_features: A list of indices of the categorical features in the background data.
If no categorical features are given, all features are assumed to be numerical or in
string format (where `np.mean` fails) features. Defaults to `None`.
Attributes:
replacement_data: The data to use for imputation. Either samples from the background data
or the mean/median of the background data.
empty_prediction: The model's prediction on an empty data point (all features missing).
"""

def __init__(
self,
model: Callable[[np.ndarray], np.ndarray],
background_data: np.ndarray,
x_explain: Optional[np.ndarray] = None,
sample_replacements: bool = True, # TODO: change to False
sample_size: int = 5,
categorical_features: list[int] = None,
random_state: Optional[int] = None,
) -> None:
super().__init__(model, background_data, categorical_features, random_state)
self._sample_replacements = sample_replacements
self._sample_size: int = sample_size
self.replacement_data: np.ndarray = np.zeros((1, self._n_features)) # will be overwritten
self.init_background(self._background_data)
self._x_explain: np.ndarray = np.zeros((1, self._n_features)) # will be overwritten @ fit
if x_explain is not None:
self.fit(x_explain)
self.empty_prediction: float = self._calc_empty_prediction()

def __call__(self, subsets: np.ndarray[bool]) -> np.ndarray[float]:
"""Imputes the missing values of a data point and calls the model.
Args:
subsets: A boolean array indicating which features are present (`True`) and which are
missing (`False`). The shape of the array must be (n_subsets, n_features).
Returns:
The model's predictions on the imputed data points. The shape of the array is
(n_subsets, n_outputs).
"""
n_subsets = subsets.shape[0]
data = np.tile(np.copy(self._x_explain), (n_subsets, 1))
if not self._sample_replacements:
replacement_data = np.tile(self.replacement_data, (n_subsets, 1))
data[~subsets] = replacement_data[~subsets]
outputs = self._model(data)
else:
# sampling from background returning array of shape (sample_size, n_subsets, n_features)
replacement_data = self._sample_replacement_values(subsets)
outputs = np.zeros((self._sample_size, n_subsets))
for i in range(self._sample_size):
replacements = replacement_data[i].reshape(n_subsets, self._n_features)
data[~subsets] = replacements[~subsets]
outputs[i] = self._model(data)
outputs = np.mean(outputs, axis=0) # average over the samples
outputs -= self.empty_prediction
return outputs

def init_background(self, x_background: np.ndarray) -> "MarginalImputer":
"""Initializes the imputer to the background data.
Args:
x_background: The background data to use for the imputer. The shape of the array must
be (n_samples, n_features).
Returns:
The initialized imputer.
"""
if self._sample_replacements:
self.replacement_data = x_background
else:
self.replacement_data = np.zeros((1, self._n_features))
for feature in range(self._n_features):
feature_column = x_background[:, feature]
if feature in self._cat_features:
summarized_feature = np.median(feature_column)
else:
try: # try to use mean for numerical features
summarized_feature = np.mean(feature_column)
except TypeError: # fallback to median for string features
summarized_feature = np.median(feature_column)
self.replacement_data[:, feature] = summarized_feature
return self

def fit(self, x_explain: np.ndarray[float]) -> "MarginalImputer":
"""Fits the imputer to the explanation point.
Args:
x_explain: The explanation point to use the imputer to.
Returns:
The fitted imputer.
"""
self._x_explain = x_explain
return self

def _sample_replacement_values(self, subsets: np.ndarray[bool]) -> np.ndarray:
"""Samples replacement values from the background data.
Args:
subsets: A boolean array indicating which features are present (`True`) and which are
missing (`False`). The shape of the array must be (n_subsets, n_features).
Returns:
The sampled replacement values. The shape of the array is (sample_size, n_subsets,
n_features).
"""
n_subsets = subsets.shape[0]
replacement_data = np.zeros((self._sample_size, n_subsets, self._n_features))
for feature in range(self._n_features):
sampled_feature_values = self._rng.choice(
self.replacement_data[:, feature], size=(self._sample_size, n_subsets), replace=True
)
replacement_data[:, :, feature] = sampled_feature_values
return replacement_data

def _calc_empty_prediction(self) -> float:
"""Runs the model on empty data points (all features missing) to get the empty prediction.
Returns:
The empty prediction.
"""
if self._sample_replacements:
shuffled_background = self._rng.permutation(self._background_data)
empty_predictions = self._model(shuffled_background)
empty_prediction = float(np.mean(empty_predictions))
return empty_prediction
empty_prediction = self._model(self.replacement_data)
try: # reshape to scalar if the model is weird
empty_prediction = float(empty_prediction)
except TypeError:
empty_prediction = float(empty_prediction[0])
return empty_prediction
Loading

0 comments on commit f92ea50

Please sign in to comment.