1515
1616if TYPE_CHECKING :
1717 from shapiq .interaction_values import InteractionValues
18+ from shapiq .utils .custom_types import CoalitionMatrix , GameValues
1819
1920
2021class Game :
@@ -29,7 +30,7 @@ class Game:
2930 game_name: The name of the game.
3031
3132 Attributes:
32- precompute_flag : A flag to manually override the precomputed check. If set to ``True``, the
33+ _precompute_flag : A flag to manually override the precomputed check. If set to ``True``, the
3334 game is considered precomputed and only uses the lookup.
3435 value_storage: The storage for the game values without normalization applied.
3536 coalition_lookup: A lookup dictionary mapping from coalitions to indices in the
@@ -79,6 +80,19 @@ class Game:
7980
8081 """
8182
83+ n_players : int
84+ """The number of players in the game."""
85+
86+ game_id : str
87+ """A unique identifier for the game, based on its class name and hash."""
88+
89+ normalization_value : float
90+ """The value which is used to normalize (center) the game values such that the value for the
91+ empty coalition is zero. If this is zero, the game values are not normalized."""
92+
93+ value_storage : GameValues
94+ """The storage for the game values without normalization applied."""
95+
8296 def __init__ (
8397 self ,
8498 n_players : int | None = None ,
@@ -120,19 +134,19 @@ def __init__(
120134
121135 """
122136 # manual flag for choosing precomputed values even if not all values might be stored
123- self .precompute_flag : bool = False # flag to manually override the precomputed check
137+ self ._precompute_flag : bool = False # flag to manually override the precomputed check
124138
125139 # define storage variables
126- self .value_storage : np . ndarray = np .zeros (0 , dtype = float )
140+ self .value_storage : GameValues = np .zeros (0 , dtype = float )
127141 self .coalition_lookup : dict [tuple [int , ...], int ] = {}
128- self .n_players : int = n_players # if path_to_values is provided, this may be overwritten
142+ self .n_players = n_players
129143
130144 if n_players is None and path_to_values is None :
131145 msg = "The number of players has to be provided if game is not loaded from values."
132146 raise ValueError (msg )
133147
134148 # setup normalization of the game
135- self .normalization_value : float = 0.0
149+ self .normalization_value = 0.0
136150 if normalize and path_to_values is None :
137151 self .normalization_value = normalization_value
138152 if normalization_value is None :
@@ -147,7 +161,7 @@ def __init__(
147161 stacklevel = 2 ,
148162 )
149163
150- game_id : str = str (hash (self ))[:8 ]
164+ game_id = str (hash (self ))[:8 ]
151165 self .game_id = f"{ self .get_game_name ()} _{ game_id } "
152166 if path_to_values is not None :
153167 self .load_values (path_to_values , precomputed = True )
@@ -177,7 +191,7 @@ def n_values_stored(self) -> int:
177191 @property
178192 def precomputed (self ) -> bool :
179193 """Indication whether the game has been precomputed."""
180- return self .n_values_stored >= 2 ** self .n_players or self .precompute_flag
194+ return self .n_values_stored >= 2 ** self .n_players or self ._precompute_flag
181195
182196 @property
183197 def normalize (self ) -> bool :
@@ -192,13 +206,13 @@ def is_normalized(self) -> bool:
192206 def _check_coalitions (
193207 self ,
194208 coalitions : (
195- np . ndarray
209+ CoalitionMatrix
196210 | list [tuple [int , ...]]
197211 | list [tuple [str , ...]]
198212 | tuple [int , ...]
199213 | tuple [str , ...]
200214 ),
201- ) -> np . ndarray :
215+ ) -> CoalitionMatrix :
202216 """Validates the coalitions and convert them to one-hot encoding.
203217
204218 Check if the coalitions are in the correct format and convert them to one-hot encoding.
@@ -280,15 +294,15 @@ def _check_coalitions(
280294 def __call__ (
281295 self ,
282296 coalitions : (
283- np . ndarray
297+ CoalitionMatrix
284298 | list [tuple [int , ...]]
285299 | list [tuple [str , ...]]
286300 | tuple [int , ...]
287301 | tuple [str , ...]
288302 ),
289303 * ,
290304 verbose : bool = False ,
291- ) -> np . ndarray :
305+ ) -> GameValues :
292306 """Calls the game with the given coalitions.
293307
294308 Calls the game's value function with the given coalitions and returns the output of the
@@ -305,7 +319,7 @@ def __call__(
305319 The values of the coalitions.
306320
307321 """
308- coalitions : np . ndarray = self ._check_coalitions (coalitions )
322+ coalitions = self ._check_coalitions (coalitions )
309323 verbose = verbose or self .verbose
310324 if not self .precomputed and not verbose :
311325 values = self .value_function (coalitions )
@@ -335,7 +349,7 @@ def _lookup_coalitions(self, coalitions: np.ndarray) -> np.ndarray:
335349 raise KeyError (msg ) from error
336350 return values
337351
338- def value_function (self , coalitions : np . ndarray ) -> np . ndarray :
352+ def value_function (self , coalitions : CoalitionMatrix ) -> GameValues :
339353 """Returns the value of the coalitions.
340354
341355 The value function of the game, which models the behavior of the game. The value function
@@ -355,7 +369,7 @@ def value_function(self, coalitions: np.ndarray) -> np.ndarray:
355369 msg = "The value function has to be implemented in inherited classes."
356370 raise NotImplementedError (msg )
357371
358- def precompute (self , coalitions : np . ndarray | None = None ) -> None :
372+ def precompute (self , coalitions : CoalitionMatrix | None = None ) -> None :
359373 """Precompute the game values for all or a given set of coalitions.
360374
361375 The pre-computation iterates over the powerset of all coalitions or a given set of
@@ -401,26 +415,25 @@ def precompute(self, coalitions: np.ndarray | None = None) -> None:
401415 stacklevel = 2 ,
402416 )
403417 if coalitions is None :
404- coalitions = list (powerset (range (self .n_players ))) # might be getting slow
405- coalitions_array = transform_coalitions_to_array (coalitions , self .n_players )
406- coalitions_dict = {coal : i for i , coal in enumerate (coalitions )}
418+ all_coalitions = list (powerset (range (self .n_players ))) # might be getting slow
419+ coalitions = transform_coalitions_to_array (all_coalitions , self .n_players )
420+ coalitions_dict = {coal : i for i , coal in enumerate (all_coalitions )}
407421 else :
408- coalitions_array = coalitions
409- coalitions_tuple = transform_array_to_coalitions (coalitions = coalitions_array )
422+ coalitions_tuple = transform_array_to_coalitions (coalitions = coalitions )
410423 coalitions_dict = {coal : i for i , coal in enumerate (coalitions_tuple )}
411424
412425 # run the game for all coalitions (no normalization)
413426 norm_value , self .normalization_value = self .normalization_value , 0
414- game_values : np . ndarray = self (coalitions_array ) # call the game with the coalitions
427+ game_values = self (coalitions ) # call the game with the coalitions
415428 self .normalization_value = norm_value
416429
417430 # update the storage with the new coalitions and values
418431 self .value_storage = game_values .astype (float )
419432 self .coalition_lookup = coalitions_dict
420- self .precompute_flag = True
433+ self ._precompute_flag = True
421434
422435 def compute (
423- self , coalitions : np . ndarray | None = None
436+ self , coalitions : CoalitionMatrix | None = None
424437 ) -> tuple [np .ndarray , dict [tuple [int , ...], int ], float ]:
425438 """Compute the game values for all or a given set of coalitions.
426439
@@ -443,8 +456,7 @@ def compute(
443456 (array([0.25, 1.5]), {(1): 0, (1, 2): 1.5}, 0.0)
444457
445458 """
446- coalitions : np .ndarray = self ._check_coalitions (coalitions )
447- game_values = self .value_function (coalitions )
459+ game_values = self .value_function (self ._check_coalitions (coalitions ))
448460
449461 return game_values , self .coalition_lookup , self .normalization_value
450462
@@ -470,7 +482,7 @@ def save_values(self, path: Path | str) -> None:
470482 self .precompute ()
471483
472484 # transform the values_storage to float16 for compression
473- self .value_storage .astype (np .float16 )
485+ values = self .value_storage .astype (np .float16 )
474486
475487 # cast the coalitions_in_storage to bool
476488 coalitions_in_storage = transform_coalitions_to_array (
@@ -481,7 +493,7 @@ def save_values(self, path: Path | str) -> None:
481493 # save the data
482494 np .savez_compressed (
483495 path ,
484- values = self . value_storage ,
496+ values = values ,
485497 coalitions = coalitions_in_storage ,
486498 n_players = self .n_players ,
487499 normalization_value = self .normalization_value ,
@@ -517,7 +529,7 @@ def load_values(self, path: Path | str, *, precomputed: bool = False) -> None:
517529 self .value_storage = data ["values" ]
518530 coalition_lookup : list [tuple ] = transform_array_to_coalitions (data ["coalitions" ])
519531 self .coalition_lookup = {coal : i for i , coal in enumerate (coalition_lookup )}
520- self .precompute_flag = precomputed
532+ self ._precompute_flag = precomputed
521533 self .normalization_value = float (data ["normalization_value" ])
522534
523535 def save (self , path : Path | str ) -> None :
0 commit comments