Skip to content

Commit

Permalink
refactors Regression into RegressionSII and RegressionFSI
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Dec 1, 2023
1 parent b4c94fc commit 051fa22
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 20 deletions.
11 changes: 9 additions & 2 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from __version__ import __version__

# approximator classes
from .approximator import PermutationSamplingSII, PermutationSamplingSTI, Regression, ShapIQ
from .approximator import (
PermutationSamplingSII,
PermutationSamplingSTI,
RegressionSII,
RegressionFSI,
ShapIQ,
)

# explainer classes
from .explainer import Explainer
Expand All @@ -31,7 +37,8 @@
"ShapIQ",
"PermutationSamplingSII",
"PermutationSamplingSTI",
"Regression",
"RegressionSII",
"RegressionFSI",
# explainers
"Explainer",
# games
Expand Down
5 changes: 3 additions & 2 deletions shapiq/approximator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from ._base import convert_nsii_into_one_dimension, transforms_sii_to_nsii # TODO add to tests
from .permutation.sii import PermutationSamplingSII
from .permutation.sti import PermutationSamplingSTI
from .regression import Regression
from .regression import RegressionSII, RegressionFSI
from .shapiq import ShapIQ

__all__ = [
"PermutationSamplingSII",
"PermutationSamplingSTI",
"Regression",
"RegressionFSI",
"RegressionSII",
"ShapIQ",
"transforms_sii_to_nsii",
"convert_nsii_into_one_dimension",
Expand Down
5 changes: 3 additions & 2 deletions shapiq/approximator/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module contains the regression-based approximators to estimate Shapley interaction values.
"""
from ._base import Regression
from .sii import RegressionSII
from .fsi import RegressionFSI

__all__ = ["Regression"]
__all__ = ["RegressionSII", "RegressionFSI"]
6 changes: 3 additions & 3 deletions shapiq/approximator/regression/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class Regression(Approximator, ShapleySamplingMixin):
"""Estimates the FSI values using the weighted least square approach.
"""Estimates the InteractionScores values using the weighted least square approach.
Args:
n: The number of players.
Expand All @@ -27,9 +27,9 @@ class Regression(Approximator, ShapleySamplingMixin):
Example:
>>> from games import DummyGame
>>> from approximator import Regression
>>> from approximator import RegressionSII
>>> game = DummyGame(n=5, interaction=(1, 2))
>>> approximator = Regression(n=5, max_order=2)
>>> approximator = RegressionSII(n=5, max_order=2)
>>> approximator.approximate(budget=100, game=game)
InteractionValues(
index=FSI, order=2, estimated=False, estimation_budget=32,
Expand Down
52 changes: 52 additions & 0 deletions shapiq/approximator/regression/fsi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Regression with Faithful Shapley Interaction (FSI) index approximation."""
from typing import Optional

from ._base import Regression
from .._base import NShapleyMixin


class RegressionFSI(Regression, NShapleyMixin):
"""Estimates the FSI values using the weighted least square approach.
Args:
n: The number of players.
max_order: The interaction order of the approximation.
random_state: The random state of the estimator. Defaults to `None`.
Attributes:
n: The number of players.
N: The set of players (starting from 0 to n - 1).
max_order: The interaction order of the approximation.
min_order: The minimum order of the approximation. For FSI, min_order is equal to 1.
iteration_cost: The cost of a single iteration of the regression FSI.
Example:
>>> from games import DummyGame
>>> from approximator import RegressionFSI
>>> game = DummyGame(n=5, interaction=(1, 2))
>>> approximator = RegressionFsi(n=5, max_order=2)
>>> approximator.approximate(budget=100, game=game)
InteractionValues(
index=FSI, order=2, estimated=False, estimation_budget=32,
values={
(0,): 0.2,
(1,): 0.7,
(2,): 0.7,
(3,): 0.2,
(4,): 0.2,
(0, 1): 0,
(0, 2): 0,
(0, 3): 0,
(0, 4): 0,
(1, 2): 1.0,
(1, 3): 0,
(1, 4): 0,
(2, 3): 0,
(2, 4): 0,
(3, 4): 0
}
)
"""

def __init__(self, n: int, max_order: int, random_state: Optional[int] = None):
super().__init__(n, max_order, index="FSI", random_state=random_state)
53 changes: 53 additions & 0 deletions shapiq/approximator/regression/sii.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Regression with Shapley interaction index (SII) approximation."""
from typing import Optional

from ._base import Regression
from .._base import NShapleyMixin


class RegressionSII(Regression, NShapleyMixin):
"""Estimates the SII values using the weighted least square approach.
Args:
n: The number of players.
max_order: The interaction order of the approximation.
random_state: The random state of the estimator. Defaults to `None`.
Attributes:
n: The number of players.
N: The set of players (starting from 0 to n - 1).
max_order: The interaction order of the approximation.
min_order: The minimum order of the approximation. For the regression estimator, min_order
is equal to 1.
iteration_cost: The cost of a single iteration of the regression SII.
Example:
>>> from games import DummyGame
>>> from approximator import RegressionSII
>>> game = DummyGame(n=5, interaction=(1, 2))
>>> approximator = RegressionSII(n=5, max_order=2)
>>> approximator.approximate(budget=100, game=game)
InteractionValues(
index=SII, order=2, estimated=False, estimation_budget=32,
values={
(0,): 0.2,
(1,): 0.7,
(2,): 0.7,
(3,): 0.2,
(4,): 0.2,
(0, 1): 0,
(0, 2): 0,
(0, 3): 0,
(0, 4): 0,
(1, 2): 1.0,
(1, 3): 0,
(1, 4): 0,
(2, 3): 0,
(2, 4): 0,
(3, 4): 0
}
)
"""

def __init__(self, n: int, max_order: int, random_state: Optional[int] = None):
super().__init__(n, max_order, index="SII", random_state=random_state)
8 changes: 3 additions & 5 deletions tests/test_approximator_regression_fsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from approximator._base import InteractionValues
from approximator.regression import Regression
from approximator.regression import RegressionFSI
from games import DummyGame


Expand All @@ -22,7 +22,7 @@
)
def test_initialization(n, max_order):
"""Tests the initialization of the RegressionFSI approximator."""
approximator = Regression(n, max_order, index="FSI")
approximator = RegressionFSI(n, max_order)
assert approximator.n == n
assert approximator.max_order == max_order
assert approximator.top_order is False
Expand All @@ -41,8 +41,6 @@ def test_initialization(n, max_order):
assert hash(approximator) != hash(approximator_deepcopy)
with pytest.raises(ValueError):
_ = approximator == 1
with pytest.raises(ValueError):
_ = Regression(n, max_order, index="something")


@pytest.mark.parametrize(
Expand All @@ -52,7 +50,7 @@ def test_approximate(n, max_order, budget, batch_size):
"""Tests the approximation of the RegressionFSI approximator."""
interaction = (1, 2)
game = DummyGame(n, interaction)
approximator = Regression(n, max_order, index="FSI", random_state=42)
approximator = RegressionFSI(n, max_order, random_state=42)
fsi_estimates = approximator.approximate(budget, game, batch_size=batch_size)
assert isinstance(fsi_estimates, InteractionValues)
assert fsi_estimates.max_order == max_order
Expand Down
23 changes: 19 additions & 4 deletions tests/test_approximator_regression_sii.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest

from approximator._base import InteractionValues
from approximator.regression import Regression
from approximator.regression._base import Regression
from approximator.regression import RegressionSII
from games import DummyGame


Expand All @@ -22,7 +23,7 @@
)
def test_initialization(n, max_order):
"""Tests the initialization of the Regression approximator for SII."""
approximator = Regression(n, max_order, index="SII")
approximator = RegressionSII(n, max_order)
assert approximator.n == n
assert approximator.max_order == max_order
assert approximator.top_order is False
Expand Down Expand Up @@ -52,7 +53,7 @@ def test_approximate(n, max_order, budget, batch_size):
"""Tests the approximation of the Regression approximator for SII."""
interaction = (1, 2)
game = DummyGame(n, interaction)
approximator = Regression(n, max_order, index="SII", random_state=42)
approximator = RegressionSII(n, max_order, random_state=42)
sii_estimates = approximator.approximate(budget, game, batch_size=batch_size)
assert isinstance(sii_estimates, InteractionValues)
assert sii_estimates.max_order == max_order
Expand All @@ -61,4 +62,18 @@ def test_approximate(n, max_order, budget, batch_size):
# check that the budget is respected
assert game.access_counter <= budget + 2

# TODO check that the estimates are correct
# check that the estimates are correct
# for order 1 player 1 and 2 are the most important with 0.6429
assert sii_estimates[(1,)] == pytest.approx(0.6429, 0.4) # quite a large interval
assert sii_estimates[(2,)] == pytest.approx(0.6429, 0.4)

# for order 2 the interaction between player 1 and 2 is the most important
assert sii_estimates[(1, 2)] == pytest.approx(1.0, 0.2)

# check efficiency
efficiency = np.sum(sii_estimates.values[:n])
assert efficiency == pytest.approx(2.0, 0.01)

# try covert to nSII
nsii_estimates = approximator.transforms_sii_to_nsii(sii_estimates)
assert nsii_estimates.index == "nSII"
11 changes: 9 additions & 2 deletions tests/test_integration_import_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,17 @@ def test_approximator_imports():
from shapiq.approximator import (
PermutationSamplingSII,
PermutationSamplingSTI,
Regression,
RegressionSII,
RegressionFSI,
ShapIQ,
)

from shapiq import ShapIQ, PermutationSamplingSII, PermutationSamplingSTI, Regression
from shapiq import (
ShapIQ,
PermutationSamplingSII,
PermutationSamplingSTI,
RegressionSII,
RegressionFSI,
)

assert True

0 comments on commit 051fa22

Please sign in to comment.