From 34a86e3970447f694b7418b682aab085332b1ac6 Mon Sep 17 00:00:00 2001 From: Santo Thies <49311372+Advueu963@users.noreply.github.com> Date: Fri, 10 Jan 2025 19:28:37 +0100 Subject: [PATCH] 275 remove shap as optional dependency (#296) --- requirements.txt | 1 - shapiq/interaction_values.py | 132 +++- shapiq/plot/__init__.py | 3 +- shapiq/plot/bar.py | 212 +++++- shapiq/plot/force.py | 635 +++++++++++++++++- shapiq/plot/utils.py | 96 +-- shapiq/plot/watefall.py | 367 +++++++++- tests/conftest.py | 38 +- tests/requirements/requirements.txt | 1 - tests/test_base_interaction_values.py | 102 ++- .../tests_explainer/test_explainer_tabular.py | 19 +- tests/tests_plots/test_bar.py | 26 +- tests/tests_plots/test_force.py | 27 +- tests/tests_plots/test_utils.py | 35 + tests/tests_plots/test_waterfall.py | 25 +- 15 files changed, 1541 insertions(+), 178 deletions(-) create mode 100644 tests/tests_plots/test_utils.py diff --git a/requirements.txt b/requirements.txt index b9bf4c2a..3b7c68ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,6 @@ ruff==0.8.4 scikit-image==0.25.0 scikit-learn==1.6.0 scipy==1.14.1 -shap==0.46.0 tqdm==4.67.1 torch==2.5.1 torchvision==0.20.1 diff --git a/shapiq/interaction_values.py b/shapiq/interaction_values.py index ae98d8fa..cb1fb931 100644 --- a/shapiq/interaction_values.py +++ b/shapiq/interaction_values.py @@ -4,6 +4,7 @@ import copy import os import pickle +from collections.abc import Sequence from dataclasses import dataclass from typing import Optional, Union from warnings import warn @@ -630,6 +631,25 @@ def to_dict(self) -> dict: "baseline_value": self.baseline_value, } + def aggregate( + self, others: Sequence["InteractionValues"], aggregation: str = "mean" + ) -> "InteractionValues": + """Aggregates InteractionValues objects using a specific aggregation method. + + Args: + others: A list of InteractionValues objects to aggregate. + aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are + ``"median"``, ``"sum"``, ``"max"``, and ``"min"``. + + Returns: + The aggregated InteractionValues object. + + Note: + For documentation on the aggregation methods, see the ``aggregate_interaction_values()`` + function. + """ + return aggregate_interaction_values([self, *others], aggregation) + def plot_network(self, show: bool = True, **kwargs) -> Optional[tuple[plt.Figure, plt.Axes]]: """Visualize InteractionValues on a graph. @@ -682,18 +702,13 @@ def plot_stacked_bar( def plot_force( self, feature_names: Optional[np.ndarray] = None, - feature_values: Optional[np.ndarray] = None, - matplotlib=True, show: bool = True, abbreviate: bool = True, - **kwargs, ) -> Optional[plt.Figure]: """Visualize InteractionValues on a force plot. For arguments, see shapiq.plots.force_plot(). - Requires the ``shap`` Python package to be installed. - Args: feature_names: The feature names used for plotting. If no feature names are provided, the feature indices are used instead. Defaults to ``None``. @@ -710,18 +725,14 @@ def plot_force( return force_plot( self, - feature_values=feature_values, feature_names=feature_names, - matplotlib=matplotlib, show=show, abbreviate=abbreviate, - **kwargs, ) def plot_waterfall( self, feature_names: Optional[np.ndarray] = None, - feature_values: Optional[np.ndarray] = None, show: bool = True, abbreviate: bool = True, max_display: int = 10, @@ -743,11 +754,10 @@ def plot_waterfall( return waterfall_plot( self, - feature_values=feature_values, feature_names=feature_names, show=show, - abbreviate=abbreviate, max_display=max_display, + abbreviate=abbreviate, ) def plot_sentence( @@ -779,3 +789,103 @@ def plot_upset(self, show: bool = True, **kwargs) -> Optional[plt.Figure]: from shapiq.plot.upset import upset_plot return upset_plot(self, show=show, **kwargs) + + +def aggregate_interaction_values( + interaction_values: Sequence[InteractionValues], + aggregation: str = "mean", +) -> InteractionValues: + """Aggregates InteractionValues objects using a specific aggregation method. + + Args: + interaction_values: A list of InteractionValues objects to aggregate. + aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are + ``"median"``, ``"sum"``, ``"max"``, and ``"min"``. + + Returns: + The aggregated InteractionValues object. + + Example: + >>> iv1 = InteractionValues( + ... values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]), + ... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5}, + ... index="SII", + ... max_order=2, + ... n_players=3, + ... min_order=1, + ... baseline_value=0.0, + ... ) + >>> iv2 = InteractionValues( + ... values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]), # this iv is missing the (1, 2) value + ... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4}, # no (1, 2) + ... index="SII", + ... max_order=2, + ... n_players=3, + ... min_order=1, + ... baseline_value=1.0, + ... ) + >>> aggregate_interaction_values([iv1, iv2], "mean") + InteractionValues( + index=SII, max_order=2, min_order=1, estimated=True, estimation_budget=None, + n_players=3, baseline_value=0.5, + Top 10 interactions: + (1, 2): 0.60 + (0, 2): 0.35 + (0, 1): 0.25 + (0,): 0.15 + (1,): 0.25 + (2,): 0.35 + ) + Note: + The index of the aggregated InteractionValues object is set to the index of the first + InteractionValues object in the list. + + Raises: + ValueError: If the aggregation method is not supported. + """ + + def _aggregate(vals: list[float], method: str) -> float: + """Does the actual aggregation of the values.""" + if method == "mean": + return np.mean(vals) + elif method == "median": + return np.median(vals) + elif method == "sum": + return np.sum(vals) + elif method == "max": + return np.max(vals) + elif method == "min": + return np.min(vals) + else: + raise ValueError(f"Aggregation method {method} is not supported.") + + # get all keys from all InteractionValues objects + all_keys = set() + for iv in interaction_values: + all_keys.update(iv.interaction_lookup.keys()) + all_keys = sorted(all_keys) + + # aggregate the values + new_values = np.zeros(len(all_keys), dtype=float) + new_lookup = {} + for i, key in enumerate(all_keys): + new_lookup[key] = i + values = [iv[key] for iv in interaction_values] + new_values[i] = _aggregate(values, aggregation) + + max_order = max([iv.max_order for iv in interaction_values]) + min_order = min([iv.min_order for iv in interaction_values]) + n_players = max([iv.n_players for iv in interaction_values]) + baseline_value = _aggregate([iv.baseline_value for iv in interaction_values], aggregation) + + return InteractionValues( + values=new_values, + index=interaction_values[0].index, + max_order=max_order, + n_players=n_players, + min_order=min_order, + interaction_lookup=new_lookup, + estimated=True, + estimation_budget=None, + baseline_value=baseline_value, + ) diff --git a/shapiq/plot/__init__.py b/shapiq/plot/__init__.py index c5503158..3cb2a16c 100644 --- a/shapiq/plot/__init__.py +++ b/shapiq/plot/__init__.py @@ -7,7 +7,7 @@ from .si_graph import si_graph_plot from .stacked_bar import stacked_bar_plot from .upset import upset_plot -from .utils import abbreviate_feature_names, get_interaction_values_and_feature_names +from .utils import abbreviate_feature_names from .watefall import waterfall_plot __all__ = [ @@ -21,5 +21,4 @@ "upset_plot", # utils "abbreviate_feature_names", - "get_interaction_values_and_feature_names", ] diff --git a/shapiq/plot/bar.py b/shapiq/plot/bar.py index af7f5163..941590d2 100644 --- a/shapiq/plot/bar.py +++ b/shapiq/plot/bar.py @@ -5,19 +5,165 @@ import matplotlib.pyplot as plt import numpy as np -from ..interaction_values import InteractionValues -from ..utils.modules import check_import_module -from .utils import get_interaction_values_and_feature_names +from ..interaction_values import InteractionValues, aggregate_interaction_values +from ._config import BLUE, RED +from .utils import abbreviate_feature_names, format_labels, format_value __all__ = ["bar_plot"] +def _bar( + values: np.ndarray, + feature_names: np.ndarray, + max_display: Optional[int] = 10, + ax: Optional[plt.Axes] = None, +) -> plt.Axes: + """Create a bar plot of a set of SHAP values. + + This is a modified version of the bar plot from the SHAP package. The original code can be found + at https://github.com/shap/shap. + """ + # determine how many top features we will plot + num_features = len(values[0]) + if max_display is None: + max_display = num_features + max_display = min(max_display, num_features) + num_cut = max(num_features - max_display, 0) # number of features that are not displayed + + # get order of features in descending order + feature_order = np.argsort(np.mean(values, axis=0))[::-1] + + # if there are more features than we are displaying then we aggregate the features not shown + if num_cut > 0: + cut_feature_values = values[:, feature_order[max_display:]] + sum_of_remaining = np.sum(cut_feature_values, axis=None) + index_of_last = feature_order[max_display] + values = np.insert(values, index_of_last, sum_of_remaining, axis=1) + max_display += 1 # include the sum of the remaining in the display + + # get the top features and their names + feature_inds = feature_order[:max_display] + y_pos = np.arange(len(feature_inds), 0, -1) + yticklabels = [feature_names[i] for i in feature_inds] + if num_cut > 0: + yticklabels[-1] = f"Sum of {int(num_cut)} other features" + + # create a figure if one was not provided + if ax is None: + ax = plt.gca() + # only modify the figure size if ax was not passed in + # compute our figure size based on how many features we are showing + fig = plt.gcf() + row_height = 0.5 + fig.set_size_inches( + 8 + 0.3 * max([len(feature_name) for feature_name in feature_names]), + max_display * row_height * np.sqrt(len(values)) + 1.5, + ) + + # if negative values are present, we draw a vertical line to mark 0 + negative_values_present = np.sum(values[:, feature_order[:max_display]] < 0) > 0 + if negative_values_present: + ax.axvline(0, 0, 1, color="#000000", linestyle="-", linewidth=1, zorder=1) + + # draw the bars + patterns = (None, "\\\\", "++", "xx", "////", "*", "o", "O", ".", "-") + total_width = 0.7 + bar_width = total_width / len(values) + for i in range(len(values)): + ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2) + ax.barh( + y_pos + ypos_offset, + values[i, feature_inds], + bar_width, + align="center", + color=[ + BLUE.hex if values[i, feature_inds[j]] <= 0 else RED.hex for j in range(len(y_pos)) + ], + hatch=patterns[i], + edgecolor=(1, 1, 1, 0.8), + label="Group " + str(i + 1), + ) + + # draw the yticks (the 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks) + ax.set_yticks( + list(y_pos) + list(y_pos + 1e-8), + yticklabels + [t.split("=")[-1] for t in yticklabels], + fontsize=13, + ) + + xlen = ax.get_xlim()[1] - ax.get_xlim()[0] + bbox = ax.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted()) + width = bbox.width + bbox_to_xscale = xlen / width + + # draw the bar labels as text next to the bars + for i in range(len(values)): + ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2) + for j in range(len(y_pos)): + ind = feature_inds[j] + if values[i, ind] < 0: + ax.text( + values[i, ind] - (5 / 72) * bbox_to_xscale, + float(y_pos[j] + ypos_offset), + format_value(values[i, ind], "%+0.02f"), + horizontalalignment="right", + verticalalignment="center", + color=BLUE.hex, + fontsize=12, + ) + else: + ax.text( + values[i, ind] + (5 / 72) * bbox_to_xscale, + float(y_pos[j] + ypos_offset), + format_value(values[i, ind], "%+0.02f"), + horizontalalignment="left", + verticalalignment="center", + color=RED.hex, + fontsize=12, + ) + + # put horizontal lines for each feature row + for i in range(max_display): + ax.axhline(i + 1, color="#888888", lw=0.5, dashes=(1, 5), zorder=-1) + + # remove plot frame and y-axis ticks + ax.xaxis.set_ticks_position("bottom") + ax.yaxis.set_ticks_position("none") + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + if negative_values_present: + ax.spines["left"].set_visible(False) + ax.tick_params("x", labelsize=11) + + # set the x-axis limits to cover the data + xmin, xmax = ax.get_xlim() + x_buffer = (xmax - xmin) * 0.05 + if negative_values_present: + ax.set_xlim(xmin - x_buffer, xmax + x_buffer) + else: + ax.set_xlim(xmin, xmax + x_buffer) + + ax.set_xlabel("Attribution", fontsize=13) + + if len(values) > 1: + ax.legend(fontsize=12, loc="lower right") + + # color the y tick labels that have the feature values as gray + # (these fall behind the black ones with just the feature name) + tick_labels = ax.yaxis.get_majorticklabels() + for i in range(max_display): + tick_labels[i].set_color("#999999") + + return ax + + def bar_plot( list_of_interaction_values: list[InteractionValues], feature_names: Optional[np.ndarray] = None, show: bool = False, abbreviate: bool = True, - **kwargs, + max_display: Optional[int] = 10, + global_plot: bool = True, ) -> Optional[plt.Axes]: """Draws interaction values on a bar plot. @@ -30,37 +176,43 @@ def bar_plot( show: Whether ``matplotlib.pyplot.show()`` is called before returning. Default is ``True``. Setting this to ``False`` allows the plot to be customized further after it has been created. abbreviate: Whether to abbreviate the feature names. Defaults to ``True``. - **kwargs: Keyword arguments passed to ``shap.plots.beeswarm()``. + max_display: The maximum number of features to display. Defaults to ``10``. If set to + ``None``, all features are displayed. + global_plot: Weather to aggregate the values of the different InteractionValues objects + into a global explanation (``True``) or to plot them as separate bars (``False``). + Defaults to ``True``. If only one InteractionValues object is provided, this parameter + is ignored. """ - check_import_module("shap") - import shap + n_players = list_of_interaction_values[0].n_players - assert len(np.unique([iv.max_order for iv in list_of_interaction_values])) == 1 + if feature_names is not None: + if abbreviate: + feature_names = abbreviate_feature_names(feature_names) + feature_mapping = {i: feature_names[i] for i in range(n_players)} + else: + feature_mapping = {i: "F" + str(i) for i in range(n_players)} - _global_values = [] - _base_values = [] - _labels = [] - _first_iv = True - for iv in list_of_interaction_values: + # aggregate the interaction values if global_plot is True + if global_plot: + global_values = aggregate_interaction_values(list_of_interaction_values) + values = np.expand_dims(global_values.values, axis=0) + interaction_list = global_values.interaction_lookup.keys() + else: # plot the interaction values separately (also includes the case of a single object) + all_interactions = set() + for iv in list_of_interaction_values: + all_interactions.update(iv.interaction_lookup.keys()) + all_interactions = sorted(all_interactions) + interaction_list = [] + values = np.zeros((len(list_of_interaction_values), len(all_interactions))) + for j, interaction in enumerate(all_interactions): + interaction_list.append(interaction) + for i, iv in enumerate(list_of_interaction_values): + values[i, j] = iv[interaction] - _shap_values, _names = get_interaction_values_and_feature_names( - iv, feature_names, None, abbreviate=abbreviate - ) - if _first_iv: - _labels = _names - _first_iv = False - _global_values.append(_shap_values) - _base_values.append(iv.baseline_value) - - _labels = np.array(_labels) if feature_names is not None else None - explanation = shap.Explanation( - values=np.stack(_global_values), - base_values=np.array(_base_values), - feature_names=_labels, - ) + # format the labels + labels = [format_labels(feature_mapping, interaction) for interaction in interaction_list] - ax = shap.plots.bar(explanation, **kwargs, show=False) - ax.set_xlabel("mean(|Shapley Interaction value|)") + ax = _bar(values=values, feature_names=labels, max_display=max_display) if not show: return ax plt.show() diff --git a/shapiq/plot/force.py b/shapiq/plot/force.py index 9907282a..debb9a49 100644 --- a/shapiq/plot/force.py +++ b/shapiq/plot/force.py @@ -2,51 +2,626 @@ from typing import Optional +import matplotlib import matplotlib.pyplot as plt import numpy as np +from matplotlib import lines +from matplotlib.font_manager import FontProperties +from matplotlib.patches import PathPatch +from matplotlib.path import Path from ..interaction_values import InteractionValues -from ..utils.modules import check_import_module -from .utils import get_interaction_values_and_feature_names +from .utils import abbreviate_feature_names, format_labels __all__ = ["force_plot"] +def _create_bars( + out_value: float, + features: np.ndarray, + feature_type: str, + width_separators: float, + width_bar: float, +) -> tuple[list, list]: + """ + Create bars and separators for the plot. + Args: + out_value: the output value + features: names and values of the features to add + feature_type: Indicating whether positive or negative features + width_separators: width to separate the bars + width_bar: width of the bars + + Returns: List of bars and separators + """ + rectangle_list = [] + separator_list = [] + + pre_val = out_value + for index, features in zip(range(len(features)), features): + if feature_type == "positive": + left_bound = float(features[0]) + right_bound = pre_val + pre_val = left_bound + + separator_indent = np.abs(width_separators) + separator_pos = left_bound + colors = ["#FF0D57", "#FFC3D5"] + else: + left_bound = pre_val + right_bound = float(features[0]) + pre_val = right_bound + + separator_indent = -np.abs(width_separators) + separator_pos = right_bound + colors = ["#1E88E5", "#D1E6FA"] + + # Create rectangle + if index == 0: + if feature_type == "positive": + points_rectangle = [ + [left_bound, 0], + [right_bound, 0], + [right_bound, width_bar], + [left_bound, width_bar], + [left_bound + separator_indent, (width_bar / 2)], + ] + else: + points_rectangle = [ + [right_bound, 0], + [left_bound, 0], + [left_bound, width_bar], + [right_bound, width_bar], + [right_bound + separator_indent, (width_bar / 2)], + ] + + else: + points_rectangle = [ + [left_bound, 0], + [right_bound, 0], + [right_bound + separator_indent * 0.90, (width_bar / 2)], + [right_bound, width_bar], + [left_bound, width_bar], + [left_bound + separator_indent * 0.90, (width_bar / 2)], + ] + + line = plt.Polygon( + points_rectangle, closed=True, fill=True, facecolor=colors[0], linewidth=0 + ) + rectangle_list += [line] + + # Create separator + points_separator = [ + [separator_pos, 0], + [separator_pos + separator_indent, (width_bar / 2)], + [separator_pos, width_bar], + ] + + line = plt.Polygon(points_separator, closed=None, fill=None, edgecolor=colors[1], lw=3) + separator_list += [line] + + return rectangle_list, separator_list + + +def _add_labels( + fig: plt.Figure, + ax: plt.Axes, + out_value: float, + features: np.ndarray, + feature_type: str, + offset_text: float, + total_effect: float = 0, + min_perc: float = 0.05, + text_rotation: float = 0, +) -> None: + """ + Add labels to the plot. + Args: + fig: Figure of the plot + ax: Axes of the plot + out_value: output value + features: The values and names of the features + feature_type: Indicating whether positive or negative features + offset_text: value to offset name of the features + total_effect: Total value of all features. Used to filter out features that do not contribute at least min_perc to the total effect. + Defaults to 0 indicating that all features are shown. + min_perc: minimal percentage of the total effect that a feature must contribute to be shown. Defaults to 0.05. + text_rotation: Degree the text should be rotated. Defaults to 0. + + Returns: + + """ + start_text = out_value + pre_val = out_value + + # Define variables specific to positive and negative effect features + if feature_type == "positive": + colors = ["#FF0D57", "#FFC3D5"] + alignment = "right" + sign = 1 + else: + colors = ["#1E88E5", "#D1E6FA"] + alignment = "left" + sign = -1 + + # Draw initial line + if feature_type == "positive": + x, y = np.array([[pre_val, pre_val], [0, -0.18]]) + line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0]) + line.set_clip_on(False) + ax.add_line(line) + start_text = pre_val + + box_end = out_value + val = out_value + for feature in features: + # Exclude all labels that do not contribute at least 10% to the total + feature_contribution = np.abs(float(feature[0]) - pre_val) / np.abs(total_effect) + if feature_contribution < min_perc: + break + + # Compute value for current feature + val = float(feature[0]) + + # Draw labels. + text = feature[1] + + if text_rotation != 0: + va_alignment = "top" + else: + va_alignment = "baseline" + + text_out_val = plt.text( + start_text - sign * offset_text, + -0.15, + text, + fontsize=12, + color=colors[0], + horizontalalignment=alignment, + va=va_alignment, + rotation=text_rotation, + ) + text_out_val.set_bbox(dict(facecolor="none", edgecolor="none")) + + # We need to draw the plot to be able to get the size of the + # text box + fig.canvas.draw() + box_size = text_out_val.get_bbox_patch().get_extents().transformed(ax.transData.inverted()) + if feature_type == "positive": + box_end_ = box_size.get_points()[0][0] + else: + box_end_ = box_size.get_points()[1][0] + + # Create end line + if (sign * box_end_) > (sign * val): + x, y = np.array([[val, val], [0, -0.18]]) + line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0]) + line.set_clip_on(False) + ax.add_line(line) + start_text = val + box_end = val + + else: + box_end = box_end_ - sign * offset_text + x, y = np.array([[val, box_end, box_end], [0, -0.08, -0.18]]) + line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0]) + line.set_clip_on(False) + ax.add_line(line) + start_text = box_end + + # Update previous value + pre_val = float(feature[0]) + + # Create line for labels + extent_shading = [out_value, box_end, 0, -0.31] + path = [ + [out_value, 0], + [pre_val, 0], + [box_end, -0.08], + [box_end, -0.2], + [out_value, -0.2], + [out_value, 0], + ] + + path = Path(path) + patch = PathPatch(path, facecolor="none", edgecolor="none") + ax.add_patch(patch) + + # Extend axis if needed + lower_lim, upper_lim = ax.get_xlim() + if box_end < lower_lim: + ax.set_xlim(box_end, upper_lim) + + if box_end > upper_lim: + ax.set_xlim(lower_lim, box_end) + + # Create shading + if feature_type == "positive": + colors = np.array([(255, 13, 87), (255, 255, 255)]) / 255.0 + else: + colors = np.array([(30, 136, 229), (255, 255, 255)]) / 255.0 + + cm = matplotlib.colors.LinearSegmentedColormap.from_list("cm", colors) + + _, Z2 = np.meshgrid(np.linspace(0, 10), np.linspace(-10, 10)) + im = plt.imshow( + Z2, + interpolation="quadric", + cmap=cm, + vmax=0.01, + alpha=0.3, + origin="lower", + extent=extent_shading, + clip_path=patch, + clip_on=True, + aspect="auto", + ) + im.set_clip_path(patch) + + return fig, ax + + +def _add_output_element(out_name: str, out_value: float, ax: plt.Axes) -> None: + """ + Add grew line indicating the output value to the plot. + Args: + out_name: Name of the output value + out_value: Value of the output + ax: Axis of the plot + + Returns: Nothing + + """ + # Add output value + x, y = np.array([[out_value, out_value], [0, 0.24]]) + line = lines.Line2D(x, y, lw=2.0, color="#F2F2F2") + line.set_clip_on(False) + ax.add_line(line) + + font0 = FontProperties() + font = font0.copy() + font.set_weight("bold") + text_out_val = plt.text( + out_value, + 0.25, + f"{out_value:.2f}", + fontproperties=font, + fontsize=14, + horizontalalignment="center", + ) + text_out_val.set_bbox(dict(facecolor="white", edgecolor="white")) + + text_out_val = plt.text( + out_value, 0.33, out_name, fontsize=12, alpha=0.5, horizontalalignment="center" + ) + text_out_val.set_bbox(dict(facecolor="white", edgecolor="white")) + + +def _add_base_value(base_value: float, ax: plt.Axes) -> None: + """ + Add base value to the plot. + Args: + base_value: the base value of the game + ax: Axes of the plot + + Returns: None + + """ + x, y = np.array([[base_value, base_value], [0.13, 0.25]]) + line = lines.Line2D(x, y, lw=2.0, color="#F2F2F2") + line.set_clip_on(False) + ax.add_line(line) + + text_out_val = ax.text( + base_value, 0.25, "base value", fontsize=12, alpha=1, horizontalalignment="center" + ) + text_out_val.set_bbox(dict(facecolor="white", edgecolor="white")) + + +def update_axis_limits( + ax: plt.Axes, + total_pos: float, + pos_features: np.ndarray, + total_neg: float, + neg_features: np.ndarray, + base_value: float, + out_value: float, +) -> None: + """ + Adjust the axis limits of the plot according to values. + Args: + ax: Axes of the plot + total_pos: value of the total positive features + pos_features: values and names of the positive features + total_neg: value of the total negative features + neg_features: values and names of the negative features + base_value: the base value of the game + out_value: the output value + + Returns: None + + """ + ax.set_ylim(-0.5, 0.15) + padding = np.max([np.abs(total_pos) * 0.2, np.abs(total_neg) * 0.2]) + + if len(pos_features) > 0: + min_x = min(np.min(pos_features[:, 0].astype(float)), base_value) - padding + else: + min_x = out_value - padding + if len(neg_features) > 0: + max_x = max(np.max(neg_features[:, 0].astype(float)), base_value) + padding + else: + max_x = out_value + padding + ax.set_xlim(min_x, max_x) + + plt.tick_params( + top=True, + bottom=False, + left=False, + right=False, + labelleft=False, + labeltop=True, + labelbottom=False, + ) + plt.locator_params(axis="x", nbins=12) + + for key, spine in zip(plt.gca().spines.keys(), plt.gca().spines.values()): + if key != "top": + spine.set_visible(False) + + +def _split_features( + interaction_dictionary: dict[tuple[int, ...], float], + feature_to_names: dict[int, str], + out_value: float, +) -> tuple[np.ndarray, np.ndarray, float, float]: + """Splits the features into positive and negative values. + + Args: + interaction_dictionary: Dictionary containing the interaction values mapping from + feature indices to their values. + feature_to_names: Dictionary mapping feature indices to feature names. + out_value: The output value. + + Returns: + tuple: A tuple containing the positive features, negative features, total positive value, + and total negative value. + """ + # split features into positive and negative values + pos_features, neg_features = [], [] + for coaltion, value in interaction_dictionary.items(): + if len(coaltion) == 0: + continue + label = format_labels(feature_to_names, coaltion) + if value >= 0: + pos_features.append([str(value), label]) + elif value < 0: + neg_features.append([str(value), label]) + pos_features = sorted(pos_features, key=lambda x: x[0], reverse=True) + neg_features = sorted(neg_features, key=lambda x: x[0], reverse=True) + pos_features = np.array(pos_features, dtype=object) + neg_features = np.array(neg_features, dtype=object) + + # convert negative feature values to plot values + neg_val = out_value + for i in neg_features: + val = float(i[0]) + neg_val = neg_val + np.abs(val) + i[0] = neg_val + if len(neg_features) > 0: + total_neg = np.max(neg_features[:, 0].astype(float)) - np.min( + neg_features[:, 0].astype(float) + ) + else: + total_neg = 0 + + # convert positive feature values to plot values + pos_val = out_value + for i in pos_features: + val = float(i[0]) + pos_val = pos_val - np.abs(val) + i[0] = pos_val + + if len(pos_features) > 0: + total_pos = np.max(pos_features[:, 0].astype(float)) - np.min( + pos_features[:, 0].astype(float) + ) + else: + total_pos = 0 + + return pos_features, neg_features, total_pos, total_neg + + +def _add_bars( + ax: plt.Axes, out_value: float, pos_features: np.ndarray, neg_features: np.ndarray +) -> None: + """ + Add bars to the plot. + Args: + ax: Axes of the plot + out_value: grand total value + pos_features: positive features + neg_features: negative features + + Returns: + + """ + width_bar = 0.1 + width_separators = (ax.get_xlim()[1] - ax.get_xlim()[0]) / 200 + # Create bar for negative shap values + rectangle_list, separator_list = _create_bars( + out_value, neg_features, "negative", width_separators, width_bar + ) + for i in rectangle_list: + ax.add_patch(i) + + for i in separator_list: + ax.add_patch(i) + + # Create bar for positive shap values + rectangle_list, separator_list = _create_bars( + out_value, pos_features, "positive", width_separators, width_bar + ) + for i in rectangle_list: + ax.add_patch(i) + + for i in separator_list: + ax.add_patch(i) + + +def draw_higher_lower_element(out_value, offset_text): + plt.text( + out_value - offset_text, + 0.35, + "higher", + fontsize=13, + color="#FF0D57", + horizontalalignment="right", + ) + plt.text( + out_value + offset_text, + 0.35, + "lower", + fontsize=13, + color="#1E88E5", + horizontalalignment="left", + ) + plt.text( + out_value, 0.34, r"$\leftarrow$", fontsize=13, color="#1E88E5", horizontalalignment="center" + ) + plt.text( + out_value, + 0.36, + r"$\rightarrow$", + fontsize=13, + color="#FF0D57", + horizontalalignment="center", + ) + + +def _draw_force_plot( + interaction_value: InteractionValues, + feature_names: np.ndarray, + figsize: tuple[int, int], + min_perc: float = 0.05, + draw_higher_lower: bool = True, +) -> plt.Figure: + """ + Draw the force plot. + Args: + interaction_value: The ``InteractionValues`` to be plotted. + feature_names: Names of the features to be plotted provided as an array. + figsize: The size of the figure. + min_perc: Define the minimum percentage of the total effect that a feature must contribute + to be shown in the plot. Defaults to 0.05. + + Returns: None + + """ + # turn off interactive plot + plt.ioff() + + # compute overall metrics + base_value = interaction_value.baseline_value + out_value = np.sum(interaction_value.values) # Sum of all values with the baseline value + + # split features into positive and negative values + features_to_names = {i: str(name) for i, name in enumerate(feature_names)} + pos_features, neg_features, total_pos, total_neg = _split_features( + interaction_value.dict_values, features_to_names, out_value + ) + + # define plots + offset_text = (np.abs(total_neg) + np.abs(total_pos)) * 0.04 + + fig, ax = plt.subplots(figsize=figsize) + + # compute axis limit + update_axis_limits(ax, total_pos, pos_features, total_neg, neg_features, base_value, out_value) + + # add the bars to the plot + _add_bars(ax, out_value, pos_features, neg_features) + + # add labels + total_effect = np.abs(total_neg) + total_pos + fig, ax = _add_labels( + fig, + ax, + out_value, + neg_features, + "negative", + offset_text, + total_effect, + min_perc=min_perc, + text_rotation=0, + ) + + fig, ax = _add_labels( + fig, + ax, + out_value, + pos_features, + "positive", + offset_text, + total_effect, + min_perc=min_perc, + text_rotation=0, + ) + + # add higher and lower element + if draw_higher_lower: + draw_higher_lower_element(out_value, offset_text) + + # add label for base value + _add_base_value(base_value, ax) + + # add output label + out_names = "" + _add_output_element(out_names, out_value, ax) + + # fix the whitespace around the plot + plt.tight_layout() + + return plt.gcf() + + def force_plot( interaction_values: InteractionValues, feature_names: Optional[np.ndarray] = None, - feature_values: Optional[np.ndarray] = None, - matplotlib: bool = True, - show: bool = False, abbreviate: bool = True, - **kwargs, + show: bool = False, + figsize: tuple[int, int] = (15, 4), + draw_higher_lower: bool = True, + min_percentage: float = 0.05, ) -> Optional[plt.Figure]: - """Draws interaction values on a force plot. - - Requires the ``shap`` Python package to be installed. + """Draws a force plot for the given interaction values. Args: - interaction_values: The interaction values as an interaction object. - feature_names: The feature names used for plotting. If no feature names are provided, the - feature indices are used instead. Defaults to ``None``. - feature_values: The feature values used for plotting. Defaults to ``None``. - matplotlib: Whether to return a ``matplotlib`` figure. Defaults to ``True``. - show: Whether to show the plot. Defaults to ``False``. - abbreviate: Whether to abbreviate the feature names. Defaults to ``True``. - **kwargs: Keyword arguments passed to ``shap.plots.force()``. - """ - check_import_module("shap") - import shap + interaction_values: The ``InteractionValues`` to be plotted. + feature_names: The names of the features. If ``None``, the features are named by their index. + show: Whether to show or return the plot. Defaults to ``False`` and returns the plot. + abbreviate: Whether to abbreviate the feature names. Defaults to ``True.`` + figsize: The size of the figure. Defaults to ``(15, 4)``. + draw_higher_lower: Whether to draw the higher and lower indicator. Defaults to ``True``. + min_percentage: Define the minimum percentage of the total effect that a feature must contribute + to be shown in the plot. Defaults to 0.05. - _shap_values, _labels = get_interaction_values_and_feature_names( - interaction_values, feature_names, feature_values, abbreviate=abbreviate - ) + Returns: + plt.Figure: The figure of the plot - return shap.plots.force( - base_value=np.array([interaction_values.baseline_value], dtype=float), # must be array - shap_values=np.array(_shap_values), - feature_names=_labels, - matplotlib=matplotlib, - show=show, - **kwargs, + """ + if feature_names is None: + feature_names = [str(i) for i in range(interaction_values.n_players)] + if abbreviate: + feature_names = abbreviate_feature_names(feature_names) + feature_names = np.array(feature_names) + plot = _draw_force_plot( + interaction_values, + feature_names, + figsize=figsize, + draw_higher_lower=draw_higher_lower, + min_perc=min_percentage, ) + if not show: + return plot + plt.show() diff --git a/shapiq/plot/utils.py b/shapiq/plot/utils.py index d5c6a117..664592f8 100644 --- a/shapiq/plot/utils.py +++ b/shapiq/plot/utils.py @@ -1,61 +1,67 @@ """This utility module contains helper functions for plotting.""" -import copy +import re from collections.abc import Iterable -from typing import Optional +from typing import Union -import numpy as np +__all__ = ["abbreviate_feature_names", "format_value", "format_labels"] -from ..interaction_values import InteractionValues -from ..utils import powerset -__all__ = ["get_interaction_values_and_feature_names", "abbreviate_feature_names"] +def format_value( + s: Union[float, str], + format_str: str = "%.2f", +) -> str: + """Strips trailing zeros and uses a unicode minus sign. + + Args: + s: The value to be formatted. + format_str: The format string to be used. Defaults to "%.2f". + + Returns: + str: The formatted value. + + Examples: + >>> format_value(1.0) + "1" + >>> format_value(1.234) + "1.23" + """ + if not issubclass(type(s), str): + s = format_str % s + s = re.sub(r"\.?0+$", "", s) + if s[0] == "-": + s = "\u2212" + s[1:] + return s -def get_interaction_values_and_feature_names( - interaction_values: InteractionValues, - feature_names: Optional[np.ndarray] = None, - feature_values: Optional[np.ndarray] = None, - abbreviate: bool = True, -) -> tuple[np.ndarray, np.ndarray]: - """Converts higher-order interaction values to SHAP-like vectors with associated labels. +def format_labels( + feature_mapping: dict[int, str], + feature_tuple: tuple[int, ...], +) -> str: + """Formats the feature labels for the plots. Args: - interaction_values: The interaction values as an interaction object. - feature_names: The feature names used for plotting. If no feature names are provided, the - feature indices are used instead. Defaults to ``None``. - feature_values: The feature values used for plotting. Defaults to ``None``. - abbreviate: Whether to abbreviate the feature names. Defaults to ``True``. + feature_mapping: A dictionary mapping feature indices to feature names. + feature_tuple: The feature tuple to be formatted. Returns: - A tuple containing the SHAP values and the corresponding labels. + str: The formatted feature tuple. + + Example: + >>> feature_mapping = {0: "A", 1: "B", 2: "C"} + >>> format_labels(feature_mapping, (0, 1)) + "A x B" + >>> format_labels(feature_mapping, (0,)) + "A" + >>> format_labels(feature_mapping, ()) + "Base Value" """ - feature_names = copy.deepcopy(feature_names) - if feature_names is not None and abbreviate: - feature_names = abbreviate_feature_names(feature_names) - _values_dict = {} - for i in range(1, interaction_values.max_order + 1): - _values_dict[i] = interaction_values.get_n_order_values(i) - _n_features = len(_values_dict[1]) - _shap_values = [] - _labels = [] - for interaction in powerset( - range(_n_features), min_size=1, max_size=interaction_values.max_order - ): - _order = len(interaction) - _values = _values_dict[_order] - _shap_values.append(_values[interaction]) - if feature_names is not None: - _name = " x ".join(str(feature_names[i]) for i in interaction) - else: - _name = " x ".join(f"{feature}" for feature in interaction) - if feature_values is not None: - _name += "\n" - _name += " x ".join(f"{feature_values[i]}".strip()[0:4] for i in interaction) - _labels.append(_name) - _shap_values = np.array(_shap_values) - _labels = np.array(_labels) - return _shap_values, _labels + if len(feature_tuple) == 0: + return "Base Value" + elif len(feature_tuple) == 1: + return str(feature_mapping[feature_tuple[0]]) + else: + return " x ".join([str(feature_mapping[f]) for f in feature_tuple]) def abbreviate_feature_names(feature_names: Iterable[str]) -> list[str]: diff --git a/shapiq/plot/watefall.py b/shapiq/plot/watefall.py index 0af4b941..4b8a19db 100644 --- a/shapiq/plot/watefall.py +++ b/shapiq/plot/watefall.py @@ -2,58 +2,359 @@ from typing import Optional +import matplotlib import matplotlib.pyplot as plt import numpy as np from ..interaction_values import InteractionValues -from ..utils.modules import check_import_module -from .utils import get_interaction_values_and_feature_names +from ._config import BLUE, RED +from .utils import abbreviate_feature_names, format_labels, format_value __all__ = ["waterfall_plot"] +def _draw_waterfall_plot( + values: np.ndarray, base_values: float, feature_names: list[str], max_display=10, show=True +) -> Optional[plt.Axes]: + """ + Create a waterfall plot idential to SHAP waterfall plot (https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/waterfall.html). + Args: + values: the explanation values + base_values: the base value of the game + feature_names: the names of the features + max_display: the maximum number of features to display + show: whether to show the plot + + Returns: the plot if show is False + + """ + # Turn off interactive plot + if show is False: + plt.ioff() + + # init variables we use for tracking the plot locations + num_features = min(max_display, len(values)) + row_height = 0.5 + rng = range(num_features - 1, -1, -1) + order = np.argsort(-np.abs(values)) + pos_lefts = [] + pos_inds = [] + pos_widths = [] + pos_low = [] + pos_high = [] + neg_lefts = [] + neg_inds = [] + neg_widths = [] + neg_low = [] + neg_high = [] + loc = base_values + values.sum() + yticklabels = ["" for _ in range(num_features + 1)] + + # size the plot based on how many features we are plotting + plt.gcf().set_size_inches(8, num_features * row_height + 1.5) + + # see how many individual (vs. grouped at the end) features we are plotting + if num_features == len(values): + num_individual = num_features + else: + num_individual = num_features - 1 + + # compute the locations of the individual features and plot the dashed connecting lines + for i in range(num_individual): + sval = values[order[i]] + loc -= sval + if sval >= 0: + pos_inds.append(rng[i]) + pos_widths.append(sval) + pos_lefts.append(loc) + else: + neg_inds.append(rng[i]) + neg_widths.append(sval) + neg_lefts.append(loc) + if num_individual != num_features or i + 4 < num_individual: + plt.plot( + [loc, loc], + [rng[i] - 1 - 0.4, rng[i] + 0.4], + color="#bbbbbb", + linestyle="--", + linewidth=0.5, + zorder=-1, + ) + yticklabels[rng[i]] = feature_names[order[i]] + + # add a last grouped feature to represent the impact of all the features we didn't show + if num_features < len(values): + yticklabels[0] = "%d other features".format() + remaining_impact = base_values - loc + if remaining_impact < 0: + pos_inds.append(0) + pos_widths.append(-remaining_impact) + pos_lefts.append(loc + remaining_impact) + else: + neg_inds.append(0) + neg_widths.append(-remaining_impact) + neg_lefts.append(loc + remaining_impact) + + points = ( + pos_lefts + + list(np.array(pos_lefts) + np.array(pos_widths)) + + neg_lefts + + list(np.array(neg_lefts) + np.array(neg_widths)) + ) + dataw = np.max(points) - np.min(points) + + # draw invisible bars just for sizing the axes + label_padding = np.array([0.1 * dataw if w < 1 else 0 for w in pos_widths]) + plt.barh( + pos_inds, + np.array(pos_widths) + label_padding + 0.02 * dataw, + left=np.array(pos_lefts) - 0.01 * dataw, + color=RED.hex, + alpha=0, + ) + label_padding = np.array([-0.1 * dataw if -w < 1 else 0 for w in neg_widths]) + plt.barh( + neg_inds, + np.array(neg_widths) + label_padding - 0.02 * dataw, + left=np.array(neg_lefts) + 0.01 * dataw, + color=BLUE.hex, + alpha=0, + ) + + # define variable we need for plotting the arrows + head_length = 0.08 + bar_width = 0.8 + xlen = plt.xlim()[1] - plt.xlim()[0] + fig = plt.gcf() + ax = plt.gca() + bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) + width = bbox.width + bbox_to_xscale = xlen / width + hl_scaled = bbox_to_xscale * head_length + renderer = fig.canvas.get_renderer() + + # draw the positive arrows + for i in range(len(pos_inds)): + dist = pos_widths[i] + arrow_obj = plt.arrow( + pos_lefts[i], + pos_inds[i], + max(dist - hl_scaled, 0.000001), + 0, + head_length=min(dist, hl_scaled), + color=RED.hex, + width=bar_width, + head_width=bar_width, + ) + + if pos_low is not None and i < len(pos_low): + plt.errorbar( + pos_lefts[i] + pos_widths[i], + pos_inds[i], + xerr=np.array([[pos_widths[i] - pos_low[i]], [pos_high[i] - pos_widths[i]]]), + ecolor=BLUE.hex, + ) + + txt_obj = plt.text( + pos_lefts[i] + 0.5 * dist, + pos_inds[i], + format_value(pos_widths[i], "%+0.02f"), + horizontalalignment="center", + verticalalignment="center", + color="white", + fontsize=12, + ) + text_bbox = txt_obj.get_window_extent(renderer=renderer) + arrow_bbox = arrow_obj.get_window_extent(renderer=renderer) + + # if the text overflows the arrow then draw it after the arrow + if text_bbox.width > arrow_bbox.width: + txt_obj.remove() + + txt_obj = plt.text( + pos_lefts[i] + (5 / 72) * bbox_to_xscale + dist, + pos_inds[i], + format_value(pos_widths[i], "%+0.02f"), + horizontalalignment="left", + verticalalignment="center", + color=RED.hex, + fontsize=12, + ) + + # draw the negative arrows + for i in range(len(neg_inds)): + dist = neg_widths[i] + + arrow_obj = plt.arrow( + neg_lefts[i], + neg_inds[i], + -max(-dist - hl_scaled, 0.000001), + 0, + head_length=min(-dist, hl_scaled), + color=BLUE.hex, + width=bar_width, + head_width=bar_width, + ) + + if neg_low is not None and i < len(neg_low): + plt.errorbar( + neg_lefts[i] + neg_widths[i], + neg_inds[i], + xerr=np.array([[neg_widths[i] - neg_low[i]], [neg_high[i] - neg_widths[i]]]), + ecolor=RED.hex, + ) + + txt_obj = plt.text( + neg_lefts[i] + 0.5 * dist, + neg_inds[i], + format_value(neg_widths[i], "%+0.02f"), + horizontalalignment="center", + verticalalignment="center", + color="white", + fontsize=12, + ) + text_bbox = txt_obj.get_window_extent(renderer=renderer) + arrow_bbox = arrow_obj.get_window_extent(renderer=renderer) + + # if the text overflows the arrow then draw it after the arrow + if text_bbox.width > arrow_bbox.width: + txt_obj.remove() + + plt.text( + neg_lefts[i] - (5 / 72) * bbox_to_xscale + dist, + neg_inds[i], + format_value(neg_widths[i], "%+0.02f"), + horizontalalignment="right", + verticalalignment="center", + color=BLUE.hex, + fontsize=12, + ) + + # draw the y-ticks twice, once in gray and then again with just the feature names in black + # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks + ytick_pos = list(range(num_features)) + list(np.arange(num_features) + 1e-8) + plt.yticks( + ytick_pos, + yticklabels[:-1] + [label.split("=")[-1] for label in yticklabels[:-1]], + fontsize=13, + ) + + # put horizontal lines for each feature row + for i in range(num_features): + plt.axhline(i, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1) + + # mark the prior expected value and the model prediction + plt.axvline( + base_values, 0, 1 / num_features, color="#bbbbbb", linestyle="--", linewidth=0.5, zorder=-1 + ) + fx = base_values + values.sum() + plt.axvline(fx, 0, 1, color="#bbbbbb", linestyle="--", linewidth=0.5, zorder=-1) + + # clean up the main axis + plt.gca().xaxis.set_ticks_position("bottom") + plt.gca().yaxis.set_ticks_position("none") + plt.gca().spines["right"].set_visible(False) + plt.gca().spines["top"].set_visible(False) + plt.gca().spines["left"].set_visible(False) + ax.tick_params(labelsize=13) + # plt.xlabel("\nModel output", fontsize=12) + + # draw the E[f(X)] tick mark + xmin, xmax = ax.get_xlim() + ax2 = ax.twiny() + ax2.set_xlim(xmin, xmax) + ax2.set_xticks( + [base_values, base_values + 1e-8] + ) # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks + ax2.set_xticklabels( + ["\n$E[f(X)]$", "\n$ = " + format_value(base_values, "%0.03f") + "$"], + fontsize=12, + ha="left", + ) + ax2.spines["right"].set_visible(False) + ax2.spines["top"].set_visible(False) + ax2.spines["left"].set_visible(False) + + # draw the f(x) tick mark + ax3 = ax2.twiny() + ax3.set_xlim(xmin, xmax) + # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks + ax3.set_xticks([base_values + values.sum(), base_values + values.sum() + 1e-8]) + ax3.set_xticklabels( + ["$f(x)$", "$ = " + format_value(fx, "%0.03f") + "$"], fontsize=12, ha="left" + ) + tick_labels = ax3.xaxis.get_majorticklabels() + tick_labels[0].set_transform( + tick_labels[0].get_transform() + + matplotlib.transforms.ScaledTranslation(-10 / 72.0, 0, fig.dpi_scale_trans) + ) + tick_labels[1].set_transform( + tick_labels[1].get_transform() + + matplotlib.transforms.ScaledTranslation(12 / 72.0, 0, fig.dpi_scale_trans) + ) + tick_labels[1].set_color("#999999") + ax3.spines["right"].set_visible(False) + ax3.spines["top"].set_visible(False) + ax3.spines["left"].set_visible(False) + + # adjust the position of the E[f(X)] = x.xx label + tick_labels = ax2.xaxis.get_majorticklabels() + tick_labels[0].set_transform( + tick_labels[0].get_transform() + + matplotlib.transforms.ScaledTranslation(-20 / 72.0, 0, fig.dpi_scale_trans) + ) + tick_labels[1].set_transform( + tick_labels[1].get_transform() + + matplotlib.transforms.ScaledTranslation(22 / 72.0, -1 / 72.0, fig.dpi_scale_trans) + ) + + tick_labels[1].set_color("#999999") + + # color the y tick labels that have the feature values as gray + # (these fall behind the black ones with just the feature name) + tick_labels = ax.yaxis.get_majorticklabels() + for i in range(num_features): + tick_labels[i].set_color("#999999") + + if show: + plt.show() + else: + return plt.gca() + + def waterfall_plot( interaction_values: InteractionValues, - feature_names: Optional[np.ndarray] = None, - feature_values: Optional[np.ndarray] = None, + feature_names: Optional[np.ndarray[str]] = None, show: bool = False, - abbreviate: bool = True, max_display: int = 10, + abbreviate: bool = True, ) -> Optional[plt.Axes]: - """Draws interaction values on a waterfall plot. - - Note: - Requires the ``shap`` Python package to be installed. + """Draws a waterfall plot with the interaction values. Args: interaction_values: The interaction values as an interaction object. - feature_names: The feature names used for plotting. If no feature names are provided, the - feature indices are used instead. Defaults to ``None``. - feature_values: The feature values used for plotting. Defaults to ``None``. + feature_names: The names of the features. Defaults to ``None``. show: Whether to show the plot. Defaults to ``False``. - abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``. max_display: The maximum number of interactions to display. Defaults to ``10``. + abbreviate: Whether to abbreviate the feature names. Defaults to ``True``. """ - check_import_module("shap") - import shap - - if interaction_values.max_order == 1: - shap_explanation = shap.Explanation( - values=interaction_values.get_n_order_values(1), - base_values=interaction_values.baseline_value, - data=feature_values, - feature_names=feature_names, - ) + + if feature_names is None: + feature_mapping = {i: str(i) for i in range(interaction_values.n_players)} else: - _shap_values, _labels = get_interaction_values_and_feature_names( - interaction_values, feature_names, feature_values, abbreviate=abbreviate - ) + if abbreviate: + feature_names = abbreviate_feature_names(feature_names) + feature_mapping = {i: feature_names[i] for i in range(interaction_values.n_players)} - shap_explanation = shap.Explanation( - values=np.array(_shap_values), - base_values=np.array([interaction_values.baseline_value], dtype=float), - data=None, - feature_names=_labels, - ) + # create the data for the waterfall plot in the correct format + data = [] + for feature_tuple, value in interaction_values.dict_values.items(): + if len(feature_tuple) > 0: + data.append((format_labels(feature_mapping, feature_tuple), str(value))) + data = np.array(data, dtype=object) + values = data[:, 1].astype(float) + feature_names = data[:, 0] - return shap.plots.waterfall(shap_explanation, max_display=max_display, show=show) + return _draw_waterfall_plot( + values, interaction_values.baseline_value, feature_names, max_display=max_display, show=show + ) diff --git a/tests/conftest.py b/tests/conftest.py index 7da6bfca..b8c5a40e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,40 @@ NR_FEATURES = 7 +@pytest.fixture +def cooking_game(): + """Return a simple game object.""" + import shapiq + + class CookingGame(shapiq.Game): + def __init__(self): + self.characteristic_function = { + (): 10, + (0,): 4, + (1,): 3, + (2,): 2, + (0, 1): 9, + (0, 2): 8, + (1, 2): 7, + (0, 1, 2): 15, + } + super().__init__( + n_players=3, + player_names=["Alice", "Bob", "Charlie"], # Optional list of names + normalization_value=self.characteristic_function[()], # 0 + normalize=False, + ) + + def value_function(self, coalitions: np.ndarray) -> np.ndarray: + """Defines the worth of a coalition as a lookup in the characteristic function.""" + output = [] + for coalition in coalitions: + output.append(self.characteristic_function[tuple(np.where(coalition)[0])]) + return np.array(output) + + return CookingGame() + + @pytest.fixture def dt_reg_model() -> DecisionTreeRegressor: """Return a simple decision tree model.""" @@ -326,6 +360,8 @@ def mae_loss(): @pytest.fixture def interaction_values_list(): """Returns a list of three InteractionValues objects.""" + rng = np.random.RandomState(42) + from shapiq.interaction_values import InteractionValues from shapiq.utils import powerset @@ -341,7 +377,7 @@ def interaction_values_list(): powerset(range(n_players), min_size=min_order, max_size=max_order) ): interaction_lookup[interaction] = i - values.append(np.random.rand()) + values.append(rng.uniform(0, 1)) values = np.array(values) iv = InteractionValues( n_players=n_players, diff --git a/tests/requirements/requirements.txt b/tests/requirements/requirements.txt index 303ab789..7a750393 100644 --- a/tests/requirements/requirements.txt +++ b/tests/requirements/requirements.txt @@ -11,7 +11,6 @@ ruff==0.6.2 scikit-image==0.24.0 scikit-learn==1.5.1 scipy==1.13.0 -shap==0.46.0 tqdm==4.66.5 torch==2.4.0 torchvision==0.19.0 diff --git a/tests/test_base_interaction_values.py b/tests/test_base_interaction_values.py index 4e4e81bc..a93f066f 100644 --- a/tests/test_base_interaction_values.py +++ b/tests/test_base_interaction_values.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from shapiq.interaction_values import InteractionValues +from shapiq.interaction_values import InteractionValues, aggregate_interaction_values from shapiq.utils import powerset @@ -626,3 +626,103 @@ def test_subset(): assert subset_interaction_values.estimated == interaction_values.estimated assert subset_interaction_values.estimation_budget == interaction_values.estimation_budget assert subset_interaction_values.index == interaction_values.index + + +@pytest.mark.parametrize("aggregation", ["sum", "mean", "median", "max", "min"]) +def test_aggregation(aggregation): + + n_objects = 3 + + n, min_order, max_order = 5, 1, 3 + interaction_values_list = [] + for _ in range(n_objects): + values = np.random.rand(2**n - 1) + interaction_lookup = { + interaction: i for i, interaction in enumerate(powerset(range(n), min_order, max_order)) + } + interaction_values = InteractionValues( + values=values, + index="SII", + max_order=max_order, + n_players=n, + min_order=min_order, + interaction_lookup=interaction_lookup, + estimated=False, + estimation_budget=0, + baseline_value=0.0, + ) + interaction_values_list.append(interaction_values) + + aggregated_interaction_values = aggregate_interaction_values( + interaction_values_list, aggregation=aggregation + ) + + assert isinstance(aggregated_interaction_values, InteractionValues) + assert aggregated_interaction_values.index == "SII" + assert aggregated_interaction_values.n_players == n + assert aggregated_interaction_values.min_order == min_order + assert aggregated_interaction_values.max_order == max_order + + # check that all interactions are equal to the expected value + for interaction in powerset(range(n), 1, n): + aggregated_value = np.array( + [interaction_values[interaction] for interaction_values in interaction_values_list] + ) + if aggregation == "sum": + expected_value = np.sum(aggregated_value) + elif aggregation == "mean": + expected_value = np.mean(aggregated_value) + elif aggregation == "median": + expected_value = np.median(aggregated_value) + elif aggregation == "max": + expected_value = np.max(aggregated_value) + elif aggregation == "min": + expected_value = np.min(aggregated_value) + assert aggregated_interaction_values[interaction] == expected_value + + # test aggregate from InteractionValues object + aggregated_from_object = interaction_values_list[0].aggregate( + aggregation=aggregation, others=interaction_values_list[1:] + ) + assert isinstance(aggregated_from_object, InteractionValues) + assert aggregated_from_object == aggregated_interaction_values # same values + assert aggregated_from_object is not aggregated_interaction_values # but different objects + + +def test_docs_aggregation_function(): + """Tests the aggregation function in the InteractionValues dataclass like in the docs.""" + + iv1 = InteractionValues( + values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]), + index="SII", + n_players=3, + min_order=1, + max_order=2, + interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5}, + baseline_value=0.0, + ) + + # this does not contain the (1, 2) interaction (i.e. is 0) + iv2 = InteractionValues( + values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]), + index="SII", + n_players=3, + min_order=1, + max_order=2, + interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4}, + baseline_value=1.0, + ) + + # test sum + aggregated_interaction_values = aggregate_interaction_values([iv1, iv2], aggregation="sum") + assert pytest.approx(aggregated_interaction_values[(0,)]) == 0.3 + assert pytest.approx(aggregated_interaction_values[(1,)]) == 0.5 + assert pytest.approx(aggregated_interaction_values[(1, 2)]) == 0.6 + assert pytest.approx(aggregated_interaction_values.baseline_value) == 1.0 + + # test mean + aggregated_interaction_values = aggregate_interaction_values([iv1, iv2], aggregation="mean") + assert pytest.approx(aggregated_interaction_values[(0,)]) == 0.15 + assert pytest.approx(aggregated_interaction_values[(1,)]) == 0.25 + assert pytest.approx(aggregated_interaction_values[(1, 2)]) == 0.3 + assert pytest.approx(aggregated_interaction_values.baseline_value) == 0.5 diff --git a/tests/tests_explainer/test_explainer_tabular.py b/tests/tests_explainer/test_explainer_tabular.py index 829eb586..0d5f33bf 100644 --- a/tests/tests_explainer/test_explainer_tabular.py +++ b/tests/tests_explainer/test_explainer_tabular.py @@ -176,25 +176,34 @@ def test_explain(dt_model, data, index, budget, max_order, imputer): def test_against_shap_linear(): """Tests weather TabularExplainer yields similar results as SHAP with a basic linear model.""" - import shap n_samples = 3 dim = 5 + rng = np.random.default_rng(42) def make_linear_model(): - w = np.random.default_rng().normal(size=dim) + w = rng.normal(size=dim) def model(X: np.ndarray): return np.dot(X, w) return model - X = np.random.default_rng().normal(size=(n_samples, dim)) + X = rng.normal(size=(n_samples, dim)) model = make_linear_model() + # import shap # compute with shap - explainer_shap = shap.explainers.Exact(model, X) - shap_values = explainer_shap(X).values + # explainer_shap = shap.explainers.Exact(model, X) + # shap_values = explainer_shap(X).values + # print(shap_values) + shap_values = np.array( + [ + [-0.29565839, -0.36698085, -0.55970434, 0.22567077, 0.05852208], + [1.08513574, 0.06365536, 0.46312977, -0.61532757, 0.00370387], + [-0.78947735, 0.30332549, 0.09657457, 0.38965679, -0.06222595], + ] + ) # compute with shapiq explainer_shapiq = TabularExplainer( diff --git a/tests/tests_plots/test_bar.py b/tests/tests_plots/test_bar.py index 4e980045..16109109 100644 --- a/tests/tests_plots/test_bar.py +++ b/tests/tests_plots/test_bar.py @@ -3,8 +3,18 @@ import matplotlib.pyplot as plt import numpy as np -from shapiq.interaction_values import InteractionValues -from shapiq.plot import bar_plot +from shapiq import ExactComputer, InteractionValues, bar_plot + + +def test_bar_cooking_game(cooking_game): + """Test the bar plot function with concrete values from the cooking game.""" + exact_computer = ExactComputer(n_players=cooking_game.n_players, game=cooking_game) + sv_exact = exact_computer(index="k-SII", order=2) + print(sv_exact.dict_values) + bar_plot([sv_exact], show=True) + + # visual inspection: + # - Order from top to bottom: Base Value, the interactions (all equal), F0, F1, F2 def test_bar_plot(interaction_values_list: list[InteractionValues]): @@ -27,3 +37,15 @@ def test_bar_plot(interaction_values_list: list[InteractionValues]): output = bar_plot(interaction_values_list, show=True) assert output is None plt.close("all") + + # test max_display=None + output = bar_plot(interaction_values_list, show=False, max_display=None) + assert output is not None + assert isinstance(output, plt.Axes) + plt.close("all") + + # test global = false + output = bar_plot(interaction_values_list, show=False, global_plot=False) + assert output is not None + assert isinstance(output, plt.Axes) + plt.close("all") diff --git a/tests/tests_plots/test_force.py b/tests/tests_plots/test_force.py index 3803437e..8cfd128a 100644 --- a/tests/tests_plots/test_force.py +++ b/tests/tests_plots/test_force.py @@ -3,8 +3,24 @@ import matplotlib.pyplot as plt import numpy as np -from shapiq.interaction_values import InteractionValues -from shapiq.plot import force_plot +from shapiq import ExactComputer, InteractionValues, force_plot + + +def test_force_cooking_game(cooking_game): + """Test the force plot function with concrete values from the cooking game.""" + exact_computer = ExactComputer(n_players=cooking_game.n_players, game=cooking_game) + interaction_values = exact_computer(index="k-SII", order=2) + print(interaction_values.dict_values) + feature_names = list(cooking_game.player_name_lookup.keys()) + force_plot(interaction_values, show=True, min_percentage=0.2, feature_names=feature_names) + plt.close() + + # visual inspection: + # - E[f(X)] = 10 + # - f(x) = 15 + # - 0, 1, and 2 should individually have negative contributions (go left) + # - all interactions should have a positive +7 contribution (go right) + # - feature 0 is too small to be displayed because of min_percentage=0.2 def test_force_plot(interaction_values_list: list[InteractionValues]): @@ -13,7 +29,6 @@ def test_force_plot(interaction_values_list: list[InteractionValues]): n_players = iv.n_players feature_names = [f"feature-{i}" for i in range(n_players)] feature_names = np.array(feature_names) - feature_values = np.array([i for i in range(n_players)]) fp = force_plot(iv, show=False) assert fp is not None @@ -25,11 +40,7 @@ def test_force_plot(interaction_values_list: list[InteractionValues]): assert isinstance(fp, plt.Figure) plt.close() - fp = force_plot(iv, show=False, feature_names=feature_names, feature_values=feature_values) - assert isinstance(fp, plt.Figure) - plt.close() - - fp = force_plot(iv, show=False, feature_names=None, feature_values=feature_values) + fp = force_plot(iv, show=False, feature_names=feature_names) assert isinstance(fp, plt.Figure) plt.close() diff --git a/tests/tests_plots/test_utils.py b/tests/tests_plots/test_utils.py new file mode 100644 index 00000000..243963d8 --- /dev/null +++ b/tests/tests_plots/test_utils.py @@ -0,0 +1,35 @@ +"""This test module tests all plotting utilities.""" + +from shapiq.plot.utils import abbreviate_feature_names, format_labels, format_value + + +def test_format_value(): + """Test the format_value function.""" + assert format_value(1.0) == "1" + assert format_value(1.234) == "1.23" + assert format_value(-1.234) == "\u22121.23" + assert format_value("1.234") == "1.234" + + +def test_format_labels(): + """Test the format_labels function.""" + feature_mapping = {0: "A", 1: "B", 2: "C"} + assert format_labels(feature_mapping, (0, 1)) == "A x B" + assert format_labels(feature_mapping, (0,)) == "A" + assert format_labels(feature_mapping, ()) == "Base Value" + assert format_labels(feature_mapping, (0, 1, 2)) == "A x B x C" + + +def test_abbreviate_feature_names(): + """Tests the abbreviate_feature_names function.""" + # check for splitting characters + feature_names = ["feature-0", "feature_1", "feature 2", "feature.3"] + assert abbreviate_feature_names(feature_names) == ["F0", "F1", "F2", "F3"] + + # check for long names + feature_names = ["longfeaturenamethatisnotshort", "stilllong"] + assert abbreviate_feature_names(feature_names) == ["lon.", "sti."] + + # check for abbreviation with capital letters + feature_names = ["LongFeatureName", "Short"] + assert abbreviate_feature_names(feature_names) == ["LFN", "Sho."] diff --git a/tests/tests_plots/test_waterfall.py b/tests/tests_plots/test_waterfall.py index bedc047e..cd06d7f9 100644 --- a/tests/tests_plots/test_waterfall.py +++ b/tests/tests_plots/test_waterfall.py @@ -3,8 +3,22 @@ import matplotlib.pyplot as plt import numpy as np -from shapiq.interaction_values import InteractionValues -from shapiq.plot import waterfall_plot +from shapiq import ExactComputer, InteractionValues, waterfall_plot + + +def test_waterfall_cooking_game(cooking_game): + """Test the waterfall plot function with concrete values from the cooking game.""" + exact_computer = ExactComputer(n_players=cooking_game.n_players, game=cooking_game) + interaction_values = exact_computer(index="k-SII", order=2) + print(interaction_values.dict_values) + waterfall_plot(interaction_values, show=True) + plt.close() + + # visual inspection: + # - E[f(X)] = 10 + # - f(x) = 15 + # - 0, 1, and 2 should individually have negative contributions (go left) + # - all interactions should have a positive +7 contribution (go right) def test_waterfall_plot(interaction_values_list: list[InteractionValues]): @@ -13,18 +27,13 @@ def test_waterfall_plot(interaction_values_list: list[InteractionValues]): n_players = iv.n_players feature_names = [f"feature-{i}" for i in range(n_players)] feature_names = np.array(feature_names) - feature_values = np.array([i for i in range(n_players)]) wp = waterfall_plot(iv, show=False) assert wp is not None assert isinstance(wp, plt.Axes) plt.close() - wp = waterfall_plot(iv, show=False, feature_names=feature_names, feature_values=feature_values) - assert isinstance(wp, plt.Axes) - plt.close() - - wp = waterfall_plot(iv, show=False, feature_names=None, feature_values=feature_values) + wp = waterfall_plot(iv, show=False, feature_names=feature_names) assert isinstance(wp, plt.Axes) plt.close()