-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adds InteractionExplainer and closes #24
- Loading branch information
Showing
9 changed files
with
492 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
"""This module contains the explainer for the shapiq package.""" | ||
|
||
from ._base import Explainer | ||
|
||
from .interaction import InteractionExplainer | ||
|
||
__all__ = [ | ||
"Explainer", | ||
"InteractionExplainer", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,32 @@ | ||
"""This module contains the base explainer classes for the shapiq package.""" | ||
import warnings | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
from typing import Callable | ||
|
||
import numpy as np | ||
|
||
from approximator._base import InteractionValues | ||
from explainer.imputer.marginal_imputer import MarginalImputer | ||
|
||
|
||
class Explainer(ABC): | ||
"""The base class for all explainers in the shapiq package.""" | ||
"""The base class for all explainers in the shapiq package. | ||
Args: | ||
model: The model to explain as a callable function expecting a data points as input and | ||
returning the model's predictions. | ||
background_data: The background data to use for the explainer. | ||
""" | ||
|
||
@abstractmethod | ||
def __init__(self) -> None: | ||
"""Initializes the explainer.""" | ||
warnings.warn("Explainer is not implemented yet.") | ||
def __init__( | ||
self, model: Callable[[np.ndarray], np.ndarray], background_data: np.ndarray | ||
) -> None: | ||
self._model = model | ||
self._background_data = background_data | ||
self._n_features = self._background_data.shape[1] | ||
self._imputer = MarginalImputer(self._model, self._background_data) | ||
|
||
@abstractmethod | ||
def explain(self) -> Any: | ||
warnings.warn("Explainer is not implemented yet.") | ||
def explain(self, x_explain: np.ndarray) -> InteractionValues: | ||
"""Explains the model's predictions.""" | ||
raise NotImplementedError("Method `explain` must be implemented in a subclass.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""This module contains the imputer for the shapiq package.""" | ||
|
||
from .marginal_imputer import MarginalImputer | ||
|
||
__all__ = ["MarginalImputer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
"""Base class for imputers.""" | ||
from abc import abstractmethod | ||
from typing import Callable, Optional | ||
|
||
import numpy as np | ||
|
||
|
||
class Imputer: | ||
"""Base class for imputers. | ||
Args: | ||
model: The model to explain as a callable function expecting a data points as input and | ||
returning the model's predictions. | ||
background_data: The background data to use for the explainer as a two-dimensional array | ||
with shape (n_samples, n_features). | ||
categorical_features: A list of indices of the categorical features in the background data. | ||
random_state: The random state to use for sampling. Defaults to `None`. | ||
""" | ||
|
||
@abstractmethod | ||
def __init__( | ||
self, | ||
model: Callable[[np.ndarray], np.ndarray], | ||
background_data: np.ndarray, | ||
categorical_features: list[int] = None, | ||
random_state: Optional[int] = None, | ||
) -> None: | ||
self._model = model | ||
self._background_data = background_data | ||
self._n_features = self._background_data.shape[1] | ||
self._cat_features: list = [] if categorical_features is None else categorical_features | ||
self._random_state = random_state | ||
self._rng = np.random.default_rng(self._random_state) | ||
|
||
@abstractmethod | ||
def __call__(self, subsets: np.ndarray[bool]) -> np.ndarray[float]: | ||
"""Imputes the missing values of a data point and calls the model. | ||
Args: | ||
subsets: A boolean array indicating which features are present (`True`) and which are | ||
missing (`False`). The shape of the array must be (n_subsets, n_features). | ||
Returns: | ||
The model's predictions on the imputed data points. The shape of the array is | ||
(n_subsets, n_outputs). | ||
""" | ||
raise NotImplementedError("Method `__call__` must be implemented in a subclass.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
"""This module contains the marginal imputer for the shapiq package.""" | ||
from typing import Callable, Optional | ||
|
||
import numpy as np | ||
|
||
from explainer.imputer._base import Imputer | ||
|
||
|
||
class MarginalImputer(Imputer): | ||
"""The marginal imputer for the shapiq package. | ||
The marginal imputer is used to impute the missing values of a data point by using the | ||
marginal distribution of the background data. | ||
Args: | ||
model: The model to explain as a callable function expecting a data points as input and | ||
returning the model's predictions. | ||
background_data: The background data to use for the explainer as a two-dimensional array | ||
with shape (n_samples, n_features). | ||
sample_replacements: Whether to sample replacements from the background data or to use the | ||
mean (for numerical features) or the median (for categorical features) of the background | ||
data. Defaults to `False`. | ||
sample_size: The number of samples to draw from the background data. Only used if | ||
`sample_replacements` is `True`. Increasing this value will linearly increase the | ||
runtime of the explainer. Defaults to `1`. | ||
categorical_features: A list of indices of the categorical features in the background data. | ||
If no categorical features are given, all features are assumed to be numerical or in | ||
string format (where `np.mean` fails) features. Defaults to `None`. | ||
Attributes: | ||
replacement_data: The data to use for imputation. Either samples from the background data | ||
or the mean/median of the background data. | ||
empty_prediction: The model's prediction on an empty data point (all features missing). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: Callable[[np.ndarray], np.ndarray], | ||
background_data: np.ndarray, | ||
x_explain: Optional[np.ndarray] = None, | ||
sample_replacements: bool = True, # TODO: change to False | ||
sample_size: int = 5, | ||
categorical_features: list[int] = None, | ||
random_state: Optional[int] = None, | ||
) -> None: | ||
super().__init__(model, background_data, categorical_features, random_state) | ||
self._sample_replacements = sample_replacements | ||
self._sample_size: int = sample_size | ||
self.replacement_data: np.ndarray = np.zeros((1, self._n_features)) # will be overwritten | ||
self.init_background(self._background_data) | ||
self._x_explain: np.ndarray = np.zeros((1, self._n_features)) # will be overwritten @ fit | ||
if x_explain is not None: | ||
self.fit(x_explain) | ||
self.empty_prediction: float = self._calc_empty_prediction() | ||
|
||
def __call__(self, subsets: np.ndarray[bool]) -> np.ndarray[float]: | ||
"""Imputes the missing values of a data point and calls the model. | ||
Args: | ||
subsets: A boolean array indicating which features are present (`True`) and which are | ||
missing (`False`). The shape of the array must be (n_subsets, n_features). | ||
Returns: | ||
The model's predictions on the imputed data points. The shape of the array is | ||
(n_subsets, n_outputs). | ||
""" | ||
n_subsets = subsets.shape[0] | ||
data = np.tile(np.copy(self._x_explain), (n_subsets, 1)) | ||
if not self._sample_replacements: | ||
replacement_data = np.tile(self.replacement_data, (n_subsets, 1)) | ||
data[~subsets] = replacement_data[~subsets] | ||
outputs = self._model(data) | ||
else: | ||
# sampling from background returning array of shape (sample_size, n_subsets, n_features) | ||
replacement_data = self._sample_replacement_values(subsets) | ||
outputs = np.zeros((self._sample_size, n_subsets)) | ||
for i in range(self._sample_size): | ||
replacements = replacement_data[i].reshape(n_subsets, self._n_features) | ||
data[~subsets] = replacements[~subsets] | ||
outputs[i] = self._model(data) | ||
outputs = np.mean(outputs, axis=0) # average over the samples | ||
outputs -= self.empty_prediction | ||
return outputs | ||
|
||
def init_background(self, x_background: np.ndarray) -> "MarginalImputer": | ||
"""Initializes the imputer to the background data. | ||
Args: | ||
x_background: The background data to use for the imputer. The shape of the array must | ||
be (n_samples, n_features). | ||
Returns: | ||
The initialized imputer. | ||
""" | ||
if self._sample_replacements: | ||
self.replacement_data = x_background | ||
else: | ||
self.replacement_data = np.zeros((1, self._n_features)) | ||
for feature in range(self._n_features): | ||
feature_column = x_background[:, feature] | ||
if feature in self._cat_features: | ||
summarized_feature = np.median(feature_column) | ||
else: | ||
try: # try to use mean for numerical features | ||
summarized_feature = np.mean(feature_column) | ||
except TypeError: # fallback to median for string features | ||
summarized_feature = np.median(feature_column) | ||
self.replacement_data[:, feature] = summarized_feature | ||
return self | ||
|
||
def fit(self, x_explain: np.ndarray[float]) -> "MarginalImputer": | ||
"""Fits the imputer to the explanation point. | ||
Args: | ||
x_explain: The explanation point to use the imputer to. | ||
Returns: | ||
The fitted imputer. | ||
""" | ||
self._x_explain = x_explain | ||
return self | ||
|
||
def _sample_replacement_values(self, subsets: np.ndarray[bool]) -> np.ndarray: | ||
"""Samples replacement values from the background data. | ||
Args: | ||
subsets: A boolean array indicating which features are present (`True`) and which are | ||
missing (`False`). The shape of the array must be (n_subsets, n_features). | ||
Returns: | ||
The sampled replacement values. The shape of the array is (sample_size, n_subsets, | ||
n_features). | ||
""" | ||
n_subsets = subsets.shape[0] | ||
replacement_data = np.zeros((self._sample_size, n_subsets, self._n_features)) | ||
for feature in range(self._n_features): | ||
sampled_feature_values = self._rng.choice( | ||
self.replacement_data[:, feature], size=(self._sample_size, n_subsets), replace=True | ||
) | ||
replacement_data[:, :, feature] = sampled_feature_values | ||
return replacement_data | ||
|
||
def _calc_empty_prediction(self) -> float: | ||
"""Runs the model on empty data points (all features missing) to get the empty prediction. | ||
Returns: | ||
The empty prediction. | ||
""" | ||
if self._sample_replacements: | ||
shuffled_background = self._rng.permutation(self._background_data) | ||
empty_predictions = self._model(shuffled_background) | ||
empty_prediction = float(np.mean(empty_predictions)) | ||
return empty_prediction | ||
empty_prediction = self._model(self.replacement_data) | ||
try: # reshape to scalar if the model is weird | ||
empty_prediction = float(empty_prediction) | ||
except TypeError: | ||
empty_prediction = float(empty_prediction[0]) | ||
return empty_prediction |
Oops, something went wrong.