-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #25 from mmschlk/development
Add initial explainer.
- Loading branch information
Showing
30 changed files
with
652 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""This module contains the KernelSHAP regression approximator for estimating the SV.""" | ||
|
||
"""Regression with Faithful Shapley Interaction (FSI) index approximation.""" | ||
from typing import Optional | ||
|
||
from ._base import Regression | ||
|
||
|
||
class KernelSHAP(Regression): | ||
"""Estimates the FSI values using the weighted least square approach. | ||
Args: | ||
n: The number of players. | ||
max_order: The interaction order of the approximation. | ||
random_state: The random state of the estimator. Defaults to `None`. | ||
Attributes: | ||
n: The number of players. | ||
N: The set of players (starting from 0 to n - 1). | ||
max_order: The interaction order of the approximation. | ||
min_order: The minimum order of the approximation. For FSI, min_order is equal to 1. | ||
iteration_cost: The cost of a single iteration of the regression FSI. | ||
Example: | ||
>>> from games import DummyGame | ||
>>> from approximator import KernelSHAP | ||
>>> game = DummyGame(n=5, interaction=(1, 2)) | ||
>>> approximator = KernelSHAP(n=5) | ||
>>> approximator.approximate(budget=100, game=game) | ||
InteractionValues( | ||
index=SV, order=1, estimated=False, estimation_budget=32, | ||
values={ | ||
(0,): 0.2, | ||
(1,): 0.7, | ||
(2,): 0.7, | ||
(3,): 0.2, | ||
(4,): 0.2, | ||
} | ||
) | ||
""" | ||
|
||
def __init__(self, n: int, random_state: Optional[int] = None): | ||
super().__init__(n, max_order=1, index="SV", random_state=random_state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
"""This module contains the explainer for the shapiq package.""" | ||
|
||
from ._base import Explainer | ||
|
||
from .interaction import InteractionExplainer | ||
|
||
__all__ = [ | ||
"Explainer", | ||
"InteractionExplainer", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,32 @@ | ||
"""This module contains the base explainer classes for the shapiq package.""" | ||
import warnings | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
from typing import Callable | ||
|
||
import numpy as np | ||
|
||
from approximator._base import InteractionValues | ||
from explainer.imputer.marginal_imputer import MarginalImputer | ||
|
||
|
||
class Explainer(ABC): | ||
"""The base class for all explainers in the shapiq package.""" | ||
"""The base class for all explainers in the shapiq package. | ||
Args: | ||
model: The model to explain as a callable function expecting a data points as input and | ||
returning the model's predictions. | ||
background_data: The background data to use for the explainer. | ||
""" | ||
|
||
@abstractmethod | ||
def __init__(self) -> None: | ||
"""Initializes the explainer.""" | ||
warnings.warn("Explainer is not implemented yet.") | ||
def __init__( | ||
self, model: Callable[[np.ndarray], np.ndarray], background_data: np.ndarray | ||
) -> None: | ||
self._model = model | ||
self._background_data = background_data | ||
self._n_features = self._background_data.shape[1] | ||
self._imputer = MarginalImputer(self._model, self._background_data) | ||
|
||
@abstractmethod | ||
def explain(self) -> Any: | ||
warnings.warn("Explainer is not implemented yet.") | ||
def explain(self, x_explain: np.ndarray) -> InteractionValues: | ||
"""Explains the model's predictions.""" | ||
raise NotImplementedError("Method `explain` must be implemented in a subclass.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""This module contains the imputer for the shapiq package.""" | ||
|
||
from .marginal_imputer import MarginalImputer | ||
|
||
__all__ = ["MarginalImputer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
"""Base class for imputers.""" | ||
from abc import abstractmethod | ||
from typing import Callable, Optional | ||
|
||
import numpy as np | ||
|
||
|
||
class Imputer: | ||
"""Base class for imputers. | ||
Args: | ||
model: The model to explain as a callable function expecting a data points as input and | ||
returning the model's predictions. | ||
background_data: The background data to use for the explainer as a two-dimensional array | ||
with shape (n_samples, n_features). | ||
categorical_features: A list of indices of the categorical features in the background data. | ||
random_state: The random state to use for sampling. Defaults to `None`. | ||
""" | ||
|
||
@abstractmethod | ||
def __init__( | ||
self, | ||
model: Callable[[np.ndarray], np.ndarray], | ||
background_data: np.ndarray, | ||
categorical_features: list[int] = None, | ||
random_state: Optional[int] = None, | ||
) -> None: | ||
self._model = model | ||
self._background_data = background_data | ||
self._n_features = self._background_data.shape[1] | ||
self._cat_features: list = [] if categorical_features is None else categorical_features | ||
self._random_state = random_state | ||
self._rng = np.random.default_rng(self._random_state) | ||
|
||
@abstractmethod | ||
def __call__(self, subsets: np.ndarray[bool]) -> np.ndarray[float]: | ||
"""Imputes the missing values of a data point and calls the model. | ||
Args: | ||
subsets: A boolean array indicating which features are present (`True`) and which are | ||
missing (`False`). The shape of the array must be (n_subsets, n_features). | ||
Returns: | ||
The model's predictions on the imputed data points. The shape of the array is | ||
(n_subsets, n_outputs). | ||
""" | ||
raise NotImplementedError("Method `__call__` must be implemented in a subclass.") |
Oops, something went wrong.