From 3c3dfc24a23834b323140cae49d145bf2391ea30 Mon Sep 17 00:00:00 2001 From: Theo <49311372+Advueu963@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:15:10 +0200 Subject: [PATCH] 183 make games be also callable with a tuplelist of tuples (#242) * added possibility of tuple/list of tuple in BaseGame __call__ * Added possibility for to Game coalitions with given player_names * Added possibility to call game based on given player_names * removed player_names property as it caused conflicts * made __call__ annotations for python3.9 * refactor game call with strings * refactor game call of tuple and list * Reimplement _check_coalitions to fewer lines. --------- Co-authored-by: Maximilian <maximilian.muschalik@gmail.com> --- shapiq/games/base.py | 113 ++++++++++++++++++++++++++-- tests/tests_games/test_base_game.py | 86 +++++++++++++++++++++ 2 files changed, 194 insertions(+), 5 deletions(-) diff --git a/shapiq/games/base.py b/shapiq/games/base.py index 89e9800b..f8e01953 100644 --- a/shapiq/games/base.py +++ b/shapiq/games/base.py @@ -4,7 +4,7 @@ import pickle import warnings from abc import ABC -from typing import Optional +from typing import Optional, Union import numpy as np from tqdm.auto import tqdm @@ -94,6 +94,7 @@ def __init__( normalization_value: Optional[float] = None, path_to_values: Optional[str] = None, verbose: bool = False, + player_names: Optional[list[str]] = None, *args, **kwargs, ) -> None: @@ -137,6 +138,11 @@ def __init__( self._empty_coalition_value_property = None self._grand_coalition_value_property = None + # define player_names + self.player_name_lookup: dict[str, int] = ( + {name: i for i, name in enumerate(player_names)} if player_names is not None else None + ) + self.verbose = verbose @property @@ -159,7 +165,105 @@ def is_normalized(self) -> bool: """Checks if the game is normalized/centered.""" return self(self.empty_coalition) == 0 - def __call__(self, coalitions: np.ndarray, verbose: bool = False) -> np.ndarray: + def _check_coalitions( + self, + coalitions: Union[np.ndarray, list[Union[tuple[int], tuple[str]]]], + ) -> np.ndarray: + """ + Check if the coalitions are in the correct format and convert them to one-hot encoding. + The format may either be a numpy array containg the coalitions in one-hot encoding or a list of tuples with integers or strings. + Args: + coalitions: The coalitions to convert to one-hot encoding. + Returns: + np.ndarray: The coalitions in the correct format + Raises: + TypeError: If the coalitions are not in the correct format. + Examples: + >>> coalitions = np.asarray([[1, 0, 0, 0], [0, 1, 1, 0]]) + >>> coalitions = [(0, 1), (1, 2)] + >>> coalitions = [()] + >>> coalitions = [(0, 1), (1, 2), (0, 1, 2)] + if player_name_lookup is not None: + >>> coalitions = [("Alice", "Bob"), ("Bob", "Charlie")] + Wrong format: + >>> coalitions = [1, 0, 0, 0] + >>> coalitions = [(1,"Alice")] + >>> coalitions = np.array([1,-1,2]) + + + """ + error_message = ( + "List may only contain tuples of integers or strings." + "The tuples are not allowed to have heterogeneous types." + "Reconcile the docs for correct format of coalitions." + ) + + if isinstance(coalitions, np.ndarray): + + # Check that coalition is contained in array + if len(coalitions) == 0: + raise TypeError("The array of coalitions is empty.") + + # Check if single coalition is correctly given + if coalitions.ndim == 1: + if len(coalitions) < self.n_players or len(coalitions) > self.n_players: + raise TypeError( + "The array of coalitions is not correctly formatted." + f"It should have a length of {self.n_players}" + ) + coalitions = coalitions.reshape((1, self.n_players)) + + # Check that all coalitions have the correct number of players + if coalitions.shape[1] != self.n_players: + raise TypeError( + f"The number of players in the coalitions ({coalitions.shape[1]}) does not match " + f"the number of players in the game ({self.n_players})." + ) + + # Check that values of numpy array are either 0 or 1 + if not np.all(np.logical_or(coalitions == 0, coalitions == 1)): + raise TypeError("The values in the array of coalitions are not binary.") + + return coalitions + + # We now assume to work with list of tuples + if isinstance(coalitions, tuple): + # if by any chance a tuple was given wrap into a list + coalitions = [coalitions] + + try: + # convert list of tuples to one-hot encoding + coalitions = transform_coalitions_to_array(coalitions, self.n_players) + + return coalitions + except Exception as err: + # It may either be the tuples contain strings or wrong format + if self.player_name_lookup is not None: + # We now assume the tuples to contain strings + try: + coalitions = [ + ( + tuple(self.player_name_lookup[player] for player in coalition) + if coalition != tuple() + else tuple() + ) + for coalition in coalitions + ] + coalitions = transform_coalitions_to_array(coalitions, self.n_players) + + return coalitions + except Exception as err: + raise TypeError(error_message) from err + + raise TypeError(error_message) from err + + def __call__( + self, + coalitions: Union[ + np.ndarray, list[Union[tuple[int], tuple[str]]], tuple[Union[int, str]], str + ], + verbose: bool = False, + ) -> np.ndarray: """Calls the game's value function with the given coalitions and returns the output of the value function. @@ -170,9 +274,8 @@ def __call__(self, coalitions: np.ndarray, verbose: bool = False) -> np.ndarray: Returns: The values of the coalitions. """ - # check if coalitions are correct dimensions - if coalitions.ndim == 1: - coalitions = coalitions.reshape((1, self.n_players)) + # check if coalitions are correct format + coalitions = self._check_coalitions(coalitions) verbose = verbose or self.verbose diff --git a/tests/tests_games/test_base_game.py b/tests/tests_games/test_base_game.py index 28f6b5cc..5e3ef654 100644 --- a/tests/tests_games/test_base_game.py +++ b/tests/tests_games/test_base_game.py @@ -5,10 +5,96 @@ import numpy as np import pytest +from shapiq.games.base import Game from shapiq.games.benchmark import DummyGame # used to test the base class from shapiq.utils.sets import powerset, transform_coalitions_to_array +def test_call(): + """This test tests the call function of the base game class.""" + + class TestGame(Game): + """This is a test game class that inherits from the base game class. + Its value function is the amount of players divided by the number of players. + """ + + def __init__(self, n, **kwargs): + super().__init__(n_players=n, normalization_value=0, **kwargs) + + def value_function(self, coalition): + return np.sum(coalition) / self.n_players + + n_players = 6 + test_game = TestGame( + n=n_players, player_names=["Alice", "Bob", "Charlie", "David", "Eve", "Frank"] + ) + + # assert that player names are correctly stored + assert test_game.player_name_lookup == { + "Alice": 0, + "Bob": 1, + "Charlie": 2, + "David": 3, + "Eve": 4, + "Frank": 5, + } + + assert test_game([]) == 0.0 + + # test coalition calls with wrong datatype + with pytest.raises(TypeError): + assert test_game([(0, 1), "Alice", "Charlie"]) + with pytest.raises(TypeError): + assert test_game([(0, 1), ("Alice",), ("Bob",)]) + with pytest.raises(TypeError): + assert test_game(("Alice", 1)) + + # test wrong coalition size in call + with pytest.raises(TypeError): + assert test_game(np.array([True, False, True])) == 0.0 + with pytest.raises(TypeError): + assert test_game(np.array([])) == 0.0 + + # test wrong method for numpy array values + with pytest.raises(TypeError): + assert test_game(np.array([1, 2, 3, 4, 5, 6])) == 0.0 + + # test wrong coalition size in shape[1] + with pytest.raises(TypeError): + assert test_game(np.array([[True, False, True]])) == 0.0 + + # test with empty coalition all call variants + test_coalition = test_game.empty_coalition + assert test_game(test_coalition) == 0.0 + assert test_game(()) == 0.0 + assert test_game([()]) == 0.0 + + # test with grand coalition all call variants + test_coalition = test_game.grand_coalition + assert test_game(test_coalition) == 1.0 + assert test_game(tuple(range(0, test_game.n_players))) == 1.0 + assert test_game([tuple(range(0, test_game.n_players))]) == 1.0 + assert test_game(tuple(test_game.player_name_lookup.values())) == 1.0 + assert test_game([tuple(test_game.player_name_lookup.values())]) == 1.0 + + # test with single player coalition all call variants + test_coalition = np.array([True] + [False for _ in range(test_game.n_players - 1)]) + assert test_game(test_coalition) - 1 / 6 < 10e-7 + assert test_game((0,)) - 1 / 6 < 10e-7 + assert test_game([(0,)]) - 1 / 6 < 10e-7 + assert test_game(("Alice",)) - 1 / 6 < 10e-7 + assert test_game([("Alice",)]) - 1 / 6 < 10e-7 + + # test string calls with missing player names + test_game2 = TestGame(n=n_players) + with pytest.raises(TypeError): + assert test_game2("Alice") == 0.0 + with pytest.raises(TypeError): + assert test_game2(("Bob",)) == 0.0 + with pytest.raises(TypeError): + assert test_game2([("Charlie",)]) == 0.0 + + def test_precompute(): """This test tests the precompute function of the base game class""" n_players = 6