From 3c3dfc24a23834b323140cae49d145bf2391ea30 Mon Sep 17 00:00:00 2001
From: Theo <49311372+Advueu963@users.noreply.github.com>
Date: Fri, 25 Oct 2024 16:15:10 +0200
Subject: [PATCH] 183 make games be also callable with a tuplelist of tuples
 (#242)

* added possibility of tuple/list of tuple in BaseGame __call__

* Added possibility for to Game coalitions with given player_names

* Added possibility to call game based on given player_names

* removed player_names property as it caused conflicts

* made __call__ annotations for python3.9

* refactor game call with strings

* refactor game call of tuple and list

* Reimplement _check_coalitions to fewer lines.

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>
---
 shapiq/games/base.py                | 113 ++++++++++++++++++++++++++--
 tests/tests_games/test_base_game.py |  86 +++++++++++++++++++++
 2 files changed, 194 insertions(+), 5 deletions(-)

diff --git a/shapiq/games/base.py b/shapiq/games/base.py
index 89e9800b..f8e01953 100644
--- a/shapiq/games/base.py
+++ b/shapiq/games/base.py
@@ -4,7 +4,7 @@
 import pickle
 import warnings
 from abc import ABC
-from typing import Optional
+from typing import Optional, Union
 
 import numpy as np
 from tqdm.auto import tqdm
@@ -94,6 +94,7 @@ def __init__(
         normalization_value: Optional[float] = None,
         path_to_values: Optional[str] = None,
         verbose: bool = False,
+        player_names: Optional[list[str]] = None,
         *args,
         **kwargs,
     ) -> None:
@@ -137,6 +138,11 @@ def __init__(
         self._empty_coalition_value_property = None
         self._grand_coalition_value_property = None
 
+        # define player_names
+        self.player_name_lookup: dict[str, int] = (
+            {name: i for i, name in enumerate(player_names)} if player_names is not None else None
+        )
+
         self.verbose = verbose
 
     @property
@@ -159,7 +165,105 @@ def is_normalized(self) -> bool:
         """Checks if the game is normalized/centered."""
         return self(self.empty_coalition) == 0
 
-    def __call__(self, coalitions: np.ndarray, verbose: bool = False) -> np.ndarray:
+    def _check_coalitions(
+        self,
+        coalitions: Union[np.ndarray, list[Union[tuple[int], tuple[str]]]],
+    ) -> np.ndarray:
+        """
+        Check if the coalitions are in the correct format and convert them to one-hot encoding.
+        The format may either be a numpy array containg the coalitions in one-hot encoding or a list of tuples with integers or strings.
+        Args:
+            coalitions: The coalitions to convert to one-hot encoding.
+        Returns:
+            np.ndarray: The coalitions in the correct format
+        Raises:
+            TypeError: If the coalitions are not in the correct format.
+        Examples:
+            >>> coalitions = np.asarray([[1, 0, 0, 0], [0, 1, 1, 0]])
+            >>> coalitions = [(0, 1), (1, 2)]
+            >>> coalitions = [()]
+            >>> coalitions = [(0, 1), (1, 2), (0, 1, 2)]
+            if player_name_lookup is not None:
+            >>> coalitions = [("Alice", "Bob"), ("Bob", "Charlie")]
+            Wrong format:
+            >>> coalitions = [1, 0, 0, 0]
+            >>> coalitions = [(1,"Alice")]
+            >>> coalitions = np.array([1,-1,2])
+
+
+        """
+        error_message = (
+            "List may only contain tuples of integers or strings."
+            "The tuples are not allowed to have heterogeneous types."
+            "Reconcile the docs for correct format of coalitions."
+        )
+
+        if isinstance(coalitions, np.ndarray):
+
+            # Check that coalition is contained in array
+            if len(coalitions) == 0:
+                raise TypeError("The array of coalitions is empty.")
+
+            # Check if single coalition is correctly given
+            if coalitions.ndim == 1:
+                if len(coalitions) < self.n_players or len(coalitions) > self.n_players:
+                    raise TypeError(
+                        "The array of coalitions is not correctly formatted."
+                        f"It should have a length of {self.n_players}"
+                    )
+                coalitions = coalitions.reshape((1, self.n_players))
+
+            # Check that all coalitions have the correct number of players
+            if coalitions.shape[1] != self.n_players:
+                raise TypeError(
+                    f"The number of players in the coalitions ({coalitions.shape[1]}) does not match "
+                    f"the number of players in the game ({self.n_players})."
+                )
+
+            # Check that values of numpy array are either 0 or 1
+            if not np.all(np.logical_or(coalitions == 0, coalitions == 1)):
+                raise TypeError("The values in the array of coalitions are not binary.")
+
+            return coalitions
+
+        # We now assume to work with list of tuples
+        if isinstance(coalitions, tuple):
+            # if by any chance a tuple was given wrap into a list
+            coalitions = [coalitions]
+
+        try:
+            # convert list of tuples to one-hot encoding
+            coalitions = transform_coalitions_to_array(coalitions, self.n_players)
+
+            return coalitions
+        except Exception as err:
+            # It may either be the tuples contain strings or wrong format
+            if self.player_name_lookup is not None:
+                # We now assume the tuples to contain strings
+                try:
+                    coalitions = [
+                        (
+                            tuple(self.player_name_lookup[player] for player in coalition)
+                            if coalition != tuple()
+                            else tuple()
+                        )
+                        for coalition in coalitions
+                    ]
+                    coalitions = transform_coalitions_to_array(coalitions, self.n_players)
+
+                    return coalitions
+                except Exception as err:
+                    raise TypeError(error_message) from err
+
+            raise TypeError(error_message) from err
+
+    def __call__(
+        self,
+        coalitions: Union[
+            np.ndarray, list[Union[tuple[int], tuple[str]]], tuple[Union[int, str]], str
+        ],
+        verbose: bool = False,
+    ) -> np.ndarray:
         """Calls the game's value function with the given coalitions and returns the output of the
         value function.
 
@@ -170,9 +274,8 @@ def __call__(self, coalitions: np.ndarray, verbose: bool = False) -> np.ndarray:
         Returns:
             The values of the coalitions.
         """
-        # check if coalitions are correct dimensions
-        if coalitions.ndim == 1:
-            coalitions = coalitions.reshape((1, self.n_players))
+        # check if coalitions are correct format
+        coalitions = self._check_coalitions(coalitions)
 
         verbose = verbose or self.verbose
 
diff --git a/tests/tests_games/test_base_game.py b/tests/tests_games/test_base_game.py
index 28f6b5cc..5e3ef654 100644
--- a/tests/tests_games/test_base_game.py
+++ b/tests/tests_games/test_base_game.py
@@ -5,10 +5,96 @@
 import numpy as np
 import pytest
 
+from shapiq.games.base import Game
 from shapiq.games.benchmark import DummyGame  # used to test the base class
 from shapiq.utils.sets import powerset, transform_coalitions_to_array
 
 
+def test_call():
+    """This test tests the call function of the base game class."""
+
+    class TestGame(Game):
+        """This is a test game class that inherits from the base game class.
+        Its value function is the amount of players divided by the number of players.
+        """
+
+        def __init__(self, n, **kwargs):
+            super().__init__(n_players=n, normalization_value=0, **kwargs)
+
+        def value_function(self, coalition):
+            return np.sum(coalition) / self.n_players
+
+    n_players = 6
+    test_game = TestGame(
+        n=n_players, player_names=["Alice", "Bob", "Charlie", "David", "Eve", "Frank"]
+    )
+
+    # assert that player names are correctly stored
+    assert test_game.player_name_lookup == {
+        "Alice": 0,
+        "Bob": 1,
+        "Charlie": 2,
+        "David": 3,
+        "Eve": 4,
+        "Frank": 5,
+    }
+
+    assert test_game([]) == 0.0
+
+    # test coalition calls with wrong datatype
+    with pytest.raises(TypeError):
+        assert test_game([(0, 1), "Alice", "Charlie"])
+    with pytest.raises(TypeError):
+        assert test_game([(0, 1), ("Alice",), ("Bob",)])
+    with pytest.raises(TypeError):
+        assert test_game(("Alice", 1))
+
+    # test wrong coalition size in call
+    with pytest.raises(TypeError):
+        assert test_game(np.array([True, False, True])) == 0.0
+    with pytest.raises(TypeError):
+        assert test_game(np.array([])) == 0.0
+
+    # test wrong method for numpy array values
+    with pytest.raises(TypeError):
+        assert test_game(np.array([1, 2, 3, 4, 5, 6])) == 0.0
+
+    # test wrong coalition size in shape[1]
+    with pytest.raises(TypeError):
+        assert test_game(np.array([[True, False, True]])) == 0.0
+
+    # test with empty coalition all call variants
+    test_coalition = test_game.empty_coalition
+    assert test_game(test_coalition) == 0.0
+    assert test_game(()) == 0.0
+    assert test_game([()]) == 0.0
+
+    # test with grand coalition all call variants
+    test_coalition = test_game.grand_coalition
+    assert test_game(test_coalition) == 1.0
+    assert test_game(tuple(range(0, test_game.n_players))) == 1.0
+    assert test_game([tuple(range(0, test_game.n_players))]) == 1.0
+    assert test_game(tuple(test_game.player_name_lookup.values())) == 1.0
+    assert test_game([tuple(test_game.player_name_lookup.values())]) == 1.0
+
+    # test with single player coalition all call variants
+    test_coalition = np.array([True] + [False for _ in range(test_game.n_players - 1)])
+    assert test_game(test_coalition) - 1 / 6 < 10e-7
+    assert test_game((0,)) - 1 / 6 < 10e-7
+    assert test_game([(0,)]) - 1 / 6 < 10e-7
+    assert test_game(("Alice",)) - 1 / 6 < 10e-7
+    assert test_game([("Alice",)]) - 1 / 6 < 10e-7
+
+    # test string calls with missing player names
+    test_game2 = TestGame(n=n_players)
+    with pytest.raises(TypeError):
+        assert test_game2("Alice") == 0.0
+    with pytest.raises(TypeError):
+        assert test_game2(("Bob",)) == 0.0
+    with pytest.raises(TypeError):
+        assert test_game2([("Charlie",)]) == 0.0
+
+
 def test_precompute():
     """This test tests the precompute function of the base game class"""
     n_players = 6