Skip to content

Commit

Permalink
Update explained
Browse files Browse the repository at this point in the history
  • Loading branch information
toandm2 committed Mar 1, 2021
1 parent 24af041 commit 6db71c0
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions pipelineservice/explain/_explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ def __init__(self,
self.y = y
def get_estimator(self):
try:
return self.pipeline[-1]
pipeline = self.pipeline[-1]
except:
return self.pipeline
pipeline = self.pipeline
sklearn.utils.validation.check_is_fitted(pipeline)
return pipeline
def get_xtransform(self):
try:
X_transform = self.pipeline[:-1].transform(self.X)
Expand Down Expand Up @@ -51,8 +53,8 @@ def plot_feature_importance(self, top_k = None):

def get_permutation_importance(self, n_jobs = -1, **kwargs):

permutation = permutation_importance(estimator = self.pipeline,
X = self.X,
permutation = permutation_importance(estimator = self.get_estimator(),
X = self.get_xtransform(),
y = self.y,
n_jobs = n_jobs,
**kwargs
Expand All @@ -75,13 +77,16 @@ def plot_permutation_importance(self, top_k = None):
return facegrid.ax
def plot_partial_dependence(self,
features,
n_jobs = -1,
**kwargs):
fig = plot_partial_dependence(estimator = self.get_estimator(),
X = self.get_xtransform(),
features = features
features = features,
n_jobs = n_jobs,
**kwargs
)
return fig
def shapley_importance(self):
def plot_shap_importance(self):
explainer = shap.TreeExplainer(self.get_estimator())
shap_values = explainer.shap_values(self.get_xtransform())
return shap.summary_plot(shap_values, self.get_xtransform(), plot_type = "bar")

0 comments on commit 6db71c0

Please sign in to comment.