Skip to content

Commit dede390

Browse files
CloseChoicemmschlk
andauthored
Remove zeros from interaction lookup, check in test (mmschlk#424)
* remove zeros from interaction lookup, check in test * remove debug statements * fix pre-commit issues * remove interaction values if they are summed to zero, add reproducible tests --------- Co-authored-by: Maximilian <[email protected]>
1 parent 4c4f0d3 commit dede390

File tree

3 files changed

+71
-25
lines changed

3 files changed

+71
-25
lines changed

src/shapiq/game_theory/aggregation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,14 @@ def aggregate_base_attributions(
7272
for interaction in powerset(base_interaction, min_size=1, max_size=order):
7373
scaling = float(bernoulli_numbers[len(base_interaction) - len(interaction)])
7474
update_interaction = scaling * base_interaction_value
75-
try:
76-
transformed_interactions[interaction] += update_interaction
77-
except KeyError:
78-
transformed_interactions[interaction] = update_interaction
75+
if update_interaction == 0:
76+
continue
77+
transformed_interactions[interaction] = (
78+
transformed_interactions.get(interaction, 0) + update_interaction
79+
)
80+
# if the interactions sum to 0, we pop them from the dict
81+
if transformed_interactions[interaction] == 0:
82+
transformed_interactions.pop(interaction)
7983

8084
# update the index name after the aggregation (e.g., SII -> k-SII)
8185
new_index = _change_index(index)

src/shapiq/game_theory/moebius_converter.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,17 @@ def _moebius_to_base_interaction(
157157
strict=False,
158158
):
159159
moebius_size = len(moebius_set)
160-
# for higher-order Möbius sets (size > order) distribute the value on all interactions
161160
for interaction in powerset(moebius_set, min_size=0, max_size=order):
162161
val_distributed = distribution_weights[moebius_size, len(interaction)]
163162
# Check if Möbius value is distributed onto this interaction
164-
if interaction in base_interaction_dict:
165-
base_interaction_dict[interaction] += moebius_val * val_distributed
166-
else:
167-
base_interaction_dict[interaction] = moebius_val * val_distributed
163+
moebius_val_calc = moebius_val * val_distributed
164+
if moebius_val_calc == 0:
165+
continue
166+
base_interaction_dict[interaction] = (
167+
base_interaction_dict.get(interaction, 0) + moebius_val_calc
168+
)
169+
if base_interaction_dict[interaction] == 0:
170+
base_interaction_dict.pop(interaction)
168171

169172
base_interaction_values = np.zeros(len(base_interaction_dict))
170173
base_interaction_lookup = {}
@@ -227,19 +230,24 @@ def _stii_routine(
227230
if moebius_size < order:
228231
# For STII, interaction below size order are the Möbius coefficients
229232
val_distributed = distribution_weights[moebius_size, moebius_size]
230-
if moebius_set in stii_dict:
231-
stii_dict[moebius_set] += moebius_val * val_distributed
232-
else:
233-
stii_dict[moebius_set] = moebius_val * val_distributed
233+
moebius_val_calc = moebius_val * val_distributed
234+
if moebius_val_calc == 0:
235+
continue
236+
stii_dict[moebius_set] = stii_dict.get(moebius_set, 0) + moebius_val_calc
237+
# if Möbius values sum up to zero, we pop it from the dict
238+
if stii_dict[moebius_set] == 0:
239+
stii_dict.pop(moebius_set)
234240
else:
235241
# higher-order Möbius sets (size > order) distribute to all top-order interactions
236242
for interaction in powerset(moebius_set, min_size=order, max_size=order):
237243
val_distributed = distribution_weights[moebius_size, len(interaction)]
238244
# Check if Möbius value is distributed onto this interaction
239-
if interaction in stii_dict:
240-
stii_dict[interaction] += moebius_val * val_distributed
241-
else:
242-
stii_dict[interaction] = moebius_val * val_distributed
245+
moebius_val_calc = moebius_val * val_distributed
246+
if moebius_val_calc == 0:
247+
continue
248+
stii_dict[interaction] = stii_dict.get(interaction, 0) + moebius_val_calc
249+
if stii_dict[interaction] == 0:
250+
stii_dict.pop(interaction)
243251

244252
stii_values = np.zeros(len(stii_dict))
245253
stii_lookup = {}
@@ -311,10 +319,13 @@ def _fii_routine(self, index: Literal["FSII", "FBII"], order: int) -> Interactio
311319
for interaction in powerset(moebius_set, min_size=1, max_size=order):
312320
val_distributed = distribution_weights[moebius_size, len(interaction)]
313321
# Check if Möbius value is distributed onto this interaction
314-
if interaction in fii_dict:
315-
fii_dict[interaction] += moebius_val * val_distributed
316-
else:
317-
fii_dict[interaction] = moebius_val * val_distributed
322+
moebius_val_calc = moebius_val * val_distributed
323+
if moebius_val_calc == 0:
324+
continue
325+
fii_dict[interaction] = fii_dict.get(interaction, 0) + moebius_val_calc
326+
# if Möbius values sum up to zero, we pop it from the dict
327+
if fii_dict[interaction] == 0:
328+
fii_dict.pop(interaction)
318329

319330
fii_values = np.zeros(len(fii_dict))
320331
fii_lookup = {}

tests/shapiq/tests_unit/tests_game_theory/test_moebius_converter.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import numpy as np
6+
import pytest
67

78
from shapiq.game_theory.moebius_converter import MoebiusConverter
89
from shapiq_games.synthetic.soum import SOUM
@@ -22,13 +23,43 @@ def test_soum_moebius_conversion():
2223

2324
moebius_converter = MoebiusConverter(soum.moebius_coefficients)
2425

25-
shapley_interactions = {}
2626
for index in ["STII", "k-SII", "FSII"]:
27-
shapley_interactions[index] = moebius_converter(index=index, order=order)
27+
shapley_interactions = moebius_converter(index=index, order=order)
2828
# Assert efficiency
29-
assert (np.sum(shapley_interactions[index].values) - predicted_value) ** 2 < 10e-7
30-
assert (shapley_interactions[index][()] - emptyset_prediction) ** 2 < 10e-7
29+
assert (np.sum(shapley_interactions.values) - predicted_value) ** 2 < 10e-7
30+
assert (shapley_interactions[()] - emptyset_prediction) ** 2 < 10e-7
31+
for v in moebius_converter._computed.values():
32+
# check that no 0's are in the interaction lookup (except for the empty set, which is the first entry)
33+
interactions = v.interactions
34+
assert all(v != 0 for idx, v in enumerate(interactions.values()) if idx > 0)
3135

3236
# test direct call of Möbius converter
3337
for index in ["STII", "k-SII", "SII", "FSII"]:
3438
moebius_converter(index=index, order=order)
39+
40+
41+
@pytest.mark.parametrize("random_state", [10, 19, 20, 21, 23])
42+
def test_soum_moebius_conversion_failing_states(random_state):
43+
"""Test SOUM moebius conversion with specific failing random states."""
44+
order = 3
45+
n_basis_games = 1
46+
n = 7
47+
48+
soum = SOUM(n, n_basis_games=n_basis_games, random_state=random_state)
49+
predicted_value = soum(np.ones(n))[0]
50+
emptyset_prediction = soum(np.zeros(n))[0]
51+
52+
moebius_converter = MoebiusConverter(soum.moebius_coefficients)
53+
54+
for index in ["STII", "k-SII", "FSII"]:
55+
shapley_interactions = moebius_converter(index=index, order=order)
56+
# Assert efficiency
57+
assert (np.sum(shapley_interactions.values) - predicted_value) ** 2 < 10e-7
58+
assert (shapley_interactions[()] - emptyset_prediction) ** 2 < 10e-7
59+
for v in moebius_converter._computed.values():
60+
interactions = v.interactions
61+
# Check that no 0's are in the interaction values (except for empty set)
62+
non_empty_values = [val for idx, val in enumerate(interactions.values()) if idx > 0]
63+
assert all(val != 0 for val in non_empty_values), (
64+
f"Found zero values in non-empty interactions with random state {random_state}"
65+
)

0 commit comments

Comments
 (0)