-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #27 from mmschlk/development
Development
- Loading branch information
Showing
12 changed files
with
438 additions
and
13 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ matplotlib | |
colour | ||
networkx | ||
scikit-learn | ||
pandas |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""This module contains small datasets for testing and examples.""" | ||
|
||
from ._all import load_bike | ||
|
||
__all__ = ["load_bike"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os | ||
import pandas as pd | ||
|
||
|
||
GITHUB_DATA_URL = "https://github.com/mmschlk/shapiq/raw/main/data/" | ||
|
||
|
||
def load_bike() -> pd.DataFrame: | ||
"""Load the bike-sharing dataset from a Kaggle competition. | ||
Original source: https://www.kaggle.com/c/bike-sharing-demand | ||
Note: | ||
The function and preprocessing is taken from the `sage` package. | ||
Returns: | ||
The bike-sharing dataset as a pandas DataFrame. | ||
""" | ||
data = pd.read_csv(os.path.join(GITHUB_DATA_URL, "bike.csv")) | ||
columns = data.columns.tolist() | ||
|
||
# Split and remove datetime column. | ||
data["datetime"] = pd.to_datetime(data["datetime"]) | ||
data["year"] = data["datetime"].dt.year | ||
data["month"] = data["datetime"].dt.month | ||
data["day"] = data["datetime"].dt.day | ||
data["hour"] = data["datetime"].dt.hour | ||
data = data.drop("datetime", axis=1) | ||
|
||
# Reorder and rename columns. | ||
data = data[["year", "month", "day", "hour"] + columns[1:]] | ||
data.columns = list(map(str.title, data.columns)) | ||
|
||
return data | ||
|
||
|
||
if __name__ == "__main__": | ||
print(load_bike()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
"""This module contains the TreeSHAP-IQ explainer for computing exact any order Shapley interactions | ||
for trees and tree ensembles.""" | ||
import numpy as np | ||
|
||
from approximator._base import InteractionValues | ||
from explainer._base import Explainer | ||
|
||
|
||
__all__ = ["TreeExplainer"] | ||
|
||
|
||
class TreeExplainer(Explainer): | ||
def __init__(self) -> None: | ||
raise NotImplementedError( | ||
"The TreeExplainer is not yet implemented. An initial version can be found here: " | ||
"'https://github.com/mmschlk/TreeSHAP-IQ'." | ||
) | ||
|
||
def explain(self, x_explain: np.ndarray) -> InteractionValues: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import sys | ||
from typing import Any, Union | ||
|
||
|
||
def safe_isinstance(obj: Any, class_path_str: Union[str, list[str], tuple[str]]) -> bool: | ||
""" | ||
Acts as a safe version of isinstance without having to explicitly import packages which may not | ||
exist in the user's environment. Checks if obj is an instance of type specified by | ||
class_path_str. | ||
Note: | ||
This function was directly taken from the `shap` repository. | ||
Args: | ||
obj: Some object you want to test against | ||
class_path_str: A string or list of strings specifying full class paths Example: | ||
`sklearn.ensemble.RandomForestRegressor` | ||
Returns: | ||
True if isinstance is true and the package exists, False otherwise | ||
""" | ||
if isinstance(class_path_str, str): | ||
class_path_strs = [class_path_str] | ||
elif isinstance(class_path_str, list) or isinstance(class_path_str, tuple): | ||
class_path_strs = class_path_str | ||
else: | ||
class_path_strs = [""] | ||
|
||
# try each module path in order | ||
for class_path_str in class_path_strs: | ||
if "." not in class_path_str: | ||
raise ValueError( | ||
"class_path_str must be a string or list of strings specifying a full \ | ||
module path to a class. Eg, 'sklearn.ensemble.RandomForestRegressor'" | ||
) | ||
|
||
# Splits on last occurrence of "." | ||
module_name, class_name = class_path_str.rsplit(".", 1) | ||
|
||
# here we don't check further if the model is not imported, since we shouldn't have | ||
# an object of that types passed to us if the model the type is from has never been | ||
# imported. (and we don't want to import lots of new modules for no reason) | ||
if module_name not in sys.modules: | ||
continue | ||
|
||
module = sys.modules[module_name] | ||
|
||
# Get class | ||
_class = getattr(module, class_name, None) | ||
|
||
if _class is None: | ||
continue | ||
|
||
if isinstance(obj, _class): | ||
return True | ||
|
||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""This module contains all tests for the TreeExplainer class of the shapiq package.""" | ||
import pytest | ||
|
||
from shapiq.explainer import TreeExplainer | ||
|
||
|
||
def test_init(): | ||
with pytest.raises(NotImplementedError): | ||
explainer = TreeExplainer() |