diff --git a/README.md b/README.md
index dc3b4537..830c1c4e 100644
--- a/README.md
+++ b/README.md
@@ -9,10 +9,10 @@
-
-
+
+
-
+
@@ -83,3 +83,23 @@ The pseudo-code above can produce the following plot (here also an image is adde
## 📖 Documentation
The documentation for ``shapiq`` can be found [here](https://shapiq.readthedocs.io/en/latest/).
+
+## 💬 Citation
+
+If you **ejnoy** `shapiq` consider starring ⭐ the repository. If you **really enjoy** the package or it has been useful to you, and you would like to cite it in a scientific publication, please refer to the [paper](https://openreview.net/forum?id=IEMLNF4gK4) accepted at NeurIPS'23:
+
+```bibtex
+@article{shapiq,
+ author = {Fabian Fumagalli and
+ Maximilian Muschalik and
+ Patrick Kolpaczki and
+ Eyke H{\"{u}}llermeier and
+ Barbara Hammer},
+ title = {{SHAP-IQ:} Unified Approximation of any-order Shapley Interactions},
+ journal = {CoRR},
+ volume = {abs/2303.01179},
+ year = {2023},
+ doi = {10.48550/ARXIV.2303.01179},
+ eprinttype = {arXiv}
+}
+```
diff --git a/requirements.txt b/requirements.txt
index 174669c9..f900214e 100644
Binary files a/requirements.txt and b/requirements.txt differ
diff --git a/shapiq/__init__.py b/shapiq/__init__.py
index 2369cf86..94a456a7 100644
--- a/shapiq/__init__.py
+++ b/shapiq/__init__.py
@@ -13,7 +13,7 @@
)
# explainer classes
-from .explainer import Explainer
+from .explainer import InteractionExplainer
# game classes
from .games import DummyGame
@@ -40,7 +40,7 @@
"RegressionSII",
"RegressionFSI",
# explainers
- "Explainer",
+ "InteractionExplainer",
# games
"DummyGame",
# plots
diff --git a/shapiq/approximator/__init__.py b/shapiq/approximator/__init__.py
index dbc73d19..f05bbeec 100644
--- a/shapiq/approximator/__init__.py
+++ b/shapiq/approximator/__init__.py
@@ -2,12 +2,13 @@
from ._base import convert_nsii_into_one_dimension, transforms_sii_to_nsii # TODO add to tests
from .permutation.sii import PermutationSamplingSII
from .permutation.sti import PermutationSamplingSTI
-from .regression import RegressionSII, RegressionFSI
+from .regression import RegressionSII, RegressionFSI, KernelSHAP
from .shapiq import ShapIQ
__all__ = [
"PermutationSamplingSII",
"PermutationSamplingSTI",
+ "KernelSHAP",
"RegressionFSI",
"RegressionSII",
"ShapIQ",
diff --git a/shapiq/approximator/_base.py b/shapiq/approximator/_base.py
index 66d1559d..18bd0130 100644
--- a/shapiq/approximator/_base.py
+++ b/shapiq/approximator/_base.py
@@ -8,7 +8,7 @@
from scipy.special import binom, bernoulli
from utils import get_explicit_subsets, powerset, split_subsets_budget
-AVAILABLE_INDICES = {"SII", "nSII", "STI", "FSI"}
+AVAILABLE_INDICES = {"SII", "nSII", "STI", "FSI", "SV"}
__all__ = [
@@ -50,10 +50,9 @@ class InteractionValues:
def __post_init__(self) -> None:
"""Checks if the index is valid."""
- if self.index not in ["SII", "nSII", "STI", "FSI"]:
+ if self.index not in AVAILABLE_INDICES:
raise ValueError(
- f"Index {self.index} is not valid. "
- f"Available indices are 'SII', 'nSII', 'STI', and 'FSI'."
+ f"Index {self.index} is not valid. " f"Available indices are {AVAILABLE_INDICES}."
)
if self.interaction_lookup is None:
self.interaction_lookup = _generate_interaction_lookup(
@@ -67,10 +66,13 @@ def __repr__(self) -> str:
f" index={self.index}, max_order={self.max_order}, min_order={self.min_order}"
f", estimated={self.estimated}, estimation_budget={self.estimation_budget},\n"
) + " values={\n"
- for interaction in powerset(set(range(self.n_players)), min_size=1, max_size=2):
+ for interaction in powerset(
+ set(range(self.n_players)), min_size=1, max_size=self.max_order
+ ):
representation += f" {interaction}: "
interaction_value = str(round(self[interaction], 4))
- interaction_value = interaction_value.replace("-0.0", "0.0").replace("0.0", "0")
+ interaction_value = interaction_value.replace("-0.0", "0.0").replace(" 0.0", " 0")
+ interaction_value = interaction_value.replace("0.0 ", "0 ")
representation += f"{interaction_value},\n"
representation = representation[:-2] # remove last "," and add closing bracket
representation += "\n }\n)"
diff --git a/shapiq/approximator/regression/__init__.py b/shapiq/approximator/regression/__init__.py
index 203525b2..0693191b 100644
--- a/shapiq/approximator/regression/__init__.py
+++ b/shapiq/approximator/regression/__init__.py
@@ -2,5 +2,6 @@
"""
from .sii import RegressionSII
from .fsi import RegressionFSI
+from .sv import KernelSHAP
-__all__ = ["RegressionSII", "RegressionFSI"]
+__all__ = ["RegressionSII", "RegressionFSI", "KernelSHAP"]
diff --git a/shapiq/approximator/regression/_base.py b/shapiq/approximator/regression/_base.py
index 838ffe0d..9623145a 100644
--- a/shapiq/approximator/regression/_base.py
+++ b/shapiq/approximator/regression/_base.py
@@ -7,7 +7,7 @@
from utils import powerset, get_explicit_subsets
-AVAILABLE_INDICES_REGRESSION = ["FSI", "SII"]
+AVAILABLE_INDICES_REGRESSION = ["FSI", "SII", "SV"]
class Regression(Approximator, ShapleySamplingMixin):
@@ -120,7 +120,7 @@ def approximate(
# if SII is used regression_subsets needs to be changed
if self.index == "SII":
regression_subsets, num_players = self._get_sii_subset_representation(all_subsets) # A
- else:
+ else: # FSI or SV
regression_subsets, num_players = self._get_fsi_subset_representation(all_subsets) # A
# initialize the regression variables
diff --git a/shapiq/approximator/regression/fsi.py b/shapiq/approximator/regression/fsi.py
index 4d8fb56f..a0d888ad 100644
--- a/shapiq/approximator/regression/fsi.py
+++ b/shapiq/approximator/regression/fsi.py
@@ -6,7 +6,7 @@
class RegressionFSI(Regression, NShapleyMixin):
- """Estimates the FSI values using the weighted least square approach.
+ """Estimates the FSI values [1] using the weighted least square approach.
Args:
n: The number of players.
@@ -20,11 +20,16 @@ class RegressionFSI(Regression, NShapleyMixin):
min_order: The minimum order of the approximation. For FSI, min_order is equal to 1.
iteration_cost: The cost of a single iteration of the regression FSI.
+ References:
+ [1]: Tsai, C.-P., Yeh, C.-K., & Ravikumar, P. (2023). Faith-Shap: The Faithful Shapley
+ Interaction Index. J. Mach. Learn. Res., 24, 94:1-94:42. Retrieved from
+ http://jmlr.org/papers/v24/22-0202.html
+
Example:
>>> from games import DummyGame
>>> from approximator import RegressionFSI
>>> game = DummyGame(n=5, interaction=(1, 2))
- >>> approximator = RegressionFsi(n=5, max_order=2)
+ >>> approximator = RegressionFSI(n=5, max_order=2)
>>> approximator.approximate(budget=100, game=game)
InteractionValues(
index=FSI, order=2, estimated=False, estimation_budget=32,
diff --git a/shapiq/approximator/regression/sv.py b/shapiq/approximator/regression/sv.py
new file mode 100644
index 00000000..d1f9beaf
--- /dev/null
+++ b/shapiq/approximator/regression/sv.py
@@ -0,0 +1,43 @@
+"""This module contains the KernelSHAP regression approximator for estimating the SV."""
+
+"""Regression with Faithful Shapley Interaction (FSI) index approximation."""
+from typing import Optional
+
+from ._base import Regression
+
+
+class KernelSHAP(Regression):
+ """Estimates the FSI values using the weighted least square approach.
+
+ Args:
+ n: The number of players.
+ max_order: The interaction order of the approximation.
+ random_state: The random state of the estimator. Defaults to `None`.
+
+ Attributes:
+ n: The number of players.
+ N: The set of players (starting from 0 to n - 1).
+ max_order: The interaction order of the approximation.
+ min_order: The minimum order of the approximation. For FSI, min_order is equal to 1.
+ iteration_cost: The cost of a single iteration of the regression FSI.
+
+ Example:
+ >>> from games import DummyGame
+ >>> from approximator import KernelSHAP
+ >>> game = DummyGame(n=5, interaction=(1, 2))
+ >>> approximator = KernelSHAP(n=5)
+ >>> approximator.approximate(budget=100, game=game)
+ InteractionValues(
+ index=SV, order=1, estimated=False, estimation_budget=32,
+ values={
+ (0,): 0.2,
+ (1,): 0.7,
+ (2,): 0.7,
+ (3,): 0.2,
+ (4,): 0.2,
+ }
+ )
+ """
+
+ def __init__(self, n: int, random_state: Optional[int] = None):
+ super().__init__(n, max_order=1, index="SV", random_state=random_state)
diff --git a/shapiq/explainer/__init__.py b/shapiq/explainer/__init__.py
index fb9b0c4c..23b90c70 100644
--- a/shapiq/explainer/__init__.py
+++ b/shapiq/explainer/__init__.py
@@ -1,7 +1,8 @@
"""This module contains the explainer for the shapiq package."""
-from ._base import Explainer
+
+from .interaction import InteractionExplainer
__all__ = [
- "Explainer",
+ "InteractionExplainer",
]
diff --git a/shapiq/explainer/_base.py b/shapiq/explainer/_base.py
index 91d324b1..8716fa78 100644
--- a/shapiq/explainer/_base.py
+++ b/shapiq/explainer/_base.py
@@ -1,17 +1,32 @@
"""This module contains the base explainer classes for the shapiq package."""
-import warnings
from abc import ABC, abstractmethod
-from typing import Any
+from typing import Callable
+
+import numpy as np
+
+from approximator._base import InteractionValues
+from explainer.imputer.marginal_imputer import MarginalImputer
class Explainer(ABC):
- """The base class for all explainers in the shapiq package."""
+ """The base class for all explainers in the shapiq package.
+
+ Args:
+ model: The model to explain as a callable function expecting a data points as input and
+ returning the model's predictions.
+ background_data: The background data to use for the explainer.
+ """
@abstractmethod
- def __init__(self) -> None:
- """Initializes the explainer."""
- warnings.warn("Explainer is not implemented yet.")
+ def __init__(
+ self, model: Callable[[np.ndarray], np.ndarray], background_data: np.ndarray
+ ) -> None:
+ self._model = model
+ self._background_data = background_data
+ self._n_features = self._background_data.shape[1]
+ self._imputer = MarginalImputer(self._model, self._background_data)
@abstractmethod
- def explain(self) -> Any:
- warnings.warn("Explainer is not implemented yet.")
+ def explain(self, x_explain: np.ndarray) -> InteractionValues:
+ """Explains the model's predictions."""
+ raise NotImplementedError("Method `explain` must be implemented in a subclass.")
diff --git a/shapiq/explainer/imputer/__init__.py b/shapiq/explainer/imputer/__init__.py
new file mode 100644
index 00000000..40b0ab25
--- /dev/null
+++ b/shapiq/explainer/imputer/__init__.py
@@ -0,0 +1,5 @@
+"""This module contains the imputer for the shapiq package."""
+
+from .marginal_imputer import MarginalImputer
+
+__all__ = ["MarginalImputer"]
diff --git a/shapiq/explainer/imputer/_base.py b/shapiq/explainer/imputer/_base.py
new file mode 100644
index 00000000..3f309230
--- /dev/null
+++ b/shapiq/explainer/imputer/_base.py
@@ -0,0 +1,47 @@
+"""Base class for imputers."""
+from abc import abstractmethod
+from typing import Callable, Optional
+
+import numpy as np
+
+
+class Imputer:
+ """Base class for imputers.
+
+ Args:
+ model: The model to explain as a callable function expecting a data points as input and
+ returning the model's predictions.
+ background_data: The background data to use for the explainer as a two-dimensional array
+ with shape (n_samples, n_features).
+ categorical_features: A list of indices of the categorical features in the background data.
+ random_state: The random state to use for sampling. Defaults to `None`.
+ """
+
+ @abstractmethod
+ def __init__(
+ self,
+ model: Callable[[np.ndarray], np.ndarray],
+ background_data: np.ndarray,
+ categorical_features: list[int] = None,
+ random_state: Optional[int] = None,
+ ) -> None:
+ self._model = model
+ self._background_data = background_data
+ self._n_features = self._background_data.shape[1]
+ self._cat_features: list = [] if categorical_features is None else categorical_features
+ self._random_state = random_state
+ self._rng = np.random.default_rng(self._random_state)
+
+ @abstractmethod
+ def __call__(self, subsets: np.ndarray[bool]) -> np.ndarray[float]:
+ """Imputes the missing values of a data point and calls the model.
+
+ Args:
+ subsets: A boolean array indicating which features are present (`True`) and which are
+ missing (`False`). The shape of the array must be (n_subsets, n_features).
+
+ Returns:
+ The model's predictions on the imputed data points. The shape of the array is
+ (n_subsets, n_outputs).
+ """
+ raise NotImplementedError("Method `__call__` must be implemented in a subclass.")
diff --git a/shapiq/explainer/imputer/marginal_imputer.py b/shapiq/explainer/imputer/marginal_imputer.py
new file mode 100644
index 00000000..30768b14
--- /dev/null
+++ b/shapiq/explainer/imputer/marginal_imputer.py
@@ -0,0 +1,159 @@
+"""This module contains the marginal imputer for the shapiq package."""
+from typing import Callable, Optional
+
+import numpy as np
+
+from explainer.imputer._base import Imputer
+
+
+class MarginalImputer(Imputer):
+ """The marginal imputer for the shapiq package.
+
+ The marginal imputer is used to impute the missing values of a data point by using the
+ marginal distribution of the background data.
+
+ Args:
+ model: The model to explain as a callable function expecting a data points as input and
+ returning the model's predictions.
+ background_data: The background data to use for the explainer as a two-dimensional array
+ with shape (n_samples, n_features).
+ sample_replacements: Whether to sample replacements from the background data or to use the
+ mean (for numerical features) or the median (for categorical features) of the background
+ data. Defaults to `False`.
+ sample_size: The number of samples to draw from the background data. Only used if
+ `sample_replacements` is `True`. Increasing this value will linearly increase the
+ runtime of the explainer. Defaults to `1`.
+ categorical_features: A list of indices of the categorical features in the background data.
+ If no categorical features are given, all features are assumed to be numerical or in
+ string format (where `np.mean` fails) features. Defaults to `None`.
+
+ Attributes:
+ replacement_data: The data to use for imputation. Either samples from the background data
+ or the mean/median of the background data.
+ empty_prediction: The model's prediction on an empty data point (all features missing).
+ """
+
+ def __init__(
+ self,
+ model: Callable[[np.ndarray], np.ndarray],
+ background_data: np.ndarray,
+ x_explain: Optional[np.ndarray] = None,
+ sample_replacements: bool = False,
+ sample_size: int = 5,
+ categorical_features: list[int] = None,
+ random_state: Optional[int] = None,
+ ) -> None:
+ super().__init__(model, background_data, categorical_features, random_state)
+ self._sample_replacements = sample_replacements
+ self._sample_size: int = sample_size
+ self.replacement_data: np.ndarray = np.zeros((1, self._n_features)) # will be overwritten
+ self.init_background(self._background_data)
+ self._x_explain: np.ndarray = np.zeros((1, self._n_features)) # will be overwritten @ fit
+ if x_explain is not None:
+ self.fit(x_explain)
+ self.empty_prediction: float = self._calc_empty_prediction()
+
+ def __call__(self, subsets: np.ndarray[bool]) -> np.ndarray[float]:
+ """Imputes the missing values of a data point and calls the model.
+
+ Args:
+ subsets: A boolean array indicating which features are present (`True`) and which are
+ missing (`False`). The shape of the array must be (n_subsets, n_features).
+
+ Returns:
+ The model's predictions on the imputed data points. The shape of the array is
+ (n_subsets, n_outputs).
+ """
+ n_subsets = subsets.shape[0]
+ data = np.tile(np.copy(self._x_explain), (n_subsets, 1))
+ if not self._sample_replacements:
+ replacement_data = np.tile(self.replacement_data, (n_subsets, 1))
+ data[~subsets] = replacement_data[~subsets]
+ outputs = self._model(data)
+ else:
+ # sampling from background returning array of shape (sample_size, n_subsets, n_features)
+ replacement_data = self._sample_replacement_values(subsets)
+ outputs = np.zeros((self._sample_size, n_subsets))
+ for i in range(self._sample_size):
+ replacements = replacement_data[i].reshape(n_subsets, self._n_features)
+ data[~subsets] = replacements[~subsets]
+ outputs[i] = self._model(data)
+ outputs = np.mean(outputs, axis=0) # average over the samples
+ outputs -= self.empty_prediction
+ return outputs
+
+ def init_background(self, x_background: np.ndarray) -> "MarginalImputer":
+ """Initializes the imputer to the background data.
+
+ Args:
+ x_background: The background data to use for the imputer. The shape of the array must
+ be (n_samples, n_features).
+
+ Returns:
+ The initialized imputer.
+ """
+ if self._sample_replacements:
+ self.replacement_data = x_background
+ else:
+ self.replacement_data = np.zeros((1, self._n_features))
+ for feature in range(self._n_features):
+ feature_column = x_background[:, feature]
+ if feature in self._cat_features:
+ summarized_feature = np.median(feature_column)
+ else:
+ try: # try to use mean for numerical features
+ summarized_feature = np.mean(feature_column)
+ except TypeError: # fallback to median for string features
+ summarized_feature = np.median(feature_column)
+ self.replacement_data[:, feature] = summarized_feature
+ return self
+
+ def fit(self, x_explain: np.ndarray[float]) -> "MarginalImputer":
+ """Fits the imputer to the explanation point.
+
+ Args:
+ x_explain: The explanation point to use the imputer to.
+
+ Returns:
+ The fitted imputer.
+ """
+ self._x_explain = x_explain
+ return self
+
+ def _sample_replacement_values(self, subsets: np.ndarray[bool]) -> np.ndarray:
+ """Samples replacement values from the background data.
+
+ Args:
+ subsets: A boolean array indicating which features are present (`True`) and which are
+ missing (`False`). The shape of the array must be (n_subsets, n_features).
+
+ Returns:
+ The sampled replacement values. The shape of the array is (sample_size, n_subsets,
+ n_features).
+ """
+ n_subsets = subsets.shape[0]
+ replacement_data = np.zeros((self._sample_size, n_subsets, self._n_features))
+ for feature in range(self._n_features):
+ sampled_feature_values = self._rng.choice(
+ self.replacement_data[:, feature], size=(self._sample_size, n_subsets), replace=True
+ )
+ replacement_data[:, :, feature] = sampled_feature_values
+ return replacement_data
+
+ def _calc_empty_prediction(self) -> float:
+ """Runs the model on empty data points (all features missing) to get the empty prediction.
+
+ Returns:
+ The empty prediction.
+ """
+ if self._sample_replacements:
+ shuffled_background = self._rng.permutation(self._background_data)
+ empty_predictions = self._model(shuffled_background)
+ empty_prediction = float(np.mean(empty_predictions))
+ return empty_prediction
+ empty_prediction = self._model(self.replacement_data)
+ try: # reshape to scalar if the model is weird
+ empty_prediction = float(empty_prediction)
+ except TypeError:
+ empty_prediction = float(empty_prediction[0])
+ return empty_prediction
diff --git a/shapiq/explainer/interaction.py b/shapiq/explainer/interaction.py
new file mode 100644
index 00000000..623d3f57
--- /dev/null
+++ b/shapiq/explainer/interaction.py
@@ -0,0 +1,127 @@
+"""This module contains the interaction explainer for the shapiq package. This is the main interface
+for users of the shapiq package."""
+from typing import Callable, Union, Optional
+
+import numpy as np
+
+from approximator._base import InteractionValues, Approximator
+from ._base import Explainer
+from approximator import (
+ RegressionSII,
+ RegressionFSI,
+ PermutationSamplingSII,
+ PermutationSamplingSTI,
+ ShapIQ,
+)
+
+
+__all__ = ["InteractionExplainer"]
+
+
+APPROXIMATOR_CONFIGURATIONS = {
+ "Regression": {"SII": RegressionSII, "FSI": RegressionFSI, "nSII": RegressionSII},
+ "Permutation": {
+ "SII": PermutationSamplingSII,
+ "STI": PermutationSamplingSTI,
+ "nSII": PermutationSamplingSII,
+ },
+ "ShapIQ": {"SII": ShapIQ, "STI": ShapIQ, "FSI": ShapIQ, "nSII": ShapIQ},
+}
+
+AVAILABLE_INDICES = {
+ index
+ for approximator_dict in APPROXIMATOR_CONFIGURATIONS.values()
+ for index in approximator_dict.keys()
+}
+
+
+class InteractionExplainer(Explainer):
+ """The interaction explainer as the main interface for the shapiq package.
+
+ The interaction explainer is the main interface for the shapiq package. It can be used to
+ explain the predictions of a model by estimating the Shapley interaction values.
+
+ Args:
+ model: The model to explain as a callable function expecting a data points as input and
+ returning the model's predictions.
+ background_data: The background data to use for the explainer.
+ approximator: The approximator to use for the explainer. Defaults to `"auto"`, which will
+ automatically choose the approximator based on the number of features and the number of
+ samples in the background data.
+ index: The Shapley interaction index to use. Must be one of `"SII"` (Shapley Interaction Index),
+ `"nSII"` (n-Shapley Interaction Index), `"STI"` (Shapley-Taylor Interaction Index), or
+ `"FSI"` (Faithful Shapley Interaction Index). Defaults to `"nSII"`.
+ """
+
+ def __init__(
+ self,
+ model: Callable[[np.ndarray], np.ndarray],
+ background_data: np.ndarray,
+ approximator: Union[str, Approximator] = "auto",
+ index: str = "nSII",
+ max_order: int = 2,
+ random_state: Optional[int] = None,
+ ) -> None:
+ super().__init__(model, background_data)
+ if index not in AVAILABLE_INDICES:
+ raise ValueError(f"Invalid index `{index}`. " f"Valid indices are {AVAILABLE_INDICES}.")
+ self.index = index
+ self._default_budget: int = 2_000
+ if max_order < 2:
+ raise ValueError("The maximum order must be at least 2.")
+ self._max_order: int = max_order
+ self._random_state = random_state
+ self._rng = np.random.default_rng(self._random_state)
+ self.approximator = self._init_approximator(approximator, index, max_order)
+
+ def explain(self, x_explain: np.ndarray, budget: Optional[int] = None) -> InteractionValues:
+ """Explains the model's predictions.
+
+ Args:
+ x_explain: The data point to explain as a 2-dimensional array with shape
+ (1, n_features).
+ budget: The budget to use for the approximation. Defaults to `None`, which will choose
+ the budget automatically based on the number of features.
+ """
+ if budget is None:
+ budget = min(2**self._n_features, self._default_budget)
+
+ # initialize the imputer with the explanation point
+ imputer = self._imputer.fit(x_explain)
+
+ # explain
+ interaction_values = self.approximator.approximate(budget=budget, game=imputer)
+
+ return interaction_values
+
+ def _init_approximator(
+ self, approximator: Union[Approximator, str], index: str, max_order: int
+ ) -> Approximator:
+ if isinstance(approximator, Approximator): # if the approximator is already given
+ return approximator
+ if approximator == "auto":
+ if index == "FSI":
+ return RegressionFSI(
+ n=self._n_features,
+ max_order=max_order,
+ random_state=self._random_state,
+ )
+ else: # default to ShapIQ
+ return ShapIQ(
+ n=self._n_features,
+ max_order=max_order,
+ top_order=False,
+ random_state=self._random_state,
+ index=index,
+ )
+ # assume that the approximator is a string
+ try:
+ approximator_class = APPROXIMATOR_CONFIGURATIONS[approximator][index]
+ except KeyError:
+ raise ValueError(
+ f"Invalid approximator `{approximator}` or index `{index}`. "
+ f"Valid configuration are described in {APPROXIMATOR_CONFIGURATIONS}."
+ )
+ # initialize the approximator class with params
+ init_approximator = approximator_class.__init__(n=self._n_features, max_order=max_order)
+ return init_approximator
diff --git a/tests/tests_approximators/__init__.py b/tests/tests_approximators/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/test_approximator_base_interaction_values.py b/tests/tests_approximators/test_approximator_base_interaction_values.py
similarity index 100%
rename from tests/test_approximator_base_interaction_values.py
rename to tests/tests_approximators/test_approximator_base_interaction_values.py
diff --git a/tests/test_approximator_nsii_estimation.py b/tests/tests_approximators/test_approximator_nsii_estimation.py
similarity index 94%
rename from tests/test_approximator_nsii_estimation.py
rename to tests/tests_approximators/test_approximator_nsii_estimation.py
index ce311c96..84a76333 100644
--- a/tests/test_approximator_nsii_estimation.py
+++ b/tests/tests_approximators/test_approximator_nsii_estimation.py
@@ -2,8 +2,13 @@
import numpy as np
import pytest
-from approximator import convert_nsii_into_one_dimension, transforms_sii_to_nsii
-from shapiq import DummyGame, PermutationSamplingSII, ShapIQ
+from approximator import (
+ convert_nsii_into_one_dimension,
+ transforms_sii_to_nsii,
+ PermutationSamplingSII,
+ ShapIQ,
+)
+from games import DummyGame
@pytest.mark.parametrize(
diff --git a/tests/test_approximator_permutation_sii.py b/tests/tests_approximators/test_approximator_permutation_sii.py
similarity index 100%
rename from tests/test_approximator_permutation_sii.py
rename to tests/tests_approximators/test_approximator_permutation_sii.py
diff --git a/tests/test_approximator_permutation_sti.py b/tests/tests_approximators/test_approximator_permutation_sti.py
similarity index 100%
rename from tests/test_approximator_permutation_sti.py
rename to tests/tests_approximators/test_approximator_permutation_sti.py
diff --git a/tests/test_approximator_regression_fsi.py b/tests/tests_approximators/test_approximator_regression_fsi.py
similarity index 100%
rename from tests/test_approximator_regression_fsi.py
rename to tests/tests_approximators/test_approximator_regression_fsi.py
diff --git a/tests/test_approximator_regression_sii.py b/tests/tests_approximators/test_approximator_regression_sii.py
similarity index 100%
rename from tests/test_approximator_regression_sii.py
rename to tests/tests_approximators/test_approximator_regression_sii.py
diff --git a/tests/tests_approximators/test_approximator_regression_sv.py b/tests/tests_approximators/test_approximator_regression_sv.py
new file mode 100644
index 00000000..8801601f
--- /dev/null
+++ b/tests/tests_approximators/test_approximator_regression_sv.py
@@ -0,0 +1,68 @@
+"""This test module contains all tests regarding the SV KernelSHAP regression approximator."""
+from copy import deepcopy, copy
+
+import numpy as np
+import pytest
+
+from approximator._base import InteractionValues
+from approximator.regression import KernelSHAP
+from games import DummyGame
+
+
+@pytest.mark.parametrize(
+ "n",
+ [
+ 3,
+ 7, # used in subsequent tests
+ 10,
+ ],
+)
+def test_initialization(n):
+ """Tests the initialization of the RegressionFSI approximator."""
+ approximator = KernelSHAP(n)
+ assert approximator.n == n
+ assert approximator.max_order == 1
+ assert approximator.top_order is False
+ assert approximator.min_order == 1
+ assert approximator.iteration_cost == 1
+ assert approximator.index == "SV"
+
+ approximator_copy = copy(approximator)
+ approximator_deepcopy = deepcopy(approximator)
+ approximator_deepcopy.index = "something"
+ assert approximator_copy == approximator # check that the copy is equal
+ assert approximator_deepcopy != approximator # check that the deepcopy is not equal
+ approximator_string = str(approximator)
+ assert repr(approximator) == approximator_string
+ assert hash(approximator) == hash(approximator_copy)
+ assert hash(approximator) != hash(approximator_deepcopy)
+ with pytest.raises(ValueError):
+ _ = approximator == 1
+
+
+@pytest.mark.parametrize("n, budget, batch_size", [(7, 380, 100), (7, 380, None), (7, 100, None)])
+def test_approximate(n, budget, batch_size):
+ """Tests the approximation of the KernelSHAP approximator."""
+
+ interaction = (1, 2)
+ game = DummyGame(n, interaction)
+
+ approximator = KernelSHAP(n)
+ sv_estimates = approximator.approximate(budget, game, batch_size=batch_size)
+ assert isinstance(sv_estimates, InteractionValues)
+ assert sv_estimates.max_order == 1
+ assert sv_estimates.min_order == 1
+ assert sv_estimates.index == "SV"
+
+ # check that the budget is respected
+ assert game.access_counter <= budget + 2
+
+ # check that the values are in the correct range
+ # check that the estimates are correct
+ # for order 1 player 1 and 2 are the most important with 0.6429
+ assert sv_estimates[(1,)] == pytest.approx(0.6429, 0.1)
+ assert sv_estimates[(2,)] == pytest.approx(0.6429, 0.1)
+
+ # check efficiency
+ efficiency = np.sum(sv_estimates.values)
+ assert efficiency == pytest.approx(2.0, 0.1)
diff --git a/tests/test_approximator_shapiq.py b/tests/tests_approximators/test_approximator_shapiq.py
similarity index 100%
rename from tests/test_approximator_shapiq.py
rename to tests/tests_approximators/test_approximator_shapiq.py
diff --git a/tests/tests_explainer/__init__.py b/tests/tests_explainer/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/tests_explainer/test_explainer_interaction.py b/tests/tests_explainer/test_explainer_interaction.py
new file mode 100644
index 00000000..3bc5bc4b
--- /dev/null
+++ b/tests/tests_explainer/test_explainer_interaction.py
@@ -0,0 +1,124 @@
+"""This test module contains all tests regarding the interaciton explainer for the shapiq package.
+"""
+
+import pytest
+
+from sklearn.tree import DecisionTreeRegressor
+from sklearn.ensemble import RandomForestRegressor
+from sklearn.datasets import make_regression
+
+from shapiq.explainer import InteractionExplainer
+
+
+@pytest.fixture
+def dt_model():
+ """Return a simple decision tree model."""
+ X, y = make_regression(n_samples=100, n_features=7, random_state=42)
+ model = DecisionTreeRegressor(random_state=42, max_depth=3)
+ model.fit(X, y)
+ return model
+
+
+@pytest.fixture
+def rf_model():
+ """Return a simple decision tree model."""
+ X, y = make_regression(n_samples=100, n_features=7, random_state=42)
+ model = RandomForestRegressor(random_state=42, max_depth=3, n_estimators=10)
+ model.fit(X, y)
+ return model
+
+
+@pytest.fixture
+def background_data():
+ """Return data to use as background data."""
+ X, y = make_regression(n_samples=100, n_features=7, random_state=42)
+ return X
+
+
+INDICES = ["SII", "nSII", "STI", "FSI"]
+MAX_ORDERS = [2, 3]
+
+
+@pytest.mark.parametrize("index", INDICES)
+@pytest.mark.parametrize("max_order", MAX_ORDERS)
+def test_init_params(dt_model, background_data, index, max_order):
+ """Test the initialization of the interaction explainer."""
+ model_function = dt_model.predict
+ explainer = InteractionExplainer(
+ model=model_function,
+ background_data=background_data,
+ random_state=42,
+ index=index,
+ max_order=max_order,
+ approximator="auto",
+ )
+ assert explainer.index == index
+ assert explainer.approximator.index == index
+ assert explainer._max_order == max_order
+ assert explainer._random_state == 42
+ # test defaults
+ if index == "FSI":
+ assert explainer.approximator.__class__.__name__ == "RegressionFSI"
+ else:
+ assert explainer.approximator.__class__.__name__ == "ShapIQ"
+
+
+def test_auto_params(dt_model, background_data):
+ """Test the initialization of the interaction explainer."""
+ model_function = dt_model.predict
+ explainer = InteractionExplainer(
+ model=model_function,
+ background_data=background_data,
+ )
+ assert explainer.index == "nSII"
+ assert explainer.approximator.index == "nSII"
+ assert explainer._max_order == 2
+ assert explainer._random_state is None
+ assert explainer.approximator.__class__.__name__ == "ShapIQ"
+
+
+def test_init_params_error(dt_model, background_data):
+ """Test the initialization of the interaction explainer."""
+ model_function = dt_model.predict
+ with pytest.raises(ValueError):
+ InteractionExplainer(
+ model=model_function,
+ background_data=background_data,
+ index="invalid",
+ )
+ with pytest.raises(ValueError):
+ InteractionExplainer(
+ model=model_function,
+ background_data=background_data,
+ max_order=0,
+ )
+ with pytest.raises(ValueError):
+ InteractionExplainer(
+ model=model_function,
+ background_data=background_data,
+ approximator="invalid",
+ )
+
+
+BUDGETS = [2**5, 2**8]
+
+
+@pytest.mark.parametrize("budget", BUDGETS)
+@pytest.mark.parametrize("index", INDICES)
+@pytest.mark.parametrize("max_order", MAX_ORDERS)
+def test_explain(dt_model, background_data, index, budget, max_order):
+ """Test the initialization of the interaction explainer."""
+ model_function = dt_model.predict
+ explainer = InteractionExplainer(
+ model=model_function,
+ background_data=background_data,
+ random_state=42,
+ index=index,
+ max_order=max_order,
+ approximator="auto",
+ )
+ x_explain = background_data[0].reshape(1, -1)
+ interaction_values = explainer.explain(x_explain, budget=budget)
+ assert interaction_values.index == index
+ assert interaction_values.max_order == max_order
+ assert interaction_values.estimation_budget <= budget + 2
diff --git a/tests/tests_games/__init__.py b/tests/tests_games/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/test_games_dummy.py b/tests/tests_games/test_games_dummy.py
similarity index 100%
rename from tests/test_games_dummy.py
rename to tests/tests_games/test_games_dummy.py
diff --git a/tests/tests_utils/__init__.py b/tests/tests_utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/test_utils_sets.py b/tests/tests_utils/test_utils_sets.py
similarity index 100%
rename from tests/test_utils_sets.py
rename to tests/tests_utils/test_utils_sets.py