Skip to content

Commit

Permalink
adds the stack bar plot for nSII values and closes #31
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Feb 5, 2024
1 parent 848f7b0 commit 1a2046c
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 85 deletions.
Binary file added docs/source/_static/stacked_bar_exampl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 2 additions & 3 deletions shapiq/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""This module contains all plotting functions for the shapiq package."""

from .network import network_plot
from .stacked_bar import stacked_bar_plot

__all__ = [
"network_plot",
]
__all__ = ["network_plot", "stacked_bar_plot"]
82 changes: 0 additions & 82 deletions shapiq/plot/n_sii_stacked_bar.py

This file was deleted.

136 changes: 136 additions & 0 deletions shapiq/plot/stacked_bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""This module contains functions to plot the n_sii stacked bar charts."""
__all__ = ["stacked_bar_plot"]

from copy import deepcopy
from typing import Union, Optional

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Patch

from ._config import COLORS_N_SII


def stacked_bar_plot(
feature_names: Union[list, np.ndarray],
n_shapley_values_pos: dict,
n_shapley_values_neg: dict,
n_sii_max_order: Optional[int] = None,
title: Optional[str] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
):
"""Plot the n-SII values for a given instance.
This stacked bar plot can be used to visualize the amount of interaction between the features
for a given instance. The n-SII values are plotted as stacked bars with positive and negative
parts stacked on top of each other. The colors represent the order of the n-SII values. For a
detailed explanation of this plot, see this `research paper <https://proceedings.mlr.press/v206/bordt23a/bordt23a.pdf>`_.
An example of the plot is shown below.
.. image:: /_static/stacked_bar_exampl.png
:width: 400
:align: center
Args:
feature_names (list): The names of the features.
n_shapley_values_pos (dict): The positive n-SII values.
n_shapley_values_neg (dict): The negative n-SII values.
n_sii_max_order (int): The order of the n-SII values.
title (str): The title of the plot.
xlabel (str): The label of the x-axis.
ylabel (str): The label of the y-axis.
Returns:
tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: A tuple containing the figure and
the axis of the plot.
Note:
To change the figure size, font size, etc., use the [matplotlib parameters](https://matplotlib.org/stable/users/explain/customizing.html).
Example:
>>> import numpy as np
>>> from shapiq.plot import stacked_bar_plot
>>> n_shapley_values_pos = {
... 1: np.asarray([1, 0, 1.75]),
... 2: np.asarray([0.25, 0.5, 0.75]),
... 3: np.asarray([0.5, 0.25, 0.25]),
... }
>>> n_shapley_values_neg = {
... 1: np.asarray([0, -1.5, 0]),
... 2: np.asarray([-0.25, -0.5, -0.75]),
... 3: np.asarray([-0.5, -0.25, -0.25]),
... }
>>> feature_names = ["a", "b", "c"]
>>> fig, axes = stacked_bar_plot(
... feature_names=feature_names,
... n_shapley_values_pos=n_shapley_values_pos,
... n_shapley_values_neg=n_shapley_values_neg,
... )
>>> plt.show()
"""
# sanitize inputs
if n_sii_max_order is None:
n_sii_max_order = len(n_shapley_values_pos)

fig, axis = plt.subplots()

# transform data to make plotting easier
n_features = len(feature_names)
x = np.arange(n_features)
values_pos = np.array(
[values for order, values in n_shapley_values_pos.items() if order >= n_sii_max_order]
)
values_neg = np.array(
[values for order, values in n_shapley_values_neg.items() if order >= n_sii_max_order]
)

# get helper variables for plotting the bars
min_max_values = [0, 0] # to set the y-axis limits after all bars are plotted
reference_pos = np.zeros(n_features) # to plot the bars on top of each other
reference_neg = deepcopy(values_neg[0]) # to plot the bars below of each other

# plot the bar segments
for order in range(len(values_pos)):
axis.bar(x, height=values_pos[order], bottom=reference_pos, color=COLORS_N_SII[order])
axis.bar(x, height=abs(values_neg[order]), bottom=reference_neg, color=COLORS_N_SII[order])
axis.axhline(y=0, color="black", linestyle="solid", linewidth=0.5)
reference_pos += values_pos[order]
try:
reference_neg += values_neg[order + 1]
except IndexError:
pass
min_max_values[0] = min(min_max_values[0], min(reference_neg))
min_max_values[1] = max(min_max_values[1], max(reference_pos))

# add a legend to the plots
legend_elements = []
for order in range(n_sii_max_order):
legend_elements.append(
Patch(facecolor=COLORS_N_SII[order], edgecolor="black", label=f"Order {order + 1}")
)
axis.legend(handles=legend_elements, loc="upper center", ncol=min(n_sii_max_order, 4))

x_ticks_labels = [feature for feature in feature_names] # might be unnecessary
axis.set_xticks(x)
axis.set_xticklabels(x_ticks_labels, rotation=45, ha="right")

axis.set_xlim(-0.5, n_features - 0.5)
axis.set_ylim(
min_max_values[0] - abs(min_max_values[1] - min_max_values[0]) * 0.02,
min_max_values[1] + abs(min_max_values[1] - min_max_values[0]) * 0.3,
)

# set title and labels if not provided

axis.set_title(
f"n-SII values up to order ${n_sii_max_order}$"
) if title is None else axis.set_title(title)

axis.set_xlabel("features") if xlabel is None else axis.set_xlabel(xlabel)
axis.set_ylabel("n-SII values") if ylabel is None else axis.set_ylabel(ylabel)

plt.tight_layout()

return fig, axis
45 changes: 45 additions & 0 deletions tests/tests_plots/test_stacked_bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""This module contains all tests for the stacked bar plots."""
import numpy as np

import matplotlib.pyplot as plt


from shapiq.plot import stacked_bar_plot


def test_stacked_bar_plot():
"""Tests whether the stacked bar plot can be created."""

n_shapley_values_pos = {
1: np.asarray([1, 0, 1.75]),
2: np.asarray([0.25, 0.5, 0.75]),
3: np.asarray([0.5, 0.25, 0.25]),
}
n_shapley_values_neg = {
1: np.asarray([0, -1.5, 0]),
2: np.asarray([-0.25, -0.5, -0.75]),
3: np.asarray([-0.5, -0.25, -0.25]),
}
feature_names = ["a", "b", "c"]
fig, axes = stacked_bar_plot(
feature_names=feature_names,
n_shapley_values_pos=n_shapley_values_pos,
n_shapley_values_neg=n_shapley_values_neg,
)
assert fig is not None
assert axes is not None
plt.close(fig)
assert True

fig, axes = stacked_bar_plot(
feature_names=feature_names,
n_shapley_values_pos=n_shapley_values_pos,
n_shapley_values_neg=n_shapley_values_neg,
n_sii_max_order=2,
title="Title",
xlabel="X",
ylabel="Y",
)
assert fig is not None
assert axes is not None
plt.close(fig)

0 comments on commit 1a2046c

Please sign in to comment.