Skip to content

Commit

Permalink
Add explained
Browse files Browse the repository at this point in the history
  • Loading branch information
toandm2 committed Mar 1, 2021
1 parent 845babd commit 24af041
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pipelineservice/explain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ._explain import Explained

87 changes: 87 additions & 0 deletions pipelineservice/explain/_explain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import sklearn
from sklearn.inspection import permutation_importance, plot_partial_dependence
import shap
import seaborn as sns


class Explained(object):
def __init__(self,
pipeline,
X = None,
y = None,
):
self.pipeline = pipeline
self.X = X
self.y = y
def get_estimator(self):
try:
return self.pipeline[-1]
except:
return self.pipeline
def get_xtransform(self):
try:
X_transform = self.pipeline[:-1].transform(self.X)
except:
X_transform = self.X
return X_transform


def get_feature_importance(self):
feature_importances = None
for attr in ("feature_importances_", "coef_"):
try:
feature_importances = getattr(self.get_estimator(), attr)
except:
continue
data = pd.DataFrame()
data['feature_name'] = self.get_xtransform().columns.tolist()
data['feature_importances'] = feature_importances
return data
def plot_feature_importance(self, top_k = None):
data = self.get_feature_importance()
data = data.sort_values(by = ['feature_importances'], ascending = False)
if top_k is not None:
data = data[:top_k]

height = int(data.shape[0]*0.3)
aspect = 12/height
facegrid = sns.catplot(data = data, y = 'feature_name', x = 'feature_importances', kind = 'bar', height = height, aspect=aspect)

return facegrid.ax

def get_permutation_importance(self, n_jobs = -1, **kwargs):

permutation = permutation_importance(estimator = self.pipeline,
X = self.X,
y = self.y,
n_jobs = n_jobs,
**kwargs
)
data = pd.DataFrame()
data['feature_name'] = self.X.columns.tolist()
data['permutation_importance'] = permutation.importances_mean
return data

def plot_permutation_importance(self, top_k = None):
data = self.get_permutation_importance()
data = data.sort_values(by = ['permutation_importance'], ascending = False)
if top_k is not None:
data = data[:top_k]

height = int(data.shape[0]*0.3)
aspect = 12/height
facegrid = sns.catplot(data = data, y = 'feature_name', x = 'permutation_importance', kind = 'bar', height = height, aspect=aspect)

return facegrid.ax
def plot_partial_dependence(self,
features,
**kwargs):
fig = plot_partial_dependence(estimator = self.get_estimator(),
X = self.get_xtransform(),
features = features
)
return fig
def shapley_importance(self):
explainer = shap.TreeExplainer(self.get_estimator())
shap_values = explainer.shap_values(self.get_xtransform())
return shap.summary_plot(shap_values, self.get_xtransform(), plot_type = "bar")

0 comments on commit 24af041

Please sign in to comment.