|
8 | 8 | from shapiq.utils import ( |
9 | 9 | count_interactions, |
10 | 10 | generate_interaction_lookup, |
| 11 | + generate_interaction_lookup_from_coalitions, |
11 | 12 | get_explicit_subsets, |
12 | 13 | pair_subset_sizes, |
13 | 14 | powerset, |
@@ -110,6 +111,37 @@ def test_generate_interaction_lookup(n, min_order, max_order, expected): |
110 | 111 | assert generate_interaction_lookup(n, min_order, max_order) == expected |
111 | 112 |
|
112 | 113 |
|
| 114 | +@pytest.mark.parametrize( |
| 115 | + ("coalitions", "expected"), |
| 116 | + [ |
| 117 | + ( |
| 118 | + np.array([[1, 0, 1], [0, 1, 1], [1, 1, 0], [0, 0, 1]]), |
| 119 | + {(0, 2): 0, (1, 2): 1, (0, 1): 2, (2,): 3}, |
| 120 | + ), |
| 121 | + ( |
| 122 | + np.array([[1, 1, 1], [0, 1, 0], [1, 0, 0], [0, 0, 1]]), |
| 123 | + {(0, 1, 2): 0, (1,): 1, (0,): 2, (2,): 3}, |
| 124 | + ), |
| 125 | + ( |
| 126 | + np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), |
| 127 | + {(0,): 0, (1,): 1, (2,): 2}, |
| 128 | + ), |
| 129 | + ( |
| 130 | + np.array([[1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 1, 0]]), |
| 131 | + {(0, 1, 3): 0, (2, 3): 1, (0, 2): 2}, |
| 132 | + ), |
| 133 | + ( |
| 134 | + np.array([[0, 0, 0], [1, 1, 1]]), |
| 135 | + {(): 0, (0, 1, 2): 1}, |
| 136 | + ), |
| 137 | + ], |
| 138 | +) |
| 139 | +def test_generate_interaction_lookup_from_coalitions(coalitions, expected): |
| 140 | + """Tests the generate_interaction_lookup_from_coalitions function.""" |
| 141 | + result = generate_interaction_lookup_from_coalitions(coalitions) |
| 142 | + assert result == expected |
| 143 | + |
| 144 | + |
113 | 145 | @pytest.mark.parametrize( |
114 | 146 | ("coalitions", "n_player", "expected"), |
115 | 147 | [ |
|
0 commit comments