From 50876ced2fb6d108a6933f66e22e74c91337ec93 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Tue, 6 Feb 2024 11:16:15 +0100 Subject: [PATCH] adds text support to network plot and fixes ordering and closes #33 --- shapiq/plot/_config.py | 2 + shapiq/plot/network.py | 374 ++++++++++++++++--------- tests/tests_plots/test_network_plot.py | 37 ++- 3 files changed, 274 insertions(+), 139 deletions(-) diff --git a/shapiq/plot/_config.py b/shapiq/plot/_config.py index 05166b06..539df662 100644 --- a/shapiq/plot/_config.py +++ b/shapiq/plot/_config.py @@ -5,12 +5,14 @@ "RED", "BLUE", "NEUTRAL", + "LINES", "COLORS_N_SII", ] RED = Color("#ff0d57") BLUE = Color("#1e88e5") NEUTRAL = Color("#ffffff") +LINES = Color("#cccccc") COLORS_N_SII = [ "#D81B60", diff --git a/shapiq/plot/network.py b/shapiq/plot/network.py index 080215d6..5004559f 100644 --- a/shapiq/plot/network.py +++ b/shapiq/plot/network.py @@ -11,152 +11,25 @@ from approximator._base import InteractionValues from utils import powerset -from ._config import BLUE, RED +from ._config import BLUE, RED, NEUTRAL, LINES __all__ = [ "network_plot", ] -def _get_color(value: float) -> str: - """Returns blue color for negative values and red color for positive values.""" - if value >= 0: - return RED.hex - return BLUE.hex - - -def _add_weight_to_edges_in_graph( - graph: nx.Graph, - first_order_values: np.ndarray, - second_order_values: np.ndarray, - n_features: int, - feature_names: list[str], -) -> None: - """Adds the weights to the edges in the graph. - - Args: - graph (nx.Graph): The graph to add the weights to. - first_order_values (np.ndarray): The first order n-SII values. - second_order_values (np.ndarray): The second order n-SII values. - n_features (int): The number of features. - feature_names (list[str]): The names of the features. - - Returns: - None - """ - - # get min and max value for n_shapley_values - min_node_value, max_node_value = np.min(first_order_values), np.max(first_order_values) - min_edge_value, max_edge_value = np.min(second_order_values), np.max(second_order_values) - - all_range = abs(max(max_node_value, max_edge_value) - min(min_node_value, min_edge_value)) - - size_scaler = 30 - - for node in graph.nodes: - weight: float = first_order_values[node] - size = abs(weight) / all_range - color = _get_color(weight) - graph.nodes[node]["node_color"] = color - graph.nodes[node]["node_size"] = size * 250 - graph.nodes[node]["label"] = feature_names[node] - graph.nodes[node]["linewidths"] = 1 - graph.nodes[node]["edgecolors"] = color - - for edge in powerset(range(n_features), min_size=2, max_size=2): - weight: float = float(second_order_values[edge]) - color = _get_color(weight) - # scale weight between min and max edge value - size = abs(weight) / all_range - graph_edge = graph.get_edge_data(*edge) - graph_edge["width"] = size * (size_scaler + 1) - graph_edge["color"] = color - - -def _add_legend_to_axis(axis: plt.Axes) -> None: - """Adds a legend for order 1 (nodes) and order 2 (edges) interactions to the axis. - - Args: - axis (plt.Axes): The axis to add the legend to. - - Returns: - None - """ - sizes = [1.0, 0.2, 0.2, 1] - labels = ["high pos.", "low pos.", "low neg.", "high neg."] - alphas_line = [0.5, 0.2, 0.2, 0.5] - - # order 1 (circles) - plot_circles = [] - for i in range(4): - size = sizes[i] - if i < 2: - color = RED.hex - else: - color = BLUE.hex - circle = axis.plot([], [], c=color, marker="o", markersize=size * 8, linestyle="None") - plot_circles.append(circle[0]) - - legend1 = plt.legend( - plot_circles, - labels, - frameon=True, - framealpha=0.5, - facecolor="white", - title=r"$\bf{Order\ 1}$", - fontsize=7, - labelspacing=0.5, - handletextpad=0.5, - borderpad=0.5, - handlelength=1.5, - bbox_to_anchor=(1.12, 1.1), - title_fontsize=7, - loc="upper right", - ) - - # order 2 (lines) - plot_lines = [] - for i in range(4): - size = sizes[i] - alpha = alphas_line[i] - if i < 2: - color = RED.hex - else: - color = BLUE.hex - line = axis.plot([], [], c=color, linewidth=size * 3, alpha=alpha) - plot_lines.append(line[0]) - - legend2 = plt.legend( - plot_lines, - labels, - frameon=True, - framealpha=0.5, - facecolor="white", - title=r"$\bf{Order\ 2}$", - fontsize=7, - labelspacing=0.5, - handletextpad=0.5, - borderpad=0.5, - handlelength=1.5, - bbox_to_anchor=(1.12, 0.92), - title_fontsize=7, - loc="upper right", - ) - - axis.add_artist(legend1) - axis.add_artist(legend2) - - def network_plot( + interaction_values: Optional[InteractionValues] = None, *, first_order_values: Optional[np.ndarray[float]] = None, second_order_values: Optional[np.ndarray[float]] = None, - interaction_values: Optional[InteractionValues] = None, feature_names: Optional[list[Any]] = None, feature_image_patches: Optional[dict[int, Image.Image]] = None, feature_image_patches_size: Optional[Union[float, dict[int, float]]] = 0.2, center_image: Optional[Image.Image] = None, center_image_size: Optional[float] = 0.6, + draw_legend: bool = True, + center_text: Optional[str] = None, ) -> tuple[plt.Figure, plt.Axes]: """Draws the interaction network. @@ -185,6 +58,8 @@ def network_plot( Defaults to 0.2. center_image: The image to be displayed in the center of the network. Defaults to None. center_image_size: The size of the center image. Defaults to 0.6. + draw_legend: Whether to draw the legend. Defaults to True. + center_text: The text to be displayed in the center of the network. Defaults to None. Returns: The figure and the axis containing the plot. @@ -192,6 +67,22 @@ def network_plot( fig, axis = plt.subplots(figsize=(6, 6)) axis.axis("off") + if interaction_values is not None: + n_players = interaction_values.n_players + first_order_values = np.zeros(n_players) + second_order_values = np.zeros((n_players, n_players)) + for interaction in powerset(range(n_players), min_size=1, max_size=2): + if len(interaction) == 1: + first_order_values[interaction[0]] = interaction_values[interaction] + else: + second_order_values[interaction] = interaction_values[interaction] + else: + if first_order_values is None or second_order_values is None: + raise ValueError( + "Either interaction_values or first_order_values and second_order_values must be " + "provided. If interaction_values is provided this will be used." + ) + # get the number of features and the feature names n_features = first_order_values.shape[0] if feature_names is None: @@ -200,6 +91,8 @@ def network_plot( # create a fully connected graph up to the n_sii_order graph = nx.complete_graph(n_features) + nodes_visit_order = _order_nodes(len(graph.nodes)) + # add the weights to the edges _add_weight_to_edges_in_graph( graph=graph, @@ -207,6 +100,7 @@ def network_plot( second_order_values=second_order_values, n_features=n_features, feature_names=feature_names, + nodes_visit_order=nodes_visit_order, ) # get node and edge attributes @@ -235,9 +129,10 @@ def network_plot( ) # add the labels or image patches to the nodes - for node, (x, y) in pos.items(): + for i, node in enumerate(nodes_visit_order): + (x, y) = pos[node] size = graph.nodes[node]["linewidths"] - label = node_labels[node] + label = node_labels[i] radius = 1.15 + size / 300 theta = np.arctan2(x, y) if abs(theta) <= 0.001: @@ -251,10 +146,10 @@ def network_plot( if feature_image_patches is None: axis.text(x, y, label, horizontalalignment="center", verticalalignment="center") else: # draw the image instead of the text - image = feature_image_patches[node] + image = feature_image_patches[i] patch_size = feature_image_patches_size if isinstance(patch_size, dict): - patch_size = patch_size[node] + patch_size = patch_size[i] extend = patch_size / 2 axis.imshow(image, extent=(x - extend, x + extend, y - extend, y + extend)) @@ -262,16 +157,184 @@ def network_plot( if center_image is not None: _add_center_image(axis, center_image, center_image_size, n_features) + # add the center text if provided + if center_text is not None: + background_color = NEUTRAL.hex + line_color = LINES.hex + axis.text( + 0, + 0, + center_text, + horizontalalignment="center", + verticalalignment="center", + bbox=dict(facecolor=background_color, alpha=0.5, edgecolor=line_color, pad=7), + color="black", + fontsize=plt.rcParams["font.size"] + 3, + ) + # add the legends to the plot - _add_legend_to_axis(axis) + if draw_legend: + _add_legend_to_axis(axis) return fig, axis +def _get_color(value: float) -> str: + """Returns blue color for negative values and red color for positive values. + + Args: + value (float): The value to determine the color for. + + Returns: + str: The color as a hex string. + """ + if value >= 0: + return RED.hex + return BLUE.hex + + +def _add_weight_to_edges_in_graph( + graph: nx.Graph, + first_order_values: np.ndarray, + second_order_values: np.ndarray, + n_features: int, + feature_names: list[str], + nodes_visit_order: list[int], +) -> None: + """Adds the weights to the edges in the graph. + + Args: + graph (nx.Graph): The graph to add the weights to. + first_order_values (np.ndarray): The first order n-SII values. + second_order_values (np.ndarray): The second order n-SII values. + n_features (int): The number of features. + feature_names (list[str]): The names of the features. + nodes_visit_order (list[int]): The order of the nodes to visit. + + Returns: + None + """ + + # get min and max value for n_shapley_values + min_node_value, max_node_value = np.min(first_order_values), np.max(first_order_values) + min_edge_value, max_edge_value = np.min(second_order_values), np.max(second_order_values) + + all_range = abs(max(max_node_value, max_edge_value) - min(min_node_value, min_edge_value)) + + size_scaler = 30 + + for i, node_id in enumerate(nodes_visit_order): + weight: float = first_order_values[i] + size = abs(weight) / all_range + color = _get_color(weight) + graph.nodes[node_id]["node_color"] = color + graph.nodes[node_id]["node_size"] = size * 250 + graph.nodes[node_id]["label"] = feature_names[node_id] + graph.nodes[node_id]["linewidths"] = 1 + graph.nodes[node_id]["edgecolors"] = color + + for interaction in powerset(range(n_features), min_size=2, max_size=2): + weight: float = float(second_order_values[interaction]) + edge = list(sorted(interaction)) + edge[0] = nodes_visit_order.index(interaction[0]) + edge[1] = nodes_visit_order.index(interaction[1]) + edge = tuple(edge) + color = _get_color(weight) + # scale weight between min and max edge value + size = abs(weight) / all_range + graph_edge = graph.get_edge_data(*edge) + graph_edge["width"] = size * (size_scaler + 1) + graph_edge["color"] = color + + +def _add_legend_to_axis(axis: plt.Axes) -> None: + """Adds a legend for order 1 (nodes) and order 2 (edges) interactions to the axis. + + Args: + axis (plt.Axes): The axis to add the legend to. + + Returns: + None + """ + sizes = [1.0, 0.2, 0.2, 1] + labels = ["high pos.", "low pos.", "low neg.", "high neg."] + alphas_line = [0.5, 0.2, 0.2, 0.5] + + # order 1 (circles) + plot_circles = [] + for i in range(4): + size = sizes[i] + if i < 2: + color = RED.hex + else: + color = BLUE.hex + circle = axis.plot([], [], c=color, marker="o", markersize=size * 8, linestyle="None") + plot_circles.append(circle[0]) + + font_size = plt.rcParams["legend.fontsize"] + + legend1 = plt.legend( + plot_circles, + labels, + frameon=True, + framealpha=0.5, + facecolor="white", + title=r"$\bf{Order\ 1}$", + fontsize=font_size, + labelspacing=0.5, + handletextpad=0.5, + borderpad=0.5, + handlelength=1.5, + title_fontsize=font_size, + loc="best", + ) + + # order 2 (lines) + plot_lines = [] + for i in range(4): + size = sizes[i] + alpha = alphas_line[i] + if i < 2: + color = RED.hex + else: + color = BLUE.hex + line = axis.plot([], [], c=color, linewidth=size * 3, alpha=alpha) + plot_lines.append(line[0]) + + legend2 = plt.legend( + plot_lines, + labels, + frameon=True, + framealpha=0.5, + facecolor="white", + title=r"$\bf{Order\ 2}$", + fontsize=font_size, + labelspacing=0.5, + handletextpad=0.5, + borderpad=0.5, + handlelength=1.5, + title_fontsize=font_size, + loc="best", + ) + + axis.add_artist(legend1) + axis.add_artist(legend2) + + def _add_center_image( axis: plt.Axes, center_image: Image.Image, center_image_size: float, n_features: int -): - """Adds the center image to the axis.""" +) -> None: + """Adds the center image to the axis. + + Args: + axis (plt.Axes): The axis to add the image to. + center_image (Image.Image): The image to add to the axis. + center_image_size (float): The size of the center image. + n_features (int): The number of features. + + Returns: + None + """ # plot the center image image_to_plot = Image.fromarray(np.asarray(copy.deepcopy(center_image))) extend = center_image_size @@ -287,3 +350,38 @@ def _add_center_image( axis.set_zorder(1) for edge in axis.collections: edge.set_zorder(0) + + +def _get_highest_node_index(n_nodes: int) -> int: + """Calculates the node with the highest position on the y-axis given the total number of nodes. + + Args: + n_nodes (int): The total number of nodes. + + Returns: + int: The index of the highest node. + """ + n_connections = 0 + # highest node is the last node below 1/4 of all connections in the circle + while n_connections <= n_nodes / 4: + n_connections += 1 + n_connections -= 1 + return n_connections + + +def _order_nodes(n_nodes: int) -> list[int]: + """Orders the nodes in the network plot. + + Args: + n_nodes (int): The total number of nodes. + + Returns: + list[int]: The order of the nodes. + """ + highest_node = _get_highest_node_index(n_nodes) + nodes_visit_order = [highest_node] + desired_order = list(reversed(list(range(n_nodes)))) + highest_node_index = desired_order.index(highest_node) + nodes_visit_order += desired_order[highest_node_index + 1 :] + nodes_visit_order += desired_order[:highest_node_index] + return nodes_visit_order diff --git a/tests/tests_plots/test_network_plot.py b/tests/tests_plots/test_network_plot.py index 47404c39..48db6c8e 100644 --- a/tests/tests_plots/test_network_plot.py +++ b/tests/tests_plots/test_network_plot.py @@ -1,9 +1,12 @@ """This module contains all tests for the network plots.""" import numpy as np import matplotlib.pyplot as plt +import pytest from PIL import Image +from scipy.special import binom from shapiq.plot import network_plot +from shapiq.approximator._base import InteractionValues def test_network_plot(): @@ -29,8 +32,29 @@ def test_network_plot(): assert axes is not None plt.close(fig) + # test with InteractionValues object + n_players = 5 + n_values = n_players + int(binom(n_players, 2)) + iv = InteractionValues( + values=np.random.rand(n_values), + index="nSII", + n_players=n_players, + min_order=1, + max_order=2, + ) + fig, axes = network_plot(interaction_values=iv) + assert fig is not None + assert axes is not None + plt.close(fig) -def test_network_plot_with_image(): + # value error if neither first_order_values nor interaction_values are given + with pytest.raises(ValueError): + network_plot() + + assert True + + +def test_network_plot_with_image_or_text(): first_order_values = np.asarray([0.1, -0.2, 0.3, 0.4, 0.5, 0.6]) second_order_values = np.random.rand(6, 6) - 0.5 n_features = len(first_order_values) @@ -66,3 +90,14 @@ def test_network_plot_with_image(): assert fig is not None assert axes is not None plt.close(fig) + + # with text + fig, axes = network_plot( + first_order_values=first_order_values, + second_order_values=second_order_values, + center_text="center text", + ) + assert fig is not None + assert axes is not None + plt.close(fig) + assert True