Skip to content

Commit

Permalink
275 remove shap as optional dependency (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
Advueu963 authored Jan 10, 2025
1 parent d2ecb56 commit 34a86e3
Show file tree
Hide file tree
Showing 15 changed files with 1,541 additions and 178 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ ruff==0.8.4
scikit-image==0.25.0
scikit-learn==1.6.0
scipy==1.14.1
shap==0.46.0
tqdm==4.67.1
torch==2.5.1
torchvision==0.20.1
Expand Down
132 changes: 121 additions & 11 deletions shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import os
import pickle
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, Union
from warnings import warn
Expand Down Expand Up @@ -630,6 +631,25 @@ def to_dict(self) -> dict:
"baseline_value": self.baseline_value,
}

def aggregate(
self, others: Sequence["InteractionValues"], aggregation: str = "mean"
) -> "InteractionValues":
"""Aggregates InteractionValues objects using a specific aggregation method.
Args:
others: A list of InteractionValues objects to aggregate.
aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
Returns:
The aggregated InteractionValues object.
Note:
For documentation on the aggregation methods, see the ``aggregate_interaction_values()``
function.
"""
return aggregate_interaction_values([self, *others], aggregation)

def plot_network(self, show: bool = True, **kwargs) -> Optional[tuple[plt.Figure, plt.Axes]]:
"""Visualize InteractionValues on a graph.
Expand Down Expand Up @@ -682,18 +702,13 @@ def plot_stacked_bar(
def plot_force(
self,
feature_names: Optional[np.ndarray] = None,
feature_values: Optional[np.ndarray] = None,
matplotlib=True,
show: bool = True,
abbreviate: bool = True,
**kwargs,
) -> Optional[plt.Figure]:
"""Visualize InteractionValues on a force plot.
For arguments, see shapiq.plots.force_plot().
Requires the ``shap`` Python package to be installed.
Args:
feature_names: The feature names used for plotting. If no feature names are provided, the
feature indices are used instead. Defaults to ``None``.
Expand All @@ -710,18 +725,14 @@ def plot_force(

return force_plot(
self,
feature_values=feature_values,
feature_names=feature_names,
matplotlib=matplotlib,
show=show,
abbreviate=abbreviate,
**kwargs,
)

def plot_waterfall(
self,
feature_names: Optional[np.ndarray] = None,
feature_values: Optional[np.ndarray] = None,
show: bool = True,
abbreviate: bool = True,
max_display: int = 10,
Expand All @@ -743,11 +754,10 @@ def plot_waterfall(

return waterfall_plot(
self,
feature_values=feature_values,
feature_names=feature_names,
show=show,
abbreviate=abbreviate,
max_display=max_display,
abbreviate=abbreviate,
)

def plot_sentence(
Expand Down Expand Up @@ -779,3 +789,103 @@ def plot_upset(self, show: bool = True, **kwargs) -> Optional[plt.Figure]:
from shapiq.plot.upset import upset_plot

return upset_plot(self, show=show, **kwargs)


def aggregate_interaction_values(
interaction_values: Sequence[InteractionValues],
aggregation: str = "mean",
) -> InteractionValues:
"""Aggregates InteractionValues objects using a specific aggregation method.
Args:
interaction_values: A list of InteractionValues objects to aggregate.
aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
Returns:
The aggregated InteractionValues object.
Example:
>>> iv1 = InteractionValues(
... values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
... index="SII",
... max_order=2,
... n_players=3,
... min_order=1,
... baseline_value=0.0,
... )
>>> iv2 = InteractionValues(
... values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]), # this iv is missing the (1, 2) value
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4}, # no (1, 2)
... index="SII",
... max_order=2,
... n_players=3,
... min_order=1,
... baseline_value=1.0,
... )
>>> aggregate_interaction_values([iv1, iv2], "mean")
InteractionValues(
index=SII, max_order=2, min_order=1, estimated=True, estimation_budget=None,
n_players=3, baseline_value=0.5,
Top 10 interactions:
(1, 2): 0.60
(0, 2): 0.35
(0, 1): 0.25
(0,): 0.15
(1,): 0.25
(2,): 0.35
)
Note:
The index of the aggregated InteractionValues object is set to the index of the first
InteractionValues object in the list.
Raises:
ValueError: If the aggregation method is not supported.
"""

def _aggregate(vals: list[float], method: str) -> float:
"""Does the actual aggregation of the values."""
if method == "mean":
return np.mean(vals)
elif method == "median":
return np.median(vals)
elif method == "sum":
return np.sum(vals)
elif method == "max":
return np.max(vals)
elif method == "min":
return np.min(vals)
else:
raise ValueError(f"Aggregation method {method} is not supported.")

# get all keys from all InteractionValues objects
all_keys = set()
for iv in interaction_values:
all_keys.update(iv.interaction_lookup.keys())
all_keys = sorted(all_keys)

# aggregate the values
new_values = np.zeros(len(all_keys), dtype=float)
new_lookup = {}
for i, key in enumerate(all_keys):
new_lookup[key] = i
values = [iv[key] for iv in interaction_values]
new_values[i] = _aggregate(values, aggregation)

max_order = max([iv.max_order for iv in interaction_values])
min_order = min([iv.min_order for iv in interaction_values])
n_players = max([iv.n_players for iv in interaction_values])
baseline_value = _aggregate([iv.baseline_value for iv in interaction_values], aggregation)

return InteractionValues(
values=new_values,
index=interaction_values[0].index,
max_order=max_order,
n_players=n_players,
min_order=min_order,
interaction_lookup=new_lookup,
estimated=True,
estimation_budget=None,
baseline_value=baseline_value,
)
3 changes: 1 addition & 2 deletions shapiq/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .si_graph import si_graph_plot
from .stacked_bar import stacked_bar_plot
from .upset import upset_plot
from .utils import abbreviate_feature_names, get_interaction_values_and_feature_names
from .utils import abbreviate_feature_names
from .watefall import waterfall_plot

__all__ = [
Expand All @@ -21,5 +21,4 @@
"upset_plot",
# utils
"abbreviate_feature_names",
"get_interaction_values_and_feature_names",
]
Loading

0 comments on commit 34a86e3

Please sign in to comment.