Skip to content

Commit

Permalink
Merge pull request #27 from mmschlk/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
mmschlk authored Jan 4, 2024
2 parents f3d0bd8 + aa57aff commit 829f3f4
Show file tree
Hide file tree
Showing 12 changed files with 438 additions and 13 deletions.
294 changes: 294 additions & 0 deletions notebooks/bike.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ matplotlib
colour
networkx
scikit-learn
pandas
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
with io.open(os.path.join(work_directory, "README.md"), encoding="utf-8") as f:
long_description = "\n" + f.read()

base_packages = ["numpy", "scipy"]
base_packages = ["numpy", "scipy", "pandas"]

plotting_packages = ["matplotlib", "colour", "networkx"]

Expand Down
4 changes: 4 additions & 0 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
split_subsets_budget,
)

from .datasets import load_bike

__all__ = [
# version
"__version__",
Expand All @@ -51,4 +53,6 @@
"split_subsets_budget",
"get_conditional_sample_weights",
"get_parent_array",
# datasets
"load_bike",
]
5 changes: 5 additions & 0 deletions shapiq/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""This module contains small datasets for testing and examples."""

from ._all import load_bike

__all__ = ["load_bike"]
38 changes: 38 additions & 0 deletions shapiq/datasets/_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import pandas as pd


GITHUB_DATA_URL = "https://github.com/mmschlk/shapiq/raw/main/data/"


def load_bike() -> pd.DataFrame:
"""Load the bike-sharing dataset from a Kaggle competition.
Original source: https://www.kaggle.com/c/bike-sharing-demand
Note:
The function and preprocessing is taken from the `sage` package.
Returns:
The bike-sharing dataset as a pandas DataFrame.
"""
data = pd.read_csv(os.path.join(GITHUB_DATA_URL, "bike.csv"))
columns = data.columns.tolist()

# Split and remove datetime column.
data["datetime"] = pd.to_datetime(data["datetime"])
data["year"] = data["datetime"].dt.year
data["month"] = data["datetime"].dt.month
data["day"] = data["datetime"].dt.day
data["hour"] = data["datetime"].dt.hour
data = data.drop("datetime", axis=1)

# Reorder and rename columns.
data = data[["year", "month", "day", "hour"] + columns[1:]]
data.columns = list(map(str.title, data.columns))

return data


if __name__ == "__main__":
print(load_bike())
5 changes: 2 additions & 3 deletions shapiq/explainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


from .interaction import InteractionExplainer
from .tree import TreeExplainer

__all__ = [
"InteractionExplainer",
]
__all__ = ["InteractionExplainer", "TreeExplainer"]
2 changes: 1 addition & 1 deletion shapiq/explainer/imputer/marginal_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
background_data: np.ndarray,
x_explain: Optional[np.ndarray] = None,
sample_replacements: bool = False,
sample_size: int = 5,
sample_size: int = 1,
categorical_features: list[int] = None,
random_state: Optional[int] = None,
) -> None:
Expand Down
20 changes: 20 additions & 0 deletions shapiq/explainer/tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""This module contains the TreeSHAP-IQ explainer for computing exact any order Shapley interactions
for trees and tree ensembles."""
import numpy as np

from approximator._base import InteractionValues
from explainer._base import Explainer


__all__ = ["TreeExplainer"]


class TreeExplainer(Explainer):
def __init__(self) -> None:
raise NotImplementedError(
"The TreeExplainer is not yet implemented. An initial version can be found here: "
"'https://github.com/mmschlk/TreeSHAP-IQ'."
)

def explain(self, x_explain: np.ndarray) -> InteractionValues:
pass
14 changes: 6 additions & 8 deletions shapiq/plot/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image

from approximator._base import InteractionValues
from utils import powerset

from ._config import BLUE, RED
Expand All @@ -23,12 +25,6 @@ def _get_color(value: float) -> str:
return BLUE.hex


def _min_max_normalization(value: float, min_value: float, max_value: float) -> float:
"""Normalizes the value between min and max"""
size = (value - min_value) / (max_value - min_value)
return size


def _add_weight_to_edges_in_graph(
graph: nx.Graph,
first_order_values: np.ndarray,
Expand Down Expand Up @@ -57,7 +53,7 @@ def _add_weight_to_edges_in_graph(
graph.nodes[node]["edgecolors"] = color

for edge in powerset(range(n_features), min_size=2, max_size=2):
weight: float = second_order_values[edge]
weight: float = float(second_order_values[edge])
color = _get_color(weight)
# scale weight between min and max edge value
size = abs(weight) / all_range
Expand Down Expand Up @@ -134,9 +130,10 @@ def _add_legend_to_axis(axis: plt.Axes) -> None:


def network_plot(
*,
interaction_values: InteractionValues,
first_order_values: np.ndarray[float],
second_order_values: np.ndarray[float],
*,
feature_names: Optional[list[Any]] = None,
feature_image_patches: Optional[dict[int, Image.Image]] = None,
feature_image_patches_size: Optional[Union[float, dict[int, float]]] = 0.2,
Expand All @@ -156,6 +153,7 @@ def network_plot(
:align: center
Args:
interaction_values: The interaction values as an interaction object.
first_order_values: The first order n-SII values of shape (n_features,).
second_order_values: The second order n-SII values of shape (n_features, n_features). The
diagonal values are ignored. Only the upper triangular values are used.
Expand Down
57 changes: 57 additions & 0 deletions shapiq/utils/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import sys
from typing import Any, Union


def safe_isinstance(obj: Any, class_path_str: Union[str, list[str], tuple[str]]) -> bool:
"""
Acts as a safe version of isinstance without having to explicitly import packages which may not
exist in the user's environment. Checks if obj is an instance of type specified by
class_path_str.
Note:
This function was directly taken from the `shap` repository.
Args:
obj: Some object you want to test against
class_path_str: A string or list of strings specifying full class paths Example:
`sklearn.ensemble.RandomForestRegressor`
Returns:
True if isinstance is true and the package exists, False otherwise
"""
if isinstance(class_path_str, str):
class_path_strs = [class_path_str]
elif isinstance(class_path_str, list) or isinstance(class_path_str, tuple):
class_path_strs = class_path_str
else:
class_path_strs = [""]

# try each module path in order
for class_path_str in class_path_strs:
if "." not in class_path_str:
raise ValueError(
"class_path_str must be a string or list of strings specifying a full \
module path to a class. Eg, 'sklearn.ensemble.RandomForestRegressor'"
)

# Splits on last occurrence of "."
module_name, class_name = class_path_str.rsplit(".", 1)

# here we don't check further if the model is not imported, since we shouldn't have
# an object of that types passed to us if the model the type is from has never been
# imported. (and we don't want to import lots of new modules for no reason)
if module_name not in sys.modules:
continue

module = sys.modules[module_name]

# Get class
_class = getattr(module, class_name, None)

if _class is None:
continue

if isinstance(obj, _class):
return True

return False
9 changes: 9 additions & 0 deletions tests/tests_explainer/test_explainer_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""This module contains all tests for the TreeExplainer class of the shapiq package."""
import pytest

from shapiq.explainer import TreeExplainer


def test_init():
with pytest.raises(NotImplementedError):
explainer = TreeExplainer()

0 comments on commit 829f3f4

Please sign in to comment.