Skip to content

Commit a08206e

Browse files
committed
import shap only on --shap call, suppress shap's numba deprecation warnings
1 parent b7e331d commit a08206e

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

pypythia/predictor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pickle
2-
import shap
2+
import warnings
33

44
from lightgbm import LGBMRegressor
55
from matplotlib.figure import Figure
@@ -112,6 +112,10 @@ def predict(self, query: Dict) -> float:
112112
return prediction
113113

114114
def plot_shapley_values(self, query: Dict) -> Figure:
115+
with warnings.catch_warnings():
116+
warnings.simplefilter("ignore")
117+
import shap
118+
115119
if isinstance(self.predictor, RandomForestRegressor):
116120
raise PyPythiaException("Cannot infer shapley values for scikit-learn predictors")
117121

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ long_description_content_type = text/markdown
66
url = https://github.com/tschuelia/PyPythia
77
author = Julia Haag
88
author_email = [email protected]
9-
version = 1.1.0
9+
version = 1.1.2
1010
classifiers =
1111
Programming Language :: Python :: 3.7
1212
Programming Language :: Python :: 3.8

0 commit comments

Comments
 (0)