Skip to content

Commit b7e331d

Browse files
committed
remove sklearn shap support to fix dependency issues
1 parent f51ba86 commit b7e331d

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

pypythia/predictor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,14 @@ def predict(self, query: Dict) -> float:
112112
return prediction
113113

114114
def plot_shapley_values(self, query: Dict) -> Figure:
115+
if isinstance(self.predictor, RandomForestRegressor):
116+
raise PyPythiaException("Cannot infer shapley values for scikit-learn predictors")
117+
115118
explainer = shap.TreeExplainer(self.predictor)
116119
df = self._check_and_pack_query(query)
117120
shap_values = explainer.shap_values(df)
118121
base_values = explainer.expected_value
119122

120-
if isinstance(self.predictor, RandomForestRegressor):
121-
base_values = base_values[0]
122-
123123
return shap.plots.waterfall(
124124
shap.Explanation(
125125
values=shap_values[0],

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ classifiers =
1313
Programming Language :: Python :: 3.9
1414
Programming Language :: Python :: 3.10
1515
Programming Language :: Python :: 3.11
16-
1716

1817
[options]
1918
include_package_data = true

tests/test_predictor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import matplotlib.figure
2+
import pytest
23

34
from tests.fixtures import *
45
import numpy as np
@@ -122,7 +123,7 @@ def test_predict(self, sklearn_predictor):
122123
assert isinstance(prediction, float)
123124
assert 0.0 <= prediction <= 1.0
124125

125-
def test_plot_shapley_values(self, sklearn_predictor):
126+
def test_plot_shapley_values_raises_error(self, sklearn_predictor):
126127
query = {
127128
"num_patterns/num_taxa": 0.0,
128129
"num_sites/num_taxa": 0.0,
@@ -134,5 +135,5 @@ def test_plot_shapley_values(self, sklearn_predictor):
134135
"proportion_unique_topos_parsimony": 0.0,
135136
}
136137

137-
fig = sklearn_predictor.plot_shapley_values(query)
138-
assert isinstance(fig, matplotlib.figure.Figure)
138+
with pytest.raises(PyPythiaException):
139+
sklearn_predictor.plot_shapley_values(query)

0 commit comments

Comments
 (0)