Skip to content

Commit 9061511

Browse files
mmschlkannprzy
andauthored
🛠️ Add pyright (mmschlk#409)
* adds pyright to project * fixes some typing problems * reworks workflows (mmschlk#405) * reworks workflows * limit tensorflow to not install on windows * adds device="cpu" to SentimentAnalysis test * adds doc building workflow as a test * reverts windows unit test runner to use python 3.12 as a fix * makes related software an rst file * moved copy_notebooks into scripts * improves docs * try re-installing python in uv to fix win errors * removed unncessary extension * Add beeswarm plot to SHAP-IQ (mmschlk#406) * add beeswarm plot * improve plot readability * add tests for beeswarm plot * update docs for beeswarm plot * more informative error messages for beeswarm plot * refactor code * add beeswarm plot to init * simplify colormap generation in beeswarm plot * add tutorial for beeswarm plot --------- Co-authored-by: Anna Przybyłowska <[email protected]>
1 parent 9a1892d commit 9061511

File tree

16 files changed

+218
-131
lines changed

16 files changed

+218
-131
lines changed

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ exclude = [
151151
# Exclude a variety of commonly ignored directories.
152152
# Note to developers: If you add a new ignore at least put a comment next to it why it is ignored
153153
[tool.ruff.lint.per-file-ignores]
154+
"shapiq/typing.py" = [
155+
"A005", # we want to be similar to other libraries that also shadow typing
156+
]
154157
"tests/*.py" = [
155158
# "ALL",
156159
"S101", # we need asserts in tests
@@ -212,6 +215,11 @@ force-wrap-aliases = true
212215
no-lines-before = ["future"]
213216
required-imports = ["from __future__ import annotations"]
214217

218+
[tool.pyright]
219+
include = ["shapiq"]
220+
exclude = ["tests", "docs", "benchmark", "scripts", "shapiq/plot"]
221+
pythonVersion = "3.10"
222+
215223
[dependency-groups]
216224
test = [
217225
"pytest>=8.3.5",
@@ -222,6 +230,7 @@ test = [
222230
lint = [
223231
"ruff>=0.11.2",
224232
"pre-commit>=4.2.0",
233+
"pyright>=1.1.402",
225234
]
226235
docs = [
227236
"sphinx>=8.0.0",

shapiq/explainer/tree/base.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass
6+
from typing import TYPE_CHECKING
67

78
import numpy as np
89

910
from .utils import compute_empty_prediction
1011

12+
if TYPE_CHECKING:
13+
from numpy.typing import NDArray
14+
1115

1216
@dataclass
1317
class TreeModel:
@@ -52,22 +56,22 @@ class TreeModel:
5256
5357
"""
5458

55-
children_left: np.ndarray[int]
56-
children_right: np.ndarray[int]
57-
features: np.ndarray[int]
58-
thresholds: np.ndarray[float]
59-
values: np.ndarray[float]
60-
node_sample_weight: np.ndarray[float]
61-
empty_prediction: float | None = None
62-
leaf_mask: np.ndarray[bool] | None = None
63-
n_features_in_tree: int | None = None
64-
max_feature_id: int | None = None
65-
feature_ids: set | None = None
66-
root_node_id: int | None = None
67-
n_nodes: int | None = None
68-
nodes: np.ndarray[int] | None = None
69-
feature_map_original_internal: dict[int, int] | None = None
70-
feature_map_internal_original: dict[int, int] | None = None
59+
children_left: NDArray[np.int_]
60+
children_right: NDArray[np.int_]
61+
features: NDArray[np.int_]
62+
thresholds: NDArray[np.floating]
63+
values: NDArray[np.floating]
64+
node_sample_weight: NDArray[np.floating]
65+
empty_prediction: float = None # type: ignore[assignment]
66+
leaf_mask: NDArray[np.bool_] = None # type: ignore[assignment]
67+
n_features_in_tree: int = None # type: ignore[assignment]
68+
max_feature_id: int = None # type: ignore[assignment]
69+
feature_ids: set = None # type: ignore[assignment]
70+
root_node_id: int = None # type: ignore[assignment]
71+
n_nodes: int = None # type: ignore[assignment]
72+
nodes: NDArray[np.int_] = None # type: ignore[assignment]
73+
feature_map_original_internal: dict[int, int] = None # type: ignore[assignment]
74+
feature_map_internal_original: dict[int, int] = None # type: ignore[assignment]
7175
original_output_type: str = "raw" # not used at the moment
7276

7377
def compute_empty_prediction(self) -> None:
@@ -187,7 +191,7 @@ def predict_one(self, x: np.ndarray) -> float:
187191
else:
188192
node = self.children_right[node]
189193
is_leaf = self.leaf_mask[node]
190-
return self.values[node]
194+
return float(self.values[node])
191195

192196

193197
@dataclass

shapiq/explainer/tree/conversion/edges.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,27 @@
22

33
from __future__ import annotations
44

5+
from typing import TYPE_CHECKING
6+
57
import numpy as np
68
from scipy.special import binom
79

810
from shapiq.explainer.tree.base import EdgeTree
911

12+
if TYPE_CHECKING:
13+
from numpy.typing import NDArray
14+
1015

1116
def create_edge_tree(
12-
children_left: np.ndarray[int],
13-
children_right: np.ndarray[int],
14-
features: np.ndarray[int],
15-
node_sample_weight: np.ndarray[float],
16-
values: np.ndarray[float],
17+
children_left: NDArray[np.int_],
18+
children_right: NDArray[np.int_],
19+
features: NDArray[np.int_],
20+
node_sample_weight: NDArray[np.floating],
21+
values: NDArray[np.floating],
1722
n_nodes: int,
1823
n_features: int,
1924
max_interaction: int,
20-
subset_updates_pos_store: dict[int, dict[int, np.ndarray[int]]],
25+
subset_updates_pos_store: dict[int, dict[int, NDArray[np.int_]]],
2126
) -> EdgeTree:
2227
"""Extracts edge information recursively from the tree information.
2328

shapiq/explainer/tree/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,21 @@
22

33
from __future__ import annotations
44

5+
from typing import TYPE_CHECKING
6+
57
import numpy as np
68

9+
if TYPE_CHECKING:
10+
from numpy.typing import NDArray
11+
12+
713
__all__ = ["compute_empty_prediction", "get_conditional_sample_weights"]
814

915

1016
def get_conditional_sample_weights(
11-
sample_count: np.ndarray[int],
12-
parent_array: np.ndarray[int],
13-
) -> np.ndarray[float]:
17+
sample_count: NDArray[np.int_],
18+
parent_array: NDArray[np.int_],
19+
) -> NDArray[np.floating]:
1420
"""Get the conditional sample weights for a tree at each decision node.
1521
1622
The conditional sample weights are the probabilities of going left or right at each decision
@@ -41,8 +47,8 @@ def get_conditional_sample_weights(
4147

4248

4349
def compute_empty_prediction(
44-
leaf_values: np.ndarray[float],
45-
leaf_sample_weights: np.ndarray[float],
50+
leaf_values: NDArray[np.floating],
51+
leaf_sample_weights: NDArray[np.floating],
4652
) -> float:
4753
"""Compute the empty prediction of a tree model.
4854
@@ -56,4 +62,4 @@ def compute_empty_prediction(
5662
The empty prediction of the tree model.
5763
5864
"""
59-
return np.sum(leaf_values * leaf_sample_weights) / np.sum(leaf_sample_weights)
65+
return float(np.sum(leaf_values * leaf_sample_weights) / np.sum(leaf_sample_weights))

shapiq/games/base.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
if TYPE_CHECKING:
1717
from shapiq.interaction_values import InteractionValues
18+
from shapiq.utils.custom_types import CoalitionMatrix, GameValues
1819

1920

2021
class Game:
@@ -29,7 +30,7 @@ class Game:
2930
game_name: The name of the game.
3031
3132
Attributes:
32-
precompute_flag: A flag to manually override the precomputed check. If set to ``True``, the
33+
_precompute_flag: A flag to manually override the precomputed check. If set to ``True``, the
3334
game is considered precomputed and only uses the lookup.
3435
value_storage: The storage for the game values without normalization applied.
3536
coalition_lookup: A lookup dictionary mapping from coalitions to indices in the
@@ -79,6 +80,19 @@ class Game:
7980
8081
"""
8182

83+
n_players: int
84+
"""The number of players in the game."""
85+
86+
game_id: str
87+
"""A unique identifier for the game, based on its class name and hash."""
88+
89+
normalization_value: float
90+
"""The value which is used to normalize (center) the game values such that the value for the
91+
empty coalition is zero. If this is zero, the game values are not normalized."""
92+
93+
value_storage: GameValues
94+
"""The storage for the game values without normalization applied."""
95+
8296
def __init__(
8397
self,
8498
n_players: int | None = None,
@@ -120,19 +134,19 @@ def __init__(
120134
121135
"""
122136
# manual flag for choosing precomputed values even if not all values might be stored
123-
self.precompute_flag: bool = False # flag to manually override the precomputed check
137+
self._precompute_flag: bool = False # flag to manually override the precomputed check
124138

125139
# define storage variables
126-
self.value_storage: np.ndarray = np.zeros(0, dtype=float)
140+
self.value_storage: GameValues = np.zeros(0, dtype=float)
127141
self.coalition_lookup: dict[tuple[int, ...], int] = {}
128-
self.n_players: int = n_players # if path_to_values is provided, this may be overwritten
142+
self.n_players = n_players
129143

130144
if n_players is None and path_to_values is None:
131145
msg = "The number of players has to be provided if game is not loaded from values."
132146
raise ValueError(msg)
133147

134148
# setup normalization of the game
135-
self.normalization_value: float = 0.0
149+
self.normalization_value = 0.0
136150
if normalize and path_to_values is None:
137151
self.normalization_value = normalization_value
138152
if normalization_value is None:
@@ -147,7 +161,7 @@ def __init__(
147161
stacklevel=2,
148162
)
149163

150-
game_id: str = str(hash(self))[:8]
164+
game_id = str(hash(self))[:8]
151165
self.game_id = f"{self.get_game_name()}_{game_id}"
152166
if path_to_values is not None:
153167
self.load_values(path_to_values, precomputed=True)
@@ -177,7 +191,7 @@ def n_values_stored(self) -> int:
177191
@property
178192
def precomputed(self) -> bool:
179193
"""Indication whether the game has been precomputed."""
180-
return self.n_values_stored >= 2**self.n_players or self.precompute_flag
194+
return self.n_values_stored >= 2**self.n_players or self._precompute_flag
181195

182196
@property
183197
def normalize(self) -> bool:
@@ -192,13 +206,13 @@ def is_normalized(self) -> bool:
192206
def _check_coalitions(
193207
self,
194208
coalitions: (
195-
np.ndarray
209+
CoalitionMatrix
196210
| list[tuple[int, ...]]
197211
| list[tuple[str, ...]]
198212
| tuple[int, ...]
199213
| tuple[str, ...]
200214
),
201-
) -> np.ndarray:
215+
) -> CoalitionMatrix:
202216
"""Validates the coalitions and convert them to one-hot encoding.
203217
204218
Check if the coalitions are in the correct format and convert them to one-hot encoding.
@@ -280,15 +294,15 @@ def _check_coalitions(
280294
def __call__(
281295
self,
282296
coalitions: (
283-
np.ndarray
297+
CoalitionMatrix
284298
| list[tuple[int, ...]]
285299
| list[tuple[str, ...]]
286300
| tuple[int, ...]
287301
| tuple[str, ...]
288302
),
289303
*,
290304
verbose: bool = False,
291-
) -> np.ndarray:
305+
) -> GameValues:
292306
"""Calls the game with the given coalitions.
293307
294308
Calls the game's value function with the given coalitions and returns the output of the
@@ -305,7 +319,7 @@ def __call__(
305319
The values of the coalitions.
306320
307321
"""
308-
coalitions: np.ndarray = self._check_coalitions(coalitions)
322+
coalitions = self._check_coalitions(coalitions)
309323
verbose = verbose or self.verbose
310324
if not self.precomputed and not verbose:
311325
values = self.value_function(coalitions)
@@ -335,7 +349,7 @@ def _lookup_coalitions(self, coalitions: np.ndarray) -> np.ndarray:
335349
raise KeyError(msg) from error
336350
return values
337351

338-
def value_function(self, coalitions: np.ndarray) -> np.ndarray:
352+
def value_function(self, coalitions: CoalitionMatrix) -> GameValues:
339353
"""Returns the value of the coalitions.
340354
341355
The value function of the game, which models the behavior of the game. The value function
@@ -355,7 +369,7 @@ def value_function(self, coalitions: np.ndarray) -> np.ndarray:
355369
msg = "The value function has to be implemented in inherited classes."
356370
raise NotImplementedError(msg)
357371

358-
def precompute(self, coalitions: np.ndarray | None = None) -> None:
372+
def precompute(self, coalitions: CoalitionMatrix | None = None) -> None:
359373
"""Precompute the game values for all or a given set of coalitions.
360374
361375
The pre-computation iterates over the powerset of all coalitions or a given set of
@@ -401,26 +415,25 @@ def precompute(self, coalitions: np.ndarray | None = None) -> None:
401415
stacklevel=2,
402416
)
403417
if coalitions is None:
404-
coalitions = list(powerset(range(self.n_players))) # might be getting slow
405-
coalitions_array = transform_coalitions_to_array(coalitions, self.n_players)
406-
coalitions_dict = {coal: i for i, coal in enumerate(coalitions)}
418+
all_coalitions = list(powerset(range(self.n_players))) # might be getting slow
419+
coalitions = transform_coalitions_to_array(all_coalitions, self.n_players)
420+
coalitions_dict = {coal: i for i, coal in enumerate(all_coalitions)}
407421
else:
408-
coalitions_array = coalitions
409-
coalitions_tuple = transform_array_to_coalitions(coalitions=coalitions_array)
422+
coalitions_tuple = transform_array_to_coalitions(coalitions=coalitions)
410423
coalitions_dict = {coal: i for i, coal in enumerate(coalitions_tuple)}
411424

412425
# run the game for all coalitions (no normalization)
413426
norm_value, self.normalization_value = self.normalization_value, 0
414-
game_values: np.ndarray = self(coalitions_array) # call the game with the coalitions
427+
game_values = self(coalitions) # call the game with the coalitions
415428
self.normalization_value = norm_value
416429

417430
# update the storage with the new coalitions and values
418431
self.value_storage = game_values.astype(float)
419432
self.coalition_lookup = coalitions_dict
420-
self.precompute_flag = True
433+
self._precompute_flag = True
421434

422435
def compute(
423-
self, coalitions: np.ndarray | None = None
436+
self, coalitions: CoalitionMatrix | None = None
424437
) -> tuple[np.ndarray, dict[tuple[int, ...], int], float]:
425438
"""Compute the game values for all or a given set of coalitions.
426439
@@ -443,8 +456,7 @@ def compute(
443456
(array([0.25, 1.5]), {(1): 0, (1, 2): 1.5}, 0.0)
444457
445458
"""
446-
coalitions: np.ndarray = self._check_coalitions(coalitions)
447-
game_values = self.value_function(coalitions)
459+
game_values = self.value_function(self._check_coalitions(coalitions))
448460

449461
return game_values, self.coalition_lookup, self.normalization_value
450462

@@ -470,7 +482,7 @@ def save_values(self, path: Path | str) -> None:
470482
self.precompute()
471483

472484
# transform the values_storage to float16 for compression
473-
self.value_storage.astype(np.float16)
485+
values = self.value_storage.astype(np.float16)
474486

475487
# cast the coalitions_in_storage to bool
476488
coalitions_in_storage = transform_coalitions_to_array(
@@ -481,7 +493,7 @@ def save_values(self, path: Path | str) -> None:
481493
# save the data
482494
np.savez_compressed(
483495
path,
484-
values=self.value_storage,
496+
values=values,
485497
coalitions=coalitions_in_storage,
486498
n_players=self.n_players,
487499
normalization_value=self.normalization_value,
@@ -517,7 +529,7 @@ def load_values(self, path: Path | str, *, precomputed: bool = False) -> None:
517529
self.value_storage = data["values"]
518530
coalition_lookup: list[tuple] = transform_array_to_coalitions(data["coalitions"])
519531
self.coalition_lookup = {coal: i for i, coal in enumerate(coalition_lookup)}
520-
self.precompute_flag = precomputed
532+
self._precompute_flag = precomputed
521533
self.normalization_value = float(data["normalization_value"])
522534

523535
def save(self, path: Path | str) -> None:

0 commit comments

Comments
 (0)