diff --git a/docs/source/_static/stacked_bar_exampl.png b/docs/source/_static/stacked_bar_exampl.png new file mode 100644 index 00000000..9ee78ac4 Binary files /dev/null and b/docs/source/_static/stacked_bar_exampl.png differ diff --git a/shapiq/__init__.py b/shapiq/__init__.py index 72ea5e39..717989af 100644 --- a/shapiq/__init__.py +++ b/shapiq/__init__.py @@ -19,7 +19,7 @@ from .games import DummyGame # plotting functions -from .plot import network_plot +from .plot import network_plot, stacked_bar_plot # public utils functions from .utils import ( # sets.py # tree.py @@ -48,6 +48,7 @@ "DummyGame", # plots "network_plot", + "stacked_bar_plot", # public utils "powerset", "get_explicit_subsets", diff --git a/shapiq/plot/__init__.py b/shapiq/plot/__init__.py index eb75a111..355b5c72 100644 --- a/shapiq/plot/__init__.py +++ b/shapiq/plot/__init__.py @@ -1,7 +1,6 @@ """This module contains all plotting functions for the shapiq package.""" from .network import network_plot +from .stacked_bar import stacked_bar_plot -__all__ = [ - "network_plot", -] +__all__ = ["network_plot", "stacked_bar_plot"] diff --git a/shapiq/plot/_config.py b/shapiq/plot/_config.py index 9bf34ca5..539df662 100644 --- a/shapiq/plot/_config.py +++ b/shapiq/plot/_config.py @@ -1,12 +1,29 @@ """This module contains the configuration for the shapiq visualizations.""" from colour import Color -RED = Color("#ff0d57") -BLUE = Color("#1e88e5") -NEUTRAL = Color("#ffffff") - __all__ = [ "RED", "BLUE", "NEUTRAL", + "LINES", + "COLORS_N_SII", +] + +RED = Color("#ff0d57") +BLUE = Color("#1e88e5") +NEUTRAL = Color("#ffffff") +LINES = Color("#cccccc") + +COLORS_N_SII = [ + "#D81B60", + "#FFB000", + "#1E88E5", + "#FE6100", + "#7F975F", + "#74ced2", + "#708090", + "#9966CC", + "#CCCCCC", + "#800080", ] +COLORS_N_SII = COLORS_N_SII * (100 + (len(COLORS_N_SII))) # repeat the colors list diff --git a/shapiq/plot/network.py b/shapiq/plot/network.py index 0f660533..5004559f 100644 --- a/shapiq/plot/network.py +++ b/shapiq/plot/network.py @@ -11,134 +11,25 @@ from approximator._base import InteractionValues from utils import powerset -from ._config import BLUE, RED +from ._config import BLUE, RED, NEUTRAL, LINES __all__ = [ "network_plot", ] -def _get_color(value: float) -> str: - """Returns blue color for negative values and red color for positive values.""" - if value >= 0: - return RED.hex - return BLUE.hex - - -def _add_weight_to_edges_in_graph( - graph: nx.Graph, - first_order_values: np.ndarray, - second_order_values: np.ndarray, - n_features: int, - feature_names: list[str], -) -> None: - """Adds the weights to the edges in the graph.""" - - # get min and max value for n_shapley_values - min_node_value, max_node_value = np.min(first_order_values), np.max(first_order_values) - min_edge_value, max_edge_value = np.min(second_order_values), np.max(second_order_values) - - all_range = abs(max(max_node_value, max_edge_value) - min(min_node_value, min_edge_value)) - - size_scaler = 30 - - for node in graph.nodes: - weight: float = first_order_values[node] - size = abs(weight) / all_range - color = _get_color(weight) - graph.nodes[node]["node_color"] = color - graph.nodes[node]["node_size"] = size * 250 - graph.nodes[node]["label"] = feature_names[node] - graph.nodes[node]["linewidths"] = 1 - graph.nodes[node]["edgecolors"] = color - - for edge in powerset(range(n_features), min_size=2, max_size=2): - weight: float = float(second_order_values[edge]) - color = _get_color(weight) - # scale weight between min and max edge value - size = abs(weight) / all_range - graph_edge = graph.get_edge_data(*edge) - graph_edge["width"] = size * (size_scaler + 1) - graph_edge["color"] = color - - -def _add_legend_to_axis(axis: plt.Axes) -> None: - """Adds a legend for order 1 (nodes) and order 2 (edges) interactions to the axis.""" - sizes = [1.0, 0.2, 0.2, 1] - labels = ["high pos.", "low pos.", "low neg.", "high neg."] - alphas_line = [0.5, 0.2, 0.2, 0.5] - - # order 1 (circles) - plot_circles = [] - for i in range(4): - size = sizes[i] - if i < 2: - color = RED.hex - else: - color = BLUE.hex - circle = axis.plot([], [], c=color, marker="o", markersize=size * 8, linestyle="None") - plot_circles.append(circle[0]) - - legend1 = plt.legend( - plot_circles, - labels, - frameon=True, - framealpha=0.5, - facecolor="white", - title=r"$\bf{Order\ 1}$", - fontsize=7, - labelspacing=0.5, - handletextpad=0.5, - borderpad=0.5, - handlelength=1.5, - bbox_to_anchor=(1.12, 1.1), - title_fontsize=7, - loc="upper right", - ) - - # order 2 (lines) - plot_lines = [] - for i in range(4): - size = sizes[i] - alpha = alphas_line[i] - if i < 2: - color = RED.hex - else: - color = BLUE.hex - line = axis.plot([], [], c=color, linewidth=size * 3, alpha=alpha) - plot_lines.append(line[0]) - - legend2 = plt.legend( - plot_lines, - labels, - frameon=True, - framealpha=0.5, - facecolor="white", - title=r"$\bf{Order\ 2}$", - fontsize=7, - labelspacing=0.5, - handletextpad=0.5, - borderpad=0.5, - handlelength=1.5, - bbox_to_anchor=(1.12, 0.92), - title_fontsize=7, - loc="upper right", - ) - - axis.add_artist(legend1) - axis.add_artist(legend2) - - def network_plot( + interaction_values: Optional[InteractionValues] = None, *, - first_order_values: np.ndarray[float], - second_order_values: np.ndarray[float], - interaction_values: InteractionValues = None, + first_order_values: Optional[np.ndarray[float]] = None, + second_order_values: Optional[np.ndarray[float]] = None, feature_names: Optional[list[Any]] = None, feature_image_patches: Optional[dict[int, Image.Image]] = None, feature_image_patches_size: Optional[Union[float, dict[int, float]]] = 0.2, center_image: Optional[Image.Image] = None, center_image_size: Optional[float] = 0.6, + draw_legend: bool = True, + center_text: Optional[str] = None, ) -> tuple[plt.Figure, plt.Axes]: """Draws the interaction network. @@ -167,6 +58,8 @@ def network_plot( Defaults to 0.2. center_image: The image to be displayed in the center of the network. Defaults to None. center_image_size: The size of the center image. Defaults to 0.6. + draw_legend: Whether to draw the legend. Defaults to True. + center_text: The text to be displayed in the center of the network. Defaults to None. Returns: The figure and the axis containing the plot. @@ -174,6 +67,22 @@ def network_plot( fig, axis = plt.subplots(figsize=(6, 6)) axis.axis("off") + if interaction_values is not None: + n_players = interaction_values.n_players + first_order_values = np.zeros(n_players) + second_order_values = np.zeros((n_players, n_players)) + for interaction in powerset(range(n_players), min_size=1, max_size=2): + if len(interaction) == 1: + first_order_values[interaction[0]] = interaction_values[interaction] + else: + second_order_values[interaction] = interaction_values[interaction] + else: + if first_order_values is None or second_order_values is None: + raise ValueError( + "Either interaction_values or first_order_values and second_order_values must be " + "provided. If interaction_values is provided this will be used." + ) + # get the number of features and the feature names n_features = first_order_values.shape[0] if feature_names is None: @@ -182,6 +91,8 @@ def network_plot( # create a fully connected graph up to the n_sii_order graph = nx.complete_graph(n_features) + nodes_visit_order = _order_nodes(len(graph.nodes)) + # add the weights to the edges _add_weight_to_edges_in_graph( graph=graph, @@ -189,6 +100,7 @@ def network_plot( second_order_values=second_order_values, n_features=n_features, feature_names=feature_names, + nodes_visit_order=nodes_visit_order, ) # get node and edge attributes @@ -217,9 +129,10 @@ def network_plot( ) # add the labels or image patches to the nodes - for node, (x, y) in pos.items(): + for i, node in enumerate(nodes_visit_order): + (x, y) = pos[node] size = graph.nodes[node]["linewidths"] - label = node_labels[node] + label = node_labels[i] radius = 1.15 + size / 300 theta = np.arctan2(x, y) if abs(theta) <= 0.001: @@ -233,10 +146,10 @@ def network_plot( if feature_image_patches is None: axis.text(x, y, label, horizontalalignment="center", verticalalignment="center") else: # draw the image instead of the text - image = feature_image_patches[node] + image = feature_image_patches[i] patch_size = feature_image_patches_size if isinstance(patch_size, dict): - patch_size = patch_size[node] + patch_size = patch_size[i] extend = patch_size / 2 axis.imshow(image, extent=(x - extend, x + extend, y - extend, y + extend)) @@ -244,16 +157,184 @@ def network_plot( if center_image is not None: _add_center_image(axis, center_image, center_image_size, n_features) + # add the center text if provided + if center_text is not None: + background_color = NEUTRAL.hex + line_color = LINES.hex + axis.text( + 0, + 0, + center_text, + horizontalalignment="center", + verticalalignment="center", + bbox=dict(facecolor=background_color, alpha=0.5, edgecolor=line_color, pad=7), + color="black", + fontsize=plt.rcParams["font.size"] + 3, + ) + # add the legends to the plot - _add_legend_to_axis(axis) + if draw_legend: + _add_legend_to_axis(axis) return fig, axis +def _get_color(value: float) -> str: + """Returns blue color for negative values and red color for positive values. + + Args: + value (float): The value to determine the color for. + + Returns: + str: The color as a hex string. + """ + if value >= 0: + return RED.hex + return BLUE.hex + + +def _add_weight_to_edges_in_graph( + graph: nx.Graph, + first_order_values: np.ndarray, + second_order_values: np.ndarray, + n_features: int, + feature_names: list[str], + nodes_visit_order: list[int], +) -> None: + """Adds the weights to the edges in the graph. + + Args: + graph (nx.Graph): The graph to add the weights to. + first_order_values (np.ndarray): The first order n-SII values. + second_order_values (np.ndarray): The second order n-SII values. + n_features (int): The number of features. + feature_names (list[str]): The names of the features. + nodes_visit_order (list[int]): The order of the nodes to visit. + + Returns: + None + """ + + # get min and max value for n_shapley_values + min_node_value, max_node_value = np.min(first_order_values), np.max(first_order_values) + min_edge_value, max_edge_value = np.min(second_order_values), np.max(second_order_values) + + all_range = abs(max(max_node_value, max_edge_value) - min(min_node_value, min_edge_value)) + + size_scaler = 30 + + for i, node_id in enumerate(nodes_visit_order): + weight: float = first_order_values[i] + size = abs(weight) / all_range + color = _get_color(weight) + graph.nodes[node_id]["node_color"] = color + graph.nodes[node_id]["node_size"] = size * 250 + graph.nodes[node_id]["label"] = feature_names[node_id] + graph.nodes[node_id]["linewidths"] = 1 + graph.nodes[node_id]["edgecolors"] = color + + for interaction in powerset(range(n_features), min_size=2, max_size=2): + weight: float = float(second_order_values[interaction]) + edge = list(sorted(interaction)) + edge[0] = nodes_visit_order.index(interaction[0]) + edge[1] = nodes_visit_order.index(interaction[1]) + edge = tuple(edge) + color = _get_color(weight) + # scale weight between min and max edge value + size = abs(weight) / all_range + graph_edge = graph.get_edge_data(*edge) + graph_edge["width"] = size * (size_scaler + 1) + graph_edge["color"] = color + + +def _add_legend_to_axis(axis: plt.Axes) -> None: + """Adds a legend for order 1 (nodes) and order 2 (edges) interactions to the axis. + + Args: + axis (plt.Axes): The axis to add the legend to. + + Returns: + None + """ + sizes = [1.0, 0.2, 0.2, 1] + labels = ["high pos.", "low pos.", "low neg.", "high neg."] + alphas_line = [0.5, 0.2, 0.2, 0.5] + + # order 1 (circles) + plot_circles = [] + for i in range(4): + size = sizes[i] + if i < 2: + color = RED.hex + else: + color = BLUE.hex + circle = axis.plot([], [], c=color, marker="o", markersize=size * 8, linestyle="None") + plot_circles.append(circle[0]) + + font_size = plt.rcParams["legend.fontsize"] + + legend1 = plt.legend( + plot_circles, + labels, + frameon=True, + framealpha=0.5, + facecolor="white", + title=r"$\bf{Order\ 1}$", + fontsize=font_size, + labelspacing=0.5, + handletextpad=0.5, + borderpad=0.5, + handlelength=1.5, + title_fontsize=font_size, + loc="best", + ) + + # order 2 (lines) + plot_lines = [] + for i in range(4): + size = sizes[i] + alpha = alphas_line[i] + if i < 2: + color = RED.hex + else: + color = BLUE.hex + line = axis.plot([], [], c=color, linewidth=size * 3, alpha=alpha) + plot_lines.append(line[0]) + + legend2 = plt.legend( + plot_lines, + labels, + frameon=True, + framealpha=0.5, + facecolor="white", + title=r"$\bf{Order\ 2}$", + fontsize=font_size, + labelspacing=0.5, + handletextpad=0.5, + borderpad=0.5, + handlelength=1.5, + title_fontsize=font_size, + loc="best", + ) + + axis.add_artist(legend1) + axis.add_artist(legend2) + + def _add_center_image( axis: plt.Axes, center_image: Image.Image, center_image_size: float, n_features: int -): - """Adds the center image to the axis.""" +) -> None: + """Adds the center image to the axis. + + Args: + axis (plt.Axes): The axis to add the image to. + center_image (Image.Image): The image to add to the axis. + center_image_size (float): The size of the center image. + n_features (int): The number of features. + + Returns: + None + """ # plot the center image image_to_plot = Image.fromarray(np.asarray(copy.deepcopy(center_image))) extend = center_image_size @@ -269,3 +350,38 @@ def _add_center_image( axis.set_zorder(1) for edge in axis.collections: edge.set_zorder(0) + + +def _get_highest_node_index(n_nodes: int) -> int: + """Calculates the node with the highest position on the y-axis given the total number of nodes. + + Args: + n_nodes (int): The total number of nodes. + + Returns: + int: The index of the highest node. + """ + n_connections = 0 + # highest node is the last node below 1/4 of all connections in the circle + while n_connections <= n_nodes / 4: + n_connections += 1 + n_connections -= 1 + return n_connections + + +def _order_nodes(n_nodes: int) -> list[int]: + """Orders the nodes in the network plot. + + Args: + n_nodes (int): The total number of nodes. + + Returns: + list[int]: The order of the nodes. + """ + highest_node = _get_highest_node_index(n_nodes) + nodes_visit_order = [highest_node] + desired_order = list(reversed(list(range(n_nodes)))) + highest_node_index = desired_order.index(highest_node) + nodes_visit_order += desired_order[highest_node_index + 1 :] + nodes_visit_order += desired_order[:highest_node_index] + return nodes_visit_order diff --git a/shapiq/plot/stacked_bar.py b/shapiq/plot/stacked_bar.py new file mode 100644 index 00000000..b727f528 --- /dev/null +++ b/shapiq/plot/stacked_bar.py @@ -0,0 +1,136 @@ +"""This module contains functions to plot the n_sii stacked bar charts.""" +__all__ = ["stacked_bar_plot"] + +from copy import deepcopy +from typing import Union, Optional + +import numpy as np +from matplotlib import pyplot as plt +from matplotlib.patches import Patch + +from ._config import COLORS_N_SII + + +def stacked_bar_plot( + feature_names: Union[list, np.ndarray], + n_shapley_values_pos: dict, + n_shapley_values_neg: dict, + n_sii_max_order: Optional[int] = None, + title: Optional[str] = None, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, +): + """Plot the n-SII values for a given instance. + + This stacked bar plot can be used to visualize the amount of interaction between the features + for a given instance. The n-SII values are plotted as stacked bars with positive and negative + parts stacked on top of each other. The colors represent the order of the n-SII values. For a + detailed explanation of this plot, see this `research paper `_. + + An example of the plot is shown below. + + .. image:: /_static/stacked_bar_exampl.png + :width: 400 + :align: center + + Args: + feature_names (list): The names of the features. + n_shapley_values_pos (dict): The positive n-SII values. + n_shapley_values_neg (dict): The negative n-SII values. + n_sii_max_order (int): The order of the n-SII values. + title (str): The title of the plot. + xlabel (str): The label of the x-axis. + ylabel (str): The label of the y-axis. + + Returns: + tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: A tuple containing the figure and + the axis of the plot. + + Note: + To change the figure size, font size, etc., use the [matplotlib parameters](https://matplotlib.org/stable/users/explain/customizing.html). + + Example: + >>> import numpy as np + >>> from shapiq.plot import stacked_bar_plot + >>> n_shapley_values_pos = { + ... 1: np.asarray([1, 0, 1.75]), + ... 2: np.asarray([0.25, 0.5, 0.75]), + ... 3: np.asarray([0.5, 0.25, 0.25]), + ... } + >>> n_shapley_values_neg = { + ... 1: np.asarray([0, -1.5, 0]), + ... 2: np.asarray([-0.25, -0.5, -0.75]), + ... 3: np.asarray([-0.5, -0.25, -0.25]), + ... } + >>> feature_names = ["a", "b", "c"] + >>> fig, axes = stacked_bar_plot( + ... feature_names=feature_names, + ... n_shapley_values_pos=n_shapley_values_pos, + ... n_shapley_values_neg=n_shapley_values_neg, + ... ) + >>> plt.show() + """ + # sanitize inputs + if n_sii_max_order is None: + n_sii_max_order = len(n_shapley_values_pos) + + fig, axis = plt.subplots() + + # transform data to make plotting easier + n_features = len(feature_names) + x = np.arange(n_features) + values_pos = np.array( + [values for order, values in n_shapley_values_pos.items() if order >= n_sii_max_order] + ) + values_neg = np.array( + [values for order, values in n_shapley_values_neg.items() if order >= n_sii_max_order] + ) + + # get helper variables for plotting the bars + min_max_values = [0, 0] # to set the y-axis limits after all bars are plotted + reference_pos = np.zeros(n_features) # to plot the bars on top of each other + reference_neg = deepcopy(values_neg[0]) # to plot the bars below of each other + + # plot the bar segments + for order in range(len(values_pos)): + axis.bar(x, height=values_pos[order], bottom=reference_pos, color=COLORS_N_SII[order]) + axis.bar(x, height=abs(values_neg[order]), bottom=reference_neg, color=COLORS_N_SII[order]) + axis.axhline(y=0, color="black", linestyle="solid", linewidth=0.5) + reference_pos += values_pos[order] + try: + reference_neg += values_neg[order + 1] + except IndexError: + pass + min_max_values[0] = min(min_max_values[0], min(reference_neg)) + min_max_values[1] = max(min_max_values[1], max(reference_pos)) + + # add a legend to the plots + legend_elements = [] + for order in range(n_sii_max_order): + legend_elements.append( + Patch(facecolor=COLORS_N_SII[order], edgecolor="black", label=f"Order {order + 1}") + ) + axis.legend(handles=legend_elements, loc="upper center", ncol=min(n_sii_max_order, 4)) + + x_ticks_labels = [feature for feature in feature_names] # might be unnecessary + axis.set_xticks(x) + axis.set_xticklabels(x_ticks_labels, rotation=45, ha="right") + + axis.set_xlim(-0.5, n_features - 0.5) + axis.set_ylim( + min_max_values[0] - abs(min_max_values[1] - min_max_values[0]) * 0.02, + min_max_values[1] + abs(min_max_values[1] - min_max_values[0]) * 0.3, + ) + + # set title and labels if not provided + + axis.set_title( + f"n-SII values up to order ${n_sii_max_order}$" + ) if title is None else axis.set_title(title) + + axis.set_xlabel("features") if xlabel is None else axis.set_xlabel(xlabel) + axis.set_ylabel("n-SII values") if ylabel is None else axis.set_ylabel(ylabel) + + plt.tight_layout() + + return fig, axis diff --git a/tests/tests_plots/test_network_plot.py b/tests/tests_plots/test_network_plot.py index 47404c39..48db6c8e 100644 --- a/tests/tests_plots/test_network_plot.py +++ b/tests/tests_plots/test_network_plot.py @@ -1,9 +1,12 @@ """This module contains all tests for the network plots.""" import numpy as np import matplotlib.pyplot as plt +import pytest from PIL import Image +from scipy.special import binom from shapiq.plot import network_plot +from shapiq.approximator._base import InteractionValues def test_network_plot(): @@ -29,8 +32,29 @@ def test_network_plot(): assert axes is not None plt.close(fig) + # test with InteractionValues object + n_players = 5 + n_values = n_players + int(binom(n_players, 2)) + iv = InteractionValues( + values=np.random.rand(n_values), + index="nSII", + n_players=n_players, + min_order=1, + max_order=2, + ) + fig, axes = network_plot(interaction_values=iv) + assert fig is not None + assert axes is not None + plt.close(fig) -def test_network_plot_with_image(): + # value error if neither first_order_values nor interaction_values are given + with pytest.raises(ValueError): + network_plot() + + assert True + + +def test_network_plot_with_image_or_text(): first_order_values = np.asarray([0.1, -0.2, 0.3, 0.4, 0.5, 0.6]) second_order_values = np.random.rand(6, 6) - 0.5 n_features = len(first_order_values) @@ -66,3 +90,14 @@ def test_network_plot_with_image(): assert fig is not None assert axes is not None plt.close(fig) + + # with text + fig, axes = network_plot( + first_order_values=first_order_values, + second_order_values=second_order_values, + center_text="center text", + ) + assert fig is not None + assert axes is not None + plt.close(fig) + assert True diff --git a/tests/tests_plots/test_stacked_bar.py b/tests/tests_plots/test_stacked_bar.py new file mode 100644 index 00000000..640b9fe8 --- /dev/null +++ b/tests/tests_plots/test_stacked_bar.py @@ -0,0 +1,45 @@ +"""This module contains all tests for the stacked bar plots.""" +import numpy as np + +import matplotlib.pyplot as plt + + +from shapiq.plot import stacked_bar_plot + + +def test_stacked_bar_plot(): + """Tests whether the stacked bar plot can be created.""" + + n_shapley_values_pos = { + 1: np.asarray([1, 0, 1.75]), + 2: np.asarray([0.25, 0.5, 0.75]), + 3: np.asarray([0.5, 0.25, 0.25]), + } + n_shapley_values_neg = { + 1: np.asarray([0, -1.5, 0]), + 2: np.asarray([-0.25, -0.5, -0.75]), + 3: np.asarray([-0.5, -0.25, -0.25]), + } + feature_names = ["a", "b", "c"] + fig, axes = stacked_bar_plot( + feature_names=feature_names, + n_shapley_values_pos=n_shapley_values_pos, + n_shapley_values_neg=n_shapley_values_neg, + ) + assert fig is not None + assert axes is not None + plt.close(fig) + assert True + + fig, axes = stacked_bar_plot( + feature_names=feature_names, + n_shapley_values_pos=n_shapley_values_pos, + n_shapley_values_neg=n_shapley_values_neg, + n_sii_max_order=2, + title="Title", + xlabel="X", + ylabel="Y", + ) + assert fig is not None + assert axes is not None + plt.close(fig)